fix(gateway): harden REQUEST_GUILD_MEMBERS path against DoS floods

This commit is contained in:
Hampus Kraft
2026-02-22 13:41:25 +00:00
parent 4f5704fa1f
commit d843d6f3f8
2 changed files with 270 additions and 59 deletions

View File

@@ -24,6 +24,15 @@
-define(CHUNK_SIZE, 1000).
-define(MAX_USER_IDS, 100).
-define(MAX_NONCE_LENGTH, 32).
-define(FULL_MEMBER_LIST_LIMIT, 100000).
-define(DEFAULT_QUERY_LIMIT, 25).
-define(MAX_MEMBER_QUERY_LIMIT, 100).
-define(REQUEST_MEMBERS_RATE_LIMIT_TABLE, guild_request_members_rate_limit).
-define(REQUEST_MEMBERS_RATE_LIMIT_WINDOW_MS, 10000).
-define(REQUEST_MEMBERS_RATE_LIMIT_MAX_EVENTS, 5).
-define(REQUEST_MEMBERS_GUILD_RATE_LIMIT_TABLE, guild_request_members_guild_rate_limit).
-define(REQUEST_MEMBERS_GUILD_RATE_LIMIT_WINDOW_MS, 10000).
-define(REQUEST_MEMBERS_GUILD_RATE_LIMIT_MAX_EVENTS, 25).
-type session_state() :: map().
-type request_data() :: map().
@@ -121,7 +130,8 @@ ensure_binary(Value) when is_binary(Value) -> Value;
ensure_binary(_) -> <<>>.
-spec ensure_limit(term()) -> non_neg_integer().
ensure_limit(Limit) when is_integer(Limit), Limit >= 0 -> Limit;
ensure_limit(Limit) when is_integer(Limit), Limit >= 0 ->
min(Limit, ?MAX_MEMBER_QUERY_LIMIT);
ensure_limit(_) -> 0.
-spec normalize_nonce(term()) -> binary() | null.
@@ -135,13 +145,99 @@ process_request(Request, SocketPid, SessionState) ->
#{guild_id := GuildId, query := Query, limit := Limit, user_ids := UserIds} = Request,
UserIdBin = maps:get(user_id, SessionState),
UserId = type_conv:to_integer(UserIdBin),
case check_permission(UserId, GuildId, Query, Limit, UserIds, SessionState) of
case check_request_rate_limit(UserId) of
ok ->
fetch_and_send_members(Request, SocketPid, SessionState);
case check_guild_request_rate_limit(GuildId) of
ok ->
case check_permission(UserId, GuildId, Query, Limit, UserIds, SessionState) of
ok ->
fetch_and_send_members(Request, SocketPid, SessionState);
{error, Reason} ->
{error, Reason}
end;
{error, Reason} ->
{error, Reason}
end;
{error, Reason} ->
{error, Reason}
end.
-spec check_request_rate_limit(integer() | undefined) -> ok | {error, atom()}.
check_request_rate_limit(UserId) when is_integer(UserId), UserId > 0 ->
ensure_request_rate_limit_table(),
Now = erlang:system_time(millisecond),
case ets:lookup(?REQUEST_MEMBERS_RATE_LIMIT_TABLE, UserId) of
[] ->
ets:insert(?REQUEST_MEMBERS_RATE_LIMIT_TABLE, {UserId, [Now]}),
ok;
[{UserId, Timestamps}] ->
RecentTimestamps =
[T || T <- Timestamps, (Now - T) < ?REQUEST_MEMBERS_RATE_LIMIT_WINDOW_MS],
case length(RecentTimestamps) >= ?REQUEST_MEMBERS_RATE_LIMIT_MAX_EVENTS of
true ->
{error, rate_limited};
false ->
ets:insert(?REQUEST_MEMBERS_RATE_LIMIT_TABLE, {UserId, [Now | RecentTimestamps]}),
ok
end
end;
check_request_rate_limit(_) ->
{error, invalid_session}.
-spec check_guild_request_rate_limit(integer()) -> ok | {error, atom()}.
check_guild_request_rate_limit(GuildId) when is_integer(GuildId), GuildId > 0 ->
ensure_guild_request_rate_limit_table(),
Now = erlang:system_time(millisecond),
case ets:lookup(?REQUEST_MEMBERS_GUILD_RATE_LIMIT_TABLE, GuildId) of
[] ->
ets:insert(?REQUEST_MEMBERS_GUILD_RATE_LIMIT_TABLE, {GuildId, [Now]}),
ok;
[{GuildId, Timestamps}] ->
RecentTimestamps =
[T || T <- Timestamps, (Now - T) < ?REQUEST_MEMBERS_GUILD_RATE_LIMIT_WINDOW_MS],
case length(RecentTimestamps) >= ?REQUEST_MEMBERS_GUILD_RATE_LIMIT_MAX_EVENTS of
true ->
{error, rate_limited};
false ->
ets:insert(
?REQUEST_MEMBERS_GUILD_RATE_LIMIT_TABLE, {GuildId, [Now | RecentTimestamps]}
),
ok
end
end;
check_guild_request_rate_limit(_) ->
{error, invalid_guild_id}.
-spec ensure_request_rate_limit_table() -> ok.
ensure_request_rate_limit_table() ->
case ets:whereis(?REQUEST_MEMBERS_RATE_LIMIT_TABLE) of
undefined ->
try
ets:new(?REQUEST_MEMBERS_RATE_LIMIT_TABLE, [named_table, public, set]),
ok
catch
error:badarg ->
ok
end;
_ ->
ok
end.
-spec ensure_guild_request_rate_limit_table() -> ok.
ensure_guild_request_rate_limit_table() ->
case ets:whereis(?REQUEST_MEMBERS_GUILD_RATE_LIMIT_TABLE) of
undefined ->
try
ets:new(?REQUEST_MEMBERS_GUILD_RATE_LIMIT_TABLE, [named_table, public, set]),
ok
catch
error:badarg ->
ok
end;
_ ->
ok
end.
-spec check_permission(
integer(), integer(), binary(), non_neg_integer(), [integer()], session_state()
) ->
@@ -215,18 +311,9 @@ fetch_and_send_members(Request, _SocketPid, SessionState) ->
-spec fetch_members(pid(), binary(), non_neg_integer(), [integer()]) -> [member()].
fetch_members(GuildPid, _Query, _Limit, UserIds) when UserIds =/= [] ->
case gen_server:call(GuildPid, {list_guild_members, #{limit => 100000, offset => 0}}, 10000) of
#{members := AllMembers} ->
filter_members_by_ids(AllMembers, UserIds);
_ ->
[]
end;
fetch_members_by_user_ids(GuildPid, UserIds);
fetch_members(GuildPid, Query, Limit, []) ->
ActualLimit =
case Limit of
0 -> 100000;
L -> L
end,
ActualLimit = resolve_member_limit(Query, Limit),
case
gen_server:call(GuildPid, {list_guild_members, #{limit => ActualLimit, offset => 0}}, 10000)
of
@@ -241,17 +328,33 @@ fetch_members(GuildPid, Query, Limit, []) ->
[]
end.
-spec filter_members_by_ids([member()], [integer()]) -> [member()].
filter_members_by_ids(Members, UserIds) ->
UserIdSet = sets:from_list(UserIds),
lists:filter(
fun(Member) ->
UserId = extract_user_id(Member),
UserId =/= undefined andalso sets:is_element(UserId, UserIdSet)
-spec fetch_members_by_user_ids(pid(), [integer()]) -> [member()].
fetch_members_by_user_ids(GuildPid, UserIds) ->
lists:filtermap(
fun(UserId) ->
try
case gen_server:call(GuildPid, {get_guild_member, #{user_id => UserId}}, 5000) of
#{success := true, member_data := Member} when is_map(Member) ->
{true, Member};
_ ->
false
end
catch
exit:_ ->
false
end
end,
Members
lists:usort(UserIds)
).
-spec resolve_member_limit(binary(), non_neg_integer()) -> pos_integer().
resolve_member_limit(<<>>, 0) ->
?FULL_MEMBER_LIST_LIMIT;
resolve_member_limit(_Query, 0) ->
?DEFAULT_QUERY_LIMIT;
resolve_member_limit(_Query, Limit) ->
Limit.
-spec filter_members_by_query([member()], binary(), non_neg_integer()) -> [member()].
filter_members_by_query(Members, Query, Limit) ->
NormalizedQuery = string:lowercase(binary_to_list(Query)),
@@ -518,6 +621,64 @@ ensure_limit_negative_test() ->
ensure_limit_non_integer_test() ->
?assertEqual(0, ensure_limit(<<"10">>)).
ensure_limit_clamped_test() ->
?assertEqual(?MAX_MEMBER_QUERY_LIMIT, ensure_limit(?MAX_MEMBER_QUERY_LIMIT + 1)).
resolve_member_limit_full_scan_test() ->
?assertEqual(?FULL_MEMBER_LIST_LIMIT, resolve_member_limit(<<>>, 0)).
resolve_member_limit_query_default_test() ->
?assertEqual(?DEFAULT_QUERY_LIMIT, resolve_member_limit(<<"ab">>, 0)).
resolve_member_limit_explicit_test() ->
?assertEqual(25, resolve_member_limit(<<"ab">>, 25)).
check_request_rate_limit_allows_initial_request_test() ->
UserId = 987654321,
clear_request_rate_limit(UserId),
?assertEqual(ok, check_request_rate_limit(UserId)),
clear_request_rate_limit(UserId).
check_request_rate_limit_blocks_burst_test() ->
UserId = 987654322,
clear_request_rate_limit(UserId),
ensure_request_rate_limit_table(),
Now = erlang:system_time(millisecond),
Timestamps = lists:duplicate(?REQUEST_MEMBERS_RATE_LIMIT_MAX_EVENTS, Now - 1000),
ets:insert(?REQUEST_MEMBERS_RATE_LIMIT_TABLE, {UserId, Timestamps}),
?assertEqual({error, rate_limited}, check_request_rate_limit(UserId)),
clear_request_rate_limit(UserId).
check_request_rate_limit_invalid_user_test() ->
?assertEqual({error, invalid_session}, check_request_rate_limit(undefined)).
check_guild_request_rate_limit_allows_initial_request_test() ->
GuildId = 87654321,
clear_guild_request_rate_limit(GuildId),
?assertEqual(ok, check_guild_request_rate_limit(GuildId)),
clear_guild_request_rate_limit(GuildId).
check_guild_request_rate_limit_blocks_burst_test() ->
GuildId = 87654322,
clear_guild_request_rate_limit(GuildId),
ensure_guild_request_rate_limit_table(),
Now = erlang:system_time(millisecond),
Timestamps = lists:duplicate(?REQUEST_MEMBERS_GUILD_RATE_LIMIT_MAX_EVENTS, Now - 1000),
ets:insert(?REQUEST_MEMBERS_GUILD_RATE_LIMIT_TABLE, {GuildId, Timestamps}),
?assertEqual({error, rate_limited}, check_guild_request_rate_limit(GuildId)),
clear_guild_request_rate_limit(GuildId).
check_guild_request_rate_limit_invalid_guild_test() ->
?assertEqual({error, invalid_guild_id}, check_guild_request_rate_limit(undefined)).
clear_request_rate_limit(UserId) ->
ensure_request_rate_limit_table(),
ets:delete(?REQUEST_MEMBERS_RATE_LIMIT_TABLE, UserId).
clear_guild_request_rate_limit(GuildId) ->
ensure_guild_request_rate_limit_table(),
ets:delete(?REQUEST_MEMBERS_GUILD_RATE_LIMIT_TABLE, GuildId).
validate_guild_id_integer_test() ->
?assertEqual({ok, 123}, validate_guild_id(123)).
@@ -589,30 +750,6 @@ chunk_presences_no_matching_presences_test() ->
Result = chunk_presences(Presences, [Members]),
?assertEqual([[]], Result).
filter_members_by_ids_basic_test() ->
Members = [
#{<<"user">> => #{<<"id">> => <<"1">>}},
#{<<"user">> => #{<<"id">> => <<"2">>}},
#{<<"user">> => #{<<"id">> => <<"3">>}}
],
Result = filter_members_by_ids(Members, [1, 3]),
?assertEqual(2, length(Result)).
filter_members_by_ids_empty_ids_test() ->
Members = [#{<<"user">> => #{<<"id">> => <<"1">>}}],
Result = filter_members_by_ids(Members, []),
?assertEqual([], Result).
filter_members_by_ids_no_match_test() ->
Members = [#{<<"user">> => #{<<"id">> => <<"1">>}}],
Result = filter_members_by_ids(Members, [999]),
?assertEqual([], Result).
filter_members_by_ids_skips_invalid_members_test() ->
Members = [#{}, #{<<"user">> => #{}}, #{<<"user">> => #{<<"id">> => <<"1">>}}],
Result = filter_members_by_ids(Members, [1]),
?assertEqual(1, length(Result)).
filter_members_by_query_case_insensitive_test() ->
Members = [
#{<<"user">> => #{<<"id">> => <<"1">>, <<"username">> => <<"Alice">>}},