refactor progress

This commit is contained in:
Hampus Kraft
2026-02-17 12:22:36 +00:00
parent cb31608523
commit d5abd1a7e4
8257 changed files with 1190207 additions and 761040 deletions

View File

@@ -21,12 +21,64 @@
-export([start_link/1]).
-export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]).
-type session_id() :: binary().
-type user_id() :: integer().
-type guild_id() :: integer().
-type channel_id() :: integer().
-type seq() :: non_neg_integer().
-type status() :: online | offline | idle | dnd.
-type guild_ref() :: {pid(), reference()} | undefined | cached_unavailable.
-type call_ref() :: {pid(), reference()}.
-type session_state() :: #{
id := session_id(),
user_id := user_id(),
user_data := map(),
custom_status := map() | null,
version := non_neg_integer(),
token_hash := binary(),
auth_session_id_hash := binary(),
buffer := [map()],
seq := seq(),
ack_seq := seq(),
properties := map(),
status := status(),
afk := boolean(),
mobile := boolean(),
presence_pid := pid() | undefined,
presence_mref := reference() | undefined,
socket_pid := pid() | undefined,
socket_mref := reference() | undefined,
guilds := #{guild_id() => guild_ref()},
calls := #{channel_id() => call_ref()},
channels := #{channel_id() => map()},
ready := map() | undefined,
bot := boolean(),
ignored_events := #{binary() => true},
initial_guild_id := guild_id() | undefined,
collected_guild_states := [map()],
collected_sessions := [map()],
collected_presences := [map()],
relationships := #{user_id() => integer()},
suppress_presence_updates := boolean(),
pending_presences := [map()],
guild_connect_inflight := #{guild_id() => non_neg_integer()},
voice_queue := queue:queue(),
voice_queue_timer := reference() | undefined,
debounce_reactions := boolean(),
reaction_buffer := [map()],
reaction_buffer_timer := reference() | undefined
}.
-export_type([session_state/0, session_id/0, user_id/0, guild_id/0, channel_id/0, seq/0]).
-spec start_link(map()) -> {ok, pid()} | {error, term()}.
start_link(SessionData) ->
gen_server:start_link(?MODULE, SessionData, []).
-spec init(map()) -> {ok, session_state()}.
init(SessionData) ->
process_flag(trap_exit, true),
Id = maps:get(id, SessionData),
UserId = maps:get(user_id, SessionData),
UserData = maps:get(user_data, SessionData),
@@ -48,13 +100,9 @@ init(SessionData) ->
false -> Ready0
end,
IgnoredEvents = build_ignored_events_map(maps:get(ignored_events, SessionData, [])),
DebounceReactions = maps:get(debounce_reactions, SessionData, false),
Channels = load_private_channels(Ready),
logger:debug("[session] Loaded ~p private channels into session state for user ~p", [
maps:size(Channels),
UserId
]),
VoiceQueueState = session_voice:init_voice_queue(),
State = #{
id => Id,
user_id => UserId,
@@ -87,9 +135,12 @@ init(SessionData) ->
relationships => load_relationships(Ready),
suppress_presence_updates => true,
pending_presences => [],
guild_connect_inflight => #{}
guild_connect_inflight => #{},
debounce_reactions => DebounceReactions,
reaction_buffer => [],
reaction_buffer_timer => undefined
},
StateWithVoiceQueue = maps:merge(State, VoiceQueueState),
self() ! {presence_connect, 0},
case Bot of
true -> self() ! bot_initial_ready;
@@ -98,9 +149,19 @@ init(SessionData) ->
lists:foreach(fun(Gid) -> self() ! {guild_connect, Gid, 0} end, GuildIds),
erlang:send_after(3000, self(), premature_readiness),
erlang:send_after(200, self(), enable_presence_updates),
{ok, StateWithVoiceQueue}.
{ok, State}.
-spec handle_call(Request, From, State) -> Result when
Request ::
{token_verify, binary()}
| {heartbeat_ack, seq()}
| {resume, seq(), pid()}
| {get_state}
| {voice_state_update, map()}
| term(),
From :: gen_server:from(),
State :: session_state(),
Result :: {reply, term(), session_state()}.
handle_call({token_verify, Token}, _From, State) ->
TokenHash = maps:get(token_hash, State),
HashedInput = utils:hash_token(Token),
@@ -109,7 +170,6 @@ handle_call({token_verify, Token}, _From, State) ->
handle_call({heartbeat_ack, Seq}, _From, State) ->
AckSeq = maps:get(ack_seq, State),
Buffer = maps:get(buffer, State),
if
Seq < AckSeq ->
{reply, false, State};
@@ -125,7 +185,6 @@ handle_call({resume, Seq, SocketPid}, _From, State) ->
Status = maps:get(status, State),
Afk = maps:get(afk, State),
Mobile = maps:get(mobile, State),
if
Seq > CurrentSeq ->
{reply, invalid_seq, State};
@@ -135,23 +194,23 @@ handle_call({resume, Seq, SocketPid}, _From, State) ->
socket_pid => SocketPid,
socket_mref => monitor(process, SocketPid)
}),
case PresencePid of
undefined ->
ok;
Pid when is_pid(Pid) ->
gen_server:call(
Pid,
{session_connect, #{
session_id => SessionId,
status => Status,
afk => Afk,
mobile => Mobile
}},
10000
)
spawn(fun() ->
gen_server:call(
Pid,
{session_connect, #{
session_id => SessionId,
status => Status,
afk => Afk,
mobile => Mobile
}},
10000
)
end)
end,
{reply, {ok, MissedEvents}, NewState}
end;
handle_call({get_state}, _From, State) ->
@@ -162,17 +221,31 @@ handle_call({voice_state_update, Data}, _From, State) ->
handle_call(_, _From, State) ->
{reply, ok, State}.
-spec handle_cast(Request, State) -> Result when
Request ::
{presence_update, map()}
| {dispatch, atom(), map()}
| {initial_global_presences, [map()]}
| {guild_join, guild_id()}
| {guild_leave, guild_id()}
| {guild_leave, guild_id(), forced_unavailable}
| {terminate, [binary()]}
| {terminate_force}
| {call_monitor, channel_id(), pid()}
| {call_unmonitor, channel_id()}
| {call_force_disconnect, channel_id(), binary() | undefined}
| term(),
State :: session_state(),
Result :: {noreply, session_state()} | {stop, normal, session_state()}.
handle_cast({presence_update, Update}, State) ->
PresencePid = maps:get(presence_pid, State, undefined),
SessionId = maps:get(id, State),
Status = maps:get(status, State),
Afk = maps:get(afk, State),
Mobile = maps:get(mobile, State),
NewStatus = maps:get(status, Update, Status),
NewAfk = maps:get(afk, Update, Afk),
NewMobile = maps:get(mobile, Update, Mobile),
NewState = maps:merge(State, #{status => NewStatus, afk => NewAfk, mobile => NewMobile}),
case PresencePid of
undefined ->
@@ -189,26 +262,41 @@ handle_cast({presence_update, Update}, State) ->
handle_cast({dispatch, Event, Data}, State) ->
session_dispatch:handle_dispatch(Event, Data, State);
handle_cast({initial_global_presences, Presences}, State) ->
NewState =
lists:foldl(
fun(Presence, AccState) ->
{noreply, UpdatedState} = session_dispatch:handle_dispatch(
presence_update, Presence, AccState
),
UpdatedState
end,
State,
Presences
),
NewState = lists:foldl(
fun(Presence, AccState) ->
{noreply, UpdatedState} = session_dispatch:handle_dispatch(
presence_update, Presence, AccState
),
UpdatedState
end,
State,
Presences
),
{noreply, NewState};
handle_cast({guild_join, GuildId}, State) ->
self() ! {guild_connect, GuildId, 0},
{noreply, State};
handle_cast({guild_leave, GuildId, forced_unavailable}, State) ->
Guilds = maps:get(guilds, State),
case maps:get(GuildId, Guilds, undefined) of
{Pid, Ref} when is_pid(Pid) ->
demonitor(Ref, [flush]);
_ ->
ok
end,
NewGuilds = maps:put(GuildId, cached_unavailable, Guilds),
{noreply, State1} = session_dispatch:handle_dispatch(
guild_delete,
#{<<"id">> => integer_to_binary(GuildId), <<"unavailable">> => true},
State
),
self() ! {guild_connect, GuildId, 0},
{noreply, maps:put(guilds, NewGuilds, State1)};
handle_cast({guild_leave, GuildId}, State) ->
Guilds = maps:get(guilds, State),
case maps:get(GuildId, Guilds, undefined) of
{Pid, Ref} when is_pid(Pid) ->
demonitor(Ref),
demonitor(Ref, [flush]),
NewGuilds = maps:put(GuildId, undefined, Guilds),
session_dispatch:handle_dispatch(
guild_delete, #{<<"id">> => integer_to_binary(GuildId)}, State
@@ -226,24 +314,6 @@ handle_cast({terminate, SessionIdHashes}, State) ->
end;
handle_cast({terminate_force}, State) ->
{stop, normal, State};
handle_cast({call_connect, ChannelIdBin}, State) ->
case validation:validate_snowflake(<<"channel_id">>, ChannelIdBin) of
{ok, ChannelId} ->
case gen_server:call(call_manager, {lookup, ChannelId}, 5000) of
{ok, CallPid} ->
case gen_server:call(CallPid, {get_state}, 5000) of
{ok, CallData} ->
session_dispatch:handle_dispatch(call_create, CallData, State);
_ ->
{noreply, State}
end;
not_found ->
{noreply, State}
end;
{error, _, Reason} ->
logger:warning("[session] Invalid channel_id for call_connect: ~p", [Reason]),
{noreply, State}
end;
handle_cast({call_monitor, ChannelId, CallPid}, State) ->
Calls = maps:get(calls, State, #{}),
case maps:get(ChannelId, Calls, undefined) of
@@ -269,9 +339,29 @@ handle_cast({call_unmonitor, ChannelId}, State) ->
undefined ->
{noreply, State}
end;
handle_cast({call_force_disconnect, ChannelId, ConnectionId}, State) ->
NewState = force_disconnect_dm_call(ChannelId, ConnectionId, State),
{noreply, NewState};
handle_cast(_, State) ->
{noreply, State}.
-spec handle_info(Info, State) -> Result when
Info ::
{presence_connect, non_neg_integer()}
| {guild_connect, guild_id(), non_neg_integer()}
| {guild_connect_result, guild_id(), non_neg_integer(), term()}
| {guild_connect_timeout, guild_id(), non_neg_integer()}
| {call_reconnect, channel_id(), non_neg_integer()}
| enable_presence_updates
| premature_readiness
| bot_initial_ready
| resume_timeout
| flush_reaction_buffer
| {process_voice_queue}
| {'DOWN', reference(), process, pid(), term()}
| term(),
State :: session_state(),
Result :: {noreply, session_state()} | {stop, normal, session_state()}.
handle_info({presence_connect, Attempt}, State) ->
PresencePid = maps:get(presence_pid, State, undefined),
case PresencePid of
@@ -282,6 +372,8 @@ handle_info({guild_connect, GuildId, Attempt}, State) ->
session_connection:handle_guild_connect(GuildId, Attempt, State);
handle_info({guild_connect_result, GuildId, Attempt, Result}, State) ->
session_connection:handle_guild_connect_result(GuildId, Attempt, Result, State);
handle_info({guild_connect_timeout, GuildId, Attempt}, State) ->
session_connection:handle_guild_connect_timeout(GuildId, Attempt, State);
handle_info({call_reconnect, ChannelId, Attempt}, State) ->
session_connection:handle_call_reconnect(ChannelId, Attempt, State);
handle_info(enable_presence_updates, State) ->
@@ -305,17 +397,99 @@ handle_info(resume_timeout, State) ->
undefined -> {stop, normal, State};
_ -> {noreply, State}
end;
handle_info({process_voice_queue}, State) ->
NewState = session_voice:process_voice_queue(State),
VoiceQueue = maps:get(voice_queue, NewState, queue:new()),
case queue:is_empty(VoiceQueue) of
false ->
Timer = erlang:send_after(100, self(), {process_voice_queue}),
{noreply, maps:put(voice_queue_timer, Timer, NewState)};
true ->
{noreply, NewState}
end;
handle_info(flush_reaction_buffer, State) ->
NewState = session_dispatch:flush_reaction_buffer(State),
{noreply, NewState};
handle_info({'DOWN', Ref, process, _Pid, Reason}, State) ->
session_monitor:handle_process_down(Ref, Reason, State);
handle_info(_Info, State) ->
{noreply, State}.
-spec terminate(term(), session_state()) -> ok.
terminate(_Reason, _State) ->
ok.
-spec code_change(term(), session_state(), term()) -> {ok, session_state()}.
code_change(_OldVsn, State, _Extra) ->
{ok, State}.
-spec force_disconnect_dm_call(channel_id(), binary() | undefined, session_state()) -> session_state().
force_disconnect_dm_call(ChannelId, ConnectionId, State) ->
UserId = maps:get(user_id, State),
SessionId = maps:get(id, State),
EffectiveConnectionId = resolve_dm_connection_id_for_channel(ChannelId, ConnectionId, UserId, State),
gen_server:cast(self(), {call_unmonitor, ChannelId}),
case EffectiveConnectionId of
undefined ->
State;
_ ->
Request = #{
user_id => UserId,
channel_id => null,
session_id => SessionId,
connection_id => EffectiveConnectionId,
self_mute => false,
self_deaf => false,
self_video => false,
self_stream => false,
viewer_stream_keys => [],
is_mobile => false,
latitude => null,
longitude => null
},
StateWithSessionPid = maps:put(session_pid, self(), State),
case dm_voice:voice_state_update(Request, StateWithSessionPid) of
{reply, #{success := true}, NewState} ->
maps:remove(session_pid, NewState);
_ ->
{reply, #{success := true}, FallbackState} =
dm_voice:disconnect_voice_user(UserId, StateWithSessionPid),
maps:remove(session_pid, FallbackState)
end
end.
-spec resolve_dm_connection_id_for_channel(
channel_id(), binary() | undefined, user_id(), session_state()
) -> binary() | undefined.
resolve_dm_connection_id_for_channel(_ChannelId, ConnectionId, _UserId, _State)
when is_binary(ConnectionId) ->
ConnectionId;
resolve_dm_connection_id_for_channel(ChannelId, _ConnectionId, UserId, State) ->
VoiceStates = maps:get(dm_voice_states, State, #{}),
UserIdBin = integer_to_binary(UserId),
ChannelIdBin = integer_to_binary(ChannelId),
maps:fold(
fun
(ConnId, VoiceState, undefined) ->
case
{
maps:get(<<"user_id">>, VoiceState, undefined),
maps:get(<<"channel_id">>, VoiceState, undefined)
}
of
{UserIdBin, ChannelIdBin} ->
ConnId;
_ ->
undefined
end;
(_ConnId, _VoiceState, ExistingConnId) ->
ExistingConnId
end,
undefined,
VoiceStates
).
-spec serialize_state(session_state()) -> map().
serialize_state(State) ->
#{
id => maps:get(id, State),
@@ -337,11 +511,13 @@ serialize_state(State) ->
collected_presences => maps:get(collected_presences, State, [])
}.
-spec build_ignored_events_map([binary()]) -> #{binary() => true}.
build_ignored_events_map(Events) when is_list(Events) ->
maps:from_list([{Event, true} || Event <- Events]);
build_ignored_events_map(_) ->
#{}.
-spec load_private_channels(map() | undefined) -> #{channel_id() => map()}.
load_private_channels(Ready) when is_map(Ready) ->
PrivateChannels = maps:get(<<"private_channels">>, Ready, []),
maps:from_list([
@@ -351,6 +527,7 @@ load_private_channels(Ready) when is_map(Ready) ->
load_private_channels(_) ->
#{}.
-spec load_relationships(map() | undefined) -> #{user_id() => integer()}.
load_relationships(Ready) when is_map(Ready) ->
Relationships = maps:get(<<"relationships">>, Ready, []),
maps:from_list(
@@ -362,9 +539,144 @@ load_relationships(Ready) when is_map(Ready) ->
load_relationships(_) ->
#{}.
-spec ensure_bot_ready_map(map() | undefined) -> map().
ensure_bot_ready_map(undefined) ->
#{<<"guilds">> => []};
ensure_bot_ready_map(Ready) when is_map(Ready) ->
maps:merge(Ready, #{<<"guilds">> => []});
ensure_bot_ready_map(_) ->
#{<<"guilds">> => []}.
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
build_ignored_events_map_test() ->
?assertEqual(#{}, build_ignored_events_map([])),
?assertEqual(#{<<"TYPING_START">> => true}, build_ignored_events_map([<<"TYPING_START">>])),
?assertEqual(
#{<<"TYPING_START">> => true, <<"PRESENCE_UPDATE">> => true},
build_ignored_events_map([<<"TYPING_START">>, <<"PRESENCE_UPDATE">>])
),
?assertEqual(#{}, build_ignored_events_map(not_a_list)),
ok.
load_private_channels_test() ->
?assertEqual(#{}, load_private_channels(undefined)),
?assertEqual(#{}, load_private_channels(#{})),
Ready = #{
<<"private_channels">> => [
#{<<"id">> => <<"123">>, <<"type">> => 1},
#{<<"id">> => <<"456">>, <<"type">> => 3}
]
},
Channels = load_private_channels(Ready),
?assertEqual(2, maps:size(Channels)),
?assert(maps:is_key(123, Channels)),
?assert(maps:is_key(456, Channels)),
ok.
load_relationships_test() ->
?assertEqual(#{}, load_relationships(undefined)),
?assertEqual(#{}, load_relationships(#{})),
Ready = #{
<<"relationships">> => [
#{<<"id">> => <<"100">>, <<"type">> => 1},
#{<<"id">> => <<"200">>, <<"type">> => 3}
]
},
Rels = load_relationships(Ready),
?assertEqual(2, maps:size(Rels)),
?assertEqual(1, maps:get(100, Rels)),
?assertEqual(3, maps:get(200, Rels)),
ok.
ensure_bot_ready_map_test() ->
?assertEqual(#{<<"guilds">> => []}, ensure_bot_ready_map(undefined)),
?assertEqual(
#{<<"guilds">> => [], <<"user">> => #{}}, ensure_bot_ready_map(#{<<"user">> => #{}})
),
?assertEqual(#{<<"guilds">> => []}, ensure_bot_ready_map(not_a_map)),
ok.
serialize_state_test() ->
State = #{
id => <<"session123">>,
user_id => 12345,
user_data => #{<<"username">> => <<"test">>},
version => 9,
seq => 10,
ack_seq => 5,
properties => #{},
status => online,
afk => false,
mobile => false,
buffer => [],
ready => undefined,
guilds => #{},
collected_guild_states => [],
collected_sessions => [],
collected_presences => []
},
Serialized = serialize_state(State),
?assertEqual(<<"session123">>, maps:get(id, Serialized)),
?assertEqual(<<"12345">>, maps:get(user_id, Serialized)),
?assertEqual(10, maps:get(seq, Serialized)),
ok.
handle_cast_forced_unavailable_guild_leave_schedules_retry_test() ->
GuildId = 123,
State0 = #{
id => <<"session-force-unavailable">>,
user_id => 1,
user_data => #{},
custom_status => null,
version => 1,
token_hash => <<>>,
auth_session_id_hash => <<>>,
buffer => [],
seq => 0,
ack_seq => 0,
properties => #{},
status => online,
afk => false,
mobile => false,
presence_pid => undefined,
presence_mref => undefined,
socket_pid => undefined,
socket_mref => undefined,
guilds => #{GuildId => {self(), make_ref()}},
calls => #{},
channels => #{},
ready => undefined,
bot => false,
ignored_events => #{},
initial_guild_id => undefined,
collected_guild_states => [],
collected_sessions => [],
collected_presences => [],
relationships => #{},
suppress_presence_updates => false,
pending_presences => [],
guild_connect_inflight => #{},
voice_queue => queue:new(),
voice_queue_timer => undefined,
debounce_reactions => false,
reaction_buffer => [],
reaction_buffer_timer => undefined
},
{noreply, State1} = handle_cast({guild_leave, GuildId, forced_unavailable}, State0),
Guilds = maps:get(guilds, State1),
?assertEqual(cached_unavailable, maps:get(GuildId, Guilds)),
Buffer = maps:get(buffer, State1),
?assertEqual(1, length(Buffer)),
[LastEvent] = Buffer,
?assertEqual(guild_delete, maps:get(event, LastEvent)),
EventData = maps:get(data, LastEvent),
?assertEqual(true, maps:get(<<"unavailable">>, EventData)),
receive
{guild_connect, GuildId, 0} -> ok
after 100 ->
?assert(false, forced_unavailable_retry_not_scheduled)
end.
-endif.

View File

@@ -21,11 +21,31 @@
handle_presence_connect/2,
handle_guild_connect/3,
handle_guild_connect_result/4,
handle_guild_connect_timeout/3,
handle_call_reconnect/3
]).
-define(GUILD_CONNECT_MAX_INFLIGHT, 8).
-define(MAX_RETRY_ATTEMPTS, 25).
-define(MAX_CALL_RETRY_ATTEMPTS, 15).
-define(GUILD_CONNECT_ASYNC_TIMEOUT_MS, 30000).
-define(MAX_GUILD_UNAVAILABLE_RETRY_DELAY_MS, 30000).
-define(MAX_GUILD_UNAVAILABLE_BACKOFF_ATTEMPT, 5).
-define(GUILD_UNAVAILABLE_JITTER_DIVISOR, 5).
-type session_state() :: session:session_state().
-type guild_id() :: session:guild_id().
-type channel_id() :: session:channel_id().
-type attempt() :: non_neg_integer().
-type guild_connect_result() ::
{ok, pid(), map()}
| {ok_unavailable, pid(), map()}
| {ok_cached_unavailable, map()}
| {error, term()}.
-spec handle_presence_connect(attempt(), session_state()) ->
{noreply, session_state()}.
handle_presence_connect(Attempt, State) ->
UserId = maps:get(user_id, State),
UserData = maps:get(user_data, State),
@@ -37,7 +57,6 @@ handle_presence_connect(Attempt, State) ->
SocketPid = maps:get(socket_pid, State, undefined),
FriendIds = presence_targets:friend_ids_from_state(State),
GroupDmRecipients = presence_targets:group_dm_recipients_from_state(State),
Message =
{start_or_lookup, #{
user_id => UserId,
@@ -48,83 +67,135 @@ handle_presence_connect(Attempt, State) ->
group_dm_recipients => GroupDmRecipients,
custom_status => maps:get(custom_status, State, null)
}},
case gen_server:call(presence_manager, Message, 5000) of
{ok, Pid} ->
try
case
gen_server:call(
Pid,
{session_connect, #{
session_id => SessionId,
status => Status,
afk => Afk,
mobile => Mobile,
socket_pid => SocketPid
}},
10000
)
of
{ok, Sessions} ->
gen_server:cast(Pid, {sync_friends, FriendIds}),
gen_server:cast(Pid, {sync_group_dm_recipients, GroupDmRecipients}),
NewState = maps:merge(State, #{
presence_pid => Pid,
presence_mref => monitor(process, Pid),
collected_sessions => Sessions
}),
session_ready:check_readiness(NewState);
_ ->
case Attempt < 25 of
true ->
erlang:send_after(
backoff_utils:calculate(Attempt),
self(),
{presence_connect, Attempt + 1}
),
{noreply, State};
false ->
{noreply, State}
end
end
catch
exit:{noproc, _} when Attempt < 25 ->
erlang:send_after(
backoff_utils:calculate(Attempt), self(), {presence_connect, Attempt + 1}
),
{noreply, State};
exit:{normal, _} when Attempt < 25 ->
erlang:send_after(
backoff_utils:calculate(Attempt), self(), {presence_connect, Attempt + 1}
),
{noreply, State};
_:_ ->
{noreply, State}
end;
try_presence_session_connect(
Pid,
SessionId,
Status,
Afk,
Mobile,
SocketPid,
FriendIds,
GroupDmRecipients,
Attempt,
State
);
_ ->
case Attempt < 25 of
true ->
erlang:send_after(
backoff_utils:calculate(Attempt), self(), {presence_connect, Attempt + 1}
),
{noreply, State};
false ->
{noreply, State}
end
schedule_presence_retry(Attempt, State)
end.
-spec try_presence_session_connect(
pid(),
binary(),
atom(),
boolean(),
boolean(),
pid() | undefined,
[integer()],
map(),
attempt(),
session_state()
) ->
{noreply, session_state()}.
try_presence_session_connect(
Pid, SessionId, Status, Afk, Mobile, SocketPid, FriendIds, GroupDmRecipients, Attempt, State
) ->
try
case
gen_server:call(
Pid,
{session_connect, #{
session_id => SessionId,
status => Status,
afk => Afk,
mobile => Mobile,
socket_pid => SocketPid
}},
10000
)
of
{ok, Sessions} ->
gen_server:cast(Pid, {sync_friends, FriendIds}),
gen_server:cast(Pid, {sync_group_dm_recipients, GroupDmRecipients}),
NewState = maps:merge(State, #{
presence_pid => Pid,
presence_mref => monitor(process, Pid),
collected_sessions => Sessions
}),
session_ready:check_readiness(NewState);
_ ->
schedule_presence_retry(Attempt, State)
end
catch
exit:{noproc, _} ->
schedule_presence_retry(Attempt, State);
exit:{normal, _} ->
schedule_presence_retry(Attempt, State);
_:_ ->
{noreply, State}
end.
-spec schedule_presence_retry(attempt(), session_state()) -> {noreply, session_state()}.
schedule_presence_retry(Attempt, State) when Attempt < ?MAX_RETRY_ATTEMPTS ->
erlang:send_after(backoff_utils:calculate(Attempt), self(), {presence_connect, Attempt + 1}),
{noreply, State};
schedule_presence_retry(_Attempt, State) ->
{noreply, State}.
-spec handle_guild_connect(guild_id(), attempt(), session_state()) ->
{noreply, session_state()}.
handle_guild_connect(GuildId, Attempt, State) ->
Guilds = maps:get(guilds, State),
SessionId = maps:get(id, State),
UserId = maps:get(user_id, State),
case maps:get(GuildId, Guilds, undefined) of
{_Pid, _Ref} ->
{noreply, State};
cached_unavailable ->
maybe_handle_cached_unavailability(GuildId, Attempt, SessionId, UserId, State);
_ ->
maybe_spawn_guild_connect(GuildId, Attempt, SessionId, UserId, State)
end.
-spec maybe_handle_cached_unavailability(
guild_id(), attempt(), binary(), integer(), session_state()
) ->
{noreply, session_state()}.
maybe_handle_cached_unavailability(GuildId, Attempt, SessionId, UserId, State) ->
UserData = maps:get(user_data, State, #{}),
case guild_availability:is_guild_unavailable_for_user_from_cache(GuildId, UserData) of
true ->
mark_cached_guild_unavailable_and_retry(GuildId, Attempt, State);
false ->
Guilds = maps:get(guilds, State, #{}),
ResetGuilds = maps:put(GuildId, undefined, Guilds),
ResetState = maps:put(guilds, ResetGuilds, State),
maybe_spawn_guild_connect(GuildId, 0, SessionId, UserId, ResetState)
end.
-spec mark_cached_guild_unavailable(guild_id(), session_state()) ->
{noreply, session_state()}.
mark_cached_guild_unavailable(GuildId, State) ->
Guilds = maps:get(guilds, State, #{}),
case maps:get(GuildId, Guilds, undefined) of
cached_unavailable ->
{noreply, State};
_ ->
UpdatedGuilds = maps:put(GuildId, cached_unavailable, Guilds),
StateWithGuild = maps:put(guilds, UpdatedGuilds, State),
{noreply, MarkedState} = session_ready:mark_guild_unavailable(GuildId, StateWithGuild),
session_ready:check_readiness(MarkedState)
end.
-spec mark_cached_guild_unavailable_and_retry(guild_id(), attempt(), session_state()) ->
{noreply, session_state()}.
mark_cached_guild_unavailable_and_retry(GuildId, Attempt, State) ->
{noreply, MarkedState} = mark_cached_guild_unavailable(GuildId, State),
schedule_cached_unavailable_retry(GuildId, Attempt, MarkedState).
-spec handle_guild_connect_result(guild_id(), attempt(), guild_connect_result(), session_state()) ->
{noreply, session_state()}.
handle_guild_connect_result(GuildId, Attempt, Result, State) ->
Inflight = maps:get(guild_connect_inflight, State, #{}),
case maps:get(GuildId, Inflight, undefined) of
@@ -136,10 +207,25 @@ handle_guild_connect_result(GuildId, Attempt, Result, State) ->
{noreply, State}
end.
-spec handle_guild_connect_timeout(guild_id(), attempt(), session_state()) -> {noreply, session_state()}.
handle_guild_connect_timeout(GuildId, Attempt, State) ->
Inflight0 = maps:get(guild_connect_inflight, State, #{}),
case maps:get(GuildId, Inflight0, undefined) of
Attempt ->
Inflight = maps:remove(GuildId, Inflight0),
State1 = maps:put(guild_connect_inflight, Inflight, State),
retry_or_fail(GuildId, Attempt, State1, fun(GId, St) ->
session_ready:mark_guild_unavailable(GId, St)
end);
_ ->
{noreply, State}
end.
-spec handle_call_reconnect(channel_id(), attempt(), session_state()) ->
{noreply, session_state()}.
handle_call_reconnect(ChannelId, Attempt, State) ->
Calls = maps:get(calls, State, #{}),
SessionId = maps:get(id, State),
case maps:get(ChannelId, Calls, undefined) of
{_Pid, _Ref} ->
{noreply, State};
@@ -147,6 +233,8 @@ handle_call_reconnect(ChannelId, Attempt, State) ->
attempt_call_reconnect(ChannelId, Attempt, SessionId, State)
end.
-spec maybe_spawn_guild_connect(guild_id(), attempt(), binary(), integer(), session_state()) ->
{noreply, session_state()}.
maybe_spawn_guild_connect(GuildId, Attempt, SessionId, UserId, State) ->
Inflight0 = maps:get(guild_connect_inflight, State, #{}),
AlreadyInflight = maps:is_key(GuildId, Inflight0),
@@ -163,33 +251,49 @@ maybe_spawn_guild_connect(GuildId, Attempt, SessionId, UserId, State) ->
State1 = maps:put(guild_connect_inflight, Inflight, State),
SessionPid = self(),
InitialGuildId = maps:get(initial_guild_id, State, undefined),
UserData = maps:get(user_data, State, #{}),
spawn(fun() ->
do_guild_connect(SessionPid, GuildId, Attempt, SessionId, UserId, Bot, InitialGuildId)
do_guild_connect(
SessionPid, GuildId, Attempt, SessionId, UserId, Bot, InitialGuildId, UserData
)
end),
{noreply, State1}
end.
do_guild_connect(SessionPid, GuildId, Attempt, SessionId, UserId, Bot, InitialGuildId) ->
-spec do_guild_connect(
pid(), guild_id(), attempt(), binary(), integer(), boolean(), guild_id() | undefined, map()
) -> ok.
do_guild_connect(SessionPid, GuildId, Attempt, SessionId, UserId, Bot, InitialGuildId, UserData) ->
Result =
try
case gen_server:call(guild_manager, {start_or_lookup, GuildId}, 5000) of
{ok, GuildPid} ->
ActiveGuilds = build_initial_active_guilds(InitialGuildId, GuildId),
Request = #{
session_id => SessionId,
user_id => UserId,
session_pid => SessionPid,
bot => Bot,
initial_guild_id => InitialGuildId,
active_guilds => ActiveGuilds
},
case gen_server:call(GuildPid, {session_connect, Request}, 10000) of
{ok, unavailable, UnavailableResponse} ->
{ok_unavailable, GuildPid, UnavailableResponse};
{ok, GuildState} ->
{ok, GuildPid, GuildState};
Error ->
{error, {session_connect_failed, Error}}
case maybe_build_unavailable_response_from_cache(GuildId, UserData) of
{ok, UnavailableResponse} ->
{ok_cached_unavailable, UnavailableResponse};
not_unavailable ->
ActiveGuilds = build_initial_active_guilds(InitialGuildId, GuildId),
IsStaff = maps:get(<<"is_staff">>, UserData, false),
Request = #{
session_id => SessionId,
user_id => UserId,
session_pid => SessionPid,
bot => Bot,
is_staff => IsStaff,
initial_guild_id => InitialGuildId,
active_guilds => ActiveGuilds
},
gen_server:cast(GuildPid, {session_connect_async, #{
guild_id => GuildId,
attempt => Attempt,
request => Request
}}),
_ = erlang:send_after(
?GUILD_CONNECT_ASYNC_TIMEOUT_MS,
SessionPid,
{guild_connect_timeout, GuildId, Attempt}
),
pending
end;
Error ->
{error, {guild_manager_failed, Error}}
@@ -202,15 +306,41 @@ do_guild_connect(SessionPid, GuildId, Attempt, SessionId, UserId, Bot, InitialGu
_:Reason ->
{error, {exception, Reason}}
end,
SessionPid ! {guild_connect_result, GuildId, Attempt, Result},
case Result of
pending ->
ok;
_ ->
SessionPid ! {guild_connect_result, GuildId, Attempt, Result}
end,
ok.
-spec maybe_build_unavailable_response_from_cache(guild_id(), map()) ->
{ok, map()} | not_unavailable.
maybe_build_unavailable_response_from_cache(GuildId, UserData) ->
case guild_availability:is_guild_unavailable_for_user_from_cache(GuildId, UserData) of
true ->
{ok, #{
<<"id">> => integer_to_binary(GuildId),
<<"unavailable">> => true
}};
false ->
not_unavailable
end.
-spec handle_guild_connect_result_internal(
guild_id(), attempt(), guild_connect_result(), session_state()
) ->
{noreply, session_state()}.
handle_guild_connect_result_internal(
GuildId, _Attempt, {ok_unavailable, GuildPid, UnavailableResponse}, State
) ->
finalize_guild_connection(GuildId, GuildPid, State, fun(St) ->
session_ready:process_guild_state(UnavailableResponse, St)
end);
handle_guild_connect_result_internal(
GuildId, Attempt, {ok_cached_unavailable, _UnavailableResponse}, State
) ->
mark_cached_guild_unavailable_and_retry(GuildId, Attempt, State);
handle_guild_connect_result_internal(GuildId, _Attempt, {ok, GuildPid, GuildState}, State) ->
finalize_guild_connection(GuildId, GuildPid, State, fun(St) ->
session_ready:process_guild_state(GuildState, St)
@@ -222,6 +352,10 @@ handle_guild_connect_result_internal(GuildId, Attempt, {error, _Reason}, State)
session_ready:mark_guild_unavailable(GId, St)
end).
-spec finalize_guild_connection(guild_id(), pid(), session_state(), fun(
(session_state()) -> {noreply, session_state()}
)) ->
{noreply, session_state()}.
finalize_guild_connection(GuildId, GuildPid, State, ReadyFun) ->
Guilds0 = maps:get(guilds, State),
case maps:get(GuildId, Guilds0, undefined) of
@@ -234,36 +368,58 @@ finalize_guild_connection(GuildId, GuildPid, State, ReadyFun) ->
ReadyFun(State1)
end.
retry_or_fail(GuildId, Attempt, State, FailureFun) ->
case Attempt < 25 of
-spec retry_or_fail(guild_id(), attempt(), session_state(), fun(
(guild_id(), session_state()) -> {noreply, session_state()}
)) ->
{noreply, session_state()}.
retry_or_fail(GuildId, Attempt, State, _FailureFun) when Attempt < ?MAX_RETRY_ATTEMPTS ->
BackoffMs = backoff_utils:calculate(Attempt),
erlang:send_after(BackoffMs, self(), {guild_connect, GuildId, Attempt + 1}),
{noreply, State};
retry_or_fail(GuildId, _Attempt, State, FailureFun) ->
FailureFun(GuildId, State).
-spec schedule_cached_unavailable_retry(guild_id(), attempt(), session_state()) ->
{noreply, session_state()}.
schedule_cached_unavailable_retry(GuildId, Attempt, State) ->
SessionId = maps:get(id, State, <<>>),
DelayMs = cached_unavailable_retry_delay_ms(GuildId, SessionId, Attempt),
NextAttempt = Attempt + 1,
erlang:send_after(DelayMs, self(), {guild_connect, GuildId, NextAttempt}),
{noreply, State}.
-spec cached_unavailable_retry_delay_ms(guild_id(), binary(), attempt()) -> non_neg_integer().
cached_unavailable_retry_delay_ms(GuildId, SessionId, Attempt) ->
CappedAttempt = min(Attempt, ?MAX_GUILD_UNAVAILABLE_BACKOFF_ATTEMPT),
BaseDelay = backoff_utils:calculate(CappedAttempt, ?MAX_GUILD_UNAVAILABLE_RETRY_DELAY_MS),
case BaseDelay >= ?MAX_GUILD_UNAVAILABLE_RETRY_DELAY_MS of
true ->
BackoffMs = backoff_utils:calculate(Attempt),
erlang:send_after(BackoffMs, self(), {guild_connect, GuildId, Attempt + 1}),
{noreply, State};
?MAX_GUILD_UNAVAILABLE_RETRY_DELAY_MS;
false ->
logger:error(
"[session_connection] Guild ~p connect failed after ~p attempts",
[GuildId, Attempt]
),
FailureFun(GuildId, State)
MaxJitter = max(1, BaseDelay div ?GUILD_UNAVAILABLE_JITTER_DIVISOR),
Jitter = erlang:phash2({GuildId, SessionId, Attempt}, MaxJitter + 1),
min(?MAX_GUILD_UNAVAILABLE_RETRY_DELAY_MS, BaseDelay + Jitter)
end.
-spec attempt_call_reconnect(channel_id(), attempt(), binary(), session_state()) ->
{noreply, session_state()}.
attempt_call_reconnect(ChannelId, Attempt, _SessionId, State) ->
case gen_server:call(call_manager, {lookup, ChannelId}, 5000) of
{ok, CallPid} ->
connect_to_call_process(CallPid, ChannelId, State);
not_found ->
handle_call_not_found(ChannelId, Attempt, State);
retry_call_or_remove(ChannelId, Attempt, State);
_Error ->
handle_call_lookup_error(ChannelId, Attempt, State)
retry_call_or_remove(ChannelId, Attempt, State)
end.
-spec connect_to_call_process(pid(), channel_id(), session_state()) ->
{noreply, session_state()}.
connect_to_call_process(CallPid, ChannelId, State) ->
Calls = maps:get(calls, State, #{}),
MonitorRef = monitor(process, CallPid),
NewCalls = maps:put(ChannelId, {CallPid, MonitorRef}, Calls),
StateWithCall = maps:put(calls, NewCalls, State),
case gen_server:call(CallPid, {get_state}, 5000) of
{ok, CallData} ->
session_dispatch:handle_dispatch(call_create, CallData, StateWithCall);
@@ -272,30 +428,295 @@ connect_to_call_process(CallPid, ChannelId, State) ->
{noreply, State}
end.
handle_call_not_found(ChannelId, Attempt, State) ->
retry_call_or_remove(ChannelId, Attempt, State).
handle_call_lookup_error(ChannelId, Attempt, State) ->
retry_call_or_remove(ChannelId, Attempt, State).
retry_call_or_remove(ChannelId, Attempt, State) ->
case Attempt < 15 of
true ->
erlang:send_after(
backoff_utils:calculate(Attempt),
self(),
{call_reconnect, ChannelId, Attempt + 1}
),
{noreply, State};
false ->
Calls = maps:get(calls, State, #{}),
NewCalls = maps:remove(ChannelId, Calls),
{noreply, maps:put(calls, NewCalls, State)}
end.
-spec retry_call_or_remove(channel_id(), attempt(), session_state()) ->
{noreply, session_state()}.
retry_call_or_remove(ChannelId, Attempt, State) when Attempt < ?MAX_CALL_RETRY_ATTEMPTS ->
erlang:send_after(
backoff_utils:calculate(Attempt), self(), {call_reconnect, ChannelId, Attempt + 1}
),
{noreply, State};
retry_call_or_remove(ChannelId, _Attempt, State) ->
Calls = maps:get(calls, State, #{}),
NewCalls = maps:remove(ChannelId, Calls),
{noreply, maps:put(calls, NewCalls, State)}.
-spec build_initial_active_guilds(guild_id() | undefined, guild_id()) -> sets:set(guild_id()).
build_initial_active_guilds(undefined, _GuildId) ->
sets:new();
build_initial_active_guilds(GuildId, GuildId) ->
sets:from_list([GuildId]);
build_initial_active_guilds(_, _) ->
sets:new().
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
build_initial_active_guilds_test() ->
?assertEqual(sets:new(), build_initial_active_guilds(undefined, 123)),
?assertEqual(sets:from_list([123]), build_initial_active_guilds(123, 123)),
?assertEqual(sets:new(), build_initial_active_guilds(456, 123)),
ok.
mark_cached_guild_unavailable_test() ->
GuildId = 2001,
CacheState = #{
id => GuildId,
data => #{
<<"guild">> => #{
<<"features">> => [<<"UNAVAILABLE_FOR_EVERYONE">>]
}
}
},
_ = guild_availability:update_unavailability_cache_for_state(CacheState),
State0 = #{
id => <<"session-1">>,
user_id => 55,
user_data => #{<<"flags">> => <<"0">>},
guilds => #{GuildId => undefined},
collected_guild_states => [],
ready => undefined
},
{noreply, State1} = mark_cached_guild_unavailable(GuildId, State0),
Guilds = maps:get(guilds, State1),
?assertEqual(cached_unavailable, maps:get(GuildId, Guilds)),
Collected = maps:get(collected_guild_states, State1, []),
?assertEqual(1, length(Collected)),
?assertMatch(
#{<<"id">> := _, <<"unavailable">> := true},
hd(Collected)
),
{noreply, State2} = mark_cached_guild_unavailable(GuildId, State1),
?assertEqual(1, length(maps:get(collected_guild_states, State2, []))),
CacheCleanupState = #{
id => GuildId,
data => #{
<<"guild">> => #{
<<"features">> => []
}
}
},
_ = guild_availability:update_unavailability_cache_for_state(CacheCleanupState),
ok.
do_guild_connect_skips_session_connect_when_cached_unavailable_test() ->
GuildId = 2002,
Attempt = 3,
SessionId = <<"session-2">>,
UserId = 77,
CacheState = #{
id => GuildId,
data => #{
<<"guild">> => #{
<<"features">> => [<<"UNAVAILABLE_FOR_EVERYONE">>]
}
}
},
_ = guild_availability:update_unavailability_cache_for_state(CacheState),
Parent = self(),
TestRef = make_ref(),
GuildPid = spawn(fun() -> guild_stub_loop(Parent, TestRef) end),
ManagerPid = spawn(fun() -> manager_stub_loop(GuildId, GuildPid) end),
?assertEqual(undefined, whereis(guild_manager)),
true = register(guild_manager, ManagerPid),
try
ok = do_guild_connect(
Parent, GuildId, Attempt, SessionId, UserId, false, undefined, #{<<"flags">> => <<"0">>}
),
case await_guild_connect_unavailable_result(GuildId, Attempt) of
{ok, Response} ->
?assertEqual(integer_to_binary(GuildId), maps:get(<<"id">>, Response)),
?assertEqual(true, maps:get(<<"unavailable">>, Response));
timeout ->
?assert(false, guild_connect_result_not_received)
end,
case saw_guild_stub_call(TestRef, 200) of
true ->
?assert(false, should_not_call_session_connect_when_cache_unavailable);
false ->
ok
end
after
case whereis(guild_manager) of
ManagerPid ->
unregister(guild_manager);
_ ->
ok
end,
ManagerPid ! stop,
GuildPid ! stop,
CacheCleanupState = #{
id => GuildId,
data => #{
<<"guild">> => #{
<<"features">> => []
}
}
},
_ = guild_availability:update_unavailability_cache_for_state(CacheCleanupState),
ok
end.
do_guild_connect_uses_session_connect_async_cast_test() ->
GuildId = 2004,
Attempt = 0,
SessionId = <<"session-async-1">>,
UserId = 77,
Parent = self(),
TestRef = make_ref(),
GuildPid = spawn(fun() -> guild_stub_loop(Parent, TestRef) end),
ManagerPid = spawn(fun() -> manager_stub_loop(GuildId, GuildPid) end),
?assertEqual(undefined, whereis(guild_manager)),
true = register(guild_manager, ManagerPid),
SessionPid = spawn(fun() -> session_capture_loop() end),
try
ok = do_guild_connect(
SessionPid,
GuildId,
Attempt,
SessionId,
UserId,
false,
undefined,
#{<<"flags">> => <<"0">>, <<"is_staff">> => false}
),
?assertMatch({session_connect_async, _}, await_guild_stub_cast(TestRef, 1000))
after
SessionPid ! stop,
case whereis(guild_manager) of
ManagerPid -> unregister(guild_manager);
_ -> ok
end,
ManagerPid ! stop,
GuildPid ! stop
end.
session_capture_loop() ->
receive
stop -> ok;
_ -> session_capture_loop()
end.
await_guild_stub_cast(TestRef, TimeoutMs) ->
receive
{guild_stub_cast, TestRef, Msg} ->
Msg;
_Other ->
await_guild_stub_cast(TestRef, TimeoutMs)
after TimeoutMs ->
timeout
end.
-spec await_guild_connect_unavailable_result(guild_id(), attempt()) ->
{ok, map()} | timeout.
await_guild_connect_unavailable_result(GuildId, Attempt) ->
receive
{guild_connect_result, GuildId, Attempt, {ok_cached_unavailable, Response}} ->
{ok, Response};
_Other ->
await_guild_connect_unavailable_result(GuildId, Attempt)
after 1000 ->
timeout
end.
cached_unavailable_retry_delay_ms_cap_test() ->
Delay = cached_unavailable_retry_delay_ms(123, <<"session-cap">>, 500),
?assertEqual(30000, Delay).
cached_unavailable_retry_delay_ms_uses_jitter_test() ->
Delay = cached_unavailable_retry_delay_ms(123, <<"session-jitter">>, 0),
?assert(Delay >= 1000),
?assert(Delay =< 1200).
maybe_handle_cached_unavailability_retries_when_cache_available_again_test() ->
GuildId = 2003,
CacheState = #{
id => GuildId,
data => #{
<<"guild">> => #{
<<"features">> => []
}
}
},
_ = guild_availability:update_unavailability_cache_for_state(CacheState),
Inflight = maps:from_list([{N, N} || N <- lists:seq(3000, 3007)]),
State0 = #{
id => <<"session-3">>,
user_id => 88,
user_data => #{<<"flags">> => <<"0">>},
guilds => #{GuildId => cached_unavailable},
guild_connect_inflight => Inflight
},
{noreply, State1} = maybe_handle_cached_unavailability(
GuildId, 42, <<"session-3">>, 88, State0
),
Guilds = maps:get(guilds, State1),
?assertEqual(undefined, maps:get(GuildId, Guilds)),
receive
{guild_connect, GuildId, 0} -> ok
after 300 ->
?assert(false, guild_connect_retry_not_scheduled_with_reset_attempt)
end.
guild_connect_timeout_exhaustion_marks_unavailable_test() ->
GuildId = 9001,
Attempt = ?MAX_RETRY_ATTEMPTS,
State0 = #{
id => <<"session-timeout-1">>,
user_id => 100,
guilds => #{GuildId => undefined},
guild_connect_inflight => #{GuildId => Attempt},
collected_guild_states => [],
ready => undefined
},
{noreply, State1} = handle_guild_connect_timeout(GuildId, Attempt, State0),
Collected = maps:get(collected_guild_states, State1, []),
?assertEqual(1, length(Collected)),
[UnavailableEntry] = Collected,
?assertEqual(integer_to_binary(GuildId), maps:get(<<"id">>, UnavailableEntry)),
?assertEqual(true, maps:get(<<"unavailable">>, UnavailableEntry)),
Inflight = maps:get(guild_connect_inflight, State1, #{}),
?assertEqual(false, maps:is_key(GuildId, Inflight)).
-spec saw_guild_stub_call(reference(), non_neg_integer()) -> boolean.
saw_guild_stub_call(TestRef, TimeoutMs) ->
receive
{guild_stub_called, TestRef, _Request} ->
true;
_Other ->
saw_guild_stub_call(TestRef, TimeoutMs)
after TimeoutMs ->
false
end.
-spec manager_stub_loop(guild_id(), pid()) -> ok.
manager_stub_loop(GuildId, GuildPid) ->
receive
stop ->
ok;
{'$gen_call', From, {start_or_lookup, GuildId}} ->
gen_server:reply(From, {ok, GuildPid}),
manager_stub_loop(GuildId, GuildPid);
{'$gen_call', From, _Request} ->
gen_server:reply(From, {error, unsupported}),
manager_stub_loop(GuildId, GuildPid);
_Other ->
manager_stub_loop(GuildId, GuildPid)
end.
-spec guild_stub_loop(pid(), reference()) -> ok.
guild_stub_loop(Parent, TestRef) ->
receive
stop ->
ok;
{'$gen_cast', Msg} ->
Parent ! {guild_stub_cast, TestRef, Msg},
guild_stub_loop(Parent, TestRef);
{'$gen_call', From, Request} ->
Parent ! {guild_stub_called, TestRef, Request},
gen_server:reply(From, {ok, #{}}),
guild_stub_loop(Parent, TestRef);
_Other ->
guild_stub_loop(Parent, TestRef)
end.
-endif.

View File

@@ -19,61 +19,203 @@
-export([
handle_dispatch/3,
flush_all_pending_presences/1
flush_all_pending_presences/1,
flush_reaction_buffer/1
]).
-type session_state() :: session:session_state().
-type event() :: atom() | binary().
-type user_id() :: session:user_id().
-define(MAX_EVENT_BUFFER_SIZE, 4096).
-define(REACTION_BUFFER_INTERVAL_MS, 650).
-define(MAX_REACTION_BUFFER_SIZE, 512).
-define(MAX_PENDING_PRESENCE_BUFFER_SIZE, 2048).
-spec handle_dispatch(event(), map(), session_state()) -> {noreply, session_state()}.
handle_dispatch(Event, Data, State) ->
case should_ignore_event(Event, State) of
true ->
{noreply, State};
false ->
case should_buffer_presence(Event, Data, State) of
case should_buffer_reaction(Event, State) of
true ->
{noreply, buffer_presence(Event, Data, State)};
{noreply, buffer_reaction(Data, State)};
false ->
Seq = maps:get(seq, State),
Buffer = maps:get(buffer, State),
SocketPid = maps:get(socket_pid, State, undefined),
NewSeq = Seq + 1,
Request = #{event => Event, data => Data, seq => NewSeq},
NewBuffer =
case Event of
message_reaction_add ->
Buffer;
message_reaction_remove ->
Buffer;
_ ->
Buffer ++ [Request]
end,
case SocketPid of
undefined ->
ok;
Pid when is_pid(Pid) ->
case erlang:is_process_alive(Pid) of
case maybe_cancel_buffered_reaction(Event, Data, State) of
{cancelled, NewState} ->
{noreply, NewState};
not_applicable ->
case should_buffer_presence(Event, Data, State) of
true ->
Pid ! {dispatch, Event, Data, NewSeq},
ok;
{noreply, buffer_presence(Event, Data, State)};
false ->
ok
FanoutSpanCtx = start_fanout_span(Event, State),
Result = do_handle_dispatch(Event, Data, State),
end_fanout_span(FanoutSpanCtx, Event),
Result
end
end,
StateWithChannels = update_channels_map(Event, Data, State),
StateWithRelationships0 = update_relationships_map(
Event, Data, StateWithChannels
),
StateAfterMain = maps:merge(StateWithRelationships0, #{
seq => NewSeq, buffer => NewBuffer
}),
StateWithPending = maybe_flush_pending_presences(Event, Data, StateAfterMain),
FinalState = sync_presence_targets(StateWithPending),
{noreply, FinalState}
end
end
end.
-spec do_handle_dispatch(event(), map(), session_state()) -> {noreply, session_state()}.
do_handle_dispatch(Event, Data, State) ->
Seq = maps:get(seq, State),
Buffer = maps:get(buffer, State),
SocketPid = maps:get(socket_pid, State, undefined),
NewSeq = Seq + 1,
Request = #{event => Event, data => Data, seq => NewSeq},
case append_or_fail(Buffer, Request, ?MAX_EVENT_BUFFER_SIZE, event_ack_buffer, State) of
{ok, NewBuffer} ->
send_to_socket(SocketPid, Event, Data, NewSeq),
StateWithChannels = update_channels_map(Event, Data, State),
StateWithRelationships = update_relationships_map(Event, Data, StateWithChannels),
StateAfterMain = maps:merge(StateWithRelationships, #{seq => NewSeq, buffer => NewBuffer}),
StateWithPending = maybe_flush_pending_presences(Event, Data, StateAfterMain),
FinalState = sync_presence_targets(StateWithPending),
{noreply, FinalState};
overflow ->
{noreply, State}
end.
-spec send_to_socket(pid() | undefined, event(), map(), non_neg_integer()) -> ok.
send_to_socket(undefined, _Event, _Data, _Seq) ->
ok;
send_to_socket(Pid, Event, Data, Seq) when is_pid(Pid) ->
case erlang:is_process_alive(Pid) of
true ->
Pid ! {dispatch, Event, Data, Seq},
ok;
false ->
ok,
ok
end.
-spec should_buffer_reaction(event(), session_state()) -> boolean().
should_buffer_reaction(message_reaction_add, State) ->
maps:get(debounce_reactions, State, false);
should_buffer_reaction(_, _) ->
false.
-spec buffer_reaction(map(), session_state()) -> session_state().
buffer_reaction(Data, State) ->
Buffer = maps:get(reaction_buffer, State, []),
case append_or_fail(Buffer, Data, ?MAX_REACTION_BUFFER_SIZE, reaction_buffer, State) of
{ok, NewBuffer} ->
Timer = maps:get(reaction_buffer_timer, State, undefined),
NewTimer =
case Timer of
undefined ->
erlang:send_after(?REACTION_BUFFER_INTERVAL_MS, self(), flush_reaction_buffer);
Existing ->
Existing
end,
State#{reaction_buffer => NewBuffer, reaction_buffer_timer => NewTimer};
overflow ->
State
end.
-spec maybe_cancel_buffered_reaction(event(), map(), session_state()) ->
{cancelled, session_state()} | not_applicable.
maybe_cancel_buffered_reaction(message_reaction_remove, Data, State) ->
case maps:get(reaction_buffer, State, []) of
[] ->
not_applicable;
Buffer ->
MessageId = maps:get(<<"message_id">>, Data, undefined),
UserId = maps:get(<<"user_id">>, Data, undefined),
Emoji = maps:get(<<"emoji">>, Data, #{}),
case remove_matching_reaction(Buffer, MessageId, UserId, Emoji) of
{found, NewBuffer} ->
{cancelled, State#{reaction_buffer => NewBuffer}};
not_found ->
not_applicable
end
end;
maybe_cancel_buffered_reaction(_, _, _) ->
not_applicable.
-spec remove_matching_reaction([map()], term(), term(), map()) ->
{found, [map()]} | not_found.
remove_matching_reaction(Buffer, MessageId, UserId, Emoji) ->
EmojiId = maps:get(<<"id">>, Emoji, undefined),
EmojiName = maps:get(<<"name">>, Emoji, undefined),
remove_matching_reaction(Buffer, MessageId, UserId, EmojiId, EmojiName, []).
remove_matching_reaction([], _MessageId, _UserId, _EmojiId, _EmojiName, _Acc) ->
not_found;
remove_matching_reaction([Entry | Rest], MessageId, UserId, EmojiId, EmojiName, Acc) ->
EntryMessageId = maps:get(<<"message_id">>, Entry, undefined),
EntryUserId = maps:get(<<"user_id">>, Entry, undefined),
EntryEmoji = maps:get(<<"emoji">>, Entry, #{}),
EntryEmojiId = maps:get(<<"id">>, EntryEmoji, undefined),
EntryEmojiName = maps:get(<<"name">>, EntryEmoji, undefined),
case
EntryMessageId =:= MessageId andalso
EntryUserId =:= UserId andalso
EntryEmojiId =:= EmojiId andalso
EntryEmojiName =:= EmojiName
of
true ->
{found, lists:reverse(Acc) ++ Rest};
false ->
remove_matching_reaction(Rest, MessageId, UserId, EmojiId, EmojiName, [Entry | Acc])
end.
-spec flush_reaction_buffer(session_state()) -> session_state().
flush_reaction_buffer(State) ->
Buffer = maps:get(reaction_buffer, State, []),
Timer = maps:get(reaction_buffer_timer, State, undefined),
case Timer of
undefined -> ok;
_ -> erlang:cancel_timer(Timer)
end,
StateCleared = State#{reaction_buffer => [], reaction_buffer_timer => undefined},
case Buffer of
[] ->
StateCleared;
[Single] ->
{noreply, FinalState} = do_handle_dispatch(message_reaction_add, Single, StateCleared),
FinalState;
_ ->
dispatch_reaction_add_many(Buffer, StateCleared)
end.
-spec dispatch_reaction_add_many([map()], session_state()) -> session_state().
dispatch_reaction_add_many(Buffer, State) ->
First = hd(Buffer),
ChannelId = maps:get(<<"channel_id">>, First, undefined),
MessageId = maps:get(<<"message_id">>, First, undefined),
GuildId = maps:get(<<"guild_id">>, First, undefined),
Reactions = lists:map(
fun(Entry) ->
Base = #{
<<"user_id">> => maps:get(<<"user_id">>, Entry, undefined),
<<"emoji">> => maps:get(<<"emoji">>, Entry, #{})
},
case maps:get(<<"member">>, Entry, undefined) of
undefined -> Base;
null -> Base;
Member -> Base#{<<"member">> => Member}
end
end,
Buffer
),
Data0 = #{
<<"channel_id">> => ChannelId,
<<"message_id">> => MessageId,
<<"reactions">> => Reactions
},
Data = case GuildId of
undefined -> Data0;
null -> Data0;
_ -> Data0#{<<"guild_id">> => GuildId}
end,
{noreply, FinalState} = do_handle_dispatch(message_reaction_add_many, Data, State),
FinalState.
-spec should_buffer_presence(event(), map(), session_state()) -> boolean().
should_buffer_presence(presence_update, Data, State) ->
case maps:get(suppress_presence_updates, State, true) of
true ->
@@ -100,6 +242,7 @@ should_buffer_presence(presence_update, Data, State) ->
should_buffer_presence(_, _, _) ->
false.
-spec relationship_allows_presence(user_id(), #{user_id() => integer()}) -> boolean().
relationship_allows_presence(UserId, Relationships) when
is_integer(UserId), is_map(Relationships)
->
@@ -111,6 +254,7 @@ relationship_allows_presence(UserId, Relationships) when
relationship_allows_presence(_, _) ->
false.
-spec is_group_dm_recipient(user_id(), session_state()) -> boolean().
is_group_dm_recipient(UserId, State) ->
GroupDmRecipients = presence_targets:group_dm_recipients_from_state(State),
lists:any(
@@ -120,13 +264,30 @@ is_group_dm_recipient(UserId, State) ->
maps:to_list(GroupDmRecipients)
).
-spec buffer_presence(event(), map(), session_state()) -> session_state().
buffer_presence(Event, Data, State) ->
Pending = maps:get(pending_presences, State, []),
UserId = presence_user_id(Data),
maps:put(
pending_presences, Pending ++ [#{event => Event, data => Data, user_id => UserId}], State
).
case
append_or_fail(
Pending,
#{event => Event, data => Data, user_id => UserId},
?MAX_PENDING_PRESENCE_BUFFER_SIZE,
pending_presence_buffer,
State
)
of
{ok, NewPending} ->
maps:put(
pending_presences,
NewPending,
State
);
overflow ->
State
end.
-spec maybe_flush_pending_presences(event(), map(), session_state()) -> session_state().
maybe_flush_pending_presences(relationship_add, Data, State) ->
maybe_flush_relationship_pending_presences(Data, State);
maybe_flush_pending_presences(relationship_update, Data, State) ->
@@ -134,120 +295,144 @@ maybe_flush_pending_presences(relationship_update, Data, State) ->
maybe_flush_pending_presences(_, _, State) ->
State.
maybe_flush_relationship_pending_presences(Data, State) when is_map(Data) ->
-spec maybe_flush_relationship_pending_presences(map(), session_state()) -> session_state().
maybe_flush_relationship_pending_presences(Data, State) ->
case maps:get(<<"type">>, Data, 0) of
1 ->
flush_pending_presences(relationship_target_id(Data), State);
3 ->
flush_pending_presences(relationship_target_id(Data), State);
_ ->
State
end;
maybe_flush_relationship_pending_presences(_Data, State) ->
State.
1 -> flush_pending_presences(relationship_target_id(Data), State);
3 -> flush_pending_presences(relationship_target_id(Data), State);
_ -> State
end.
-spec flush_pending_presences(user_id() | undefined, session_state()) -> session_state().
flush_pending_presences(undefined, State) ->
State;
flush_pending_presences(UserId, State) ->
Pending = maps:get(pending_presences, State, []),
{ToSend, Remaining} =
lists:partition(fun(P) -> maps:get(user_id, P, undefined) =:= UserId end, Pending),
FlushedState =
lists:foldl(
fun(P, AccState) ->
dispatch_presence_now(P, AccState)
end,
State,
ToSend
),
{ToSend, Remaining} = lists:partition(
fun(P) -> maps:get(user_id, P, undefined) =:= UserId end,
Pending
),
FlushedState = lists:foldl(
fun(P, AccState) -> dispatch_presence_now(P, AccState) end, State, ToSend
),
maps:put(pending_presences, Remaining, FlushedState).
-spec dispatch_presence_now(map(), session_state()) -> session_state().
dispatch_presence_now(P, State) ->
Event = maps:get(event, P),
Data = maps:get(data, P),
Seq = maps:get(seq, State),
Buffer = maps:get(buffer, State),
SocketPid = maps:get(socket_pid, State, undefined),
NewSeq = Seq + 1,
Request = #{event => Event, data => Data, seq => NewSeq},
NewBuffer = Buffer ++ [Request],
case append_or_fail(Buffer, Request, ?MAX_EVENT_BUFFER_SIZE, event_ack_buffer, State) of
{ok, NewBuffer} ->
send_to_socket(SocketPid, Event, Data, NewSeq),
maps:merge(State, #{seq => NewSeq, buffer => NewBuffer});
overflow ->
State
end.
-spec append_or_fail([term()], term(), pos_integer(), atom(), session_state()) ->
{ok, [term()]} | overflow.
append_or_fail(Buffer, Entry, MaxSize, BufferKind, State) ->
case length(Buffer) >= MaxSize of
true ->
report_buffer_overflow(BufferKind, length(Buffer), MaxSize, State),
overflow;
false ->
{ok, Buffer ++ [Entry]}
end.
-spec report_buffer_overflow(atom(), non_neg_integer(), pos_integer(), session_state()) -> ok.
report_buffer_overflow(BufferKind, CurrentSize, MaxSize, State) ->
AckSeq = maps:get(ack_seq, State, 0),
Seq = maps:get(seq, State, 0),
UnackedEvents = max(0, Seq - AckSeq),
KindBin = buffer_kind_to_binary(BufferKind),
SocketPid = maps:get(socket_pid, State, undefined),
Details = #{
kind => KindBin,
current_size => CurrentSize,
limit => MaxSize,
seq => Seq,
ack_seq => AckSeq,
unacked_events => UnackedEvents
},
logger:warning(
"Session backpressure overflow. kind=~ts current=~B limit=~B seq=~B ack_seq=~B unacked=~B",
[KindBin, CurrentSize, MaxSize, Seq, AckSeq, UnackedEvents]
),
otel_metrics:counter(<<"gateway.session.backpressure_overflow">>, 1, #{<<"kind">> => KindBin}),
otel_metrics:gauge(<<"gateway.session.unacked_events">>, UnackedEvents, #{<<"kind">> => KindBin}),
case SocketPid of
undefined ->
ok;
Pid when is_pid(Pid) ->
case erlang:is_process_alive(Pid) of
true ->
Pid ! {dispatch, Event, Data, NewSeq},
ok;
false ->
ok
end
Pid ! {session_backpressure_error, Details};
_ ->
ok
end,
gen_server:cast(self(), {terminate_force}),
ok.
maps:merge(State, #{seq => NewSeq, buffer => NewBuffer}).
-spec buffer_kind_to_binary(atom()) -> binary().
buffer_kind_to_binary(event_ack_buffer) -> <<"event_ack_buffer">>;
buffer_kind_to_binary(reaction_buffer) -> <<"reaction_buffer">>;
buffer_kind_to_binary(pending_presence_buffer) -> <<"pending_presence_buffer">>.
presence_user_id(Data) when is_map(Data) ->
User = maps:get(<<"user">>, Data, #{}),
map_utils:get_integer(User, <<"id">>, undefined);
presence_user_id(_) ->
undefined.
-spec presence_user_id(map()) -> user_id() | undefined.
presence_user_id(Data) ->
case maps:find(<<"user">>, Data) of
{ok, User} when is_map(User) ->
map_utils:get_integer(User, <<"id">>, undefined);
_ ->
undefined
end.
-spec relationship_target_id(map()) -> user_id() | undefined.
relationship_target_id(Data) when is_map(Data) ->
type_conv:extract_id(Data, <<"id">>).
-spec flush_all_pending_presences(session_state()) -> session_state().
flush_all_pending_presences(State) ->
Pending = maps:get(pending_presences, State, []),
FlushedState =
lists:foldl(
fun(P, AccState) ->
dispatch_presence_now(P, AccState)
end,
State,
Pending
),
FlushedState = lists:foldl(
fun(P, AccState) -> dispatch_presence_now(P, AccState) end, State, Pending
),
maps:put(pending_presences, [], FlushedState).
-spec should_ignore_event(event(), session_state()) -> boolean().
should_ignore_event(Event, State) ->
IgnoredEvents = maps:get(ignored_events, State, #{}),
case event_name(Event) of
undefined ->
false;
EventName ->
maps:is_key(EventName, IgnoredEvents)
undefined -> false;
EventName -> maps:is_key(EventName, IgnoredEvents)
end.
-spec event_name(event()) -> binary() | undefined.
event_name(Event) when is_binary(Event) ->
Event;
event_name(Event) when is_atom(Event) ->
try constants:dispatch_event_atom(Event) of
Name when is_binary(Name) ->
Name
Name when is_binary(Name) -> Name
catch
_:_ ->
undefined
_:_ -> undefined
end;
event_name(_) ->
undefined.
-spec update_channels_map(event(), map(), session_state()) -> session_state().
update_channels_map(channel_create, Data, State) when is_map(Data) ->
case maps:get(<<"type">>, Data, undefined) of
1 ->
add_channel_to_state(Data, State);
3 ->
add_channel_to_state(Data, State);
_ ->
State
1 -> add_channel_to_state(Data, State);
3 -> add_channel_to_state(Data, State);
_ -> State
end;
update_channels_map(channel_update, Data, State) when is_map(Data) ->
case maps:get(<<"type">>, Data, undefined) of
1 ->
add_channel_to_state(Data, State);
3 ->
add_channel_to_state(Data, State);
_ ->
State
1 -> add_channel_to_state(Data, State);
3 -> add_channel_to_state(Data, State);
_ -> State
end;
update_channels_map(channel_delete, Data, State) when is_map(Data) ->
case maps:get(<<"id">>, Data, undefined) of
@@ -270,6 +455,7 @@ update_channels_map(channel_recipient_remove, Data, State) when is_map(Data) ->
update_channels_map(_Event, _Data, State) ->
State.
-spec add_channel_to_state(map(), session_state()) -> session_state().
add_channel_to_state(Data, State) ->
case maps:get(<<"id">>, Data, undefined) of
undefined ->
@@ -279,17 +465,13 @@ add_channel_to_state(Data, State) ->
{ok, ChannelId} ->
Channels = maps:get(channels, State, #{}),
NewChannels = maps:put(ChannelId, Data, Channels),
UserId = maps:get(user_id, State),
logger:info(
"[session_dispatch] Added/updated channel ~p for user ~p, type: ~p",
[ChannelId, UserId, maps:get(<<"type">>, Data, 0)]
),
maps:put(channels, NewChannels, State);
{error, _, _} ->
State
end
end.
-spec update_recipient_membership(add | remove, map(), session_state()) -> session_state().
update_recipient_membership(Action, Data, State) ->
ChannelIdBin = maps:get(<<"channel_id">>, Data, undefined),
case validation:validate_snowflake(<<"channel_id">>, ChannelIdBin) of
@@ -321,6 +503,7 @@ update_recipient_membership(Action, Data, State) ->
State
end.
-spec update_channel_recipient(map(), user_id(), map(), add | remove) -> map().
update_channel_recipient(Channel, RecipientId, UserMap, add) ->
RecipientIds = maps:get(<<"recipient_ids">>, Channel, []),
Recipients = maps:get(<<"recipients">>, Channel, []),
@@ -345,28 +528,26 @@ update_channel_recipient(Channel, RecipientId, _UserMap, remove) ->
),
Channel#{<<"recipient_ids">> => NewRecipientIds, <<"recipients">> => NewRecipients}.
-spec add_unique_id(user_id(), [binary() | user_id()]) -> [binary() | user_id()].
add_unique_id(Id, List) ->
case lists:member(Id, List) orelse lists:member(integer_to_binary(Id), List) of
true -> List;
false -> [Id | List]
end.
-spec add_unique_user(map(), [map()]) -> [map()].
add_unique_user(UserMap, List) when is_map(UserMap) ->
case type_conv:extract_id(UserMap, <<"id">>) of
undefined ->
List;
Id ->
case
lists:any(
fun(R) -> type_conv:extract_id(R, <<"id">>) =:= Id end,
List
)
of
case lists:any(fun(R) -> type_conv:extract_id(R, <<"id">>) =:= Id end, List) of
true -> List;
false -> [UserMap | List]
end
end.
-spec update_relationships_map(event(), map(), session_state()) -> session_state().
update_relationships_map(relationship_add, Data, State) ->
upsert_relationship(Data, State);
update_relationships_map(relationship_update, Data, State) ->
@@ -383,6 +564,7 @@ update_relationships_map(relationship_remove, Data, State) ->
update_relationships_map(_, _, State) ->
State.
-spec upsert_relationship(map(), session_state()) -> session_state().
upsert_relationship(Data, State) ->
case type_conv:extract_id(Data, <<"id">>) of
undefined ->
@@ -394,6 +576,7 @@ upsert_relationship(Data, State) ->
maps:put(relationships, NewRelationships, State)
end.
-spec sync_presence_targets(session_state()) -> session_state().
sync_presence_targets(State) ->
PresencePid = maps:get(presence_pid, State, undefined),
case PresencePid of
@@ -407,6 +590,42 @@ sync_presence_targets(State) ->
State
end.
-spec start_fanout_span(event(), session_state()) -> {term(), term()} | undefined.
start_fanout_span(Event, State) ->
case event_name(Event) of
undefined ->
undefined;
EventName ->
SpanName = websocket_fanout,
Attributes = build_fanout_attributes(EventName, State),
gateway_tracing:start_event_span(?MODULE, SpanName, Attributes)
end.
-spec end_fanout_span({term(), term()} | undefined, event()) -> ok.
end_fanout_span(Context, Event) ->
EventName = event_name(Event),
Attributes = #{<<"event.name">> => EventName},
Outcome = gateway_tracing:end_event_span(Context, Attributes),
end_fanout_metrics(Outcome),
ok.
build_fanout_attributes(EventName, State) ->
case maps:get(user_id, State, undefined) of
undefined -> #{<<"event.name">> => EventName};
UserId -> #{<<"event.name">> => EventName, <<"user.id">> => UserId}
end.
-ifdef(HAS_OPENTELEMETRY).
end_fanout_metrics(Outcome) ->
case Outcome of
ok -> gateway_metrics_collector:inc_fanout(1);
_ -> gateway_metrics_collector:inc_fanout(0)
end.
-else.
end_fanout_metrics(_Outcome) ->
gateway_metrics_collector:inc_fanout(1).
-endif.
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
@@ -483,4 +702,27 @@ presence_update_without_guild_id_not_buffered_for_incoming_request_relationship_
?assertEqual(1, length(maps:get(buffer, State1, []))),
ok.
relationship_allows_presence_test() ->
?assertEqual(true, relationship_allows_presence(1, #{1 => 1})),
?assertEqual(true, relationship_allows_presence(1, #{1 => 3})),
?assertEqual(false, relationship_allows_presence(1, #{1 => 0})),
?assertEqual(false, relationship_allows_presence(1, #{1 => 2})),
?assertEqual(false, relationship_allows_presence(1, #{1 => 4})),
?assertEqual(false, relationship_allows_presence(1, #{})),
?assertEqual(false, relationship_allows_presence(not_integer, #{})),
ok.
presence_user_id_test() ->
?assertEqual(123, presence_user_id(#{<<"user">> => #{<<"id">> => <<"123">>}})),
?assertEqual(undefined, presence_user_id(#{<<"user">> => #{}})),
?assertEqual(undefined, presence_user_id(#{})),
?assertEqual(undefined, presence_user_id(#{<<"user">> => not_a_map})),
ok.
add_unique_id_test() ->
?assertEqual([1, 2, 3], add_unique_id(1, [2, 3])),
?assertEqual([1, 2, 3], add_unique_id(1, [1, 2, 3])),
?assertEqual([<<"1">>, 2, 3], add_unique_id(1, [<<"1">>, 2, 3])),
ok.
-endif.

View File

@@ -20,540 +20,395 @@
-include_lib("fluxer_gateway/include/timeout_config.hrl").
-export([start_link/0]).
-define(SHARD_TABLE, session_manager_shard_table).
-define(START_TIMEOUT, 10000).
-define(LOOKUP_TIMEOUT, 5000).
-export([start_link/0, start/2, lookup/1]).
-export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]).
-export_type([session_data/0, user_id/0]).
-type session_id() :: binary().
-type user_id() :: integer().
-type session_ref() :: {pid(), reference()}.
-type status() :: online | offline | idle | dnd.
-type identify_timestamp() :: integer().
-define(IDENTIFY_FLAG_USE_CANARY_API, 16#1).
-type identify_request() :: #{
session_id := session_id(),
identify_data := map(),
version := non_neg_integer(),
peer_ip := term(),
token := binary()
}.
-type session_data() :: #{
id := session_id(),
user_id := user_id(),
user_data := map(),
version := non_neg_integer(),
token_hash := binary(),
auth_session_id_hash := binary(),
properties := map(),
status := status(),
afk := boolean(),
mobile := boolean(),
socket_pid := pid(),
guilds := [integer()],
ready := map(),
ignored_events := [binary()]
}.
-type state() :: #{
sessions := #{session_id() => session_ref()},
api_host := string(),
api_canary_host := undefined | string(),
identify_attempts := [identify_timestamp()]
}.
-type shard() :: #{pid := pid(), ref := reference()}.
-type state() :: #{shards := #{non_neg_integer() => shard()}, shard_count := pos_integer()}.
-spec start_link() -> {ok, pid()} | {error, term()}.
start_link() ->
gen_server:start_link({local, ?MODULE}, ?MODULE, [], []).
-spec start(map(), pid()) -> term().
start(Request, SocketPid) ->
SessionId = maps:get(session_id, Request),
call_shard(SessionId, {start, Request, SocketPid}, ?START_TIMEOUT).
-spec lookup(session_id()) -> {ok, pid()} | {error, not_found}.
lookup(SessionId) ->
call_shard(SessionId, {lookup, SessionId}, ?LOOKUP_TIMEOUT).
-spec init([]) -> {ok, state()}.
init([]) ->
fluxer_gateway_env:load(),
process_flag(trap_exit, true),
ApiHost = fluxer_gateway_env:get(api_host),
ApiCanaryHost = fluxer_gateway_env:get(api_canary_host),
{ok, #{
sessions => #{},
api_host => ApiHost,
api_canary_host => ApiCanaryHost,
identify_attempts => []
}}.
-spec handle_call(Request, From, State) -> Result when
Request ::
{start, identify_request(), pid()}
| {lookup, session_id()}
| get_local_count
| get_global_count
| term(),
From :: gen_server:from(),
State :: state(),
Result :: {reply, Reply, state()},
Reply ::
{success, pid()}
| {ok, pid()}
| {error, not_found}
| {error, identify_rate_limited}
| {error, invalid_token}
| {error, rate_limited}
| {error, {server_error, non_neg_integer()}}
| {error, {http_error, non_neg_integer()}}
| {error, {network_error, term()}}
| {error, registration_failed}
| {error, term()}
| {ok, non_neg_integer()}
| ok.
handle_call(
{start, Request, SocketPid},
_From,
State
) ->
Sessions = maps:get(sessions, State),
Attempts = maps:get(identify_attempts, State),
SessionId = maps:get(session_id, Request),
case maps:get(SessionId, Sessions, undefined) of
{Pid, _Ref} ->
{reply, {success, Pid}, State};
undefined ->
SessionName = process_registry:build_process_name(session, SessionId),
case whereis(SessionName) of
undefined ->
case check_identify_rate_limit(Attempts) of
{ok, NewAttempts} ->
handle_identify_request(
Request,
SocketPid,
SessionId,
Sessions,
maps:put(identify_attempts, NewAttempts, State)
);
{error, rate_limited} ->
{reply, {error, identify_rate_limited}, State}
end;
Pid ->
Ref = monitor(process, Pid),
NewSessions = maps:put(SessionId, {Pid, Ref}, Sessions),
{reply, {success, Pid}, maps:put(sessions, NewSessions, State)}
end
end;
handle_call({lookup, SessionId}, _From, State) ->
Sessions = maps:get(sessions, State),
case maps:get(SessionId, Sessions, undefined) of
{Pid, _Ref} ->
{reply, {ok, Pid}, State};
undefined ->
SessionName = process_registry:build_process_name(session, SessionId),
case whereis(SessionName) of
undefined ->
{reply, {error, not_found}, State};
Pid ->
Ref = monitor(process, Pid),
NewSessions = maps:put(SessionId, {Pid, Ref}, Sessions),
{reply, {ok, Pid}, maps:put(sessions, NewSessions, State)}
end
end;
handle_call(get_local_count, _From, State) ->
Sessions = maps:get(sessions, State),
{reply, {ok, maps:size(Sessions)}, State};
handle_call(get_global_count, _From, State) ->
Sessions = maps:get(sessions, State),
{reply, {ok, maps:size(Sessions)}, State};
handle_call(_, _From, State) ->
{reply, ok, State}.
-spec handle_identify_request(
identify_request(),
pid(),
session_id(),
#{session_id() => session_ref()},
state()
) ->
{reply,
{success, pid()}
| {error, invalid_token}
| {error, rate_limited}
| {error, {server_error, non_neg_integer()}}
| {error, {http_error, non_neg_integer()}}
| {error, {network_error, term()}}
| {error, registration_failed}
| {error, term()},
state()}.
handle_identify_request(
Request, SocketPid, SessionId, Sessions, State
) ->
IdentifyData = maps:get(identify_data, Request),
Version = maps:get(version, Request),
PeerIP = maps:get(peer_ip, Request),
UseCanary = should_use_canary_api(IdentifyData),
{_UsedCanary, RpcClient} = select_rpc_client(State, UseCanary),
case fetch_rpc_data(Request, PeerIP, RpcClient) of
{ok, Data} ->
UserDataMap = maps:get(<<"user">>, Data),
UserId = type_conv:extract_id(UserDataMap, <<"id">>),
AuthSessionIdHashEncoded = maps:get(<<"auth_session_id_hash">>, Data, undefined),
AuthSessionIdHash =
case AuthSessionIdHashEncoded of
undefined -> <<>>;
null -> <<>>;
_ -> base64url:decode(AuthSessionIdHashEncoded)
end,
Status = parse_presence(Data, IdentifyData),
GuildIds = parse_guild_ids(Data),
Properties = maps:get(properties, IdentifyData),
Presence = map_utils:get_safe(IdentifyData, presence, null),
IgnoredEvents = map_utils:get_safe(IdentifyData, ignored_events, []),
InitialGuildId = map_utils:get_safe(IdentifyData, initial_guild_id, undefined),
Bot = map_utils:get_safe(UserDataMap, <<"bot">>, false),
ReadyData =
case Bot of
true -> maps:merge(Data, #{<<"guilds">> => []});
false -> Data
end,
UserSettingsMap = map_utils:get_safe(Data, <<"user_settings">>, #{}),
CustomStatusFromSettings = map_utils:get_safe(
UserSettingsMap, <<"custom_status">>, null
),
PresenceCustomStatus = get_presence_custom_status(Presence),
CustomStatus =
case CustomStatusFromSettings of
null -> PresenceCustomStatus;
_ -> CustomStatusFromSettings
end,
Mobile =
case Presence of
null -> map_utils:get_safe(Properties, <<"mobile">>, false);
P when is_map(P) -> map_utils:get_safe(P, <<"mobile">>, false);
_ -> false
end,
Afk =
case Presence of
null -> false;
P2 when is_map(P2) -> map_utils:get_safe(P2, <<"afk">>, false);
_ -> false
end,
UserData0 = #{
<<"id">> => maps:get(<<"id">>, UserDataMap),
<<"username">> => maps:get(<<"username">>, UserDataMap),
<<"discriminator">> => maps:get(<<"discriminator">>, UserDataMap),
<<"avatar">> => maps:get(<<"avatar">>, UserDataMap),
<<"avatar_color">> => map_utils:get_safe(
UserDataMap, <<"avatar_color">>, undefined
),
<<"bot">> => map_utils:get_safe(UserDataMap, <<"bot">>, undefined),
<<"system">> => map_utils:get_safe(UserDataMap, <<"system">>, undefined),
<<"flags">> => maps:get(<<"flags">>, UserDataMap)
},
UserData = user_utils:normalize_user(UserData0),
SessionData = #{
id => SessionId,
user_id => UserId,
user_data => UserData,
custom_status => CustomStatus,
version => Version,
token_hash => utils:hash_token(maps:get(token, IdentifyData)),
auth_session_id_hash => AuthSessionIdHash,
properties => Properties,
status => Status,
afk => Afk,
mobile => Mobile,
socket_pid => SocketPid,
guilds => GuildIds,
ready => ReadyData,
bot => Bot,
ignored_events => IgnoredEvents,
initial_guild_id => InitialGuildId
},
SessionName = process_registry:build_process_name(session, SessionId),
case whereis(SessionName) of
undefined ->
case session:start_link(SessionData) of
{ok, Pid} ->
case
process_registry:register_and_monitor(SessionName, Pid, Sessions)
of
{ok, RegisteredPid, Ref, NewSessions0} ->
CleanSessions = maps:remove(SessionName, NewSessions0),
NewSessions = maps:put(
SessionId, {RegisteredPid, Ref}, CleanSessions
),
{reply, {success, RegisteredPid}, maps:put(
sessions, NewSessions, State
)};
{error, registration_race_condition} ->
{reply, {error, registration_failed}, State};
{error, _Reason} ->
{reply, {error, registration_failed}, State}
end;
Error ->
{reply, Error, State}
end;
ExistingPid ->
Ref = monitor(process, ExistingPid),
CleanSessions = maps:remove(SessionName, Sessions),
NewSessions = maps:put(SessionId, {ExistingPid, Ref}, CleanSessions),
{reply, {success, ExistingPid}, maps:put(sessions, NewSessions, State)}
end;
{error, invalid_token} ->
{reply, {error, invalid_token}, State};
{error, rate_limited} ->
{reply, {error, rate_limited}, State};
{error, Reason} ->
{reply, {error, Reason}, State}
end.
-spec handle_cast(term(), state()) -> {noreply, state()}.
handle_cast(_, State) ->
{noreply, State}.
select_rpc_client(State, true) ->
case maps:get(api_canary_host, State) of
undefined ->
logger:warning(
"[session_manager] Canary API requested but not configured, falling back to stable API"
),
{false, maps:get(api_host, State)};
CanaryHost ->
{true, CanaryHost}
end;
select_rpc_client(State, false) ->
{false, maps:get(api_host, State)}.
should_use_canary_api(IdentifyData) ->
case map_utils:get_safe(IdentifyData, flags, 0) of
Flags when is_integer(Flags), Flags >= 0 ->
(Flags band ?IDENTIFY_FLAG_USE_CANARY_API) =/= 0;
_ ->
false
end.
-spec handle_info(Info, State) -> {noreply, state()} when
Info :: {'DOWN', reference(), process, pid(), term()} | term(),
State :: state().
handle_info({'DOWN', _Ref, process, Pid, _Reason}, State) ->
Sessions = maps:get(sessions, State),
NewSessions = process_registry:cleanup_on_down(Pid, Sessions),
{noreply, maps:put(sessions, NewSessions, State)};
handle_info(_, State) ->
{noreply, State}.
-spec terminate(Reason, State) -> ok when
Reason :: term(),
State :: state().
terminate(_Reason, _State) ->
ok.
-spec code_change(OldVsn, State, Extra) -> {ok, state()} when
OldVsn :: term(),
State :: state() | tuple(),
Extra :: term().
code_change(_OldVsn, State, _Extra) when is_map(State) ->
{ok, State};
code_change(_OldVsn, State, _Extra) when is_tuple(State), element(1, State) =:= state ->
Sessions = element(2, State),
ApiHost = element(3, State),
ApiCanaryHost = element(4, State),
IdentifyAttempts = element(5, State),
{ok, #{
sessions => Sessions,
api_host => ApiHost,
api_canary_host => ApiCanaryHost,
identify_attempts => IdentifyAttempts
}};
code_change(_OldVsn, State, _Extra) ->
fluxer_gateway_env:load(),
ensure_shard_table(),
{ShardCount, _Source} = determine_shard_count(),
Shards = start_shards(ShardCount),
State = #{shards => Shards, shard_count => ShardCount},
sync_shard_table(State),
{ok, State}.
-spec fetch_rpc_data(map(), term(), string()) ->
{ok, map()}
| {error, invalid_token}
| {error, rate_limited}
| {error, {server_error, non_neg_integer()}}
| {error, {http_error, non_neg_integer()}}
| {error, {network_error, term()}}.
fetch_rpc_data(Request, PeerIP, ApiHost) ->
StartTime = erlang:system_time(millisecond),
Result = do_fetch_rpc_data(Request, PeerIP, ApiHost),
EndTime = erlang:system_time(millisecond),
LatencyMs = EndTime - StartTime,
gateway_metrics_collector:record_rpc_latency(LatencyMs),
Result.
-spec handle_call(term(), gen_server:from(), state()) -> {reply, term(), state()}.
handle_call({proxy_call, SessionId, Request, Timeout}, _From, State) ->
{Reply, NewState} = forward_call(SessionId, Request, Timeout, State),
{reply, Reply, NewState};
handle_call({start, Request, SocketPid}, _From, State) ->
SessionId = maps:get(session_id, Request),
{Reply, NewState} = forward_call(SessionId, {start, Request, SocketPid}, ?START_TIMEOUT, State),
{reply, Reply, NewState};
handle_call({lookup, SessionId}, _From, State) ->
{Reply, NewState} = forward_call(SessionId, {lookup, SessionId}, ?LOOKUP_TIMEOUT, State),
{reply, Reply, NewState};
handle_call(get_local_count, _From, State) ->
{Count, NewState} = aggregate_counts(get_local_count, State),
{reply, {ok, Count}, NewState};
handle_call(get_global_count, _From, State) ->
{Count, NewState} = aggregate_counts(get_global_count, State),
{reply, {ok, Count}, NewState};
handle_call(_Request, _From, State) ->
{reply, ok, State}.
-spec do_fetch_rpc_data(map(), term(), string()) ->
{ok, map()}
| {error, invalid_token}
| {error, rate_limited}
| {error, {server_error, non_neg_integer()}}
| {error, {http_error, non_neg_integer()}}
| {error, {network_error, term()}}.
do_fetch_rpc_data(Request, PeerIP, ApiHost) ->
Url = rpc_client:get_rpc_url(ApiHost),
Headers = rpc_client:get_rpc_headers() ++ [{<<"content-type">>, <<"application/json">>}],
IdentifyData = maps:get(identify_data, Request),
Properties = map_utils:get_safe(IdentifyData, properties, #{}),
LatitudeRaw = map_utils:get_safe(Properties, <<"latitude">>, undefined),
LongitudeRaw = map_utils:get_safe(Properties, <<"longitude">>, undefined),
Latitude =
case LatitudeRaw of
undefined -> undefined;
null -> undefined;
SafeLatitude -> SafeLatitude
end,
Longitude =
case LongitudeRaw of
undefined -> undefined;
null -> undefined;
SafeLongitude -> SafeLongitude
end,
RpcRequest = #{
<<"type">> => <<"session">>,
<<"token">> => maps:get(token, IdentifyData),
<<"version">> => maps:get(version, Request),
<<"ip">> => PeerIP
},
RpcRequestWithLatitude =
case Latitude of
undefined -> RpcRequest;
LatitudeValue -> maps:put(<<"latitude">>, LatitudeValue, RpcRequest)
end,
RpcRequestWithLongitude =
case Longitude of
undefined -> RpcRequestWithLatitude;
LongitudeValue -> maps:put(<<"longitude">>, LongitudeValue, RpcRequestWithLatitude)
end,
Body = jsx:encode(RpcRequestWithLongitude),
case hackney:request(post, Url, Headers, Body, []) of
{ok, 200, _RespHeaders, ClientRef} ->
case hackney:body(ClientRef) of
{ok, ResponseBody} ->
hackney:close(ClientRef),
ResponseData = jsx:decode(ResponseBody, [{return_maps, true}]),
{ok, maps:get(<<"data">>, ResponseData)};
{error, BodyError} ->
hackney:close(ClientRef),
logger:error("[session_manager] Failed to read response body: ~p", [BodyError]),
{error, {network_error, BodyError}}
end;
{ok, 401, _, ClientRef} ->
hackney:close(ClientRef),
logger:info("[session_manager] RPC authentication failed (401)"),
{error, invalid_token};
{ok, 429, _, ClientRef} ->
hackney:close(ClientRef),
logger:warning("[session_manager] RPC rate limited (429)"),
{error, rate_limited};
{ok, StatusCode, _, ClientRef} when StatusCode >= 500 ->
ErrorBody =
case hackney:body(ClientRef) of
{ok, Body2} -> Body2;
{error, _} -> <<"<unable to read error body>">>
end,
hackney:close(ClientRef),
logger:error("[session_manager] RPC server error ~p: ~s", [StatusCode, ErrorBody]),
{error, {server_error, StatusCode}};
{ok, StatusCode, _, ClientRef} when StatusCode >= 400 ->
ErrorBody =
case hackney:body(ClientRef) of
{ok, Body2} -> Body2;
{error, _} -> <<"<unable to read error body>">>
end,
hackney:close(ClientRef),
logger:warning("[session_manager] RPC client error ~p: ~s", [StatusCode, ErrorBody]),
{error, {http_error, StatusCode}};
{ok, StatusCode, _, ClientRef} ->
hackney:close(ClientRef),
logger:warning("[session_manager] RPC unexpected status: ~p", [StatusCode]),
{error, {http_error, StatusCode}};
{error, Reason} ->
logger:error("[session_manager] RPC request failed: ~p", [Reason]),
{error, {network_error, Reason}}
end.
-spec handle_cast(term(), state()) -> {noreply, state()}.
handle_cast(_Msg, State) ->
{noreply, State}.
-spec parse_presence(map(), map()) -> status().
parse_presence(Data, IdentifyData) ->
StoredStatus = get_stored_status(Data),
PresenceStatus =
case map_utils:get_safe(IdentifyData, presence, null) of
null ->
undefined;
Presence when is_map(Presence) ->
map_utils:get_safe(Presence, status, <<"online">>);
_ ->
undefined
-spec handle_info(term(), state()) -> {noreply, state()}.
handle_info({'DOWN', Ref, process, Pid, _Reason}, State) ->
Shards = maps:get(shards, State),
case find_shard_by_ref(Ref, Shards) of
{ok, Index} ->
{_Shard, NewState} = restart_shard(Index, State),
{noreply, NewState};
not_found ->
case find_shard_by_pid(Pid, Shards) of
{ok, Index} ->
{_Shard, NewState} = restart_shard(Index, State),
{noreply, NewState};
not_found ->
{noreply, State}
end
end;
handle_info({'EXIT', Pid, _Reason}, State) ->
Shards = maps:get(shards, State),
case find_shard_by_pid(Pid, Shards) of
{ok, Index} ->
{_Shard, NewState} = restart_shard(Index, State),
{noreply, NewState};
not_found ->
{noreply, State}
end;
handle_info(_Info, State) ->
{noreply, State}.
-spec terminate(term(), state()) -> ok.
terminate(_Reason, State) ->
Shards = maps:get(shards, State),
lists:foreach(
fun(#{pid := Pid}) ->
catch gen_server:stop(Pid, shutdown, 5000)
end,
SelectedStatus = select_initial_status(PresenceStatus, StoredStatus),
utils:parse_status(SelectedStatus).
maps:values(Shards)
),
catch ets:delete(?SHARD_TABLE),
ok.
-spec parse_guild_ids(map()) -> [integer()].
parse_guild_ids(Data) ->
GuildIds = map_utils:get_safe(Data, <<"guild_ids">>, []),
[utils:binary_to_integer_safe(Id) || Id <- GuildIds, Id =/= undefined].
-spec check_identify_rate_limit(list()) -> {ok, list()} | {error, rate_limited}.
check_identify_rate_limit(Attempts) ->
case fluxer_gateway_env:get(identify_rate_limit_enabled) of
true ->
Now = erlang:system_time(millisecond),
WindowDuration = 5000,
AttemptsInWindow = [T || T <- Attempts, (Now - T) < WindowDuration],
AttemptsCount = length(AttemptsInWindow),
MaxIdentifiesPerWindow = 1,
case AttemptsCount >= MaxIdentifiesPerWindow of
true ->
{error, rate_limited};
false ->
NewAttempts = [Now | AttemptsInWindow],
{ok, NewAttempts}
end;
-spec code_change(term(), term(), term()) -> {ok, state()}.
code_change(_OldVsn, State, _Extra) when is_map(State) ->
case {maps:is_key(shards, State), maps:is_key(shard_count, State)} of
{true, true} ->
sync_shard_table(State),
{ok, State};
_ ->
{ok, Attempts}
{ok, rebuild_state()}
end;
code_change(_OldVsn, _State, _Extra) ->
{ok, rebuild_state()}.
-spec call_shard(session_id(), term(), pos_integer()) -> term().
call_shard(SessionId, Request, Timeout) ->
case shard_pid_from_table(SessionId) of
{ok, Pid} ->
case catch gen_server:call(Pid, Request, Timeout) of
{'EXIT', _} ->
call_via_manager(SessionId, Request, Timeout);
Reply ->
Reply
end;
error ->
call_via_manager(SessionId, Request, Timeout)
end.
-spec get_presence_custom_status(term()) -> map() | null.
get_presence_custom_status(Presence) ->
case Presence of
null -> null;
Map when is_map(Map) -> map_utils:get_safe(Map, <<"custom_status">>, null);
_ -> null
-spec call_via_manager(session_id(), term(), pos_integer()) -> term().
call_via_manager(SessionId, Request, Timeout) ->
gen_server:call(?MODULE, {proxy_call, SessionId, Request, Timeout}, Timeout + 1000).
-spec forward_call(session_id(), term(), pos_integer(), state()) -> {term(), state()}.
forward_call(SessionId, Request, Timeout, State) ->
{Index, State1} = ensure_shard(SessionId, State),
Shards = maps:get(shards, State1),
#{pid := Pid} = maps:get(Index, Shards),
case catch gen_server:call(Pid, Request, Timeout) of
{'EXIT', _} ->
{_Shard, State2} = restart_shard(Index, State1),
Shards2 = maps:get(shards, State2),
#{pid := RetryPid} = maps:get(Index, Shards2),
case catch gen_server:call(RetryPid, Request, Timeout) of
{'EXIT', _} ->
{{error, unavailable}, State2};
Reply ->
{Reply, State2}
end;
Reply ->
{Reply, State1}
end.
-spec get_stored_status(map()) -> binary().
get_stored_status(Data) ->
case map_utils:get_safe(Data, <<"user_settings">>, null) of
null ->
<<"online">>;
UserSettings ->
case normalize_status(map_utils:get_safe(UserSettings, <<"status">>, <<"online">>)) of
undefined -> <<"online">>;
Value -> Value
-spec rebuild_state() -> state().
rebuild_state() ->
ensure_shard_table(),
{ShardCount, _Source} = determine_shard_count(),
Shards = start_shards(ShardCount),
State = #{shards => Shards, shard_count => ShardCount},
sync_shard_table(State),
State.
-spec determine_shard_count() -> {pos_integer(), configured | auto}.
determine_shard_count() ->
case fluxer_gateway_env:get(session_shards) of
Value when is_integer(Value), Value > 0 ->
{Value, configured};
_ ->
{default_shard_count(), auto}
end.
-spec default_shard_count() -> pos_integer().
default_shard_count() ->
Candidates = [
erlang:system_info(logical_processors_available),
erlang:system_info(schedulers_online)
],
lists:max([C || C <- Candidates, is_integer(C), C > 0] ++ [1]).
-spec start_shards(pos_integer()) -> #{non_neg_integer() => shard()}.
start_shards(Count) ->
lists:foldl(
fun(Index, Acc) ->
case start_shard(Index) of
{ok, Shard} ->
maps:put(Index, Shard, Acc);
{error, _Reason} ->
Acc
end
end,
#{},
lists:seq(0, Count - 1)
).
-spec start_shard(non_neg_integer()) -> {ok, shard()} | {error, term()}.
start_shard(Index) ->
case session_manager_shard:start_link(Index) of
{ok, Pid} ->
Ref = erlang:monitor(process, Pid),
put_shard_pid(Index, Pid),
{ok, #{pid => Pid, ref => Ref}};
Error ->
Error
end.
-spec restart_shard(non_neg_integer(), state()) -> {shard(), state()}.
restart_shard(Index, State) ->
case start_shard(Index) of
{ok, Shard} ->
Shards = maps:get(shards, State),
NewState = State#{shards := maps:put(Index, Shard, Shards)},
sync_shard_table(NewState),
{Shard, NewState};
{error, _Reason} ->
Dummy = #{pid => spawn(fun() -> exit(normal) end), ref => make_ref()},
{Dummy, State}
end.
-spec ensure_shard(session_id(), state()) -> {non_neg_integer(), state()}.
ensure_shard(SessionId, State) ->
Count = maps:get(shard_count, State),
Shards = maps:get(shards, State),
Index = select_shard(SessionId, Count),
case maps:get(Index, Shards, undefined) of
undefined ->
{_Shard, NewState} = restart_shard(Index, State),
{Index, NewState};
#{pid := Pid} ->
case erlang:is_process_alive(Pid) of
true ->
{Index, State};
false ->
{_Shard, NewState} = restart_shard(Index, State),
{Index, NewState}
end
end.
-spec select_initial_status(binary() | undefined, binary()) -> binary().
select_initial_status(PresenceStatus, StoredStatus) ->
NormalizedPresence = normalize_status(PresenceStatus),
case {NormalizedPresence, StoredStatus} of
{undefined, Stored} ->
Stored;
{<<"unknown">>, Stored} ->
Stored;
{<<"online">>, Stored} when Stored =/= <<"online">> ->
Stored;
{Presence, _} ->
Presence
-spec aggregate_counts(term(), state()) -> {non_neg_integer(), state()}.
aggregate_counts(Request, State) ->
Shards = maps:get(shards, State),
Counts =
lists:map(
fun(#{pid := Pid}) ->
case catch gen_server:call(Pid, Request, ?DEFAULT_GEN_SERVER_TIMEOUT) of
{ok, Count} when is_integer(Count) ->
Count;
Count when is_integer(Count) ->
Count;
_ ->
0
end
end,
maps:values(Shards)
),
{lists:sum(Counts), State}.
-spec ensure_shard_table() -> ok.
ensure_shard_table() ->
case ets:whereis(?SHARD_TABLE) of
undefined ->
_ = ets:new(?SHARD_TABLE, [named_table, public, set, {read_concurrency, true}]),
ok;
_ ->
ok
end.
-spec normalize_status(term()) -> binary() | undefined.
normalize_status(undefined) ->
undefined;
normalize_status(null) ->
undefined;
normalize_status(Status) when is_binary(Status) ->
Status;
normalize_status(Status) when is_atom(Status) ->
try constants:status_type_atom(Status) of
Value when is_binary(Value) -> Value
-spec sync_shard_table(state()) -> ok.
sync_shard_table(State) ->
ensure_shard_table(),
_ = ets:delete_all_objects(?SHARD_TABLE),
ShardCount = maps:get(shard_count, State),
ets:insert(?SHARD_TABLE, {shard_count, ShardCount}),
Shards = maps:get(shards, State),
lists:foreach(
fun({Index, #{pid := Pid}}) ->
put_shard_pid(Index, Pid)
end,
maps:to_list(Shards)
),
ok.
-spec put_shard_pid(non_neg_integer(), pid()) -> ok.
put_shard_pid(Index, Pid) ->
ets:insert(?SHARD_TABLE, {{shard_pid, Index}, Pid}),
ok.
-spec shard_pid_from_table(session_id()) -> {ok, pid()} | error.
shard_pid_from_table(SessionId) ->
try
case ets:lookup(?SHARD_TABLE, shard_count) of
[{shard_count, ShardCount}] when is_integer(ShardCount), ShardCount > 0 ->
Index = select_shard(SessionId, ShardCount),
case ets:lookup(?SHARD_TABLE, {shard_pid, Index}) of
[{{shard_pid, Index}, Pid}] when is_pid(Pid) ->
case erlang:is_process_alive(Pid) of
true -> {ok, Pid};
false -> error
end;
_ ->
error
end;
_ ->
error
end
catch
_:_ -> undefined
end;
normalize_status(_) ->
undefined.
error:badarg ->
error
end.
-spec select_shard(session_id(), pos_integer()) -> non_neg_integer().
select_shard(SessionId, Count) when Count > 0 ->
rendezvous_router:select(SessionId, Count).
-spec find_shard_by_ref(reference(), #{non_neg_integer() => shard()}) ->
{ok, non_neg_integer()} | not_found.
find_shard_by_ref(Ref, Shards) ->
maps:fold(
fun
(_Index, _Shard, {ok, _} = Found) ->
Found;
(Index, #{ref := ExistingRef}, not_found) ->
case ExistingRef =:= Ref of
true -> {ok, Index};
false -> not_found
end
end,
not_found,
Shards
).
-spec find_shard_by_pid(pid(), #{non_neg_integer() => shard()}) ->
{ok, non_neg_integer()} | not_found.
find_shard_by_pid(Pid, Shards) ->
maps:fold(
fun
(_Index, _Shard, {ok, _} = Found) ->
Found;
(Index, #{pid := ExistingPid}, not_found) ->
case ExistingPid =:= Pid of
true -> {ok, Index};
false -> not_found
end
end,
not_found,
Shards
).
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
determine_shard_count_configured_test() ->
with_runtime_config(session_shards, 3, fun() ->
?assertMatch({3, configured}, determine_shard_count())
end).
determine_shard_count_auto_test() ->
with_runtime_config(session_shards, undefined, fun() ->
{Count, auto} = determine_shard_count(),
?assert(Count >= 1)
end).
default_shard_count_positive_test() ->
Count = default_shard_count(),
?assert(Count >= 1).
select_shard_deterministic_test() ->
SessionId = <<"session-abc">>,
ShardCount = 8,
Shard1 = select_shard(SessionId, ShardCount),
Shard2 = select_shard(SessionId, ShardCount),
?assertEqual(Shard1, Shard2).
select_shard_in_range_test() ->
ShardCount = 8,
lists:foreach(
fun(N) ->
SessionId = list_to_binary(integer_to_list(N)),
Shard = select_shard(SessionId, ShardCount),
?assert(Shard >= 0 andalso Shard < ShardCount)
end,
lists:seq(1, 100)
).
with_runtime_config(Key, Value, Fun) ->
Original = fluxer_gateway_env:get(Key),
fluxer_gateway_env:patch(#{Key => Value}),
Result = Fun(),
fluxer_gateway_env:update(fun(Map) ->
case Original of
undefined -> maps:remove(Key, Map);
Existing -> maps:put(Key, Existing, Map)
end
end),
Result.
-endif.

View File

@@ -0,0 +1,721 @@
%% Copyright (C) 2026 Fluxer Contributors
%%
%% This file is part of Fluxer.
%%
%% Fluxer is free software: you can redistribute it and/or modify
%% it under the terms of the GNU Affero General Public License as published by
%% the Free Software Foundation, either version 3 of the License, or
%% (at your option) any later version.
%%
%% Fluxer is distributed in the hope that it will be useful,
%% but WITHOUT ANY WARRANTY; without even the implied warranty of
%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
%% GNU Affero General Public License for more details.
%%
%% You should have received a copy of the GNU Affero General Public License
%% along with Fluxer. If not, see <https://www.gnu.org/licenses/>.
-module(session_manager_shard).
-behaviour(gen_server).
-include_lib("fluxer_gateway/include/timeout_config.hrl").
-export([start_link/1]).
-export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]).
-export_type([session_data/0, user_id/0]).
-define(IDENTIFY_FLAG_USE_CANARY_API, 16#1).
-define(IDENTIFY_FLAG_DEBOUNCE_MESSAGE_REACTIONS, 16#2).
-type session_id() :: binary().
-type user_id() :: integer().
-type session_ref() :: {pid(), reference()}.
-type status() :: online | offline | idle | dnd.
-type identify_timestamp() :: integer().
-type identify_request() :: #{
session_id := session_id(),
identify_data := map(),
version := non_neg_integer(),
peer_ip := term(),
token := binary()
}.
-type session_data() :: #{
id := session_id(),
user_id := user_id(),
user_data := map(),
version := non_neg_integer(),
token_hash := binary(),
auth_session_id_hash := binary(),
properties := map(),
status := status(),
afk := boolean(),
mobile := boolean(),
socket_pid := pid(),
guilds := [integer()],
ready := map(),
ignored_events := [binary()]
}.
-type state() :: #{
sessions := #{session_id() => session_ref()},
api_host := string(),
api_canary_host := undefined | string(),
identify_attempts := [identify_timestamp()],
pending_identifies := #{session_id() => pending_identify()},
identify_workers := #{reference() => session_id()},
shard_index := non_neg_integer()
}.
-type pending_identify() :: #{
request := identify_request(),
socket_pid := pid(),
froms := [gen_server:from()]
}.
-type start_reply() ::
{success, pid()}
| {error, invalid_token}
| {error, rate_limited}
| {error, identify_rate_limited}
| {error, {server_error, non_neg_integer()}}
| {error, {http_error, non_neg_integer()}}
| {error, {network_error, term()}}
| {error, registration_failed}
| {error, term()}.
-type lookup_reply() :: {ok, pid()} | {error, not_found}.
-spec start_link(non_neg_integer()) -> {ok, pid()} | {error, term()}.
start_link(ShardIndex) ->
gen_server:start_link(?MODULE, #{shard_index => ShardIndex}, []).
-spec init(map()) -> {ok, state()}.
init(Args) ->
fluxer_gateway_env:load(),
process_flag(trap_exit, true),
ApiHost = fluxer_gateway_env:get(api_host),
ApiCanaryHost = fluxer_gateway_env:get(api_canary_host),
ShardIndex = maps:get(shard_index, Args, 0),
{ok, #{
sessions => #{},
api_host => ApiHost,
api_canary_host => ApiCanaryHost,
identify_attempts => [],
pending_identifies => #{},
identify_workers => #{},
shard_index => ShardIndex
}}.
-spec handle_call(Request, From, State) -> Result when
Request ::
{start, identify_request(), pid()}
| {lookup, session_id()}
| get_local_count
| get_global_count
| term(),
From :: gen_server:from(),
State :: state(),
Result :: {reply, Reply, state()} | {noreply, state()},
Reply :: start_reply() | lookup_reply() | {ok, non_neg_integer()} | ok.
handle_call({start, Request, SocketPid}, From, State) ->
Sessions = maps:get(sessions, State),
Attempts = maps:get(identify_attempts, State),
PendingIdentifies = maps:get(pending_identifies, State),
SessionId = maps:get(session_id, Request),
case maps:get(SessionId, Sessions, undefined) of
{Pid, _Ref} ->
{reply, {success, Pid}, State};
undefined ->
SessionName = process_registry:build_process_name(session, SessionId),
case whereis(SessionName) of
undefined ->
case maps:get(SessionId, PendingIdentifies, undefined) of
undefined ->
case check_identify_rate_limit(Attempts) of
{ok, NewAttempts} ->
NewState = maps:put(identify_attempts, NewAttempts, State),
start_identify_fetch(Request, SocketPid, SessionId, From, NewState);
{error, rate_limited} ->
{reply, {error, identify_rate_limited}, State}
end;
PendingIdentify ->
UpdatedPending = PendingIdentify#{
froms => [From | maps:get(froms, PendingIdentify, [])]
},
NewPending = maps:put(SessionId, UpdatedPending, PendingIdentifies),
{noreply, maps:put(pending_identifies, NewPending, State)}
end;
Pid ->
Ref = monitor(process, Pid),
NewSessions = maps:put(SessionId, {Pid, Ref}, Sessions),
{reply, {success, Pid}, maps:put(sessions, NewSessions, State)}
end
end;
handle_call({lookup, SessionId}, _From, State) ->
Sessions = maps:get(sessions, State),
case maps:get(SessionId, Sessions, undefined) of
{Pid, _Ref} ->
{reply, {ok, Pid}, State};
undefined ->
SessionName = process_registry:build_process_name(session, SessionId),
case whereis(SessionName) of
undefined ->
{reply, {error, not_found}, State};
Pid ->
Ref = monitor(process, Pid),
NewSessions = maps:put(SessionId, {Pid, Ref}, Sessions),
{reply, {ok, Pid}, maps:put(sessions, NewSessions, State)}
end
end;
handle_call(get_local_count, _From, State) ->
Sessions = maps:get(sessions, State),
{reply, {ok, maps:size(Sessions)}, State};
handle_call(get_global_count, _From, State) ->
Sessions = maps:get(sessions, State),
{reply, {ok, maps:size(Sessions)}, State};
handle_call(_, _From, State) ->
{reply, ok, State}.
-spec build_and_start_session(
map(), map(), non_neg_integer(), pid(), session_id(), #{session_id() => session_ref()}, state()
) ->
{reply, start_reply(), state()}.
build_and_start_session(Data, IdentifyData, Version, SocketPid, SessionId, Sessions, State) ->
UserDataMap = maps:get(<<"user">>, Data),
UserId = type_conv:extract_id(UserDataMap, <<"id">>),
AuthSessionIdHashEncoded = maps:get(<<"auth_session_id_hash">>, Data, undefined),
AuthSessionIdHash =
case AuthSessionIdHashEncoded of
undefined -> <<>>;
null -> <<>>;
_ -> base64url:decode(AuthSessionIdHashEncoded)
end,
Status = parse_presence(Data, IdentifyData),
GuildIds = parse_guild_ids(Data),
Properties = maps:get(properties, IdentifyData),
Presence = map_utils:get_safe(IdentifyData, presence, null),
IgnoredEvents = map_utils:get_safe(IdentifyData, ignored_events, []),
InitialGuildId = map_utils:get_safe(IdentifyData, initial_guild_id, undefined),
Bot = map_utils:get_safe(UserDataMap, <<"bot">>, false),
ReadyData =
case Bot of
true -> maps:merge(Data, #{<<"guilds">> => []});
false -> Data
end,
UserSettingsMap = map_utils:get_safe(Data, <<"user_settings">>, #{}),
CustomStatusFromSettings = map_utils:get_safe(UserSettingsMap, <<"custom_status">>, null),
PresenceCustomStatus = get_presence_custom_status(Presence),
CustomStatus =
case CustomStatusFromSettings of
null -> PresenceCustomStatus;
_ -> CustomStatusFromSettings
end,
Mobile =
case Presence of
null -> map_utils:get_safe(Properties, <<"mobile">>, false);
P when is_map(P) -> map_utils:get_safe(P, <<"mobile">>, false);
_ -> false
end,
Afk =
case Presence of
null -> false;
P2 when is_map(P2) -> map_utils:get_safe(P2, <<"afk">>, false);
_ -> false
end,
UserData0 = #{
<<"id">> => maps:get(<<"id">>, UserDataMap),
<<"username">> => maps:get(<<"username">>, UserDataMap),
<<"discriminator">> => maps:get(<<"discriminator">>, UserDataMap),
<<"avatar">> => maps:get(<<"avatar">>, UserDataMap),
<<"avatar_color">> => map_utils:get_safe(UserDataMap, <<"avatar_color">>, undefined),
<<"bot">> => map_utils:get_safe(UserDataMap, <<"bot">>, undefined),
<<"system">> => map_utils:get_safe(UserDataMap, <<"system">>, undefined),
<<"flags">> => maps:get(<<"flags">>, UserDataMap)
},
NormalizedUserData = user_utils:normalize_user(UserData0),
UserData = maps:put(
<<"is_staff">>,
maps:get(<<"is_staff">>, UserDataMap, false),
NormalizedUserData
),
DebounceReactions = should_debounce_reactions(IdentifyData),
SessionData = #{
id => SessionId,
user_id => UserId,
user_data => UserData,
custom_status => CustomStatus,
version => Version,
token_hash => utils:hash_token(maps:get(token, IdentifyData)),
auth_session_id_hash => AuthSessionIdHash,
properties => Properties,
status => Status,
afk => Afk,
mobile => Mobile,
socket_pid => SocketPid,
guilds => GuildIds,
ready => ReadyData,
bot => Bot,
ignored_events => IgnoredEvents,
initial_guild_id => InitialGuildId,
debounce_reactions => DebounceReactions
},
start_session_process(SessionData, SessionId, Sessions, State).
-spec start_session_process(map(), session_id(), #{session_id() => session_ref()}, state()) ->
{reply, start_reply(), state()}.
start_session_process(SessionData, SessionId, Sessions, State) ->
SessionName = process_registry:build_process_name(session, SessionId),
case whereis(SessionName) of
undefined ->
case session:start_link(SessionData) of
{ok, Pid} ->
case process_registry:register_and_monitor(SessionName, Pid, Sessions) of
{ok, RegisteredPid, Ref, NewSessions0} ->
CleanSessions = maps:remove(SessionName, NewSessions0),
NewSessions = maps:put(SessionId, {RegisteredPid, Ref}, CleanSessions),
{reply, {success, RegisteredPid},
maps:put(sessions, NewSessions, State)};
{error, registration_race_condition} ->
{reply, {error, registration_failed}, State};
{error, _Reason} ->
{reply, {error, registration_failed}, State}
end;
Error ->
{reply, Error, State}
end;
ExistingPid ->
Ref = monitor(process, ExistingPid),
CleanSessions = maps:remove(SessionName, Sessions),
NewSessions = maps:put(SessionId, {ExistingPid, Ref}, CleanSessions),
{reply, {success, ExistingPid}, maps:put(sessions, NewSessions, State)}
end.
-spec handle_cast(term(), state()) -> {noreply, state()}.
handle_cast(_, State) ->
{noreply, State}.
-spec start_identify_fetch(identify_request(), pid(), session_id(), gen_server:from(), state()) ->
{noreply, state()}.
start_identify_fetch(Request, SocketPid, SessionId, From, State) ->
IdentifyData = maps:get(identify_data, Request),
UseCanary = should_use_canary_api(IdentifyData),
{_UsedCanary, RpcClient} = select_rpc_client(State, UseCanary),
ManagerPid = self(),
{_WorkerPid, WorkerRef} =
spawn_monitor(fun() ->
PeerIP = maps:get(peer_ip, Request),
FetchResult = fetch_rpc_data(Request, PeerIP, RpcClient),
ManagerPid ! {identify_fetch_result, SessionId, FetchResult}
end),
PendingIdentifies = maps:get(pending_identifies, State),
NewPending = maps:put(SessionId, #{
request => Request,
socket_pid => SocketPid,
froms => [From]
}, PendingIdentifies),
IdentifyWorkers = maps:get(identify_workers, State),
NewWorkers = maps:put(WorkerRef, SessionId, IdentifyWorkers),
{noreply, State#{
pending_identifies := NewPending,
identify_workers := NewWorkers
}}.
-spec select_rpc_client(state(), boolean()) -> {boolean(), string()}.
select_rpc_client(State, true) ->
case maps:get(api_canary_host, State) of
undefined ->
{false, maps:get(api_host, State)};
CanaryHost ->
{true, CanaryHost}
end;
select_rpc_client(State, false) ->
{false, maps:get(api_host, State)}.
-spec should_use_canary_api(map()) -> boolean().
should_use_canary_api(IdentifyData) ->
case map_utils:get_safe(IdentifyData, flags, 0) of
Flags when is_integer(Flags), Flags >= 0 ->
(Flags band ?IDENTIFY_FLAG_USE_CANARY_API) =/= 0;
_ ->
false
end.
-spec should_debounce_reactions(map()) -> boolean().
should_debounce_reactions(IdentifyData) ->
case map_utils:get_safe(IdentifyData, flags, 0) of
Flags when is_integer(Flags), Flags >= 0 ->
(Flags band ?IDENTIFY_FLAG_DEBOUNCE_MESSAGE_REACTIONS) =/= 0;
_ ->
false
end.
-spec handle_info(Info, State) -> {noreply, state()} when
Info :: {'DOWN', reference(), process, pid(), term()} | term(),
State :: state().
handle_info({identify_fetch_result, SessionId, FetchResult}, State) ->
complete_identify_fetch(SessionId, FetchResult, State);
handle_info({'DOWN', Ref, process, Pid, Reason}, State) ->
IdentifyWorkers = maps:get(identify_workers, State),
case maps:take(Ref, IdentifyWorkers) of
{SessionId, RemainingWorkers} ->
StateWithoutWorker = maps:put(identify_workers, RemainingWorkers, State),
maybe_fail_pending_identify(SessionId, Reason, StateWithoutWorker);
error ->
Sessions = maps:get(sessions, State),
NewSessions = process_registry:cleanup_on_down(Pid, Sessions),
{noreply, maps:put(sessions, NewSessions, State)}
end;
handle_info(_, State) ->
{noreply, State}.
-spec complete_identify_fetch(session_id(), term(), state()) -> {noreply, state()}.
complete_identify_fetch(SessionId, FetchResult, State) ->
PendingIdentifies = maps:get(pending_identifies, State),
case maps:take(SessionId, PendingIdentifies) of
error ->
{noreply, State};
{PendingIdentify, RemainingPending} ->
State1 = maps:put(pending_identifies, RemainingPending, State),
State2 = cleanup_identify_worker(SessionId, State1),
{Reply, NewState} = resolve_identify_result(FetchResult, PendingIdentify, SessionId, State2),
reply_to_waiters(maps:get(froms, PendingIdentify, []), Reply),
{noreply, NewState}
end.
-spec maybe_fail_pending_identify(session_id(), term(), state()) -> {noreply, state()}.
maybe_fail_pending_identify(_SessionId, Reason, State) when
Reason =:= normal; Reason =:= shutdown
->
{noreply, State};
maybe_fail_pending_identify(SessionId, Reason, State) ->
PendingIdentifies = maps:get(pending_identifies, State),
case maps:take(SessionId, PendingIdentifies) of
error ->
{noreply, State};
{PendingIdentify, RemainingPending} ->
reply_to_waiters(
maps:get(froms, PendingIdentify, []),
{error, {network_error, Reason}}
),
NewState = maps:put(pending_identifies, RemainingPending, State),
{noreply, NewState}
end.
-spec cleanup_identify_worker(session_id(), state()) -> state().
cleanup_identify_worker(SessionId, State) ->
IdentifyWorkers = maps:get(identify_workers, State),
RemainingWorkers = maps:filter(
fun(_Ref, WorkerSessionId) -> WorkerSessionId =/= SessionId end,
IdentifyWorkers
),
maps:put(identify_workers, RemainingWorkers, State).
-spec resolve_identify_result(term(), pending_identify(), session_id(), state()) ->
{start_reply(), state()}.
resolve_identify_result({ok, Data}, PendingIdentify, SessionId, State) ->
Request = maps:get(request, PendingIdentify),
IdentifyData = maps:get(identify_data, Request),
Version = maps:get(version, Request),
SocketPid = maps:get(socket_pid, PendingIdentify),
Sessions = maps:get(sessions, State),
{reply, Reply, NewState} = build_and_start_session(
Data,
IdentifyData,
Version,
SocketPid,
SessionId,
Sessions,
State
),
{Reply, NewState};
resolve_identify_result({error, invalid_token}, _PendingIdentify, _SessionId, State) ->
{{error, invalid_token}, State};
resolve_identify_result({error, rate_limited}, _PendingIdentify, _SessionId, State) ->
{{error, rate_limited}, State};
resolve_identify_result({error, Reason}, _PendingIdentify, _SessionId, State) ->
{{error, Reason}, State}.
-spec reply_to_waiters([gen_server:from()], start_reply()) -> ok.
reply_to_waiters(Waiters, Reply) ->
lists:foreach(fun(From) -> gen_server:reply(From, Reply) end, Waiters),
ok.
-spec terminate(Reason, State) -> ok when
Reason :: term(),
State :: state().
terminate(_Reason, _State) ->
ok.
-spec code_change(OldVsn, State, Extra) -> {ok, state()} when
OldVsn :: term(),
State :: state() | tuple(),
Extra :: term().
code_change(_OldVsn, State, _Extra) when is_map(State) ->
{ok, State#{
pending_identifies => maps:get(pending_identifies, State, #{}),
identify_workers => maps:get(identify_workers, State, #{}),
shard_index => maps:get(shard_index, State, 0)
}};
code_change(_OldVsn, State, _Extra) when is_tuple(State), element(1, State) =:= state ->
Sessions = element(2, State),
ApiHost = element(3, State),
ApiCanaryHost = element(4, State),
IdentifyAttempts = element(5, State),
{ok, #{
sessions => Sessions,
api_host => ApiHost,
api_canary_host => ApiCanaryHost,
identify_attempts => IdentifyAttempts,
pending_identifies => #{},
identify_workers => #{},
shard_index => 0
}};
code_change(_OldVsn, State, _Extra) ->
{ok, State}.
-spec fetch_rpc_data(map(), term(), string()) ->
{ok, map()}
| {error, invalid_token}
| {error, rate_limited}
| {error, {server_error, non_neg_integer()}}
| {error, {http_error, non_neg_integer()}}
| {error, {network_error, term()}}.
fetch_rpc_data(Request, PeerIP, ApiHost) ->
StartTime = erlang:system_time(millisecond),
Result = do_fetch_rpc_data(Request, PeerIP, ApiHost),
EndTime = erlang:system_time(millisecond),
LatencyMs = EndTime - StartTime,
gateway_metrics_collector:record_rpc_latency(LatencyMs),
Result.
-spec do_fetch_rpc_data(map(), term(), string()) ->
{ok, map()}
| {error, invalid_token}
| {error, rate_limited}
| {error, {server_error, non_neg_integer()}}
| {error, {http_error, non_neg_integer()}}
| {error, {network_error, term()}}.
do_fetch_rpc_data(Request, PeerIP, ApiHost) ->
Url = rpc_client:get_rpc_url(ApiHost),
Headers = rpc_client:get_rpc_headers() ++ [{<<"content-type">>, <<"application/json">>}],
IdentifyData = maps:get(identify_data, Request),
Properties = map_utils:get_safe(IdentifyData, properties, #{}),
LatitudeRaw = map_utils:get_safe(Properties, <<"latitude">>, undefined),
LongitudeRaw = map_utils:get_safe(Properties, <<"longitude">>, undefined),
Latitude = normalize_coordinate(LatitudeRaw),
Longitude = normalize_coordinate(LongitudeRaw),
RpcRequest = #{
<<"type">> => <<"session">>,
<<"token">> => maps:get(token, IdentifyData),
<<"version">> => maps:get(version, Request),
<<"ip">> => PeerIP
},
RpcRequestWithCoords = add_coordinates(RpcRequest, Latitude, Longitude),
Body = json:encode(RpcRequestWithCoords),
execute_rpc_request(Url, Headers, Body).
-spec normalize_coordinate(term()) -> term() | undefined.
normalize_coordinate(undefined) -> undefined;
normalize_coordinate(null) -> undefined;
normalize_coordinate(Value) -> Value.
-spec add_coordinates(map(), term(), term()) -> map().
add_coordinates(Request, undefined, undefined) ->
Request;
add_coordinates(Request, Lat, undefined) ->
maps:put(<<"latitude">>, Lat, Request);
add_coordinates(Request, undefined, Lon) ->
maps:put(<<"longitude">>, Lon, Request);
add_coordinates(Request, Lat, Lon) ->
maps:merge(Request, #{<<"latitude">> => Lat, <<"longitude">> => Lon}).
-spec execute_rpc_request(iodata(), list(), binary()) ->
{ok, map()}
| {error, invalid_token}
| {error, rate_limited}
| {error, {server_error, non_neg_integer()}}
| {error, {http_error, non_neg_integer()}}
| {error, {network_error, term()}}.
execute_rpc_request(Url, Headers, Body) ->
case gateway_http_client:request(rpc, post, Url, Headers, Body) of
{ok, 200, _RespHeaders, ResponseBody} ->
ResponseData = json:decode(ResponseBody),
{ok, maps:get(<<"data">>, ResponseData)};
{ok, 401, _RespHeaders, _ResponseBody} ->
{error, invalid_token};
{ok, 429, _RespHeaders, _ResponseBody} ->
{error, rate_limited};
{ok, StatusCode, _RespHeaders, _ResponseBody} when StatusCode >= 500 ->
{error, {server_error, StatusCode}};
{ok, StatusCode, _RespHeaders, _ResponseBody} when StatusCode >= 400 ->
{error, {http_error, StatusCode}};
{ok, StatusCode, _RespHeaders, _ResponseBody} ->
{error, {http_error, StatusCode}};
{error, Reason} ->
{error, {network_error, Reason}}
end.
-spec parse_presence(map(), map()) -> status().
parse_presence(Data, IdentifyData) ->
StoredStatus = get_stored_status(Data),
PresenceStatus =
case map_utils:get_safe(IdentifyData, presence, null) of
null ->
undefined;
Presence when is_map(Presence) ->
map_utils:get_safe(Presence, status, <<"online">>);
_ ->
undefined
end,
SelectedStatus = select_initial_status(PresenceStatus, StoredStatus),
utils:parse_status(SelectedStatus).
-spec parse_guild_ids(map()) -> [integer()].
parse_guild_ids(Data) ->
GuildIds = map_utils:get_safe(Data, <<"guild_ids">>, []),
[utils:binary_to_integer_safe(Id) || Id <- GuildIds, Id =/= undefined].
-spec check_identify_rate_limit(list()) -> {ok, list()} | {error, rate_limited}.
check_identify_rate_limit(Attempts) ->
case fluxer_gateway_env:get(identify_rate_limit_enabled) of
true ->
Now = erlang:system_time(millisecond),
WindowDuration = 5000,
AttemptsInWindow = [T || T <- Attempts, (Now - T) < WindowDuration],
AttemptsCount = length(AttemptsInWindow),
MaxIdentifiesPerWindow = 1,
case AttemptsCount >= MaxIdentifiesPerWindow of
true ->
{error, rate_limited};
false ->
NewAttempts = [Now | AttemptsInWindow],
{ok, NewAttempts}
end;
_ ->
{ok, Attempts}
end.
-spec get_presence_custom_status(term()) -> map() | null.
get_presence_custom_status(Presence) ->
case Presence of
null -> null;
Map when is_map(Map) -> map_utils:get_safe(Map, <<"custom_status">>, null);
_ -> null
end.
-spec get_stored_status(map()) -> binary().
get_stored_status(Data) ->
case map_utils:get_safe(Data, <<"user_settings">>, null) of
null ->
<<"online">>;
UserSettings ->
case normalize_status(map_utils:get_safe(UserSettings, <<"status">>, <<"online">>)) of
undefined -> <<"online">>;
Value -> Value
end
end.
-spec select_initial_status(binary() | undefined, binary()) -> binary().
select_initial_status(PresenceStatus, StoredStatus) ->
NormalizedPresence = normalize_status(PresenceStatus),
case {NormalizedPresence, StoredStatus} of
{undefined, Stored} ->
Stored;
{<<"unknown">>, Stored} ->
Stored;
{<<"online">>, Stored} when Stored =/= <<"online">> ->
Stored;
{Presence, _} ->
Presence
end.
-spec normalize_status(term()) -> binary() | undefined.
normalize_status(undefined) ->
undefined;
normalize_status(null) ->
undefined;
normalize_status(Status) when is_binary(Status) ->
Status;
normalize_status(Status) when is_atom(Status) ->
try constants:status_type_atom(Status) of
Value when is_binary(Value) -> Value
catch
_:_ -> undefined
end;
normalize_status(_) ->
undefined.
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
normalize_status_test() ->
?assertEqual(undefined, normalize_status(undefined)),
?assertEqual(undefined, normalize_status(null)),
?assertEqual(<<"online">>, normalize_status(<<"online">>)),
?assertEqual(<<"idle">>, normalize_status(<<"idle">>)),
?assertEqual(undefined, normalize_status(123)),
ok.
select_initial_status_test() ->
?assertEqual(<<"idle">>, select_initial_status(undefined, <<"idle">>)),
?assertEqual(<<"dnd">>, select_initial_status(<<"unknown">>, <<"dnd">>)),
?assertEqual(<<"idle">>, select_initial_status(<<"online">>, <<"idle">>)),
?assertEqual(<<"online">>, select_initial_status(<<"online">>, <<"online">>)),
?assertEqual(<<"dnd">>, select_initial_status(<<"dnd">>, <<"online">>)),
ok.
normalize_coordinate_test() ->
?assertEqual(undefined, normalize_coordinate(undefined)),
?assertEqual(undefined, normalize_coordinate(null)),
?assertEqual(1.5, normalize_coordinate(1.5)),
?assertEqual(<<"test">>, normalize_coordinate(<<"test">>)),
ok.
add_coordinates_test() ->
Base = #{<<"type">> => <<"session">>},
?assertEqual(Base, add_coordinates(Base, undefined, undefined)),
?assertEqual(
#{<<"type">> => <<"session">>, <<"latitude">> => 1.0}, add_coordinates(Base, 1.0, undefined)
),
?assertEqual(
#{<<"type">> => <<"session">>, <<"longitude">> => 2.0},
add_coordinates(Base, undefined, 2.0)
),
?assertEqual(
#{<<"type">> => <<"session">>, <<"latitude">> => 1.0, <<"longitude">> => 2.0},
add_coordinates(Base, 1.0, 2.0)
),
ok.
should_use_canary_api_test() ->
?assertEqual(false, should_use_canary_api(#{})),
?assertEqual(false, should_use_canary_api(#{flags => 0})),
?assertEqual(true, should_use_canary_api(#{flags => 1})),
?assertEqual(true, should_use_canary_api(#{flags => 17})),
?assertEqual(false, should_use_canary_api(#{flags => 2})),
?assertEqual(false, should_use_canary_api(#{flags => -1})),
ok.
should_debounce_reactions_test() ->
?assertEqual(false, should_debounce_reactions(#{})),
?assertEqual(false, should_debounce_reactions(#{flags => 0})),
?assertEqual(false, should_debounce_reactions(#{flags => 1})),
?assertEqual(true, should_debounce_reactions(#{flags => 2})),
?assertEqual(true, should_debounce_reactions(#{flags => 3})),
?assertEqual(true, should_debounce_reactions(#{flags => 18})),
?assertEqual(false, should_debounce_reactions(#{flags => -1})),
ok.
get_presence_custom_status_test() ->
?assertEqual(null, get_presence_custom_status(null)),
?assertEqual(null, get_presence_custom_status(#{})),
?assertEqual(
#{<<"text">> => <<"hello">>},
get_presence_custom_status(#{<<"custom_status">> => #{<<"text">> => <<"hello">>}})
),
?assertEqual(null, get_presence_custom_status(not_a_map)),
ok.
-endif.

View File

@@ -23,78 +23,96 @@
find_call_by_ref/2
]).
handle_process_down(Ref, _Reason, State) ->
-type session_state() :: session:session_state().
-type guild_id() :: session:guild_id().
-type channel_id() :: session:channel_id().
-type guild_ref() :: {pid(), reference()} | undefined | cached_unavailable.
-type call_ref() :: {pid(), reference()}.
-spec handle_process_down(reference(), term(), session_state()) ->
{noreply, session_state()}.
handle_process_down(Ref, Reason, State) ->
SocketRef = maps:get(socket_mref, State, undefined),
PresenceRef = maps:get(presence_mref, State, undefined),
Guilds = maps:get(guilds, State),
Calls = maps:get(calls, State, #{}),
case Ref of
SocketRef when Ref =:= SocketRef ->
self() ! {presence_update, #{status => offline}},
erlang:send_after(10000, self(), resume_timeout),
{noreply, maps:merge(State, #{socket_pid => undefined, socket_mref => undefined})};
handle_socket_down(State);
PresenceRef when Ref =:= PresenceRef ->
self() ! {presence_connect, 0},
{noreply, maps:put(presence_pid, undefined, State)};
handle_presence_down(State);
_ ->
case find_guild_by_ref(Ref, Guilds) of
{ok, GuildId} ->
handle_guild_down(GuildId, _Reason, State, Guilds);
handle_guild_down(GuildId, Reason, State, Guilds);
not_found ->
case find_call_by_ref(Ref, Calls) of
{ok, ChannelId} ->
handle_call_down(ChannelId, _Reason, State, Calls);
handle_call_down(ChannelId, Reason, State, Calls);
not_found ->
{noreply, State}
end
end
end.
handle_guild_down(GuildId, Reason, State, Guilds) ->
case Reason of
killed ->
gen_server:cast(self(), {guild_leave, GuildId}),
{noreply, State};
_ ->
GuildDeleteData = #{
<<"id">> => integer_to_binary(GuildId),
<<"unavailable">> => true
},
{noreply, UpdatedState} = session_dispatch:handle_dispatch(
guild_delete, GuildDeleteData, State
),
-spec handle_socket_down(session_state()) -> {noreply, session_state()}.
handle_socket_down(State) ->
self() ! {presence_update, #{status => offline}},
erlang:send_after(10000, self(), resume_timeout),
{noreply, maps:merge(State, #{socket_pid => undefined, socket_mref => undefined})}.
NewGuilds = maps:put(GuildId, undefined, Guilds),
erlang:send_after(1000, self(), {guild_connect, GuildId, 0}),
{noreply, maps:put(guilds, NewGuilds, UpdatedState)}
end.
-spec handle_presence_down(session_state()) -> {noreply, session_state()}.
handle_presence_down(State) ->
self() ! {presence_connect, 0},
{noreply, maps:put(presence_pid, undefined, State)}.
handle_call_down(ChannelId, Reason, State, Calls) ->
case Reason of
killed ->
NewCalls = maps:remove(ChannelId, Calls),
{noreply, maps:put(calls, NewCalls, State)};
_ ->
CallDeleteData = #{
<<"channel_id">> => integer_to_binary(ChannelId),
<<"unavailable">> => true
},
{noreply, UpdatedState} = session_dispatch:handle_dispatch(
call_delete, CallDeleteData, State
),
-spec handle_guild_down(guild_id(), term(), session_state(), #{guild_id() => guild_ref()}) ->
{noreply, session_state()}.
handle_guild_down(GuildId, killed, State, _Guilds) ->
gen_server:cast(self(), {guild_leave, GuildId}),
{noreply, State};
handle_guild_down(GuildId, normal, State, _Guilds) ->
gen_server:cast(self(), {guild_leave, GuildId}),
{noreply, State};
handle_guild_down(GuildId, _Reason, State, Guilds) ->
GuildDeleteData = #{
<<"id">> => integer_to_binary(GuildId),
<<"unavailable">> => true
},
{noreply, UpdatedState} = session_dispatch:handle_dispatch(
guild_delete, GuildDeleteData, State
),
NewGuilds = maps:put(GuildId, undefined, Guilds),
erlang:send_after(1000, self(), {guild_connect, GuildId, 0}),
{noreply, maps:put(guilds, NewGuilds, UpdatedState)}.
NewCalls = maps:put(ChannelId, undefined, Calls),
erlang:send_after(1000, self(), {call_reconnect, ChannelId, 0}),
{noreply, maps:put(calls, NewCalls, UpdatedState)}
end.
-spec handle_call_down(channel_id(), term(), session_state(), #{channel_id() => call_ref()}) ->
{noreply, session_state()}.
handle_call_down(ChannelId, killed, State, Calls) ->
NewCalls = maps:remove(ChannelId, Calls),
{noreply, maps:put(calls, NewCalls, State)};
handle_call_down(ChannelId, _Reason, State, Calls) ->
CallDeleteData = #{
<<"channel_id">> => integer_to_binary(ChannelId),
<<"unavailable">> => true
},
{noreply, UpdatedState} = session_dispatch:handle_dispatch(call_delete, CallDeleteData, State),
NewCalls = maps:put(ChannelId, undefined, Calls),
erlang:send_after(1000, self(), {call_reconnect, ChannelId, 0}),
{noreply, maps:put(calls, NewCalls, UpdatedState)}.
-spec find_guild_by_ref(reference(), #{guild_id() => guild_ref()}) ->
{ok, guild_id()} | not_found.
find_guild_by_ref(Ref, Guilds) ->
find_by_ref(Ref, Guilds).
-spec find_call_by_ref(reference(), #{channel_id() => call_ref()}) ->
{ok, channel_id()} | not_found.
find_call_by_ref(Ref, Calls) ->
find_by_ref(Ref, Calls).
-spec find_by_ref(reference(), #{integer() => {pid(), reference()} | undefined}) ->
{ok, integer()} | not_found.
find_by_ref(Ref, Map) ->
maps:fold(
fun
@@ -104,3 +122,36 @@ find_by_ref(Ref, Map) ->
not_found,
Map
).
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
find_by_ref_test() ->
Ref1 = make_ref(),
Ref2 = make_ref(),
Ref3 = make_ref(),
Map = #{
123 => {self(), Ref1},
456 => {self(), Ref2},
789 => undefined
},
?assertEqual({ok, 123}, find_by_ref(Ref1, Map)),
?assertEqual({ok, 456}, find_by_ref(Ref2, Map)),
?assertEqual(not_found, find_by_ref(Ref3, Map)),
ok.
find_guild_by_ref_test() ->
Ref = make_ref(),
Guilds = #{100 => {self(), Ref}, 200 => undefined},
?assertEqual({ok, 100}, find_guild_by_ref(Ref, Guilds)),
?assertEqual(not_found, find_guild_by_ref(make_ref(), Guilds)),
ok.
find_call_by_ref_test() ->
Ref = make_ref(),
Calls = #{300 => {self(), Ref}},
?assertEqual({ok, 300}, find_call_by_ref(Ref, Calls)),
?assertEqual(not_found, find_call_by_ref(make_ref(), Calls)),
ok.
-endif.

View File

@@ -31,10 +31,13 @@
clear_guild_synced/2
]).
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-endif.
-type guild_id() :: session:guild_id().
-type user_id() :: session:user_id().
-type event() :: atom().
-type session_data() :: map().
-type guild_state() :: map().
-spec is_passive(guild_id(), session_data()) -> boolean().
is_passive(GuildId, SessionData) ->
case maps:get(bot, SessionData, false) of
true ->
@@ -44,16 +47,19 @@ is_passive(GuildId, SessionData) ->
not sets:is_element(GuildId, ActiveGuilds)
end.
-spec set_active(guild_id(), session_data()) -> session_data().
set_active(GuildId, SessionData) ->
ActiveGuilds = maps:get(active_guilds, SessionData, sets:new()),
NewActiveGuilds = sets:add_element(GuildId, ActiveGuilds),
maps:put(active_guilds, NewActiveGuilds, SessionData).
-spec set_passive(guild_id(), session_data()) -> session_data().
set_passive(GuildId, SessionData) ->
ActiveGuilds = maps:get(active_guilds, SessionData, sets:new()),
NewActiveGuilds = sets:del_element(GuildId, ActiveGuilds),
maps:put(active_guilds, NewActiveGuilds, SessionData).
-spec should_receive_event(event(), map(), guild_id(), session_data(), guild_state()) -> boolean().
should_receive_event(Event, EventData, GuildId, SessionData, State) ->
case Event of
typing_start ->
@@ -63,15 +69,17 @@ should_receive_event(Event, EventData, GuildId, SessionData, State) ->
true ->
true;
false ->
case is_message_event(Event) of
case is_lazy_guild_event(Event) of
true ->
case is_small_guild(State) of
true ->
true;
false ->
case is_passive(GuildId, SessionData) of
false -> true;
true -> should_passive_receive(Event, EventData, SessionData)
false ->
true;
true ->
should_passive_receive(Event, EventData, SessionData)
end
end;
false ->
@@ -83,31 +91,36 @@ should_receive_event(Event, EventData, GuildId, SessionData, State) ->
end
end.
-spec is_small_guild(guild_state()) -> boolean().
is_small_guild(State) ->
MemberCount = maps:get(member_count, State, undefined),
case MemberCount of
undefined -> false; %% Conservative: treat as large
undefined -> false;
Count when is_integer(Count) -> Count =< 250
end.
-spec is_message_event(event()) -> boolean().
is_message_event(message_create) -> true;
is_message_event(message_update) -> true;
is_message_event(message_delete) -> true;
is_message_event(message_delete_bulk) -> true;
is_message_event(_) -> false.
-spec is_lazy_guild_event(event()) -> boolean().
is_lazy_guild_event(Event) ->
is_message_event(Event) orelse Event =:= voice_state_update.
-spec should_passive_receive(event(), map(), session_data()) -> boolean().
should_passive_receive(message_create, EventData, SessionData) ->
is_user_mentioned(EventData, SessionData);
Mentioned = is_user_mentioned(EventData, SessionData),
case Mentioned of
true ->
true;
false ->
false
end;
should_passive_receive(guild_delete, _EventData, _SessionData) ->
true;
should_passive_receive(channel_create, _EventData, _SessionData) ->
true;
should_passive_receive(channel_delete, _EventData, _SessionData) ->
true;
should_passive_receive(passive_updates, _EventData, _SessionData) ->
true;
should_passive_receive(guild_update, _EventData, _SessionData) ->
true;
should_passive_receive(guild_member_update, EventData, SessionData) ->
UserId = maps:get(user_id, SessionData),
MemberUser = maps:get(<<"user">>, EventData, #{}),
@@ -118,26 +131,23 @@ should_passive_receive(guild_member_remove, EventData, SessionData) ->
MemberUser = maps:get(<<"user">>, EventData, #{}),
MemberUserId = map_utils:get_integer(MemberUser, <<"id">>, undefined),
UserId =:= MemberUserId;
should_passive_receive(voice_state_update, EventData, SessionData) ->
UserId = maps:get(user_id, SessionData),
EventUserId = map_utils:get_integer(EventData, <<"user_id">>, undefined),
UserId =:= EventUserId;
should_passive_receive(voice_server_update, _EventData, _SessionData) ->
should_passive_receive(passive_updates, _EventData, _SessionData) ->
true;
should_passive_receive(_, _, _) ->
false.
-spec is_user_mentioned(map(), session_data()) -> boolean().
is_user_mentioned(EventData, SessionData) ->
UserId = maps:get(user_id, SessionData),
MentionEveryone = maps:get(<<"mention_everyone">>, EventData, false),
Mentions = maps:get(<<"mentions">>, EventData, []),
MentionRoles = maps:get(<<"mention_roles">>, EventData, []),
UserRoles = maps:get(user_roles, SessionData, []),
MentionEveryone orelse
is_user_in_mentions(UserId, Mentions) orelse
has_mentioned_role(UserRoles, MentionRoles).
-spec is_user_in_mentions(user_id(), [map()]) -> boolean().
is_user_in_mentions(_UserId, []) ->
false;
is_user_in_mentions(UserId, [#{<<"id">> := Id} | Rest]) when is_binary(Id) ->
@@ -150,6 +160,7 @@ is_user_in_mentions(UserId, [#{<<"id">> := Id} | Rest]) when is_binary(Id) ->
is_user_in_mentions(UserId, [_ | Rest]) ->
is_user_in_mentions(UserId, Rest).
-spec has_mentioned_role([integer()], [binary() | integer()]) -> boolean().
has_mentioned_role([], _MentionRoles) ->
false;
has_mentioned_role([RoleId | Rest], MentionRoles) ->
@@ -158,39 +169,31 @@ has_mentioned_role([RoleId | Rest], MentionRoles) ->
lists:member(RoleId, MentionRoles) orelse
has_mentioned_role(Rest, MentionRoles).
-spec get_user_roles_for_guild(user_id(), guild_state()) -> [integer()].
get_user_roles_for_guild(UserId, GuildState) ->
Data = maps:get(data, GuildState, #{}),
Members = maps:get(<<"members">>, Data, []),
case find_member_by_user_id(UserId, Members) of
case guild_permissions:find_member_by_user_id(UserId, GuildState) of
undefined -> [];
Member -> extract_role_ids(maps:get(<<"roles">>, Member, []))
end.
find_member_by_user_id(_UserId, []) ->
undefined;
find_member_by_user_id(UserId, [Member | Rest]) ->
User = maps:get(<<"user">>, Member, #{}),
MemberUserId = map_utils:get_integer(User, <<"id">>, undefined),
case UserId =:= MemberUserId of
true -> Member;
false -> find_member_by_user_id(UserId, Rest)
end.
-spec extract_role_ids([binary() | integer()]) -> [integer()].
extract_role_ids(Roles) ->
lists:filtermap(
fun(Role) when is_binary(Role) ->
case validation:validate_snowflake(<<"role">>, Role) of
{ok, RoleId} -> {true, RoleId};
{error, _, _} -> false
end;
(Role) when is_integer(Role) ->
{true, Role};
(_) ->
false
fun
(Role) when is_binary(Role) ->
case validation:validate_snowflake(<<"role">>, Role) of
{ok, RoleId} -> {true, RoleId};
{error, _, _} -> false
end;
(Role) when is_integer(Role) ->
{true, Role};
(_) ->
false
end,
Roles
).
-spec should_receive_typing(guild_id(), session_data()) -> boolean().
should_receive_typing(GuildId, SessionData) ->
case get_typing_override(GuildId, SessionData) of
undefined ->
@@ -199,30 +202,36 @@ should_receive_typing(GuildId, SessionData) ->
TypingFlag
end.
-spec set_typing_override(guild_id(), boolean(), session_data()) -> session_data().
set_typing_override(GuildId, TypingFlag, SessionData) ->
TypingOverrides = maps:get(typing_overrides, SessionData, #{}),
NewTypingOverrides = maps:put(GuildId, TypingFlag, TypingOverrides),
maps:put(typing_overrides, NewTypingOverrides, SessionData).
-spec get_typing_override(guild_id(), session_data()) -> boolean() | undefined.
get_typing_override(GuildId, SessionData) ->
TypingOverrides = maps:get(typing_overrides, SessionData, #{}),
maps:get(GuildId, TypingOverrides, undefined).
-spec is_guild_synced(guild_id(), session_data()) -> boolean().
is_guild_synced(GuildId, SessionData) ->
SyncedGuilds = maps:get(synced_guilds, SessionData, sets:new()),
sets:is_element(GuildId, SyncedGuilds).
-spec mark_guild_synced(guild_id(), session_data()) -> session_data().
mark_guild_synced(GuildId, SessionData) ->
SyncedGuilds = maps:get(synced_guilds, SessionData, sets:new()),
NewSyncedGuilds = sets:add_element(GuildId, SyncedGuilds),
maps:put(synced_guilds, NewSyncedGuilds, SessionData).
-spec clear_guild_synced(guild_id(), session_data()) -> session_data().
clear_guild_synced(GuildId, SessionData) ->
SyncedGuilds = maps:get(synced_guilds, SessionData, sets:new()),
NewSyncedGuilds = sets:del_element(GuildId, SyncedGuilds),
maps:put(synced_guilds, NewSyncedGuilds, SessionData).
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
is_passive_test() ->
SessionData = #{active_guilds => sets:from_list([123, 456])},
@@ -262,13 +271,13 @@ should_receive_event_passive_guild_delete_test() ->
should_receive_event_passive_channel_create_test() ->
SessionData = #{user_id => 1, active_guilds => sets:new()},
State = #{member_count => 100},
?assertEqual(true, should_receive_event(channel_create, #{}, 123, SessionData, State)),
?assertEqual(false, should_receive_event(channel_create, #{}, 123, SessionData, State)),
ok.
should_receive_event_passive_channel_delete_test() ->
SessionData = #{user_id => 1, active_guilds => sets:new()},
State = #{member_count => 100},
?assertEqual(true, should_receive_event(channel_delete, #{}, 123, SessionData, State)),
?assertEqual(false, should_receive_event(channel_delete, #{}, 123, SessionData, State)),
ok.
should_receive_event_passive_passive_updates_test() ->
@@ -280,7 +289,7 @@ should_receive_event_passive_passive_updates_test() ->
should_receive_event_passive_message_not_mentioned_test() ->
SessionData = #{user_id => 1, active_guilds => sets:new(), user_roles => []},
EventData = #{<<"mentions">> => [], <<"mention_roles">> => [], <<"mention_everyone">> => false},
State = #{member_count => 300}, %% Large guild
State = #{member_count => 300},
?assertEqual(false, should_receive_event(message_create, EventData, 123, SessionData, State)),
ok.
@@ -291,14 +300,14 @@ should_receive_event_passive_message_user_mentioned_test() ->
<<"mention_roles">> => [],
<<"mention_everyone">> => false
},
State = #{member_count => 300}, %% Large guild
State = #{member_count => 300},
?assertEqual(true, should_receive_event(message_create, EventData, 123, SessionData, State)),
ok.
should_receive_event_passive_message_mention_everyone_test() ->
SessionData = #{user_id => 1, active_guilds => sets:new(), user_roles => []},
EventData = #{<<"mentions">> => [], <<"mention_roles">> => [], <<"mention_everyone">> => true},
State = #{member_count => 300}, %% Large guild
State = #{member_count => 300},
?assertEqual(true, should_receive_event(message_create, EventData, 123, SessionData, State)),
ok.
@@ -307,25 +316,43 @@ should_receive_event_passive_message_role_mentioned_test() ->
EventData = #{
<<"mentions">> => [], <<"mention_roles">> => [<<"100">>], <<"mention_everyone">> => false
},
State = #{member_count => 300}, %% Large guild
State = #{member_count => 300},
?assertEqual(true, should_receive_event(message_create, EventData, 123, SessionData, State)),
ok.
should_receive_event_passive_other_event_test() ->
SessionData = #{user_id => 1, active_guilds => sets:new()},
State = #{member_count => 300}, %% Large guild
State = #{member_count => 300},
?assertEqual(false, should_receive_event(typing_start, #{}, 123, SessionData, State)),
?assertEqual(false, should_receive_event(message_update, #{}, 123, SessionData, State)),
ok.
should_receive_event_small_guild_all_sessions_receive_messages_test() ->
SessionData = #{user_id => 1, active_guilds => sets:new()},
State = #{member_count => 100}, %% Small guild
State = #{member_count => 100},
?assertEqual(true, should_receive_event(message_create, #{}, 123, SessionData, State)),
?assertEqual(true, should_receive_event(message_update, #{}, 123, SessionData, State)),
?assertEqual(true, should_receive_event(message_delete, #{}, 123, SessionData, State)),
ok.
should_receive_event_small_guild_voice_state_test() ->
SessionData = #{user_id => 1, active_guilds => sets:new()},
State = #{member_count => 100},
EventData = #{<<"user_id">> => <<"2">>},
?assertEqual(
true, should_receive_event(voice_state_update, EventData, 123, SessionData, State)
),
ok.
should_receive_event_passive_voice_state_blocked_test() ->
SessionData = #{user_id => 1, active_guilds => sets:new()},
State = #{member_count => 300},
EventData = #{<<"user_id">> => <<"2">>},
?assertEqual(
false, should_receive_event(voice_state_update, EventData, 123, SessionData, State)
),
ok.
is_passive_bot_always_active_test() ->
BotSessionData = #{user_id => 1, active_guilds => sets:new(), bot => true},
?assertEqual(false, is_passive(123, BotSessionData)),
@@ -342,4 +369,35 @@ should_receive_event_bot_always_receives_test() ->
?assertEqual(true, should_receive_event(guild_delete, #{}, 123, BotSessionData, State)),
ok.
is_small_guild_test() ->
?assertEqual(true, is_small_guild(#{member_count => 100})),
?assertEqual(true, is_small_guild(#{member_count => 250})),
?assertEqual(false, is_small_guild(#{member_count => 251})),
?assertEqual(false, is_small_guild(#{member_count => 1000})),
?assertEqual(false, is_small_guild(#{})),
ok.
is_message_event_test() ->
?assertEqual(true, is_message_event(message_create)),
?assertEqual(true, is_message_event(message_update)),
?assertEqual(true, is_message_event(message_delete)),
?assertEqual(true, is_message_event(message_delete_bulk)),
?assertEqual(false, is_message_event(typing_start)),
?assertEqual(false, is_message_event(guild_create)),
ok.
is_lazy_guild_event_test() ->
?assertEqual(true, is_lazy_guild_event(message_create)),
?assertEqual(true, is_lazy_guild_event(voice_state_update)),
?assertEqual(false, is_lazy_guild_event(typing_start)),
?assertEqual(false, is_lazy_guild_event(channel_create)),
ok.
extract_role_ids_test() ->
?assertEqual([123], extract_role_ids([<<"123">>])),
?assertEqual([456], extract_role_ids([456])),
?assertEqual([123, 456], extract_role_ids([<<"123">>, 456])),
?assertEqual([], extract_role_ids([<<"invalid">>])),
ok.
-endif.

View File

@@ -25,29 +25,29 @@
update_ready_guilds/2
]).
-type session_state() :: session:session_state().
-type guild_id() :: session:guild_id().
-type channel_id() :: session:channel_id().
-type user_id() :: session:user_id().
-spec process_guild_state(map(), session_state()) -> {noreply, session_state()}.
process_guild_state(GuildState, State) ->
Ready = maps:get(ready, State),
CollectedGuilds = maps:get(collected_guild_states, State),
case Ready of
undefined ->
{noreply, StateAfterCreate} = session_dispatch:handle_dispatch(
guild_create, GuildState, State
),
dispatch_guild_initial_presences(GuildState, StateAfterCreate);
Event = guild_state_event(GuildState),
session_dispatch:handle_dispatch(Event, GuildState, State);
_ ->
NewCollectedGuilds = [GuildState | CollectedGuilds],
NewState = maps:put(collected_guild_states, NewCollectedGuilds, State),
check_readiness(update_ready_guilds(GuildState, NewState))
end.
dispatch_guild_initial_presences(_GuildState, State) ->
{noreply, State}.
-spec mark_guild_unavailable(guild_id(), session_state()) -> {noreply, session_state()}.
mark_guild_unavailable(GuildId, State) ->
CollectedGuilds = maps:get(collected_guild_states, State),
Ready = maps:get(ready, State),
UnavailableState = #{<<"id">> => integer_to_binary(GuildId), <<"unavailable">> => true},
NewCollectedGuilds = [UnavailableState | CollectedGuilds],
NewState = maps:put(collected_guild_states, NewCollectedGuilds, State),
@@ -56,24 +56,26 @@ mark_guild_unavailable(GuildId, State) ->
_ -> {noreply, update_ready_guilds(UnavailableState, NewState)}
end.
-spec check_readiness(session_state()) -> {noreply, session_state()}.
check_readiness(State) ->
Ready = maps:get(ready, State),
PresencePid = maps:get(presence_pid, State, undefined),
Guilds = maps:get(guilds, State),
case Ready of
undefined ->
{noreply, State};
_ when PresencePid =/= undefined ->
AllGuildsReady = lists:all(fun({_, V}) -> V =/= undefined end, maps:to_list(Guilds)),
if
AllGuildsReady -> dispatch_ready_data(State);
true -> {noreply, State}
case AllGuildsReady of
true -> dispatch_ready_data(State);
false -> {noreply, State}
end;
_ ->
{noreply, State}
end.
-spec dispatch_ready_data(session_state()) ->
{noreply, session_state()} | {stop, normal, session_state()}.
dispatch_ready_data(State) ->
Ready = maps:get(ready, State),
CollectedGuilds = maps:get(collected_guild_states, State),
@@ -86,41 +88,29 @@ dispatch_ready_data(State) ->
SocketPid = maps:get(socket_pid, State, undefined),
Guilds = maps:get(guilds, State),
IsBot = maps:get(bot, State, false),
ReadyData =
case Ready of
undefined -> #{<<"guilds">> => []};
R -> R
end,
ReadyDataWithStrippedRelationships = strip_user_from_relationships(ReadyData),
ReadyDataBotStripped =
case IsBot of
true -> maps:put(<<"guilds">>, [], ReadyDataWithStrippedRelationships);
false -> ReadyDataWithStrippedRelationships
end,
UnavailableGuilds = [
#{<<"id">> => integer_to_binary(GuildId), <<"unavailable">> => true}
|| {GuildId, undefined} <- maps:to_list(Guilds)
],
StrippedGuilds = [strip_users_from_guild_members(G) || G <- lists:reverse(CollectedGuilds)],
AllGuildStates = StrippedGuilds ++ UnavailableGuilds,
ReadyDataWithoutGuildIds = maps:remove(<<"guild_ids">>, ReadyDataBotStripped),
GuildsForReady =
case IsBot of
true -> [];
false -> AllGuildStates
end,
logger:debug(
"[session_ready] dispatching READY for user ~p session ~p",
[UserId, SessionId]
),
FinalReadyData = maps:merge(ReadyDataWithoutGuildIds, #{
<<"guilds">> => GuildsForReady,
<<"sessions">> => CollectedSessions,
@@ -129,12 +119,11 @@ dispatch_ready_data(State) ->
<<"version">> => Version,
<<"session_id">> => SessionId
}),
case SocketPid of
undefined ->
{stop, normal, State};
Pid when is_pid(Pid) ->
metrics_client:counter(<<"gateway.ready">>),
otel_metrics:counter(<<"gateway.ready">>, 1, #{}),
StateAfterReady = dispatch_event(ready, FinalReadyData, State),
SessionCount = length(CollectedSessions),
GuildCount = length(GuildsForReady),
@@ -144,16 +133,15 @@ dispatch_ready_data(State) ->
<<"user_id">> => integer_to_binary(UserId),
<<"bot">> => bool_to_binary(IsBot)
},
metrics_client:gauge(<<"gateway.sessions.active">>, Dimensions, SessionCount),
metrics_client:gauge(<<"gateway.guilds.active">>, Dimensions, GuildCount),
metrics_client:gauge(<<"gateway.presences.active">>, Dimensions, PresenceCount),
otel_metrics:gauge(<<"gateway.sessions.active">>, SessionCount, Dimensions),
otel_metrics:gauge(<<"gateway.guilds.active">>, GuildCount, Dimensions),
otel_metrics:gauge(<<"gateway.presences.active">>, PresenceCount, Dimensions),
StateAfterGuildCreates =
case IsBot of
true ->
lists:foldl(
fun(GuildState, AccState) ->
dispatch_event(guild_create, GuildState, AccState)
dispatch_event(guild_state_event(GuildState), GuildState, AccState)
end,
StateAfterReady,
AllGuildStates
@@ -161,14 +149,13 @@ dispatch_ready_data(State) ->
false ->
StateAfterReady
end,
PrivateChannels = get_private_channels(StateAfterGuildCreates),
SessionPid = self(),
spawn(fun() ->
dispatch_call_creates_for_channels(
PrivateChannels, SessionId, StateAfterGuildCreates
PrivateChannels, SessionId, SessionPid
)
end),
FinalState = maps:merge(StateAfterGuildCreates, #{
ready => undefined,
collected_guild_states => [],
@@ -177,6 +164,20 @@ dispatch_ready_data(State) ->
{noreply, FinalState}
end.
-spec is_unavailable_guild_state(map()) -> boolean().
is_unavailable_guild_state(GuildState) ->
maps:get(<<"unavailable">>, GuildState, false) =:= true.
-spec guild_state_event(map()) -> guild_create | guild_delete.
guild_state_event(GuildState) ->
case is_unavailable_guild_state(GuildState) of
true ->
guild_delete;
false ->
guild_create
end.
-spec dispatch_event(atom(), map(), session_state()) -> session_state().
dispatch_event(Event, Data, State) ->
Seq = maps:get(seq, State),
SocketPid = maps:get(socket_pid, State, undefined),
@@ -187,6 +188,7 @@ dispatch_event(Event, Data, State) ->
end,
maps:put(seq, NewSeq, State).
-spec update_ready_guilds(map(), session_state()) -> session_state().
update_ready_guilds(GuildState, State) ->
case maps:get(bot, State, false) of
true ->
@@ -204,14 +206,14 @@ update_ready_guilds(GuildState, State) ->
end
end.
-spec collect_ready_users(session_state(), [map()]) -> [map()].
collect_ready_users(State, CollectedGuilds) ->
case maps:get(bot, State, false) of
true ->
[];
false ->
collect_ready_users_nonbot(State, CollectedGuilds)
true -> [];
false -> collect_ready_users_nonbot(State, CollectedGuilds)
end.
-spec collect_ready_users_nonbot(session_state(), [map()]) -> [map()].
collect_ready_users_nonbot(State, CollectedGuilds) ->
Ready = maps:get(ready, State, #{}),
Relationships = map_utils:ensure_list(map_utils:get_safe(Ready, <<"relationships">>, [])),
@@ -225,10 +227,10 @@ collect_ready_users_nonbot(State, CollectedGuilds) ->
Users0 = [U || U <- RelUsers ++ ChannelUsers ++ GuildUsers, is_map(U)],
dedup_users(Users0).
-spec collect_ready_presences(session_state(), [map()]) -> [map()].
collect_ready_presences(State, _CollectedGuilds) ->
CurrentUserId = maps:get(user_id, State),
IsBot = maps:get(bot, State, false),
{FriendIds, GdmIds} =
case IsBot of
true ->
@@ -242,7 +244,6 @@ collect_ready_presences(State, _CollectedGuilds) ->
]),
{FIds, GIds}
end,
Targets = lists:usort(FriendIds ++ GdmIds) -- [CurrentUserId],
case Targets of
[] ->
@@ -253,30 +254,33 @@ collect_ready_presences(State, _CollectedGuilds) ->
dedup_presences(Visible)
end.
-spec presence_user_id(map()) -> user_id() | undefined.
presence_user_id(P) when is_map(P) ->
User = maps:get(<<"user">>, P, #{}),
map_utils:get_integer(User, <<"id">>, undefined);
presence_user_id(_) ->
undefined.
-spec presence_visible(map()) -> boolean().
presence_visible(P) ->
Status = maps:get(<<"status">>, P, <<"offline">>),
Status =/= <<"offline">> andalso Status =/= <<"invisible">>.
-spec dedup_presences([map()]) -> [map()].
dedup_presences(Presences) ->
Map =
lists:foldl(
fun(P, Acc) ->
case presence_user_id(P) of
undefined -> Acc;
Id -> maps:put(Id, P, Acc)
end
end,
#{},
Presences
),
Map = lists:foldl(
fun(P, Acc) ->
case presence_user_id(P) of
undefined -> Acc;
Id -> maps:put(Id, P, Acc)
end
end,
#{},
Presences
),
maps:values(Map).
-spec collect_channel_users([map()]) -> [map()].
collect_channel_users(Channels) ->
lists:foldl(
fun(Channel, Acc) ->
@@ -294,6 +298,7 @@ collect_channel_users(Channels) ->
Channels
).
-spec collect_guild_users([map()]) -> [map()].
collect_guild_users(GuildStates) ->
lists:foldl(
fun(GuildState, Acc) ->
@@ -308,25 +313,26 @@ collect_guild_users(GuildStates) ->
ensure_list(GuildStates)
).
-spec dedup_users([map()]) -> [map()].
dedup_users(Users) ->
Map =
lists:foldl(
fun(U, Acc) ->
Id = maps:get(<<"id">>, U, undefined),
case Id of
undefined -> Acc;
_ -> maps:put(Id, U, Acc)
end
end,
#{},
Users
),
Map = lists:foldl(
fun(U, Acc) ->
Id = maps:get(<<"id">>, U, undefined),
case Id of
undefined -> Acc;
_ -> maps:put(Id, U, Acc)
end
end,
#{},
Users
),
maps:values(Map).
ensure_list(List) when is_list(List) -> List;
ensure_list(_) -> [].
-spec ensure_list([term()]) -> [term()].
ensure_list(List) -> List.
strip_users_from_guild_members(GuildState) when is_map(GuildState) ->
-spec strip_users_from_guild_members(map()) -> map().
strip_users_from_guild_members(GuildState) ->
case maps:get(<<"unavailable">>, GuildState, false) of
true ->
GuildState;
@@ -334,10 +340,9 @@ strip_users_from_guild_members(GuildState) when is_map(GuildState) ->
Members = map_utils:ensure_list(maps:get(<<"members">>, GuildState, [])),
StrippedMembers = [strip_user_from_member(M) || M <- Members],
maps:put(<<"members">>, StrippedMembers, GuildState)
end;
strip_users_from_guild_members(GuildState) ->
GuildState.
end.
-spec strip_user_from_member(map()) -> map().
strip_user_from_member(Member) when is_map(Member) ->
case maps:get(<<"user">>, Member, undefined) of
undefined ->
@@ -351,6 +356,7 @@ strip_user_from_member(Member) when is_map(Member) ->
strip_user_from_member(Member) ->
Member.
-spec strip_user_from_relationships(map()) -> map().
strip_user_from_relationships(ReadyData) when is_map(ReadyData) ->
Relationships = map_utils:ensure_list(maps:get(<<"relationships">>, ReadyData, [])),
StrippedRelationships = [strip_user_from_relationship(R) || R <- Relationships],
@@ -358,6 +364,7 @@ strip_user_from_relationships(ReadyData) when is_map(ReadyData) ->
strip_user_from_relationships(ReadyData) ->
ReadyData.
-spec strip_user_from_relationship(map()) -> map().
strip_user_from_relationship(Relationship) when is_map(Relationship) ->
case maps:get(<<"user">>, Relationship, undefined) of
undefined ->
@@ -375,6 +382,7 @@ strip_user_from_relationship(Relationship) when is_map(Relationship) ->
strip_user_from_relationship(Relationship) ->
Relationship.
-spec get_private_channels(session_state()) -> #{channel_id() => map()}.
get_private_channels(State) ->
Channels = maps:get(channels, State, #{}),
maps:filter(
@@ -385,19 +393,21 @@ get_private_channels(State) ->
Channels
).
dispatch_call_creates_for_channels(PrivateChannels, SessionId, State) ->
-spec dispatch_call_creates_for_channels(#{channel_id() => map()}, binary(), pid()) -> ok.
dispatch_call_creates_for_channels(PrivateChannels, SessionId, SessionPid) ->
lists:foreach(
fun({ChannelId, _Channel}) ->
dispatch_call_create_for_channel(ChannelId, SessionId, State)
dispatch_call_create_for_channel(ChannelId, SessionId, SessionPid)
end,
maps:to_list(PrivateChannels)
).
dispatch_call_create_for_channel(ChannelId, _SessionId, State) ->
-spec dispatch_call_create_for_channel(channel_id(), binary(), pid()) -> ok.
dispatch_call_create_for_channel(ChannelId, SessionId, SessionPid) ->
try
case gen_server:call(call_manager, {lookup, ChannelId}, 5000) of
{ok, CallPid} ->
dispatch_call_create_from_pid(CallPid, State);
dispatch_call_create_from_pid(CallPid, ChannelId, SessionId, SessionPid);
_ ->
ok
end
@@ -405,39 +415,269 @@ dispatch_call_create_for_channel(ChannelId, _SessionId, State) ->
_:_ -> ok
end.
dispatch_call_create_from_pid(CallPid, State) ->
-spec dispatch_call_create_from_pid(pid(), channel_id(), binary(), pid()) -> ok.
dispatch_call_create_from_pid(CallPid, ChannelId, SessionId, SessionPid) ->
case gen_server:call(CallPid, {get_state}, 5000) of
{ok, CallData} ->
CreatedAt = maps:get(created_at, CallData, 0),
Now = erlang:system_time(millisecond),
CallAge = Now - CreatedAt,
case CallAge < 5000 of
true ->
ok;
false ->
ChannelIdBin = maps:get(channel_id, CallData),
case validation:validate_snowflake(<<"channel_id">>, ChannelIdBin) of
{ok, ChannelId} ->
SessionPid = self(),
gen_server:cast(SessionPid, {call_monitor, ChannelId, CallPid}),
dispatch_event(call_create, CallData, State),
SessionId = maps:get(id, State),
metrics_client:counter(<<"gateway.calls.total">>, #{
<<"channel_id">> => integer_to_binary(ChannelId),
<<"session_id">> => SessionId,
<<"status">> => <<"create">>
});
{error, _, Reason} ->
logger:warning("[session_ready] Invalid channel_id in call data: ~p", [
Reason
]),
ok
end
end;
gen_server:cast(SessionPid, {call_monitor, ChannelId, CallPid}),
gen_server:cast(SessionPid, {dispatch, call_create, CallData}),
otel_metrics:counter(<<"gateway.calls.total">>, 1, #{
<<"channel_id">> => integer_to_binary(ChannelId),
<<"session_id">> => SessionId,
<<"status">> => <<"create">>
}),
ok;
_ ->
ok
end.
-spec bool_to_binary(term()) -> binary().
-spec bool_to_binary(boolean()) -> binary().
bool_to_binary(true) -> <<"true">>;
bool_to_binary(false) -> <<"false">>.
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
ensure_list_test() ->
?assertEqual([1, 2, 3], ensure_list([1, 2, 3])),
?assertEqual([], ensure_list([])),
ok.
bool_to_binary_test() ->
?assertEqual(<<"true">>, bool_to_binary(true)),
?assertEqual(<<"false">>, bool_to_binary(false)),
ok.
presence_visible_test() ->
?assertEqual(true, presence_visible(#{<<"status">> => <<"online">>})),
?assertEqual(true, presence_visible(#{<<"status">> => <<"idle">>})),
?assertEqual(true, presence_visible(#{<<"status">> => <<"dnd">>})),
?assertEqual(false, presence_visible(#{<<"status">> => <<"offline">>})),
?assertEqual(false, presence_visible(#{<<"status">> => <<"invisible">>})),
?assertEqual(false, presence_visible(#{})),
ok.
dedup_users_test() ->
Users = [
#{<<"id">> => <<"1">>, <<"username">> => <<"alice">>},
#{<<"id">> => <<"2">>, <<"username">> => <<"bob">>},
#{<<"id">> => <<"1">>, <<"username">> => <<"alice_duplicate">>}
],
Result = dedup_users(Users),
?assertEqual(2, length(Result)),
ok.
strip_user_from_member_test() ->
Member = #{
<<"user">> => #{<<"id">> => <<"123">>, <<"username">> => <<"test">>},
<<"nick">> => <<"nickname">>
},
Stripped = strip_user_from_member(Member),
?assertEqual(#{<<"id">> => <<"123">>}, maps:get(<<"user">>, Stripped)),
?assertEqual(<<"nickname">>, maps:get(<<"nick">>, Stripped)),
ok.
strip_user_from_relationship_test() ->
Rel = #{
<<"user">> => #{<<"id">> => <<"100">>, <<"username">> => <<"friend">>}, <<"type">> => 1
},
Stripped = strip_user_from_relationship(Rel),
?assertEqual(undefined, maps:get(<<"user">>, Stripped, undefined)),
?assertEqual(<<"100">>, maps:get(<<"id">>, Stripped)),
?assertEqual(1, maps:get(<<"type">>, Stripped)),
ok.
process_guild_state_unavailable_dispatches_guild_delete_test() ->
State0 = base_state_for_guild_dispatch_test(),
GuildState = #{<<"id">> => <<"123">>, <<"unavailable">> => true},
{noreply, State1} = process_guild_state(GuildState, State0),
Buffer = maps:get(buffer, State1),
?assertEqual(1, length(Buffer)),
FirstEvent = maps:get(event, hd(Buffer)),
?assertEqual(guild_delete, FirstEvent),
ok.
process_guild_state_available_dispatches_guild_create_test() ->
State0 = base_state_for_guild_dispatch_test(),
GuildState = #{
<<"id">> => <<"123">>,
<<"unavailable">> => false,
<<"channels">> => [],
<<"members">> => []
},
{noreply, State1} = process_guild_state(GuildState, State0),
Buffer = maps:get(buffer, State1),
?assertEqual(1, length(Buffer)),
FirstEvent = maps:get(event, hd(Buffer)),
?assertEqual(guild_create, FirstEvent),
ok.
dispatch_ready_data_bot_unavailable_dispatches_guild_delete_test() ->
drain_mailbox(),
UnavailableGuild = #{<<"id">> => <<"987">>, <<"unavailable">> => true},
State0 = #{
id => <<"session-ready-test">>,
user_id => 42,
version => 1,
ready => #{<<"v">> => 9, <<"guilds">> => []},
bot => true,
guilds => #{987 => cached_unavailable},
channels => #{},
relationships => #{},
collected_guild_states => [UnavailableGuild],
collected_sessions => [],
seq => 0,
buffer => [],
socket_pid => self(),
ignored_events => #{},
suppress_presence_updates => false,
pending_presences => [],
debounce_reactions => false,
reaction_buffer => [],
reaction_buffer_timer => undefined,
presence_pid => undefined
},
{noreply, _State1} = dispatch_ready_data(State0),
receive
{dispatch, ready, ReadyData, _ReadySeq} ->
?assertEqual([], maps:get(<<"guilds">>, ReadyData, []));
OtherReady ->
?assert(false, {unexpected_ready_message, OtherReady})
after 1000 ->
?assert(false, ready_not_dispatched)
end,
receive
{dispatch, guild_delete, GuildDeleteData, _GuildDeleteSeq} ->
?assertEqual(<<"987">>, maps:get(<<"id">>, GuildDeleteData)),
?assertEqual(true, maps:get(<<"unavailable">>, GuildDeleteData));
OtherDelete ->
?assert(false, {unexpected_guild_event, OtherDelete})
after 1000 ->
?assert(false, guild_delete_not_dispatched)
end,
receive
{dispatch, guild_create, _GuildCreateData, _GuildCreateSeq} ->
?assert(false, unexpected_guild_create_for_unavailable_guild)
after 100 ->
ok
end.
dispatch_ready_data_nonbot_includes_unavailable_guild_test() ->
drain_mailbox(),
UnavailableGuild = #{<<"id">> => <<"654">>, <<"unavailable">> => true},
State0 = #{
id => <<"session-ready-nonbot-test">>,
user_id => 43,
version => 1,
ready => #{<<"v">> => 9, <<"guilds">> => []},
bot => false,
guilds => #{654 => cached_unavailable},
channels => #{},
relationships => #{},
collected_guild_states => [UnavailableGuild],
collected_sessions => [],
seq => 0,
buffer => [],
socket_pid => self(),
ignored_events => #{},
suppress_presence_updates => false,
pending_presences => [],
debounce_reactions => false,
reaction_buffer => [],
reaction_buffer_timer => undefined,
presence_pid => undefined
},
{noreply, _State1} = dispatch_ready_data(State0),
receive
{dispatch, ready, ReadyData, _ReadySeq} ->
ReadyGuilds = maps:get(<<"guilds">>, ReadyData, []),
?assertEqual(1, length(ReadyGuilds)),
ReadyGuild = hd(ReadyGuilds),
?assertEqual(<<"654">>, maps:get(<<"id">>, ReadyGuild)),
?assertEqual(true, maps:get(<<"unavailable">>, ReadyGuild));
OtherReady ->
?assert(false, {unexpected_ready_message, OtherReady})
after 1000 ->
?assert(false, ready_not_dispatched)
end,
receive
{dispatch, guild_create, _GuildCreateData, _GuildCreateSeq} ->
?assert(false, unexpected_guild_create_for_nonbot_ready)
after 100 ->
ok
end,
receive
{dispatch, guild_delete, _GuildDeleteData, _GuildDeleteSeq} ->
?assert(false, unexpected_guild_delete_for_nonbot_ready)
after 100 ->
ok
end.
dispatch_call_create_from_pid_always_casts_to_session_test() ->
ChannelId = 1234,
SessionId = <<"session-ready-call-test">>,
CallData = #{
channel_id => integer_to_binary(ChannelId),
message_id => <<"9001">>,
region => null,
ringing => [],
recipients => [],
voice_states => [],
created_at => erlang:system_time(millisecond)
},
CallPid = spawn(fun() -> call_state_stub_loop(CallData) end),
ok = dispatch_call_create_from_pid(CallPid, ChannelId, SessionId, self()),
receive
{'$gen_cast', {call_monitor, ChannelId, CallPid}} ->
ok
after 1000 ->
?assert(false, call_monitor_not_cast_to_session)
end,
receive
{'$gen_cast', {dispatch, call_create, DispatchData}} ->
?assertEqual(CallData, DispatchData)
after 1000 ->
?assert(false, call_create_not_cast_to_session)
end,
exit(CallPid, kill),
ok.
-spec call_state_stub_loop(map()) -> ok.
call_state_stub_loop(CallData) ->
receive
{'$gen_call', From, {get_state}} ->
gen_server:reply(From, {ok, CallData}),
call_state_stub_loop(CallData);
_ ->
call_state_stub_loop(CallData)
end.
-spec drain_mailbox() -> ok.
drain_mailbox() ->
receive
_Message ->
drain_mailbox()
after 0 ->
ok
end.
-spec base_state_for_guild_dispatch_test() -> session_state().
base_state_for_guild_dispatch_test() ->
#{
ready => undefined,
seq => 0,
buffer => [],
socket_pid => undefined,
ignored_events => #{},
channels => #{},
relationships => #{},
suppress_presence_updates => false,
pending_presences => [],
presence_pid => undefined,
debounce_reactions => false,
reaction_buffer => [],
reaction_buffer_timer => undefined,
collected_guild_states => []
}.
-endif.

View File

@@ -18,11 +18,47 @@
-module(session_voice).
-export([
init_voice_queue/0,
process_voice_queue/1,
handle_voice_state_update/2,
handle_voice_disconnect/1,
handle_voice_token_request/8
handle_voice_disconnect/1
]).
-type session_state() :: session:session_state().
-type guild_id() :: session:guild_id().
-type channel_id() :: session:channel_id().
-type user_id() :: session:user_id().
-type voice_state_reply() ::
{reply, ok, session_state()}
| {reply, {error, term(), term()}, session_state()}.
-spec init_voice_queue() -> #{voice_queue := queue:queue(), voice_queue_timer := undefined}.
init_voice_queue() ->
#{voice_queue => queue:new(), voice_queue_timer => undefined}.
-spec process_voice_queue(session_state()) -> session_state().
process_voice_queue(State) ->
VoiceQueue = maps:get(voice_queue, State, queue:new()),
case queue:out(VoiceQueue) of
{empty, _} ->
State;
{{value, Item}, NewQueue} ->
process_voice_queue_item(Item, maps:put(voice_queue, NewQueue, State))
end.
-spec process_voice_queue_item(map(), session_state()) -> session_state().
process_voice_queue_item(Item, State) ->
case maps:get(type, Item, undefined) of
voice_state_update ->
Data = maps:get(data, Item),
{reply, _, NewState} = handle_voice_state_update(Data, State),
NewState;
_ ->
State
end.
-spec handle_voice_state_update(map(), session_state()) -> voice_state_reply().
handle_voice_state_update(Data, State) ->
GuildIdRaw = maps:get(<<"guild_id">>, Data, null),
ChannelIdRaw = maps:get(<<"channel_id">>, Data, null),
@@ -31,18 +67,15 @@ handle_voice_state_update(Data, State) ->
SelfDeaf = maps:get(<<"self_deaf">>, Data, false),
SelfVideo = maps:get(<<"self_video">>, Data, false),
SelfStream = maps:get(<<"self_stream">>, Data, false),
ViewerStreamKey = maps:get(<<"viewer_stream_key">>, Data, undefined),
ViewerStreamKeys = maps:get(<<"viewer_stream_keys">>, Data, undefined),
IsMobile = maps:get(<<"is_mobile">>, Data, false),
Latitude = maps:get(<<"latitude">>, Data, null),
Longitude = maps:get(<<"longitude">>, Data, null),
SessionId = maps:get(id, State),
UserId = maps:get(user_id, State),
Guilds = maps:get(guilds, State),
GuildIdResult = validation:validate_optional_snowflake(GuildIdRaw),
ChannelIdResult = validation:validate_optional_snowflake(ChannelIdRaw),
case {GuildIdResult, ChannelIdResult} of
{{ok, GuildId}, {ok, ChannelId}} ->
handle_validated_voice_state_update(
@@ -53,7 +86,7 @@ handle_voice_state_update(Data, State) ->
SelfDeaf,
SelfVideo,
SelfStream,
ViewerStreamKey,
ViewerStreamKeys,
IsMobile,
Latitude,
Longitude,
@@ -68,6 +101,23 @@ handle_voice_state_update(Data, State) ->
{reply, Error, State}
end.
-spec handle_validated_voice_state_update(
guild_id() | null,
channel_id() | null,
binary() | null,
boolean(),
boolean(),
boolean(),
boolean(),
list() | undefined,
boolean(),
number() | null,
number() | null,
binary(),
user_id(),
map(),
session_state()
) -> voice_state_reply().
handle_validated_voice_state_update(
null,
null,
@@ -76,7 +126,7 @@ handle_validated_voice_state_update(
_SelfDeaf,
_SelfVideo,
_SelfStream,
_ViewerStreamKey,
_ViewerStreamKeys,
_IsMobile,
_Latitude,
_Longitude,
@@ -94,7 +144,7 @@ handle_validated_voice_state_update(
_SelfDeaf,
_SelfVideo,
_SelfStream,
_ViewerStreamKey,
_ViewerStreamKeys,
_IsMobile,
_Latitude,
_Longitude,
@@ -112,12 +162,11 @@ handle_validated_voice_state_update(
self_deaf => false,
self_video => false,
self_stream => false,
viewer_stream_key => null,
viewer_stream_keys => [],
is_mobile => false,
latitude => null,
longitude => null
},
StateWithSessionPid = maps:put(session_pid, self(), State),
case dm_voice:voice_state_update(Request, StateWithSessionPid) of
{reply, #{success := true}, NewState} ->
@@ -134,7 +183,7 @@ handle_validated_voice_state_update(
SelfDeaf,
SelfVideo,
SelfStream,
ViewerStreamKey,
ViewerStreamKeys,
IsMobile,
Latitude,
Longitude,
@@ -152,12 +201,11 @@ handle_validated_voice_state_update(
self_deaf => SelfDeaf,
self_video => SelfVideo,
self_stream => SelfStream,
viewer_stream_key => ViewerStreamKey,
viewer_stream_keys => ViewerStreamKeys,
is_mobile => IsMobile,
latitude => Latitude,
longitude => Longitude
},
StateWithSessionPid = maps:put(session_pid, self(), State),
case dm_voice:voice_state_update(Request, StateWithSessionPid) of
{reply, #{success := true, needs_token := true}, NewState} ->
@@ -183,7 +231,7 @@ handle_validated_voice_state_update(
SelfDeaf,
SelfVideo,
SelfStream,
ViewerStreamKey,
ViewerStreamKeys,
IsMobile,
Latitude,
Longitude,
@@ -194,7 +242,6 @@ handle_validated_voice_state_update(
) when is_integer(GuildId) ->
case maps:get(GuildId, Guilds, undefined) of
undefined ->
logger:warning("[session_voice] Guild not found in session: ~p", [GuildId]),
{reply, gateway_errors:error(voice_guild_not_found), State};
{GuildPid, _Ref} when is_pid(GuildPid) ->
Request = #{
@@ -206,64 +253,35 @@ handle_validated_voice_state_update(
self_deaf => SelfDeaf,
self_video => SelfVideo,
self_stream => SelfStream,
viewer_stream_key => ViewerStreamKey,
viewer_stream_keys => ViewerStreamKeys,
is_mobile => IsMobile,
latitude => Latitude,
longitude => Longitude
},
logger:debug(
"[session_voice] Calling guild process for voice state update: GuildId=~p, ChannelId=~p, ConnectionId=~p",
[GuildId, ChannelId, ConnectionId]
),
case guild_client:voice_state_update(GuildPid, Request, 12000) of
{ok, #{needs_token := true}} ->
logger:debug("[session_voice] Voice state update succeeded, needs token"),
SessionPid = self(),
spawn(fun() ->
handle_voice_token_request(
GuildId,
ChannelId,
UserId,
ConnectionId,
SessionId,
SessionPid,
Latitude,
Longitude
)
end),
{reply, ok, State};
{ok, _} ->
logger:debug("[session_voice] Voice state update succeeded"),
{reply, ok, State};
{error, timeout} ->
logger:error(
"[session_voice] Voice state update timed out (>12s) for GuildId=~p, ChannelId=~p",
[GuildId, ChannelId]
),
{reply, gateway_errors:error(timeout), State};
{error, noproc} ->
logger:error(
"[session_voice] Guild process not running for GuildId=~p",
[GuildId]
),
{reply, gateway_errors:error(internal_error), State};
{error, Category, ErrorAtom} ->
logger:warning("[session_voice] Voice state update failed: ~p", [ErrorAtom]),
{reply, {error, Category, ErrorAtom}, State}
end;
queue_guild_voice_state_update(
GuildPid,
GuildId,
ChannelId,
UserId,
ConnectionId,
SessionId,
Request,
Latitude,
Longitude,
State
);
_ ->
logger:warning("[session_voice] Invalid guild pid in session"),
{reply, gateway_errors:error(internal_error), State}
end;
handle_validated_voice_state_update(
GuildId,
ChannelId,
ConnectionId,
_GuildId,
_ChannelId,
_ConnectionId,
_SelfMute,
_SelfDeaf,
_SelfVideo,
_SelfStream,
_ViewerStreamKey,
_ViewerStreamKeys,
_IsMobile,
_Latitude,
_Longitude,
@@ -272,62 +290,170 @@ handle_validated_voice_state_update(
_Guilds,
State
) ->
logger:warning(
"[session_voice] Invalid voice state update parameters: GuildId=~p, ChannelId=~p, ConnectionId=~p",
[GuildId, ChannelId, ConnectionId]
),
{reply, gateway_errors:error(validation_invalid_params), State}.
handle_voice_disconnect(State) ->
Guilds = maps:get(guilds, State),
UserId = maps:get(user_id, State),
SessionId = maps:get(id, State),
ConnectionId = maps:get(connection_id, State),
lists:foreach(
fun
({_GuildId, {GuildPid, _Ref}}) when is_pid(GuildPid) ->
Request = #{
user_id => UserId,
channel_id => null,
session_id => SessionId,
connection_id => ConnectionId,
self_mute => false,
self_deaf => false,
self_video => false,
self_stream => false,
viewer_stream_key => null
},
_ = guild_client:voice_state_update(GuildPid, Request, 10000);
(_) ->
ok
end,
maps:to_list(Guilds)
),
{reply, #{success := true}, NewState} = dm_voice:disconnect_voice_user(UserId, State),
{reply, ok, NewState}.
handle_voice_token_request(
GuildId, ChannelId, UserId, ConnectionId, _SessionId, SessionPid, Latitude, Longitude
-spec queue_guild_voice_state_update(
pid(),
guild_id(),
channel_id() | null,
user_id(),
binary() | null,
binary(),
map(),
number() | null,
number() | null,
session_state()
) ->
Req = voice_utils:build_voice_token_rpc_request(
GuildId, ChannelId, UserId, ConnectionId, Latitude, Longitude
),
voice_state_reply().
queue_guild_voice_state_update(
GuildPid,
GuildId,
ChannelId,
UserId,
ConnectionId,
SessionId,
Request,
Latitude,
Longitude,
State
) ->
SessionPid = self(),
spawn(fun() ->
handle_guild_voice_state_update(
GuildPid,
GuildId,
ChannelId,
UserId,
ConnectionId,
SessionId,
Request,
Latitude,
Longitude,
SessionPid
)
end),
{reply, ok, State}.
case rpc_client:call(Req) of
{ok, Data} ->
Token = maps:get(<<"token">>, Data),
Endpoint = maps:get(<<"endpoint">>, Data),
-spec handle_guild_voice_state_update(
pid(),
guild_id(),
channel_id() | null,
user_id(),
binary() | null,
binary(),
map(),
number() | null,
number() | null,
pid()
) ->
ok.
handle_guild_voice_state_update(
GuildPid,
GuildId,
ChannelId,
_UserId,
_ConnectionId,
_SessionId,
Request,
_Latitude,
_Longitude,
SessionPid
) ->
case guild_client:voice_state_update(GuildPid, Request, 12000) of
{ok, Reply} when is_map(Reply) ->
maybe_dispatch_voice_server_update_from_reply(Reply, GuildId, ChannelId, SessionPid),
ok;
{error, timeout} ->
ok;
{error, noproc} ->
ok;
{error, _Category, _ErrorAtom} ->
ok
end.
-spec maybe_dispatch_voice_server_update_from_reply(map(), guild_id(), channel_id() | null, pid()) ->
ok.
maybe_dispatch_voice_server_update_from_reply(Reply, GuildId, ChannelId, SessionPid) ->
case
{
maps:get(token, Reply, undefined),
maps:get(endpoint, Reply, undefined),
maps:get(connection_id, Reply, undefined),
ChannelId
}
of
{Token, Endpoint, ConnectionId, ChannelIdValue} when
is_binary(Token),
is_binary(Endpoint),
is_binary(ConnectionId),
is_integer(ChannelIdValue),
is_pid(SessionPid)
->
VoiceServerUpdate = #{
<<"token">> => Token,
<<"endpoint">> => Endpoint,
<<"guild_id">> => integer_to_binary(GuildId),
<<"channel_id">> => integer_to_binary(ChannelIdValue),
<<"connection_id">> => ConnectionId
},
gen_server:cast(SessionPid, {dispatch, voice_server_update, VoiceServerUpdate});
{error, _Reason} ->
gen_server:cast(SessionPid, {dispatch, voice_server_update, VoiceServerUpdate}),
ok;
_ ->
ok
end.
-spec handle_voice_disconnect(session_state()) -> voice_state_reply().
handle_voice_disconnect(State) ->
Guilds = maps:get(guilds, State),
UserId = maps:get(user_id, State),
SessionId = maps:get(id, State),
ConnectionId = maps:get(connection_id, State, null),
Request = #{
user_id => UserId,
channel_id => null,
session_id => SessionId,
connection_id => ConnectionId,
self_mute => false,
self_deaf => false,
self_video => false,
self_stream => false,
viewer_stream_keys => []
},
dispatch_guild_voice_disconnects(Guilds, Request),
{reply, #{success := true}, NewState} = dm_voice:disconnect_voice_user(UserId, State),
{reply, ok, NewState}.
-spec dispatch_guild_voice_disconnects(map(), map()) -> ok.
dispatch_guild_voice_disconnects(Guilds, Request) ->
lists:foreach(
fun
({_GuildId, {GuildPid, _Ref}}) when is_pid(GuildPid) ->
spawn(fun() ->
_ = guild_client:voice_state_update(GuildPid, Request, 10000),
ok
end),
ok;
(_) ->
ok
end,
maps:to_list(Guilds)
).
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
init_voice_queue_test() ->
Result = init_voice_queue(),
?assert(maps:is_key(voice_queue, Result)),
?assert(maps:is_key(voice_queue_timer, Result)),
?assertEqual(undefined, maps:get(voice_queue_timer, Result)),
?assert(queue:is_empty(maps:get(voice_queue, Result))),
ok.
process_voice_queue_empty_test() ->
State = #{voice_queue => queue:new()},
Result = process_voice_queue(State),
?assertEqual(State, Result),
ok.
-endif.