fix: various fixes to sentry-reported errors and more

This commit is contained in:
Hampus Kraft
2026-02-18 15:38:51 +00:00
parent 302c0d2a0c
commit 0517a966a3
357 changed files with 25420 additions and 16281 deletions

View File

@@ -384,8 +384,6 @@ disconnect_user_after_pending_timeout(ConnectionId, UserId, SessionId, State) ->
VoiceStates = maps:get(voice_states, State),
case maps:is_key(UserId, VoiceStates) of
true ->
%% Keep the active participant state when LiveKit confirm arrives late.
%% Reconciliation and explicit leave/disconnect paths handle true ghosts.
{noreply, State#{pending_connections => NewPending}};
false ->
NewSessions = remove_session_entry(SessionId, maps:get(sessions, State)),

View File

@@ -23,11 +23,12 @@
start(_StartType, _StartArgs) ->
fluxer_gateway_env:load(),
otel_metrics:init(),
passive_sync_registry:init(),
guild_counts_cache:init(),
Port = fluxer_gateway_env:get(port),
Dispatch = cowboy_router:compile([
{'_', [
{<<"/_health">>, health_handler, []},
{<<"/_rpc">>, gateway_rpc_http_handler, []},
{<<"/_admin/reload">>, hot_reload_handler, []},
{<<"/">>, gateway_handler, []}
]}

View File

@@ -43,17 +43,15 @@ load_from(Path) when is_list(Path) ->
-spec build_config(map()) -> config().
build_config(Json) ->
Service = get_map(Json, [<<"services">>, <<"gateway">>]),
Gateway = get_map(Json, [<<"gateway">>]),
Nats = get_map(Json, [<<"services">>, <<"nats">>]),
Telemetry = get_map(Json, [<<"telemetry">>]),
Sentry = get_map(Json, [<<"sentry">>]),
Vapid = get_map(Json, [<<"auth">>, <<"vapid">>]),
#{
port => get_int(Service, <<"port">>, 8080),
rpc_tcp_port => get_int(Service, <<"rpc_tcp_port">>, 8772),
api_host => get_env_or_string("FLUXER_GATEWAY_API_HOST", Service, <<"api_host">>, "api"),
api_canary_host => get_optional_string(Service, <<"api_canary_host">>),
admin_reload_secret => get_optional_binary(Service, <<"admin_reload_secret">>),
rpc_secret_key => get_binary(Gateway, <<"rpc_secret">>, undefined),
nats_core_url => get_string(Nats, <<"core_url">>, "nats://127.0.0.1:4222"),
nats_auth_token => get_string(Nats, <<"auth_token">>, ""),
identify_rate_limit_enabled => get_bool(Service, <<"identify_rate_limit_enabled">>, false),
push_enabled => get_bool(Service, <<"push_enabled">>, true),
push_user_guild_settings_cache_mb => get_int(
@@ -73,18 +71,12 @@ build_config(Json) ->
get_int(Service, <<"push_badge_counts_cache_ttl_seconds">>, 60),
push_dispatcher_max_inflight => get_int(Service, <<"push_dispatcher_max_inflight">>, 16),
push_dispatcher_max_queue => get_int(Service, <<"push_dispatcher_max_queue">>, 2048),
gateway_http_rpc_connect_timeout_ms =>
get_int(Service, <<"gateway_http_rpc_connect_timeout_ms">>, 5000),
gateway_http_rpc_recv_timeout_ms =>
get_int(Service, <<"gateway_http_rpc_recv_timeout_ms">>, 30000),
gateway_http_push_connect_timeout_ms =>
get_int(Service, <<"gateway_http_push_connect_timeout_ms">>, 3000),
gateway_http_push_recv_timeout_ms =>
get_int(Service, <<"gateway_http_push_recv_timeout_ms">>, 5000),
gateway_http_rpc_max_concurrency =>
get_int(Service, <<"gateway_http_rpc_max_concurrency">>, 512),
gateway_rpc_tcp_max_input_buffer_bytes =>
get_int(Service, <<"gateway_rpc_tcp_max_input_buffer_bytes">>, 2097152),
gateway_http_push_max_concurrency =>
get_int(Service, <<"gateway_http_push_max_concurrency">>, 256),
gateway_http_failure_threshold =>
@@ -148,27 +140,6 @@ get_optional_bool(Map, Key) ->
get_string(Map, Key, Default) when is_list(Default) ->
to_string(get_value(Map, Key), Default).
-spec get_env_or_string(string(), map(), binary(), string()) -> string().
get_env_or_string(EnvVar, Map, Key, Default) when is_list(EnvVar), is_list(Default) ->
case os:getenv(EnvVar) of
false -> get_string(Map, Key, Default);
"" -> get_string(Map, Key, Default);
Value -> Value
end.
-spec get_optional_string(map(), binary()) -> string() | undefined.
get_optional_string(Map, Key) ->
case get_value(Map, Key) of
undefined ->
undefined;
Value ->
Clean = string:trim(to_string(Value, "")),
case Clean of
"" -> undefined;
_ -> Clean
end
end.
-spec get_binary(map(), binary(), binary() | undefined) -> binary() | undefined.
get_binary(Map, Key, Default) ->
to_binary(get_value(Map, Key), Default).

View File

@@ -32,7 +32,7 @@ init([]) ->
},
Children = [
child_spec(gateway_http_client, gateway_http_client),
child_spec(gateway_rpc_tcp_server, gateway_rpc_tcp_server),
child_spec(gateway_nats_rpc, gateway_nats_rpc),
child_spec(session_manager, session_manager),
child_spec(presence_cache, presence_cache),
child_spec(presence_bus, presence_bus),

View File

@@ -35,7 +35,6 @@
-spec parse_compression(binary() | undefined) -> compression().
parse_compression(<<"none">>) ->
none;
%% TODO: temporarily disabled re-enable zstd-stream once compression issues are resolved
parse_compression(<<"zstd-stream">>) ->
none;
parse_compression(_) ->
@@ -123,7 +122,6 @@ parse_compression_test_() ->
?_assertEqual(none, parse_compression(<<>>)),
?_assertEqual(none, parse_compression(<<"none">>)),
?_assertEqual(none, parse_compression(<<"invalid">>)),
%% zstd-stream temporarily disabled always returns none
?_assertEqual(none, parse_compression(<<"zstd-stream">>))
].

View File

@@ -517,11 +517,16 @@ handle_resume(Data, State) ->
Token = maps:get(<<"token">>, Data),
SessionId = maps:get(<<"session_id">>, Data),
Seq = maps:get(<<"seq">>, Data),
case session_manager:lookup(SessionId) of
{ok, Pid} when is_pid(Pid) ->
handle_resume_with_session(Pid, Token, SessionId, Seq, State);
{error, not_found} ->
handle_resume_session_not_found(State)
case is_binary(SessionId) of
false ->
handle_resume_session_not_found(State);
true ->
case session_manager:lookup(SessionId) of
{ok, Pid} when is_pid(Pid) ->
handle_resume_with_session(Pid, Token, SessionId, Seq, State);
{error, _} ->
handle_resume_session_not_found(State)
end
end.
-spec handle_voice_state_update(pid(), map(), state()) -> ws_result().

View File

@@ -0,0 +1,260 @@
%% Copyright (C) 2026 Fluxer Contributors
%%
%% This file is part of Fluxer.
%%
%% Fluxer is free software: you can redistribute it and/or modify
%% it under the terms of the GNU Affero General Public License as published by
%% the Free Software Foundation, either version 3 of the License, or
%% (at your option) any later version.
%%
%% Fluxer is distributed in the hope that it will be useful,
%% but WITHOUT ANY WARRANTY; without even the implied warranty of
%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
%% GNU Affero General Public License for more details.
%%
%% You should have received a copy of the GNU Affero General Public License
%% along with Fluxer. If not, see <https://www.gnu.org/licenses/>.
-module(gateway_nats_rpc).
-behaviour(gen_server).
-export([start_link/0, get_connection/0]).
-export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]).
-define(DEFAULT_MAX_HANDLERS, 1024).
-define(RECONNECT_DELAY_MS, 2000).
-define(RPC_SUBJECT_PREFIX, <<"rpc.gateway.">>).
-define(RPC_SUBJECT_WILDCARD, <<"rpc.gateway.>">>).
-define(QUEUE_GROUP, <<"gateway">>).
-spec start_link() -> {ok, pid()} | {error, term()}.
start_link() ->
gen_server:start_link({local, ?MODULE}, ?MODULE, [], []).
-spec get_connection() -> {ok, nats:conn() | undefined} | {error, term()}.
get_connection() ->
gen_server:call(?MODULE, get_connection).
-spec init([]) -> {ok, map()}.
init([]) ->
process_flag(trap_exit, true),
self() ! connect,
{ok, #{
conn => undefined,
sub => undefined,
handler_count => 0,
max_handlers => max_handlers(),
monitor_ref => undefined
}}.
-spec handle_call(term(), gen_server:from(), map()) -> {reply, term(), map()}.
handle_call(get_connection, _From, #{conn := Conn} = State) ->
{reply, {ok, Conn}, State};
handle_call(_Request, _From, State) ->
{reply, ok, State}.
-spec handle_cast(term(), map()) -> {noreply, map()}.
handle_cast(_Msg, State) ->
{noreply, State}.
-spec handle_info(term(), map()) -> {noreply, map()}.
handle_info(connect, State) ->
{noreply, do_connect(State)};
handle_info({Conn, ready}, #{conn := Conn} = State) ->
{noreply, do_subscribe(State)};
handle_info({Conn, closed}, #{conn := Conn} = State) ->
logger:warning("Gateway NATS RPC connection closed, reconnecting"),
{noreply, schedule_reconnect(State#{conn => undefined, sub => undefined, monitor_ref => undefined})};
handle_info({Conn, {error, Reason}}, #{conn := Conn} = State) ->
logger:warning("Gateway NATS RPC connection error: ~p, reconnecting", [Reason]),
{noreply, schedule_reconnect(State#{conn => undefined, sub => undefined, monitor_ref => undefined})};
handle_info({Conn, _Sid, {msg, Subject, Payload, MsgOpts}},
#{conn := Conn, handler_count := HandlerCount, max_handlers := MaxHandlers} = State) ->
case maps:get(reply_to, MsgOpts, undefined) of
undefined ->
{noreply, State};
ReplyTo ->
case HandlerCount >= MaxHandlers of
true ->
ErrorResponse = iolist_to_binary(json:encode(#{
<<"ok">> => false,
<<"error">> => <<"overloaded">>
})),
nats:pub(Conn, ReplyTo, ErrorResponse),
{noreply, State};
false ->
Parent = self(),
spawn(fun() ->
try
handle_rpc_request(Conn, Subject, Payload, ReplyTo)
after
Parent ! {handler_done, self()}
end
end),
{noreply, State#{handler_count => HandlerCount + 1}}
end
end;
handle_info({handler_done, _Pid}, #{handler_count := HandlerCount} = State) when HandlerCount > 0 ->
{noreply, State#{handler_count => HandlerCount - 1}};
handle_info({handler_done, _Pid}, State) ->
{noreply, State};
handle_info({'DOWN', MRef, process, Conn, Reason}, #{conn := Conn, monitor_ref := MRef} = State) ->
logger:warning("Gateway NATS RPC connection process died: ~p, reconnecting", [Reason]),
{noreply, schedule_reconnect(State#{conn => undefined, sub => undefined, monitor_ref => undefined})};
handle_info(_Info, State) ->
{noreply, State}.
-spec terminate(term(), map()) -> ok.
terminate(_Reason, #{conn := Conn}) when Conn =/= undefined ->
catch nats:disconnect(Conn),
logger:info("Gateway NATS RPC subscriber stopped"),
ok;
terminate(_Reason, _State) ->
ok.
-spec code_change(term(), map(), term()) -> {ok, map()}.
code_change(_OldVsn, State, _Extra) ->
{ok, State}.
-spec do_connect(map()) -> map().
do_connect(State) ->
NatsUrl = fluxer_gateway_env:get(nats_core_url),
AuthToken = fluxer_gateway_env:get(nats_auth_token),
case parse_nats_url(NatsUrl) of
{ok, Host, Port} ->
Opts = build_connect_opts(AuthToken),
case nats:connect(Host, Port, Opts) of
{ok, Conn} ->
MRef = nats:monitor(Conn),
logger:info("Gateway NATS RPC connected to ~s:~p", [Host, Port]),
State#{conn => Conn, monitor_ref => MRef};
{error, Reason} ->
logger:error("Gateway NATS RPC failed to connect: ~p", [Reason]),
schedule_reconnect(State)
end;
{error, Reason} ->
logger:error("Gateway NATS RPC failed to parse URL: ~p", [Reason]),
schedule_reconnect(State)
end.
-spec do_subscribe(map()) -> map().
do_subscribe(#{conn := Conn} = State) when Conn =/= undefined ->
case nats:sub(Conn, ?RPC_SUBJECT_WILDCARD, #{queue_group => ?QUEUE_GROUP}) of
{ok, Sid} ->
logger:info("Gateway NATS RPC subscribed to ~s with queue group ~s",
[?RPC_SUBJECT_WILDCARD, ?QUEUE_GROUP]),
State#{sub => Sid};
{error, Reason} ->
logger:error("Gateway NATS RPC failed to subscribe: ~p", [Reason]),
State
end;
do_subscribe(State) ->
State.
-spec handle_rpc_request(nats:conn(), binary(), binary(), binary()) -> ok.
handle_rpc_request(Conn, Subject, Payload, ReplyTo) ->
Method = strip_rpc_prefix(Subject),
Response = execute_rpc_method(Method, Payload),
ResponseBin = iolist_to_binary(json:encode(Response)),
nats:pub(Conn, ReplyTo, ResponseBin),
ok.
-spec strip_rpc_prefix(binary()) -> binary().
strip_rpc_prefix(<<"rpc.gateway.", Method/binary>>) ->
Method;
strip_rpc_prefix(Subject) ->
Subject.
-spec execute_rpc_method(binary(), binary()) -> map().
execute_rpc_method(Method, PayloadBin) ->
try
Params = json:decode(PayloadBin),
Result = gateway_rpc_router:execute(Method, Params),
#{<<"ok">> => true, <<"result">> => Result}
catch
throw:{error, Message} ->
#{<<"ok">> => false, <<"error">> => error_binary(Message)};
exit:timeout ->
#{<<"ok">> => false, <<"error">> => <<"timeout">>};
exit:{timeout, _} ->
#{<<"ok">> => false, <<"error">> => <<"timeout">>};
Class:Reason ->
logger:error(
"Gateway NATS RPC method execution failed. method=~ts class=~p reason=~p",
[Method, Class, Reason]
),
#{<<"ok">> => false, <<"error">> => <<"internal_error">>}
end.
-spec error_binary(term()) -> binary().
error_binary(Value) when is_binary(Value) ->
Value;
error_binary(Value) when is_list(Value) ->
unicode:characters_to_binary(Value);
error_binary(Value) when is_atom(Value) ->
atom_to_binary(Value, utf8);
error_binary(Value) ->
unicode:characters_to_binary(io_lib:format("~p", [Value])).
-spec parse_nats_url(term()) -> {ok, string(), inet:port_number()} | {error, term()}.
parse_nats_url(Url) when is_list(Url) ->
parse_nats_url(list_to_binary(Url));
parse_nats_url(<<"nats://", Rest/binary>>) ->
parse_host_port(Rest);
parse_nats_url(<<"tls://", Rest/binary>>) ->
parse_host_port(Rest);
parse_nats_url(Url) when is_binary(Url) ->
parse_host_port(Url);
parse_nats_url(_) ->
{error, invalid_nats_url}.
-spec parse_host_port(binary()) -> {ok, string(), inet:port_number()} | {error, term()}.
parse_host_port(HostPort) ->
case binary:split(HostPort, <<":">>) of
[Host, PortBin] ->
try
Port = binary_to_integer(PortBin),
{ok, binary_to_list(Host), Port}
catch
_:_ -> {error, invalid_port}
end;
[Host] ->
{ok, binary_to_list(Host), 4222}
end.
-spec build_connect_opts(term()) -> map().
build_connect_opts(AuthToken) when is_binary(AuthToken), byte_size(AuthToken) > 0 ->
#{auth_token => AuthToken, buffer_size => 0};
build_connect_opts(AuthToken) when is_list(AuthToken) ->
case AuthToken of
"" -> #{buffer_size => 0};
_ -> #{auth_token => list_to_binary(AuthToken), buffer_size => 0}
end;
build_connect_opts(_) ->
#{buffer_size => 0}.
-spec schedule_reconnect(map()) -> map().
schedule_reconnect(State) ->
erlang:send_after(?RECONNECT_DELAY_MS, self(), connect),
State.
-spec max_handlers() -> pos_integer().
max_handlers() ->
case fluxer_gateway_env:get(gateway_http_rpc_max_concurrency) of
Value when is_integer(Value), Value > 0 ->
Value;
_ ->
?DEFAULT_MAX_HANDLERS
end.
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
parse_nats_url_test() ->
?assertEqual({ok, "127.0.0.1", 4222}, parse_nats_url(<<"nats://127.0.0.1:4222">>)),
?assertEqual({ok, "localhost", 4222}, parse_nats_url(<<"nats://localhost:4222">>)),
?assertEqual({ok, "localhost", 4222}, parse_nats_url(<<"nats://localhost">>)),
?assertEqual({ok, "127.0.0.1", 4222}, parse_nats_url("nats://127.0.0.1:4222")),
?assertEqual({error, invalid_nats_url}, parse_nats_url(undefined)).
-endif.

View File

@@ -35,15 +35,11 @@ execute_method(<<"guild.dispatch">>, #{
GuildId = validation:snowflake_or_throw(<<"guild_id">>, GuildIdBin),
with_guild(GuildId, fun(Pid) ->
EventAtom = constants:dispatch_event_atom(Event),
case
gen_server:call(
Pid, {dispatch, #{event => EventAtom, data => Data}}, ?GUILD_CALL_TIMEOUT
)
of
ok ->
true;
_ -> throw({error, <<"dispatch_error">>})
end
IsAlive = erlang:is_process_alive(Pid),
logger:info("rpc guild.dispatch: guild_id=~p event=~p pid=~p alive=~p",
[GuildId, EventAtom, Pid, IsAlive]),
gen_server:cast(Pid, {dispatch, #{event => EventAtom, data => Data}}),
true
end);
execute_method(<<"guild.get_counts">>, #{<<"guild_id">> := GuildIdBin}) ->
GuildId = validation:snowflake_or_throw(<<"guild_id">>, GuildIdBin),
@@ -78,29 +74,23 @@ execute_method(<<"guild.get_data">>, #{<<"guild_id">> := GuildIdBin, <<"user_id"
execute_method(<<"guild.get_member">>, #{<<"guild_id">> := GuildIdBin, <<"user_id">> := UserIdBin}) ->
GuildId = validation:snowflake_or_throw(<<"guild_id">>, GuildIdBin),
UserId = validation:snowflake_or_throw(<<"user_id">>, UserIdBin),
with_guild(GuildId, fun(Pid) ->
Request = #{user_id => UserId},
case gen_server:call(Pid, {get_guild_member, Request}, ?GUILD_CALL_TIMEOUT) of
#{success := true, member_data := MemberData} ->
#{<<"success">> => true, <<"member_data">> => MemberData};
#{success := false} ->
#{<<"success">> => false};
_ ->
throw({error, <<"guild_member_error">>})
end
end);
case get_member_cached_or_rpc(GuildId, UserId) of
{ok, MemberData} when is_map(MemberData) ->
#{<<"success">> => true, <<"member_data">> => MemberData};
{ok, undefined} ->
#{<<"success">> => false};
error ->
throw({error, <<"guild_member_error">>})
end;
execute_method(<<"guild.has_member">>, #{<<"guild_id">> := GuildIdBin, <<"user_id">> := UserIdBin}) ->
GuildId = validation:snowflake_or_throw(<<"guild_id">>, GuildIdBin),
UserId = validation:snowflake_or_throw(<<"user_id">>, UserIdBin),
with_guild(GuildId, fun(Pid) ->
Request = #{user_id => UserId},
case gen_server:call(Pid, {has_member, Request}, ?GUILD_CALL_TIMEOUT) of
#{has_member := HasMember} when is_boolean(HasMember) ->
#{<<"has_member">> => HasMember};
_ ->
throw({error, <<"membership_check_error">>})
end
end);
case get_has_member_cached_or_rpc(GuildId, UserId) of
{ok, HasMember} ->
#{<<"has_member">> => HasMember};
error ->
throw({error, <<"membership_check_error">>})
end;
execute_method(<<"guild.list_members">>, #{
<<"guild_id">> := GuildIdBin, <<"limit">> := Limit, <<"offset">> := Offset
}) ->
@@ -429,9 +419,9 @@ execute_method(<<"guild.update_member_voice">>, #{
}) ->
GuildId = validation:snowflake_or_throw(<<"guild_id">>, GuildIdBin),
UserId = validation:snowflake_or_throw(<<"user_id">>, UserIdBin),
with_guild(GuildId, fun(Pid) ->
with_voice_server(GuildId, fun(VoicePid, _GuildPid) ->
Request = #{user_id => UserId, mute => Mute, deaf => Deaf},
case gen_server:call(Pid, {update_member_voice, Request}, ?GUILD_CALL_TIMEOUT) of
case gen_server:call(VoicePid, {update_member_voice, Request}, ?GUILD_CALL_TIMEOUT) of
#{success := true} -> #{<<"success">> => true};
#{error := Error} -> throw({error, normalize_voice_rpc_error(Error)})
end
@@ -443,9 +433,9 @@ execute_method(
GuildId = validation:snowflake_or_throw(<<"guild_id">>, GuildIdBin),
UserId = validation:snowflake_or_throw(<<"user_id">>, UserIdBin),
ConnectionId = maps:get(<<"connection_id">>, Params, null),
with_guild(GuildId, fun(Pid) ->
with_voice_server(GuildId, fun(VoicePid, _GuildPid) ->
Request = #{user_id => UserId, connection_id => ConnectionId},
case gen_server:call(Pid, {disconnect_voice_user, Request}, ?GUILD_CALL_TIMEOUT) of
case gen_server:call(VoicePid, {disconnect_voice_user, Request}, ?GUILD_CALL_TIMEOUT) of
#{success := true} -> #{<<"success">> => true};
#{error := Error} -> throw({error, normalize_voice_rpc_error(Error)})
end
@@ -464,11 +454,11 @@ execute_method(
<<"expected_channel_id">>, ExpectedChannelIdBin
),
ConnectionId = maps:get(<<"connection_id">>, Params, undefined),
with_guild(GuildId, fun(Pid) ->
with_voice_server(GuildId, fun(VoicePid, _GuildPid) ->
Request = build_disconnect_request(UserId, ExpectedChannelId, ConnectionId),
case
gen_server:call(
Pid, {disconnect_voice_user_if_in_channel, Request}, ?GUILD_CALL_TIMEOUT
VoicePid, {disconnect_voice_user_if_in_channel, Request}, ?GUILD_CALL_TIMEOUT
)
of
#{success := true, ignored := true} -> #{<<"success">> => true, <<"ignored">> => true};
@@ -481,11 +471,11 @@ execute_method(<<"guild.disconnect_all_voice_users_in_channel">>, #{
}) ->
GuildId = validation:snowflake_or_throw(<<"guild_id">>, GuildIdBin),
ChannelId = validation:snowflake_or_throw(<<"channel_id">>, ChannelIdBin),
with_guild(GuildId, fun(Pid) ->
with_voice_server(GuildId, fun(VoicePid, _GuildPid) ->
Request = #{channel_id => ChannelId},
case
gen_server:call(
Pid, {disconnect_all_voice_users_in_channel, Request}, ?GUILD_CALL_TIMEOUT
VoicePid, {disconnect_all_voice_users_in_channel, Request}, ?GUILD_CALL_TIMEOUT
)
of
#{success := true, disconnected_count := Count} ->
@@ -499,11 +489,11 @@ execute_method(<<"guild.confirm_voice_connection_from_livekit">>, Params) ->
ConnectionId = maps:get(<<"connection_id">>, Params),
TokenNonce = maps:get(<<"token_nonce">>, Params, undefined),
GuildId = validation:snowflake_or_throw(<<"guild_id">>, GuildIdBin),
with_guild(GuildId, fun(Pid) ->
with_voice_server(GuildId, fun(VoicePid, _GuildPid) ->
Request = #{connection_id => ConnectionId, token_nonce => TokenNonce},
case
gen_server:call(
Pid, {confirm_voice_connection_from_livekit, Request}, ?GUILD_CALL_TIMEOUT
VoicePid, {confirm_voice_connection_from_livekit, Request}, ?GUILD_CALL_TIMEOUT
)
of
#{success := true} -> #{<<"success">> => true};
@@ -518,8 +508,8 @@ execute_method(<<"guild.get_voice_states_for_channel">>, Params) ->
GuildIdBin = maps:get(<<"guild_id">>, Params),
ChannelIdBin = maps:get(<<"channel_id">>, Params),
GuildId = validation:snowflake_or_throw(<<"guild_id">>, GuildIdBin),
with_guild(GuildId, fun(Pid) ->
case gen_server:call(Pid, {get_voice_states_for_channel, ChannelIdBin}, 10000) of
with_voice_server(GuildId, fun(VoicePid, _GuildPid) ->
case gen_server:call(VoicePid, {get_voice_states_for_channel, ChannelIdBin}, 10000) of
#{voice_states := VoiceStates} ->
#{<<"voice_states">> => VoiceStates};
_ ->
@@ -530,8 +520,8 @@ execute_method(<<"guild.get_pending_joins_for_channel">>, Params) ->
GuildIdBin = maps:get(<<"guild_id">>, Params),
ChannelIdBin = maps:get(<<"channel_id">>, Params),
GuildId = validation:snowflake_or_throw(<<"guild_id">>, GuildIdBin),
with_guild(GuildId, fun(Pid) ->
case gen_server:call(Pid, {get_pending_joins_for_channel, ChannelIdBin}, 10000) of
with_voice_server(GuildId, fun(VoicePid, _GuildPid) ->
case gen_server:call(VoicePid, {get_pending_joins_for_channel, ChannelIdBin}, 10000) of
#{pending_joins := PendingJoins} ->
#{<<"pending_joins">> => PendingJoins};
_ ->
@@ -559,7 +549,7 @@ execute_method(<<"guild.move_member">>, #{
connection_id => ConnectionId
}
),
with_guild(GuildId, fun(Pid) ->
with_voice_server(GuildId, fun(VoicePid, GuildPid) ->
Request = #{
user_id => UserId,
moderator_id => ModeratorId,
@@ -567,10 +557,10 @@ execute_method(<<"guild.move_member">>, #{
connection_id => ConnectionId
},
handle_move_member_result(
gen_server:call(Pid, {move_member, Request}, ?GUILD_CALL_TIMEOUT),
gen_server:call(VoicePid, {move_member, Request}, ?GUILD_CALL_TIMEOUT),
GuildId,
ChannelId,
Pid
GuildPid
)
end);
execute_method(<<"guild.get_voice_state">>, #{
@@ -578,9 +568,9 @@ execute_method(<<"guild.get_voice_state">>, #{
}) ->
GuildId = validation:snowflake_or_throw(<<"guild_id">>, GuildIdBin),
UserId = validation:snowflake_or_throw(<<"user_id">>, UserIdBin),
with_guild(GuildId, fun(Pid) ->
with_voice_server(GuildId, fun(VoicePid, _GuildPid) ->
Request = #{user_id => UserId},
case gen_server:call(Pid, {get_voice_state, Request}, ?GUILD_CALL_TIMEOUT) of
case gen_server:call(VoicePid, {get_voice_state, Request}, ?GUILD_CALL_TIMEOUT) of
#{voice_state := null} -> #{<<"voice_state">> => null};
#{voice_state := VoiceState} -> #{<<"voice_state">> => VoiceState};
_ -> throw({error, <<"voice_state_error">>})
@@ -591,11 +581,11 @@ execute_method(<<"guild.switch_voice_region">>, #{
}) ->
GuildId = validation:snowflake_or_throw(<<"guild_id">>, GuildIdBin),
ChannelId = validation:snowflake_or_throw(<<"channel_id">>, ChannelIdBin),
with_guild(GuildId, fun(Pid) ->
with_voice_server(GuildId, fun(VoicePid, GuildPid) ->
Request = #{channel_id => ChannelId},
case gen_server:call(Pid, {switch_voice_region, Request}, ?GUILD_CALL_TIMEOUT) of
case gen_server:call(VoicePid, {switch_voice_region, Request}, ?GUILD_CALL_TIMEOUT) of
#{success := true} ->
spawn(fun() -> guild_voice:switch_voice_region(GuildId, ChannelId, Pid) end),
spawn(fun() -> guild_voice:switch_voice_region(GuildId, ChannelId, GuildPid) end),
#{<<"success">> => true};
#{error := Error} ->
throw({error, normalize_voice_rpc_error(Error)})
@@ -644,12 +634,26 @@ execute_method(<<"guild.batch_voice_state_update">>, #{<<"updates">> := UpdatesB
-spec fetch_online_count_entry(integer()) -> map() | undefined.
fetch_online_count_entry(GuildId) ->
case guild_counts_cache:get(GuildId) of
{ok, MemberCount, OnlineCount} ->
#{
<<"guild_id">> => integer_to_binary(GuildId),
<<"member_count">> => MemberCount,
<<"online_count">> => OnlineCount
};
miss ->
fetch_online_count_entry_from_process(GuildId)
end.
-spec fetch_online_count_entry_from_process(integer()) -> map() | undefined.
fetch_online_count_entry_from_process(GuildId) ->
case get_guild_pid(GuildId) of
{ok, Pid} ->
case gen_server:call(Pid, {get_counts}, ?GUILD_CALL_TIMEOUT) of
#{presence_count := PresenceCount} ->
#{member_count := MemberCount, presence_count := PresenceCount} ->
#{
<<"guild_id">> => integer_to_binary(GuildId),
<<"member_count">> => MemberCount,
<<"online_count">> => PresenceCount
};
_ ->
@@ -670,6 +674,23 @@ with_guild(GuildId, Fun, NotFoundError) ->
_ -> throw({error, NotFoundError})
end.
-spec with_voice_server(integer(), fun((pid(), pid()) -> T)) -> T when T :: term().
with_voice_server(GuildId, Fun) ->
case get_guild_pid(GuildId) of
{ok, GuildPid} ->
VoicePid = resolve_voice_pid(GuildId, GuildPid),
Fun(VoicePid, GuildPid);
_ ->
throw({error, <<"guild_not_found">>})
end.
-spec resolve_voice_pid(integer(), pid()) -> pid().
resolve_voice_pid(GuildId, FallbackGuildPid) ->
case guild_voice_server:lookup(GuildId) of
{ok, VoicePid} -> VoicePid;
{error, not_found} -> FallbackGuildPid
end.
-spec get_guild_pid(integer()) -> {ok, pid()} | error.
get_guild_pid(GuildId) ->
case lookup_guild_pid_from_cache(GuildId) of
@@ -792,6 +813,56 @@ get_viewable_channels_via_rpc(GuildId, UserId) ->
error
end.
-spec get_has_member_cached_or_rpc(integer(), integer()) -> {ok, boolean()} | error.
get_has_member_cached_or_rpc(GuildId, UserId) ->
case guild_permission_cache:has_member(GuildId, UserId) of
{ok, HasMember} ->
{ok, HasMember};
{error, not_found} ->
get_has_member_via_rpc(GuildId, UserId)
end.
-spec get_has_member_via_rpc(integer(), integer()) -> {ok, boolean()} | error.
get_has_member_via_rpc(GuildId, UserId) ->
case get_guild_pid(GuildId) of
{ok, Pid} ->
Request = #{user_id => UserId},
case gen_server:call(Pid, {has_member, Request}, ?GUILD_CALL_TIMEOUT) of
#{has_member := HasMember} when is_boolean(HasMember) ->
{ok, HasMember};
_ ->
error
end;
error ->
error
end.
-spec get_member_cached_or_rpc(integer(), integer()) -> {ok, map() | undefined} | error.
get_member_cached_or_rpc(GuildId, UserId) ->
case guild_permission_cache:get_member(GuildId, UserId) of
{ok, MemberOrUndefined} ->
{ok, MemberOrUndefined};
{error, not_found} ->
get_member_via_rpc(GuildId, UserId)
end.
-spec get_member_via_rpc(integer(), integer()) -> {ok, map() | undefined} | error.
get_member_via_rpc(GuildId, UserId) ->
case get_guild_pid(GuildId) of
{ok, Pid} ->
Request = #{user_id => UserId},
case gen_server:call(Pid, {get_guild_member, Request}, ?GUILD_CALL_TIMEOUT) of
#{success := true, member_data := MemberData} ->
{ok, MemberData};
#{success := false} ->
{ok, undefined};
_ ->
error
end;
error ->
error
end.
-spec parse_channel_id(binary()) -> integer() | undefined.
parse_channel_id(<<"0">>) -> undefined;
parse_channel_id(ChannelIdBin) -> validation:snowflake_or_throw(<<"channel_id">>, ChannelIdBin).
@@ -892,11 +963,12 @@ parse_voice_update(
-spec process_voice_update({integer(), integer(), boolean(), boolean(), term()}) -> map().
process_voice_update({GuildId, UserId, Mute, Deaf, ConnectionId}) ->
case gen_server:call(guild_manager, {start_or_lookup, GuildId}, ?GUILD_LOOKUP_TIMEOUT) of
{ok, Pid} ->
{ok, GuildPid} ->
VoicePid = resolve_voice_pid(GuildId, GuildPid),
Request = #{
user_id => UserId, mute => Mute, deaf => Deaf, connection_id => ConnectionId
},
case gen_server:call(Pid, {update_member_voice, Request}, ?GUILD_CALL_TIMEOUT) of
case gen_server:call(VoicePid, {update_member_voice, Request}, ?GUILD_CALL_TIMEOUT) of
#{success := true} ->
#{
<<"guild_id">> => integer_to_binary(GuildId),
@@ -1062,4 +1134,50 @@ get_viewable_channels_cached_or_rpc_prefers_cache_test() ->
ok = guild_permission_cache:delete(GuildId)
end.
get_has_member_cached_or_rpc_prefers_cache_test() ->
GuildId = 12348,
UserId = 502,
Data = #{
<<"guild">> => #{<<"owner_id">> => <<"999">>},
<<"roles">> => [],
<<"members">> => #{
UserId => #{
<<"user">> => #{<<"id">> => integer_to_binary(UserId)},
<<"roles">> => []
}
},
<<"channels">> => []
},
ok = guild_permission_cache:put_data(GuildId, Data),
try
?assertEqual({ok, true}, get_has_member_cached_or_rpc(GuildId, UserId)),
?assertEqual({ok, false}, get_has_member_cached_or_rpc(GuildId, 99999))
after
ok = guild_permission_cache:delete(GuildId)
end.
get_member_cached_or_rpc_prefers_cache_test() ->
GuildId = 12349,
UserId = 503,
Data = #{
<<"guild">> => #{<<"owner_id">> => <<"999">>},
<<"roles">> => [],
<<"members">> => #{
UserId => #{
<<"user">> => #{<<"id">> => integer_to_binary(UserId)},
<<"roles">> => [],
<<"nick">> => <<"CacheNick">>
}
},
<<"channels">> => []
},
ok = guild_permission_cache:put_data(GuildId, Data),
try
{ok, MemberData} = get_member_cached_or_rpc(GuildId, UserId),
?assertEqual(<<"CacheNick">>, maps:get(<<"nick">>, MemberData)),
?assertEqual({ok, undefined}, get_member_cached_or_rpc(GuildId, 99999))
after
ok = guild_permission_cache:delete(GuildId)
end.
-endif.

View File

@@ -1,167 +0,0 @@
%% Copyright (C) 2026 Fluxer Contributors
%%
%% This file is part of Fluxer.
%%
%% Fluxer is free software: you can redistribute it and/or modify
%% it under the terms of the GNU Affero General Public License as published by
%% the Free Software Foundation, either version 3 of the License, or
%% (at your option) any later version.
%%
%% Fluxer is distributed in the hope that it will be useful,
%% but WITHOUT ANY WARRANTY; without even the implied warranty of
%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
%% GNU Affero General Public License for more details.
%%
%% You should have received a copy of the GNU Affero General Public License
%% along with Fluxer. If not, see <https://www.gnu.org/licenses/>.
-module(gateway_rpc_http_handler).
-export([init/2]).
-define(JSON_HEADERS, #{<<"content-type">> => <<"application/json">>}).
-spec init(cowboy_req:req(), term()) -> {ok, cowboy_req:req(), term()}.
init(Req0, State) ->
case cowboy_req:method(Req0) of
<<"POST">> ->
handle_post(Req0, State);
_ ->
Req = cowboy_req:reply(405, #{<<"allow">> => <<"POST">>}, <<>>, Req0),
{ok, Req, State}
end.
-spec handle_post(cowboy_req:req(), term()) -> {ok, cowboy_req:req(), term()}.
handle_post(Req0, State) ->
case authorize(Req0) of
ok ->
case read_body(Req0) of
{ok, Decoded, Req1} ->
handle_decoded_body(Decoded, Req1, State);
{error, ErrorBody, Req1} ->
respond(400, ErrorBody, Req1, State)
end;
{error, Req1} ->
{ok, Req1, State}
end.
-spec handle_decoded_body(map(), cowboy_req:req(), term()) -> {ok, cowboy_req:req(), term()}.
handle_decoded_body(Decoded, Req0, State) ->
case maps:get(<<"method">>, Decoded, undefined) of
undefined ->
respond(400, #{<<"error">> => <<"Missing method">>}, Req0, State);
Method when is_binary(Method) ->
ParamsValue = maps:get(<<"params">>, Decoded, #{}),
case is_map(ParamsValue) of
true ->
execute_method(Method, ParamsValue, Req0, State);
false ->
respond(400, #{<<"error">> => <<"Invalid params">>}, Req0, State)
end;
_ ->
respond(400, #{<<"error">> => <<"Invalid method">>}, Req0, State)
end.
-spec authorize(cowboy_req:req()) -> ok | {error, cowboy_req:req()}.
authorize(Req0) ->
case cowboy_req:header(<<"authorization">>, Req0) of
undefined ->
Req = cowboy_req:reply(
401,
?JSON_HEADERS,
json:encode(#{<<"error">> => <<"Unauthorized">>}),
Req0
),
{error, Req};
AuthHeader ->
authorize_with_secret(AuthHeader, Req0)
end.
-spec authorize_with_secret(binary(), cowboy_req:req()) -> ok | {error, cowboy_req:req()}.
authorize_with_secret(AuthHeader, Req0) ->
case fluxer_gateway_env:get(rpc_secret_key) of
undefined ->
Req = cowboy_req:reply(
500,
?JSON_HEADERS,
json:encode(#{<<"error">> => <<"RPC secret not configured">>}),
Req0
),
{error, Req};
Secret when is_binary(Secret) ->
Expected = <<"Bearer ", Secret/binary>>,
check_auth_header(AuthHeader, Expected, Req0)
end.
-spec check_auth_header(binary(), binary(), cowboy_req:req()) -> ok | {error, cowboy_req:req()}.
check_auth_header(AuthHeader, Expected, Req0) ->
case secure_compare(AuthHeader, Expected) of
true ->
ok;
false ->
Req = cowboy_req:reply(
401,
?JSON_HEADERS,
json:encode(#{<<"error">> => <<"Unauthorized">>}),
Req0
),
{error, Req}
end.
-spec secure_compare(binary(), binary()) -> boolean().
secure_compare(Left, Right) when is_binary(Left), is_binary(Right) ->
case byte_size(Left) =:= byte_size(Right) of
true ->
crypto:hash_equals(Left, Right);
false ->
false
end.
-spec read_body(cowboy_req:req()) ->
{ok, map(), cowboy_req:req()} | {error, map(), cowboy_req:req()}.
read_body(Req0) ->
read_body_chunks(Req0, <<>>).
-spec read_body_chunks(cowboy_req:req(), binary()) ->
{ok, map(), cowboy_req:req()} | {error, map(), cowboy_req:req()}.
read_body_chunks(Req0, Acc) ->
case cowboy_req:read_body(Req0) of
{ok, Body, Req1} ->
FullBody = <<Acc/binary, Body/binary>>,
decode_body(FullBody, Req1);
{more, Body, Req1} ->
read_body_chunks(Req1, <<Acc/binary, Body/binary>>)
end.
-spec decode_body(binary(), cowboy_req:req()) ->
{ok, map(), cowboy_req:req()} | {error, map(), cowboy_req:req()}.
decode_body(Body, Req0) ->
case catch json:decode(Body) of
{'EXIT', _Reason} ->
{error, #{<<"error">> => <<"Invalid JSON payload">>}, Req0};
Decoded when is_map(Decoded) ->
{ok, Decoded, Req0};
_ ->
{error, #{<<"error">> => <<"Invalid request body">>}, Req0}
end.
-spec execute_method(binary(), map(), cowboy_req:req(), term()) -> {ok, cowboy_req:req(), term()}.
execute_method(Method, Params, Req0, State) ->
try
Result = gateway_rpc_router:execute(Method, Params),
respond(200, #{<<"result">> => Result}, Req0, State)
catch
throw:{error, Message} ->
respond(400, #{<<"error">> => Message}, Req0, State);
exit:timeout ->
respond(504, #{<<"error">> => <<"timeout">>}, Req0, State);
exit:{timeout, _} ->
respond(504, #{<<"error">> => <<"timeout">>}, Req0, State);
_:_ ->
respond(500, #{<<"error">> => <<"Internal error">>}, Req0, State)
end.
-spec respond(pos_integer(), map(), cowboy_req:req(), term()) -> {ok, cowboy_req:req(), term()}.
respond(Status, Body, Req0, State) ->
Req = cowboy_req:reply(Status, ?JSON_HEADERS, json:encode(Body), Req0),
{ok, Req, State}.

View File

@@ -1,446 +0,0 @@
%% Copyright (C) 2026 Fluxer Contributors
%%
%% This file is part of Fluxer.
%%
%% Fluxer is free software: you can redistribute it and/or modify
%% it under the terms of the GNU Affero General Public License as published by
%% the Free Software Foundation, either version 3 of the License, or
%% (at your option) any later version.
%%
%% Fluxer is distributed in the hope that it will be useful,
%% but WITHOUT ANY WARRANTY; without even the implied warranty of
%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
%% GNU Affero General Public License for more details.
%%
%% You should have received a copy of the GNU Affero General Public License
%% along with Fluxer. If not, see <https://www.gnu.org/licenses/>.
-module(gateway_rpc_tcp_connection).
-export([serve/1]).
-define(DEFAULT_MAX_INFLIGHT, 1024).
-define(DEFAULT_MAX_INPUT_BUFFER_BYTES, 2097152).
-define(DEFAULT_DISPATCH_RESERVE_DIVISOR, 8).
-define(MAX_FRAME_BYTES, 1048576).
-define(PROTOCOL_VERSION, <<"fluxer.rpc.tcp.v1">>).
-type state() :: #{
socket := inet:socket(),
buffer := binary(),
authenticated := boolean(),
inflight := non_neg_integer(),
max_inflight := pos_integer(),
max_input_buffer_bytes := pos_integer()
}.
-type rpc_result() :: {ok, term()} | {error, binary()}.
-spec serve(inet:socket()) -> ok.
serve(Socket) ->
ok = inet:setopts(Socket, [{active, once}, {nodelay, true}, {keepalive, true}]),
State = #{
socket => Socket,
buffer => <<>>,
authenticated => false,
inflight => 0,
max_inflight => max_inflight(),
max_input_buffer_bytes => max_input_buffer_bytes()
},
loop(State).
-spec loop(state()) -> ok.
loop(#{socket := Socket} = State) ->
receive
{tcp, Socket, Data} ->
case handle_tcp_data(Data, State) of
{ok, NewState} ->
ok = inet:setopts(Socket, [{active, once}]),
loop(NewState);
{stop, Reason, _NewState} ->
logger:debug("Gateway TCP RPC connection closed: ~p", [Reason]),
close_socket(Socket),
ok
end;
{tcp_closed, Socket} ->
ok;
{tcp_error, Socket, Reason} ->
logger:warning("Gateway TCP RPC socket error: ~p", [Reason]),
close_socket(Socket),
ok;
{rpc_response, RequestId, Result} ->
NewState = handle_rpc_response(RequestId, Result, State),
loop(NewState);
_Other ->
loop(State)
end.
-spec handle_tcp_data(binary(), state()) -> {ok, state()} | {stop, term(), state()}.
handle_tcp_data(Data, #{buffer := Buffer, max_input_buffer_bytes := MaxInputBufferBytes} = State) ->
case byte_size(Buffer) + byte_size(Data) =< MaxInputBufferBytes of
false ->
_ = send_error_frame(State, protocol_error_binary(input_buffer_limit_exceeded)),
{stop, input_buffer_limit_exceeded, State};
true ->
Combined = <<Buffer/binary, Data/binary>>,
decode_tcp_frames(Combined, State)
end.
-spec decode_tcp_frames(binary(), state()) -> {ok, state()} | {stop, term(), state()}.
decode_tcp_frames(Combined, State) ->
case decode_frames(Combined, []) of
{ok, Frames, Rest} ->
process_frames(Frames, State#{buffer => Rest});
{error, Reason} ->
_ = send_error_frame(State, protocol_error_binary(Reason)),
{stop, Reason, State}
end.
-spec process_frames([map()], state()) -> {ok, state()} | {stop, term(), state()}.
process_frames([], State) ->
{ok, State};
process_frames([Frame | Rest], State) ->
case process_frame(Frame, State) of
{ok, NewState} ->
process_frames(Rest, NewState);
{stop, Reason, NewState} ->
{stop, Reason, NewState}
end.
-spec process_frame(map(), state()) -> {ok, state()} | {stop, term(), state()}.
process_frame(#{<<"type">> := <<"hello">>} = Frame, #{authenticated := false} = State) ->
handle_hello_frame(Frame, State);
process_frame(#{<<"type">> := <<"hello">>}, State) ->
_ = send_error_frame(State, <<"duplicate_hello">>),
{stop, duplicate_hello, State};
process_frame(#{<<"type">> := <<"request">>} = Frame, #{authenticated := true} = State) ->
handle_request_frame(Frame, State);
process_frame(#{<<"type">> := <<"request">>}, State) ->
_ = send_error_frame(State, <<"unauthorized">>),
{stop, unauthorized, State};
process_frame(#{<<"type">> := <<"ping">>}, State) ->
_ = send_frame(State, #{<<"type">> => <<"pong">>}),
{ok, State};
process_frame(#{<<"type">> := <<"pong">>}, State) ->
{ok, State};
process_frame(#{<<"type">> := <<"close">>}, State) ->
{stop, client_close, State};
process_frame(_Frame, State) ->
_ = send_error_frame(State, <<"unknown_frame_type">>),
{stop, unknown_frame_type, State}.
-spec handle_hello_frame(map(), state()) -> {ok, state()} | {stop, term(), state()}.
handle_hello_frame(Frame, State) ->
case {maps:get(<<"protocol">>, Frame, undefined), maps:get(<<"authorization">>, Frame, undefined)} of
{?PROTOCOL_VERSION, AuthHeader} when is_binary(AuthHeader) ->
authorize_hello(AuthHeader, State);
_ ->
_ = send_error_frame(State, <<"invalid_hello">>),
{stop, invalid_hello, State}
end.
-spec authorize_hello(binary(), state()) -> {ok, state()} | {stop, term(), state()}.
authorize_hello(AuthHeader, State) ->
case fluxer_gateway_env:get(rpc_secret_key) of
Secret when is_binary(Secret) ->
Expected = <<"Bearer ", Secret/binary>>,
case secure_compare(AuthHeader, Expected) of
true ->
HelloAck = #{
<<"type">> => <<"hello_ack">>,
<<"protocol">> => ?PROTOCOL_VERSION,
<<"max_in_flight">> => maps:get(max_inflight, State),
<<"ping_interval_ms">> => 15000
},
_ = send_frame(State, HelloAck),
{ok, State#{authenticated => true}};
false ->
_ = send_error_frame(State, <<"unauthorized">>),
{stop, unauthorized, State}
end;
_ ->
_ = send_error_frame(State, <<"rpc_secret_not_configured">>),
{stop, rpc_secret_not_configured, State}
end.
-spec handle_request_frame(map(), state()) -> {ok, state()}.
handle_request_frame(Frame, State) ->
RequestId = request_id_from_frame(Frame),
Method = maps:get(<<"method">>, Frame, undefined),
case should_reject_request(Method, State) of
true ->
_ =
send_response_frame(
State,
RequestId,
false,
undefined,
<<"overloaded">>
),
{ok, State};
false ->
case {Method, maps:get(<<"params">>, Frame, undefined)} of
{MethodName, Params} when is_binary(RequestId), is_binary(MethodName), is_map(Params) ->
Parent = self(),
_ = spawn(fun() ->
Parent ! {rpc_response, RequestId, execute_method(MethodName, Params)}
end),
{ok, increment_inflight(State)};
_ ->
_ =
send_response_frame(
State,
RequestId,
false,
undefined,
<<"invalid_request">>
),
{ok, State}
end
end.
-spec should_reject_request(term(), state()) -> boolean().
should_reject_request(Method, #{inflight := Inflight, max_inflight := MaxInflight}) ->
case is_dispatch_method(Method) of
true ->
Inflight >= MaxInflight;
false ->
Inflight >= non_dispatch_inflight_limit(MaxInflight)
end.
-spec non_dispatch_inflight_limit(pos_integer()) -> pos_integer().
non_dispatch_inflight_limit(MaxInflight) ->
Reserve = dispatch_reserve_slots(MaxInflight),
max(1, MaxInflight - Reserve).
-spec dispatch_reserve_slots(pos_integer()) -> pos_integer().
dispatch_reserve_slots(MaxInflight) ->
max(1, MaxInflight div ?DEFAULT_DISPATCH_RESERVE_DIVISOR).
-spec is_dispatch_method(term()) -> boolean().
is_dispatch_method(Method) when is_binary(Method) ->
Suffix = <<".dispatch">>,
MethodSize = byte_size(Method),
SuffixSize = byte_size(Suffix),
MethodSize >= SuffixSize andalso
binary:part(Method, MethodSize - SuffixSize, SuffixSize) =:= Suffix;
is_dispatch_method(_) ->
false.
-spec execute_method(binary(), map()) -> rpc_result().
execute_method(Method, Params) ->
try
Result = gateway_rpc_router:execute(Method, Params),
{ok, Result}
catch
throw:{error, Message} ->
{error, error_binary(Message)};
exit:timeout ->
{error, <<"timeout">>};
exit:{timeout, _} ->
{error, <<"timeout">>};
Class:Reason ->
logger:error(
"Gateway TCP RPC method execution failed. method=~ts class=~p reason=~p",
[Method, Class, Reason]
),
{error, <<"internal_error">>}
end.
-spec handle_rpc_response(binary(), rpc_result(), state()) -> state().
handle_rpc_response(RequestId, {ok, Result}, State) ->
_ = send_response_frame(State, RequestId, true, Result, undefined),
decrement_inflight(State);
handle_rpc_response(RequestId, {error, Error}, State) ->
_ = send_response_frame(State, RequestId, false, undefined, Error),
decrement_inflight(State).
-spec send_response_frame(state(), binary(), boolean(), term(), binary() | undefined) -> ok | {error, term()}.
send_response_frame(State, RequestId, true, Result, _Error) ->
send_frame(State, #{
<<"type">> => <<"response">>,
<<"id">> => RequestId,
<<"ok">> => true,
<<"result">> => Result
});
send_response_frame(State, RequestId, false, _Result, Error) ->
send_frame(State, #{
<<"type">> => <<"response">>,
<<"id">> => RequestId,
<<"ok">> => false,
<<"error">> => Error
}).
-spec send_error_frame(state(), binary()) -> ok | {error, term()}.
send_error_frame(State, Error) ->
send_frame(State, #{
<<"type">> => <<"error">>,
<<"error">> => Error
}).
-spec send_frame(state(), map()) -> ok | {error, term()}.
send_frame(#{socket := Socket}, Frame) ->
gen_tcp:send(Socket, encode_frame(Frame)).
-spec encode_frame(map()) -> binary().
encode_frame(Frame) ->
Payload = iolist_to_binary(json:encode(Frame)),
Length = integer_to_binary(byte_size(Payload)),
<<Length/binary, "\n", Payload/binary>>.
-spec decode_frames(binary(), [map()]) -> {ok, [map()], binary()} | {error, term()}.
decode_frames(Buffer, Acc) ->
case binary:match(Buffer, <<"\n">>) of
nomatch ->
{ok, lists:reverse(Acc), Buffer};
{Pos, 1} ->
LengthBin = binary:part(Buffer, 0, Pos),
case parse_length(LengthBin) of
{ok, Length} ->
HeaderSize = Pos + 1,
RequiredSize = HeaderSize + Length,
case byte_size(Buffer) >= RequiredSize of
false ->
{ok, lists:reverse(Acc), Buffer};
true ->
Payload = binary:part(Buffer, HeaderSize, Length),
RestSize = byte_size(Buffer) - RequiredSize,
Rest = binary:part(Buffer, RequiredSize, RestSize),
case decode_payload(Payload) of
{ok, Frame} ->
decode_frames(Rest, [Frame | Acc]);
{error, Reason} ->
{error, Reason}
end
end;
{error, Reason} ->
{error, Reason}
end
end.
-spec decode_payload(binary()) -> {ok, map()} | {error, term()}.
decode_payload(Payload) ->
case catch json:decode(Payload) of
{'EXIT', _} ->
{error, invalid_json};
Frame when is_map(Frame) ->
{ok, Frame};
_ ->
{error, invalid_json}
end.
-spec parse_length(binary()) -> {ok, non_neg_integer()} | {error, term()}.
parse_length(<<>>) ->
{error, invalid_frame_length};
parse_length(LengthBin) ->
try
Length = binary_to_integer(LengthBin),
case Length >= 0 andalso Length =< ?MAX_FRAME_BYTES of
true -> {ok, Length};
false -> {error, invalid_frame_length}
end
catch
_:_ ->
{error, invalid_frame_length}
end.
-spec secure_compare(binary(), binary()) -> boolean().
secure_compare(Left, Right) when is_binary(Left), is_binary(Right) ->
case byte_size(Left) =:= byte_size(Right) of
true ->
crypto:hash_equals(Left, Right);
false ->
false
end.
-spec request_id_from_frame(map()) -> binary().
request_id_from_frame(Frame) ->
case maps:get(<<"id">>, Frame, <<>>) of
Id when is_binary(Id) ->
Id;
Id when is_integer(Id) ->
integer_to_binary(Id);
_ ->
<<>>
end.
-spec increment_inflight(state()) -> state().
increment_inflight(#{inflight := Inflight} = State) ->
State#{inflight => Inflight + 1}.
-spec decrement_inflight(state()) -> state().
decrement_inflight(#{inflight := Inflight} = State) when Inflight > 0 ->
State#{inflight => Inflight - 1};
decrement_inflight(State) ->
State.
-spec error_binary(term()) -> binary().
error_binary(Value) when is_binary(Value) ->
Value;
error_binary(Value) when is_list(Value) ->
unicode:characters_to_binary(Value);
error_binary(Value) when is_atom(Value) ->
atom_to_binary(Value, utf8);
error_binary(Value) ->
unicode:characters_to_binary(io_lib:format("~p", [Value])).
-spec protocol_error_binary(term()) -> binary().
protocol_error_binary(invalid_json) ->
<<"invalid_json">>;
protocol_error_binary(invalid_frame_length) ->
<<"invalid_frame_length">>;
protocol_error_binary(input_buffer_limit_exceeded) ->
<<"input_buffer_limit_exceeded">>.
-spec close_socket(inet:socket()) -> ok.
close_socket(Socket) ->
catch gen_tcp:close(Socket),
ok.
-spec max_inflight() -> pos_integer().
max_inflight() ->
case fluxer_gateway_env:get(gateway_http_rpc_max_concurrency) of
Value when is_integer(Value), Value > 0 ->
Value;
_ ->
?DEFAULT_MAX_INFLIGHT
end.
-spec max_input_buffer_bytes() -> pos_integer().
max_input_buffer_bytes() ->
case fluxer_gateway_env:get(gateway_rpc_tcp_max_input_buffer_bytes) of
Value when is_integer(Value), Value > 0 ->
Value;
_ ->
?DEFAULT_MAX_INPUT_BUFFER_BYTES
end.
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
decode_single_frame_test() ->
Frame = #{<<"type">> => <<"ping">>},
Encoded = encode_frame(Frame),
?assertEqual({ok, [Frame], <<>>}, decode_frames(Encoded, [])).
decode_multiple_frames_test() ->
FrameA = #{<<"type">> => <<"ping">>},
FrameB = #{<<"type">> => <<"pong">>},
Encoded = <<(encode_frame(FrameA))/binary, (encode_frame(FrameB))/binary>>,
?assertEqual({ok, [FrameA, FrameB], <<>>}, decode_frames(Encoded, [])).
decode_partial_frame_test() ->
Frame = #{<<"type">> => <<"ping">>},
Encoded = encode_frame(Frame),
Prefix = binary:part(Encoded, 0, 3),
?assertEqual({ok, [], Prefix}, decode_frames(Prefix, [])).
invalid_length_test() ->
?assertEqual({error, invalid_frame_length}, decode_frames(<<"x\n{}">>, [])).
secure_compare_test() ->
?assert(secure_compare(<<"abc">>, <<"abc">>)),
?assertNot(secure_compare(<<"abc">>, <<"abd">>)),
?assertNot(secure_compare(<<"abc">>, <<"abcd">>)).
-endif.

View File

@@ -1,108 +0,0 @@
%% Copyright (C) 2026 Fluxer Contributors
%%
%% This file is part of Fluxer.
%%
%% Fluxer is free software: you can redistribute it and/or modify
%% it under the terms of the GNU Affero General Public License as published by
%% the Free Software Foundation, either version 3 of the License, or
%% (at your option) any later version.
%%
%% Fluxer is distributed in the hope that it will be useful,
%% but WITHOUT ANY WARRANTY; without even the implied warranty of
%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
%% GNU Affero General Public License for more details.
%%
%% You should have received a copy of the GNU Affero General Public License
%% along with Fluxer. If not, see <https://www.gnu.org/licenses/>.
-module(gateway_rpc_tcp_server).
-behaviour(gen_server).
-export([start_link/0, accept_loop/1]).
-export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]).
-type state() :: #{
listen_socket := inet:socket(),
acceptor_pid := pid(),
port := inet:port_number()
}.
-spec start_link() -> {ok, pid()} | {error, term()}.
start_link() ->
gen_server:start_link({local, ?MODULE}, ?MODULE, [], []).
-spec init([]) -> {ok, state()} | {stop, term()}.
init([]) ->
process_flag(trap_exit, true),
Port = fluxer_gateway_env:get(rpc_tcp_port),
case gen_tcp:listen(Port, listen_options()) of
{ok, ListenSocket} ->
AcceptorPid = spawn_link(?MODULE, accept_loop, [ListenSocket]),
logger:info("Gateway TCP RPC listener started on port ~p", [Port]),
{ok, #{
listen_socket => ListenSocket,
acceptor_pid => AcceptorPid,
port => Port
}};
{error, Reason} ->
{stop, {rpc_tcp_listen_failed, Port, Reason}}
end.
-spec handle_call(term(), gen_server:from(), state()) -> {reply, ok, state()}.
handle_call(_Request, _From, State) ->
{reply, ok, State}.
-spec handle_cast(term(), state()) -> {noreply, state()}.
handle_cast(_Msg, State) ->
{noreply, State}.
-spec handle_info(term(), state()) -> {noreply, state()}.
handle_info({'EXIT', Pid, Reason}, #{acceptor_pid := Pid, listen_socket := ListenSocket} = State) ->
case Reason of
normal ->
{noreply, State};
shutdown ->
{noreply, State};
_ ->
logger:error("Gateway TCP RPC acceptor crashed: ~p", [Reason]),
NewAcceptorPid = spawn_link(?MODULE, accept_loop, [ListenSocket]),
{noreply, State#{acceptor_pid => NewAcceptorPid}}
end;
handle_info(_Info, State) ->
{noreply, State}.
-spec terminate(term(), state()) -> ok.
terminate(_Reason, #{listen_socket := ListenSocket, port := Port}) ->
catch gen_tcp:close(ListenSocket),
logger:info("Gateway TCP RPC listener stopped on port ~p", [Port]),
ok.
-spec code_change(term(), state(), term()) -> {ok, state()}.
code_change(_OldVsn, State, _Extra) ->
{ok, State}.
-spec accept_loop(inet:socket()) -> ok.
accept_loop(ListenSocket) ->
case gen_tcp:accept(ListenSocket) of
{ok, Socket} ->
_ = spawn_link(?MODULE, accept_loop, [ListenSocket]),
gateway_rpc_tcp_connection:serve(Socket);
{error, closed} ->
ok;
{error, Reason} ->
logger:error("Gateway TCP RPC accept failed: ~p", [Reason]),
timer:sleep(200),
accept_loop(ListenSocket)
end.
-spec listen_options() -> [gen_tcp:listen_option()].
listen_options() ->
[
binary,
{packet, raw},
{active, false},
{reuseaddr, true},
{nodelay, true},
{backlog, 4096},
{keepalive, true}
].

View File

@@ -316,6 +316,7 @@ is_fluxer_module(Module) ->
lists:prefix("gateway_http_", ModuleStr) orelse
lists:prefix("session", ModuleStr) orelse
lists:prefix("guild", ModuleStr) orelse
lists:prefix("passive_sync_registry", ModuleStr) orelse
lists:prefix("presence", ModuleStr) orelse
lists:prefix("push", ModuleStr) orelse
lists:prefix("push_dispatcher", ModuleStr) orelse

View File

@@ -17,96 +17,72 @@
-module(rpc_client).
-export([
call/1,
call/2,
get_rpc_url/0,
get_rpc_url/1,
get_rpc_headers/0
]).
-export([call/1]).
-define(NATS_RPC_SUBJECT, <<"rpc.api">>).
-define(NATS_RPC_TIMEOUT_MS, 10000).
-type rpc_request() :: map().
-type rpc_response() :: {ok, map()} | {error, term()}.
-type rpc_options() :: map().
-spec call(rpc_request()) -> rpc_response().
call(Request) ->
call(Request, #{}).
case gateway_nats_rpc:get_connection() of
{ok, undefined} ->
{error, not_connected};
{ok, Conn} ->
do_request(Conn, Request);
{error, Reason} ->
{error, {not_connected, Reason}}
end.
-spec call(rpc_request(), rpc_options()) -> rpc_response().
call(Request, _Options) ->
Url = get_rpc_url(),
Headers = get_rpc_headers(),
Body = json:encode(Request),
case gateway_http_client:request(rpc, post, Url, Headers, Body) of
{ok, 200, _RespHeaders, RespBody} ->
handle_success_response(RespBody);
{ok, StatusCode, _RespHeaders, RespBody} ->
handle_error_response(StatusCode, RespBody);
-spec do_request(nats:conn(), rpc_request()) -> rpc_response().
do_request(Conn, Request) ->
Payload = iolist_to_binary(json:encode(Request)),
case nats:request(Conn, ?NATS_RPC_SUBJECT, Payload, #{timeout => ?NATS_RPC_TIMEOUT_MS}) of
{ok, {ResponseBin, _MsgOpts}} ->
handle_nats_response(ResponseBin);
{error, timeout} ->
{error, timeout};
{error, no_responders} ->
{error, no_responders};
{error, Reason} ->
{error, Reason}
end.
-spec handle_success_response(binary()) -> rpc_response().
handle_success_response(RespBody) ->
Response = json:decode(RespBody),
Data = maps:get(<<"data">>, Response, #{}),
{ok, Data}.
-spec handle_error_response(pos_integer(), binary()) -> {error, term()}.
handle_error_response(StatusCode, RespBody) ->
{error, {http_error, StatusCode, RespBody}}.
-spec get_rpc_url() -> string().
get_rpc_url() ->
ApiHost = fluxer_gateway_env:get(api_host),
get_rpc_url(ApiHost).
-spec get_rpc_url(string() | binary()) -> string().
get_rpc_url(ApiHost) ->
BaseUrl = api_host_base_url(ApiHost),
BaseUrl ++ "/_rpc".
-spec api_host_base_url(string() | binary()) -> string().
api_host_base_url(ApiHost) ->
HostString = ensure_string(ApiHost),
Normalized = normalize_api_host(HostString),
strip_trailing_slash(Normalized).
-spec ensure_string(binary() | string()) -> string().
ensure_string(Value) when is_binary(Value) ->
binary_to_list(Value);
ensure_string(Value) when is_list(Value) ->
Value.
-spec normalize_api_host(string()) -> string().
normalize_api_host(Host) ->
Lower = string:lowercase(Host),
case {has_protocol_prefix(Lower, "http://"), has_protocol_prefix(Lower, "https://")} of
{true, _} -> Host;
{_, true} -> Host;
_ -> "http://" ++ Host
-spec handle_nats_response(iodata()) -> rpc_response().
handle_nats_response(ResponseBin) ->
Response = json:decode(iolist_to_binary(ResponseBin)),
case maps:get(<<"_error">>, Response, undefined) of
undefined ->
Data = maps:get(<<"data">>, Response, #{}),
{ok, Data};
_ ->
Status = maps:get(<<"status">>, Response, 500),
Message = maps:get(<<"message">>, Response, <<"unknown error">>),
{error, {rpc_error, Status, Message}}
end.
-spec has_protocol_prefix(string(), string()) -> boolean().
has_protocol_prefix(Str, Prefix) ->
case string:prefix(Str, Prefix) of
nomatch -> false;
_ -> true
end.
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-spec strip_trailing_slash(string()) -> string().
strip_trailing_slash([]) ->
"";
strip_trailing_slash(Url) ->
case lists:last(Url) of
$/ -> strip_trailing_slash(lists:droplast(Url));
_ -> Url
end.
handle_nats_response_ok_test() ->
Response = json:encode(#{
<<"type">> => <<"session">>,
<<"data">> => #{<<"user">> => <<"test">>}
}),
?assertEqual({ok, #{<<"user">> => <<"test">>}}, handle_nats_response(Response)).
-spec get_rpc_headers() -> [{binary() | string(), binary() | string()}].
get_rpc_headers() ->
RpcSecretKey = fluxer_gateway_env:get(rpc_secret_key),
AuthHeader = {<<"Authorization">>, <<"Bearer ", RpcSecretKey/binary>>},
InitialHeaders = [AuthHeader],
gateway_tracing:inject_rpc_headers(InitialHeaders).
handle_nats_response_error_401_test() ->
Response = json:encode(#{<<"_error">> => true, <<"status">> => 401, <<"message">> => <<"Unauthorized">>}),
?assertEqual({error, {rpc_error, 401, <<"Unauthorized">>}}, handle_nats_response(Response)).
handle_nats_response_error_429_test() ->
Response = json:encode(#{<<"_error">> => true, <<"status">> => 429, <<"message">> => <<"Rate limited">>}),
?assertEqual({error, {rpc_error, 429, <<"Rate limited">>}}, handle_nats_response(Response)).
handle_nats_response_error_500_test() ->
Response = json:encode(#{<<"_error">> => true, <<"status">> => 500, <<"message">> => <<"Internal error">>}),
?assertEqual({error, {rpc_error, 500, <<"Internal error">>}}, handle_nats_response(Response)).
-endif.

File diff suppressed because it is too large Load Diff

View File

@@ -227,16 +227,7 @@ maybe_disconnect_voice_for_user(UserId, ProcessedUsers, State) ->
-spec ensure_unavailability_cache_table() -> ok.
ensure_unavailability_cache_table() ->
case ets:whereis(?GUILD_UNAVAILABILITY_CACHE) of
undefined ->
try ets:new(?GUILD_UNAVAILABILITY_CACHE, [named_table, public, set, {read_concurrency, true}]) of
_ -> ok
catch
error:badarg -> ok
end;
_ ->
ok
end.
guild_ets_utils:ensure_table(?GUILD_UNAVAILABILITY_CACHE, [named_table, public, set, {read_concurrency, true}]).
-spec set_cached_unavailability_mode(guild_id(), unavailability_mode()) -> ok.
set_cached_unavailability_mode(GuildId, available) ->
@@ -449,10 +440,7 @@ state_for_unavailability_transition_test(GuildId, NonStaffPid, StaffPid) ->
active_guilds => sets:new(),
user_roles => [],
bot => false,
is_staff => false,
previous_passive_updates => #{},
previous_passive_channel_versions => #{},
previous_passive_voice_states => #{}
is_staff => false
},
<<"staff">> => #{
session_id => <<"staff">>,
@@ -462,10 +450,7 @@ state_for_unavailability_transition_test(GuildId, NonStaffPid, StaffPid) ->
active_guilds => sets:new(),
user_roles => [],
bot => false,
is_staff => true,
previous_passive_updates => #{},
previous_passive_channel_versions => #{},
previous_passive_voice_states => #{}
is_staff => true
}
},
presence_subscriptions => #{},

View File

@@ -18,6 +18,7 @@
-module(guild_client).
-export([voice_state_update/3]).
-export([voice_state_update/4]).
-export_type([
voice_state_update_success/0,
@@ -54,17 +55,40 @@
-spec voice_state_update(pid(), map(), timeout()) -> voice_state_update_result().
voice_state_update(GuildPid, Request, Timeout) ->
ensure_table(),
case acquire_slot(GuildPid) of
TargetPid = GuildPid,
case acquire_slot(TargetPid) of
ok ->
try
execute_with_circuit_breaker(GuildPid, Request, Timeout)
execute_with_circuit_breaker(TargetPid, Request, Timeout)
after
release_slot(GuildPid)
release_slot(TargetPid)
end;
{error, Reason} ->
{error, Reason}
end.
-spec voice_state_update(pid(), integer(), map(), timeout()) -> voice_state_update_result().
voice_state_update(GuildPid, GuildId, Request, Timeout) ->
ensure_table(),
TargetPid = resolve_voice_pid(GuildId, GuildPid),
case acquire_slot(TargetPid) of
ok ->
try
execute_with_circuit_breaker(TargetPid, Request, Timeout)
after
release_slot(TargetPid)
end;
{error, Reason} ->
{error, Reason}
end.
-spec resolve_voice_pid(integer(), pid()) -> pid().
resolve_voice_pid(GuildId, FallbackGuildPid) ->
case guild_voice_server:lookup(GuildId) of
{ok, VoicePid} -> VoicePid;
{error, not_found} -> FallbackGuildPid
end.
-spec execute_with_circuit_breaker(pid(), map(), timeout()) -> voice_state_update_result().
execute_with_circuit_breaker(GuildPid, Request, Timeout) ->
case get_circuit_state(GuildPid) of
@@ -203,23 +227,13 @@ safe_lookup(GuildPid) ->
-spec ensure_table() -> ok.
ensure_table() ->
case ets:whereis(?CIRCUIT_BREAKER_TABLE) of
undefined ->
try
ets:new(?CIRCUIT_BREAKER_TABLE, [
named_table,
public,
set,
{read_concurrency, true},
{write_concurrency, true}
]),
ok
catch
error:badarg -> ok
end;
_ ->
ok
end.
guild_ets_utils:ensure_table(?CIRCUIT_BREAKER_TABLE, [
named_table,
public,
set,
{read_concurrency, true},
{write_concurrency, true}
]).
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").

View File

@@ -0,0 +1,278 @@
%% Copyright (C) 2026 Fluxer Contributors
%%
%% This file is part of Fluxer.
%%
%% Fluxer is free software: you can redistribute it and/or modify
%% it under the terms of the GNU Affero General Public License as published by
%% the Free Software Foundation, either version 3 of the License, or
%% (at your option) any later version.
%%
%% Fluxer is distributed in the hope that it will be useful,
%% but WITHOUT ANY WARRANTY; without even the implied warranty of
%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
%% GNU Affero General Public License for more details.
%%
%% You should have received a copy of the GNU Affero General Public License
%% along with Fluxer. If not, see <https://www.gnu.org/licenses/>.
-module(guild_common).
-export([
safe_call/3,
parse_event_data/1,
relay_upsert_voice_state/2,
strip_members/1,
build_shard_state/4,
merge_cluster_state/2,
merge_user_set_maps/2
]).
-type guild_id() :: integer().
-type shard_index() :: non_neg_integer().
-spec safe_call(pid(), term(), timeout()) -> term().
safe_call(Pid, Msg, Timeout) when is_pid(Pid) ->
try gen_server:call(Pid, Msg, Timeout) of
Reply -> Reply
catch
exit:{timeout, _} -> {error, timeout};
exit:{noproc, _} -> {error, noproc};
exit:{normal, _} -> {error, noproc};
_:Reason -> {error, Reason}
end.
-spec parse_event_data(binary() | map()) -> map().
parse_event_data(EventData) when is_binary(EventData) ->
json:decode(EventData);
parse_event_data(EventData) when is_map(EventData) ->
EventData.
-spec relay_upsert_voice_state(map(), map()) -> map().
relay_upsert_voice_state(VoiceState, State) when is_map(VoiceState) ->
ConnectionId = maps:get(<<"connection_id">>, VoiceState, undefined),
case ConnectionId of
undefined ->
State;
_ ->
VoiceStates0 = maps:get(voice_states, State, #{}),
ChannelId = maps:get(<<"channel_id">>, VoiceState, null),
VoiceStates =
case ChannelId of
null -> maps:remove(ConnectionId, VoiceStates0);
_ -> maps:put(ConnectionId, VoiceState, VoiceStates0)
end,
maps:put(voice_states, VoiceStates, State)
end;
relay_upsert_voice_state(_, State) ->
State.
-spec strip_members(map()) -> map().
strip_members(Data) when is_map(Data) ->
Data1 = maps:remove(<<"members">>, Data),
maps:remove(<<"member_role_index">>, Data1);
strip_members(Data) ->
Data.
-spec build_shard_state(guild_id(), map(), pos_integer(), shard_index()) -> map().
build_shard_state(GuildId, Data, ShardCount, ShardIndex) ->
DisableCache = ShardIndex =/= 0,
MemberCount = guild_data_index:member_count(Data),
ShardData =
case DisableCache of
true -> strip_members(Data);
false -> Data
end,
ShardState0 = #{
id => GuildId,
data => ShardData,
sessions => #{},
member_count => MemberCount,
disable_push_notifications => true,
disable_member_list_updates => DisableCache,
disable_auto_stop_on_empty => true,
very_large_guild_coordinator_pid => self(),
very_large_guild_shard_count => ShardCount,
very_large_guild_shard_index => ShardIndex
},
case DisableCache of
true -> maps:put(disable_permission_cache_updates, true, ShardState0);
false -> ShardState0
end.
-spec merge_cluster_state(map(), map()) -> map().
merge_cluster_state(Acc, Frag) ->
SessionsAcc = maps:get(sessions, Acc, #{}),
SessionsFrag = maps:get(sessions, Frag, #{}),
VoiceAcc = maps:get(voice_states, Acc, #{}),
VoiceFrag = maps:get(voice_states, Frag, #{}),
VAAcc = maps:get(virtual_channel_access, Acc, #{}),
VAFrag = maps:get(virtual_channel_access, Frag, #{}),
PendingAcc = maps:get(virtual_channel_access_pending, Acc, #{}),
PendingFrag = maps:get(virtual_channel_access_pending, Frag, #{}),
PreserveAcc = maps:get(virtual_channel_access_preserve, Acc, #{}),
PreserveFrag = maps:get(virtual_channel_access_preserve, Frag, #{}),
MoveAcc = maps:get(virtual_channel_access_move_pending, Acc, #{}),
MoveFrag = maps:get(virtual_channel_access_move_pending, Frag, #{}),
Acc#{
sessions => maps:merge(SessionsAcc, SessionsFrag),
voice_states => maps:merge(VoiceAcc, VoiceFrag),
virtual_channel_access => merge_user_set_maps(VAAcc, VAFrag),
virtual_channel_access_pending => merge_user_set_maps(PendingAcc, PendingFrag),
virtual_channel_access_preserve => merge_user_set_maps(PreserveAcc, PreserveFrag),
virtual_channel_access_move_pending => merge_user_set_maps(MoveAcc, MoveFrag)
}.
-spec merge_user_set_maps(map(), map()) -> map().
merge_user_set_maps(A, B) ->
maps:fold(
fun(UserId, SetB, Acc) ->
case maps:get(UserId, Acc, undefined) of
undefined ->
maps:put(UserId, SetB, Acc);
SetA ->
maps:put(UserId, sets:union(SetA, SetB), Acc)
end
end,
A,
B
).
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
safe_call_timeout_test() ->
Pid = spawn(fun() ->
receive
{'$gen_call', _From, _Msg} ->
timer:sleep(5000)
end
end),
Result = safe_call(Pid, ping, 50),
?assertEqual({error, timeout}, Result),
exit(Pid, kill),
ok.
safe_call_noproc_test() ->
Pid = spawn(fun() -> ok end),
timer:sleep(50),
Result = safe_call(Pid, ping, 100),
?assertMatch({error, _}, Result),
ok.
parse_event_data_binary_test() ->
Binary = <<"{\"key\":\"value\"}">>,
Result = parse_event_data(Binary),
?assertEqual(#{<<"key">> => <<"value">>}, Result).
parse_event_data_map_test() ->
Map = #{<<"key">> => <<"value">>},
Result = parse_event_data(Map),
?assertEqual(Map, Result).
relay_upsert_voice_state_adds_state_test() ->
VoiceState = #{
<<"connection_id">> => <<"conn-1">>,
<<"channel_id">> => <<"100">>,
<<"user_id">> => <<"42">>
},
State0 = #{voice_states => #{}},
State1 = relay_upsert_voice_state(VoiceState, State0),
VoiceStates = maps:get(voice_states, State1),
?assertEqual(VoiceState, maps:get(<<"conn-1">>, VoiceStates)).
relay_upsert_voice_state_removes_on_null_channel_test() ->
Existing = #{<<"connection_id">> => <<"conn-1">>, <<"channel_id">> => <<"100">>},
State0 = #{voice_states => #{<<"conn-1">> => Existing}},
RemoveState = #{<<"connection_id">> => <<"conn-1">>, <<"channel_id">> => null},
State1 = relay_upsert_voice_state(RemoveState, State0),
VoiceStates = maps:get(voice_states, State1),
?assertEqual(false, maps:is_key(<<"conn-1">>, VoiceStates)).
relay_upsert_voice_state_no_connection_id_test() ->
State0 = #{voice_states => #{}},
State1 = relay_upsert_voice_state(#{<<"channel_id">> => <<"100">>}, State0),
?assertEqual(State0, State1).
relay_upsert_voice_state_non_map_test() ->
State0 = #{voice_states => #{}},
State1 = relay_upsert_voice_state(not_a_map, State0),
?assertEqual(State0, State1).
strip_members_test() ->
Data = #{
<<"members">> => [#{<<"user">> => #{<<"id">> => <<"1">>}}],
<<"member_role_index">> => #{1 => [<<"role1">>]},
<<"channels">> => [#{<<"id">> => <<"10">>}],
<<"roles">> => [#{<<"id">> => <<"role1">>}]
},
Stripped = strip_members(Data),
?assertEqual(false, maps:is_key(<<"members">>, Stripped)),
?assertEqual(false, maps:is_key(<<"member_role_index">>, Stripped)),
?assertEqual([#{<<"id">> => <<"10">>}], maps:get(<<"channels">>, Stripped)),
?assertEqual([#{<<"id">> => <<"role1">>}], maps:get(<<"roles">>, Stripped)).
strip_members_empty_test() ->
?assertEqual(#{}, strip_members(#{})).
strip_members_non_map_test() ->
?assertEqual(not_a_map, strip_members(not_a_map)).
merge_user_set_maps_test() ->
SetA = sets:from_list([1, 2]),
SetB = sets:from_list([2, 3]),
MapA = #{10 => SetA},
MapB = #{10 => SetB, 20 => SetB},
Merged = merge_user_set_maps(MapA, MapB),
?assert(maps:is_key(10, Merged)),
?assert(maps:is_key(20, Merged)),
MergedSet10 = maps:get(10, Merged),
?assert(sets:is_element(1, MergedSet10)),
?assert(sets:is_element(2, MergedSet10)),
?assert(sets:is_element(3, MergedSet10)),
?assertEqual(3, sets:size(MergedSet10)),
?assertEqual(SetB, maps:get(20, Merged)).
merge_user_set_maps_empty_test() ->
?assertEqual(#{}, merge_user_set_maps(#{}, #{})).
merge_cluster_state_test() ->
Acc = #{
sessions => #{<<"s1">> => #{user_id => 1}},
voice_states => #{<<"c1">> => #{channel_id => 10}},
virtual_channel_access => #{},
virtual_channel_access_pending => #{},
virtual_channel_access_preserve => #{},
virtual_channel_access_move_pending => #{}
},
Frag = #{
sessions => #{<<"s2">> => #{user_id => 2}},
voice_states => #{<<"c2">> => #{channel_id => 20}},
virtual_channel_access => #{},
virtual_channel_access_pending => #{},
virtual_channel_access_preserve => #{},
virtual_channel_access_move_pending => #{}
},
Merged = merge_cluster_state(Acc, Frag),
?assert(maps:is_key(<<"s1">>, maps:get(sessions, Merged))),
?assert(maps:is_key(<<"s2">>, maps:get(sessions, Merged))),
?assert(maps:is_key(<<"c1">>, maps:get(voice_states, Merged))),
?assert(maps:is_key(<<"c2">>, maps:get(voice_states, Merged))).
build_shard_state_primary_test() ->
Data = #{<<"members">> => [#{<<"user">> => #{<<"id">> => <<"1">>}}]},
ShardState = build_shard_state(100, Data, 4, 0),
?assertEqual(100, maps:get(id, ShardState)),
?assertEqual(Data, maps:get(data, ShardState)),
?assertEqual(false, maps:get(disable_member_list_updates, ShardState)),
?assertEqual(false, maps:is_key(disable_permission_cache_updates, ShardState)).
build_shard_state_secondary_test() ->
Data = #{<<"members">> => [#{<<"user">> => #{<<"id">> => <<"1">>}}]},
ShardState = build_shard_state(100, Data, 4, 2),
?assertEqual(100, maps:get(id, ShardState)),
ShardData = maps:get(data, ShardState),
?assertEqual(false, maps:is_key(<<"members">>, ShardData)),
?assertEqual(true, maps:get(disable_member_list_updates, ShardState)),
?assertEqual(true, maps:get(disable_permission_cache_updates, ShardState)).
-endif.

View File

@@ -0,0 +1,128 @@
%% Copyright (C) 2026 Fluxer Contributors
%%
%% This file is part of Fluxer.
%%
%% Fluxer is free software: you can redistribute it and/or modify
%% it under the terms of the GNU Affero General Public License as published by
%% the Free Software Foundation, either version 3 of the License, or
%% (at your option) any later version.
%%
%% Fluxer is distributed in the hope that it will be useful,
%% but WITHOUT ANY WARRANTY; without even the implied warranty of
%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
%% GNU Affero General Public License for more details.
%%
%% You should have received a copy of the GNU Affero General Public License
%% along with Fluxer. If not, see <https://www.gnu.org/licenses/>.
-module(guild_counts_cache).
-export([
init/0,
update/3,
get/1,
delete/1
]).
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-endif.
-define(TABLE, guild_counts_cache).
-type guild_id() :: integer().
-spec init() -> ok.
init() ->
case ets:whereis(?TABLE) of
undefined ->
_ = ets:new(?TABLE, [
named_table,
public,
set,
{read_concurrency, true},
{write_concurrency, true}
]),
ok;
_ ->
ok
end.
-spec update(guild_id(), non_neg_integer(), non_neg_integer()) -> ok.
update(GuildId, MemberCount, OnlineCount) ->
ensure_table(),
ets:insert(?TABLE, {GuildId, MemberCount, OnlineCount}),
ok.
-spec get(guild_id()) -> {ok, non_neg_integer(), non_neg_integer()} | miss.
get(GuildId) ->
case catch ets:lookup(?TABLE, GuildId) of
[{GuildId, MemberCount, OnlineCount}] ->
{ok, MemberCount, OnlineCount};
_ ->
miss
end.
-spec delete(guild_id()) -> ok.
delete(GuildId) ->
catch ets:delete(?TABLE, GuildId),
ok.
-spec ensure_table() -> ok.
ensure_table() ->
case ets:whereis(?TABLE) of
undefined -> init();
_ -> ok
end.
-ifdef(TEST).
init_creates_table_test() ->
catch ets:delete(?TABLE),
ok = init(),
?assertNotEqual(undefined, ets:whereis(?TABLE)),
catch ets:delete(?TABLE).
init_idempotent_test() ->
catch ets:delete(?TABLE),
ok = init(),
ok = init(),
?assertNotEqual(undefined, ets:whereis(?TABLE)),
catch ets:delete(?TABLE).
update_and_get_test() ->
catch ets:delete(?TABLE),
ok = init(),
ok = update(100, 50, 25),
?assertEqual({ok, 50, 25}, guild_counts_cache:get(100)),
catch ets:delete(?TABLE).
get_miss_test() ->
catch ets:delete(?TABLE),
ok = init(),
?assertEqual(miss, guild_counts_cache:get(999)),
catch ets:delete(?TABLE).
update_overwrites_test() ->
catch ets:delete(?TABLE),
ok = init(),
ok = update(100, 50, 25),
ok = update(100, 60, 30),
?assertEqual({ok, 60, 30}, guild_counts_cache:get(100)),
catch ets:delete(?TABLE).
delete_removes_entry_test() ->
catch ets:delete(?TABLE),
ok = init(),
ok = update(100, 50, 25),
ok = delete(100),
?assertEqual(miss, guild_counts_cache:get(100)),
catch ets:delete(?TABLE).
delete_nonexistent_test() ->
catch ets:delete(?TABLE),
ok = init(),
ok = delete(999),
catch ets:delete(?TABLE).
-endif.

View File

@@ -477,4 +477,296 @@ put_member_and_remove_member_keep_member_role_index_in_sync_test() ->
Index2 = member_role_index(Data2),
?assertEqual(undefined, maps:get(30, Index2, undefined)).
normalize_data_empty_lists_test() ->
Data = #{
<<"members">> => [],
<<"roles">> => [],
<<"channels">> => []
},
Normalized = normalize_data(Data),
?assertEqual(#{}, maps:get(<<"members">>, Normalized)),
?assertEqual([], maps:get(<<"roles">>, Normalized)),
?assertEqual([], maps:get(<<"channels">>, Normalized)),
?assertEqual(#{}, maps:get(<<"role_index">>, Normalized)),
?assertEqual(#{}, maps:get(<<"channel_index">>, Normalized)),
?assertEqual(#{}, maps:get(<<"member_role_index">>, Normalized)).
normalize_data_non_map_input_test() ->
?assertEqual(not_a_map, normalize_data(not_a_map)),
?assertEqual(42, normalize_data(42)).
normalize_data_missing_keys_defaults_test() ->
Data = #{},
Normalized = normalize_data(Data),
?assertEqual(#{}, maps:get(<<"members">>, Normalized)),
?assertEqual([], maps:get(<<"roles">>, Normalized)),
?assertEqual([], maps:get(<<"channels">>, Normalized)).
normalize_data_already_map_members_test() ->
Data = #{
<<"members">> => #{
1 => #{<<"user">> => #{<<"id">> => <<"1">>}, <<"nick">> => <<"a">>}
},
<<"roles">> => [],
<<"channels">> => []
},
Normalized = normalize_data(Data),
Members = maps:get(<<"members">>, Normalized),
?assertEqual(1, map_size(Members)),
?assertMatch(#{1 := _}, Members).
member_map_non_map_input_test() ->
?assertEqual(#{}, member_map(not_a_map)),
?assertEqual(#{}, member_map(42)).
member_map_invalid_members_value_test() ->
Data = #{<<"members">> => <<"invalid">>},
?assertEqual(#{}, member_map(Data)).
member_map_members_without_user_test() ->
Data = #{<<"members">> => [#{<<"nick">> => <<"orphan">>}]},
?assertEqual(#{}, member_map(Data)).
member_map_duplicate_user_ids_last_wins_test() ->
Data = #{
<<"members">> => [
#{<<"user">> => #{<<"id">> => <<"1">>}, <<"nick">> => <<"first">>},
#{<<"user">> => #{<<"id">> => <<"1">>}, <<"nick">> => <<"second">>}
]
},
MemberMap = member_map(Data),
?assertEqual(1, map_size(MemberMap)),
?assertEqual(<<"second">>, maps:get(<<"nick">>, maps:get(1, MemberMap))).
member_list_empty_test() ->
Data = #{<<"members">> => #{}},
?assertEqual([], member_list(Data)).
member_ids_returns_all_user_ids_test() ->
Data = #{
<<"members">> => #{
5 => #{<<"user">> => #{<<"id">> => <<"5">>}},
3 => #{<<"user">> => #{<<"id">> => <<"3">>}},
8 => #{<<"user">> => #{<<"id">> => <<"8">>}}
}
},
Ids = lists:sort(member_ids(Data)),
?assertEqual([3, 5, 8], Ids).
get_member_non_integer_key_test() ->
Data = #{<<"members">> => #{1 => #{<<"user">> => #{<<"id">> => <<"1">>}}}},
?assertEqual(undefined, get_member(not_an_integer, Data)).
get_member_missing_user_test() ->
Data = #{<<"members">> => #{}},
?assertEqual(undefined, get_member(999, Data)).
put_member_no_user_id_returns_unchanged_test() ->
Data = #{<<"members">> => #{}},
?assertEqual(Data, put_member(#{<<"nick">> => <<"orphan">>}, Data)).
put_member_non_map_member_returns_unchanged_test() ->
Data = #{<<"members">> => #{}},
?assertEqual(Data, put_member(not_a_map, Data)).
put_member_non_map_data_returns_unchanged_test() ->
?assertEqual(not_a_map, put_member(#{<<"user">> => #{<<"id">> => <<"1">>}}, not_a_map)).
put_member_adds_new_member_test() ->
Data = #{<<"members">> => #{}},
UpdatedData = put_member(
#{<<"user">> => #{<<"id">> => <<"42">>}, <<"nick">> => <<"new">>},
Data
),
?assertMatch(#{42 := _}, maps:get(<<"members">>, UpdatedData)),
?assertEqual(<<"new">>, maps:get(<<"nick">>, get_member(42, UpdatedData))).
put_member_map_replaces_all_members_test() ->
Data = #{
<<"members">> => #{1 => #{<<"user">> => #{<<"id">> => <<"1">>}}}
},
NewMap = #{2 => #{<<"user">> => #{<<"id">> => <<"2">>}, <<"roles">> => [<<"10">>]}},
Updated = put_member_map(NewMap, Data),
?assertEqual(undefined, get_member(1, Updated)),
?assertMatch(#{2 := _}, maps:get(<<"members">>, Updated)).
put_member_map_non_map_returns_unchanged_test() ->
Data = #{<<"members">> => #{}},
?assertEqual(Data, put_member_map(not_a_map, Data)).
put_member_list_converts_and_stores_test() ->
Data = #{<<"members">> => #{}},
Members = [
#{<<"user">> => #{<<"id">> => <<"1">>}},
#{<<"user">> => #{<<"id">> => <<"2">>}}
],
Updated = put_member_list(Members, Data),
?assertEqual(2, map_size(maps:get(<<"members">>, Updated))).
put_member_list_non_list_returns_unchanged_test() ->
Data = #{<<"members">> => #{}},
?assertEqual(Data, put_member_list(not_a_list, Data)).
remove_member_non_integer_returns_unchanged_test() ->
Data = #{<<"members">> => #{1 => #{<<"user">> => #{<<"id">> => <<"1">>}}}},
?assertEqual(Data, remove_member(not_an_int, Data)).
remove_member_non_existent_test() ->
Data = #{<<"members">> => #{1 => #{<<"user">> => #{<<"id">> => <<"1">>}}}},
Updated = remove_member(999, Data),
?assertEqual(1, map_size(maps:get(<<"members">>, Updated))).
role_list_non_map_input_test() ->
?assertEqual([], role_list(not_a_map)).
role_list_non_list_roles_value_test() ->
Data = #{<<"roles">> => <<"invalid">>},
?assertEqual([], role_list(Data)).
role_index_non_map_input_test() ->
?assertEqual(#{}, role_index(not_a_map)).
role_index_from_list_test() ->
Data = #{
<<"roles">> => [
#{<<"id">> => <<"100">>, <<"name">> => <<"Admin">>},
#{<<"id">> => <<"200">>, <<"name">> => <<"Member">>}
]
},
Index = role_index(Data),
?assertEqual(2, map_size(Index)),
?assertEqual(<<"Admin">>, maps:get(<<"name">>, maps:get(100, Index))).
put_roles_updates_list_and_index_test() ->
Data = #{
<<"roles">> => [#{<<"id">> => <<"1">>, <<"name">> => <<"old">>}]
},
NewRoles = [
#{<<"id">> => <<"10">>, <<"name">> => <<"new1">>},
#{<<"id">> => <<"20">>, <<"name">> => <<"new2">>}
],
Updated = put_roles(NewRoles, Data),
?assertEqual(NewRoles, role_list(Updated)),
?assertEqual(2, map_size(role_index(Updated))).
put_roles_non_map_data_returns_unchanged_test() ->
?assertEqual(not_a_map, put_roles([], not_a_map)).
channel_list_non_map_input_test() ->
?assertEqual([], channel_list(not_a_map)).
channel_list_non_list_channels_value_test() ->
Data = #{<<"channels">> => <<"invalid">>},
?assertEqual([], channel_list(Data)).
channel_index_non_map_input_test() ->
?assertEqual(#{}, channel_index(not_a_map)).
channel_index_from_list_test() ->
Data = #{
<<"channels">> => [
#{<<"id">> => <<"300">>, <<"name">> => <<"general">>},
#{<<"id">> => <<"301">>, <<"name">> => <<"random">>}
]
},
Index = channel_index(Data),
?assertEqual(2, map_size(Index)),
?assertEqual(<<"general">>, maps:get(<<"name">>, maps:get(300, Index))).
put_channels_updates_list_and_index_test() ->
Data = #{<<"channels">> => []},
NewChannels = [
#{<<"id">> => <<"50">>, <<"name">> => <<"ch1">>},
#{<<"id">> => <<"51">>, <<"name">> => <<"ch2">>}
],
Updated = put_channels(NewChannels, Data),
?assertEqual(NewChannels, channel_list(Updated)),
?assertEqual(2, map_size(channel_index(Updated))).
put_channels_non_map_data_returns_unchanged_test() ->
?assertEqual(not_a_map, put_channels([], not_a_map)).
member_role_index_non_map_input_test() ->
?assertEqual(#{}, member_role_index(not_a_map)).
member_role_index_members_without_roles_test() ->
Data = #{
<<"members">> => #{
1 => #{<<"user">> => #{<<"id">> => <<"1">>}}
}
},
Index = member_role_index(Data),
?assertEqual(#{}, Index).
member_role_index_shared_roles_test() ->
Data = #{
<<"members">> => #{
1 => #{<<"user">> => #{<<"id">> => <<"1">>}, <<"roles">> => [<<"10">>]},
2 => #{<<"user">> => #{<<"id">> => <<"2">>}, <<"roles">> => [<<"10">>]},
3 => #{<<"user">> => #{<<"id">> => <<"3">>}, <<"roles">> => [<<"10">>, <<"20">>]}
}
},
Index = member_role_index(Data),
?assertEqual(#{1 => true, 2 => true, 3 => true}, maps:get(10, Index)),
?assertEqual(#{3 => true}, maps:get(20, Index)).
build_id_index_skips_items_without_id_test() ->
Items = [
#{<<"id">> => <<"1">>, <<"name">> => <<"first">>},
#{<<"name">> => <<"no_id">>},
#{<<"id">> => <<"2">>, <<"name">> => <<"second">>}
],
Index = build_id_index(Items),
?assertEqual(2, map_size(Index)),
?assertEqual(<<"first">>, maps:get(<<"name">>, maps:get(1, Index))).
build_id_index_empty_list_test() ->
?assertEqual(#{}, build_id_index([])).
extract_integer_list_mixed_types_test() ->
?assertEqual([1, 2, 3], extract_integer_list([<<"1">>, 2, <<"3">>])),
?assertEqual([1, 3], extract_integer_list([<<"1">>, <<"invalid">>, <<"3">>])),
?assertEqual([], extract_integer_list(not_a_list)).
ensure_list_test() ->
?assertEqual([1, 2], ensure_list([1, 2])),
?assertEqual([], ensure_list(not_a_list)),
?assertEqual([], ensure_list(#{})).
normalize_member_map_with_binary_keys_test() ->
MemberMap = #{
<<"42">> => #{<<"user">> => #{<<"id">> => <<"42">>}, <<"nick">> => <<"test">>}
},
Normalized = normalize_member_map(MemberMap),
?assertMatch(#{42 := _}, Normalized),
?assertEqual(<<"test">>, maps:get(<<"nick">>, maps:get(42, Normalized))).
put_member_multiple_roles_index_test() ->
Data = #{<<"members">> => #{}},
Member = #{<<"user">> => #{<<"id">> => <<"7">>}, <<"roles">> => [<<"10">>, <<"20">>, <<"30">>]},
Updated = put_member(Member, Data),
Index = member_role_index(Updated),
?assertEqual(#{7 => true}, maps:get(10, Index)),
?assertEqual(#{7 => true}, maps:get(20, Index)),
?assertEqual(#{7 => true}, maps:get(30, Index)).
remove_member_cleans_empty_role_entries_test() ->
Data0 = #{<<"members">> => #{}},
Data1 = put_member(
#{<<"user">> => #{<<"id">> => <<"1">>}, <<"roles">> => [<<"10">>]},
Data0
),
Data2 = put_member(
#{<<"user">> => #{<<"id">> => <<"2">>}, <<"roles">> => [<<"10">>, <<"20">>]},
Data1
),
Data3 = remove_member(1, Data2),
Index = member_role_index(Data3),
?assertEqual(#{2 => true}, maps:get(10, Index)),
?assertEqual(#{2 => true}, maps:get(20, Index)),
Data4 = remove_member(2, Data3),
Index2 = member_role_index(Data4),
?assertEqual(undefined, maps:get(10, Index2, undefined)),
?assertEqual(undefined, maps:get(20, Index2, undefined)).
-endif.

View File

@@ -74,6 +74,8 @@ process_dispatch(Event, EventData, State) ->
FilteredSessions = filter_sessions_for_event(
Event, FinalData, SessionIdOpt, Sessions, FilterState
),
logger:info("process_dispatch: event=~p guild_id=~p total_sessions=~p filtered_sessions=~p",
[Event, GuildId, map_size(Sessions), length(FilteredSessions)]),
DispatchSuccess = dispatch_to_sessions(FilteredSessions, Event, FinalData, UpdatedState),
track_dispatch_metrics(Event, DispatchSuccess),
maybe_send_push_notifications(Event, FinalData, GuildId, UpdatedState),
@@ -260,8 +262,10 @@ dispatch_bulk_to_session(_, _, _, _, Acc) ->
-spec dispatch_standard([session_pair()], event(), event_data(), guild_id(), guild_state()) ->
non_neg_integer().
dispatch_standard(FilteredSessions, Event, FinalData, GuildId, State) ->
logger:info("dispatch_standard: event=~p guild_id=~p filtered_sessions=~p member_count=~p",
[Event, GuildId, length(FilteredSessions), maps:get(member_count, State, undefined)]),
SuccessCount = lists:foldl(
fun({_Sid, SessionData}, Acc) ->
fun({Sid, SessionData}, Acc) ->
Pid = maps:get(pid, SessionData),
case
is_pid(Pid) andalso
@@ -277,6 +281,11 @@ dispatch_standard(FilteredSessions, Event, FinalData, GuildId, State) ->
_:_ -> Acc
end;
false ->
logger:info("dispatch_standard skip: sid=~p is_pid=~p passive=~p small=~p",
[Sid,
is_pid(Pid),
session_passive:is_passive(GuildId, SessionData),
session_passive:is_small_guild(State)]),
Acc
end
end,
@@ -1098,4 +1107,188 @@ build_channel_delete_dispatch_state(VisiblePid, HiddenPid) ->
}
}.
should_skip_dispatch_guild_update_never_skipped_test() ->
State = #{
data => #{
<<"guild">> => #{<<"features">> => [<<"UNAVAILABLE_FOR_EVERYONE">>]}
}
},
?assertEqual(false, should_skip_dispatch(guild_update, State)).
should_skip_dispatch_unavailable_for_everyone_test() ->
State = #{
data => #{
<<"guild">> => #{<<"features">> => [<<"UNAVAILABLE_FOR_EVERYONE">>]}
}
},
?assertEqual(true, should_skip_dispatch(message_create, State)).
should_skip_dispatch_unavailable_for_everyone_but_staff_test() ->
State = #{
data => #{
<<"guild">> => #{<<"features">> => [<<"UNAVAILABLE_FOR_EVERYONE_BUT_STAFF">>]}
}
},
?assertEqual(true, should_skip_dispatch(message_create, State)).
should_skip_dispatch_normal_guild_test() ->
State = #{
data => #{
<<"guild">> => #{<<"features">> => []}
}
},
?assertEqual(false, should_skip_dispatch(message_create, State)).
should_skip_dispatch_no_features_test() ->
State = #{data => #{<<"guild">> => #{}}},
?assertEqual(false, should_skip_dispatch(message_create, State)).
filter_sessions_for_event_guild_wide_goes_to_all_sessions_test() ->
S1 = #{session_id => <<"s1">>, user_id => 10, pid => self()},
S2 = #{session_id => <<"s2">>, user_id => 11, pid => self()},
Sessions = #{<<"s1">> => S1, <<"s2">> => S2},
State = #{sessions => Sessions, data => #{<<"members">> => #{}}},
Result = filter_sessions_for_event(guild_member_add, #{}, undefined, Sessions, State),
?assertEqual(2, length(Result)).
extract_channel_id_message_create_uses_channel_id_field_test() ->
Data = #{<<"channel_id">> => <<"42">>},
?assertEqual(42, extract_channel_id(message_create, Data)).
extract_channel_id_channel_create_uses_id_field_test() ->
Data = #{<<"id">> => <<"42">>},
?assertEqual(42, extract_channel_id(channel_create, Data)).
extract_channel_id_channel_update_uses_id_field_test() ->
Data = #{<<"id">> => <<"42">>},
?assertEqual(42, extract_channel_id(channel_update, Data)).
parse_integer_undefined_returns_default_test() ->
?assertEqual(42, parse_integer(undefined, 42)).
parse_integer_integer_test() ->
?assertEqual(7, parse_integer(7, 0)).
parse_integer_valid_binary_test() ->
?assertEqual(123, parse_integer(<<"123">>, 0)).
parse_integer_invalid_binary_test() ->
?assertEqual(0, parse_integer(<<"abc">>, 0)).
parse_integer_other_type_test() ->
?assertEqual(5, parse_integer(3.14, 5)).
is_guild_operation_disabled_test() ->
State = disabled_operations_state(3),
?assertEqual(true, is_guild_operation_disabled(State, 1)),
?assertEqual(true, is_guild_operation_disabled(State, 2)),
?assertEqual(true, is_guild_operation_disabled(State, 3)),
?assertEqual(false, is_guild_operation_disabled(State, 4)).
is_guild_operation_disabled_binary_test() ->
State = disabled_operations_state(<<"5">>),
?assertEqual(true, is_guild_operation_disabled(State, 1)),
?assertEqual(true, is_guild_operation_disabled(State, 4)),
?assertEqual(false, is_guild_operation_disabled(State, 2)).
extract_session_id_if_needed_reaction_remove_test() ->
Data = #{<<"session_id">> => <<"sid">>, <<"emoji">> => #{}},
{SessionId, CleanData} = extract_session_id_if_needed(message_reaction_remove, Data),
?assertEqual(<<"sid">>, SessionId),
?assertNot(maps:is_key(<<"session_id">>, CleanData)).
decorate_member_data_typing_start_test() ->
Member = #{<<"user">> => #{<<"id">> => <<"456">>}, <<"roles">> => []},
State = #{data => #{<<"members">> => [Member]}},
Data = #{<<"user_id">> => <<"456">>},
Decorated = decorate_member_data(typing_start, Data, State),
?assert(maps:is_key(<<"member">>, Decorated)),
?assert(maps:is_key(<<"user">>, maps:get(<<"member">>, Decorated))).
decorate_member_data_guild_event_no_decoration_test() ->
State = #{data => #{<<"members">> => []}},
Data = #{<<"name">> => <<"test">>},
Decorated = decorate_member_data(guild_update, Data, State),
?assertEqual(false, maps:is_key(<<"member">>, Decorated)).
filter_visible_channels_test() ->
GuildId = 42,
UserId = 10,
ViewPerm = constants:view_channel_permission(),
Member = #{<<"user">> => #{<<"id">> => integer_to_binary(UserId)}, <<"roles">> => []},
State = #{
id => GuildId,
data => #{
<<"guild">> => #{<<"owner_id">> => <<"999">>},
<<"roles">> => [
#{<<"id">> => integer_to_binary(GuildId), <<"permissions">> => integer_to_binary(ViewPerm)}
],
<<"members">> => [Member],
<<"channels">> => [
#{<<"id">> => <<"100">>, <<"permission_overwrites">> => []},
#{
<<"id">> => <<"101">>,
<<"permission_overwrites">> => [
#{
<<"id">> => integer_to_binary(GuildId),
<<"type">> => 0,
<<"allow">> => <<"0">>,
<<"deny">> => integer_to_binary(ViewPerm)
}
]
}
]
}
},
Channels = [
#{<<"id">> => <<"100">>},
#{<<"id">> => <<"101">>}
],
Result = filter_visible_channels(Channels, UserId, Member, State),
?assertEqual(1, length(Result)),
?assertEqual(<<"100">>, maps:get(<<"id">>, hd(Result))).
filter_visible_channels_undefined_member_test() ->
State = #{data => #{<<"members">> => []}},
Channels = [#{<<"id">> => <<"100">>}],
Result = filter_visible_channels(Channels, 10, undefined, State),
?assertEqual([], Result).
extract_user_id_from_event_test() ->
EventData = #{<<"user">> => #{<<"id">> => <<"42">>}},
?assertEqual(42, extract_user_id_from_event(EventData)).
extract_user_id_from_event_missing_test() ->
?assertEqual(undefined, extract_user_id_from_event(#{})).
extract_user_id_from_event_invalid_test() ->
EventData = #{<<"user">> => #{<<"id">> => <<"invalid">>}},
?assertEqual(undefined, extract_user_id_from_event(EventData)).
find_channel_name_uses_index_test() ->
Data = #{
<<"channels">> => [
#{<<"id">> => <<"100">>, <<"name">> => <<"general">>}
],
<<"channel_index">> => #{100 => #{<<"id">> => <<"100">>, <<"name">> => <<"general">>}}
},
?assertEqual(<<"general">>, find_channel_name(<<"100">>, Data)).
find_channel_name_invalid_id_test() ->
Data = #{<<"channels">> => []},
?assertEqual(<<"unknown">>, find_channel_name(<<"invalid">>, Data)).
extract_role_ids_test() ->
Member = #{<<"roles">> => [<<"10">>, <<"20">>, <<"invalid">>]},
Result = lists:sort(extract_role_ids(Member)),
?assertEqual([10, 20], Result).
extract_role_ids_empty_test() ->
Member = #{<<"roles">> => []},
?assertEqual([], extract_role_ids(Member)).
extract_role_ids_missing_key_test() ->
Member = #{},
?assertEqual([], extract_role_ids(Member)).
-endif.

View File

@@ -0,0 +1,53 @@
%% Copyright (C) 2026 Fluxer Contributors
%%
%% This file is part of Fluxer.
%%
%% Fluxer is free software: you can redistribute it and/or modify
%% it under the terms of the GNU Affero General Public License as published by
%% the Free Software Foundation, either version 3 of the License, or
%% (at your option) any later version.
%%
%% Fluxer is distributed in the hope that it will be useful,
%% but WITHOUT ANY WARRANTY; without even the implied warranty of
%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
%% GNU Affero General Public License for more details.
%%
%% You should have received a copy of the GNU Affero General Public License
%% along with Fluxer. If not, see <https://www.gnu.org/licenses/>.
-module(guild_ets_utils).
-export([ensure_table/2]).
-spec ensure_table(atom(), list()) -> ok.
ensure_table(TableName, Options) ->
case ets:whereis(TableName) of
undefined ->
try ets:new(TableName, Options) of
_ -> ok
catch
error:badarg -> ok
end;
_ ->
ok
end.
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
ensure_table_creates_new_table_test() ->
TableName = guild_ets_utils_test_table,
catch ets:delete(TableName),
ok = ensure_table(TableName, [named_table, public, set]),
?assertNotEqual(undefined, ets:whereis(TableName)),
ets:delete(TableName).
ensure_table_idempotent_test() ->
TableName = guild_ets_utils_test_idempotent,
catch ets:delete(TableName),
ok = ensure_table(TableName, [named_table, public, set]),
ok = ensure_table(TableName, [named_table, public, set]),
?assertNotEqual(undefined, ets:whereis(TableName)),
ets:delete(TableName).
-endif.

View File

@@ -193,8 +193,13 @@ forward_call_to_shard(GuildId, Request, State) ->
Pid = maps:get(pid, ShardMap),
case catch gen_server:call(Pid, Request, ?DEFAULT_GEN_SERVER_TIMEOUT) of
{'EXIT', _} ->
{_Shard, State2} = restart_shard(Index, State1),
forward_call_to_shard(GuildId, Request, State2);
case erlang:is_process_alive(Pid) of
true ->
{{error, timeout}, State1};
false ->
{_Shard, State2} = restart_shard(Index, State1),
forward_call_to_shard(GuildId, Request, State2)
end;
{ok, GuildPid} = Reply ->
ets:insert(?GUILD_PID_CACHE, {GuildId, GuildPid}),
erlang:monitor(process, GuildPid),
@@ -367,4 +372,54 @@ find_shard_by_pid_found_test() ->
Shards = #{0 => #{pid => Pid, ref => make_ref()}},
?assertMatch({ok, 0}, find_shard_by_pid(Pid, Shards)).
forward_call_to_shard_timeout_does_not_restart_shard_test_() ->
{timeout, 15, fun() ->
catch ets:delete(guild_pid_cache),
SlowShardPid = spawn(fun() -> slow_shard_loop() end),
ShardRef = erlang:monitor(process, SlowShardPid),
State = #{
shards => #{0 => #{pid => SlowShardPid, ref => ShardRef}},
shard_count => 1
},
ets:new(guild_pid_cache, [named_table, public, set, {read_concurrency, true}]),
try
GuildId = 99999,
{Reply, NewState} = forward_call_to_shard(GuildId, {start_or_lookup, GuildId}, State),
?assertMatch({error, timeout}, Reply),
?assert(is_process_alive(SlowShardPid)),
NewShards = maps:get(shards, NewState),
#{pid := ShardPidAfter} = maps:get(0, NewShards),
?assertEqual(SlowShardPid, ShardPidAfter)
after
SlowShardPid ! stop,
catch ets:delete(guild_pid_cache)
end
end}.
slow_shard_loop() ->
receive
{'$gen_call', _From, _Msg} ->
timer:sleep(10000),
slow_shard_loop();
stop ->
ok;
_ ->
slow_shard_loop()
end.
cleanup_guild_from_cache_does_not_remove_new_pid_test() ->
catch ets:delete(guild_pid_cache),
ets:new(guild_pid_cache, [named_table, public, set, {read_concurrency, true}]),
try
OldPid = spawn(fun() -> ok end),
timer:sleep(10),
NewPid = spawn(fun() -> timer:sleep(1000) end),
ets:insert(guild_pid_cache, {42, NewPid}),
cleanup_guild_from_cache(OldPid),
[{42, FoundPid}] = ets:lookup(guild_pid_cache, 42),
?assertEqual(NewPid, FoundPid)
after
catch ets:delete(guild_pid_cache)
end.
-endif.

View File

@@ -20,7 +20,6 @@
-include_lib("fluxer_gateway/include/timeout_config.hrl").
-define(GUILD_API_CANARY_PERCENTAGE, 5).
-define(BATCH_SIZE, 10).
-define(BATCH_DELAY_MS, 100).
@@ -33,8 +32,6 @@
-type fetch_result() :: {ok, guild_data()} | {error, term()}.
-type state() :: #{
guilds := #{guild_id() => guild_ref() | loading},
api_host := string(),
api_canary_host := undefined | string(),
pending_requests := #{guild_id() => [gen_server:from()]},
shard_index := non_neg_integer()
}.
@@ -47,13 +44,9 @@ start_link(ShardIndex) ->
init(Args) ->
process_flag(trap_exit, true),
fluxer_gateway_env:load(),
ApiHost = fluxer_gateway_env:get(api_host),
ApiCanaryHost = fluxer_gateway_env:get(api_canary_host),
ShardIndex = maps:get(shard_index, Args, 0),
{ok, #{
guilds => #{},
api_host => ApiHost,
api_canary_host => ApiCanaryHost,
pending_requests => #{},
shard_index => ShardIndex
}}.
@@ -149,12 +142,11 @@ start_fetch(GuildId, From, State) ->
{noreply, NewState}.
-spec spawn_fetch(guild_id(), state()) -> pid().
spawn_fetch(GuildId, State) ->
spawn_fetch(GuildId, _State) ->
Manager = self(),
ApiHostInfo = select_api_host(State),
spawn(fun() ->
try
Result = fetch_guild_data_with_fallback(GuildId, ApiHostInfo, State),
Result = fetch_guild_data(GuildId),
gen_server:cast(Manager, {guild_data_fetched, GuildId, Result})
catch
_:_:_ ->
@@ -258,12 +250,11 @@ do_reload_guild(GuildId, From, State) ->
end.
-spec spawn_reload(guild_id(), pid(), gen_server:from(), state()) -> pid().
spawn_reload(GuildId, Pid, From, State) ->
spawn_reload(GuildId, Pid, From, _State) ->
Manager = self(),
ApiHostInfo = select_api_host(State),
spawn(fun() ->
try
Result = fetch_guild_data_with_fallback(GuildId, ApiHostInfo, State),
Result = fetch_guild_data(GuildId),
gen_server:cast(Manager, {guild_data_reloaded, GuildId, Pid, From, Result})
catch
_:_:_ ->
@@ -319,13 +310,12 @@ reload_guilds_in_batches(Guilds, State) ->
end.
-spec reload_batch([{guild_id(), pid()}], state()) -> ok.
reload_batch(Batch, State) ->
ApiHostInfo = select_api_host(State),
reload_batch(Batch, _State) ->
lists:foreach(
fun({GuildId, Pid}) ->
spawn(fun() ->
try
case fetch_guild_data_with_fallback(GuildId, ApiHostInfo, State) of
case fetch_guild_data(GuildId) of
{ok, Data} ->
catch gen_server:call(Pid, {reload, Data}, ?GUILD_CALL_TIMEOUT);
{error, _Reason} ->
@@ -385,18 +375,23 @@ start_new_guild(GuildId, Data, GuildName, State) ->
true -> very_large_guild;
false -> guild
end,
case GuildModule:start_link(GuildState) of
{ok, Pid} ->
case process_registry:register_and_monitor(GuildName, Pid, Guilds) of
{ok, RegisteredPid, Ref, NewGuilds0} ->
CleanGuilds = maps:remove(GuildName, NewGuilds0),
NewGuilds = maps:put(GuildId, {RegisteredPid, Ref}, CleanGuilds),
{ok, RegisteredPid, State#{guilds => NewGuilds}};
{error, Reason} ->
{error, Reason}
case whereis(GuildName) of
undefined ->
case GuildModule:start_link(GuildState) of
{ok, Pid} ->
case process_registry:register_and_monitor(GuildName, Pid, Guilds) of
{ok, RegisteredPid, Ref, NewGuilds0} ->
CleanGuilds = maps:remove(GuildName, NewGuilds0),
NewGuilds = maps:put(GuildId, {RegisteredPid, Ref}, CleanGuilds),
{ok, RegisteredPid, State#{guilds => NewGuilds}};
{error, Reason} ->
{error, Reason}
end;
Error ->
Error
end;
Error ->
Error
_AlreadyRegistered ->
lookup_existing_guild(GuildId, GuildName, State)
end.
-spec is_very_large_guild(guild_data()) -> boolean().
@@ -417,81 +412,18 @@ lookup_existing_guild(GuildId, GuildName, State) ->
{error, process_died}
end.
-spec fetch_guild_data(guild_id(), string()) -> fetch_result().
fetch_guild_data(GuildId, ApiHost) ->
-spec fetch_guild_data(guild_id()) -> fetch_result().
fetch_guild_data(GuildId) ->
RpcRequest = #{
<<"type">> => <<"guild">>,
<<"guild_id">> => type_conv:to_binary(GuildId),
<<"version">> => 1
},
Url = rpc_client:get_rpc_url(ApiHost),
Headers = rpc_client:get_rpc_headers() ++ [{<<"content-type">>, <<"application/json">>}],
Body = json:encode(RpcRequest),
case gateway_http_client:request(rpc, post, Url, Headers, Body) of
{ok, 200, _RespHeaders, RespBody} ->
handle_fetch_response(RespBody);
{ok, StatusCode, _RespHeaders, _RespBody} ->
handle_fetch_error(StatusCode);
{error, Reason} ->
{error, {request_failed, Reason}}
end.
-spec handle_fetch_response(binary()) -> fetch_result().
handle_fetch_response(RespBody) ->
Response = json:decode(RespBody),
Data = maps:get(<<"data">>, Response, #{}),
{ok, Data}.
-spec handle_fetch_error(integer()) -> {error, {http_status, integer()}}.
handle_fetch_error(StatusCode) ->
{error, {http_status, StatusCode}}.
-spec select_api_host(state()) -> {string(), boolean()}.
select_api_host(State) ->
case maps:get(api_canary_host, State) of
undefined ->
{maps:get(api_host, State), false};
_ ->
case should_use_canary_api() of
true -> {maps:get(api_canary_host, State), true};
false -> {maps:get(api_host, State), false}
end
end.
-spec should_use_canary_api() -> boolean().
should_use_canary_api() ->
erlang:unique_integer([positive]) rem 100 < ?GUILD_API_CANARY_PERCENTAGE.
-spec fetch_guild_data_with_fallback(guild_id(), {string(), boolean()}, state()) -> fetch_result().
fetch_guild_data_with_fallback(GuildId, {ApiHost, false}, _State) ->
fetch_guild_data(GuildId, ApiHost);
fetch_guild_data_with_fallback(GuildId, {ApiHost, true}, State) ->
case fetch_guild_data(GuildId, ApiHost) of
{ok, Data} ->
{ok, Data};
Error ->
StableHost = maps:get(api_host, State),
case StableHost == ApiHost of
true ->
Error;
false ->
fetch_guild_data(GuildId, StableHost)
end
end.
rpc_client:call(RpcRequest).
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
select_api_host_no_canary_test() ->
State = #{api_host => "http://api.local", api_canary_host => undefined},
{Host, IsCanary} = select_api_host(State),
?assertEqual("http://api.local", Host),
?assertEqual(false, IsCanary).
should_use_canary_api_returns_boolean_test() ->
Result = should_use_canary_api(),
?assert(is_boolean(Result)).
select_guilds_to_reload_empty_ids_test() ->
Guilds = #{1 => {self(), make_ref()}, 2 => {self(), make_ref()}},
Result = select_guilds_to_reload([], Guilds),
@@ -513,8 +445,6 @@ do_start_or_lookup_loading_deduplicates_requests_test() ->
From2 = {self(), make_ref()},
State0 = #{
guilds => #{GuildId => loading},
api_host => "http://api.local",
api_canary_host => undefined,
pending_requests => #{},
shard_index => 0
},
@@ -528,4 +458,76 @@ do_start_or_lookup_loading_deduplicates_requests_test() ->
?assert(lists:member(From1, Requests)),
?assert(lists:member(From2, Requests)).
start_new_guild_skips_start_when_already_registered_test() ->
GuildId = 77777,
GuildName = process_registry:build_process_name(guild, GuildId),
ExistingPid = spawn(fun() -> mock_guild_loop() end),
register(GuildName, ExistingPid),
try
State0 = #{
guilds => #{},
pending_requests => #{},
shard_index => 0
},
Data = #{<<"guild">> => #{<<"id">> => <<"77777">>, <<"features">> => []}},
Result = start_new_guild(GuildId, Data, GuildName, State0),
?assertMatch({ok, ExistingPid, _}, Result),
{ok, RetPid, _NewState} = Result,
?assertEqual(ExistingPid, RetPid)
after
catch unregister(GuildName),
ExistingPid ! stop
end.
start_guild_returns_existing_when_registered_test() ->
GuildId = 88888,
GuildName = process_registry:build_process_name(guild, GuildId),
ExistingPid = spawn(fun() -> mock_guild_loop() end),
register(GuildName, ExistingPid),
try
State0 = #{
guilds => #{},
pending_requests => #{},
shard_index => 0
},
Data = #{<<"guild">> => #{<<"id">> => <<"88888">>, <<"features">> => []}},
Result = start_guild(GuildId, Data, State0),
?assertMatch({ok, ExistingPid, _}, Result)
after
catch unregister(GuildName),
ExistingPid ! stop
end.
register_and_monitor_race_kills_duplicate_test_() ->
{timeout, 15, fun() ->
GuildId = 66666,
GuildName = process_registry:build_process_name(guild, GuildId),
WinnerPid = spawn(fun() -> mock_guild_loop() end),
register(GuildName, WinnerPid),
LoserPid = spawn(fun() -> mock_guild_loop() end),
try
Guilds = #{},
Result = process_registry:register_and_monitor(GuildName, LoserPid, Guilds),
?assertMatch({ok, WinnerPid, _, _}, Result),
timer:sleep(200),
?assertEqual(false, is_process_alive(LoserPid)),
?assert(is_process_alive(WinnerPid))
after
catch unregister(GuildName),
catch (WinnerPid ! stop),
catch (LoserPid ! stop)
end
end}.
mock_guild_loop() ->
receive
{'$gen_call', From, _Msg} ->
gen_server:reply(From, ok),
mock_guild_loop();
stop ->
ok;
_ ->
mock_guild_loop()
end.
-endif.

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -769,6 +769,449 @@ role_ids_from_roles_test() ->
],
?assertEqual([100, 200], role_ids_from_roles(Roles)).
check_target_member_owner_can_manage_anyone_test() ->
State = test_state(),
{reply, #{can_manage := CanManage}, _} =
check_target_member(#{user_id => 1, target_user_id => 2}, State),
?assertEqual(true, CanManage).
check_target_member_cannot_manage_owner_test() ->
State = test_state(),
{reply, #{can_manage := CanManage}, _} =
check_target_member(#{user_id => 2, target_user_id => 1}, State),
?assertEqual(false, CanManage).
check_target_member_higher_role_can_manage_lower_test() ->
State = test_state(),
{reply, #{can_manage := CanManage}, _} =
check_target_member(#{user_id => 3, target_user_id => 2}, State),
?assertEqual(true, CanManage).
check_target_member_lower_role_cannot_manage_higher_test() ->
State = test_state(),
{reply, #{can_manage := CanManage}, _} =
check_target_member(#{user_id => 2, target_user_id => 3}, State),
?assertEqual(false, CanManage).
can_manage_roles_owner_always_true_test() ->
State = test_state(),
{reply, #{can_manage := CanManage}, _} =
can_manage_roles(#{user_id => 1, role_id => 201}, State),
?assertEqual(true, CanManage).
can_manage_roles_member_lower_role_test() ->
State = test_state(),
{reply, #{can_manage := CanManage}, _} =
can_manage_roles(#{user_id => 2, role_id => 201}, State),
?assertEqual(false, CanManage).
can_manage_roles_unknown_role_test() ->
State = test_state(),
{reply, #{can_manage := CanManage}, _} =
can_manage_roles(#{user_id => 2, role_id => 999}, State),
?assertEqual(false, CanManage).
get_viewable_channels_unknown_user_test() ->
State = test_state(),
{reply, #{channel_ids := ChannelIds}, _} =
get_viewable_channels(#{user_id => 999}, State),
?assertEqual([], ChannelIds).
get_viewable_channels_owner_sees_all_test() ->
State = test_state(),
{reply, #{channel_ids := ChannelIds}, _} =
get_viewable_channels(#{user_id => 1}, State),
?assert(length(ChannelIds) >= 2).
get_viewable_channels_with_deny_overwrite_test() ->
ViewPerm = constants:view_channel_permission(),
ManageRoles = constants:manage_roles_permission(),
State = #{
id => 100,
data => #{
<<"guild">> => #{<<"owner_id">> => <<"1">>},
<<"roles">> => [
#{<<"id">> => <<"100">>, <<"permissions">> => integer_to_binary(ViewPerm bor ManageRoles), <<"position">> => 0}
],
<<"channels">> => [
#{
<<"id">> => <<"500">>,
<<"type">> => 0,
<<"permission_overwrites">> => [
#{<<"id">> => <<"100">>, <<"type">> => 0, <<"allow">> => <<"0">>, <<"deny">> => integer_to_binary(ViewPerm)}
]
},
#{<<"id">> => <<"501">>, <<"type">> => 0, <<"permission_overwrites">> => []}
],
<<"members">> => [
#{<<"user">> => #{<<"id">> => <<"1">>}, <<"roles">> => [<<"100">>]},
#{<<"user">> => #{<<"id">> => <<"2">>}, <<"roles">> => [<<"100">>]}
]
}
},
{reply, #{channel_ids := OwnerChannels}, _} =
get_viewable_channels(#{user_id => 1}, State),
?assert(lists:member(500, OwnerChannels)),
{reply, #{channel_ids := MemberChannels}, _} =
get_viewable_channels(#{user_id => 2}, State),
?assertNot(lists:member(500, MemberChannels)),
?assert(lists:member(501, MemberChannels)).
resolve_all_mentions_mention_everyone_test() ->
State = test_state(),
Request = #{
channel_id => 500,
author_id => 1,
mention_everyone => true,
mention_here => false,
role_ids => [],
user_ids => []
},
{reply, #{user_ids := UserIds}, _} = resolve_all_mentions(Request, State),
?assert(lists:member(2, UserIds)),
?assert(lists:member(3, UserIds)),
?assertNot(lists:member(1, UserIds)).
resolve_all_mentions_mention_here_only_connected_test() ->
State0 = test_state(),
State = State0#{
sessions => #{
<<"sess_2">> => #{user_id => 2}
}
},
Request = #{
channel_id => 500,
author_id => 1,
mention_everyone => false,
mention_here => true,
role_ids => [],
user_ids => []
},
{reply, #{user_ids := UserIds}, _} = resolve_all_mentions(Request, State),
?assert(lists:member(2, UserIds)),
?assertNot(lists:member(3, UserIds)).
resolve_all_mentions_excludes_bots_test() ->
ViewPerm = constants:view_channel_permission(),
ManageRoles = constants:manage_roles_permission(),
State = #{
id => 100,
data => #{
<<"guild">> => #{<<"owner_id">> => <<"1">>},
<<"roles">> => [
#{<<"id">> => <<"100">>, <<"permissions">> => integer_to_binary(ViewPerm bor ManageRoles), <<"position">> => 0}
],
<<"channels">> => [
#{<<"id">> => <<"500">>, <<"type">> => 0, <<"permission_overwrites">> => []}
],
<<"members">> => [
#{<<"user">> => #{<<"id">> => <<"1">>}, <<"roles">> => [<<"100">>]},
#{<<"user">> => #{<<"id">> => <<"2">>, <<"bot">> => true}, <<"roles">> => [<<"100">>]},
#{<<"user">> => #{<<"id">> => <<"3">>}, <<"roles">> => [<<"100">>]}
]
},
sessions => #{}
},
Request = #{
channel_id => 500,
author_id => 1,
mention_everyone => true,
mention_here => false,
role_ids => [],
user_ids => []
},
{reply, #{user_ids := UserIds}, _} = resolve_all_mentions(Request, State),
?assertNot(lists:member(2, UserIds)),
?assert(lists:member(3, UserIds)).
resolve_all_mentions_author_always_excluded_test() ->
State = test_state(),
Request = #{
channel_id => 500,
author_id => 2,
mention_everyone => false,
mention_here => false,
role_ids => [],
user_ids => [2]
},
{reply, #{user_ids := UserIds}, _} = resolve_all_mentions(Request, State),
?assertNot(lists:member(2, UserIds)).
get_all_users_to_mention_excludes_author_test() ->
State = test_state(),
{reply, #{user_ids := UserIds}, _} =
get_all_users_to_mention(#{channel_id => 500, author_id => 1}, State),
?assertNot(lists:member(1, UserIds)),
?assert(lists:member(2, UserIds)),
?assert(lists:member(3, UserIds)).
get_members_with_role_undefined_role_id_test() ->
State = test_state(),
{reply, #{user_ids := UserIds}, _} =
get_members_with_role(#{role_id => undefined}, State),
?assertEqual([], UserIds).
get_members_with_role_nonexistent_role_test() ->
State = test_state(),
{reply, #{user_ids := UserIds}, _} =
get_members_with_role(#{role_id => 999}, State),
?assertEqual([], UserIds).
get_assignable_roles_non_member_test() ->
State = test_state(),
{reply, #{role_ids := RoleIds}, _} =
get_assignable_roles(#{user_id => 999}, State),
?assertEqual([], RoleIds).
member_user_id_missing_user_test() ->
?assertEqual(undefined, member_user_id(#{})).
member_user_id_missing_id_test() ->
?assertEqual(undefined, member_user_id(#{<<"user">> => #{}})).
member_user_id_valid_test() ->
?assertEqual(42, member_user_id(#{<<"user">> => #{<<"id">> => <<"42">>}})).
member_roles_empty_test() ->
?assertEqual([], member_roles(#{})).
member_roles_binary_ids_test() ->
Member = #{<<"roles">> => [<<"100">>, <<"200">>]},
?assertEqual([100, 200], member_roles(Member)).
is_member_bot_missing_user_test() ->
?assertEqual(false, is_member_bot(#{})).
is_member_bot_missing_bot_field_test() ->
?assertEqual(false, is_member_bot(#{<<"user">> => #{}})).
role_position_default_test() ->
?assertEqual(0, role_position(#{})).
role_position_explicit_test() ->
?assertEqual(5, role_position(#{<<"position">> => 5})).
role_ids_from_roles_empty_test() ->
?assertEqual([], role_ids_from_roles([])).
role_ids_from_roles_skips_undefined_ids_test() ->
Roles = [#{}, #{<<"id">> => <<"100">>}],
?assertEqual([100], role_ids_from_roles(Roles)).
normalize_int_list_mixed_types_test() ->
?assertEqual([1, 2], normalize_int_list([<<"1">>, <<"invalid">>, 2])).
normalize_int_list_non_list_input_test() ->
?assertEqual([], normalize_int_list(not_a_list)).
member_has_any_role_set_empty_roles_test() ->
Member = #{<<"roles">> => []},
RoleSet = gb_sets:from_list([100]),
?assertEqual(false, member_has_any_role_set(Member, RoleSet)).
member_has_any_role_set_empty_set_test() ->
Member = #{<<"roles">> => [<<"100">>]},
RoleSet = gb_sets:empty(),
?assertEqual(false, member_has_any_role_set(Member, RoleSet)).
compare_roles_first_undefined_test() ->
Role = #{<<"position">> => 5, <<"id">> => <<"10">>},
?assertEqual(Role, compare_roles(Role, undefined)).
compare_roles_higher_position_wins_test() ->
RoleA = #{<<"position">> => 5, <<"id">> => <<"10">>},
RoleB = #{<<"position">> => 10, <<"id">> => <<"20">>},
?assertEqual(RoleB, compare_roles(RoleB, RoleA)).
compare_roles_same_position_lower_id_wins_test() ->
RoleA = #{<<"position">> => 5, <<"id">> => <<"20">>},
RoleB = #{<<"position">> => 5, <<"id">> => <<"10">>},
?assertEqual(RoleB, compare_roles(RoleB, RoleA)).
compare_roles_same_position_higher_id_loses_test() ->
RoleA = #{<<"position">> => 5, <<"id">> => <<"10">>},
RoleB = #{<<"position">> => 5, <<"id">> => <<"20">>},
?assertEqual(RoleA, compare_roles(RoleB, RoleA)).
get_highest_role_empty_roles_test() ->
?assertEqual(undefined, get_highest_role([], #{})).
get_highest_role_no_matching_roles_test() ->
Roles = #{999 => #{<<"id">> => <<"999">>, <<"position">> => 5}},
?assertEqual(undefined, get_highest_role([100], Roles)).
get_highest_role_picks_highest_position_test() ->
RoleLow = #{<<"id">> => <<"100">>, <<"position">> => 5},
RoleHigh = #{<<"id">> => <<"200">>, <<"position">> => 10},
Roles = #{100 => RoleLow, 200 => RoleHigh},
?assertEqual(RoleHigh, get_highest_role([100, 200], Roles)).
build_connected_user_ids_false_returns_empty_test() ->
Sessions = #{<<"s1">> => #{user_id => 1}},
?assertEqual(gb_sets:empty(), build_connected_user_ids(false, Sessions)).
build_connected_user_ids_true_collects_user_ids_test() ->
Sessions = #{
<<"s1">> => #{user_id => 1},
<<"s2">> => #{user_id => 2},
<<"s3">> => #{other => data}
},
Result = build_connected_user_ids(true, Sessions),
?assert(gb_sets:is_member(1, Result)),
?assert(gb_sets:is_member(2, Result)),
?assertEqual(2, gb_sets:size(Result)).
build_connected_user_ids_empty_sessions_test() ->
Result = build_connected_user_ids(true, #{}),
?assertEqual(gb_sets:empty(), Result).
check_should_mention_everyone_true_test() ->
Member = #{<<"roles">> => []},
?assertEqual(true, check_should_mention(
1, Member, true, false, false, false,
gb_sets:empty(), gb_sets:empty(), gb_sets:empty()
)).
check_should_mention_here_connected_test() ->
Member = #{<<"roles">> => []},
Connected = gb_sets:from_list([1]),
?assertEqual(true, check_should_mention(
1, Member, false, true, false, false,
gb_sets:empty(), gb_sets:empty(), Connected
)).
check_should_mention_here_not_connected_test() ->
Member = #{<<"roles">> => []},
Connected = gb_sets:from_list([2]),
?assertEqual(false, check_should_mention(
1, Member, false, true, false, false,
gb_sets:empty(), gb_sets:empty(), Connected
)).
check_should_mention_role_match_test() ->
Member = #{<<"roles">> => [<<"100">>]},
RoleSet = gb_sets:from_list([100]),
?assertEqual(true, check_should_mention(
1, Member, false, false, true, false,
RoleSet, gb_sets:empty(), gb_sets:empty()
)).
check_should_mention_direct_id_match_test() ->
Member = #{<<"roles">> => []},
DirectSet = gb_sets:from_list([1]),
?assertEqual(true, check_should_mention(
1, Member, false, false, false, true,
gb_sets:empty(), DirectSet, gb_sets:empty()
)).
check_should_mention_nothing_matches_test() ->
Member = #{<<"roles">> => []},
?assertEqual(false, check_should_mention(
1, Member, false, false, false, false,
gb_sets:empty(), gb_sets:empty(), gb_sets:empty()
)).
member_can_view_channel_non_integer_channel_id_test() ->
?assertEqual(false, member_can_view_channel(1, undefined, #{}, #{})).
collect_mentions_excludes_author_test() ->
ViewPerm = constants:view_channel_permission(),
State = #{
id => 100,
data => #{
<<"guild">> => #{<<"owner_id">> => <<"1">>},
<<"roles">> => [
#{<<"id">> => <<"100">>, <<"permissions">> => integer_to_binary(ViewPerm), <<"position">> => 0}
],
<<"channels">> => [
#{<<"id">> => <<"500">>, <<"type">> => 0, <<"permission_overwrites">> => []}
],
<<"members">> => [
#{<<"user">> => #{<<"id">> => <<"1">>}, <<"roles">> => [<<"100">>]},
#{<<"user">> => #{<<"id">> => <<"2">>}, <<"roles">> => [<<"100">>]}
]
}
},
Members = [
#{<<"user">> => #{<<"id">> => <<"1">>}, <<"roles">> => [<<"100">>]},
#{<<"user">> => #{<<"id">> => <<"2">>}, <<"roles">> => [<<"100">>]}
],
UserIds = collect_mentions(Members, 1, 500, State, fun(_) -> true end),
?assertNot(lists:member(1, UserIds)),
?assert(lists:member(2, UserIds)).
collect_mentions_skips_members_without_user_id_test() ->
ViewPerm = constants:view_channel_permission(),
State = #{
id => 100,
data => #{
<<"guild">> => #{<<"owner_id">> => <<"1">>},
<<"roles">> => [
#{<<"id">> => <<"100">>, <<"permissions">> => integer_to_binary(ViewPerm), <<"position">> => 0}
],
<<"channels">> => [
#{<<"id">> => <<"500">>, <<"type">> => 0, <<"permission_overwrites">> => []}
],
<<"members">> => []
}
},
Members = [#{}, #{<<"user">> => #{}}],
UserIds = collect_mentions(Members, 1, 500, State, fun(_) -> true end),
?assertEqual([], UserIds).
filter_assignable_role_below_position_test() ->
Role = #{<<"id">> => <<"100">>, <<"position">> => 5},
?assertEqual({true, 100}, filter_assignable_role(Role, 10)).
filter_assignable_role_at_position_test() ->
Role = #{<<"id">> => <<"100">>, <<"position">> => 10},
?assertEqual(false, filter_assignable_role(Role, 10)).
filter_assignable_role_above_position_test() ->
Role = #{<<"id">> => <<"100">>, <<"position">> => 15},
?assertEqual(false, filter_assignable_role(Role, 10)).
filter_assignable_role_no_id_test() ->
Role = #{<<"position">> => 5},
?assertEqual(false, filter_assignable_role(Role, 10)).
user_ids_for_any_role_empty_roles_test() ->
State = test_state(),
?assertEqual([], user_ids_for_any_role([], State)).
user_ids_for_any_role_nonexistent_role_test() ->
State = test_state(),
?assertEqual([], user_ids_for_any_role([999], State)).
user_ids_for_any_role_multiple_roles_test() ->
State = test_state(),
UserIds = lists:sort(user_ids_for_any_role([200, 201], State)),
?assertEqual([2, 3], UserIds).
owner_id_valid_test() ->
State = test_state(),
?assertEqual(1, owner_id(State)).
owner_id_missing_guild_test() ->
State = #{data => #{}},
?assertEqual(0, owner_id(State)).
guild_data_missing_data_test() ->
State = #{},
?assertEqual(#{}, guild_data(State)).
guild_members_empty_test() ->
State = #{data => #{<<"members">> => []}},
?assertEqual([], guild_members(State)).
guild_roles_empty_test() ->
State = #{data => #{<<"roles">> => []}},
?assertEqual([], guild_roles(State)).
guild_channels_empty_test() ->
State = #{data => #{<<"channels">> => []}},
?assertEqual([], guild_channels(State)).
test_state() ->
GuildId = 100,
OwnerId = 1,

View File

@@ -70,35 +70,35 @@ send_passive_updates_to_sessions(State) ->
0 ->
State;
_ ->
UpdatedSessions = process_passive_sessions(
maps:to_list(PassiveSessions), GuildId, Sessions, Channels, State
process_passive_sessions(
maps:to_list(PassiveSessions), GuildId, Channels, State
),
maps:put(sessions, UpdatedSessions, State)
State
end.
-spec process_passive_sessions([{binary(), map()}], integer(), map(), [map()], guild_state()) ->
map().
process_passive_sessions(PassiveSessionList, GuildId, Sessions, Channels, State) ->
lists:foldl(
fun({SessionId, SessionData}, AccSessions) ->
-spec process_passive_sessions([{binary(), map()}], integer(), [map()], guild_state()) ->
ok.
process_passive_sessions(PassiveSessionList, GuildId, Channels, State) ->
lists:foreach(
fun({SessionId, SessionData}) ->
process_single_passive_session(
SessionId, SessionData, GuildId, Channels, State, AccSessions
SessionId, SessionData, GuildId, Channels, State
)
end,
Sessions,
PassiveSessionList
).
-spec process_single_passive_session(binary(), map(), integer(), [map()], guild_state(), map()) ->
map().
process_single_passive_session(SessionId, SessionData, GuildId, Channels, State, AccSessions) ->
-spec process_single_passive_session(binary(), map(), integer(), [map()], guild_state()) ->
ok.
process_single_passive_session(SessionId, SessionData, GuildId, Channels, State) ->
Pid = maps:get(pid, SessionData),
UserId = maps:get(user_id, SessionData),
Member = guild_permissions:find_member_by_user_id(UserId, State),
CurrentLastMessageIds = build_last_message_ids(Channels, UserId, Member, State),
PreviousLastMessageIds = maps:get(previous_passive_updates, SessionData, #{}),
RegState = passive_sync_registry:lookup(SessionId, GuildId),
PreviousLastMessageIds = maps:get(previous_passive_updates, RegState, #{}),
Delta = compute_delta(CurrentLastMessageIds, PreviousLastMessageIds),
PreviousChannelVersions = maps:get(previous_passive_channel_versions, SessionData, #{}),
PreviousChannelVersions = maps:get(previous_passive_channel_versions, RegState, #{}),
{CurrentChannelVersions, CurrentChannelsById} =
build_viewable_channel_snapshots(Channels, UserId, Member, State),
{CreatedChannelIds, UpdatedChannelIds, DeletedChannelIds} =
@@ -107,12 +107,10 @@ process_single_passive_session(SessionId, SessionData, GuildId, Channels, State,
UpdatedChannels = [maps:get(Id, CurrentChannelsById) || Id <- UpdatedChannelIds],
ViewableChannels = guild_visibility:viewable_channel_set(UserId, State),
CurrentVoiceStates = build_current_voice_state_map(ViewableChannels, State),
PreviousVoiceStates = maps:get(previous_passive_voice_states, SessionData, #{}),
PreviousVoiceStates = maps:get(previous_passive_voice_states, RegState, #{}),
VoiceStateUpdates = compute_voice_state_updates(
CurrentVoiceStates, PreviousVoiceStates, GuildId
),
UpdatedSessionDataBase =
maps:put(previous_passive_voice_states, CurrentVoiceStates, SessionData),
HasChannelDelta = map_size(Delta) > 0,
HasVoiceUpdates = VoiceStateUpdates =/= [],
HasCreatedChannels = CreatedChannels =/= [],
@@ -134,21 +132,21 @@ process_single_passive_session(SessionId, SessionData, GuildId, Channels, State,
gen_server:cast(Pid, {dispatch, passive_updates, EventData}),
PreviousLastMessageIds1 = maps:without(DeletedChannelIds, PreviousLastMessageIds),
MergedLastMessageIds = maps:merge(PreviousLastMessageIds1, Delta),
UpdatedSessionData0 =
maps:put(previous_passive_updates, MergedLastMessageIds, UpdatedSessionDataBase),
UpdatedSessionData =
maps:put(
previous_passive_channel_versions, CurrentChannelVersions, UpdatedSessionData0
),
maps:put(SessionId, UpdatedSessionData, AccSessions);
NewRegState = #{
previous_passive_updates => MergedLastMessageIds,
previous_passive_channel_versions => CurrentChannelVersions,
previous_passive_voice_states => CurrentVoiceStates
},
passive_sync_registry:store(SessionId, GuildId, NewRegState),
ok;
_ ->
UpdatedSessionData =
maps:put(
previous_passive_channel_versions,
CurrentChannelVersions,
UpdatedSessionDataBase
),
maps:put(SessionId, UpdatedSessionData, AccSessions)
NewRegState = #{
previous_passive_updates => PreviousLastMessageIds,
previous_passive_channel_versions => CurrentChannelVersions,
previous_passive_voice_states => CurrentVoiceStates
},
passive_sync_registry:store(SessionId, GuildId, NewRegState),
ok
end.
-spec build_passive_event_data(integer(), map(), [map()], [map()], [binary()], [map()]) -> map().

View File

@@ -20,9 +20,12 @@
-export([
put_state/1,
put_data/2,
put_normalized_data/2,
delete/1,
get_permissions/3,
get_snapshot/1
get_snapshot/1,
has_member/2,
get_member/2
]).
-type guild_id() :: integer().
@@ -39,7 +42,7 @@ put_state(State) when is_map(State) ->
Data = maps:get(data, State, #{}),
case is_integer(GuildId) of
true ->
put_data(GuildId, Data);
put_normalized_data(GuildId, Data);
false ->
ok
end;
@@ -48,12 +51,18 @@ put_state(_) ->
-spec put_data(guild_id(), guild_data()) -> ok.
put_data(GuildId, Data) when is_integer(GuildId), is_map(Data) ->
ensure_table(),
NormalizedData = guild_data_index:normalize_data(Data),
put_normalized_data(GuildId, NormalizedData);
put_data(_, _) ->
ok.
-spec put_normalized_data(guild_id(), guild_data()) -> ok.
put_normalized_data(GuildId, NormalizedData) when is_integer(GuildId), is_map(NormalizedData) ->
ensure_table(),
Snapshot = #{id => GuildId, data => NormalizedData},
true = ets:insert(?TABLE, {GuildId, Snapshot}),
ok;
put_data(_, _) ->
put_normalized_data(_, _) ->
ok.
-spec delete(guild_id()) -> ok.
@@ -77,6 +86,29 @@ get_permissions(GuildId, UserId, ChannelId) when is_integer(GuildId), is_integer
get_permissions(_, _, _) ->
{error, not_found}.
-spec has_member(guild_id(), user_id()) -> {ok, boolean()} | {error, not_found}.
has_member(GuildId, UserId) when is_integer(GuildId), is_integer(UserId) ->
case get_snapshot(GuildId) of
{ok, Snapshot} ->
Member = guild_permissions:find_member_by_user_id(UserId, Snapshot),
{ok, Member =/= undefined};
{error, not_found} ->
{error, not_found}
end;
has_member(_, _) ->
{error, not_found}.
-spec get_member(guild_id(), user_id()) -> {ok, map() | undefined} | {error, not_found}.
get_member(GuildId, UserId) when is_integer(GuildId), is_integer(UserId) ->
case get_snapshot(GuildId) of
{ok, Snapshot} ->
{ok, guild_permissions:find_member_by_user_id(UserId, Snapshot)};
{error, not_found} ->
{error, not_found}
end;
get_member(_, _) ->
{error, not_found}.
-spec get_snapshot(guild_id()) -> {ok, guild_state()} | {error, not_found}.
get_snapshot(GuildId) when is_integer(GuildId) ->
ensure_table(),
@@ -91,16 +123,7 @@ get_snapshot(_) ->
-spec ensure_table() -> ok.
ensure_table() ->
case ets:whereis(?TABLE) of
undefined ->
try ets:new(?TABLE, [named_table, public, set, {read_concurrency, true}]) of
_ -> ok
catch
error:badarg -> ok
end;
_ ->
ok
end.
guild_ets_utils:ensure_table(?TABLE, [named_table, public, set, {read_concurrency, true}]).
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
@@ -129,4 +152,89 @@ put_and_get_permissions_test() ->
missing_guild_returns_not_found_test() ->
?assertEqual({error, not_found}, get_permissions(999999, 1, undefined)).
put_normalized_data_skips_renormalization_test() ->
GuildId = 102,
UserId = 45,
ViewPermission = constants:view_channel_permission(),
Data = guild_data_index:normalize_data(#{
<<"guild">> => #{<<"owner_id">> => <<"999">>},
<<"roles">> => [
#{<<"id">> => integer_to_binary(GuildId), <<"permissions">> => integer_to_binary(ViewPermission)}
],
<<"members">> => [
#{<<"user">> => #{<<"id">> => integer_to_binary(UserId)}, <<"roles">> => []}
],
<<"channels">> => [
#{<<"id">> => <<"600">>, <<"permission_overwrites">> => []}
]
}),
ok = put_normalized_data(GuildId, Data),
{ok, Permissions} = get_permissions(GuildId, UserId, 600),
?assert((Permissions band ViewPermission) =/= 0),
ok = delete(GuildId).
put_state_uses_fast_path_test() ->
GuildId = 103,
UserId = 46,
ViewPermission = constants:view_channel_permission(),
NormalizedData = guild_data_index:normalize_data(#{
<<"guild">> => #{<<"owner_id">> => <<"999">>},
<<"roles">> => [
#{<<"id">> => integer_to_binary(GuildId), <<"permissions">> => integer_to_binary(ViewPermission)}
],
<<"members">> => [
#{<<"user">> => #{<<"id">> => integer_to_binary(UserId)}, <<"roles">> => []}
],
<<"channels">> => [
#{<<"id">> => <<"700">>, <<"permission_overwrites">> => []}
]
}),
State = #{id => GuildId, data => NormalizedData},
ok = put_state(State),
{ok, Permissions} = get_permissions(GuildId, UserId, 700),
?assert((Permissions band ViewPermission) =/= 0),
ok = delete(GuildId).
has_member_returns_true_when_member_exists_test() ->
GuildId = 104,
UserId = 47,
Data = guild_data_index:normalize_data(#{
<<"guild">> => #{<<"owner_id">> => <<"999">>},
<<"roles">> => [],
<<"members">> => [
#{<<"user">> => #{<<"id">> => integer_to_binary(UserId)}, <<"roles">> => []}
],
<<"channels">> => []
}),
ok = put_normalized_data(GuildId, Data),
?assertEqual({ok, true}, has_member(GuildId, UserId)),
?assertEqual({ok, false}, has_member(GuildId, 99999)),
ok = delete(GuildId).
has_member_returns_not_found_when_no_snapshot_test() ->
?assertEqual({error, not_found}, has_member(999998, 1)).
get_member_returns_member_data_test() ->
GuildId = 105,
UserId = 48,
MemberData = #{
<<"user">> => #{<<"id">> => integer_to_binary(UserId)},
<<"roles">> => [],
<<"nick">> => <<"TestNick">>
},
Data = guild_data_index:normalize_data(#{
<<"guild">> => #{<<"owner_id">> => <<"999">>},
<<"roles">> => [],
<<"members">> => [MemberData],
<<"channels">> => []
}),
ok = put_normalized_data(GuildId, Data),
{ok, Result} = get_member(GuildId, UserId),
?assertEqual(<<"TestNick">>, maps:get(<<"nick">>, Result)),
{ok, undefined} = get_member(GuildId, 99999),
ok = delete(GuildId).
get_member_returns_not_found_when_no_snapshot_test() ->
?assertEqual({error, not_found}, get_member(999997, 1)).
-endif.

View File

@@ -110,10 +110,10 @@ broadcast_presence_update(UserId, Payload, State) ->
Sessions = maps:get(sessions, State, #{}),
MemberSubs = maps:get(member_subscriptions, State, guild_subscriptions:init_state()),
SubscribedSessionIds = guild_subscriptions:get_subscribed_sessions(UserId, MemberSubs),
TargetChannels = guild_visibility:viewable_channel_set(UserId, State),
TargetChannelMap = get_user_viewable_channel_map(UserId, Sessions, State),
{ValidSessionIds, InvalidSessionIds} =
partition_subscribed_sessions(
SubscribedSessionIds, Sessions, TargetChannels, UserId, State
SubscribedSessionIds, Sessions, TargetChannelMap, UserId, State
),
StateAfterInvalidRemovals =
lists:foldl(
@@ -245,9 +245,9 @@ member_id(Member) ->
User = map_utils:ensure_map(maps:get(<<"user">>, Member, #{})),
map_utils:get_integer(User, <<"id">>, undefined).
-spec partition_subscribed_sessions([binary()], map(), sets:set(), user_id(), guild_state()) ->
-spec partition_subscribed_sessions([binary()], map(), map(), user_id(), guild_state()) ->
{[binary()], [binary()]}.
partition_subscribed_sessions(SessionIds, Sessions, TargetChannels, TargetUserId, State) ->
partition_subscribed_sessions(SessionIds, Sessions, TargetChannelMap, TargetUserId, State) ->
lists:foldl(
fun(SessionId, {Valids, Invalids}) ->
case maps:get(SessionId, Sessions, undefined) of
@@ -262,11 +262,8 @@ partition_subscribed_sessions(SessionIds, Sessions, TargetChannels, TargetUserId
UserId when UserId =:= TargetUserId ->
false;
_ ->
SessionChannels = guild_visibility:viewable_channel_set(
SessionUserId, State
),
not sets:is_empty(
sets:intersection(SessionChannels, TargetChannels)
session_shares_channels(
SessionData, SessionUserId, TargetChannelMap, State
)
end,
case Shared of
@@ -279,6 +276,71 @@ partition_subscribed_sessions(SessionIds, Sessions, TargetChannels, TargetUserId
SessionIds
).
-spec session_shares_channels(map(), user_id(), map(), guild_state()) -> boolean().
session_shares_channels(SessionData, SessionUserId, TargetChannelMap, State) ->
case maps:get(viewable_channels, SessionData, undefined) of
ViewableMap when is_map(ViewableMap) ->
maps_share_any_key(ViewableMap, TargetChannelMap);
_ ->
SessionChannels = guild_visibility:viewable_channel_set(SessionUserId, State),
TargetChannels = sets:from_list(maps:keys(TargetChannelMap)),
not sets:is_empty(sets:intersection(SessionChannels, TargetChannels))
end.
-spec maps_share_any_key(map(), map()) -> boolean().
maps_share_any_key(MapA, MapB) ->
{Smaller, Larger} =
case map_size(MapA) =< map_size(MapB) of
true -> {MapA, MapB};
false -> {MapB, MapA}
end,
maps_share_any_key_iter(maps:iterator(Smaller), Larger).
-spec maps_share_any_key_iter(maps:iterator(), map()) -> boolean().
maps_share_any_key_iter(Iterator, LargerMap) ->
case maps:next(Iterator) of
none ->
false;
{Key, _, NextIterator} ->
case maps:is_key(Key, LargerMap) of
true -> true;
false -> maps_share_any_key_iter(NextIterator, LargerMap)
end
end.
-spec get_user_viewable_channel_map(user_id(), map(), guild_state()) -> map().
get_user_viewable_channel_map(UserId, Sessions, State) ->
case find_session_viewable_channels_for_user(UserId, Sessions) of
undefined ->
ChannelList = guild_visibility:get_user_viewable_channels(UserId, State),
maps:from_list([{Ch, true} || Ch <- ChannelList]);
ViewableMap ->
ViewableMap
end.
-spec find_session_viewable_channels_for_user(user_id(), map()) -> map() | undefined.
find_session_viewable_channels_for_user(UserId, Sessions) ->
find_session_viewable_channels_iter(UserId, maps:iterator(Sessions)).
-spec find_session_viewable_channels_iter(user_id(), maps:iterator()) -> map() | undefined.
find_session_viewable_channels_iter(UserId, Iterator) ->
case maps:next(Iterator) of
none ->
undefined;
{_, SessionData, NextIterator} ->
case maps:get(user_id, SessionData, undefined) of
UserId ->
case maps:get(viewable_channels, SessionData, undefined) of
ViewableChannels when is_map(ViewableChannels) ->
ViewableChannels;
_ ->
find_session_viewable_channels_iter(UserId, NextIterator)
end;
_ ->
find_session_viewable_channels_iter(UserId, NextIterator)
end
end.
-spec remove_session_member_subscription(binary(), user_id(), guild_state()) -> guild_state().
remove_session_member_subscription(SessionId, UserId, State) ->
MemberSubs = maps:get(member_subscriptions, State, guild_subscriptions:init_state()),
@@ -334,6 +396,96 @@ normalize_presence_status_test() ->
?assertEqual(<<"idle">>, normalize_presence_status(<<"idle">>)),
?assertEqual(<<"offline">>, normalize_presence_status(undefined)).
handle_bus_presence_invisible_normalized_test() ->
State = presence_test_state(),
Payload = #{
<<"status">> => <<"invisible">>,
<<"mobile">> => false,
<<"afk">> => false,
<<"user">> => #{<<"id">> => <<"1">>, <<"username">> => <<"Alpha">>}
},
{noreply, NewState} = handle_bus_presence(1, Payload, State),
MemberPresence = maps:get(member_presence, NewState, #{}),
UserPresence = maps:get(1, MemberPresence),
?assertEqual(<<"offline">>, maps:get(<<"status">>, UserPresence)).
maps_share_any_key_empty_test() ->
?assertEqual(false, maps_share_any_key(#{}, #{})),
?assertEqual(false, maps_share_any_key(#{1 => true}, #{})),
?assertEqual(false, maps_share_any_key(#{}, #{1 => true})).
maps_share_any_key_overlap_test() ->
?assertEqual(true, maps_share_any_key(#{1 => true, 2 => true}, #{2 => true, 3 => true})),
?assertEqual(true, maps_share_any_key(#{5 => true}, #{5 => true})).
maps_share_any_key_no_overlap_test() ->
?assertEqual(false, maps_share_any_key(#{1 => true, 2 => true}, #{3 => true, 4 => true})).
get_user_viewable_channel_map_uses_session_cache_test() ->
Sessions = #{
<<"s1">> => #{user_id => 10, viewable_channels => #{100 => true, 200 => true}},
<<"s2">> => #{user_id => 20, viewable_channels => #{300 => true}}
},
State = #{sessions => Sessions, data => #{<<"members">> => #{}}},
Result = get_user_viewable_channel_map(10, Sessions, State),
?assertEqual(#{100 => true, 200 => true}, Result).
get_user_viewable_channel_map_skips_session_without_cache_test() ->
Sessions = #{
<<"s1">> => #{user_id => 10},
<<"s2">> => #{user_id => 10, viewable_channels => #{100 => true}}
},
State = #{sessions => Sessions, data => #{<<"members">> => #{}}},
Result = get_user_viewable_channel_map(10, Sessions, State),
?assertEqual(#{100 => true}, Result).
session_shares_channels_uses_cached_viewable_test() ->
SessionData = #{user_id => 20, viewable_channels => #{100 => true, 200 => true}},
TargetChannelMap = #{200 => true, 300 => true},
State = #{sessions => #{}, data => #{<<"members">> => #{}}},
?assertEqual(true, session_shares_channels(SessionData, 20, TargetChannelMap, State)).
session_shares_channels_no_overlap_test() ->
SessionData = #{user_id => 20, viewable_channels => #{100 => true}},
TargetChannelMap = #{200 => true, 300 => true},
State = #{sessions => #{}, data => #{<<"members">> => #{}}},
?assertEqual(false, session_shares_channels(SessionData, 20, TargetChannelMap, State)).
partition_subscribed_sessions_uses_cached_channels_test() ->
Sessions = #{
<<"s1">> => #{user_id => 20, pid => self(), viewable_channels => #{100 => true}},
<<"s2">> => #{user_id => 30, pid => self(), viewable_channels => #{200 => true}}
},
TargetChannelMap = #{100 => true, 300 => true},
State = #{sessions => Sessions, data => #{<<"members">> => #{}}},
{Valid, Invalid} = partition_subscribed_sessions(
[<<"s1">>, <<"s2">>], Sessions, TargetChannelMap, 10, State
),
?assertEqual([<<"s1">>], Valid),
?assertEqual([<<"s2">>], Invalid).
partition_subscribed_sessions_excludes_target_user_test() ->
Sessions = #{
<<"s1">> => #{user_id => 10, pid => self(), viewable_channels => #{100 => true}}
},
TargetChannelMap = #{100 => true},
State = #{sessions => Sessions, data => #{<<"members">> => #{}}},
{Valid, Invalid} = partition_subscribed_sessions(
[<<"s1">>], Sessions, TargetChannelMap, 10, State
),
?assertEqual([], Valid),
?assertEqual([<<"s1">>], Invalid).
partition_subscribed_sessions_missing_session_test() ->
Sessions = #{},
TargetChannelMap = #{100 => true},
State = #{sessions => Sessions, data => #{<<"members">> => #{}}},
{Valid, Invalid} = partition_subscribed_sessions(
[<<"s1">>], Sessions, TargetChannelMap, 10, State
),
?assertEqual([], Valid),
?assertEqual([<<"s1">>], Invalid).
presence_test_state() ->
#{
id => 42,

View File

@@ -0,0 +1,251 @@
%% Copyright (C) 2026 Fluxer Contributors
%%
%% This file is part of Fluxer.
%%
%% Fluxer is free software: you can redistribute it and/or modify
%% it under the terms of the GNU Affero General Public License as published by
%% the Free Software Foundation, either version 3 of the License, or
%% (at your option) any later version.
%%
%% Fluxer is distributed in the hope that it will be useful,
%% but WITHOUT ANY WARRANTY; without even the implied warranty of
%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
%% GNU Affero General Public License for more details.
%%
%% You should have received a copy of the GNU Affero General Public License
%% along with Fluxer. If not, see <https://www.gnu.org/licenses/>.
-module(guild_query_handler).
-export([handle_call/3]).
-type guild_state() :: map().
-type user_id() :: integer().
-spec handle_call(term(), gen_server:from(), guild_state()) ->
{reply, term(), guild_state()}
| {noreply, guild_state()}.
handle_call({very_large_guild_prime_member, Member}, _From, State) when is_map(Member) ->
Data0 = maps:get(data, State, #{}),
Data = guild_data_index:put_member(Member, Data0),
{reply, ok, maps:put(data, Data, State)};
handle_call({very_large_guild_prime_member, _}, _From, State) ->
{reply, ok, State};
handle_call({very_large_guild_get_members, UserIds}, _From, State) when is_list(UserIds) ->
Data = maps:get(data, State, #{}),
MemberMap = guild_data_index:member_map(Data),
Reply = lists:foldl(
fun(UserId, Acc) ->
case maps:get(UserId, MemberMap, undefined) of
Member when is_map(Member) -> maps:put(UserId, Member, Acc);
_ -> Acc
end
end,
#{},
UserIds
),
{reply, Reply, State};
handle_call({get_counts}, _From, State) ->
MemberCount = maps:get(member_count, State, 0),
OnlineCount = guild_member_list:get_online_count(State),
GuildId = maps:get(id, State, undefined),
case is_integer(GuildId) of
true -> guild_counts_cache:update(GuildId, MemberCount, OnlineCount);
false -> ok
end,
{reply, #{member_count => MemberCount, presence_count => OnlineCount}, State};
handle_call({get_large_guild_metadata}, _From, State) ->
MemberCount = maps:get(member_count, State, 0),
Data = maps:get(data, State, #{}),
Guild = maps:get(<<"guild">>, Data, #{}),
Features = maps:get(<<"features">>, Guild, []),
{reply, #{member_count => MemberCount, features => Features}, State};
handle_call({get_users_to_mention_by_roles, Request}, _From, State) ->
spawn_async_reply(
_From,
fun() ->
{reply, Reply, _} = guild_members:get_users_to_mention_by_roles(Request, State),
Reply
end
),
{noreply, State};
handle_call({get_users_to_mention_by_user_ids, Request}, _From, State) ->
spawn_async_reply(
_From,
fun() ->
{reply, Reply, _} = guild_members:get_users_to_mention_by_user_ids(Request, State),
Reply
end
),
{noreply, State};
handle_call({get_all_users_to_mention, Request}, _From, State) ->
spawn_async_reply(
_From,
fun() ->
{reply, Reply, _} = guild_members:get_all_users_to_mention(Request, State),
Reply
end
),
{noreply, State};
handle_call({resolve_all_mentions, Request}, _From, State) ->
spawn_async_reply(
_From,
fun() ->
{reply, Reply, _} = guild_members:resolve_all_mentions(Request, State),
Reply
end
),
{noreply, State};
handle_call({get_members_with_role, Request}, _From, State) ->
spawn_async_reply(
_From,
fun() ->
{reply, Reply, _} = guild_members:get_members_with_role(Request, State),
Reply
end
),
{noreply, State};
handle_call({check_permission, Request}, _From, State) ->
spawn_async_reply(
_From,
fun() ->
#{user_id := UserId, permission := Permission, channel_id := ChannelId} = Request,
true = is_integer(Permission),
HasPermission =
case owner_id(State) =:= UserId of
true ->
true;
false ->
Permissions = guild_permissions:get_member_permissions(
UserId, ChannelId, State
),
(Permissions band Permission) =:= Permission
end,
#{has_permission => HasPermission}
end
),
{noreply, State};
handle_call({get_user_permissions, Request}, _From, State) ->
spawn_async_reply(
_From,
fun() ->
#{user_id := UserId, channel_id := ChannelId} = Request,
Permissions = guild_permissions:get_member_permissions(UserId, ChannelId, State),
#{permissions => Permissions}
end
),
{noreply, State};
handle_call({can_manage_roles, Request}, _From, State) ->
guild_members:can_manage_roles(Request, State);
handle_call({can_manage_role, Request}, _From, State) ->
guild_members:can_manage_role(Request, State);
handle_call({get_guild_data, Request}, _From, State) ->
guild_data:get_guild_data(Request, State);
handle_call({get_assignable_roles, Request}, _From, State) ->
guild_members:get_assignable_roles(Request, State);
handle_call({get_user_max_role_position, Request}, _From, State) ->
#{user_id := UserId} = Request,
Position = guild_permissions:get_max_role_position(UserId, State),
{reply, #{position => Position}, State};
handle_call({check_target_member, Request}, _From, State) ->
guild_members:check_target_member(Request, State);
handle_call({get_viewable_channels, Request}, _From, State) ->
spawn_async_reply(
_From,
fun() ->
{reply, Reply, _} = guild_members:get_viewable_channels(Request, State),
Reply
end
),
{noreply, State};
handle_call({get_guild_member, Request}, _From, State) ->
guild_data:get_guild_member(Request, State);
handle_call({has_member, Request}, _From, State) ->
guild_data:has_member(Request, State);
handle_call({list_guild_members, Request}, _From, State) ->
guild_data:list_guild_members(Request, State);
handle_call({list_guild_members_cursor, Request}, _From, State) ->
guild_member_list:get_members_cursor(Request, State);
handle_call({get_vanity_url_channel}, _From, State) ->
guild_data:get_vanity_url_channel(State);
handle_call({get_first_viewable_text_channel}, _From, State) ->
guild_data:get_first_viewable_text_channel(State);
handle_call({get_category_channel_count, Request}, _From, State) ->
#{category_id := CategoryId} = Request,
Data = maps:get(data, State),
Channels = maps:get(<<"channels">>, Data, []),
Count = length([
Ch
|| Ch <- Channels,
map_utils:get_integer(Ch, <<"parent_id">>, undefined) =:= CategoryId
]),
{reply, #{count => Count}, State};
handle_call({get_channel_count}, _From, State) ->
Data = maps:get(data, State),
Channels = maps:get(<<"channels">>, Data, []),
Count = length(Channels),
{reply, #{count => Count}, State};
handle_call({get_sessions}, _From, State) ->
{reply, State, State};
handle_call({get_push_base_state}, _From, State) ->
{reply,
#{
id => maps:get(id, State, 0),
data => maps:get(data, State, #{}),
virtual_channel_access => maps:get(virtual_channel_access, State, #{})
},
State};
handle_call({get_cluster_merge_state}, _From, State) ->
{reply,
#{
sessions => maps:get(sessions, State, #{}),
voice_states => maps:get(voice_states, State, #{}),
virtual_channel_access => maps:get(virtual_channel_access, State, #{}),
virtual_channel_access_pending => maps:get(virtual_channel_access_pending, State, #{}),
virtual_channel_access_preserve => maps:get(virtual_channel_access_preserve, State, #{}),
virtual_channel_access_move_pending =>
maps:get(virtual_channel_access_move_pending, State, #{})
},
State}.
-spec spawn_async_reply(gen_server:from(), fun(() -> term())) -> ok.
spawn_async_reply(From, ReplyFun) ->
spawn(fun() ->
Reply =
try
ReplyFun()
catch
_:_ ->
#{error => async_handler_failed}
end,
gen_server:reply(From, Reply)
end),
ok.
-spec owner_id(guild_state()) -> user_id().
owner_id(State) ->
case resolve_data_map(State) of
undefined ->
0;
Data ->
Guild = maps:get(<<"guild">>, Data, #{}),
type_conv:to_integer(maps:get(<<"owner_id">>, Guild, <<"0">>))
end.
-spec resolve_data_map(guild_state() | map()) -> map() | undefined.
resolve_data_map(State) when is_map(State) ->
case maps:find(data, State) of
{ok, Data} when is_map(Data) ->
Data;
{ok, Data} when is_map(Data) =:= false ->
undefined;
error ->
case State of
#{<<"members">> := _} ->
State;
_ ->
undefined
end
end;
resolve_data_map(_) ->
undefined.

View File

@@ -460,4 +460,321 @@ normalize_nonce_test() ->
?assertEqual(null, normalize_nonce(<<"this_nonce_is_way_too_long_to_be_valid">>)),
?assertEqual(null, normalize_nonce(undefined)).
validate_user_ids_too_many_test() ->
UserIds = lists:seq(1, 101),
?assertEqual({error, too_many_user_ids}, validate_user_ids(UserIds)).
validate_user_ids_exactly_max_test() ->
UserIds = lists:seq(1, 100),
{ok, Parsed} = validate_user_ids(UserIds),
?assertEqual(100, length(Parsed)).
validate_user_ids_non_list_test() ->
{ok, []} = validate_user_ids(not_a_list).
validate_user_ids_filters_invalid_test() ->
{ok, Parsed} = validate_user_ids([<<"1">>, <<"invalid">>, 3, -5, 0]),
?assertEqual([1, 3], Parsed).
validate_user_ids_empty_test() ->
{ok, []} = validate_user_ids([]).
parse_user_id_integer_test() ->
?assertEqual({ok, 42}, parse_user_id(42)).
parse_user_id_binary_test() ->
?assertEqual({ok, 123}, parse_user_id(<<"123">>)).
parse_user_id_zero_test() ->
?assertEqual(error, parse_user_id(0)).
parse_user_id_negative_test() ->
?assertEqual(error, parse_user_id(-1)).
parse_user_id_invalid_binary_test() ->
?assertEqual(error, parse_user_id(<<"abc">>)).
parse_user_id_other_type_test() ->
?assertEqual(error, parse_user_id(1.5)).
ensure_binary_binary_test() ->
?assertEqual(<<"hello">>, ensure_binary(<<"hello">>)).
ensure_binary_integer_test() ->
?assertEqual(<<>>, ensure_binary(42)).
ensure_binary_undefined_test() ->
?assertEqual(<<>>, ensure_binary(undefined)).
ensure_limit_valid_test() ->
?assertEqual(10, ensure_limit(10)).
ensure_limit_zero_test() ->
?assertEqual(0, ensure_limit(0)).
ensure_limit_negative_test() ->
?assertEqual(0, ensure_limit(-1)).
ensure_limit_non_integer_test() ->
?assertEqual(0, ensure_limit(<<"10">>)).
validate_guild_id_integer_test() ->
?assertEqual({ok, 123}, validate_guild_id(123)).
validate_guild_id_zero_test() ->
?assertEqual({error, invalid_guild_id}, validate_guild_id(0)).
validate_guild_id_negative_test() ->
?assertEqual({error, invalid_guild_id}, validate_guild_id(-1)).
validate_guild_id_atom_test() ->
?assertEqual({error, invalid_guild_id}, validate_guild_id(undefined)).
build_chunk_data_basic_test() ->
Members = [#{<<"user">> => #{<<"id">> => <<"1">>}}],
Result = build_chunk_data(Members, [], 0, 1, null),
?assertEqual(Members, maps:get(<<"members">>, Result)),
?assertEqual(0, maps:get(<<"chunk_index">>, Result)),
?assertEqual(1, maps:get(<<"chunk_count">>, Result)),
?assertNot(maps:is_key(<<"presences">>, Result)),
?assertNot(maps:is_key(<<"nonce">>, Result)).
build_chunk_data_with_presences_test() ->
Members = [#{<<"user">> => #{<<"id">> => <<"1">>}}],
Presences = [#{<<"user">> => #{<<"id">> => <<"1">>}, <<"status">> => <<"online">>}],
Result = build_chunk_data(Members, Presences, 0, 1, null),
?assertEqual(Presences, maps:get(<<"presences">>, Result)),
?assertNot(maps:is_key(<<"nonce">>, Result)).
build_chunk_data_with_nonce_test() ->
Members = [],
Result = build_chunk_data(Members, [], 0, 1, <<"my_nonce">>),
?assertEqual(<<"my_nonce">>, maps:get(<<"nonce">>, Result)).
build_chunk_data_with_presences_and_nonce_test() ->
Members = [#{<<"user">> => #{<<"id">> => <<"1">>}}],
Presences = [#{<<"user">> => #{<<"id">> => <<"1">>}, <<"status">> => <<"online">>}],
Result = build_chunk_data(Members, Presences, 2, 5, <<"nonce1">>),
?assertEqual(Presences, maps:get(<<"presences">>, Result)),
?assertEqual(<<"nonce1">>, maps:get(<<"nonce">>, Result)),
?assertEqual(2, maps:get(<<"chunk_index">>, Result)),
?assertEqual(5, maps:get(<<"chunk_count">>, Result)).
chunk_presences_aligns_with_member_chunks_test() ->
Members1 = [
#{<<"user">> => #{<<"id">> => <<"1">>}},
#{<<"user">> => #{<<"id">> => <<"2">>}}
],
Members2 = [
#{<<"user">> => #{<<"id">> => <<"3">>}}
],
Presences = [
#{<<"user">> => #{<<"id">> => <<"1">>}, <<"status">> => <<"online">>},
#{<<"user">> => #{<<"id">> => <<"3">>}, <<"status">> => <<"idle">>}
],
Result = chunk_presences(Presences, [Members1, Members2]),
?assertEqual(2, length(Result)),
[P1, P2] = Result,
?assertEqual(1, length(P1)),
?assertEqual(1, length(P2)).
chunk_presences_empty_presences_test() ->
Members = [#{<<"user">> => #{<<"id">> => <<"1">>}}],
Result = chunk_presences([], [Members]),
?assertEqual([[]], Result).
chunk_presences_no_matching_presences_test() ->
Members = [#{<<"user">> => #{<<"id">> => <<"1">>}}],
Presences = [#{<<"user">> => #{<<"id">> => <<"999">>}, <<"status">> => <<"online">>}],
Result = chunk_presences(Presences, [Members]),
?assertEqual([[]], Result).
filter_members_by_ids_basic_test() ->
Members = [
#{<<"user">> => #{<<"id">> => <<"1">>}},
#{<<"user">> => #{<<"id">> => <<"2">>}},
#{<<"user">> => #{<<"id">> => <<"3">>}}
],
Result = filter_members_by_ids(Members, [1, 3]),
?assertEqual(2, length(Result)).
filter_members_by_ids_empty_ids_test() ->
Members = [#{<<"user">> => #{<<"id">> => <<"1">>}}],
Result = filter_members_by_ids(Members, []),
?assertEqual([], Result).
filter_members_by_ids_no_match_test() ->
Members = [#{<<"user">> => #{<<"id">> => <<"1">>}}],
Result = filter_members_by_ids(Members, [999]),
?assertEqual([], Result).
filter_members_by_ids_skips_invalid_members_test() ->
Members = [#{}, #{<<"user">> => #{}}, #{<<"user">> => #{<<"id">> => <<"1">>}}],
Result = filter_members_by_ids(Members, [1]),
?assertEqual(1, length(Result)).
filter_members_by_query_case_insensitive_test() ->
Members = [
#{<<"user">> => #{<<"id">> => <<"1">>, <<"username">> => <<"Alice">>}},
#{<<"user">> => #{<<"id">> => <<"2">>, <<"username">> => <<"bob">>}}
],
Results = filter_members_by_query(Members, <<"ALICE">>, 10),
?assertEqual(1, length(Results)).
filter_members_by_query_respects_limit_test() ->
Members = [
#{<<"user">> => #{<<"id">> => <<"1">>, <<"username">> => <<"alice1">>}},
#{<<"user">> => #{<<"id">> => <<"2">>, <<"username">> => <<"alice2">>}},
#{<<"user">> => #{<<"id">> => <<"3">>, <<"username">> => <<"alice3">>}}
],
Results = filter_members_by_query(Members, <<"alice">>, 2),
?assertEqual(2, length(Results)).
filter_members_by_query_empty_query_matches_all_test() ->
Members = [
#{<<"user">> => #{<<"id">> => <<"1">>, <<"username">> => <<"alice">>}},
#{<<"user">> => #{<<"id">> => <<"2">>, <<"username">> => <<"bob">>}}
],
Results = filter_members_by_query(Members, <<>>, 10),
?assertEqual(2, length(Results)).
filter_members_by_query_no_match_test() ->
Members = [
#{<<"user">> => #{<<"id">> => <<"1">>, <<"username">> => <<"alice">>}}
],
Results = filter_members_by_query(Members, <<"zzz">>, 10),
?assertEqual(0, length(Results)).
filter_members_by_query_matches_nick_test() ->
Members = [
#{
<<"user">> => #{<<"id">> => <<"1">>, <<"username">> => <<"alice">>},
<<"nick">> => <<"SuperNick">>
}
],
Results = filter_members_by_query(Members, <<"super">>, 10),
?assertEqual(1, length(Results)).
get_display_name_null_nick_test() ->
Member = #{
<<"user">> => #{<<"username">> => <<"user">>},
<<"nick">> => null
},
?assertEqual(<<"user">>, get_display_name(Member)).
get_display_name_non_binary_nick_test() ->
Member = #{
<<"user">> => #{<<"username">> => <<"user">>},
<<"nick">> => 12345
},
?assertEqual(<<"user">>, get_display_name(Member)).
get_display_name_non_map_test() ->
?assertEqual(<<>>, get_display_name(not_a_map)).
get_display_name_null_global_name_test() ->
Member = #{<<"user">> => #{<<"username">> => <<"user">>, <<"global_name">> => null}},
?assertEqual(<<"user">>, get_display_name(Member)).
get_display_name_non_binary_global_name_test() ->
Member = #{<<"user">> => #{<<"username">> => <<"user">>, <<"global_name">> => 12345}},
?assertEqual(<<"user">>, get_display_name(Member)).
get_username_null_test() ->
?assertEqual(<<>>, get_username(#{<<"username">> => null})).
get_username_undefined_test() ->
?assertEqual(<<>>, get_username(#{<<"username">> => undefined})).
get_username_non_binary_test() ->
?assertEqual(<<>>, get_username(#{<<"username">> => 12345})).
get_username_missing_test() ->
?assertEqual(<<>>, get_username(#{})).
extract_user_id_valid_test() ->
?assertEqual(42, extract_user_id(#{<<"user">> => #{<<"id">> => <<"42">>}})).
extract_user_id_missing_user_test() ->
?assertEqual(undefined, extract_user_id(#{})).
extract_user_id_non_map_test() ->
?assertEqual(undefined, extract_user_id(not_a_map)).
presence_visible_online_test() ->
?assertEqual(true, presence_visible(#{<<"status">> => <<"online">>})).
presence_visible_idle_test() ->
?assertEqual(true, presence_visible(#{<<"status">> => <<"idle">>})).
presence_visible_dnd_test() ->
?assertEqual(true, presence_visible(#{<<"status">> => <<"dnd">>})).
presence_visible_offline_test() ->
?assertEqual(false, presence_visible(#{<<"status">> => <<"offline">>})).
presence_visible_invisible_test() ->
?assertEqual(false, presence_visible(#{<<"status">> => <<"invisible">>})).
presence_visible_missing_status_test() ->
?assertEqual(false, presence_visible(#{})).
normalize_nonce_exactly_max_length_test() ->
Nonce = list_to_binary(lists:duplicate(32, $a)),
?assertEqual(Nonce, normalize_nonce(Nonce)).
normalize_nonce_one_over_max_test() ->
Nonce = list_to_binary(lists:duplicate(33, $a)),
?assertEqual(null, normalize_nonce(Nonce)).
normalize_nonce_empty_binary_test() ->
?assertEqual(<<>>, normalize_nonce(<<>>)).
normalize_nonce_integer_test() ->
?assertEqual(null, normalize_nonce(42)).
normalize_nonce_null_atom_test() ->
?assertEqual(null, normalize_nonce(null)).
parse_request_defaults_test() ->
Data = #{<<"guild_id">> => 12345},
{ok, Request} = parse_request(Data),
?assertEqual(12345, maps:get(guild_id, Request)),
?assertEqual(<<>>, maps:get(query, Request)),
?assertEqual(0, maps:get(limit, Request)),
?assertEqual([], maps:get(user_ids, Request)),
?assertEqual(false, maps:get(presences, Request)),
?assertEqual(null, maps:get(nonce, Request)).
parse_request_non_binary_query_test() ->
Data = #{<<"guild_id">> => 123, <<"query">> => 42},
{ok, Request} = parse_request(Data),
?assertEqual(<<>>, maps:get(query, Request)).
parse_request_negative_limit_test() ->
Data = #{<<"guild_id">> => 123, <<"limit">> => -5},
{ok, Request} = parse_request(Data),
?assertEqual(0, maps:get(limit, Request)).
parse_request_presences_not_true_test() ->
Data = #{<<"guild_id">> => 123, <<"presences">> => <<"yes">>},
{ok, Request} = parse_request(Data),
?assertEqual(false, maps:get(presences, Request)).
parse_request_missing_guild_id_test() ->
Data = #{<<"query">> => <<"test">>},
?assertEqual({error, invalid_guild_id}, parse_request(Data)).
handle_request_invalid_data_test() ->
?assertEqual({error, invalid_request}, handle_request(not_a_map, self(), #{})).
chunk_list_single_element_test() ->
?assertEqual([[1]], chunk_list([1], 1)).
chunk_list_exact_multiple_test() ->
?assertEqual([[1, 2], [3, 4]], chunk_list([1, 2, 3, 4], 2)).
chunk_list_large_size_test() ->
?assertEqual([[1, 2, 3]], chunk_list([1, 2, 3], 1000)).
-endif.

View File

@@ -82,11 +82,14 @@ register_new_session(Request, Pid, UserId, SessionId, State) ->
user_roles => UserRoles,
bot => Bot,
is_staff => maps:get(is_staff, Request, false),
previous_passive_updates => InitialLastMessageIds,
previous_passive_channel_versions => InitialChannelVersions,
previous_passive_voice_states => #{},
viewable_channels => InitialViewableChannels
},
InitialPassiveState = #{
previous_passive_updates => InitialLastMessageIds,
previous_passive_channel_versions => InitialChannelVersions,
previous_passive_voice_states => #{}
},
passive_sync_registry:store(SessionId, GuildId, InitialPassiveState),
NewSessions = maps:put(SessionId, SessionData, Sessions),
State1 = maps:put(sessions, NewSessions, State),
State2 = subscribe_to_user_presence(UserId, State1),
@@ -199,7 +202,9 @@ cleanup_disconnecting_session(undefined, State) ->
cleanup_disconnecting_session(Session, State) ->
UserId = maps:get(user_id, Session),
SessionId = maps:get(session_id, Session),
GuildId = maps:get(id, State),
_ = maybe_notify_coordinator(session_disconnected, SessionId, UserId, State),
passive_sync_registry:delete(SessionId, GuildId),
StateAfterPresence = unsubscribe_from_user_presence(UserId, State),
StateAfterMemberList = guild_member_list:unsubscribe_session(SessionId, StateAfterPresence),
MemberSubs = maps:get(
@@ -675,12 +680,10 @@ remove_session_removes_entry_test() ->
mref => make_ref(),
active_guilds => sets:new(),
user_roles => [],
bot => false,
previous_passive_updates => #{},
previous_passive_channel_versions => #{},
previous_passive_voice_states => #{}
bot => false
},
State = #{
id => 42,
sessions => #{SessionId => SessionData},
presence_subscriptions => #{1 => 1},
member_list_subscriptions => #{},
@@ -699,12 +702,10 @@ remove_session_cleans_connect_pending_test() ->
mref => make_ref(),
active_guilds => sets:new(),
user_roles => [],
bot => false,
previous_passive_updates => #{},
previous_passive_channel_versions => #{},
previous_passive_voice_states => #{}
bot => false
},
State = #{
id => 42,
sessions => #{SessionId => SessionData},
presence_subscriptions => #{1 => 1},
member_list_subscriptions => #{},
@@ -804,4 +805,222 @@ pending_connect_filtered_from_channel_sessions_test() ->
[{ResultSid, _}] = Result,
?assertEqual(<<"s1">>, ResultSid).
set_session_viewable_channels_test() ->
Sessions = #{<<"s1">> => #{user_id => 1, pid => self()}},
State = #{sessions => Sessions},
ViewableChannels = #{100 => true, 200 => true},
UpdatedState = set_session_viewable_channels(<<"s1">>, ViewableChannels, State),
UpdatedSession = maps:get(<<"s1">>, maps:get(sessions, UpdatedState)),
?assertEqual(ViewableChannels, maps:get(viewable_channels, UpdatedSession)).
set_session_viewable_channels_missing_session_test() ->
State = #{sessions => #{}},
Result = set_session_viewable_channels(<<"nonexistent">>, #{100 => true}, State),
?assertEqual(State, Result).
set_session_active_guild_missing_session_test() ->
State = #{sessions => #{}},
Result = set_session_active_guild(<<"nonexistent">>, 42, State),
?assertEqual(State, Result).
set_session_passive_guild_missing_session_test() ->
State = #{sessions => #{}},
Result = set_session_passive_guild(<<"nonexistent">>, 42, State),
?assertEqual(State, Result).
is_session_active_missing_session_test() ->
State = #{id => 42, sessions => #{}},
?assertEqual(false, is_session_active(<<"nonexistent">>, State)).
filter_sessions_for_channel_excludes_specified_session_test() ->
S1 = #{session_id => <<"s1">>, user_id => 10, pid => self(), viewable_channels => #{200 => true}},
S2 = #{session_id => <<"s2">>, user_id => 11, pid => self(), viewable_channels => #{200 => true}},
Sessions = #{<<"s1">> => S1, <<"s2">> => S2},
State = #{sessions => Sessions, data => #{<<"members">> => #{}}},
Result = filter_sessions_for_channel(Sessions, 200, <<"s1">>, State),
?assertEqual(1, length(Result)),
[{ResultSid, _}] = Result,
?assertEqual(<<"s2">>, ResultSid).
filter_sessions_for_channel_falls_back_to_permission_check_test() ->
GuildId = 42,
UserId = 10,
ChannelId = 200,
ViewPerm = constants:view_channel_permission(),
SessionData = #{
session_id => <<"s1">>,
user_id => UserId,
pid => self()
},
Sessions = #{<<"s1">> => SessionData},
State = #{
id => GuildId,
sessions => Sessions,
data => #{
<<"guild">> => #{<<"owner_id">> => <<"999">>},
<<"roles">> => [
#{<<"id">> => integer_to_binary(GuildId), <<"permissions">> => integer_to_binary(ViewPerm)}
],
<<"members">> => #{
UserId => #{<<"user">> => #{<<"id">> => integer_to_binary(UserId)}, <<"roles">> => []}
},
<<"channels">> => [
#{<<"id">> => integer_to_binary(ChannelId), <<"permission_overwrites">> => []}
]
}
},
Result = filter_sessions_for_channel(Sessions, ChannelId, undefined, State),
?assertEqual(1, length(Result)).
filter_sessions_for_channel_no_member_returns_empty_test() ->
SessionData = #{
session_id => <<"s1">>,
user_id => 999,
pid => self()
},
Sessions = #{<<"s1">> => SessionData},
State = #{
sessions => Sessions,
data => #{
<<"guild">> => #{<<"owner_id">> => <<"888">>},
<<"members">> => [],
<<"roles">> => [],
<<"channels">> => []
}
},
Result = filter_sessions_for_channel(Sessions, 200, undefined, State),
?assertEqual([], Result).
filter_sessions_exclude_session_filters_pending_test() ->
Sessions = #{
<<"s1">> => #{pending_connect => true},
<<"s2">> => #{},
<<"s3">> => #{pending_connect => false}
},
Result = filter_sessions_exclude_session(Sessions, undefined),
ResultIds = lists:sort([Sid || {Sid, _} <- Result]),
?assertEqual([<<"s2">>, <<"s3">>], ResultIds).
subscribe_unsubscribe_presence_test_() ->
{setup,
fun() -> ensure_test_deps() end,
fun(_) -> stop_test_deps() end,
fun(_) ->
[fun() ->
State0 = #{presence_subscriptions => #{}},
State1 = subscribe_to_user_presence(10, State0),
Subs1 = maps:get(presence_subscriptions, State1),
?assertEqual(1, maps:get(10, Subs1)),
State2 = subscribe_to_user_presence(10, State1),
Subs2 = maps:get(presence_subscriptions, State2),
?assertEqual(2, maps:get(10, Subs2)),
State3 = unsubscribe_from_user_presence(10, State2),
Subs3 = maps:get(presence_subscriptions, State3),
?assertEqual(1, maps:get(10, Subs3)),
State4 = unsubscribe_from_user_presence(10, State3),
Subs4 = maps:get(presence_subscriptions, State4),
?assertEqual(0, maps:get(10, Subs4))
end]
end}.
unsubscribe_from_user_presence_zero_count_noop_test() ->
State = #{presence_subscriptions => #{10 => 0}},
Result = unsubscribe_from_user_presence(10, State),
?assertEqual(State, Result).
unsubscribe_from_user_presence_missing_user_noop_test() ->
State = #{presence_subscriptions => #{}},
Result = unsubscribe_from_user_presence(999, State),
?assertEqual(State, Result).
handle_user_offline_nonzero_count_noop_test() ->
State = #{
presence_subscriptions => #{10 => 1},
member_presence => #{10 => #{<<"status">> => <<"online">>}}
},
Result = handle_user_offline(10, State),
?assertEqual(State, Result).
handle_user_offline_missing_user_noop_test() ->
State = #{presence_subscriptions => #{}},
Result = handle_user_offline(999, State),
?assertEqual(State, Result).
should_auto_stop_on_empty_default_test() ->
State = #{},
?assertEqual(true, should_auto_stop_on_empty(State)).
should_auto_stop_on_empty_disabled_test() ->
State = #{disable_auto_stop_on_empty => true},
?assertEqual(false, should_auto_stop_on_empty(State)).
should_auto_stop_on_empty_vlg_coordinator_test() ->
State = #{very_large_guild_coordinator_pid => self()},
?assertEqual(false, should_auto_stop_on_empty(State)).
build_viewable_channel_map_test() ->
Map = build_viewable_channel_map([100, 200, 300]),
?assertEqual(3, map_size(Map)),
?assertEqual(true, maps:get(100, Map)),
?assertEqual(true, maps:get(200, Map)),
?assertEqual(true, maps:get(300, Map)).
build_viewable_channel_map_empty_test() ->
?assertEqual(#{}, build_viewable_channel_map([])).
normalize_connect_queue_list_test() ->
List = [#{a => 1}, #{a => 2}],
Queue = normalize_connect_queue(List),
?assert(queue:is_queue(Queue)),
?assertEqual(2, queue:len(Queue)).
normalize_connect_queue_queue_test() ->
Q = queue:from_list([1, 2, 3]),
?assertEqual(Q, normalize_connect_queue(Q)).
normalize_connect_queue_undefined_test() ->
?assertEqual(undefined, normalize_connect_queue(undefined)),
?assertEqual(undefined, normalize_connect_queue(42)).
ensure_test_deps() ->
ensure_mock_registered(presence_bus),
ensure_mock_registered(presence_cache).
stop_test_deps() ->
stop_mock_registered(presence_bus),
stop_mock_registered(presence_cache).
ensure_mock_registered(Name) ->
case whereis(Name) of
undefined ->
Pid = spawn(fun() -> mock_gen_server_loop() end),
register(Name, Pid),
Pid;
Pid ->
Pid
end.
stop_mock_registered(Name) ->
case whereis(Name) of
undefined -> ok;
Pid ->
catch unregister(Name),
Pid ! stop,
ok
end.
mock_gen_server_loop() ->
receive
{'$gen_call', From, {get, _}} ->
gen_server:reply(From, not_found),
mock_gen_server_loop();
{'$gen_call', From, _Msg} ->
gen_server:reply(From, ok),
mock_gen_server_loop();
stop ->
ok;
_ ->
mock_gen_server_loop()
end.
-endif.

View File

@@ -703,4 +703,292 @@ guild_update_syncs_unavailability_cache_test() ->
_ = guild_availability:update_unavailability_cache_for_state(CleanupState)
end.
handle_guild_update_merges_fields_test() ->
Data = #{
<<"guild">> => #{<<"name">> => <<"Old">>, <<"icon">> => <<"abc">>},
<<"roles">> => [],
<<"members">> => [],
<<"channels">> => []
},
EventData = #{<<"name">> => <<"New">>, <<"description">> => <<"desc">>},
Result = handle_guild_update(EventData, Data),
Guild = maps:get(<<"guild">>, Result),
?assertEqual(<<"New">>, maps:get(<<"name">>, Guild)),
?assertEqual(<<"abc">>, maps:get(<<"icon">>, Guild)),
?assertEqual(<<"desc">>, maps:get(<<"description">>, Guild)).
handle_member_update_ignores_non_member_test() ->
Data = #{
<<"members">> => #{
1 => #{<<"user">> => #{<<"id">> => <<"1">>}, <<"nick">> => <<"nick">>}
}
},
EventData = #{<<"user">> => #{<<"id">> => <<"999">>}, <<"nick">> => <<"new">>},
Result = handle_member_update(EventData, Data),
?assertEqual(1, map_size(maps:get(<<"members">>, Result))),
?assertEqual(undefined, guild_data_index:get_member(999, Result)).
handle_role_create_test() ->
Data = #{
<<"roles">> => [#{<<"id">> => <<"1">>, <<"name">> => <<"Everyone">>}],
<<"members">> => [],
<<"channels">> => []
},
EventData = #{<<"role">> => #{<<"id">> => <<"2">>, <<"name">> => <<"New">>}},
Result = handle_role_create(EventData, Data),
Roles = guild_data_index:role_list(Result),
?assertEqual(2, length(Roles)),
RoleIndex = guild_data_index:role_index(Result),
?assertMatch(#{2 := _}, RoleIndex).
handle_role_update_replaces_role_test() ->
Data = #{
<<"roles">> => [
#{<<"id">> => <<"1">>, <<"name">> => <<"Old">>},
#{<<"id">> => <<"2">>, <<"name">> => <<"Keep">>}
],
<<"members">> => [],
<<"channels">> => []
},
EventData = #{<<"role">> => #{<<"id">> => <<"1">>, <<"name">> => <<"Updated">>}},
Result = handle_role_update(EventData, Data),
Roles = guild_data_index:role_list(Result),
[R1, R2] = Roles,
?assertEqual(<<"Updated">>, maps:get(<<"name">>, R1)),
?assertEqual(<<"Keep">>, maps:get(<<"name">>, R2)).
handle_role_update_bulk_test() ->
Data = #{
<<"roles">> => [
#{<<"id">> => <<"1">>, <<"name">> => <<"A">>},
#{<<"id">> => <<"2">>, <<"name">> => <<"B">>},
#{<<"id">> => <<"3">>, <<"name">> => <<"C">>}
],
<<"members">> => [],
<<"channels">> => []
},
EventData = #{
<<"roles">> => [
#{<<"id">> => <<"1">>, <<"name">> => <<"A2">>},
#{<<"id">> => <<"3">>, <<"name">> => <<"C2">>}
]
},
Result = handle_role_update_bulk(EventData, Data),
Roles = guild_data_index:role_list(Result),
[R1, R2, R3] = Roles,
?assertEqual(<<"A2">>, maps:get(<<"name">>, R1)),
?assertEqual(<<"B">>, maps:get(<<"name">>, R2)),
?assertEqual(<<"C2">>, maps:get(<<"name">>, R3)).
handle_role_delete_removes_role_from_list_test() ->
Data = #{
<<"roles">> => [
#{<<"id">> => <<"1">>, <<"name">> => <<"Keep">>},
#{<<"id">> => <<"2">>, <<"name">> => <<"Delete">>}
],
<<"members">> => [],
<<"channels">> => []
},
EventData = #{<<"role_id">> => <<"2">>},
Result = handle_role_delete(EventData, Data),
Roles = guild_data_index:role_list(Result),
?assertEqual(1, length(Roles)),
?assertEqual(<<"Keep">>, maps:get(<<"name">>, hd(Roles))).
handle_channel_update_test() ->
Data = #{
<<"channels">> => [
#{<<"id">> => <<"100">>, <<"name">> => <<"old">>},
#{<<"id">> => <<"101">>, <<"name">> => <<"keep">>}
]
},
EventData = #{<<"id">> => <<"100">>, <<"name">> => <<"updated">>},
Result = handle_channel_update(EventData, Data),
Channels = guild_data_index:channel_list(Result),
[C1, C2] = Channels,
?assertEqual(<<"updated">>, maps:get(<<"name">>, C1)),
?assertEqual(<<"keep">>, maps:get(<<"name">>, C2)).
handle_channel_update_bulk_test() ->
Data = #{
<<"channels">> => [
#{<<"id">> => <<"1">>, <<"name">> => <<"A">>},
#{<<"id">> => <<"2">>, <<"name">> => <<"B">>}
]
},
EventData = #{
<<"channels">> => [
#{<<"id">> => <<"2">>, <<"name">> => <<"B2">>}
]
},
Result = handle_channel_update_bulk(EventData, Data),
Channels = guild_data_index:channel_list(Result),
[C1, C2] = Channels,
?assertEqual(<<"A">>, maps:get(<<"name">>, C1)),
?assertEqual(<<"B2">>, maps:get(<<"name">>, C2)).
handle_channel_delete_test() ->
Data = #{
<<"channels">> => [
#{<<"id">> => <<"100">>, <<"name">> => <<"general">>},
#{<<"id">> => <<"101">>, <<"name">> => <<"random">>}
]
},
EventData = #{<<"id">> => <<"100">>},
Result = handle_channel_delete(EventData, Data),
Channels = guild_data_index:channel_list(Result),
?assertEqual(1, length(Channels)),
?assertEqual(<<"random">>, maps:get(<<"name">>, hd(Channels))).
handle_message_create_updates_last_message_id_test() ->
Data = #{
<<"channels">> => [
#{<<"id">> => <<"100">>, <<"last_message_id">> => <<"500">>},
#{<<"id">> => <<"101">>, <<"last_message_id">> => <<"600">>}
]
},
EventData = #{<<"channel_id">> => <<"100">>, <<"id">> => <<"700">>},
Result = handle_message_create(EventData, Data),
Channels = guild_data_index:channel_list(Result),
[C1, C2] = Channels,
?assertEqual(<<"700">>, maps:get(<<"last_message_id">>, C1)),
?assertEqual(<<"600">>, maps:get(<<"last_message_id">>, C2)).
handle_channel_pins_update_test() ->
Data = #{
<<"channels">> => [
#{<<"id">> => <<"100">>}
]
},
EventData = #{<<"channel_id">> => <<"100">>, <<"last_pin_timestamp">> => <<"2024-01-01T00:00:00Z">>},
Result = handle_channel_pins_update(EventData, Data),
[Ch] = guild_data_index:channel_list(Result),
?assertEqual(<<"2024-01-01T00:00:00Z">>, maps:get(<<"last_pin_timestamp">>, Ch)).
handle_emojis_update_test() ->
Data = #{<<"emojis">> => []},
EventData = #{<<"emojis">> => [#{<<"id">> => <<"1">>}]},
Result = handle_emojis_update(EventData, Data),
?assertEqual([#{<<"id">> => <<"1">>}], maps:get(<<"emojis">>, Result)).
handle_stickers_update_test() ->
Data = #{<<"stickers">> => []},
EventData = #{<<"stickers">> => [#{<<"id">> => <<"1">>}]},
Result = handle_stickers_update(EventData, Data),
?assertEqual([#{<<"id">> => <<"1">>}], maps:get(<<"stickers">>, Result)).
replace_item_by_id_test() ->
Items = [
#{<<"id">> => <<"1">>, <<"v">> => <<"a">>},
#{<<"id">> => <<"2">>, <<"v">> => <<"b">>}
],
Result = replace_item_by_id(Items, <<"1">>, #{<<"id">> => <<"1">>, <<"v">> => <<"c">>}),
[R1, R2] = Result,
?assertEqual(<<"c">>, maps:get(<<"v">>, R1)),
?assertEqual(<<"b">>, maps:get(<<"v">>, R2)).
replace_item_by_id_no_match_test() ->
Items = [#{<<"id">> => <<"1">>, <<"v">> => <<"a">>}],
Result = replace_item_by_id(Items, <<"999">>, #{<<"id">> => <<"999">>}),
?assertEqual(Items, Result).
remove_item_by_id_test() ->
Items = [
#{<<"id">> => <<"1">>},
#{<<"id">> => <<"2">>},
#{<<"id">> => <<"3">>}
],
Result = remove_item_by_id(Items, <<"2">>),
?assertEqual(2, length(Result)),
Ids = [maps:get(<<"id">>, I) || I <- Result],
?assertEqual([<<"1">>, <<"3">>], Ids).
remove_item_by_id_no_match_test() ->
Items = [#{<<"id">> => <<"1">>}],
?assertEqual(Items, remove_item_by_id(Items, <<"999">>)).
bulk_update_items_no_updates_test() ->
Items = [#{<<"id">> => <<"1">>, <<"v">> => <<"a">>}],
?assertEqual(Items, bulk_update_items(Items, [])).
bulk_update_items_missing_id_in_bulk_ignored_test() ->
Items = [#{<<"id">> => <<"1">>, <<"v">> => <<"a">>}],
BulkItems = [#{<<"v">> => <<"b">>}],
?assertEqual(Items, bulk_update_items(Items, BulkItems)).
extract_role_ids_from_role_update_test() ->
EventData = #{<<"role">> => #{<<"id">> => <<"42">>}},
?assertEqual([42], extract_role_ids_from_role_update(EventData)).
extract_role_ids_from_role_update_missing_id_test() ->
EventData = #{<<"role">> => #{}},
?assertEqual([], extract_role_ids_from_role_update(EventData)).
extract_role_ids_from_role_update_missing_role_test() ->
?assertEqual([], extract_role_ids_from_role_update(#{})).
extract_role_ids_from_role_update_bulk_test() ->
EventData = #{<<"roles">> => [#{<<"id">> => <<"1">>}, #{<<"id">> => <<"2">>}]},
?assertEqual([1, 2], extract_role_ids_from_role_update_bulk(EventData)).
extract_role_ids_from_role_update_bulk_empty_test() ->
?assertEqual([], extract_role_ids_from_role_update_bulk(#{})).
extract_role_ids_from_role_delete_test() ->
EventData = #{<<"role_id">> => <<"55">>},
?assertEqual([55], extract_role_ids_from_role_delete(EventData)).
extract_role_ids_from_role_delete_missing_test() ->
?assertEqual([], extract_role_ids_from_role_delete(#{})).
update_data_for_event_unknown_returns_data_unchanged_test() ->
Data = #{<<"test">> => true},
?assertEqual(Data, update_data_for_event(unknown_event, #{}, Data, #{})).
strip_role_from_members_no_affected_users_test() ->
Data = #{
<<"roles">> => [],
<<"members">> => #{
1 => #{<<"user">> => #{<<"id">> => <<"1">>}, <<"roles">> => [<<"100">>]}
},
<<"channels">> => []
},
Result = strip_role_from_members(<<"999">>, Data),
M1 = guild_data_index:get_member(1, Result),
?assertEqual([<<"100">>], maps:get(<<"roles">>, M1)).
strip_role_from_channel_overwrites_preserves_user_overwrites_test() ->
Data = #{
<<"channels">> => [
#{
<<"id">> => <<"500">>,
<<"permission_overwrites">> => [
#{<<"id">> => <<"100">>, <<"type">> => 0, <<"allow">> => <<"0">>, <<"deny">> => <<"0">>},
#{<<"id">> => <<"1">>, <<"type">> => 1, <<"allow">> => <<"1024">>, <<"deny">> => <<"0">>}
]
}
]
},
Result = strip_role_from_channel_overwrites(<<"100">>, Data),
[Ch] = guild_data_index:channel_list(Result),
Overwrites = maps:get(<<"permission_overwrites">>, Ch),
?assertEqual(1, length(Overwrites)),
?assertEqual(1, maps:get(<<"type">>, hd(Overwrites))).
cleanup_removed_member_sessions_removes_non_members_test() ->
Data = #{
<<"members">> => #{
1 => #{<<"user">> => #{<<"id">> => <<"1">>}}
}
},
Sessions = #{
<<"s1">> => #{user_id => 1},
<<"s2">> => #{user_id => 999}
},
State = #{data => Data, sessions => Sessions},
Result = cleanup_removed_member_sessions(State),
NewSessions = maps:get(sessions, Result),
?assertEqual(1, map_size(NewSessions)),
?assert(maps:is_key(<<"s1">>, NewSessions)).
-endif.

View File

@@ -0,0 +1,327 @@
%% Copyright (C) 2026 Fluxer Contributors
%%
%% This file is part of Fluxer.
%%
%% Fluxer is free software: you can redistribute it and/or modify
%% it under the terms of the GNU Affero General Public License as published by
%% the Free Software Foundation, either version 3 of the License, or
%% (at your option) any later version.
%%
%% Fluxer is distributed in the hope that it will be useful,
%% but WITHOUT ANY WARRANTY; without even the implied warranty of
%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
%% GNU Affero General Public License for more details.
%%
%% You should have received a copy of the GNU Affero General Public License
%% along with Fluxer. If not, see <https://www.gnu.org/licenses/>.
-module(guild_subscription_handler).
-export([
handle_call/3,
handle_cast/2
]).
-type guild_state() :: map().
-type user_id() :: integer().
-type session_id() :: binary().
-type channel_id() :: integer().
-spec handle_call(term(), gen_server:from(), guild_state()) ->
{reply, term(), guild_state()}.
handle_call({lazy_subscribe, Request}, _From, State) ->
handle_lazy_subscribe(Request, State).
-spec handle_cast(term(), guild_state()) -> {noreply, guild_state()}.
handle_cast({update_member_subscriptions, SessionId, MemberIds}, State) ->
NewState0 = handle_update_member_subscriptions(SessionId, MemberIds, State),
NewState = maybe_prune_very_large_guild_members(NewState0),
{noreply, NewState};
handle_cast({very_large_guild_member_list_deliver, Deliveries}, State) when is_list(Deliveries) ->
_ = deliver_member_list_updates(Deliveries, State),
{noreply, State}.
-spec handle_lazy_subscribe(map(), guild_state()) -> {reply, ok, guild_state()}.
handle_lazy_subscribe(Request, State) ->
case maps:get(disable_member_list_updates, State, false) of
true ->
{reply, ok, State};
false ->
#{session_id := SessionId, channel_id := ChannelId, ranges := Ranges} = Request,
Sessions0 = maps:get(sessions, State, #{}),
SessionUserId = get_session_user_id(SessionId, Sessions0),
case
is_integer(SessionUserId) andalso
guild_permissions:can_view_channel(SessionUserId, ChannelId, undefined, State)
of
true ->
GuildId = maps:get(id, State),
ListId = guild_member_list:calculate_list_id(ChannelId, State),
{NewState, ShouldSendSync, NormalizedRanges} =
guild_member_list:subscribe_ranges(SessionId, ListId, Ranges, State),
handle_lazy_subscribe_sync(
ShouldSendSync, NormalizedRanges, GuildId, ListId, ChannelId, SessionId, NewState
);
false ->
{reply, ok, State}
end
end.
-spec handle_lazy_subscribe_sync(
boolean(), list(), integer(), term(), channel_id(), session_id(), guild_state()
) ->
{reply, ok, guild_state()}.
handle_lazy_subscribe_sync(true, [], _GuildId, _ListId, _ChannelId, _SessionId, State) ->
{reply, ok, State};
handle_lazy_subscribe_sync(true, RangesToSend, GuildId, ListId, ChannelId, SessionId, State) ->
SyncResponse = guild_member_list:build_sync_response(GuildId, ListId, RangesToSend, State),
SyncResponseWithChannel = maps:put(
<<"channel_id">>, integer_to_binary(ChannelId), SyncResponse
),
Sessions = maps:get(sessions, State, #{}),
case maps:get(SessionId, Sessions, undefined) of
#{pid := SessionPid} when is_pid(SessionPid) ->
gen_server:cast(
SessionPid, {dispatch, guild_member_list_update, SyncResponseWithChannel}
);
_ ->
ok
end,
{reply, ok, State};
handle_lazy_subscribe_sync(_, _, _GuildId, _ListId, _ChannelId, _SessionId, State) ->
{reply, ok, State}.
-spec get_session_user_id(session_id(), map()) -> user_id() | undefined.
get_session_user_id(SessionId, Sessions) ->
case maps:get(SessionId, Sessions, undefined) of
#{user_id := Uid} -> Uid;
_ -> undefined
end.
-spec handle_update_member_subscriptions(session_id(), [user_id()], guild_state()) -> guild_state().
handle_update_member_subscriptions(SessionId, MemberIds, State) ->
MemberSubs = maps:get(member_subscriptions, State, guild_subscriptions:init_state()),
Sessions = maps:get(sessions, State, #{}),
SessionUserId = get_session_user_id(SessionId, Sessions),
StateWithPrimedMembers = maybe_prime_very_large_guild_members(MemberIds, State),
FilteredMemberIds = filter_member_ids_with_mutual_channels(
SessionUserId, MemberIds, StateWithPrimedMembers
),
OldSubscriptions = guild_subscriptions:get_user_ids_for_session(SessionId, MemberSubs),
NewMemberSubs = guild_subscriptions:update_subscriptions(
SessionId, FilteredMemberIds, MemberSubs
),
NewSubscriptions = guild_subscriptions:get_user_ids_for_session(SessionId, NewMemberSubs),
Added = sets:to_list(sets:subtract(NewSubscriptions, OldSubscriptions)),
Removed = sets:to_list(sets:subtract(OldSubscriptions, NewSubscriptions)),
State1 = maps:put(member_subscriptions, NewMemberSubs, StateWithPrimedMembers),
State2 = handle_added_subscriptions(Added, SessionId, State1),
handle_removed_subscriptions(Removed, State2).
-spec handle_added_subscriptions([user_id()], session_id(), guild_state()) -> guild_state().
handle_added_subscriptions(Added, SessionId, State) ->
lists:foldl(
fun(UserId, Acc) ->
StateWithPresence = guild_sessions:subscribe_to_user_presence(UserId, Acc),
guild_presence:send_cached_presence_to_session(UserId, SessionId, StateWithPresence)
end,
State,
Added
).
-spec handle_removed_subscriptions([user_id()], guild_state()) -> guild_state().
handle_removed_subscriptions(Removed, State) ->
lists:foldl(
fun(UserId, Acc) -> guild_sessions:unsubscribe_from_user_presence(UserId, Acc) end,
State,
Removed
).
-spec filter_member_ids_with_mutual_channels(user_id() | undefined, [user_id()], guild_state()) ->
[user_id()].
filter_member_ids_with_mutual_channels(undefined, _, _) ->
[];
filter_member_ids_with_mutual_channels(SessionUserId, MemberIds, State) ->
SessionChannels = guild_visibility:viewable_channel_set(SessionUserId, State),
lists:filtermap(
fun(MemberId) ->
case MemberId =:= SessionUserId of
true ->
false;
false ->
case has_shared_channels(SessionChannels, MemberId, State) of
true -> {true, MemberId};
false -> false
end
end
end,
MemberIds
).
-spec has_shared_channels(sets:set(), user_id(), guild_state()) -> boolean().
has_shared_channels(_, MemberId, _) when not is_integer(MemberId) ->
false;
has_shared_channels(SessionChannels, MemberId, State) ->
CandidateChannels = guild_visibility:viewable_channel_set(MemberId, State),
not sets:is_empty(sets:intersection(SessionChannels, CandidateChannels)).
-spec maybe_prime_very_large_guild_members([user_id()], guild_state()) -> guild_state().
maybe_prime_very_large_guild_members(UserIds, State) when is_list(UserIds) ->
case
{
maps:get(very_large_guild_coordinator_pid, State, undefined),
maps:get(very_large_guild_shard_index, State, undefined)
}
of
{CoordPid, ShardIndex} when is_pid(CoordPid), is_integer(ShardIndex), ShardIndex =/= 0 ->
UniqueUserIds = lists:usort([U || U <- UserIds, is_integer(U), U > 0]),
case UniqueUserIds of
[] ->
State;
_ ->
MembersReply =
try gen_server:call(
CoordPid, {very_large_guild_get_members, UniqueUserIds}, 10000
) of
Reply -> Reply
catch
_:_ -> #{}
end,
prime_members_from_reply(MembersReply, State)
end;
_ ->
State
end;
maybe_prime_very_large_guild_members(_, State) ->
State.
-spec prime_members_from_reply(term(), guild_state()) -> guild_state().
prime_members_from_reply(MembersReply, State) when is_map(MembersReply) ->
Data0 = maps:get(data, State, #{}),
Data = maps:fold(
fun(_UserId, Member, AccData) ->
case is_map(Member) of
true -> guild_data_index:put_member(Member, AccData);
false -> AccData
end
end,
Data0,
MembersReply
),
maps:put(data, Data, State);
prime_members_from_reply(_, State) ->
State.
-spec maybe_prune_very_large_guild_members(guild_state()) -> guild_state().
maybe_prune_very_large_guild_members(State) ->
case
{
maps:get(very_large_guild_coordinator_pid, State, undefined),
maps:get(very_large_guild_shard_index, State, undefined)
}
of
{CoordPid, ShardIndex} when is_pid(CoordPid), is_integer(ShardIndex), ShardIndex =/= 0 ->
prune_member_cache_to_needed_users(State);
_ ->
State
end.
-spec prune_member_cache_to_needed_users(guild_state()) -> guild_state().
prune_member_cache_to_needed_users(State) ->
Data0 = maps:get(data, State, #{}),
Members0 = maps:get(<<"members">>, Data0, #{}),
case is_map(Members0) of
false ->
State;
true ->
NeededUserIds = needed_member_cache_user_ids(State),
NeededSet = sets:from_list(NeededUserIds),
FilteredMembers = maps:filter(
fun(UserId, _Member) -> sets:is_element(UserId, NeededSet) end,
Members0
),
case map_size(FilteredMembers) =:= map_size(Members0) of
true ->
State;
false ->
Data1 = guild_data_index:put_member_map(FilteredMembers, Data0),
maps:put(data, Data1, State)
end
end.
-spec needed_member_cache_user_ids(guild_state()) -> [user_id()].
needed_member_cache_user_ids(State) ->
Sessions = maps:get(sessions, State, #{}),
SessionUserIds = maps:fold(
fun(_SessionId, SessionData, Acc) ->
case maps:get(user_id, SessionData, undefined) of
UserId when is_integer(UserId), UserId > 0 -> [UserId | Acc];
_ -> Acc
end
end,
[],
Sessions
),
MemberSubs = maps:get(member_subscriptions, State, guild_subscriptions:init_state()),
SubscribedUserIds = maps:keys(MemberSubs),
lists:usort(SessionUserIds ++ SubscribedUserIds).
-spec deliver_member_list_updates([{session_id(), map()}], guild_state()) -> ok.
deliver_member_list_updates(Deliveries, State) ->
Sessions = maps:get(sessions, State, #{}),
lists:foreach(
fun({SessionId, Payload}) ->
case maps:get(SessionId, Sessions, undefined) of
#{pid := SessionPid} = SessionData when is_pid(SessionPid), is_map(Payload) ->
case maps:get(pending_connect, SessionData, false) of
true ->
ok;
false ->
ChannelId = member_list_payload_channel_id(Payload),
case can_session_view_channel(SessionData, ChannelId, State) of
true ->
gen_server:cast(
SessionPid, {dispatch, guild_member_list_update, Payload}
);
false ->
ok
end
end;
_ ->
ok
end
end,
Deliveries
),
ok.
-spec member_list_payload_channel_id(map()) -> channel_id().
member_list_payload_channel_id(Payload) ->
ChannelIdBin = maps:get(<<"channel_id">>, Payload, undefined),
ListIdBin = maps:get(<<"id">>, Payload, <<"0">>),
case ChannelIdBin of
Bin when is_binary(Bin) ->
case type_conv:to_integer(Bin) of
undefined -> 0;
Id -> Id
end;
_ ->
case type_conv:to_integer(ListIdBin) of
undefined -> 0;
Id -> Id
end
end.
-spec can_session_view_channel(map(), channel_id(), guild_state()) -> boolean().
can_session_view_channel(_SessionData, ChannelId, _State) when not is_integer(ChannelId); ChannelId =< 0 ->
false;
can_session_view_channel(SessionData, ChannelId, State) ->
case {maps:get(user_id, SessionData, undefined), maps:get(viewable_channels, SessionData, undefined)} of
{UserId, ViewableChannels} when is_integer(UserId), is_map(ViewableChannels) ->
maps:is_key(ChannelId, ViewableChannels) orelse
guild_permissions:can_view_channel(UserId, ChannelId, undefined, State);
{UserId, _} when is_integer(UserId) ->
guild_permissions:can_view_channel(UserId, ChannelId, undefined, State);
_ ->
false
end.

View File

@@ -1,24 +0,0 @@
%% Copyright (C) 2026 Fluxer Contributors
%%
%% This file is part of Fluxer.
%%
%% Fluxer is free software: you can redistribute it and/or modify
%% it under the terms of the GNU Affero General Public License as published by
%% the Free Software Foundation, either version 3 of the License, or
%% (at your option) any later version.
%%
%% Fluxer is distributed in the hope that it will be useful,
%% but WITHOUT ANY WARRANTY; without even the implied warranty of
%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
%% GNU Affero General Public License for more details.
%%
%% You should have received a copy of the GNU Affero General Public License
%% along with Fluxer. If not, see <https://www.gnu.org/licenses/>.
-module(guild_sync).
-export([send_guild_sync/2]).
-spec send_guild_sync(pid(), binary()) -> ok.
send_guild_sync(GuildPid, SessionId) ->
gen_server:cast(GuildPid, {send_guild_sync, SessionId}).

View File

@@ -89,7 +89,7 @@ process_sync_flag(GuildSubData, _GuildId, GuildPid, SessionId, ActiveChanged) ->
ShouldSync = maps:get(<<"sync">>, GuildSubData, false) =:= true orelse ActiveChanged,
case ShouldSync of
true ->
guild_sync:send_guild_sync(GuildPid, SessionId);
gen_server:cast(GuildPid, {send_guild_sync, SessionId});
false ->
ok
end.

View File

@@ -700,4 +700,348 @@ filter_connected_session_entries_excludes_pending_test() ->
ResultIds = lists:sort([Sid || {Sid, _} <- Result]),
?assertEqual([<<"s1">>, <<"s3">>], ResultIds).
administrator_sees_all_channels_test() ->
GuildId = 50,
UserId = 10,
ChannelId = 100,
Admin = constants:administrator_permission(),
State = #{
id => GuildId,
data => #{
<<"guild">> => #{<<"owner_id">> => <<"999">>},
<<"roles">> => [
#{<<"id">> => integer_to_binary(GuildId), <<"permissions">> => integer_to_binary(Admin)}
],
<<"members">> => [
#{<<"user">> => #{<<"id">> => integer_to_binary(UserId)}, <<"roles">> => []}
],
<<"channels">> => [
#{
<<"id">> => integer_to_binary(ChannelId),
<<"permission_overwrites">> => [
#{
<<"id">> => integer_to_binary(GuildId),
<<"type">> => 0,
<<"allow">> => <<"0">>,
<<"deny">> => integer_to_binary(constants:view_channel_permission())
}
]
}
]
}
},
Channels = get_user_viewable_channels(UserId, State),
?assertEqual([ChannelId], Channels).
owner_sees_all_channels_test() ->
GuildId = 60,
OwnerId = 10,
ChannelId = 200,
State = #{
id => GuildId,
data => #{
<<"guild">> => #{<<"owner_id">> => integer_to_binary(OwnerId)},
<<"roles">> => [
#{<<"id">> => integer_to_binary(GuildId), <<"permissions">> => <<"0">>}
],
<<"members">> => [
#{<<"user">> => #{<<"id">> => integer_to_binary(OwnerId)}, <<"roles">> => []}
],
<<"channels">> => [
#{<<"id">> => integer_to_binary(ChannelId), <<"permission_overwrites">> => []}
]
}
},
Channels = get_user_viewable_channels(OwnerId, State),
?assertEqual([ChannelId], Channels).
everyone_role_grants_view_test() ->
GuildId = 70,
UserId = 10,
ChannelId = 300,
ViewPerm = constants:view_channel_permission(),
State = #{
id => GuildId,
data => #{
<<"guild">> => #{<<"owner_id">> => <<"999">>},
<<"roles">> => [
#{<<"id">> => integer_to_binary(GuildId), <<"permissions">> => integer_to_binary(ViewPerm)}
],
<<"members">> => [
#{<<"user">> => #{<<"id">> => integer_to_binary(UserId)}, <<"roles">> => []}
],
<<"channels">> => [
#{<<"id">> => integer_to_binary(ChannelId), <<"permission_overwrites">> => []}
]
}
},
Channels = get_user_viewable_channels(UserId, State),
?assertEqual([ChannelId], Channels).
channel_overwrite_denies_view_test() ->
GuildId = 80,
UserId = 10,
RoleId = 200,
ChannelId = 400,
ViewPerm = constants:view_channel_permission(),
State = #{
id => GuildId,
data => #{
<<"guild">> => #{<<"owner_id">> => <<"999">>},
<<"roles">> => [
#{<<"id">> => integer_to_binary(GuildId), <<"permissions">> => integer_to_binary(ViewPerm)},
#{<<"id">> => integer_to_binary(RoleId), <<"permissions">> => <<"0">>}
],
<<"members">> => [
#{
<<"user">> => #{<<"id">> => integer_to_binary(UserId)},
<<"roles">> => [integer_to_binary(RoleId)]
}
],
<<"channels">> => [
#{
<<"id">> => integer_to_binary(ChannelId),
<<"permission_overwrites">> => [
#{
<<"id">> => integer_to_binary(RoleId),
<<"type">> => 0,
<<"allow">> => <<"0">>,
<<"deny">> => integer_to_binary(ViewPerm)
}
]
}
]
}
},
Channels = get_user_viewable_channels(UserId, State),
?assertEqual([], Channels).
role_overwrite_allows_view_test() ->
GuildId = 90,
UserId = 10,
RoleId = 300,
ChannelId = 500,
ViewPerm = constants:view_channel_permission(),
State = #{
id => GuildId,
data => #{
<<"guild">> => #{<<"owner_id">> => <<"999">>},
<<"roles">> => [
#{<<"id">> => integer_to_binary(GuildId), <<"permissions">> => integer_to_binary(ViewPerm)},
#{<<"id">> => integer_to_binary(RoleId), <<"permissions">> => <<"0">>}
],
<<"members">> => [
#{
<<"user">> => #{<<"id">> => integer_to_binary(UserId)},
<<"roles">> => [integer_to_binary(RoleId)]
}
],
<<"channels">> => [
#{
<<"id">> => integer_to_binary(ChannelId),
<<"permission_overwrites">> => [
#{
<<"id">> => integer_to_binary(GuildId),
<<"type">> => 0,
<<"allow">> => <<"0">>,
<<"deny">> => integer_to_binary(ViewPerm)
},
#{
<<"id">> => integer_to_binary(RoleId),
<<"type">> => 0,
<<"allow">> => integer_to_binary(ViewPerm),
<<"deny">> => <<"0">>
}
]
}
]
}
},
Channels = get_user_viewable_channels(UserId, State),
?assertEqual([ChannelId], Channels).
user_overwrite_denies_view_test() ->
GuildId = 91,
UserId = 10,
RoleId = 301,
ChannelId = 501,
ViewPerm = constants:view_channel_permission(),
State = #{
id => GuildId,
data => #{
<<"guild">> => #{<<"owner_id">> => <<"999">>},
<<"roles">> => [
#{<<"id">> => integer_to_binary(GuildId), <<"permissions">> => integer_to_binary(ViewPerm)},
#{<<"id">> => integer_to_binary(RoleId), <<"permissions">> => <<"0">>}
],
<<"members">> => [
#{
<<"user">> => #{<<"id">> => integer_to_binary(UserId)},
<<"roles">> => [integer_to_binary(RoleId)]
}
],
<<"channels">> => [
#{
<<"id">> => integer_to_binary(ChannelId),
<<"permission_overwrites">> => [
#{
<<"id">> => integer_to_binary(RoleId),
<<"type">> => 0,
<<"allow">> => integer_to_binary(ViewPerm),
<<"deny">> => <<"0">>
},
#{
<<"id">> => integer_to_binary(UserId),
<<"type">> => 1,
<<"allow">> => <<"0">>,
<<"deny">> => integer_to_binary(ViewPerm)
}
]
}
]
}
},
Channels = get_user_viewable_channels(UserId, State),
?assertEqual([], Channels).
viewable_channel_set_uses_cached_session_data_test() ->
UserId = 10,
State = #{
sessions => #{
<<"s1">> => #{
user_id => UserId,
viewable_channels => #{100 => true, 200 => true}
}
},
data => #{
<<"guild">> => #{<<"owner_id">> => <<"999">>},
<<"members">> => [],
<<"channels">> => [],
<<"roles">> => []
}
},
ChannelSet = viewable_channel_set(UserId, State),
?assertEqual(true, sets:is_element(100, ChannelSet)),
?assertEqual(true, sets:is_element(200, ChannelSet)),
?assertEqual(false, sets:is_element(999, ChannelSet)).
have_shared_viewable_channel_shared_test() ->
GuildId = 100,
ViewPerm = constants:view_channel_permission(),
State = #{
id => GuildId,
sessions => #{},
data => #{
<<"guild">> => #{<<"owner_id">> => <<"999">>},
<<"roles">> => [
#{<<"id">> => integer_to_binary(GuildId), <<"permissions">> => integer_to_binary(ViewPerm)}
],
<<"members">> => [
#{<<"user">> => #{<<"id">> => <<"10">>}, <<"roles">> => []},
#{<<"user">> => #{<<"id">> => <<"20">>}, <<"roles">> => []}
],
<<"channels">> => [
#{<<"id">> => <<"500">>, <<"permission_overwrites">> => []}
]
}
},
?assertEqual(true, have_shared_viewable_channel(10, 20, State)).
have_shared_viewable_channel_no_shared_test() ->
GuildId = 101,
ViewPerm = constants:view_channel_permission(),
RoleId = 200,
State = #{
id => GuildId,
sessions => #{},
data => #{
<<"guild">> => #{<<"owner_id">> => <<"999">>},
<<"roles">> => [
#{<<"id">> => integer_to_binary(GuildId), <<"permissions">> => <<"0">>},
#{<<"id">> => integer_to_binary(RoleId), <<"permissions">> => integer_to_binary(ViewPerm)}
],
<<"members">> => [
#{<<"user">> => #{<<"id">> => <<"10">>}, <<"roles">> => [integer_to_binary(RoleId)]},
#{<<"user">> => #{<<"id">> => <<"20">>}, <<"roles">> => []}
],
<<"channels">> => [
#{
<<"id">> => <<"500">>,
<<"permission_overwrites">> => [
#{
<<"id">> => integer_to_binary(GuildId),
<<"type">> => 0,
<<"allow">> => <<"0">>,
<<"deny">> => integer_to_binary(ViewPerm)
}
]
}
]
}
},
?assertEqual(false, have_shared_viewable_channel(10, 20, State)).
update_viewable_map_for_channel_add_test() ->
Map = #{100 => true},
Result = update_viewable_map_for_channel(Map, 200, true),
?assertEqual(true, maps:is_key(200, Result)),
?assertEqual(true, maps:is_key(100, Result)).
update_viewable_map_for_channel_remove_test() ->
Map = #{100 => true, 200 => true},
Result = update_viewable_map_for_channel(Map, 100, false),
?assertEqual(false, maps:is_key(100, Result)),
?assertEqual(true, maps:is_key(200, Result)).
multiple_channels_partial_visibility_test() ->
GuildId = 110,
UserId = 10,
ViewPerm = constants:view_channel_permission(),
RoleId = 300,
State = #{
id => GuildId,
data => #{
<<"guild">> => #{<<"owner_id">> => <<"999">>},
<<"roles">> => [
#{<<"id">> => integer_to_binary(GuildId), <<"permissions">> => integer_to_binary(ViewPerm)},
#{<<"id">> => integer_to_binary(RoleId), <<"permissions">> => <<"0">>}
],
<<"members">> => [
#{
<<"user">> => #{<<"id">> => integer_to_binary(UserId)},
<<"roles">> => []
}
],
<<"channels">> => [
#{<<"id">> => <<"100">>, <<"permission_overwrites">> => []},
#{
<<"id">> => <<"101">>,
<<"permission_overwrites">> => [
#{
<<"id">> => integer_to_binary(GuildId),
<<"type">> => 0,
<<"allow">> => <<"0">>,
<<"deny">> => integer_to_binary(ViewPerm)
}
]
},
#{<<"id">> => <<"102">>, <<"permission_overwrites">> => []}
]
}
},
Channels = lists:sort(get_user_viewable_channels(UserId, State)),
?assertEqual([100, 102], Channels).
viewable_channel_map_test() ->
Set = sets:from_list([10, 20, 30]),
Map = viewable_channel_map(Set),
?assertEqual(3, map_size(Map)),
?assertEqual(true, maps:get(10, Map)),
?assertEqual(true, maps:get(20, Map)),
?assertEqual(true, maps:get(30, Map)).
viewable_channel_map_empty_test() ->
Map = viewable_channel_map(sets:new()),
?assertEqual(#{}, Map).
-endif.

View File

@@ -0,0 +1,159 @@
%% Copyright (C) 2026 Fluxer Contributors
%%
%% This file is part of Fluxer.
%%
%% Fluxer is free software: you can redistribute it and/or modify
%% it under the terms of the GNU Affero General Public License as published by
%% the Free Software Foundation, either version 3 of the License, or
%% (at your option) any later version.
%%
%% Fluxer is distributed in the hope that it will be useful,
%% but WITHOUT ANY WARRANTY; without even the implied warranty of
%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
%% GNU Affero General Public License for more details.
%%
%% You should have received a copy of the GNU Affero General Public License
%% along with Fluxer. If not, see <https://www.gnu.org/licenses/>.
-module(guild_voice_handler).
-export([
handle_call/3,
handle_cast/2
]).
-type guild_state() :: map().
-spec handle_call(term(), gen_server:from(), guild_state()) ->
{reply, term(), guild_state()}
| {noreply, guild_state()}.
handle_call({voice_state_update, Request}, _From, State) ->
guild_voice:voice_state_update(Request, State);
handle_call({get_voice_state, Request}, _From, State) ->
guild_voice:get_voice_state(Request, State);
handle_call({update_member_voice, Request}, _From, State) ->
guild_voice:update_member_voice(Request, State);
handle_call({disconnect_voice_user, Request}, _From, State) ->
guild_voice:disconnect_voice_user(Request, State);
handle_call({disconnect_voice_user_if_in_channel, Request}, _From, State) ->
guild_voice:disconnect_voice_user_if_in_channel(Request, State);
handle_call({disconnect_all_voice_users_in_channel, Request}, _From, State) ->
guild_voice:disconnect_all_voice_users_in_channel(Request, State);
handle_call({confirm_voice_connection_from_livekit, Request}, _From, State) ->
guild_voice:confirm_voice_connection_from_livekit(Request, State);
handle_call({move_member, Request}, _From, State) ->
guild_voice:move_member(Request, State);
handle_call({switch_voice_region, Request}, _From, State) ->
guild_voice:switch_voice_region_handler(Request, State);
handle_call({add_virtual_channel_access, UserId, ChannelId}, _From, State) ->
NewState = guild_virtual_channel_access:add_virtual_access(UserId, ChannelId, State),
guild_virtual_channel_access:dispatch_channel_visibility_change(
UserId, ChannelId, add, NewState
),
{reply, ok, NewState};
handle_call({store_pending_connection, ConnectionId, Metadata}, _From, State) ->
PendingConnections = maps:get(pending_voice_connections, State, #{}),
NewPendingConnections = maps:put(ConnectionId, Metadata, PendingConnections),
NewState = maps:put(pending_voice_connections, NewPendingConnections, State),
{reply, ok, NewState};
handle_call({get_voice_states_for_channel, ChannelIdBin}, _From, State) ->
VoiceStates = maps:get(voice_states, State, #{}),
Filtered = maps:fold(
fun(ConnId, VS, Acc) ->
case maps:get(<<"channel_id">>, VS, null) of
ChannelIdBin ->
[#{
connection_id => ConnId,
user_id => maps:get(<<"user_id">>, VS, null),
channel_id => ChannelIdBin
} | Acc];
_ ->
Acc
end
end,
[],
VoiceStates
),
{reply, #{voice_states => Filtered}, State};
handle_call({get_pending_joins_for_channel, ChannelIdBin}, _From, State) ->
PendingConnections = maps:get(pending_voice_connections, State, #{}),
ChannelIdInt = binary_to_integer(ChannelIdBin),
Filtered = maps:fold(
fun(ConnId, Metadata, Acc) ->
case maps:get(channel_id, Metadata, undefined) of
ChannelIdInt ->
[#{
connection_id => ConnId,
user_id => integer_to_binary(maps:get(user_id, Metadata, 0)),
token_nonce => maps:get(token_nonce, Metadata, null),
expires_at => maps:get(expires_at, Metadata, 0)
} | Acc];
_ ->
Acc
end
end,
[],
PendingConnections
),
{reply, #{pending_joins => Filtered}, State}.
-spec handle_cast(term(), guild_state()) -> {noreply, guild_state()}.
handle_cast({relay_voice_state_update, VoiceState, OldChannelIdBin}, State) ->
State1 = relay_upsert_voice_state(VoiceState, State),
StateNoRelay = maps:remove(very_large_guild_coordinator_pid, State1),
_ = guild_voice_broadcast:broadcast_voice_state_update(VoiceState, StateNoRelay, OldChannelIdBin),
{noreply, State1};
handle_cast(
{relay_voice_server_update, GuildId, ChannelId, SessionId, Token, Endpoint, ConnectionId},
State
) ->
StateNoRelay = maps:remove(very_large_guild_coordinator_pid, State),
_ = guild_voice_broadcast:broadcast_voice_server_update_to_session(
GuildId,
ChannelId,
SessionId,
Token,
Endpoint,
ConnectionId,
StateNoRelay
),
{noreply, State};
handle_cast({store_pending_connection, ConnectionId, Metadata}, State) ->
PendingConnections = maps:get(pending_voice_connections, State, #{}),
NewPendingConnections = maps:put(ConnectionId, Metadata, PendingConnections),
NewState = maps:put(pending_voice_connections, NewPendingConnections, State),
{noreply, NewState};
handle_cast({add_virtual_channel_access, UserId, ChannelId}, State) ->
NewState = guild_virtual_channel_access:add_virtual_access(UserId, ChannelId, State),
guild_virtual_channel_access:dispatch_channel_visibility_change(
UserId, ChannelId, add, NewState
),
{noreply, NewState};
handle_cast({remove_virtual_channel_access, UserId, ChannelId}, State) ->
guild_virtual_channel_access:dispatch_channel_visibility_change(
UserId, ChannelId, remove, State
),
NewState = guild_virtual_channel_access:remove_virtual_access(UserId, ChannelId, State),
{noreply, NewState};
handle_cast({cleanup_virtual_access_for_user, UserId}, State) ->
NewState = guild_voice_disconnect:cleanup_virtual_channel_access_for_user(UserId, State),
{noreply, NewState}.
-spec relay_upsert_voice_state(map(), guild_state()) -> guild_state().
relay_upsert_voice_state(VoiceState, State) when is_map(VoiceState) ->
ConnectionId = maps:get(<<"connection_id">>, VoiceState, undefined),
case ConnectionId of
undefined ->
State;
_ ->
VoiceStates0 = maps:get(voice_states, State, #{}),
ChannelId = maps:get(<<"channel_id">>, VoiceState, null),
VoiceStates =
case ChannelId of
null -> maps:remove(ConnectionId, VoiceStates0);
_ -> maps:put(ConnectionId, VoiceState, VoiceStates0)
end,
maps:put(voice_states, VoiceStates, State)
end;
relay_upsert_voice_state(_, State) ->
State.

View File

@@ -0,0 +1,225 @@
%% Copyright (C) 2026 Fluxer Contributors
%%
%% This file is part of Fluxer.
%%
%% Fluxer is free software: you can redistribute it and/or modify
%% it under the terms of the GNU Affero General Public License as published by
%% the Free Software Foundation, either version 3 of the License, or
%% (at your option) any later version.
%%
%% Fluxer is distributed in the hope that it will be useful,
%% but WITHOUT ANY WARRANTY; without even the implied warranty of
%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
%% GNU Affero General Public License for more details.
%%
%% You should have received a copy of the GNU Affero General Public License
%% along with Fluxer. If not, see <https://www.gnu.org/licenses/>.
-module(passive_sync_registry).
-export([
init/0,
store/3,
lookup/2,
delete/2,
delete_all_for_session/1
]).
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-endif.
-define(TABLE, passive_sync_registry).
-type session_id() :: binary().
-type guild_id() :: integer().
-type passive_state() :: #{
previous_passive_updates := #{binary() => binary()},
previous_passive_channel_versions := #{binary() => integer()},
previous_passive_voice_states := #{binary() => map()}
}.
-spec init() -> ok.
init() ->
case ets:whereis(?TABLE) of
undefined ->
_ = ets:new(?TABLE, [
named_table,
public,
set,
{read_concurrency, true},
{write_concurrency, true}
]),
ok;
_ ->
ok
end.
-spec store(session_id(), guild_id(), passive_state()) -> ok.
store(SessionId, GuildId, PassiveState) ->
ensure_table(),
Key = {SessionId, GuildId},
ets:insert(?TABLE, {Key, PassiveState, SessionId}),
ok.
-spec lookup(session_id(), guild_id()) -> passive_state().
lookup(SessionId, GuildId) ->
ensure_table(),
Key = {SessionId, GuildId},
case ets:lookup(?TABLE, Key) of
[{Key, PassiveState, _SessionId}] ->
PassiveState;
[] ->
#{
previous_passive_updates => #{},
previous_passive_channel_versions => #{},
previous_passive_voice_states => #{}
}
end.
-spec delete(session_id(), guild_id()) -> ok.
delete(SessionId, GuildId) ->
ensure_table(),
Key = {SessionId, GuildId},
ets:delete(?TABLE, Key),
ok.
-spec delete_all_for_session(session_id()) -> ok.
delete_all_for_session(SessionId) ->
ensure_table(),
ets:match_delete(?TABLE, {'_', '_', SessionId}),
ok.
-spec ensure_table() -> ok.
ensure_table() ->
case ets:whereis(?TABLE) of
undefined -> init();
_ -> ok
end.
-ifdef(TEST).
init_creates_table_test() ->
cleanup_table(),
ok = init(),
?assertNotEqual(undefined, ets:whereis(?TABLE)),
cleanup_table().
init_idempotent_test() ->
cleanup_table(),
ok = init(),
ok = init(),
?assertNotEqual(undefined, ets:whereis(?TABLE)),
cleanup_table().
store_and_lookup_test() ->
cleanup_table(),
ok = init(),
SessionId = <<"session_1">>,
GuildId = 100,
PassiveState = #{
previous_passive_updates => #{<<"ch1">> => <<"msg1">>},
previous_passive_channel_versions => #{<<"ch1">> => 5},
previous_passive_voice_states => #{}
},
ok = store(SessionId, GuildId, PassiveState),
Result = lookup(SessionId, GuildId),
?assertEqual(PassiveState, Result),
cleanup_table().
lookup_missing_returns_defaults_test() ->
cleanup_table(),
ok = init(),
Result = lookup(<<"nonexistent">>, 999),
?assertEqual(#{
previous_passive_updates => #{},
previous_passive_channel_versions => #{},
previous_passive_voice_states => #{}
}, Result),
cleanup_table().
delete_removes_entry_test() ->
cleanup_table(),
ok = init(),
SessionId = <<"session_1">>,
GuildId = 100,
PassiveState = #{
previous_passive_updates => #{<<"ch1">> => <<"msg1">>},
previous_passive_channel_versions => #{},
previous_passive_voice_states => #{}
},
ok = store(SessionId, GuildId, PassiveState),
ok = delete(SessionId, GuildId),
Result = lookup(SessionId, GuildId),
?assertEqual(#{
previous_passive_updates => #{},
previous_passive_channel_versions => #{},
previous_passive_voice_states => #{}
}, Result),
cleanup_table().
delete_all_for_session_removes_all_guilds_test() ->
cleanup_table(),
ok = init(),
SessionId = <<"session_1">>,
OtherSessionId = <<"session_2">>,
State1 = #{
previous_passive_updates => #{<<"ch1">> => <<"msg1">>},
previous_passive_channel_versions => #{},
previous_passive_voice_states => #{}
},
State2 = #{
previous_passive_updates => #{<<"ch2">> => <<"msg2">>},
previous_passive_channel_versions => #{},
previous_passive_voice_states => #{}
},
OtherState = #{
previous_passive_updates => #{<<"ch3">> => <<"msg3">>},
previous_passive_channel_versions => #{},
previous_passive_voice_states => #{}
},
ok = store(SessionId, 100, State1),
ok = store(SessionId, 200, State2),
ok = store(OtherSessionId, 100, OtherState),
ok = delete_all_for_session(SessionId),
?assertEqual(#{
previous_passive_updates => #{},
previous_passive_channel_versions => #{},
previous_passive_voice_states => #{}
}, lookup(SessionId, 100)),
?assertEqual(#{
previous_passive_updates => #{},
previous_passive_channel_versions => #{},
previous_passive_voice_states => #{}
}, lookup(SessionId, 200)),
?assertEqual(OtherState, lookup(OtherSessionId, 100)),
cleanup_table().
store_overwrites_existing_test() ->
cleanup_table(),
ok = init(),
SessionId = <<"session_1">>,
GuildId = 100,
State1 = #{
previous_passive_updates => #{<<"ch1">> => <<"msg1">>},
previous_passive_channel_versions => #{},
previous_passive_voice_states => #{}
},
State2 = #{
previous_passive_updates => #{<<"ch1">> => <<"msg2">>},
previous_passive_channel_versions => #{<<"ch1">> => 3},
previous_passive_voice_states => #{}
},
ok = store(SessionId, GuildId, State1),
ok = store(SessionId, GuildId, State2),
Result = lookup(SessionId, GuildId),
?assertEqual(State2, Result),
cleanup_table().
cleanup_table() ->
case ets:whereis(?TABLE) of
undefined -> ok;
_ -> ets:delete(?TABLE), ok
end.
-endif.

View File

@@ -265,8 +265,9 @@ handle_cast({very_large_guild_member_list_deliver, DeliveriesByShard}, State) wh
DeliveriesByShard
),
{noreply, State};
handle_cast({dispatch, _Event, _EventData} = Msg, State) ->
broadcast_cast(Msg, State),
handle_cast({dispatch, #{event := Event, data := EventData} = Request}, State) ->
broadcast_cast({dispatch, Request}, State),
maybe_trigger_push(Event, EventData, State),
{noreply, State};
handle_cast(_Msg, State) ->
{noreply, State}.
@@ -363,13 +364,6 @@ terminate(_Reason, State) ->
code_change(_OldVsn, State, _Extra) ->
{ok, State}.
-spec strip_members(map()) -> map().
strip_members(Data) when is_map(Data) ->
Data1 = maps:remove(<<"members">>, Data),
maps:remove(<<"member_role_index">>, Data1);
strip_members(Data) ->
Data.
-spec determine_shard_count(map()) -> pos_integer().
determine_shard_count(GuildState) ->
Override = maps:get(very_large_guild_shard_count, GuildState, undefined),
@@ -392,32 +386,9 @@ determine_shard_count(GuildState) ->
-spec start_shards(guild_id(), map(), pos_integer()) -> {#{shard_index() => shard_entry()}, ok}.
start_shards(GuildId, Data, ShardCount) ->
MemberCount = guild_data_index:member_count(Data),
Shards = lists:foldl(
fun(Index, Acc) ->
DisableCache = Index =/= 0,
ShardData =
case DisableCache of
true -> strip_members(Data);
false -> Data
end,
ShardState0 = #{
id => GuildId,
data => ShardData,
sessions => #{},
member_count => MemberCount,
disable_push_notifications => true,
disable_member_list_updates => DisableCache,
disable_auto_stop_on_empty => true,
very_large_guild_coordinator_pid => self(),
very_large_guild_shard_count => ShardCount,
very_large_guild_shard_index => Index
},
ShardState =
case DisableCache of
true -> maps:put(disable_permission_cache_updates, true, ShardState0);
false -> ShardState0
end,
ShardState = guild_common:build_shard_state(GuildId, Data, ShardCount, Index),
case guild:start_link(ShardState) of
{ok, Pid} ->
MRef = monitor(process, Pid),
@@ -439,37 +410,15 @@ restart_shard(ShardIndex, State) ->
Data =
case maps:get(0, Shards0, undefined) of
#{pid := Pid0} ->
case safe_call(Pid0, {get_push_base_state}, 5000) of
case guild_common:safe_call(Pid0, {get_push_base_state}, 5000) of
#{data := D} -> D;
_ -> Fallback
end;
_ ->
Fallback
end,
DisableCache = ShardIndex =/= 0,
ShardData =
case DisableCache of
true -> strip_members(Data);
false -> Data
end,
MemberCount = guild_data_index:member_count(Data),
ShardState0 = #{
id => GuildId,
data => ShardData,
sessions => #{},
member_count => MemberCount,
disable_push_notifications => true,
disable_member_list_updates => DisableCache,
disable_auto_stop_on_empty => true,
very_large_guild_coordinator_pid => self(),
very_large_guild_shard_count => maps:get(shard_count, State, 1),
very_large_guild_shard_index => ShardIndex
},
ShardState =
case DisableCache of
true -> maps:put(disable_permission_cache_updates, true, ShardState0);
false -> ShardState0
end,
ShardCount = maps:get(shard_count, State, 1),
ShardState = guild_common:build_shard_state(GuildId, Data, ShardCount, ShardIndex),
case guild:start_link(ShardState) of
{ok, NewPid} ->
MRef = monitor(process, NewPid),
@@ -586,9 +535,9 @@ reload_shards(NewData, State) ->
Payload =
case Index of
0 -> NewData;
_ -> strip_members(NewData)
_ -> guild_common:strip_members(NewData)
end,
_ = safe_call(Pid, {reload, Payload}, 20000),
_ = guild_common:safe_call(Pid, {reload, Payload}, 20000),
ok
end,
Shards
@@ -635,15 +584,8 @@ do_prime_connected_members(State) ->
ok.
-spec safe_call(pid(), term(), timeout()) -> term().
safe_call(Pid, Msg, Timeout) when is_pid(Pid) ->
try gen_server:call(Pid, Msg, Timeout) of
Reply -> Reply
catch
exit:{timeout, _} -> {error, timeout};
exit:{noproc, _} -> {error, noproc};
exit:{normal, _} -> {error, noproc};
_:Reason -> {error, Reason}
end.
safe_call(Pid, Msg, Timeout) ->
guild_common:safe_call(Pid, Msg, Timeout).
-spec safe_call_to_session_shard(session_id(), term(), timeout(), state()) -> term().
safe_call_to_session_shard(SessionId, Msg, Timeout, State) ->
@@ -764,41 +706,7 @@ maybe_notify_member_list_virtual_access_cleanup(UserId, State) ->
-spec merge_cluster_state(map(), map()) -> map().
merge_cluster_state(Acc, Frag) ->
SessionsAcc = maps:get(sessions, Acc, #{}),
SessionsFrag = maps:get(sessions, Frag, #{}),
VoiceAcc = maps:get(voice_states, Acc, #{}),
VoiceFrag = maps:get(voice_states, Frag, #{}),
VAAcc = maps:get(virtual_channel_access, Acc, #{}),
VAFrag = maps:get(virtual_channel_access, Frag, #{}),
PendingAcc = maps:get(virtual_channel_access_pending, Acc, #{}),
PendingFrag = maps:get(virtual_channel_access_pending, Frag, #{}),
PreserveAcc = maps:get(virtual_channel_access_preserve, Acc, #{}),
PreserveFrag = maps:get(virtual_channel_access_preserve, Frag, #{}),
MoveAcc = maps:get(virtual_channel_access_move_pending, Acc, #{}),
MoveFrag = maps:get(virtual_channel_access_move_pending, Frag, #{}),
Acc#{
sessions => maps:merge(SessionsAcc, SessionsFrag),
voice_states => maps:merge(VoiceAcc, VoiceFrag),
virtual_channel_access => merge_user_set_maps(VAAcc, VAFrag),
virtual_channel_access_pending => merge_user_set_maps(PendingAcc, PendingFrag),
virtual_channel_access_preserve => merge_user_set_maps(PreserveAcc, PreserveFrag),
virtual_channel_access_move_pending => merge_user_set_maps(MoveAcc, MoveFrag)
}.
-spec merge_user_set_maps(map(), map()) -> map().
merge_user_set_maps(A, B) ->
maps:fold(
fun(UserId, SetB, Acc) ->
case maps:get(UserId, Acc, undefined) of
undefined ->
maps:put(UserId, SetB, Acc);
SetA ->
maps:put(UserId, sets:union(SetA, SetB), Acc)
end
end,
A,
B
).
guild_common:merge_cluster_state(Acc, Frag).
-spec relay_to_other_shards(shard_index(), term(), state()) -> ok.
relay_to_other_shards(SourceIndex, Msg, State) ->
@@ -1382,13 +1290,13 @@ strip_members_removes_members_and_role_index_test() ->
<<"channels">> => [#{<<"id">> => <<"10">>}],
<<"roles">> => [#{<<"id">> => <<"role1">>}]
},
Stripped = strip_members(Data),
Stripped = guild_common:strip_members(Data),
?assertEqual(false, maps:is_key(<<"members">>, Stripped)),
?assertEqual(false, maps:is_key(<<"member_role_index">>, Stripped)),
?assertEqual([#{<<"id">> => <<"10">>}], maps:get(<<"channels">>, Stripped)),
?assertEqual([#{<<"id">> => <<"role1">>}], maps:get(<<"roles">>, Stripped)),
?assertEqual(#{}, strip_members(#{})),
?assertEqual(not_a_map, strip_members(not_a_map)),
?assertEqual(#{}, guild_common:strip_members(#{})),
?assertEqual(not_a_map, guild_common:strip_members(not_a_map)),
ok.
coordinator_stops_on_last_disconnect_test() ->
@@ -1421,7 +1329,7 @@ merge_user_set_maps_test() ->
SetB = sets:from_list([2, 3]),
MapA = #{10 => SetA},
MapB = #{10 => SetB, 20 => SetB},
Merged = merge_user_set_maps(MapA, MapB),
Merged = guild_common:merge_user_set_maps(MapA, MapB),
?assert(maps:is_key(10, Merged)),
?assert(maps:is_key(20, Merged)),
MergedSet10 = maps:get(10, Merged),
@@ -1430,7 +1338,7 @@ merge_user_set_maps_test() ->
?assert(sets:is_element(3, MergedSet10)),
?assertEqual(3, sets:size(MergedSet10)),
?assertEqual(SetB, maps:get(20, Merged)),
EmptyMerge = merge_user_set_maps(#{}, #{}),
EmptyMerge = guild_common:merge_user_set_maps(#{}, #{}),
?assertEqual(#{}, EmptyMerge),
ok.

View File

@@ -199,22 +199,10 @@ handle_subscribe(SessionId, ChannelId, Ranges, State) ->
NormalizedRanges = guild_member_list:normalize_ranges(Ranges),
ListId = guild_member_list:calculate_list_id(ChannelId, Snapshot0),
Subs0 = maps:get(subscriptions, State, #{}),
ListSubs0 = maps:get(ListId, Subs0, #{}),
OldRanges = maps:get(SessionId, ListSubs0, []),
Subs =
case NormalizedRanges of
[] ->
ListSubs1 = maps:remove(SessionId, ListSubs0),
case map_size(ListSubs1) of
0 -> maps:remove(ListId, Subs0);
_ -> maps:put(ListId, ListSubs1, Subs0)
end;
_ ->
ListSubs1 = maps:put(SessionId, NormalizedRanges, ListSubs0),
maps:put(ListId, ListSubs1, Subs0)
end,
{Subs, _OldRanges, ShouldSync} =
guild_member_list_common:update_subscriptions(SessionId, ListId, NormalizedRanges, Subs0),
State1 = maps:put(subscriptions, Subs, State),
case NormalizedRanges =/= [] andalso NormalizedRanges =/= OldRanges of
case ShouldSync of
true ->
self() ! {send_initial_sync, SessionId, ChannelId, ListId, NormalizedRanges},
State1;
@@ -524,8 +512,7 @@ maybe_normalize_data(_Data0) ->
-spec member_user_id(map()) -> user_id().
member_user_id(MemberData) ->
User = maps:get(<<"user">>, MemberData, #{}),
map_utils:get_integer(User, <<"id">>, 0).
guild_member_list_common:get_member_user_id(MemberData).
-spec upsert_item_by_id(term(), map(), [map()]) -> [map()].
upsert_item_by_id(Id, NewItem, Items) ->
@@ -758,17 +745,7 @@ cleanup_virtual_access(UserId, State) ->
-spec remove_session_subscriptions(session_id(), state()) -> state().
remove_session_subscriptions(SessionId, State) ->
Subs0 = maps:get(subscriptions, State, #{}),
Subs = maps:fold(
fun(ListId, ListSubs0, Acc) ->
ListSubs = maps:remove(SessionId, ListSubs0),
case map_size(ListSubs) of
0 -> Acc;
_ -> maps:put(ListId, ListSubs, Acc)
end
end,
#{},
Subs0
),
Subs = guild_member_list_common:remove_session_from_subscriptions(SessionId, Subs0),
maps:put(subscriptions, Subs, State).
-spec safe_call(pid(), term(), timeout()) -> term().
@@ -1485,5 +1462,640 @@ notify_channel_update_triggers_channel_sync_test() ->
exit(Shard0Pid, shutdown),
ok.
session_connected_first_session_triggers_delta_test() ->
Shard0Pid = start_stub_shard0(),
{ok, Pid} = start_link(#{
id => 1,
coordinator_pid => self(),
shard0_pid => Shard0Pid
}),
ok = await(fun() ->
State = sys:get_state(Pid),
maps:get(compute_inflight, State, true) =:= false
end),
gen_server:cast(Pid, {session_connected, <<"s1">>, 0, 42}),
ok = await(fun() ->
State = sys:get_state(Pid),
Counts = maps:get(user_session_counts, State, #{}),
maps:get(42, Counts, 0) =:= 1
end),
State = sys:get_state(Pid),
Routes = maps:get(session_routes, State, #{}),
?assertEqual(0, maps:get(<<"s1">>, Routes)),
catch gen_server:stop(Pid),
unlink(Shard0Pid),
exit(Shard0Pid, shutdown),
ok.
session_connected_second_session_no_delta_test() ->
Shard0Pid = start_stub_shard0(),
{ok, Pid} = start_link(#{
id => 1,
coordinator_pid => self(),
shard0_pid => Shard0Pid
}),
gen_server:cast(Pid, {session_connected, <<"s1">>, 0, 42}),
ok = await(fun() ->
State = sys:get_state(Pid),
maps:get(compute_inflight, State, true) =:= false
end),
gen_server:cast(Pid, {session_connected, <<"s2">>, 1, 42}),
ok = await(fun() ->
State = sys:get_state(Pid),
maps:get(user_session_counts, State, #{}) =:= #{42 => 2}
end),
State = sys:get_state(Pid),
Routes = maps:get(session_routes, State, #{}),
?assertEqual(0, maps:get(<<"s1">>, Routes)),
?assertEqual(1, maps:get(<<"s2">>, Routes)),
catch gen_server:stop(Pid),
unlink(Shard0Pid),
exit(Shard0Pid, shutdown),
ok.
session_disconnected_last_session_triggers_delta_test() ->
Shard0Pid = start_stub_shard0(),
{ok, Pid} = start_link(#{
id => 1,
coordinator_pid => self(),
shard0_pid => Shard0Pid
}),
gen_server:cast(Pid, {session_connected, <<"s1">>, 0, 42}),
ok = await(fun() ->
State = sys:get_state(Pid),
Counts = maps:get(user_session_counts, State, #{}),
maps:get(42, Counts, 0) =:= 1
end),
ok = await(fun() ->
State = sys:get_state(Pid),
maps:get(compute_inflight, State, true) =:= false
end),
gen_server:cast(Pid, {session_disconnected, <<"s1">>, 42}),
ok = await(fun() ->
State = sys:get_state(Pid),
Counts = maps:get(user_session_counts, State, #{}),
not maps:is_key(42, Counts)
end),
State = sys:get_state(Pid),
Routes = maps:get(session_routes, State, #{}),
?assertNot(maps:is_key(<<"s1">>, Routes)),
catch gen_server:stop(Pid),
unlink(Shard0Pid),
exit(Shard0Pid, shutdown),
ok.
session_disconnected_not_last_no_delta_test() ->
Shard0Pid = start_stub_shard0(),
{ok, Pid} = start_link(#{
id => 1,
coordinator_pid => self(),
shard0_pid => Shard0Pid
}),
gen_server:cast(Pid, {session_connected, <<"s1">>, 0, 42}),
gen_server:cast(Pid, {session_connected, <<"s2">>, 1, 42}),
ok = await(fun() ->
State = sys:get_state(Pid),
maps:get(user_session_counts, State, #{}) =:= #{42 => 2}
end),
ok = await(fun() ->
State = sys:get_state(Pid),
maps:get(compute_inflight, State, true) =:= false
end),
gen_server:cast(Pid, {session_disconnected, <<"s1">>, 42}),
ok = await(fun() ->
State = sys:get_state(Pid),
maps:get(user_session_counts, State, #{}) =:= #{42 => 1}
end),
State = sys:get_state(Pid),
?assertNot(maps:is_key(<<"s1">>, maps:get(session_routes, State, #{}))),
?assert(maps:is_key(<<"s2">>, maps:get(session_routes, State, #{}))),
catch gen_server:stop(Pid),
unlink(Shard0Pid),
exit(Shard0Pid, shutdown),
ok.
virtual_access_added_and_removed_restores_state_test() ->
Shard0Pid = start_stub_shard0(),
{ok, Pid} = start_link(#{
id => 1,
coordinator_pid => self(),
shard0_pid => Shard0Pid
}),
gen_server:cast(Pid, {virtual_access_added, 10, 50}),
ok = await(fun() ->
State = sys:get_state(Pid),
VA = maps:get(virtual_channel_access, State, #{}),
maps:is_key(10, VA)
end),
gen_server:cast(Pid, {virtual_access_removed, 10, 50}),
ok = await(fun() ->
State = sys:get_state(Pid),
VA = maps:get(virtual_channel_access, State, #{}),
not maps:is_key(10, VA)
end),
catch gen_server:stop(Pid),
unlink(Shard0Pid),
exit(Shard0Pid, shutdown),
ok.
virtual_access_multiple_channels_test() ->
Shard0Pid = start_stub_shard0(),
{ok, Pid} = start_link(#{
id => 1,
coordinator_pid => self(),
shard0_pid => Shard0Pid
}),
gen_server:cast(Pid, {virtual_access_added, 10, 50}),
gen_server:cast(Pid, {virtual_access_added, 10, 60}),
ok = await(fun() ->
State = sys:get_state(Pid),
VA = maps:get(virtual_channel_access, State, #{}),
case maps:get(10, VA, undefined) of
undefined -> false;
Channels -> sets:size(Channels) =:= 2
end
end),
gen_server:cast(Pid, {virtual_access_removed, 10, 50}),
ok = await(fun() ->
State = sys:get_state(Pid),
VA = maps:get(virtual_channel_access, State, #{}),
case maps:get(10, VA, undefined) of
undefined -> false;
Channels -> sets:size(Channels) =:= 1
end
end),
State = sys:get_state(Pid),
VA = maps:get(virtual_channel_access, State, #{}),
Channels = maps:get(10, VA),
?assertEqual(true, sets:is_element(60, Channels)),
?assertEqual(false, sets:is_element(50, Channels)),
catch gen_server:stop(Pid),
unlink(Shard0Pid),
exit(Shard0Pid, shutdown),
ok.
virtual_access_cleanup_removes_all_channels_test() ->
Shard0Pid = start_stub_shard0(),
{ok, Pid} = start_link(#{
id => 1,
coordinator_pid => self(),
shard0_pid => Shard0Pid
}),
gen_server:cast(Pid, {virtual_access_added, 10, 50}),
gen_server:cast(Pid, {virtual_access_added, 10, 60}),
ok = await(fun() ->
State = sys:get_state(Pid),
VA = maps:get(virtual_channel_access, State, #{}),
maps:is_key(10, VA)
end),
gen_server:cast(Pid, {virtual_access_cleanup, 10}),
ok = await(fun() ->
State = sys:get_state(Pid),
VA = maps:get(virtual_channel_access, State, #{}),
not maps:is_key(10, VA)
end),
catch gen_server:stop(Pid),
unlink(Shard0Pid),
exit(Shard0Pid, shutdown),
ok.
virtual_access_removed_nonexistent_user_noop_test() ->
Shard0Pid = start_stub_shard0(),
{ok, Pid} = start_link(#{
id => 1,
coordinator_pid => self(),
shard0_pid => Shard0Pid
}),
ok = await(fun() ->
State = sys:get_state(Pid),
maps:get(compute_inflight, State, true) =:= false
end),
gen_server:cast(Pid, {virtual_access_removed, 999, 50}),
timer:sleep(100),
State = sys:get_state(Pid),
VA = maps:get(virtual_channel_access, State, #{}),
?assertNot(maps:is_key(999, VA)),
catch gen_server:stop(Pid),
unlink(Shard0Pid),
exit(Shard0Pid, shutdown),
ok.
presence_update_multiple_users_test() ->
Shard0Pid = start_stub_shard0(),
{ok, Pid} = start_link(#{
id => 1,
coordinator_pid => self(),
shard0_pid => Shard0Pid
}),
gen_server:cast(Pid, {presence_update, 10, #{<<"status">> => <<"online">>}}),
gen_server:cast(Pid, {presence_update, 20, #{<<"status">> => <<"idle">>}}),
gen_server:cast(Pid, {presence_update, 30, #{<<"status">> => <<"dnd">>}}),
ok = await(fun() ->
State = sys:get_state(Pid),
Presence = maps:get(member_presence, State, #{}),
maps:is_key(10, Presence) andalso maps:is_key(20, Presence) andalso maps:is_key(30, Presence)
end),
State = sys:get_state(Pid),
Presence = maps:get(member_presence, State, #{}),
?assertEqual(#{<<"status">> => <<"online">>}, maps:get(10, Presence)),
?assertEqual(#{<<"status">> => <<"idle">>}, maps:get(20, Presence)),
?assertEqual(#{<<"status">> => <<"dnd">>}, maps:get(30, Presence)),
catch gen_server:stop(Pid),
unlink(Shard0Pid),
exit(Shard0Pid, shutdown),
ok.
presence_update_overwrite_test() ->
Shard0Pid = start_stub_shard0(),
{ok, Pid} = start_link(#{
id => 1,
coordinator_pid => self(),
shard0_pid => Shard0Pid
}),
gen_server:cast(Pid, {presence_update, 10, #{<<"status">> => <<"online">>}}),
ok = await(fun() ->
State = sys:get_state(Pid),
Presence = maps:get(member_presence, State, #{}),
maps:get(10, Presence, #{}) =:= #{<<"status">> => <<"online">>}
end),
gen_server:cast(Pid, {presence_update, 10, #{<<"status">> => <<"offline">>}}),
ok = await(fun() ->
State = sys:get_state(Pid),
Presence = maps:get(member_presence, State, #{}),
maps:get(10, Presence, #{}) =:= #{<<"status">> => <<"offline">>}
end),
gen_server:cast(Pid, {presence_update, 10, #{<<"status">> => <<"online">>}}),
ok = await(fun() ->
State = sys:get_state(Pid),
Presence = maps:get(member_presence, State, #{}),
maps:get(10, Presence, #{}) =:= #{<<"status">> => <<"online">>}
end),
catch gen_server:stop(Pid),
unlink(Shard0Pid),
exit(Shard0Pid, shutdown),
ok.
subscribe_multiple_sessions_to_same_channel_test() ->
Shard0Pid = start_stub_shard0(),
{ok, Pid} = start_link(#{
id => 1,
coordinator_pid => self(),
shard0_pid => Shard0Pid
}),
gen_server:cast(Pid, {session_connected, <<"s1">>, 0, 10}),
gen_server:cast(Pid, {session_connected, <<"s2">>, 0, 20}),
gen_server:cast(Pid, {subscribe, <<"s1">>, 10, [{0, 50}]}),
gen_server:cast(Pid, {subscribe, <<"s2">>, 10, [{25, 75}]}),
ok = await(fun() ->
State = sys:get_state(Pid),
Subs = maps:get(subscriptions, State, #{}),
case maps:get(<<"10">>, Subs, undefined) of
undefined -> false;
ListSubs ->
maps:is_key(<<"s1">>, ListSubs) andalso maps:is_key(<<"s2">>, ListSubs)
end
end),
State = sys:get_state(Pid),
Subs = maps:get(subscriptions, State, #{}),
ListSubs = maps:get(<<"10">>, Subs),
?assertEqual([{0, 50}], maps:get(<<"s1">>, ListSubs)),
?assertEqual([{25, 75}], maps:get(<<"s2">>, ListSubs)),
catch gen_server:stop(Pid),
unlink(Shard0Pid),
exit(Shard0Pid, shutdown),
ok.
subscribe_invalid_ranges_filtered_test() ->
Shard0Pid = start_stub_shard0(),
{ok, Pid} = start_link(#{
id => 1,
coordinator_pid => self(),
shard0_pid => Shard0Pid
}),
gen_server:cast(Pid, {session_connected, <<"s1">>, 0, 10}),
gen_server:cast(Pid, {subscribe, <<"s1">>, 10, [{100, 50}, {-1, 10}]}),
ok = await(fun() ->
State = sys:get_state(Pid),
maps:get(compute_inflight, State, true) =:= false
end),
State = sys:get_state(Pid),
Subs = maps:get(subscriptions, State, #{}),
?assertEqual(#{}, Subs),
catch gen_server:stop(Pid),
unlink(Shard0Pid),
exit(Shard0Pid, shutdown),
ok.
upsert_member_with_no_user_field_test() ->
Shard0Pid = start_stub_shard0(),
{ok, Pid} = start_link(#{
id => 1,
coordinator_pid => self(),
shard0_pid => Shard0Pid
}),
MemberData = #{<<"nick">> => <<"orphan">>},
gen_server:cast(Pid, {upsert_member, MemberData}),
ok = await(fun() ->
State = sys:get_state(Pid),
maps:get(compute_inflight, State, true) =:= false
end),
catch gen_server:stop(Pid),
unlink(Shard0Pid),
exit(Shard0Pid, shutdown),
ok.
upsert_member_then_update_roles_test() ->
Shard0Pid = start_stub_shard0(),
{ok, Pid} = start_link(#{
id => 1,
coordinator_pid => self(),
shard0_pid => Shard0Pid
}),
MemberData = #{
<<"user">> => #{<<"id">> => <<"42">>, <<"username">> => <<"alice">>},
<<"roles">> => [<<"100">>]
},
gen_server:cast(Pid, {upsert_member, MemberData}),
ok = await(fun() ->
State = sys:get_state(Pid),
Base = maps:get(base_data, State, #{}),
Members = guild_data_index:member_map(Base),
maps:is_key(42, Members)
end),
UpdatedMemberData = #{
<<"user">> => #{<<"id">> => <<"42">>, <<"username">> => <<"alice">>},
<<"roles">> => [<<"200">>, <<"300">>]
},
gen_server:cast(Pid, {upsert_member, UpdatedMemberData}),
ok = await(fun() ->
State = sys:get_state(Pid),
Base = maps:get(base_data, State, #{}),
Members = guild_data_index:member_map(Base),
case maps:get(42, Members, undefined) of
undefined -> false;
M -> maps:get(<<"roles">>, M, []) =:= [<<"200">>, <<"300">>]
end
end),
catch gen_server:stop(Pid),
unlink(Shard0Pid),
exit(Shard0Pid, shutdown),
ok.
remove_member_then_upsert_again_test() ->
Shard0Pid = start_stub_shard0(),
{ok, Pid} = start_link(#{
id => 1,
coordinator_pid => self(),
shard0_pid => Shard0Pid
}),
MemberData = #{
<<"user">> => #{<<"id">> => <<"42">>, <<"username">> => <<"alice">>},
<<"roles">> => []
},
gen_server:cast(Pid, {upsert_member, MemberData}),
ok = await(fun() ->
State = sys:get_state(Pid),
Base = maps:get(base_data, State, #{}),
maps:is_key(42, guild_data_index:member_map(Base))
end),
gen_server:cast(Pid, {remove_member, 42}),
ok = await(fun() ->
State = sys:get_state(Pid),
Base = maps:get(base_data, State, #{}),
not maps:is_key(42, guild_data_index:member_map(Base))
end),
gen_server:cast(Pid, {upsert_member, MemberData}),
ok = await(fun() ->
State = sys:get_state(Pid),
Base = maps:get(base_data, State, #{}),
maps:is_key(42, guild_data_index:member_map(Base))
end),
catch gen_server:stop(Pid),
unlink(Shard0Pid),
exit(Shard0Pid, shutdown),
ok.
multiple_rapid_presence_updates_coalesce_test() ->
Shard0Pid = start_stub_shard0(),
{ok, Pid} = start_link(#{
id => 1,
coordinator_pid => self(),
shard0_pid => Shard0Pid
}),
ok = await(fun() ->
State = sys:get_state(Pid),
maps:get(compute_inflight, State, true) =:= false
end),
gen_server:cast(Pid, {presence_update, 10, #{<<"status">> => <<"online">>}}),
gen_server:cast(Pid, {presence_update, 10, #{<<"status">> => <<"idle">>}}),
gen_server:cast(Pid, {presence_update, 10, #{<<"status">> => <<"dnd">>}}),
gen_server:cast(Pid, {presence_update, 10, #{<<"status">> => <<"offline">>}}),
ok = await(fun() ->
State = sys:get_state(Pid),
maps:get(compute_inflight, State, true) =:= false
end),
State = sys:get_state(Pid),
Presence = maps:get(member_presence, State, #{}),
?assertEqual(#{<<"status">> => <<"offline">>}, maps:get(10, Presence)),
catch gen_server:stop(Pid),
unlink(Shard0Pid),
exit(Shard0Pid, shutdown),
ok.
unknown_cast_ignored_test() ->
Shard0Pid = start_stub_shard0(),
{ok, Pid} = start_link(#{
id => 1,
coordinator_pid => self(),
shard0_pid => Shard0Pid
}),
gen_server:cast(Pid, {unknown_message, <<"data">>}),
timer:sleep(50),
?assertEqual(true, is_process_alive(Pid)),
catch gen_server:stop(Pid),
unlink(Shard0Pid),
exit(Shard0Pid, shutdown),
ok.
unknown_info_ignored_test() ->
Shard0Pid = start_stub_shard0(),
{ok, Pid} = start_link(#{
id => 1,
coordinator_pid => self(),
shard0_pid => Shard0Pid
}),
Pid ! {unknown_info_message},
timer:sleep(50),
?assertEqual(true, is_process_alive(Pid)),
catch gen_server:stop(Pid),
unlink(Shard0Pid),
exit(Shard0Pid, shutdown),
ok.
snapshot_includes_presence_and_sessions_test() ->
Shard0Pid = start_stub_shard0(),
{ok, Pid} = start_link(#{
id => 1,
coordinator_pid => self(),
shard0_pid => Shard0Pid
}),
gen_server:cast(Pid, {session_connected, <<"s1">>, 0, 10}),
gen_server:cast(Pid, {presence_update, 10, #{<<"status">> => <<"online">>}}),
ok = await(fun() ->
State = sys:get_state(Pid),
maps:get(compute_inflight, State, true) =:= false
end),
State = sys:get_state(Pid),
Snapshot = maps:get(snapshot, State),
?assertEqual(1, maps:get(id, Snapshot)),
SnapshotPresence = maps:get(member_presence, Snapshot),
?assertEqual(#{<<"status">> => <<"online">>}, maps:get(10, SnapshotPresence)),
SnapshotSessions = maps:get(sessions, Snapshot),
?assert(map_size(SnapshotSessions) > 0),
catch gen_server:stop(Pid),
unlink(Shard0Pid),
exit(Shard0Pid, shutdown),
ok.
channels_bulk_update_invalid_input_test() ->
Shard0Pid = start_stub_shard0(),
{ok, Pid} = start_link(#{
id => 1,
coordinator_pid => self(),
shard0_pid => Shard0Pid
}),
ok = await(fun() ->
State = sys:get_state(Pid),
maps:get(compute_inflight, State, true) =:= false
end),
gen_server:cast(Pid, {channels_bulk_update, not_a_list}),
timer:sleep(100),
?assertEqual(true, is_process_alive(Pid)),
catch gen_server:stop(Pid),
unlink(Shard0Pid),
exit(Shard0Pid, shutdown),
ok.
roles_bulk_update_invalid_input_test() ->
Shard0Pid = start_stub_shard0(),
{ok, Pid} = start_link(#{
id => 1,
coordinator_pid => self(),
shard0_pid => Shard0Pid
}),
ok = await(fun() ->
State = sys:get_state(Pid),
maps:get(compute_inflight, State, true) =:= false
end),
gen_server:cast(Pid, {roles_bulk_update, not_a_list}),
timer:sleep(100),
?assertEqual(true, is_process_alive(Pid)),
catch gen_server:stop(Pid),
unlink(Shard0Pid),
exit(Shard0Pid, shutdown),
ok.
upsert_channel_new_channel_test() ->
Shard0Pid = start_stub_shard0(),
{ok, Pid} = start_link(#{
id => 1,
coordinator_pid => self(),
shard0_pid => Shard0Pid
}),
ok = await(fun() ->
State = sys:get_state(Pid),
maps:get(compute_inflight, State, true) =:= false
end),
NewChannel = #{<<"id">> => <<"999">>, <<"name">> => <<"new-channel">>},
gen_server:cast(Pid, {upsert_channel, NewChannel}),
ok = await(fun() ->
State = sys:get_state(Pid),
Base = maps:get(base_data, State, #{}),
Channels = guild_data_index:channel_list(Base),
lists:any(fun(C) -> maps:get(<<"id">>, C, undefined) =:= <<"999">> end, Channels)
end),
catch gen_server:stop(Pid),
unlink(Shard0Pid),
exit(Shard0Pid, shutdown),
ok.
member_removed_while_presence_update_test() ->
Shard0Pid = start_stub_shard0(),
{ok, Pid} = start_link(#{
id => 1,
coordinator_pid => self(),
shard0_pid => Shard0Pid
}),
MemberData = #{
<<"user">> => #{<<"id">> => <<"42">>, <<"username">> => <<"alice">>},
<<"roles">> => []
},
gen_server:cast(Pid, {upsert_member, MemberData}),
ok = await(fun() ->
State = sys:get_state(Pid),
Base = maps:get(base_data, State, #{}),
maps:is_key(42, guild_data_index:member_map(Base))
end),
gen_server:cast(Pid, {presence_update, 42, #{<<"status">> => <<"online">>}}),
gen_server:cast(Pid, {remove_member, 42}),
ok = await(fun() ->
State = sys:get_state(Pid),
Base = maps:get(base_data, State, #{}),
not maps:is_key(42, guild_data_index:member_map(Base))
end),
ok = await(fun() ->
State = sys:get_state(Pid),
maps:get(compute_inflight, State, true) =:= false
end),
?assertEqual(true, is_process_alive(Pid)),
catch gen_server:stop(Pid),
unlink(Shard0Pid),
exit(Shard0Pid, shutdown),
ok.
upsert_member_non_map_ignored_test() ->
Shard0Pid = start_stub_shard0(),
{ok, Pid} = start_link(#{
id => 1,
coordinator_pid => self(),
shard0_pid => Shard0Pid
}),
ok = await(fun() ->
State = sys:get_state(Pid),
maps:get(compute_inflight, State, true) =:= false
end),
gen_server:cast(Pid, {upsert_member, not_a_map}),
timer:sleep(100),
?assertEqual(true, is_process_alive(Pid)),
catch gen_server:stop(Pid),
unlink(Shard0Pid),
exit(Shard0Pid, shutdown),
ok.
full_sync_all_clears_pending_test() ->
Shard0Pid = start_stub_shard0(),
{ok, Pid} = start_link(#{
id => 1,
coordinator_pid => self(),
shard0_pid => Shard0Pid
}),
ok = await(fun() ->
State = sys:get_state(Pid),
maps:get(compute_inflight, State, true) =:= false
end),
gen_server:cast(Pid, {notify_role_update}),
ok = await(fun() ->
State = sys:get_state(Pid),
maps:get(compute_inflight, State, true) =:= false
end),
State = sys:get_state(Pid),
?assertEqual(false, maps:get(pending_full_sync_all, State)),
?assertEqual(false, maps:get(pending_delta, State)),
?assertEqual(false, maps:get(pending_refresh_base_data, State)),
catch gen_server:stop(Pid),
unlink(Shard0Pid),
exit(Shard0Pid, shutdown),
ok.
-endif.

View File

@@ -389,7 +389,7 @@ get_voice_token(ChannelId, UserId, _SessionId, SessionPid, Latitude, Longitude)
connection_id => ConnectionId
}},
ok;
{error, {http_error, _Status, Body}} ->
{error, {rpc_error, _Status, Body}} ->
case parse_unclaimed_error(Body) of
true -> SessionPid ! {voice_error, voice_unclaimed_account};
false -> SessionPid ! {voice_error, voice_token_failed}
@@ -448,7 +448,7 @@ get_dm_voice_token_and_create_state(
IsMobile,
State
);
{error, {http_error, _Status, Body}} ->
{error, {rpc_error, _Status, Body}} ->
case parse_unclaimed_error(Body) of
true -> {reply, gateway_errors:error(voice_unclaimed_account), State};
false -> {reply, gateway_errors:error(voice_token_failed), State}

View File

@@ -149,4 +149,26 @@ handle_virtual_channel_access_for_move(UserId, ChannelId, _ConnectionsToMove, Gu
-spec cleanup_virtual_access_on_disconnect(integer(), pid()) -> ok.
cleanup_virtual_access_on_disconnect(UserId, GuildPid) ->
gen_server:cast(GuildPid, {cleanup_virtual_access_for_user, UserId}).
GuildId = resolve_guild_id_from_pid(GuildPid),
TargetPid = resolve_voice_server(GuildId, GuildPid),
gen_server:cast(TargetPid, {cleanup_virtual_access_for_user, UserId}).
-spec resolve_guild_id_from_pid(pid()) -> integer() | undefined.
resolve_guild_id_from_pid(GuildPid) ->
try gen_server:call(GuildPid, {get_sessions}, 5000) of
State when is_map(State) ->
maps:get(id, State, undefined);
_ ->
undefined
catch
_:_ -> undefined
end.
-spec resolve_voice_server(integer() | undefined, pid()) -> pid().
resolve_voice_server(undefined, FallbackPid) ->
FallbackPid;
resolve_voice_server(GuildId, FallbackPid) ->
case guild_voice_server:lookup(GuildId) of
{ok, VoicePid} -> VoicePid;
{error, not_found} -> FallbackPid
end.

View File

@@ -527,12 +527,6 @@ get_voice_token_and_create_state(Context, Member, ParsedViewerStreamKey, State)
),
VoiceState1 = maybe_attach_session_id(VoiceState0, SessionIdBin),
VoiceState = maybe_attach_member(VoiceState1, Member),
%% NOTE: We intentionally do NOT add the voice state to voice_states
%% or broadcast it yet. The voice state will only be added and
%% broadcast when LiveKit confirms the user has actually connected
%% via confirm_voice_connection_from_livekit/2.
%% This prevents users from appearing in voice channels before
%% they're actually connected and ready to communicate.
Now = erlang:system_time(millisecond),
PendingMetadata = #{
user_id => UserId,
@@ -1059,7 +1053,7 @@ request_voice_token(GuildId, ChannelId, UserId, ConnectionId, VoicePermissions,
endpoint => maps:get(<<"endpoint">>, Data),
connection_id => maps:get(<<"connectionId">>, Data)
}};
{error, {http_error, _Status, Body}} ->
{error, {rpc_error, _Status, Body}} ->
case parse_unclaimed_error(Body) of
true ->
{error, voice_unclaimed_account};

View File

@@ -48,8 +48,6 @@ handle_voice_disconnect(ConnectionId, _SessionId, UserId, VoiceStates0, State) -
VoiceStates = voice_state_utils:ensure_voice_states(VoiceStates0),
case maps:get(ConnectionId, VoiceStates, undefined) of
undefined ->
%% Voice state not in voice_states - check if it's still pending
%% (user disconnected before LiveKit confirmation)
State1 = clear_pending_voice_connection(ConnectionId, State),
{reply, #{success => true}, State1};
OldVoiceState ->
@@ -118,7 +116,6 @@ disconnect_voice_user(#{user_id := UserId} = Request, State) ->
end),
case maps:size(UserVoiceStates) of
0 ->
%% No active voice states - also clean up any pending connections
State1 = clear_pending_voice_connections_for_user(UserId, State),
{reply, #{success => true}, State1};
_ ->
@@ -141,7 +138,6 @@ disconnect_voice_user(#{user_id := UserId} = Request, State) ->
SpecificConnection ->
case maps:get(SpecificConnection, VoiceStates, undefined) of
undefined ->
%% Not found in voice_states - also clean up pending connection
State1 = clear_pending_voice_connection(SpecificConnection, State),
{reply, #{success => true}, State1};
VoiceState ->
@@ -191,8 +187,6 @@ disconnect_voice_user_if_in_channel(
end),
case maps:size(UserVoiceStates) of
0 ->
%% Not found in voice_states - also clean up any pending connections
%% for this user/channel (user disconnected before LiveKit confirmation)
State1 = clear_pending_voice_connections_for_user_channel(
UserId, ExpectedChannelId, State
),
@@ -215,8 +209,6 @@ disconnect_voice_user_if_in_channel(
ConnId ->
case maps:get(ConnId, VoiceStates, undefined) of
undefined ->
%% Not found in voice_states - also clean up pending connection
%% (user disconnected before LiveKit confirmation)
State1 = clear_pending_voice_connection(ConnId, State),
{reply,
#{success => true, ignored => true, reason => <<"connection_not_found">>},
@@ -282,8 +274,6 @@ disconnect_all_voice_users_in_channel(#{channel_id := ChannelId}, State) ->
ChannelVoiceStates = voice_state_utils:filter_voice_states(VoiceStates, fun(_, V) ->
voice_state_utils:voice_state_channel_id(V) =:= ChannelId
end),
%% Also clean up any pending connections for this channel
%% (users that requested tokens but haven't confirmed via LiveKit yet)
State1 = clear_pending_voice_connections_for_channel(ChannelId, State),
case maps:size(ChannelVoiceStates) of
0 ->
@@ -646,8 +636,6 @@ disconnect_voice_user_if_in_channel_skips_force_disconnect_test() ->
ok
end.
%% Tests for clear_pending_voice_connection/2
clear_pending_voice_connection_removes_connection_test() ->
PendingConnections = #{
<<"conn1">> => #{user_id => 1, channel_id => 100},
@@ -670,8 +658,6 @@ clear_pending_voice_connection_handles_empty_pending_test() ->
NewState = clear_pending_voice_connection(<<"conn">>, State),
?assertEqual(#{}, maps:get(pending_voice_connections, NewState, #{})).
%% Tests for clear_pending_voice_connections_for_user/2
clear_pending_voice_connections_for_user_removes_all_user_connections_test() ->
PendingConnections = #{
<<"conn1">> => #{user_id => 5, channel_id => 100},
@@ -685,8 +671,6 @@ clear_pending_voice_connections_for_user_removes_all_user_connections_test() ->
?assertNot(maps:is_key(<<"conn2">>, NewPending)),
?assert(maps:is_key(<<"conn3">>, NewPending)).
%% Tests for clear_pending_voice_connections_for_user_channel/3
clear_pending_voice_connections_for_user_channel_removes_matching_test() ->
PendingConnections = #{
<<"conn1">> => #{user_id => 5, channel_id => 100},
@@ -700,8 +684,6 @@ clear_pending_voice_connections_for_user_channel_removes_matching_test() ->
?assert(maps:is_key(<<"conn2">>, NewPending)),
?assert(maps:is_key(<<"conn3">>, NewPending)).
%% Tests for clear_pending_voice_connections_for_channel/2
clear_pending_voice_connections_for_channel_removes_all_channel_connections_test() ->
PendingConnections = #{
<<"conn1">> => #{user_id => 5, channel_id => 100},
@@ -715,8 +697,6 @@ clear_pending_voice_connections_for_channel_removes_all_channel_connections_test
?assertNot(maps:is_key(<<"conn2">>, NewPending)),
?assert(maps:is_key(<<"conn3">>, NewPending)).
%% Tests for disconnect handlers cleaning up pending connections
handle_voice_disconnect_cleans_pending_when_not_in_voice_states_test() ->
PendingConnections = #{<<"conn1">> => #{user_id => 5, channel_id => 100}},
State = #{

View File

@@ -419,10 +419,9 @@ send_single_voice_server_update(GuildId, ChannelId, SessionInfo, GuildPid) ->
<<"server_deaf">> => ServerDeaf,
<<"member">> => Member
},
_ = gen_server:call(
GuildPid,
{store_pending_connection, NewConnectionId, PendingMetadata},
10000
_ = store_pending_connection(
GuildId, GuildPid,
NewConnectionId, PendingMetadata
),
guild_voice_broadcast:broadcast_voice_server_update_to_session(
GuildId,
@@ -441,6 +440,22 @@ send_single_voice_server_update(GuildId, ChannelId, SessionInfo, GuildPid) ->
end
end.
-spec store_pending_connection(integer(), pid(), binary(), map()) -> ok.
store_pending_connection(GuildId, GuildPid, ConnectionId, Metadata) ->
TargetPid = resolve_voice_server(GuildId, GuildPid),
gen_server:call(
TargetPid,
{store_pending_connection, ConnectionId, Metadata},
10000
).
-spec resolve_voice_server(integer(), pid()) -> pid().
resolve_voice_server(GuildId, FallbackPid) ->
case guild_voice_server:lookup(GuildId) of
{ok, VoicePid} -> VoicePid;
{error, not_found} -> FallbackPid
end.
-ifdef(TEST).
move_member_user_not_in_voice_test() ->

View File

@@ -113,8 +113,8 @@ send_voice_server_update_for_region_switch(
PendingMetadata = build_pending_metadata(
UserId, GuildId, ChannelId, SessionId, ExistingVoiceState, TokenNonce
),
_ = gen_server:call(
GuildPid, {store_pending_connection, ConnectionId, PendingMetadata}, 10000
_ = store_pending_connection(
GuildId, GuildPid, ConnectionId, PendingMetadata
),
guild_voice_broadcast:broadcast_voice_server_update_to_session(
GuildId,
@@ -154,6 +154,22 @@ build_pending_metadata(UserId, GuildId, ChannelId, SessionId, ExistingVoiceState
expires_at => Now + 30000
}.
-spec store_pending_connection(integer(), pid(), binary(), map()) -> ok.
store_pending_connection(GuildId, GuildPid, ConnectionId, Metadata) ->
TargetPid = resolve_voice_server(GuildId, GuildPid),
gen_server:call(
TargetPid,
{store_pending_connection, ConnectionId, Metadata},
10000
).
-spec resolve_voice_server(integer(), pid()) -> pid().
resolve_voice_server(GuildId, FallbackPid) ->
case guild_voice_server:lookup(GuildId) of
{ok, VoicePid} -> VoicePid;
{error, not_found} -> FallbackPid
end.
-ifdef(TEST).
switch_voice_region_handler_not_found_test() ->

View File

@@ -0,0 +1,360 @@
%% Copyright (C) 2026 Fluxer Contributors
%%
%% This file is part of Fluxer.
%%
%% Fluxer is free software: you can redistribute it and/or modify
%% it under the terms of the GNU Affero General Public License as published by
%% the Free Software Foundation, either version 3 of the License, or
%% (at your option) any later version.
%%
%% Fluxer is distributed in the hope that it will be useful,
%% but WITHOUT ANY WARRANTY; without even the implied warranty of
%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
%% GNU Affero General Public License for more details.
%%
%% You should have received a copy of the GNU Affero General Public License
%% along with Fluxer. If not, see <https://www.gnu.org/licenses/>.
-module(guild_voice_server).
-behaviour(gen_server).
-export([
start_link/2,
stop/1,
lookup/1
]).
-export([
init/1,
handle_call/3,
handle_cast/2,
handle_info/2,
terminate/2,
code_change/3
]).
-define(REGISTRY_TABLE, guild_voice_registry).
-define(SWEEP_INTERVAL_MS, 10000).
-define(GUILD_CALL_TIMEOUT, 10000).
-type voice_state() :: map().
-type voice_state_map() :: #{binary() => voice_state()}.
-type server_state() :: #{
guild_id := integer(),
guild_pid := pid(),
voice_states := voice_state_map(),
pending_voice_connections := map(),
recently_disconnected_voice_states := map()
}.
-spec start_link(integer(), pid()) -> {ok, pid()} | {error, term()}.
start_link(GuildId, GuildPid) ->
gen_server:start_link(?MODULE, #{guild_id => GuildId, guild_pid => GuildPid}, []).
-spec stop(pid()) -> ok.
stop(Pid) ->
gen_server:stop(Pid, normal, 5000).
-spec lookup(integer()) -> {ok, pid()} | {error, not_found}.
lookup(GuildId) ->
ensure_registry(),
case ets:lookup(?REGISTRY_TABLE, GuildId) of
[{_, Pid}] when is_pid(Pid) ->
case is_process_alive(Pid) of
true -> {ok, Pid};
false ->
ets:delete(?REGISTRY_TABLE, GuildId),
{error, not_found}
end;
_ ->
{error, not_found}
end.
-spec init(map()) -> {ok, server_state()}.
init(#{guild_id := GuildId, guild_pid := GuildPid}) ->
process_flag(trap_exit, true),
ensure_registry(),
ets:insert(?REGISTRY_TABLE, {GuildId, self()}),
erlang:send_after(?SWEEP_INTERVAL_MS, self(), sweep_pending_joins),
{ok, #{
guild_id => GuildId,
guild_pid => GuildPid,
voice_states => #{},
pending_voice_connections => #{},
recently_disconnected_voice_states => #{}
}}.
-spec handle_call(term(), gen_server:from(), server_state()) ->
{reply, term(), server_state()}.
handle_call({voice_state_update, Request}, _From, State) ->
GuildState = build_guild_state(State),
case guild_voice:voice_state_update(Request, GuildState) of
{reply, Reply, NewGuildState} ->
{reply, Reply, apply_guild_state(NewGuildState, State)}
end;
handle_call({get_voice_state, Request}, _From, State) ->
GuildState = build_guild_state(State),
case guild_voice:get_voice_state(Request, GuildState) of
{reply, Reply, NewGuildState} ->
{reply, Reply, apply_guild_state(NewGuildState, State)}
end;
handle_call({update_member_voice, Request}, _From, State) ->
GuildState = build_guild_state(State),
case guild_voice:update_member_voice(Request, GuildState) of
{reply, Reply, NewGuildState} ->
{reply, Reply, apply_guild_state(NewGuildState, State)}
end;
handle_call({disconnect_voice_user, Request}, _From, State) ->
GuildState = build_guild_state(State),
case guild_voice:disconnect_voice_user(Request, GuildState) of
{reply, Reply, NewGuildState} ->
{reply, Reply, apply_guild_state(NewGuildState, State)}
end;
handle_call({disconnect_voice_user_if_in_channel, Request}, _From, State) ->
GuildState = build_guild_state(State),
case guild_voice:disconnect_voice_user_if_in_channel(Request, GuildState) of
{reply, Reply, NewGuildState} ->
{reply, Reply, apply_guild_state(NewGuildState, State)}
end;
handle_call({disconnect_all_voice_users_in_channel, Request}, _From, State) ->
GuildState = build_guild_state(State),
case guild_voice:disconnect_all_voice_users_in_channel(Request, GuildState) of
{reply, Reply, NewGuildState} ->
{reply, Reply, apply_guild_state(NewGuildState, State)}
end;
handle_call({confirm_voice_connection_from_livekit, Request}, _From, State) ->
GuildState = build_guild_state(State),
case guild_voice:confirm_voice_connection_from_livekit(Request, GuildState) of
{reply, Reply, NewGuildState} ->
{reply, Reply, apply_guild_state(NewGuildState, State)}
end;
handle_call({move_member, Request}, _From, State) ->
GuildState = build_guild_state(State),
case guild_voice:move_member(Request, GuildState) of
{reply, Reply, NewGuildState} ->
{reply, Reply, apply_guild_state(NewGuildState, State)}
end;
handle_call({switch_voice_region, Request}, _From, State) ->
GuildState = build_guild_state(State),
case guild_voice:switch_voice_region_handler(Request, GuildState) of
{reply, Reply, NewGuildState} ->
{reply, Reply, apply_guild_state(NewGuildState, State)}
end;
handle_call({store_pending_connection, ConnectionId, Metadata}, _From, State) ->
PendingConnections = maps:get(pending_voice_connections, State, #{}),
NewPendingConnections = maps:put(ConnectionId, Metadata, PendingConnections),
NewState = maps:put(pending_voice_connections, NewPendingConnections, State),
{reply, ok, NewState};
handle_call({get_voice_states_for_channel, ChannelIdBin}, _From, State) ->
VoiceStates = maps:get(voice_states, State, #{}),
Filtered = maps:fold(
fun(ConnId, VS, Acc) ->
case maps:get(<<"channel_id">>, VS, null) of
ChannelIdBin ->
[#{
connection_id => ConnId,
user_id => maps:get(<<"user_id">>, VS, null),
channel_id => ChannelIdBin
} | Acc];
_ ->
Acc
end
end,
[],
VoiceStates
),
{reply, #{voice_states => Filtered}, State};
handle_call({get_pending_joins_for_channel, ChannelIdBin}, _From, State) ->
PendingConnections = maps:get(pending_voice_connections, State, #{}),
ChannelIdInt = binary_to_integer(ChannelIdBin),
Filtered = maps:fold(
fun(ConnId, Metadata, Acc) ->
case maps:get(channel_id, Metadata, undefined) of
ChannelIdInt ->
[#{
connection_id => ConnId,
user_id => integer_to_binary(maps:get(user_id, Metadata, 0)),
token_nonce => maps:get(token_nonce, Metadata, null),
expires_at => maps:get(expires_at, Metadata, 0)
} | Acc];
_ ->
Acc
end
end,
[],
PendingConnections
),
{reply, #{pending_joins => Filtered}, State};
handle_call({get_voice_states_list}, _From, State) ->
VoiceStates = maps:get(voice_states, State, #{}),
{reply, maps:values(VoiceStates), State};
handle_call({get_voice_states_map}, _From, State) ->
{reply, maps:get(voice_states, State, #{}), State};
handle_call({set_voice_states, VoiceStates}, _From, State) ->
{reply, ok, maps:put(voice_states, VoiceStates, State)};
handle_call(_, _From, State) ->
{reply, ok, State}.
-spec handle_cast(term(), server_state()) -> {noreply, server_state()}.
handle_cast({store_pending_connection, ConnectionId, Metadata}, State) ->
PendingConnections = maps:get(pending_voice_connections, State, #{}),
NewPendingConnections = maps:put(ConnectionId, Metadata, PendingConnections),
NewState = maps:put(pending_voice_connections, NewPendingConnections, State),
{noreply, NewState};
handle_cast({relay_voice_state_update, VoiceState, OldChannelIdBin}, State) ->
GuildState = build_guild_state(State),
State1 = relay_upsert_voice_state(VoiceState, State),
GuildStateNoRelay = maps:remove(very_large_guild_coordinator_pid, GuildState),
_ = guild_voice_broadcast:broadcast_voice_state_update(
VoiceState, GuildStateNoRelay, OldChannelIdBin
),
{noreply, State1};
handle_cast(
{relay_voice_server_update, GuildId, ChannelId, SessionId, Token, Endpoint, ConnectionId},
State
) ->
GuildState = build_guild_state(State),
GuildStateNoRelay = maps:remove(very_large_guild_coordinator_pid, GuildState),
_ = guild_voice_broadcast:broadcast_voice_server_update_to_session(
GuildId,
ChannelId,
SessionId,
Token,
Endpoint,
ConnectionId,
GuildStateNoRelay
),
{noreply, State};
handle_cast({cleanup_virtual_access_for_user, UserId}, State) ->
GuildState = build_guild_state(State),
NewGuildState = guild_voice_disconnect:cleanup_virtual_channel_access_for_user(
UserId, GuildState
),
{noreply, apply_guild_state(NewGuildState, State)};
handle_cast(_, State) ->
{noreply, State}.
-spec handle_info(term(), server_state()) ->
{noreply, server_state()} | {stop, normal, server_state()}.
handle_info(sweep_pending_joins, State) ->
GuildState = build_guild_state(State),
NewGuildState = guild_voice_connection:sweep_expired_pending_joins(GuildState),
erlang:send_after(?SWEEP_INTERVAL_MS, self(), sweep_pending_joins),
{noreply, apply_guild_state(NewGuildState, State)};
handle_info({'EXIT', Pid, Reason}, #{guild_pid := GuildPid} = State) when Pid =:= GuildPid ->
logger:info(
"Voice server shutting down because guild process exited",
#{guild_id => maps:get(guild_id, State), reason => Reason}
),
{stop, normal, State};
handle_info(_, State) ->
{noreply, State}.
-spec terminate(term(), server_state()) -> ok.
terminate(_Reason, #{guild_id := GuildId}) ->
catch ets:delete(?REGISTRY_TABLE, GuildId),
ok;
terminate(_Reason, _State) ->
ok.
-spec code_change(term(), server_state(), term()) -> {ok, server_state()}.
code_change(_OldVsn, State, _Extra) ->
{ok, State}.
-spec build_guild_state(server_state()) -> map().
build_guild_state(#{guild_pid := GuildPid} = State) ->
GuildData = fetch_guild_data(GuildPid),
maps:merge(GuildData, #{
voice_states => maps:get(voice_states, State, #{}),
pending_voice_connections => maps:get(pending_voice_connections, State, #{}),
recently_disconnected_voice_states =>
maps:get(recently_disconnected_voice_states, State, #{})
}).
-spec apply_guild_state(map(), server_state()) -> server_state().
apply_guild_state(GuildState, State) ->
State#{
voice_states => maps:get(voice_states, GuildState, maps:get(voice_states, State, #{})),
pending_voice_connections =>
maps:get(
pending_voice_connections,
GuildState,
maps:get(pending_voice_connections, State, #{})
),
recently_disconnected_voice_states =>
maps:get(
recently_disconnected_voice_states,
GuildState,
maps:get(recently_disconnected_voice_states, State, #{})
)
}.
-spec fetch_guild_data(pid()) -> map().
fetch_guild_data(GuildPid) ->
try gen_server:call(GuildPid, {get_sessions}, ?GUILD_CALL_TIMEOUT) of
GuildState when is_map(GuildState) ->
GuildState;
_ ->
#{}
catch
exit:{timeout, _} ->
logger:warning("Voice server timed out fetching guild state", #{}),
#{};
exit:{noproc, _} ->
#{};
exit:{normal, _} ->
#{}
end.
-spec relay_upsert_voice_state(map(), server_state()) -> server_state().
relay_upsert_voice_state(VoiceState, State) when is_map(VoiceState) ->
ConnectionId = maps:get(<<"connection_id">>, VoiceState, undefined),
case ConnectionId of
undefined ->
State;
_ ->
VoiceStates0 = maps:get(voice_states, State, #{}),
ChannelId = maps:get(<<"channel_id">>, VoiceState, null),
VoiceStates =
case ChannelId of
null -> maps:remove(ConnectionId, VoiceStates0);
_ -> maps:put(ConnectionId, VoiceState, VoiceStates0)
end,
maps:put(voice_states, VoiceStates, State)
end;
relay_upsert_voice_state(_, State) ->
State.
-spec ensure_registry() -> ok.
ensure_registry() ->
guild_ets_utils:ensure_table(?REGISTRY_TABLE, [
named_table,
public,
set,
{read_concurrency, true},
{write_concurrency, true}
]).

View File

@@ -646,18 +646,40 @@ dispatch_global_presence(TargetId, Payload, State) ->
true ->
{noreply, State};
false ->
cache_if_visible(TargetId, Payload),
Sessions = maps:get(sessions, State),
SessionPids = [maps:get(pid, S) || S <- maps:values(Sessions)],
lists:foreach(
fun(Pid) when is_pid(Pid) ->
gen_server:cast(Pid, {dispatch, presence_update, Payload})
end,
SessionPids
),
case maps:get(<<"user_update">>, Payload, false) of
true ->
dispatch_global_user_update(TargetId, Payload, State);
false ->
cache_if_visible(TargetId, Payload),
dispatch_to_sessions(Payload, State),
{noreply, State}
end
end.
-spec dispatch_global_user_update(user_id(), map(), state()) -> {noreply, state()}.
dispatch_global_user_update(TargetId, Payload, State) ->
NewUserData = maps:get(<<"user">>, Payload, #{}),
case presence_cache:get(TargetId) of
{ok, CachedPresence} ->
MergedPresence = maps:put(<<"user">>, NewUserData, CachedPresence),
presence_cache:put(TargetId, MergedPresence),
dispatch_to_sessions(MergedPresence, State),
{noreply, State};
_ ->
{noreply, State}
end.
-spec dispatch_to_sessions(map(), state()) -> ok.
dispatch_to_sessions(Payload, State) ->
Sessions = maps:get(sessions, State),
SessionPids = [maps:get(pid, S) || S <- maps:values(Sessions)],
lists:foreach(
fun(Pid) when is_pid(Pid) ->
gen_server:cast(Pid, {dispatch, presence_update, Payload})
end,
SessionPids
).
-spec sync_friend_subscriptions([user_id()], state()) -> state().
sync_friend_subscriptions(FriendIds, State) ->
case maps:get(is_bot, State, false) of

View File

@@ -38,7 +38,8 @@ ensure_status_binary(online) -> <<"online">>;
ensure_status_binary(offline) -> <<"offline">>;
ensure_status_binary(idle) -> <<"idle">>;
ensure_status_binary(dnd) -> <<"dnd">>;
ensure_status_binary(invisible) -> <<"invisible">>;
ensure_status_binary(invisible) -> <<"offline">>;
ensure_status_binary(<<"invisible">>) -> <<"offline">>;
ensure_status_binary(Status) when is_binary(Status) -> Status;
ensure_status_binary(_) -> <<"offline">>.
@@ -69,10 +70,11 @@ ensure_status_binary_atom_test() ->
?assertEqual(<<"offline">>, ensure_status_binary(offline)),
?assertEqual(<<"idle">>, ensure_status_binary(idle)),
?assertEqual(<<"dnd">>, ensure_status_binary(dnd)),
?assertEqual(<<"invisible">>, ensure_status_binary(invisible)).
?assertEqual(<<"offline">>, ensure_status_binary(invisible)).
ensure_status_binary_binary_test() ->
?assertEqual(<<"online">>, ensure_status_binary(<<"online">>)),
?assertEqual(<<"offline">>, ensure_status_binary(<<"invisible">>)),
?assertEqual(<<"custom">>, ensure_status_binary(<<"custom">>)).
ensure_status_binary_unknown_test() ->
@@ -98,4 +100,18 @@ normalize_custom_status_test() ->
?assertEqual(#{<<"text">> => <<"hi">>}, normalize_custom_status(#{<<"text">> => <<"hi">>})),
?assertEqual(null, normalize_custom_status(<<"invalid">>)),
?assertEqual(null, normalize_custom_status(123)).
build_invisible_atom_normalized_to_offline_test() ->
User = #{<<"id">> => <<"1">>, <<"username">> => <<"Test">>},
CustomStatus = #{<<"text">> => <<"hello">>},
Result = build(User, invisible, false, false, CustomStatus),
?assertEqual(<<"offline">>, maps:get(<<"status">>, Result)),
?assertEqual(null, maps:get(<<"custom_status">>, Result)).
build_invisible_binary_normalized_to_offline_test() ->
User = #{<<"id">> => <<"1">>, <<"username">> => <<"Test">>},
CustomStatus = #{<<"text">> => <<"hello">>},
Result = build(User, <<"invisible">>, false, false, CustomStatus),
?assertEqual(<<"offline">>, maps:get(<<"status">>, Result)),
?assertEqual(null, maps:get(<<"custom_status">>, Result)).
-endif.

View File

@@ -24,7 +24,6 @@
-export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]).
-export_type([session_data/0, user_id/0]).
-define(IDENTIFY_FLAG_USE_CANARY_API, 16#1).
-define(IDENTIFY_FLAG_DEBOUNCE_MESSAGE_REACTIONS, 16#2).
-type session_id() :: binary().
@@ -60,8 +59,6 @@
-type state() :: #{
sessions := #{session_id() => session_ref()},
api_host := string(),
api_canary_host := undefined | string(),
identify_attempts := [identify_timestamp()],
pending_identifies := #{session_id() => pending_identify()},
identify_workers := #{reference() => session_id()},
@@ -80,7 +77,7 @@
| {error, rate_limited}
| {error, identify_rate_limited}
| {error, {server_error, non_neg_integer()}}
| {error, {http_error, non_neg_integer()}}
| {error, {rpc_error, non_neg_integer(), binary()}}
| {error, {network_error, term()}}
| {error, registration_failed}
| {error, term()}.
@@ -95,13 +92,9 @@ start_link(ShardIndex) ->
init(Args) ->
fluxer_gateway_env:load(),
process_flag(trap_exit, true),
ApiHost = fluxer_gateway_env:get(api_host),
ApiCanaryHost = fluxer_gateway_env:get(api_canary_host),
ShardIndex = maps:get(shard_index, Args, 0),
{ok, #{
sessions => #{},
api_host => ApiHost,
api_canary_host => ApiCanaryHost,
identify_attempts => [],
pending_identifies => #{},
identify_workers => #{},
@@ -153,7 +146,7 @@ handle_call({start, Request, SocketPid}, From, State) ->
{reply, {success, Pid}, maps:put(sessions, NewSessions, State)}
end
end;
handle_call({lookup, SessionId}, _From, State) ->
handle_call({lookup, SessionId}, _From, State) when is_binary(SessionId) ->
Sessions = maps:get(sessions, State),
case maps:get(SessionId, Sessions, undefined) of
{Pid, _Ref} ->
@@ -169,6 +162,8 @@ handle_call({lookup, SessionId}, _From, State) ->
{reply, {ok, Pid}, maps:put(sessions, NewSessions, State)}
end
end;
handle_call({lookup, _InvalidSessionId}, _From, State) ->
{reply, {error, not_found}, State};
handle_call(get_local_count, _From, State) ->
Sessions = maps:get(sessions, State),
{reply, {ok, maps:size(Sessions)}, State};
@@ -299,14 +294,11 @@ handle_cast(_, State) ->
-spec start_identify_fetch(identify_request(), pid(), session_id(), gen_server:from(), state()) ->
{noreply, state()}.
start_identify_fetch(Request, SocketPid, SessionId, From, State) ->
IdentifyData = maps:get(identify_data, Request),
UseCanary = should_use_canary_api(IdentifyData),
{_UsedCanary, RpcClient} = select_rpc_client(State, UseCanary),
ManagerPid = self(),
{_WorkerPid, WorkerRef} =
spawn_monitor(fun() ->
PeerIP = maps:get(peer_ip, Request),
FetchResult = fetch_rpc_data(Request, PeerIP, RpcClient),
FetchResult = fetch_rpc_data(Request, PeerIP),
ManagerPid ! {identify_fetch_result, SessionId, FetchResult}
end),
PendingIdentifies = maps:get(pending_identifies, State),
@@ -322,26 +314,6 @@ start_identify_fetch(Request, SocketPid, SessionId, From, State) ->
identify_workers := NewWorkers
}}.
-spec select_rpc_client(state(), boolean()) -> {boolean(), string()}.
select_rpc_client(State, true) ->
case maps:get(api_canary_host, State) of
undefined ->
{false, maps:get(api_host, State)};
CanaryHost ->
{true, CanaryHost}
end;
select_rpc_client(State, false) ->
{false, maps:get(api_host, State)}.
-spec should_use_canary_api(map()) -> boolean().
should_use_canary_api(IdentifyData) ->
case map_utils:get_safe(IdentifyData, flags, 0) of
Flags when is_integer(Flags), Flags >= 0 ->
(Flags band ?IDENTIFY_FLAG_USE_CANARY_API) =/= 0;
_ ->
false
end.
-spec should_debounce_reactions(map()) -> boolean().
should_debounce_reactions(IdentifyData) ->
case map_utils:get_safe(IdentifyData, flags, 0) of
@@ -460,13 +432,9 @@ code_change(_OldVsn, State, _Extra) when is_map(State) ->
}};
code_change(_OldVsn, State, _Extra) when is_tuple(State), element(1, State) =:= state ->
Sessions = element(2, State),
ApiHost = element(3, State),
ApiCanaryHost = element(4, State),
IdentifyAttempts = element(5, State),
{ok, #{
sessions => Sessions,
api_host => ApiHost,
api_canary_host => ApiCanaryHost,
identify_attempts => IdentifyAttempts,
pending_identifies => #{},
identify_workers => #{},
@@ -475,31 +443,27 @@ code_change(_OldVsn, State, _Extra) when is_tuple(State), element(1, State) =:=
code_change(_OldVsn, State, _Extra) ->
{ok, State}.
-spec fetch_rpc_data(map(), term(), string()) ->
-spec fetch_rpc_data(map(), term()) ->
{ok, map()}
| {error, invalid_token}
| {error, rate_limited}
| {error, {server_error, non_neg_integer()}}
| {error, {http_error, non_neg_integer()}}
| {error, {network_error, term()}}.
fetch_rpc_data(Request, PeerIP, ApiHost) ->
fetch_rpc_data(Request, PeerIP) ->
StartTime = erlang:system_time(millisecond),
Result = do_fetch_rpc_data(Request, PeerIP, ApiHost),
Result = do_fetch_rpc_data(Request, PeerIP),
EndTime = erlang:system_time(millisecond),
LatencyMs = EndTime - StartTime,
gateway_metrics_collector:record_rpc_latency(LatencyMs),
Result.
-spec do_fetch_rpc_data(map(), term(), string()) ->
-spec do_fetch_rpc_data(map(), term()) ->
{ok, map()}
| {error, invalid_token}
| {error, rate_limited}
| {error, {server_error, non_neg_integer()}}
| {error, {http_error, non_neg_integer()}}
| {error, {network_error, term()}}.
do_fetch_rpc_data(Request, PeerIP, ApiHost) ->
Url = rpc_client:get_rpc_url(ApiHost),
Headers = rpc_client:get_rpc_headers() ++ [{<<"content-type">>, <<"application/json">>}],
do_fetch_rpc_data(Request, PeerIP) ->
IdentifyData = maps:get(identify_data, Request),
Properties = map_utils:get_safe(IdentifyData, properties, #{}),
LatitudeRaw = map_utils:get_safe(Properties, <<"latitude">>, undefined),
@@ -513,8 +477,18 @@ do_fetch_rpc_data(Request, PeerIP, ApiHost) ->
<<"ip">> => PeerIP
},
RpcRequestWithCoords = add_coordinates(RpcRequest, Latitude, Longitude),
Body = json:encode(RpcRequestWithCoords),
execute_rpc_request(Url, Headers, Body).
case rpc_client:call(RpcRequestWithCoords) of
{ok, Data} ->
{ok, Data};
{error, {rpc_error, 401, _}} ->
{error, invalid_token};
{error, {rpc_error, 429, _}} ->
{error, rate_limited};
{error, {rpc_error, StatusCode, _}} when StatusCode >= 500 ->
{error, {server_error, StatusCode}};
{error, Reason} ->
{error, {network_error, Reason}}
end.
-spec normalize_coordinate(term()) -> term() | undefined.
normalize_coordinate(undefined) -> undefined;
@@ -531,31 +505,6 @@ add_coordinates(Request, undefined, Lon) ->
add_coordinates(Request, Lat, Lon) ->
maps:merge(Request, #{<<"latitude">> => Lat, <<"longitude">> => Lon}).
-spec execute_rpc_request(iodata(), list(), binary()) ->
{ok, map()}
| {error, invalid_token}
| {error, rate_limited}
| {error, {server_error, non_neg_integer()}}
| {error, {http_error, non_neg_integer()}}
| {error, {network_error, term()}}.
execute_rpc_request(Url, Headers, Body) ->
case gateway_http_client:request(rpc, post, Url, Headers, Body) of
{ok, 200, _RespHeaders, ResponseBody} ->
ResponseData = json:decode(ResponseBody),
{ok, maps:get(<<"data">>, ResponseData)};
{ok, 401, _RespHeaders, _ResponseBody} ->
{error, invalid_token};
{ok, 429, _RespHeaders, _ResponseBody} ->
{error, rate_limited};
{ok, StatusCode, _RespHeaders, _ResponseBody} when StatusCode >= 500 ->
{error, {server_error, StatusCode}};
{ok, StatusCode, _RespHeaders, _ResponseBody} when StatusCode >= 400 ->
{error, {http_error, StatusCode}};
{ok, StatusCode, _RespHeaders, _ResponseBody} ->
{error, {http_error, StatusCode}};
{error, Reason} ->
{error, {network_error, Reason}}
end.
-spec parse_presence(map(), map()) -> status().
parse_presence(Data, IdentifyData) ->
@@ -689,15 +638,6 @@ add_coordinates_test() ->
),
ok.
should_use_canary_api_test() ->
?assertEqual(false, should_use_canary_api(#{})),
?assertEqual(false, should_use_canary_api(#{flags => 0})),
?assertEqual(true, should_use_canary_api(#{flags => 1})),
?assertEqual(true, should_use_canary_api(#{flags => 17})),
?assertEqual(false, should_use_canary_api(#{flags => 2})),
?assertEqual(false, should_use_canary_api(#{flags => -1})),
ok.
should_debounce_reactions_test() ->
?assertEqual(false, should_debounce_reactions(#{})),
?assertEqual(false, should_debounce_reactions(#{flags => 0})),

View File

@@ -71,9 +71,6 @@ handle_presence_down(State) ->
handle_guild_down(GuildId, killed, State, _Guilds) ->
gen_server:cast(self(), {guild_leave, GuildId}),
{noreply, State};
handle_guild_down(GuildId, normal, State, _Guilds) ->
gen_server:cast(self(), {guild_leave, GuildId}),
{noreply, State};
handle_guild_down(GuildId, _Reason, State, Guilds) ->
GuildDeleteData = #{
<<"id">> => integer_to_binary(GuildId),
@@ -154,4 +151,161 @@ find_call_by_ref_test() ->
?assertEqual(not_found, find_call_by_ref(make_ref(), Calls)),
ok.
-spec build_test_session_state(guild_id(), #{guild_id() => guild_ref()}) -> session_state().
build_test_session_state(GuildId, Guilds) ->
#{
id => <<"session-monitor-test">>,
user_id => 1,
user_data => #{},
custom_status => null,
version => 1,
token_hash => <<>>,
auth_session_id_hash => <<>>,
buffer => [],
seq => 0,
ack_seq => 0,
properties => #{},
status => online,
afk => false,
mobile => false,
presence_pid => undefined,
presence_mref => undefined,
socket_pid => undefined,
socket_mref => undefined,
guilds => Guilds,
calls => #{},
channels => #{},
ready => undefined,
bot => false,
ignored_events => #{},
initial_guild_id => GuildId,
collected_guild_states => [],
collected_sessions => [],
collected_presences => [],
relationships => #{},
suppress_presence_updates => false,
pending_presences => [],
guild_connect_inflight => #{},
voice_queue => queue:new(),
voice_queue_timer => undefined,
debounce_reactions => false,
reaction_buffer => [],
reaction_buffer_timer => undefined
}.
handle_guild_down_normal_marks_unavailable_and_schedules_reconnect_test() ->
GuildId = 50001,
GuildRef = make_ref(),
GuildPid = spawn(fun() -> receive stop -> ok end end),
Guilds = #{GuildId => {GuildPid, GuildRef}},
State0 = build_test_session_state(GuildId, Guilds),
{noreply, State1} = handle_guild_down(GuildId, normal, State0, Guilds),
UpdatedGuilds = maps:get(guilds, State1),
?assertEqual(undefined, maps:get(GuildId, UpdatedGuilds)),
Buffer = maps:get(buffer, State1),
?assertEqual(1, length(Buffer)),
[Event] = Buffer,
?assertEqual(guild_delete, maps:get(event, Event)),
EventData = maps:get(data, Event),
?assertEqual(integer_to_binary(GuildId), maps:get(<<"id">>, EventData)),
?assertEqual(true, maps:get(<<"unavailable">>, EventData)),
receive
{guild_connect, GuildId, 0} -> ok
after 2000 ->
?assert(false)
end,
GuildPid ! stop.
handle_guild_down_shutdown_marks_unavailable_and_schedules_reconnect_test() ->
GuildId = 50002,
GuildRef = make_ref(),
GuildPid = spawn(fun() -> receive stop -> ok end end),
Guilds = #{GuildId => {GuildPid, GuildRef}},
State0 = build_test_session_state(GuildId, Guilds),
{noreply, State1} = handle_guild_down(GuildId, shutdown, State0, Guilds),
UpdatedGuilds = maps:get(guilds, State1),
?assertEqual(undefined, maps:get(GuildId, UpdatedGuilds)),
Buffer = maps:get(buffer, State1),
?assertEqual(1, length(Buffer)),
[Event] = Buffer,
?assertEqual(guild_delete, maps:get(event, Event)),
EventData = maps:get(data, Event),
?assertEqual(true, maps:get(<<"unavailable">>, EventData)),
receive
{guild_connect, GuildId, 0} -> ok
after 2000 ->
?assert(false)
end,
GuildPid ! stop.
handle_guild_down_crash_marks_unavailable_and_schedules_reconnect_test() ->
GuildId = 50003,
GuildRef = make_ref(),
GuildPid = spawn(fun() -> receive stop -> ok end end),
Guilds = #{GuildId => {GuildPid, GuildRef}},
State0 = build_test_session_state(GuildId, Guilds),
{noreply, State1} = handle_guild_down(GuildId, {error, something_went_wrong}, State0, Guilds),
UpdatedGuilds = maps:get(guilds, State1),
?assertEqual(undefined, maps:get(GuildId, UpdatedGuilds)),
Buffer = maps:get(buffer, State1),
?assertEqual(1, length(Buffer)),
[Event] = Buffer,
?assertEqual(guild_delete, maps:get(event, Event)),
EventData = maps:get(data, Event),
?assertEqual(true, maps:get(<<"unavailable">>, EventData)),
receive
{guild_connect, GuildId, 0} -> ok
after 2000 ->
?assert(false)
end,
GuildPid ! stop.
handle_guild_down_killed_sends_permanent_guild_leave_test() ->
GuildId = 50004,
GuildRef = make_ref(),
GuildPid = spawn(fun() -> receive stop -> ok end end),
Guilds = #{GuildId => {GuildPid, GuildRef}},
State0 = build_test_session_state(GuildId, Guilds),
{noreply, State1} = handle_guild_down(GuildId, killed, State0, Guilds),
?assertEqual([], maps:get(buffer, State1, [])),
UpdatedGuilds = maps:get(guilds, State1),
?assertEqual({GuildPid, GuildRef}, maps:get(GuildId, UpdatedGuilds)),
receive
{guild_connect, GuildId, 0} ->
?assert(false)
after 200 ->
ok
end,
receive
{'$gen_cast', {guild_leave, GuildId}} -> ok
after 200 ->
?assert(false)
end,
GuildPid ! stop.
handle_process_down_guild_normal_exit_dispatches_unavailable_test() ->
GuildId = 50005,
GuildPid = spawn(fun() -> receive stop -> ok end end),
GuildRef = monitor(process, GuildPid),
Guilds = #{GuildId => {GuildPid, GuildRef}},
State0 = build_test_session_state(GuildId, Guilds),
GuildPid ! stop,
receive
{'DOWN', GuildRef, process, GuildPid, Reason} ->
{noreply, State1} = handle_process_down(GuildRef, Reason, State0),
Buffer = maps:get(buffer, State1),
?assertEqual(1, length(Buffer)),
[Event] = Buffer,
?assertEqual(guild_delete, maps:get(event, Event)),
EventData = maps:get(data, Event),
?assertEqual(true, maps:get(<<"unavailable">>, EventData)),
receive
{guild_connect, GuildId, 0} -> ok
after 2000 ->
?assert(false)
end
after 2000 ->
?assert(false)
end.
-endif.

View File

@@ -19,6 +19,7 @@
-export([
is_passive/2,
is_small_guild/1,
set_active/2,
set_passive/2,
should_receive_event/5,

View File

@@ -359,7 +359,7 @@ handle_guild_voice_state_update(
_Longitude,
SessionPid
) ->
case guild_client:voice_state_update(GuildPid, Request, 12000) of
case guild_client:voice_state_update(GuildPid, GuildId, Request, 12000) of
{ok, Reply} when is_map(Reply) ->
maybe_dispatch_voice_server_update_from_reply(Reply, GuildId, ChannelId, SessionPid),
ok;
@@ -427,9 +427,9 @@ handle_voice_disconnect(State) ->
dispatch_guild_voice_disconnects(Guilds, Request) ->
lists:foreach(
fun
({_GuildId, {GuildPid, _Ref}}) when is_pid(GuildPid) ->
({GuildId, {GuildPid, _Ref}}) when is_pid(GuildPid) ->
spawn(fun() ->
_ = guild_client:voice_state_update(GuildPid, Request, 10000),
_ = guild_client:voice_state_update(GuildPid, GuildId, Request, 10000),
ok
end),
ok;

View File

@@ -57,7 +57,7 @@ register_and_monitor(Name, Pid, ProcessMap) ->
{ok, Pid, Ref, NewMap}
catch
error:badarg ->
catch gen_server:stop(Pid, normal, 5000),
force_stop_process(Pid),
case whereis(Name) of
undefined ->
{error, registration_race_condition};
@@ -70,6 +70,22 @@ register_and_monitor(Name, Pid, ProcessMap) ->
{error, {Error, Reason}}
end.
-spec force_stop_process(pid()) -> ok.
force_stop_process(Pid) ->
MRef = monitor(process, Pid),
exit(Pid, shutdown),
receive
{'DOWN', MRef, process, Pid, _} -> ok
after 3000 ->
exit(Pid, kill),
receive
{'DOWN', MRef, process, Pid, _} -> ok
after 2000 ->
demonitor(MRef, [flush]),
ok
end
end.
-spec lookup_or_monitor(atom(), term(), process_map()) -> lookup_result().
lookup_or_monitor(Name, Key, ProcessMap) ->
case whereis(Name) of
@@ -661,4 +677,54 @@ integration_rapid_cycles_test_() ->
)
end}.
force_stop_process_normal_test_() ->
{timeout, 15, fun() ->
Pid = spawn(fun() ->
receive stop -> ok end
end),
?assert(is_process_alive(Pid)),
force_stop_process(Pid),
timer:sleep(50),
?assertEqual(false, is_process_alive(Pid))
end}.
force_stop_process_already_dead_test() ->
Pid = spawn(fun() -> ok end),
timer:sleep(10),
?assertEqual(false, is_process_alive(Pid)),
force_stop_process(Pid).
force_stop_process_kills_unresponsive_test_() ->
{timeout, 15, fun() ->
Pid = spawn(fun() ->
process_flag(trap_exit, true),
receive
never_arrives -> ok
end
end),
?assert(is_process_alive(Pid)),
force_stop_process(Pid),
timer:sleep(100),
?assertEqual(false, is_process_alive(Pid))
end}.
register_and_monitor_duplicate_stops_loser_test_() ->
{timeout, 15, fun() ->
Name = test_reg_dup_stops_loser,
WinnerPid = spawn(fun() -> timer:sleep(5000) end),
register(Name, WinnerPid),
LoserPid = spawn(fun() ->
process_flag(trap_exit, true),
receive
never_arrives -> ok
end
end),
?assert(is_process_alive(LoserPid)),
Result = register_and_monitor(Name, LoserPid, #{}),
?assertMatch({ok, WinnerPid, _, _}, Result),
timer:sleep(100),
?assertEqual(false, is_process_alive(LoserPid)),
unregister(Name)
end}.
-endif.