refactor progress

This commit is contained in:
Hampus Kraft
2026-02-17 12:22:36 +00:00
parent cb31608523
commit d5abd1a7e4
8257 changed files with 1190207 additions and 761040 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -21,123 +21,533 @@
is_guild_unavailable_for_user/2,
is_user_staff/2,
check_unavailability_transition/2,
handle_unavailability_transition/2
handle_unavailability_transition/2,
get_cached_unavailability_mode/1,
is_guild_unavailable_for_user_from_cache/2,
update_unavailability_cache_for_state/1
]).
-import(guild_permissions, [find_member_by_user_id/2]).
-import(guild_data, [get_guild_state/2]).
-type guild_state() :: map().
-type user_id() :: integer().
-type guild_id() :: integer().
-type unavailability_mode() ::
available
| unavailable_for_everyone
| unavailable_for_everyone_but_staff.
-type transition_result() :: {unavailable_enabled, boolean()} | unavailable_disabled | no_change.
-define(GUILD_UNAVAILABILITY_CACHE, guild_unavailability_cache).
-define(STAFF_USER_FLAG, 16#1).
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-endif.
-spec is_guild_unavailable_for_user(user_id(), guild_state()) -> boolean().
is_guild_unavailable_for_user(UserId, State) ->
Data = maps:get(data, State),
Guild = maps:get(<<"guild">>, Data),
Features = maps:get(<<"features">>, Guild, []),
HasUnavailableForEveryone = lists:member(<<"UNAVAILABLE_FOR_EVERYONE">>, Features),
HasUnavailableForEveryoneButStaff =
lists:member(<<"UNAVAILABLE_FOR_EVERYONE_BUT_STAFF">>, Features),
case {HasUnavailableForEveryone, HasUnavailableForEveryoneButStaff} of
{true, _} ->
case get_unavailability_mode_from_state(State) of
unavailable_for_everyone ->
true;
{false, true} ->
unavailable_for_everyone_but_staff ->
not is_user_staff(UserId, State);
{false, false} ->
available ->
false
end.
-spec is_user_staff(user_id(), guild_state()) -> boolean().
is_user_staff(UserId, State) ->
case find_member_by_user_id(UserId, State) of
undefined ->
case is_user_staff_from_sessions(UserId, State) of
true ->
true;
false ->
false;
Member ->
User = maps:get(<<"user">>, Member, #{}),
Flags = utils:binary_to_integer_safe(maps:get(<<"flags">>, User, <<"0">>)),
(Flags band 16#1) =:= 16#1
end.
check_unavailability_transition(OldState, NewState) ->
OldData = maps:get(data, OldState),
OldGuild = maps:get(<<"guild">>, OldData),
OldFeatures = maps:get(<<"features">>, OldGuild, []),
NewData = maps:get(data, NewState),
NewGuild = maps:get(<<"guild">>, NewData),
NewFeatures = maps:get(<<"features">>, NewGuild, []),
OldUnavailableForEveryone = lists:member(<<"UNAVAILABLE_FOR_EVERYONE">>, OldFeatures),
NewUnavailableForEveryone = lists:member(<<"UNAVAILABLE_FOR_EVERYONE">>, NewFeatures),
OldUnavailableForEveryoneButStaff =
lists:member(<<"UNAVAILABLE_FOR_EVERYONE_BUT_STAFF">>, OldFeatures),
NewUnavailableForEveryoneButStaff =
lists:member(<<"UNAVAILABLE_FOR_EVERYONE_BUT_STAFF">>, NewFeatures),
OldIsUnavailable = OldUnavailableForEveryone orelse OldUnavailableForEveryoneButStaff,
NewIsUnavailable = NewUnavailableForEveryone orelse NewUnavailableForEveryoneButStaff,
case {OldIsUnavailable, NewIsUnavailable} of
{false, true} ->
{unavailable_enabled, NewUnavailableForEveryoneButStaff};
{true, false} ->
unavailable_disabled;
_ ->
case
{OldUnavailableForEveryoneButStaff, NewUnavailableForEveryoneButStaff,
OldUnavailableForEveryone, NewUnavailableForEveryone}
of
{true, false, false, true} ->
{unavailable_enabled, false};
{false, true, true, false} ->
{unavailable_enabled, true};
_ ->
no_change
undefined ->
case guild_permissions:find_member_by_user_id(UserId, State) of
undefined ->
false;
Member ->
User = maps:get(<<"user">>, Member, #{}),
is_user_staff_from_user_data(User)
end
end.
handle_unavailability_transition(OldState, NewState) ->
GuildId = maps:get(id, NewState),
UnavailablePayload = #{
<<"id">> => integer_to_binary(GuildId),
<<"unavailable">> => true
},
-spec is_user_staff_from_sessions(user_id(), guild_state()) -> boolean() | undefined.
is_user_staff_from_sessions(UserId, State) ->
Sessions = maps:get(sessions, State, #{}),
maps:fold(
fun(_SessionId, SessionData, Acc) ->
case Acc of
undefined ->
SessionUserId = maps:get(user_id, SessionData, undefined),
SessionIsStaff = maps:get(is_staff, SessionData, undefined),
case {SessionUserId =:= UserId, SessionIsStaff} of
{true, true} ->
true;
{true, false} ->
false;
_ ->
undefined
end;
_ ->
Acc
end
end,
undefined,
Sessions
).
-spec get_cached_unavailability_mode(guild_id()) -> unavailability_mode().
get_cached_unavailability_mode(GuildId) ->
ensure_unavailability_cache_table(),
case ets:lookup(?GUILD_UNAVAILABILITY_CACHE, GuildId) of
[{GuildId, Mode}] ->
normalize_unavailability_mode(Mode);
[] ->
available
end.
-spec is_guild_unavailable_for_user_from_cache(guild_id(), map()) -> boolean().
is_guild_unavailable_for_user_from_cache(GuildId, UserData) ->
case get_cached_unavailability_mode(GuildId) of
unavailable_for_everyone ->
true;
unavailable_for_everyone_but_staff ->
not is_user_staff_from_user_data(UserData);
available ->
false
end.
-spec update_unavailability_cache_for_state(guild_state()) -> unavailability_mode().
update_unavailability_cache_for_state(State) ->
GuildId = maps:get(id, State),
Mode = get_unavailability_mode_from_state(State),
set_cached_unavailability_mode(GuildId, Mode),
Mode.
-spec check_unavailability_transition(guild_state(), guild_state()) -> transition_result().
check_unavailability_transition(OldState, NewState) ->
OldMode = get_unavailability_mode_from_state(OldState),
NewMode = get_unavailability_mode_from_state(NewState),
case {OldMode, NewMode} of
{available, unavailable_for_everyone} ->
{unavailable_enabled, false};
{available, unavailable_for_everyone_but_staff} ->
{unavailable_enabled, true};
{unavailable_for_everyone, available} ->
unavailable_disabled;
{unavailable_for_everyone_but_staff, available} ->
unavailable_disabled;
{unavailable_for_everyone_but_staff, unavailable_for_everyone} ->
{unavailable_enabled, false};
{unavailable_for_everyone, unavailable_for_everyone_but_staff} ->
{unavailable_enabled, true};
_ ->
no_change
end.
-spec handle_unavailability_transition(guild_state(), guild_state()) -> guild_state().
handle_unavailability_transition(OldState, NewState) ->
_ = update_unavailability_cache_for_state(NewState),
GuildId = maps:get(id, NewState),
case check_unavailability_transition(OldState, NewState) of
{unavailable_enabled, StaffOnly} ->
Sessions = maps:get(sessions, NewState, #{}),
lists:foreach(
fun({_SessionId, SessionData}) ->
UserId = maps:get(user_id, SessionData),
Pid = maps:get(pid, SessionData),
ShouldBeUnavailable =
case StaffOnly of
true -> not is_user_staff(UserId, NewState);
false -> true
end,
case ShouldBeUnavailable of
true ->
gen_server:cast(Pid, {dispatch, guild_delete, UnavailablePayload});
false ->
ok
end
end,
maps:to_list(Sessions)
);
disconnect_ineligible_sessions(StaffOnly, NewState, GuildId);
unavailable_disabled ->
Sessions = maps:get(sessions, NewState, #{}),
GuildId = maps:get(id, NewState),
BulkPresences = presence_utils:collect_guild_member_presences(NewState),
lists:foreach(
fun({_SessionId, SessionData}) ->
UserId = maps:get(user_id, SessionData),
Pid = maps:get(pid, SessionData),
GuildState = get_guild_state(UserId, NewState),
gen_server:cast(Pid, {dispatch, guild_create, GuildState}),
presence_utils:send_presence_bulk(Pid, GuildId, UserId, BulkPresences)
case maps:get(pending_connect, SessionData, false) of
true ->
ok;
false ->
UserId = maps:get(user_id, SessionData),
Pid = maps:get(pid, SessionData),
GuildState = guild_data:get_guild_state(UserId, NewState),
gen_server:cast(Pid, {dispatch, guild_create, GuildState}),
presence_utils:send_presence_bulk(Pid, GuildId, UserId, BulkPresences)
end
end,
maps:to_list(Sessions)
);
),
NewState;
no_change ->
NewState
end.
-spec disconnect_ineligible_sessions(boolean(), guild_state(), guild_id()) -> guild_state().
disconnect_ineligible_sessions(StaffOnly, State, GuildId) ->
Sessions = maps:get(sessions, State, #{}),
{FinalState, _DisconnectedUsers} = lists:foldl(
fun({SessionId, SessionData}, {AccState, ProcessedUsers}) ->
UserId = maps:get(user_id, SessionData),
case should_disconnect_user(UserId, StaffOnly, AccState) of
true ->
Pid = maps:get(pid, SessionData, undefined),
maybe_send_guild_leave(Pid, GuildId),
{VoiceState, UpdatedUsers} =
maybe_disconnect_voice_for_user(UserId, ProcessedUsers, AccState),
NewState = guild_sessions:remove_session(SessionId, VoiceState),
{NewState, UpdatedUsers};
false ->
{AccState, ProcessedUsers}
end
end,
{State, sets:new()},
maps:to_list(Sessions)
),
FinalState.
-spec should_disconnect_user(user_id(), boolean(), guild_state()) -> boolean().
should_disconnect_user(UserId, true, State) ->
not is_user_staff(UserId, State);
should_disconnect_user(_UserId, false, _State) ->
true.
-spec maybe_send_guild_leave(pid() | undefined, guild_id()) -> ok.
maybe_send_guild_leave(Pid, GuildId) when is_pid(Pid) ->
gen_server:cast(Pid, {guild_leave, GuildId, forced_unavailable}),
ok;
maybe_send_guild_leave(_Pid, _GuildId) ->
ok.
-spec maybe_disconnect_voice_for_user(user_id(), sets:set(user_id()), guild_state()) ->
{guild_state(), sets:set(user_id())}.
maybe_disconnect_voice_for_user(UserId, ProcessedUsers, State) ->
case sets:is_element(UserId, ProcessedUsers) of
true ->
{State, ProcessedUsers};
false ->
{reply, _Result, VoiceState} = guild_voice_disconnect:disconnect_voice_user(
#{user_id => UserId, connection_id => null},
State
),
{VoiceState, sets:add_element(UserId, ProcessedUsers)}
end.
-spec ensure_unavailability_cache_table() -> ok.
ensure_unavailability_cache_table() ->
case ets:whereis(?GUILD_UNAVAILABILITY_CACHE) of
undefined ->
try ets:new(?GUILD_UNAVAILABILITY_CACHE, [named_table, public, set, {read_concurrency, true}]) of
_ -> ok
catch
error:badarg -> ok
end;
_ ->
ok
end.
-spec set_cached_unavailability_mode(guild_id(), unavailability_mode()) -> ok.
set_cached_unavailability_mode(GuildId, available) ->
ensure_unavailability_cache_table(),
ets:delete(?GUILD_UNAVAILABILITY_CACHE, GuildId),
ok;
set_cached_unavailability_mode(GuildId, Mode) ->
ensure_unavailability_cache_table(),
ets:insert(?GUILD_UNAVAILABILITY_CACHE, {GuildId, Mode}),
ok.
-spec normalize_unavailability_mode(term()) -> unavailability_mode().
normalize_unavailability_mode(unavailable_for_everyone) ->
unavailable_for_everyone;
normalize_unavailability_mode(unavailable_for_everyone_but_staff) ->
unavailable_for_everyone_but_staff;
normalize_unavailability_mode(_) ->
available.
-spec get_unavailability_mode_from_state(guild_state()) -> unavailability_mode().
get_unavailability_mode_from_state(State) ->
Data = maps:get(data, State, #{}),
Guild = maps:get(<<"guild">>, Data, #{}),
Features = maps:get(<<"features">>, Guild, []),
get_unavailability_mode_from_features(Features).
-spec get_unavailability_mode_from_features(term()) -> unavailability_mode().
get_unavailability_mode_from_features(Features) when is_list(Features) ->
HasUnavailableForEveryone = lists:member(<<"UNAVAILABLE_FOR_EVERYONE">>, Features),
HasUnavailableForEveryoneButStaff =
lists:member(<<"UNAVAILABLE_FOR_EVERYONE_BUT_STAFF">>, Features),
case {HasUnavailableForEveryone, HasUnavailableForEveryoneButStaff} of
{true, _} ->
unavailable_for_everyone;
{false, true} ->
unavailable_for_everyone_but_staff;
{false, false} ->
available
end;
get_unavailability_mode_from_features(_) ->
available.
-spec is_user_staff_from_user_data(map()) -> boolean().
is_user_staff_from_user_data(UserData) when is_map(UserData) ->
case parse_is_staff_value(maps:get(<<"is_staff">>, UserData, undefined)) of
undefined ->
is_user_staff_from_flags(UserData);
IsStaff ->
IsStaff
end;
is_user_staff_from_user_data(_) ->
false.
-spec parse_is_staff_value(term()) -> boolean() | undefined.
parse_is_staff_value(true) ->
true;
parse_is_staff_value(false) ->
false;
parse_is_staff_value(<<"true">>) ->
true;
parse_is_staff_value(<<"false">>) ->
false;
parse_is_staff_value(_) ->
undefined.
-spec is_user_staff_from_flags(map()) -> boolean().
is_user_staff_from_flags(UserData) ->
FlagsValue = maps:get(<<"flags">>, UserData, 0),
Flags = type_conv:to_integer(FlagsValue),
case Flags of
undefined ->
false;
Value when is_integer(Value) ->
(Value band ?STAFF_USER_FLAG) =:= ?STAFF_USER_FLAG
end.
-ifdef(TEST).
-spec cleanup_unavailability_cache(guild_id()) -> ok.
cleanup_unavailability_cache(GuildId) ->
set_cached_unavailability_mode(GuildId, available).
disconnect_ineligible_sessions_staff_only_test() ->
Parent = self(),
GuildId = 99001,
NonStaffPid = start_session_capture(non_staff, Parent),
StaffPid = start_session_capture(staff, Parent),
try
BaseState = state_for_unavailability_transition_test(GuildId, NonStaffPid, StaffPid),
OldState = BaseState,
NewState = BaseState#{
data => #{
<<"guild">> => #{
<<"features">> => [<<"UNAVAILABLE_FOR_EVERYONE_BUT_STAFF">>]
},
<<"members">> => [
#{<<"user">> => #{<<"id">> => <<"1001">>, <<"flags">> => <<"0">>}},
#{<<"user">> => #{<<"id">> => <<"1002">>, <<"flags">> => <<"1">>}}
]
}
},
UpdatedState = handle_unavailability_transition(OldState, NewState),
Sessions = maps:get(sessions, UpdatedState, #{}),
?assertEqual(1, map_size(Sessions)),
?assert(maps:is_key(<<"staff">>, Sessions)),
?assertEqual(unavailable_for_everyone_but_staff, get_cached_unavailability_mode(GuildId)),
receive
{non_staff, {'$gen_cast', {guild_leave, GuildId, forced_unavailable}}} -> ok
after 1000 ->
?assert(false)
end,
receive
{staff, {'$gen_cast', {guild_leave, GuildId, forced_unavailable}}} ->
?assert(false)
after 200 ->
ok
end
after
cleanup_unavailability_cache(GuildId),
NonStaffPid ! stop,
StaffPid ! stop
end.
disconnect_ineligible_sessions_everyone_test() ->
Parent = self(),
GuildId = 99002,
UserOnePid = start_session_capture(user_one, Parent),
UserTwoPid = start_session_capture(user_two, Parent),
try
BaseState = state_for_unavailability_transition_test(GuildId, UserOnePid, UserTwoPid),
OldState = BaseState,
NewState = BaseState#{
data => #{
<<"guild">> => #{
<<"features">> => [<<"UNAVAILABLE_FOR_EVERYONE">>]
},
<<"members">> => [
#{<<"user">> => #{<<"id">> => <<"1001">>, <<"flags">> => <<"0">>}},
#{<<"user">> => #{<<"id">> => <<"1002">>, <<"flags">> => <<"1">>}}
]
}
},
UpdatedState = handle_unavailability_transition(OldState, NewState),
?assertEqual(#{}, maps:get(sessions, UpdatedState, #{})),
?assertEqual(unavailable_for_everyone, get_cached_unavailability_mode(GuildId)),
receive
{user_one, {'$gen_cast', {guild_leave, GuildId, forced_unavailable}}} -> ok
after 1000 ->
?assert(false)
end,
receive
{user_two, {'$gen_cast', {guild_leave, GuildId, forced_unavailable}}} -> ok
after 1000 ->
?assert(false)
end
after
cleanup_unavailability_cache(GuildId),
UserOnePid ! stop,
UserTwoPid ! stop
end.
is_guild_unavailable_for_user_from_cache_is_staff_test() ->
GuildId = 99003,
try
set_cached_unavailability_mode(GuildId, unavailable_for_everyone_but_staff),
?assertEqual(
true,
is_guild_unavailable_for_user_from_cache(
GuildId,
#{<<"is_staff">> => false}
)
),
?assertEqual(
false,
is_guild_unavailable_for_user_from_cache(
GuildId,
#{<<"is_staff">> => true}
)
)
after
cleanup_unavailability_cache(GuildId)
end.
-spec start_session_capture(atom(), pid()) -> pid().
start_session_capture(Tag, Parent) ->
spawn(fun() -> session_capture_loop(Tag, Parent) end).
-spec session_capture_loop(atom(), pid()) -> ok.
session_capture_loop(Tag, Parent) ->
receive
stop ->
ok;
{'$gen_cast', Msg} ->
Parent ! {Tag, {'$gen_cast', Msg}},
session_capture_loop(Tag, Parent);
_Other ->
session_capture_loop(Tag, Parent)
end.
-spec state_for_unavailability_transition_test(guild_id(), pid(), pid()) -> guild_state().
state_for_unavailability_transition_test(GuildId, NonStaffPid, StaffPid) ->
#{
id => GuildId,
sessions => #{
<<"non_staff">> => #{
session_id => <<"non_staff">>,
user_id => 1001,
pid => NonStaffPid,
mref => make_ref(),
active_guilds => sets:new(),
user_roles => [],
bot => false,
is_staff => false,
previous_passive_updates => #{},
previous_passive_channel_versions => #{},
previous_passive_voice_states => #{}
},
<<"staff">> => #{
session_id => <<"staff">>,
user_id => 1002,
pid => StaffPid,
mref => make_ref(),
active_guilds => sets:new(),
user_roles => [],
bot => false,
is_staff => true,
previous_passive_updates => #{},
previous_passive_channel_versions => #{},
previous_passive_voice_states => #{}
}
},
presence_subscriptions => #{},
member_list_subscriptions => #{},
member_subscriptions => guild_subscriptions:init_state(),
data => #{
<<"guild">> => #{
<<"features">> => []
},
<<"members">> => [
#{<<"user">> => #{<<"id">> => <<"1001">>, <<"flags">> => <<"0">>}},
#{<<"user">> => #{<<"id">> => <<"1002">>, <<"flags">> => <<"1">>}}
]
},
voice_states => #{},
pending_voice_connections => #{}
}.
is_guild_unavailable_for_user_unavailable_for_everyone_test() ->
State = #{
data => #{
<<"guild">> => #{
<<"features">> => [<<"UNAVAILABLE_FOR_EVERYONE">>]
},
<<"members">> => []
}
},
?assertEqual(true, is_guild_unavailable_for_user(123, State)).
is_guild_unavailable_for_user_available_test() ->
State = #{
data => #{
<<"guild">> => #{
<<"features">> => []
},
<<"members">> => []
}
},
?assertEqual(false, is_guild_unavailable_for_user(123, State)).
check_unavailability_transition_no_change_test() ->
State = #{
data => #{
<<"guild">> => #{
<<"features">> => []
}
}
},
?assertEqual(no_change, check_unavailability_transition(State, State)).
check_unavailability_transition_enabled_test() ->
OldState = #{
data => #{
<<"guild">> => #{
<<"features">> => []
}
}
},
NewState = #{
data => #{
<<"guild">> => #{
<<"features">> => [<<"UNAVAILABLE_FOR_EVERYONE">>]
}
}
},
?assertEqual({unavailable_enabled, false}, check_unavailability_transition(OldState, NewState)).
check_unavailability_transition_disabled_test() ->
OldState = #{
data => #{
<<"guild">> => #{
<<"features">> => [<<"UNAVAILABLE_FOR_EVERYONE">>]
}
}
},
NewState = #{
data => #{
<<"guild">> => #{
<<"features">> => []
}
}
},
?assertEqual(unavailable_disabled, check_unavailability_transition(OldState, NewState)).
-endif.

View File

@@ -17,9 +17,7 @@
-module(guild_client).
-export([
voice_state_update/3
]).
-export([voice_state_update/3]).
-export_type([
voice_state_update_success/0,
@@ -27,7 +25,10 @@
voice_state_update_result/0
]).
-define(DEFAULT_TIMEOUT, 12000).
-define(CIRCUIT_BREAKER_TABLE, guild_circuit_breaker).
-define(FAILURE_THRESHOLD, 5).
-define(RECOVERY_TIMEOUT_MS, 30000).
-define(MAX_CONCURRENT, 50).
-type voice_state_update_success() :: #{
success := true,
@@ -44,10 +45,39 @@
{ok, voice_state_update_success()}
| {error, timeout}
| {error, noproc}
| {error, circuit_breaker_open}
| {error, too_many_requests}
| {error, atom(), atom()}.
-type circuit_state() :: closed | open | half_open.
-spec voice_state_update(pid(), map(), timeout()) -> voice_state_update_result().
voice_state_update(GuildPid, Request, Timeout) ->
ensure_table(),
case acquire_slot(GuildPid) of
ok ->
try
execute_with_circuit_breaker(GuildPid, Request, Timeout)
after
release_slot(GuildPid)
end;
{error, Reason} ->
{error, Reason}
end.
-spec execute_with_circuit_breaker(pid(), map(), timeout()) -> voice_state_update_result().
execute_with_circuit_breaker(GuildPid, Request, Timeout) ->
case get_circuit_state(GuildPid) of
open ->
{error, circuit_breaker_open};
State when State =:= closed; State =:= half_open ->
Result = do_call(GuildPid, Request, Timeout),
update_circuit_state(GuildPid, Result, State),
Result
end.
-spec do_call(pid(), map(), timeout()) -> voice_state_update_result().
do_call(GuildPid, Request, Timeout) ->
try gen_server:call(GuildPid, {voice_state_update, Request}, Timeout) of
Response when is_map(Response) ->
case maps:get(success, Response, false) of
@@ -57,12 +87,138 @@ voice_state_update(GuildPid, Request, Timeout) ->
{error, Category, ErrorAtom} when is_atom(Category), is_atom(ErrorAtom) ->
{error, Category, ErrorAtom}
catch
exit:{timeout, _} ->
{error, timeout};
exit:{noproc, _} ->
{error, noproc};
exit:{normal, _} ->
{error, noproc}
exit:{timeout, _} -> {error, timeout};
exit:{noproc, _} -> {error, noproc};
exit:{normal, _} -> {error, noproc}
end.
-spec get_circuit_state(pid()) -> circuit_state().
get_circuit_state(GuildPid) ->
case safe_lookup(GuildPid) of
[] ->
closed;
[{_, #{state := open, opened_at := OpenedAt}}] ->
Now = erlang:system_time(millisecond),
case Now - OpenedAt > ?RECOVERY_TIMEOUT_MS of
true -> half_open;
false -> open
end;
[{_, #{state := State}}] ->
State
end.
-spec update_circuit_state(pid(), voice_state_update_result(), circuit_state()) -> ok.
update_circuit_state(GuildPid, Result, PrevState) ->
IsSuccess = is_success_result(Result),
case {IsSuccess, PrevState} of
{true, half_open} ->
ets:delete(?CIRCUIT_BREAKER_TABLE, GuildPid),
ok;
{true, closed} ->
reset_failures(GuildPid);
{false, _} ->
record_failure(GuildPid)
end.
-spec is_success_result(voice_state_update_result()) -> boolean().
is_success_result({ok, _}) -> true;
is_success_result(_) -> false.
-spec reset_failures(pid()) -> ok.
reset_failures(GuildPid) ->
case safe_lookup(GuildPid) of
[{_, State}] ->
ets:insert(?CIRCUIT_BREAKER_TABLE, {GuildPid, State#{failures => 0}}),
ok;
[] ->
ok
end.
-spec record_failure(pid()) -> ok.
record_failure(GuildPid) ->
Now = erlang:system_time(millisecond),
case safe_lookup(GuildPid) of
[] ->
ets:insert(
?CIRCUIT_BREAKER_TABLE,
{GuildPid, #{
state => closed,
failures => 1,
concurrent => 0
}}
),
ok;
[{_, #{failures := F} = State}] when F + 1 >= ?FAILURE_THRESHOLD ->
ets:insert(
?CIRCUIT_BREAKER_TABLE,
{GuildPid, State#{
state => open,
failures => F + 1,
opened_at => Now
}}
),
ok;
[{_, #{failures := F} = State}] ->
ets:insert(?CIRCUIT_BREAKER_TABLE, {GuildPid, State#{failures => F + 1}}),
ok
end.
-spec acquire_slot(pid()) -> ok | {error, too_many_requests}.
acquire_slot(GuildPid) ->
case safe_lookup(GuildPid) of
[] ->
ets:insert(
?CIRCUIT_BREAKER_TABLE,
{GuildPid, #{
state => closed,
failures => 0,
concurrent => 1
}}
),
ok;
[{_, #{concurrent := C}}] when C >= ?MAX_CONCURRENT ->
{error, too_many_requests};
[{_, #{concurrent := C} = State}] ->
ets:insert(?CIRCUIT_BREAKER_TABLE, {GuildPid, State#{concurrent => C + 1}}),
ok
end.
-spec release_slot(pid()) -> ok.
release_slot(GuildPid) ->
case safe_lookup(GuildPid) of
[{_, #{concurrent := C} = State}] when C > 0 ->
ets:insert(?CIRCUIT_BREAKER_TABLE, {GuildPid, State#{concurrent => C - 1}}),
ok;
_ ->
ok
end.
-spec safe_lookup(pid()) -> list().
safe_lookup(GuildPid) ->
try ets:lookup(?CIRCUIT_BREAKER_TABLE, GuildPid) of
Result -> Result
catch
error:badarg -> []
end.
-spec ensure_table() -> ok.
ensure_table() ->
case ets:whereis(?CIRCUIT_BREAKER_TABLE) of
undefined ->
try
ets:new(?CIRCUIT_BREAKER_TABLE, [
named_table,
public,
set,
{read_concurrency, true},
{write_concurrency, true}
]),
ok
catch
error:badarg -> ok
end;
_ ->
ok
end.
-ifdef(TEST).
@@ -72,4 +228,136 @@ module_exports_test() ->
Exports = guild_client:module_info(exports),
?assert(lists:member({voice_state_update, 3}, Exports)).
ensure_table_creates_table_test() ->
catch ets:delete(?CIRCUIT_BREAKER_TABLE),
?assertEqual(undefined, ets:whereis(?CIRCUIT_BREAKER_TABLE)),
ensure_table(),
?assertNotEqual(undefined, ets:whereis(?CIRCUIT_BREAKER_TABLE)).
ensure_table_idempotent_test() ->
ensure_table(),
ensure_table(),
?assertNotEqual(undefined, ets:whereis(?CIRCUIT_BREAKER_TABLE)).
acquire_slot_creates_entry_test() ->
ensure_table(),
Pid = spawn(fun() ->
receive
done -> ok
end
end),
ets:delete_all_objects(?CIRCUIT_BREAKER_TABLE),
?assertEqual(ok, acquire_slot(Pid)),
[{Pid, State}] = ets:lookup(?CIRCUIT_BREAKER_TABLE, Pid),
?assertEqual(1, maps:get(concurrent, State)),
Pid ! done.
acquire_slot_increments_test() ->
ensure_table(),
Pid = spawn(fun() ->
receive
done -> ok
end
end),
ets:delete_all_objects(?CIRCUIT_BREAKER_TABLE),
acquire_slot(Pid),
acquire_slot(Pid),
[{Pid, State}] = ets:lookup(?CIRCUIT_BREAKER_TABLE, Pid),
?assertEqual(2, maps:get(concurrent, State)),
Pid ! done.
release_slot_decrements_test() ->
ensure_table(),
Pid = spawn(fun() ->
receive
done -> ok
end
end),
ets:delete_all_objects(?CIRCUIT_BREAKER_TABLE),
acquire_slot(Pid),
acquire_slot(Pid),
release_slot(Pid),
[{Pid, State}] = ets:lookup(?CIRCUIT_BREAKER_TABLE, Pid),
?assertEqual(1, maps:get(concurrent, State)),
Pid ! done.
get_circuit_state_closed_test() ->
ensure_table(),
Pid = spawn(fun() ->
receive
done -> ok
end
end),
ets:delete_all_objects(?CIRCUIT_BREAKER_TABLE),
?assertEqual(closed, get_circuit_state(Pid)),
Pid ! done.
get_circuit_state_open_test() ->
ensure_table(),
Pid = spawn(fun() ->
receive
done -> ok
end
end),
ets:delete_all_objects(?CIRCUIT_BREAKER_TABLE),
Now = erlang:system_time(millisecond),
ets:insert(
?CIRCUIT_BREAKER_TABLE,
{Pid, #{
state => open,
failures => 5,
concurrent => 0,
opened_at => Now
}}
),
?assertEqual(open, get_circuit_state(Pid)),
Pid ! done.
get_circuit_state_half_open_test() ->
ensure_table(),
Pid = spawn(fun() ->
receive
done -> ok
end
end),
ets:delete_all_objects(?CIRCUIT_BREAKER_TABLE),
OldTime = erlang:system_time(millisecond) - ?RECOVERY_TIMEOUT_MS - 1000,
ets:insert(
?CIRCUIT_BREAKER_TABLE,
{Pid, #{
state => open,
failures => 5,
concurrent => 0,
opened_at => OldTime
}}
),
?assertEqual(half_open, get_circuit_state(Pid)),
Pid ! done.
record_failure_opens_circuit_test() ->
ensure_table(),
Pid = spawn(fun() ->
receive
done -> ok
end
end),
ets:delete_all_objects(?CIRCUIT_BREAKER_TABLE),
ets:insert(
?CIRCUIT_BREAKER_TABLE,
{Pid, #{
state => closed,
failures => ?FAILURE_THRESHOLD - 1,
concurrent => 0
}}
),
record_failure(Pid),
[{Pid, State}] = ets:lookup(?CIRCUIT_BREAKER_TABLE, Pid),
?assertEqual(open, maps:get(state, State)),
Pid ! done.
is_success_result_test() ->
?assertEqual(true, is_success_result({ok, #{}})),
?assertEqual(false, is_success_result({error, timeout})),
?assertEqual(false, is_success_result({error, noproc})).
-endif.

View File

@@ -31,6 +31,7 @@
-type guild_data_map() :: map().
-type guild_member() :: map().
-type channel_list() :: [map()].
-type user_id() :: integer().
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
@@ -45,11 +46,10 @@ get_guild_data(#{user_id := UserId}, State) ->
Reply = #{guild_data => GuildData},
{reply, Reply, State};
_ ->
Members = map_utils:ensure_list(maps:get(<<"members">>, Data, [])),
case member_in_list(UserId, Members) of
false ->
case guild_data_index:get_member(UserId, Data) of
undefined ->
{reply, #{guild_data => null, error_reason => <<"forbidden">>}, State};
true ->
_ ->
GuildData = build_complete_guild_data(Data, State),
{reply, #{guild_data => GuildData}, State}
end
@@ -76,7 +76,7 @@ has_member(#{user_id := UserId}, State) ->
-spec list_guild_members(map(), guild_state()) -> guild_reply(map()).
list_guild_members(#{limit := Limit, offset := Offset}, State) ->
Data = guild_data_map(State),
AllMembers = map_utils:ensure_list(maps:get(<<"members">>, Data, [])),
AllMembers = guild_data_index:member_list(Data),
TotalCount = length(AllMembers),
PaginatedMembers = paginate_members(AllMembers, Limit, Offset),
{reply, #{members => PaginatedMembers, total => TotalCount}, State}.
@@ -93,19 +93,24 @@ get_first_viewable_text_channel(State) ->
EveryoneChannelId = find_everyone_viewable_text_channel(Channels, State),
{reply, #{channel_id => EveryoneChannelId}, State}.
-spec get_guild_state(integer(), guild_state()) -> map().
-spec get_guild_state(user_id(), guild_state()) -> map().
get_guild_state(UserId, State) ->
Data = guild_data_map(State),
GuildId = map_utils:get_integer(State, id, 0),
AllChannels = channels_from_data(Data),
AllMembers = map_utils:ensure_list(maps:get(<<"members">>, Data, [])),
AllMembers = guild_data_index:member_values(Data),
Member = find_member_by_user_id(UserId, State),
{ViewableChannels, JoinedAt} = derive_member_view(UserId, Member, State, AllChannels),
OnlineCount = guild_member_list:get_online_count(State),
OwnMemberList = case Member of
undefined -> [];
M -> [M]
end,
OwnMemberList =
case Member of
undefined -> [];
M -> [M]
end,
VoiceStates = guild_voice:get_voice_states_list(State),
VoiceMembers = voice_members_from_states(VoiceStates, AllMembers),
Members = merge_members(OwnMemberList, VoiceMembers),
MemberCount = maps:get(member_count, State, length(AllMembers)),
#{
<<"id">> => integer_to_binary(GuildId),
<<"properties">> => maps:get(<<"guild">>, Data, #{}),
@@ -113,11 +118,11 @@ get_guild_state(UserId, State) ->
<<"channels">> => ViewableChannels,
<<"emojis">> => maps:get(<<"emojis">>, Data, []),
<<"stickers">> => maps:get(<<"stickers">>, Data, []),
<<"members">> => OwnMemberList,
<<"member_count">> => length(AllMembers),
<<"members">> => Members,
<<"member_count">> => MemberCount,
<<"online_count">> => OnlineCount,
<<"presences">> => [],
<<"voice_states">> => guild_voice:get_voice_states_list(State),
<<"voice_states">> => VoiceStates,
<<"joined_at">> => JoinedAt
}.
@@ -140,6 +145,7 @@ find_everyone_viewable_text_channel(Channels, State) ->
map_utils:ensure_list(Channels)
).
-spec find_member_by_user_id(user_id(), guild_state()) -> guild_member() | undefined.
find_member_by_user_id(UserId, State) ->
guild_permissions:find_member_by_user_id(UserId, State).
@@ -164,19 +170,7 @@ channels_from_state(State) ->
-spec channels_from_data(guild_data_map()) -> channel_list().
channels_from_data(Data) ->
map_utils:ensure_list(maps:get(<<"channels">>, Data, [])).
-spec member_in_list(integer(), [guild_member()]) -> boolean().
member_in_list(UserId, Members) ->
lists:any(fun(Member) -> member_matches(UserId, Member) end, Members).
-spec member_matches(integer(), guild_member()) -> boolean().
member_matches(UserId, Member) ->
MemberUser = map_utils:ensure_map(maps:get(<<"user">>, Member, #{})),
case map_utils:get_integer(MemberUser, <<"id">>, undefined) of
undefined -> false;
Id -> Id =:= UserId
end.
guild_data_index:channel_list(Data).
-spec paginate_members([guild_member()], non_neg_integer(), non_neg_integer()) -> [guild_member()].
paginate_members(Members, Limit, Offset) ->
@@ -191,7 +185,7 @@ paginate_members(Members, Limit, Offset) ->
end
end.
-spec derive_member_view(integer(), guild_member() | undefined, guild_state(), channel_list()) ->
-spec derive_member_view(user_id(), guild_member() | undefined, guild_state(), channel_list()) ->
{channel_list(), term()}.
derive_member_view(_UserId, undefined, _State, _Channels) ->
{[], null};
@@ -210,7 +204,63 @@ derive_member_view(UserId, Member, State, Channels) ->
JoinedAt = maps:get(<<"joined_at">>, Member, null),
{Filtered, JoinedAt}.
-spec role_permissions_for_id(list(), integer()) -> integer().
-spec voice_members_from_states([map()], [guild_member()]) -> [guild_member()].
voice_members_from_states(VoiceStates, Members) ->
MemberIndex = build_member_index(Members),
lists:filtermap(
fun(VoiceState) ->
case voice_state_utils:voice_state_user_id(VoiceState) of
undefined ->
false;
UserId ->
case maps:get(UserId, MemberIndex, undefined) of
undefined -> false;
Member -> {true, Member}
end
end
end,
VoiceStates
).
-spec build_member_index([guild_member()]) -> #{integer() => guild_member()}.
build_member_index(Members) ->
lists:foldl(
fun(Member, Acc) ->
case member_user_id(Member) of
undefined -> Acc;
UserId -> maps:put(UserId, Member, Acc)
end
end,
#{},
Members
).
-spec merge_members([guild_member()], [guild_member()]) -> [guild_member()].
merge_members(Primary, Secondary) ->
{Merged, _} =
lists:foldl(
fun(Member, {Acc, Seen}) ->
case member_user_id(Member) of
undefined ->
{Acc, Seen};
UserId ->
case sets:is_element(UserId, Seen) of
true -> {Acc, Seen};
false -> {[Member | Acc], sets:add_element(UserId, Seen)}
end
end
end,
{[], sets:new()},
Primary ++ Secondary
),
lists:reverse(Merged).
-spec member_user_id(guild_member()) -> integer() | undefined.
member_user_id(Member) ->
MemberUser = map_utils:ensure_map(maps:get(<<"user">>, Member, #{})),
map_utils:get_integer(MemberUser, <<"id">>, undefined).
-spec role_permissions_for_id([map()], integer()) -> integer().
role_permissions_for_id(Roles, GuildId) ->
lists:foldl(
fun(Role, Acc) ->
@@ -229,6 +279,10 @@ select_first_viewable(Channel, GuildId, BasePerms) ->
ChannelId = map_utils:get_integer(Channel, <<"id">>, undefined),
select_first_viewable(ChannelType, ChannelId, Channel, GuildId, BasePerms).
-spec select_first_viewable(
integer() | undefined, integer() | undefined, map(), integer(), integer()
) ->
integer() | null.
select_first_viewable(0, ChannelId, Channel, GuildId, BasePerms) when is_integer(ChannelId) ->
case (BasePerms band constants:administrator_permission()) =/= 0 of
true ->
@@ -273,6 +327,13 @@ find_everyone_viewable_text_channel_test() ->
ChannelId = find_everyone_viewable_text_channel(Channels, State),
?assertEqual(500, ChannelId).
paginate_members_test() ->
Members = [#{<<"id">> => 1}, #{<<"id">> => 2}, #{<<"id">> => 3}],
?assertEqual([#{<<"id">> => 1}, #{<<"id">> => 2}], paginate_members(Members, 2, 0)),
?assertEqual([#{<<"id">> => 2}, #{<<"id">> => 3}], paginate_members(Members, 2, 1)),
?assertEqual([#{<<"id">> => 3}], paginate_members(Members, 2, 2)),
?assertEqual([], paginate_members(Members, 2, 5)).
test_state() ->
GuildId = 100,
ViewPerm = constants:view_channel_permission(),

View File

@@ -0,0 +1,480 @@
%% Copyright (C) 2026 Fluxer Contributors
%%
%% This file is part of Fluxer.
%%
%% Fluxer is free software: you can redistribute it and/or modify
%% it under the terms of the GNU Affero General Public License as published by
%% the Free Software Foundation, either version 3 of the License, or
%% (at your option) any later version.
%%
%% Fluxer is distributed in the hope that it will be useful,
%% but WITHOUT ANY WARRANTY; without even the implied warranty of
%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
%% GNU Affero General Public License for more details.
%%
%% You should have received a copy of the GNU Affero General Public License
%% along with Fluxer. If not, see <https://www.gnu.org/licenses/>.
-module(guild_data_index).
-export([
normalize_data/1,
member_map/1,
member_values/1,
member_list/1,
member_count/1,
member_ids/1,
member_role_index/1,
get_member/2,
put_member/2,
put_member_map/2,
put_member_list/2,
remove_member/2,
role_list/1,
role_index/1,
put_roles/2,
channel_list/1,
channel_index/1,
put_channels/2
]).
-type guild_data() :: map().
-type member() :: map().
-type role() :: map().
-type channel() :: map().
-type user_id() :: integer().
-type snowflake_id() :: integer().
-type role_member_index() :: #{snowflake_id() => #{user_id() => true}}.
-spec normalize_data(guild_data()) -> guild_data().
normalize_data(Data) when is_map(Data) ->
MemberMap = member_map(Data),
Roles = role_list(Data),
Channels = channel_list(Data),
Data1 = maps:put(<<"members">>, MemberMap, Data),
Data2 = maps:put(<<"roles">>, Roles, Data1),
Data3 = maps:put(<<"channels">>, Channels, Data2),
Data4 = maps:put(<<"role_index">>, build_id_index(Roles), Data3),
Data5 = maps:put(<<"channel_index">>, build_id_index(Channels), Data4),
maps:put(<<"member_role_index">>, build_member_role_index(MemberMap), Data5);
normalize_data(Data) ->
Data.
-spec member_map(guild_data()) -> #{user_id() => member()}.
member_map(Data) when is_map(Data) ->
case maps:get(<<"members">>, Data, #{}) of
Members when is_map(Members) ->
normalize_member_map(Members);
Members when is_list(Members) ->
build_member_map(Members);
_ ->
#{}
end;
member_map(_) ->
#{}.
-spec member_list(guild_data()) -> [member()].
member_list(Data) ->
MemberPairs = maps:to_list(member_map(Data)),
SortedPairs = lists:sort(fun({A, _}, {B, _}) -> A =< B end, MemberPairs),
[Member || {_UserId, Member} <- SortedPairs].
-spec member_values(guild_data()) -> [member()].
member_values(Data) ->
maps:values(member_map(Data)).
-spec member_count(guild_data()) -> non_neg_integer().
member_count(Data) ->
map_size(member_map(Data)).
-spec member_ids(guild_data()) -> [user_id()].
member_ids(Data) ->
maps:keys(member_map(Data)).
-spec member_role_index(guild_data()) -> role_member_index().
member_role_index(Data) when is_map(Data) ->
case maps:get(<<"member_role_index">>, Data, undefined) of
Index when is_map(Index) -> normalize_member_role_index(Index);
_ -> build_member_role_index(member_map(Data))
end;
member_role_index(_) ->
#{}.
-spec get_member(user_id(), guild_data()) -> member() | undefined.
get_member(UserId, Data) when is_integer(UserId) ->
maps:get(UserId, member_map(Data), undefined);
get_member(_, _) ->
undefined.
-spec put_member(member(), guild_data()) -> guild_data().
put_member(Member, Data) when is_map(Member), is_map(Data) ->
case member_user_id(Member) of
undefined ->
Data;
UserId ->
MemberMap = member_map(Data),
ExistingMember = maps:get(UserId, MemberMap, undefined),
ExistingRoles = member_role_ids(ExistingMember),
UpdatedRoles = member_role_ids(Member),
RoleIndex = member_role_index(Data),
RoleIndex1 = remove_user_from_member_role_index(UserId, ExistingRoles, RoleIndex),
RoleIndex2 = add_user_to_member_role_index(UserId, UpdatedRoles, RoleIndex1),
Data1 = maps:put(<<"members">>, maps:put(UserId, Member, MemberMap), Data),
maps:put(<<"member_role_index">>, RoleIndex2, Data1)
end;
put_member(_, Data) ->
Data.
-spec put_member_map(#{user_id() => member()}, guild_data()) -> guild_data().
put_member_map(MemberMap, Data) when is_map(MemberMap), is_map(Data) ->
NormalizedMemberMap = normalize_member_map(MemberMap),
Data1 = maps:put(<<"members">>, NormalizedMemberMap, Data),
maps:put(<<"member_role_index">>, build_member_role_index(NormalizedMemberMap), Data1);
put_member_map(_, Data) ->
Data.
-spec put_member_list([member()], guild_data()) -> guild_data().
put_member_list(Members, Data) when is_list(Members), is_map(Data) ->
put_member_map(build_member_map(Members), Data);
put_member_list(_, Data) ->
Data.
-spec remove_member(user_id(), guild_data()) -> guild_data().
remove_member(UserId, Data) when is_integer(UserId), is_map(Data) ->
MemberMap = member_map(Data),
Member = maps:get(UserId, MemberMap, undefined),
MemberRoles = member_role_ids(Member),
RoleIndex = member_role_index(Data),
RoleIndex1 = remove_user_from_member_role_index(UserId, MemberRoles, RoleIndex),
Data1 = maps:put(<<"members">>, maps:remove(UserId, MemberMap), Data),
maps:put(<<"member_role_index">>, RoleIndex1, Data1);
remove_member(_, Data) ->
Data.
-spec role_list(guild_data()) -> [role()].
role_list(Data) when is_map(Data) ->
ensure_list(maps:get(<<"roles">>, Data, []));
role_list(_) ->
[].
-spec role_index(guild_data()) -> #{snowflake_id() => role()}.
role_index(Data) when is_map(Data) ->
case maps:get(<<"role_index">>, Data, undefined) of
Index when is_map(Index) -> normalize_id_index(Index);
_ -> build_id_index(role_list(Data))
end;
role_index(_) ->
#{}.
-spec put_roles([role()], guild_data()) -> guild_data().
put_roles(Roles, Data) when is_map(Data) ->
RoleList = ensure_list(Roles),
Data1 = maps:put(<<"roles">>, RoleList, Data),
maps:put(<<"role_index">>, build_id_index(RoleList), Data1);
put_roles(_, Data) ->
Data.
-spec channel_list(guild_data()) -> [channel()].
channel_list(Data) when is_map(Data) ->
ensure_list(maps:get(<<"channels">>, Data, []));
channel_list(_) ->
[].
-spec channel_index(guild_data()) -> #{snowflake_id() => channel()}.
channel_index(Data) when is_map(Data) ->
case maps:get(<<"channel_index">>, Data, undefined) of
Index when is_map(Index) -> normalize_id_index(Index);
_ -> build_id_index(channel_list(Data))
end;
channel_index(_) ->
#{}.
-spec put_channels([channel()], guild_data()) -> guild_data().
put_channels(Channels, Data) when is_map(Data) ->
ChannelList = ensure_list(Channels),
Data1 = maps:put(<<"channels">>, ChannelList, Data),
maps:put(<<"channel_index">>, build_id_index(ChannelList), Data1);
put_channels(_, Data) ->
Data.
-spec build_member_map([member()]) -> #{user_id() => member()}.
build_member_map(Members) ->
lists:foldl(
fun(Member, Acc) ->
case member_user_id(Member) of
undefined ->
Acc;
UserId ->
maps:put(UserId, Member, Acc)
end
end,
#{},
Members
).
-spec normalize_member_map(map()) -> #{user_id() => member()}.
normalize_member_map(MemberMap) ->
maps:fold(
fun(Key, Member, Acc) ->
case normalize_member_key(Key, Member) of
undefined ->
Acc;
UserId ->
maps:put(UserId, Member, Acc)
end
end,
#{},
MemberMap
).
-spec normalize_member_key(term(), member()) -> user_id() | undefined.
normalize_member_key(Key, Member) ->
case type_conv:to_integer(Key) of
undefined -> member_user_id(Member);
UserId -> UserId
end.
-spec member_user_id(member()) -> user_id() | undefined.
member_user_id(Member) when is_map(Member) ->
User = maps:get(<<"user">>, Member, #{}),
map_utils:get_integer(User, <<"id">>, undefined);
member_user_id(_) ->
undefined.
-spec member_role_ids(term()) -> [snowflake_id()].
member_role_ids(Member) when is_map(Member) ->
extract_integer_list(maps:get(<<"roles">>, Member, []));
member_role_ids(_) ->
[].
-spec build_member_role_index(#{user_id() => member()}) -> role_member_index().
build_member_role_index(MemberMap) ->
maps:fold(
fun(UserId, Member, Acc) ->
add_user_to_member_role_index(UserId, member_role_ids(Member), Acc)
end,
#{},
MemberMap
).
-spec normalize_member_role_index(map()) -> role_member_index().
normalize_member_role_index(Index) ->
maps:fold(
fun(RoleKey, Members, Acc) ->
case type_conv:to_integer(RoleKey) of
undefined ->
Acc;
RoleId ->
NormalizedMembers = normalize_member_role_members(Members),
case map_size(NormalizedMembers) of
0 ->
Acc;
_ ->
maps:put(RoleId, NormalizedMembers, Acc)
end
end
end,
#{},
Index
).
-spec normalize_member_role_members(term()) -> #{user_id() => true}.
normalize_member_role_members(Members) when is_map(Members) ->
maps:fold(
fun(UserKey, _Flag, Acc) ->
case type_conv:to_integer(UserKey) of
undefined ->
Acc;
UserId ->
maps:put(UserId, true, Acc)
end
end,
#{},
Members
);
normalize_member_role_members(_) ->
#{}.
-spec add_user_to_member_role_index(user_id(), [snowflake_id()], role_member_index()) ->
role_member_index().
add_user_to_member_role_index(UserId, RoleIds, RoleIndex) when is_integer(UserId) ->
lists:foldl(
fun(RoleId, Acc) ->
RoleMembers = maps:get(RoleId, Acc, #{}),
maps:put(RoleId, maps:put(UserId, true, RoleMembers), Acc)
end,
RoleIndex,
RoleIds
).
-spec remove_user_from_member_role_index(user_id(), [snowflake_id()], role_member_index()) ->
role_member_index().
remove_user_from_member_role_index(UserId, RoleIds, RoleIndex) when is_integer(UserId) ->
lists:foldl(
fun(RoleId, Acc) ->
RoleMembers = maps:get(RoleId, Acc, #{}),
UpdatedRoleMembers = maps:remove(UserId, RoleMembers),
case map_size(UpdatedRoleMembers) of
0 ->
maps:remove(RoleId, Acc);
_ ->
maps:put(RoleId, UpdatedRoleMembers, Acc)
end
end,
RoleIndex,
RoleIds
).
-spec build_id_index([map()]) -> #{snowflake_id() => map()}.
build_id_index(Items) ->
lists:foldl(
fun(Item, Acc) ->
case map_utils:get_integer(Item, <<"id">>, undefined) of
undefined ->
Acc;
Id ->
maps:put(Id, Item, Acc)
end
end,
#{},
Items
).
-spec normalize_id_index(map()) -> #{snowflake_id() => map()}.
normalize_id_index(Index) ->
maps:fold(
fun(Key, Item, Acc) ->
case type_conv:to_integer(Key) of
undefined ->
case map_utils:get_integer(Item, <<"id">>, undefined) of
undefined -> Acc;
Id -> maps:put(Id, Item, Acc)
end;
Id ->
maps:put(Id, Item, Acc)
end
end,
#{},
Index
).
-spec ensure_list(term()) -> list().
ensure_list(List) when is_list(List) ->
List;
ensure_list(_) ->
[].
-spec extract_integer_list(term()) -> [integer()].
extract_integer_list(List) when is_list(List) ->
lists:reverse(
lists:foldl(
fun(Value, Acc) ->
case type_conv:to_integer(Value) of
undefined -> Acc;
Int -> [Int | Acc]
end
end,
[],
List
)
);
extract_integer_list(_) ->
[].
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
member_map_from_list_test() ->
Data = #{
<<"members">> => [
#{<<"user">> => #{<<"id">> => <<"2">>}, <<"nick">> => <<"beta">>},
#{<<"user">> => #{<<"id">> => <<"1">>}, <<"nick">> => <<"alpha">>}
]
},
MemberMap = member_map(Data),
?assertMatch(#{1 := _, 2 := _}, MemberMap),
?assertEqual(2, member_count(Data)).
member_list_is_sorted_test() ->
Data = #{
<<"members">> => #{
5 => #{<<"user">> => #{<<"id">> => <<"5">>}},
2 => #{<<"user">> => #{<<"id">> => <<"2">>}}
}
},
[First, Second] = member_list(Data),
?assertEqual(2, member_user_id(First)),
?assertEqual(5, member_user_id(Second)).
member_values_returns_members_without_sorting_test() ->
Data = #{
<<"members">> => #{
9 => #{<<"user">> => #{<<"id">> => <<"9">>}},
2 => #{<<"user">> => #{<<"id">> => <<"2">>}}
}
},
Values = member_values(Data),
?assertEqual(2, length(Values)).
put_member_updates_entry_test() ->
Data = #{
<<"members">> => #{
10 => #{<<"user">> => #{<<"id">> => <<"10">>}, <<"nick">> => <<"old">>}
}
},
UpdatedData = put_member(
#{<<"user">> => #{<<"id">> => <<"10">>}, <<"nick">> => <<"new">>},
Data
),
UpdatedMember = get_member(10, UpdatedData),
?assertEqual(<<"new">>, maps:get(<<"nick">>, UpdatedMember)).
remove_member_removes_entry_test() ->
Data = #{
<<"members">> => #{
10 => #{<<"user">> => #{<<"id">> => <<"10">>}}
}
},
UpdatedData = remove_member(10, Data),
?assertEqual(undefined, get_member(10, UpdatedData)).
normalize_data_builds_indexes_test() ->
Data = #{
<<"members">> => [#{<<"user">> => #{<<"id">> => <<"1">>}}],
<<"roles">> => [#{<<"id">> => <<"100">>}],
<<"channels">> => [#{<<"id">> => <<"200">>}]
},
Normalized = normalize_data(Data),
?assert(is_map(maps:get(<<"members">>, Normalized))),
?assertMatch(#{100 := _}, role_index(Normalized)),
?assertMatch(#{200 := _}, channel_index(Normalized)).
member_role_index_builds_role_to_user_lookup_test() ->
Data = #{
<<"members">> => #{
1 => #{<<"user">> => #{<<"id">> => <<"1">>}, <<"roles">> => [<<"10">>, <<"11">>]},
2 => #{<<"user">> => #{<<"id">> => <<"2">>}, <<"roles">> => [<<"11">>]}
}
},
Index = member_role_index(Data),
?assertEqual(#{1 => true}, maps:get(10, Index)),
?assertEqual(#{1 => true, 2 => true}, maps:get(11, Index)).
put_member_and_remove_member_keep_member_role_index_in_sync_test() ->
Data0 = #{
<<"members">> => #{
3 => #{<<"user">> => #{<<"id">> => <<"3">>}, <<"roles">> => [<<"20">>]}
}
},
Data1 = put_member(
#{<<"user">> => #{<<"id">> => <<"3">>}, <<"roles">> => [<<"30">>]},
Data0
),
Index1 = member_role_index(Data1),
?assertEqual(undefined, maps:get(20, Index1, undefined)),
?assertEqual(#{3 => true}, maps:get(30, Index1)),
Data2 = remove_member(3, Data1),
Index2 = member_role_index(Data2),
?assertEqual(undefined, maps:get(30, Index2, undefined)).
-endif.

File diff suppressed because it is too large Load Diff

View File

@@ -23,23 +23,15 @@
-export([start_link/0]).
-export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]).
-define(GUILD_PID_CACHE, guild_pid_cache).
-type guild_id() :: integer().
-type shard_map() :: #{pid => pid(), ref => reference()}.
-type shard_map() :: #{pid := pid(), ref := reference()}.
-type state() :: #{
shards => #{non_neg_integer() => shard_map()},
shard_count => pos_integer()
shards := #{non_neg_integer() => shard_map()},
shard_count := pos_integer()
}.
-record(shard, {
pid :: pid(),
ref :: reference()
}).
-record(state, {
shards = #{} :: #{non_neg_integer() => #shard{}},
shard_count = 1 :: pos_integer()
}).
-spec start_link() -> {ok, pid()} | {error, term()}.
start_link() ->
gen_server:start_link({local, ?MODULE}, ?MODULE, [], []).
@@ -47,23 +39,23 @@ start_link() ->
-spec init(list()) -> {ok, state()}.
init([]) ->
process_flag(trap_exit, true),
{ShardCount, Source} = determine_shard_count(),
ShardMap = start_shards(ShardCount, #{}),
maybe_log_shard_source(guild_manager, ShardCount, Source),
ets:new(?GUILD_PID_CACHE, [named_table, public, set, {read_concurrency, true}]),
{ShardCount, _Source} = determine_shard_count(),
ShardMap = start_shards(ShardCount),
{ok, #{shards => ShardMap, shard_count => ShardCount}}.
-spec handle_call(term(), gen_server:from(), state()) -> {reply, term(), state()}.
handle_call({start_or_lookup, GuildId} = Request, _From, State) ->
{Reply, NewState} = forward_call(GuildId, Request, State),
handle_call({start_or_lookup, GuildId}, _From, State) ->
{Reply, NewState} = forward_call(GuildId, {start_or_lookup, GuildId}, State),
{reply, Reply, NewState};
handle_call({stop_guild, GuildId} = Request, _From, State) ->
{Reply, NewState} = forward_call(GuildId, Request, State),
handle_call({stop_guild, GuildId}, _From, State) ->
{Reply, NewState} = forward_call(GuildId, {stop_guild, GuildId}, State),
{reply, Reply, NewState};
handle_call({reload_guild, GuildId} = Request, _From, State) ->
{Reply, NewState} = forward_call(GuildId, Request, State),
handle_call({reload_guild, GuildId}, _From, State) ->
{Reply, NewState} = forward_call(GuildId, {reload_guild, GuildId}, State),
{reply, Reply, NewState};
handle_call({shutdown_guild, GuildId} = Request, _From, State) ->
{Reply, NewState} = forward_call(GuildId, Request, State),
handle_call({shutdown_guild, GuildId}, _From, State) ->
{Reply, NewState} = forward_call(GuildId, {shutdown_guild, GuildId}, State),
{reply, Reply, NewState};
handle_call({reload_all_guilds, GuildIds}, _From, State) ->
{Reply, NewState} = handle_reload_all(GuildIds, State),
@@ -74,8 +66,7 @@ handle_call(get_local_count, _From, State) ->
handle_call(get_global_count, _From, State) ->
{Count, NewState} = aggregate_counts(get_global_count, State),
{reply, {ok, Count}, NewState};
handle_call(Request, _From, State) ->
logger:warning("[guild_manager] unknown request ~p", [Request]),
handle_call(_Request, _From, State) ->
{reply, ok, State}.
-spec handle_cast(term(), state()) -> {noreply, state()}.
@@ -83,21 +74,20 @@ handle_cast(_Msg, State) ->
{noreply, State}.
-spec handle_info(term(), state()) -> {noreply, state()}.
handle_info({'DOWN', Ref, process, _Pid, Reason}, State) ->
handle_info({'DOWN', Ref, process, Pid, _Reason}, State) ->
Shards = maps:get(shards, State),
case find_shard_by_ref(Ref, Shards) of
{ok, Index} ->
logger:warning("[guild_manager] shard ~p crashed: ~p", [Index, Reason]),
{_Shard, NewState} = restart_shard(Index, State),
{noreply, NewState};
not_found ->
cleanup_guild_from_cache(Pid),
{noreply, State}
end;
handle_info({'EXIT', Pid, Reason}, State) ->
handle_info({'EXIT', Pid, _Reason}, State) ->
Shards = maps:get(shards, State),
case find_shard_by_pid(Pid, Shards) of
{ok, Index} ->
logger:warning("[guild_manager] shard ~p exited: ~p", [Index, Reason]),
{_Shard, NewState} = restart_shard(Index, State),
{noreply, NewState};
not_found ->
@@ -116,17 +106,10 @@ terminate(_Reason, State) ->
end,
maps:values(Shards)
),
catch ets:delete(?GUILD_PID_CACHE),
ok.
-spec code_change(term(), term(), term()) -> {ok, state()}.
code_change(_OldVsn, #state{shards = OldShards, shard_count = ShardCount}, _Extra) ->
NewShards = maps:map(
fun(_Index, #shard{pid = Pid, ref = Ref}) ->
#{pid => Pid, ref => Ref}
end,
OldShards
),
{ok, #{shards => NewShards, shard_count => ShardCount}};
code_change(_OldVsn, State, _Extra) when is_map(State) ->
{ok, State}.
@@ -139,19 +122,26 @@ determine_shard_count() ->
{default_shard_count(), auto}
end.
-spec start_shards(pos_integer(), #{}) -> #{non_neg_integer() => shard_map()}.
start_shards(Count, Acc) ->
-spec default_shard_count() -> pos_integer().
default_shard_count() ->
Candidates = [
erlang:system_info(logical_processors_available),
erlang:system_info(schedulers_online)
],
max(1, lists:max([C || C <- Candidates, is_integer(C), C > 0] ++ [1])).
-spec start_shards(pos_integer()) -> #{non_neg_integer() => shard_map()}.
start_shards(Count) ->
lists:foldl(
fun(Index, MapAcc) ->
case start_shard(Index) of
{ok, Shard} ->
maps:put(Index, Shard, MapAcc);
{error, Reason} ->
logger:warning("[guild_manager] failed to start shard ~p: ~p", [Index, Reason]),
{error, _Reason} ->
MapAcc
end
end,
Acc,
#{},
lists:seq(0, Count - 1)
).
@@ -172,14 +162,31 @@ restart_shard(Index, State) ->
{ok, Shard} ->
Updated = State#{shards => maps:put(Index, Shard, Shards)},
{Shard, Updated};
{error, Reason} ->
logger:error("[guild_manager] failed to restart shard ~p: ~p", [Index, Reason]),
Dummy = #{pid => spawn(fun() -> exit(normal) end), ref => make_ref()},
{error, _Reason} ->
DummyPid = spawn(fun() -> ok end),
Dummy = #{pid => DummyPid, ref => make_ref()},
{Dummy, State}
end.
-spec forward_call(guild_id(), term(), state()) -> {term(), state()}.
forward_call(GuildId, {start_or_lookup, _} = Request, State) ->
case ets:lookup(?GUILD_PID_CACHE, GuildId) of
[{GuildId, GuildPid}] when is_pid(GuildPid) ->
case erlang:is_process_alive(GuildPid) of
true ->
{{ok, GuildPid}, State};
false ->
ets:delete(?GUILD_PID_CACHE, GuildId),
forward_call_to_shard(GuildId, Request, State)
end;
[] ->
forward_call_to_shard(GuildId, Request, State)
end;
forward_call(GuildId, Request, State) ->
forward_call_to_shard(GuildId, Request, State).
-spec forward_call_to_shard(guild_id(), term(), state()) -> {term(), state()}.
forward_call_to_shard(GuildId, Request, State) ->
{Index, State1} = ensure_shard(GuildId, State),
Shards = maps:get(shards, State1),
ShardMap = maps:get(Index, Shards),
@@ -187,7 +194,11 @@ forward_call(GuildId, Request, State) ->
case catch gen_server:call(Pid, Request, ?DEFAULT_GEN_SERVER_TIMEOUT) of
{'EXIT', _} ->
{_Shard, State2} = restart_shard(Index, State1),
forward_call(GuildId, Request, State2);
forward_call_to_shard(GuildId, Request, State2);
{ok, GuildPid} = Reply ->
ets:insert(?GUILD_PID_CACHE, {GuildId, GuildPid}),
erlang:monitor(process, GuildPid),
{Reply, State1};
Reply ->
{Reply, State1}
end.
@@ -223,146 +234,137 @@ select_shard(GuildId, Count) when Count > 0 ->
-spec aggregate_counts(term(), state()) -> {non_neg_integer(), state()}.
aggregate_counts(Request, State) ->
Shards = maps:get(shards, State),
Counts =
[
begin
Pid = maps:get(pid, ShardMap),
case catch gen_server:call(Pid, Request, ?DEFAULT_GEN_SERVER_TIMEOUT) of
{ok, Count} -> Count;
_ -> 0
end
Counts = lists:map(
fun(ShardMap) ->
Pid = maps:get(pid, ShardMap),
case catch gen_server:call(Pid, Request, ?DEFAULT_GEN_SERVER_TIMEOUT) of
{ok, Count} -> Count;
_ -> 0
end
|| ShardMap <- maps:values(Shards)
],
end,
maps:values(Shards)
),
{lists:sum(Counts), State}.
-spec handle_reload_all([guild_id()], state()) -> {#{count => non_neg_integer()}, state()}.
-spec handle_reload_all([guild_id()], state()) -> {#{count := non_neg_integer()}, state()}.
handle_reload_all([], State) ->
Shards = maps:get(shards, State),
{Replies, FinalState} =
lists:foldl(
fun({_Index, ShardMap}, {AccReplies, AccState}) ->
Pid = maps:get(pid, ShardMap),
case catch gen_server:call(Pid, {reload_all_guilds, []}, 60000) of
Reply ->
{AccReplies ++ [Reply], AccState}
end
end,
{[], State},
maps:to_list(Shards)
),
Count = lists:sum([maps:get(count, Reply, 0) || Reply <- Replies]),
{Replies, FinalState} = lists:foldl(
fun({_Index, ShardMap}, {AccReplies, AccState}) ->
Pid = maps:get(pid, ShardMap),
Reply = catch gen_server:call(Pid, {reload_all_guilds, []}, 15000),
{[Reply | AccReplies], AccState}
end,
{[], State},
maps:to_list(Shards)
),
Count = lists:sum([maps:get(count, Reply, 0) || Reply <- Replies, is_map(Reply)]),
{#{count => Count}, FinalState};
handle_reload_all(GuildIds, State) ->
Count = maps:get(shard_count, State),
Groups = group_ids_by_shard(GuildIds, Count),
{TotalCount, FinalState} =
lists:foldl(
fun({Index, Ids}, {AccCount, AccState}) ->
{ShardIdx, State1} = ensure_shard_for_index(Index, AccState),
Shards = maps:get(shards, State1),
ShardMap = maps:get(ShardIdx, Shards),
Pid = maps:get(pid, ShardMap),
case catch gen_server:call(Pid, {reload_all_guilds, Ids}, 60000) of
#{count := CountReply} ->
{AccCount + CountReply, State1};
_ ->
{AccCount, State1}
end
end,
{0, State},
Groups
),
{TotalCount, FinalState} = lists:foldl(
fun({Index, Ids}, {AccCount, AccState}) ->
{ShardIdx, State1} = ensure_shard_for_index(Index, AccState),
Shards = maps:get(shards, State1),
ShardMap = maps:get(ShardIdx, Shards),
Pid = maps:get(pid, ShardMap),
case catch gen_server:call(Pid, {reload_all_guilds, Ids}, 15000) of
#{count := CountReply} ->
{AccCount + CountReply, State1};
_ ->
{AccCount, State1}
end
end,
{0, State},
Groups
),
{#{count => TotalCount}, FinalState}.
-spec group_ids_by_shard([guild_id()], pos_integer()) -> [{non_neg_integer(), [guild_id()]}].
group_ids_by_shard(GuildIds, ShardCount) ->
lists:foldl(
fun(GuildId, Acc) ->
Index = select_shard(GuildId, ShardCount),
case lists:keytake(Index, 1, Acc) of
{value, {Index, Ids}, Rest} ->
[{Index, [GuildId | Ids]} | Rest];
false ->
[{Index, [GuildId]} | Acc]
end
end,
[],
GuildIds
).
rendezvous_router:group_keys(GuildIds, ShardCount).
-spec find_shard_by_ref(reference(), #{non_neg_integer() => shard_map()}) ->
{ok, non_neg_integer()} | not_found.
find_shard_by_ref(Ref, Shards) ->
maps:fold(
fun
(Index, ShardMap, _) when is_map(ShardMap) ->
case maps:get(ref, ShardMap) of
R when R =:= Ref -> {ok, Index};
_ -> not_found
end;
(_, _, Acc) ->
Acc
end,
not_found,
Shards
).
find_shard_by(fun(#{ref := R}) -> R =:= Ref end, Shards).
-spec find_shard_by_pid(pid(), #{non_neg_integer() => shard_map()}) ->
{ok, non_neg_integer()} | not_found.
find_shard_by_pid(Pid, Shards) ->
find_shard_by(fun(#{pid := P}) -> P =:= Pid end, Shards).
-spec find_shard_by(fun((shard_map()) -> boolean()), #{non_neg_integer() => shard_map()}) ->
{ok, non_neg_integer()} | not_found.
find_shard_by(Pred, Shards) ->
maps:fold(
fun
(Index, ShardMap, _) when is_map(ShardMap) ->
case maps:get(pid, ShardMap) of
P when P =:= Pid -> {ok, Index};
_ -> not_found
end;
(_, _, Acc) ->
Acc
(_, _, {ok, _} = Found) ->
Found;
(Index, ShardMap, not_found) ->
case Pred(ShardMap) of
true -> {ok, Index};
false -> not_found
end
end,
not_found,
Shards
).
-spec default_shard_count() -> pos_integer().
default_shard_count() ->
Candidates = [
erlang:system_info(logical_processors_available), erlang:system_info(schedulers_online)
],
lists:max([C || C <- Candidates, is_integer(C), C > 0] ++ [1]).
-spec maybe_log_shard_source(atom(), pos_integer(), configured | auto) -> ok.
maybe_log_shard_source(Name, Count, configured) ->
logger:info("[~p] starting with ~p shards (configured)", [Name, Count]),
ok;
maybe_log_shard_source(Name, Count, auto) ->
logger:info("[~p] starting with ~p shards (auto)", [Name, Count]),
-spec cleanup_guild_from_cache(pid()) -> ok.
cleanup_guild_from_cache(Pid) ->
case ets:match_object(?GUILD_PID_CACHE, {'$1', Pid}) of
[{GuildId, _Pid}] ->
ets:delete(?GUILD_PID_CACHE, GuildId);
[] ->
ok
end,
ok.
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
determine_shard_count_configured_test() ->
with_runtime_config(guild_shards, 4, fun() ->
?assertMatch({4, configured}, determine_shard_count())
end).
default_shard_count_positive_test() ->
Count = default_shard_count(),
?assert(Count >= 1).
determine_shard_count_auto_test() ->
with_runtime_config(guild_shards, undefined, fun() ->
{Count, auto} = determine_shard_count(),
?assert(Count > 0)
end).
select_shard_deterministic_test() ->
GuildId = 12345,
ShardCount = 8,
Shard1 = select_shard(GuildId, ShardCount),
Shard2 = select_shard(GuildId, ShardCount),
?assertEqual(Shard1, Shard2).
select_shard_in_range_test() ->
ShardCount = 8,
lists:foreach(
fun(GuildId) ->
Shard = select_shard(GuildId, ShardCount),
?assert(Shard >= 0 andalso Shard < ShardCount)
end,
lists:seq(1, 100)
).
group_ids_by_shard_test() ->
GuildIds = [1, 2, 3, 4, 5],
ShardCount = 2,
Groups = group_ids_by_shard(GuildIds, ShardCount),
AllIds = lists:flatten([Ids || {_, Ids} <- Groups]),
?assertEqual(lists:sort(GuildIds), lists:sort(AllIds)).
find_shard_by_ref_found_test() ->
Ref = make_ref(),
Shards = #{0 => #{pid => self(), ref => Ref}},
?assertMatch({ok, 0}, find_shard_by_ref(Ref, Shards)).
find_shard_by_ref_not_found_test() ->
Shards = #{0 => #{pid => self(), ref => make_ref()}},
?assertEqual(not_found, find_shard_by_ref(make_ref(), Shards)).
find_shard_by_pid_found_test() ->
Pid = self(),
Shards = #{0 => #{pid => Pid, ref => make_ref()}},
?assertMatch({ok, 0}, find_shard_by_pid(Pid, Shards)).
with_runtime_config(Key, Value, Fun) ->
Original = fluxer_gateway_env:get(Key),
fluxer_gateway_env:patch(#{Key => Value}),
Result = Fun(),
fluxer_gateway_env:update(fun(Map) ->
case Original of
undefined -> maps:remove(Key, Map);
Val -> maps:put(Key, Val, Map)
end
end),
Result.
-endif.

View File

@@ -21,6 +21,8 @@
-include_lib("fluxer_gateway/include/timeout_config.hrl").
-define(GUILD_API_CANARY_PERCENTAGE, 5).
-define(BATCH_SIZE, 10).
-define(BATCH_DELAY_MS, 100).
-export([start_link/1]).
-export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]).
@@ -30,56 +32,34 @@
-type guild_data() :: #{binary() => term()}.
-type fetch_result() :: {ok, guild_data()} | {error, term()}.
-type state() :: #{
guilds => #{guild_id() => guild_ref() | loading},
api_host => string(),
api_canary_host => undefined | string(),
pending_requests => #{guild_id() => [gen_server:from()]}
guilds := #{guild_id() => guild_ref() | loading},
api_host := string(),
api_canary_host := undefined | string(),
pending_requests := #{guild_id() => [gen_server:from()]},
shard_index := non_neg_integer()
}.
-record(state, {
guilds = #{} :: #{guild_id() => guild_ref() | loading},
api_host :: string(),
api_canary_host :: undefined | string(),
pending_requests = #{} :: #{guild_id() => [gen_server:from()]}
}).
-spec start_link(non_neg_integer()) -> {ok, pid()} | {error, term()}.
start_link(ShardIndex) ->
gen_server:start_link(?MODULE, #{shard_index => ShardIndex}, []).
-spec init(map()) -> {ok, state()}.
init(_Args) ->
init(Args) ->
process_flag(trap_exit, true),
fluxer_gateway_env:load(),
ApiHost = fluxer_gateway_env:get(api_host),
ApiCanaryHost = fluxer_gateway_env:get(api_canary_host),
ShardIndex = maps:get(shard_index, Args, 0),
{ok, #{
guilds => #{},
api_host => ApiHost,
api_canary_host => ApiCanaryHost,
pending_requests => #{}
pending_requests => #{},
shard_index => ShardIndex
}}.
-spec handle_call(Request, From, State) -> Result when
Request ::
{start_or_lookup, guild_id()}
| {stop_guild, guild_id()}
| {reload_guild, guild_id()}
| {reload_all_guilds, [guild_id()]}
| {shutdown_guild, guild_id()}
| get_local_count
| get_global_count
| term(),
From :: gen_server:from(),
State :: state(),
Result ::
{reply, Reply, state()}
| {noreply, state()},
Reply ::
{ok, pid()}
| {error, term()}
| ok
| {ok, non_neg_integer()}.
-spec handle_call(term(), gen_server:from(), state()) ->
{reply, term(), state()} | {noreply, state()}.
handle_call({start_or_lookup, GuildId}, From, State) ->
do_start_or_lookup(GuildId, From, State);
handle_call({stop_guild, GuildId}, _From, State) ->
@@ -87,37 +67,7 @@ handle_call({stop_guild, GuildId}, _From, State) ->
handle_call({reload_guild, GuildId}, From, State) ->
do_reload_guild(GuildId, From, State);
handle_call({reload_all_guilds, GuildIds}, From, State) ->
Guilds = maps:get(guilds, State),
GuildsToReload =
case GuildIds of
[] ->
[{GuildId, Pid} || {GuildId, {Pid, _Ref}} <- maps:to_list(Guilds)];
Ids ->
lists:filtermap(
fun(GuildId) ->
case maps:get(GuildId, Guilds, undefined) of
{Pid, _Ref} -> {true, {GuildId, Pid}};
_ -> false
end
end,
Ids
)
end,
Manager = self(),
spawn(fun() ->
try
reload_guilds_in_batches(GuildsToReload, Manager, State, 10, 100),
gen_server:cast(Manager, {all_guilds_reloaded, From, length(GuildsToReload)})
catch
Class:Error:Stacktrace ->
logger:error(
"[guild_manager] Spawned process failed: ~p:~p~n~p",
[Class, Error, Stacktrace]
),
gen_server:cast(Manager, {all_guilds_reloaded, From, 0})
end
end),
{noreply, State};
do_reload_all_guilds(GuildIds, From, State);
handle_call({shutdown_guild, GuildId}, _From, State) ->
do_shutdown_guild(GuildId, State);
handle_call(get_local_count, _From, State) ->
@@ -131,60 +81,18 @@ handle_call(get_global_count, _From, State) ->
handle_call(_Unknown, _From, State) ->
{reply, ok, State}.
-spec handle_cast(Request, State) -> {noreply, state()} when
Request ::
{guild_data_fetched, guild_id(), fetch_result()}
| {guild_data_reloaded, guild_id(), pid(), gen_server:from(), fetch_result()}
| {all_guilds_reloaded, gen_server:from(), non_neg_integer()}
| term(),
State :: state().
-spec handle_cast(term(), state()) -> {noreply, state()}.
handle_cast({guild_data_fetched, GuildId, Result}, State) ->
Pending = maps:get(pending_requests, State),
Requests = maps:get(GuildId, Pending, []),
Guilds = maps:get(guilds, State),
case Result of
{ok, Data} ->
case start_guild(GuildId, Data, State) of
{ok, Pid, NewState} ->
lists:foreach(fun(From) -> gen_server:reply(From, {ok, Pid}) end, Requests),
NewPending = maps:remove(GuildId, Pending),
NewGuilds = maps:get(guilds, NewState),
CleanGuilds = maps:remove(GuildId, NewGuilds),
{noreply, NewState#{pending_requests => NewPending, guilds => CleanGuilds}};
{error, Reason} ->
logger:error("[guild_manager] Failed to start guild ~p: ~p", [GuildId, Reason]),
lists:foreach(
fun(From) -> gen_server:reply(From, {error, Reason}) end, Requests
),
NewGuilds = maps:remove(GuildId, Guilds),
NewPending = maps:remove(GuildId, Pending),
{noreply, State#{guilds => NewGuilds, pending_requests => NewPending}}
end;
_ ->
lists:foreach(fun(From) -> gen_server:reply(From, {error, not_found}) end, Requests),
NewGuilds = maps:remove(GuildId, Guilds),
NewPending = maps:remove(GuildId, Pending),
{noreply, State#{guilds => NewGuilds, pending_requests => NewPending}}
end;
handle_cast({guild_data_reloaded, _GuildId, Pid, From, Result}, State) ->
case Result of
{ok, Data} ->
gen_server:call(Pid, {reload, Data}, ?GUILD_CALL_TIMEOUT),
gen_server:reply(From, ok),
{noreply, State};
_ ->
gen_server:reply(From, {error, fetch_failed}),
{noreply, State}
end;
handle_guild_data_fetched(GuildId, Result, State);
handle_cast({guild_data_reloaded, GuildId, Pid, From, Result}, State) ->
handle_guild_data_reloaded(GuildId, Pid, From, Result, State);
handle_cast({all_guilds_reloaded, From, Count}, State) ->
gen_server:reply(From, #{count => Count}),
{noreply, State};
handle_cast(_Unknown, State) ->
{noreply, State}.
-spec handle_info(Info, State) -> {noreply, state()} when
Info :: {'DOWN', reference(), process, pid(), term()} | term(),
State :: state().
-spec handle_info(term(), state()) -> {noreply, state()}.
handle_info({'DOWN', _Ref, process, Pid, _Reason}, State) ->
Guilds = maps:get(guilds, State),
NewGuilds = process_registry:cleanup_on_down(Pid, Guilds),
@@ -192,23 +100,323 @@ handle_info({'DOWN', _Ref, process, Pid, _Reason}, State) ->
handle_info(_Unknown, State) ->
{noreply, State}.
-spec terminate(Reason, State) -> ok when
Reason :: term(),
State :: state().
-spec terminate(term(), state()) -> ok.
terminate(_Reason, _State) ->
ok.
-spec code_change(term(), term(), term()) -> {ok, state()}.
code_change(_OldVsn, #state{guilds = Guilds, api_host = ApiHost, api_canary_host = ApiCanaryHost, pending_requests = Pending}, _Extra) ->
{ok, #{
guilds => Guilds,
api_host => ApiHost,
api_canary_host => ApiCanaryHost,
pending_requests => Pending
}};
code_change(_OldVsn, State, _Extra) when is_map(State) ->
{ok, State}.
-spec do_start_or_lookup(guild_id(), gen_server:from(), state()) ->
{reply, {ok, pid()} | {error, term()}, state()} | {noreply, state()}.
do_start_or_lookup(GuildId, From, State) ->
Guilds = maps:get(guilds, State),
case maps:get(GuildId, Guilds, undefined) of
{Pid, _Ref} ->
{reply, {ok, Pid}, State};
loading ->
add_pending_request(GuildId, From, State);
undefined ->
lookup_or_fetch(GuildId, From, State)
end.
-spec lookup_or_fetch(guild_id(), gen_server:from(), state()) ->
{reply, {ok, pid()}, state()} | {noreply, state()}.
lookup_or_fetch(GuildId, From, State) ->
GuildName = process_registry:build_process_name(guild, GuildId),
case whereis(GuildName) of
undefined ->
start_fetch(GuildId, From, State);
_ExistingPid ->
Guilds = maps:get(guilds, State),
case process_registry:lookup_or_monitor(GuildName, GuildId, Guilds) of
{ok, Pid, _Ref, NewGuilds} ->
{reply, {ok, Pid}, State#{guilds => NewGuilds}};
{error, not_found} ->
{reply, {error, process_died}, State}
end
end.
-spec start_fetch(guild_id(), gen_server:from(), state()) -> {noreply, state()}.
start_fetch(GuildId, From, State) ->
Guilds = maps:get(guilds, State),
NewGuilds = maps:put(GuildId, loading, Guilds),
Pending = maps:get(pending_requests, State),
NewPending = maps:put(GuildId, [From], Pending),
NewState = State#{guilds => NewGuilds, pending_requests => NewPending},
spawn_fetch(GuildId, State),
{noreply, NewState}.
-spec spawn_fetch(guild_id(), state()) -> pid().
spawn_fetch(GuildId, State) ->
Manager = self(),
ApiHostInfo = select_api_host(State),
spawn(fun() ->
try
Result = fetch_guild_data_with_fallback(GuildId, ApiHostInfo, State),
gen_server:cast(Manager, {guild_data_fetched, GuildId, Result})
catch
_:_:_ ->
gen_server:cast(Manager, {guild_data_fetched, GuildId, {error, fetch_failed}})
end
end).
-spec add_pending_request(guild_id(), gen_server:from(), state()) -> {noreply, state()}.
add_pending_request(GuildId, From, State) ->
Pending = maps:get(pending_requests, State),
Requests = maps:get(GuildId, Pending, []),
NewPending = maps:put(GuildId, [From | Requests], Pending),
{noreply, State#{pending_requests => NewPending}}.
-spec handle_guild_data_fetched(guild_id(), fetch_result(), state()) -> {noreply, state()}.
handle_guild_data_fetched(GuildId, Result, State) ->
Pending = maps:get(pending_requests, State),
Requests = maps:get(GuildId, Pending, []),
Guilds = maps:get(guilds, State),
case Result of
{ok, Data} ->
case start_guild(GuildId, Data, State) of
{ok, Pid, NewState} ->
reply_to_all(Requests, {ok, Pid}),
NewPending = maps:remove(GuildId, Pending),
{noreply, NewState#{pending_requests => NewPending}};
{error, Reason} ->
reply_to_all(Requests, {error, Reason}),
NewGuilds = maps:remove(GuildId, Guilds),
NewPending = maps:remove(GuildId, Pending),
{noreply, State#{guilds => NewGuilds, pending_requests => NewPending}}
end;
{error, Reason} ->
reply_to_all(Requests, {error, Reason}),
NewGuilds = maps:remove(GuildId, Guilds),
NewPending = maps:remove(GuildId, Pending),
{noreply, State#{guilds => NewGuilds, pending_requests => NewPending}}
end.
-spec handle_guild_data_reloaded(guild_id(), pid(), gen_server:from(), fetch_result(), state()) ->
{noreply, state()}.
handle_guild_data_reloaded(_GuildId, Pid, From, Result, State) ->
case Result of
{ok, Data} ->
catch gen_server:call(Pid, {reload, Data}, ?GUILD_CALL_TIMEOUT),
gen_server:reply(From, ok);
_ ->
gen_server:reply(From, {error, fetch_failed})
end,
{noreply, State}.
-spec reply_to_all([gen_server:from()], term()) -> ok.
reply_to_all(Requests, Reply) ->
lists:foreach(fun(From) -> gen_server:reply(From, Reply) end, Requests).
-spec do_stop_guild(guild_id(), state()) -> {reply, ok, state()}.
do_stop_guild(GuildId, State) ->
Guilds = maps:get(guilds, State),
GuildName = process_registry:build_process_name(guild, GuildId),
case maps:get(GuildId, Guilds, undefined) of
{Pid, Ref} ->
demonitor(Ref, [flush]),
catch gen_server:stop(Pid, normal, ?SHUTDOWN_TIMEOUT),
process_registry:safe_unregister(GuildName),
NewGuilds = maps:remove(GuildId, Guilds),
{reply, ok, State#{guilds => NewGuilds}};
_ ->
case whereis(GuildName) of
undefined ->
{reply, ok, State};
ExistingPid ->
catch gen_server:stop(ExistingPid, normal, ?SHUTDOWN_TIMEOUT),
process_registry:safe_unregister(GuildName),
{reply, ok, State}
end
end.
-spec do_reload_guild(guild_id(), gen_server:from(), state()) ->
{reply, {error, not_found}, state()} | {noreply, state()}.
do_reload_guild(GuildId, From, State) ->
Guilds = maps:get(guilds, State),
GuildName = process_registry:build_process_name(guild, GuildId),
case maps:get(GuildId, Guilds, undefined) of
{Pid, _Ref} ->
spawn_reload(GuildId, Pid, From, State),
{noreply, State};
_ ->
case whereis(GuildName) of
undefined ->
{reply, {error, not_found}, State};
_ExistingPid ->
case process_registry:lookup_or_monitor(GuildName, GuildId, Guilds) of
{ok, Pid, _Ref, NewGuilds} ->
NewState = State#{guilds => NewGuilds},
spawn_reload(GuildId, Pid, From, NewState),
{noreply, NewState};
{error, not_found} ->
{reply, {error, not_found}, State}
end
end
end.
-spec spawn_reload(guild_id(), pid(), gen_server:from(), state()) -> pid().
spawn_reload(GuildId, Pid, From, State) ->
Manager = self(),
ApiHostInfo = select_api_host(State),
spawn(fun() ->
try
Result = fetch_guild_data_with_fallback(GuildId, ApiHostInfo, State),
gen_server:cast(Manager, {guild_data_reloaded, GuildId, Pid, From, Result})
catch
_:_:_ ->
gen_server:cast(
Manager, {guild_data_reloaded, GuildId, Pid, From, {error, fetch_failed}}
)
end
end).
-spec do_reload_all_guilds([guild_id()], gen_server:from(), state()) -> {noreply, state()}.
do_reload_all_guilds(GuildIds, From, State) ->
Guilds = maps:get(guilds, State),
GuildsToReload = select_guilds_to_reload(GuildIds, Guilds),
Manager = self(),
spawn(fun() ->
try
reload_guilds_in_batches(GuildsToReload, State),
gen_server:cast(Manager, {all_guilds_reloaded, From, length(GuildsToReload)})
catch
_:_:_ ->
gen_server:cast(Manager, {all_guilds_reloaded, From, 0})
end
end),
{noreply, State}.
-spec select_guilds_to_reload([guild_id()], #{guild_id() => guild_ref() | loading}) ->
[{guild_id(), pid()}].
select_guilds_to_reload([], Guilds) ->
[{GuildId, Pid} || {GuildId, {Pid, _Ref}} <- maps:to_list(Guilds)];
select_guilds_to_reload(GuildIds, Guilds) ->
lists:filtermap(
fun(GuildId) ->
case maps:get(GuildId, Guilds, undefined) of
{Pid, _Ref} -> {true, {GuildId, Pid}};
_ -> false
end
end,
GuildIds
).
-spec reload_guilds_in_batches([{guild_id(), pid()}], state()) -> ok.
reload_guilds_in_batches([], _State) ->
ok;
reload_guilds_in_batches(Guilds, State) ->
{Batch, Remaining} = lists:split(min(?BATCH_SIZE, length(Guilds)), Guilds),
reload_batch(Batch, State),
case Remaining of
[] ->
ok;
_ ->
timer:sleep(?BATCH_DELAY_MS),
reload_guilds_in_batches(Remaining, State)
end.
-spec reload_batch([{guild_id(), pid()}], state()) -> ok.
reload_batch(Batch, State) ->
ApiHostInfo = select_api_host(State),
lists:foreach(
fun({GuildId, Pid}) ->
spawn(fun() ->
try
case fetch_guild_data_with_fallback(GuildId, ApiHostInfo, State) of
{ok, Data} ->
catch gen_server:call(Pid, {reload, Data}, ?GUILD_CALL_TIMEOUT);
{error, _Reason} ->
ok
end
catch
_:_ -> ok
end
end)
end,
Batch
).
-spec do_shutdown_guild(guild_id(), state()) -> {reply, ok, state()}.
do_shutdown_guild(GuildId, State) ->
Guilds = maps:get(guilds, State),
GuildName = process_registry:build_process_name(guild, GuildId),
case maps:get(GuildId, Guilds, undefined) of
{Pid, Ref} ->
demonitor(Ref, [flush]),
catch gen_server:call(Pid, {terminate}, ?SHUTDOWN_TIMEOUT),
process_registry:safe_unregister(GuildName),
NewGuilds = maps:remove(GuildId, Guilds),
{reply, ok, State#{guilds => NewGuilds}};
_ ->
case whereis(GuildName) of
undefined ->
{reply, ok, State};
ExistingPid ->
catch gen_server:call(ExistingPid, {terminate}, ?SHUTDOWN_TIMEOUT),
process_registry:safe_unregister(GuildName),
{reply, ok, State}
end
end.
-spec start_guild(guild_id(), guild_data(), state()) -> {ok, pid(), state()} | {error, term()}.
start_guild(GuildId, Data, State) ->
GuildName = process_registry:build_process_name(guild, GuildId),
case whereis(GuildName) of
undefined ->
start_new_guild(GuildId, Data, GuildName, State);
_ExistingPid ->
lookup_existing_guild(GuildId, GuildName, State)
end.
-spec start_new_guild(guild_id(), guild_data(), atom(), state()) ->
{ok, pid(), state()} | {error, term()}.
start_new_guild(GuildId, Data, GuildName, State) ->
GuildState = #{
id => GuildId,
data => Data,
sessions => #{}
},
Guilds = maps:get(guilds, State),
GuildModule =
case is_very_large_guild(Data) of
true -> very_large_guild;
false -> guild
end,
case GuildModule:start_link(GuildState) of
{ok, Pid} ->
case process_registry:register_and_monitor(GuildName, Pid, Guilds) of
{ok, RegisteredPid, Ref, NewGuilds0} ->
CleanGuilds = maps:remove(GuildName, NewGuilds0),
NewGuilds = maps:put(GuildId, {RegisteredPid, Ref}, CleanGuilds),
{ok, RegisteredPid, State#{guilds => NewGuilds}};
{error, Reason} ->
{error, Reason}
end;
Error ->
Error
end.
-spec is_very_large_guild(guild_data()) -> boolean().
is_very_large_guild(Data) when is_map(Data) ->
Guild = maps:get(<<"guild">>, Data, #{}),
Features = maps:get(<<"features">>, Guild, []),
lists:member(<<"VERY_LARGE_GUILD">>, Features);
is_very_large_guild(_) ->
false.
-spec lookup_existing_guild(guild_id(), atom(), state()) -> {ok, pid(), state()} | {error, term()}.
lookup_existing_guild(GuildId, GuildName, State) ->
Guilds = maps:get(guilds, State),
case process_registry:lookup_or_monitor(GuildName, GuildId, Guilds) of
{ok, Pid, _Ref, NewGuilds} ->
{ok, Pid, State#{guilds => NewGuilds}};
{error, not_found} ->
{error, process_died}
end.
-spec fetch_guild_data(guild_id(), string()) -> fetch_result().
fetch_guild_data(GuildId, ApiHost) ->
RpcRequest = #{
@@ -217,43 +425,27 @@ fetch_guild_data(GuildId, ApiHost) ->
<<"version">> => 1
},
Url = rpc_client:get_rpc_url(ApiHost),
Headers =
rpc_client:get_rpc_headers() ++ [{<<"content-type">>, <<"application/json">>}],
Body = jsx:encode(RpcRequest),
case
hackney:request(post, Url, Headers, Body, [{recv_timeout, 30000}, {connect_timeout, 5000}])
of
{ok, 200, _RespHeaders, ClientRef} ->
case hackney:body(ClientRef) of
{ok, RespBody} ->
hackney:close(ClientRef),
Response = jsx:decode(RespBody, [return_maps]),
Data = maps:get(<<"data">>, Response, #{}),
{ok, Data};
{error, BodyReason} ->
hackney:close(ClientRef),
logger:error("[guild_manager] Failed to read guild response body: ~p", [
BodyReason
]),
{error, fetch_failed}
end;
{ok, StatusCode, _RespHeaders, ClientRef} ->
ErrorBody =
case hackney:body(ClientRef) of
{ok, Body2} -> Body2;
{error, _} -> <<"<unable to read error body>">>
end,
hackney:close(ClientRef),
logger:error(
"[guild_manager] Guild RPC failed with status ~p: ~s",
[StatusCode, ErrorBody]
),
{error, fetch_failed};
Headers = rpc_client:get_rpc_headers() ++ [{<<"content-type">>, <<"application/json">>}],
Body = json:encode(RpcRequest),
case gateway_http_client:request(rpc, post, Url, Headers, Body) of
{ok, 200, _RespHeaders, RespBody} ->
handle_fetch_response(RespBody);
{ok, StatusCode, _RespHeaders, _RespBody} ->
handle_fetch_error(StatusCode);
{error, Reason} ->
logger:error("[guild_manager] Guild RPC request failed: ~p", [Reason]),
{error, fetch_failed}
{error, {request_failed, Reason}}
end.
-spec handle_fetch_response(binary()) -> fetch_result().
handle_fetch_response(RespBody) ->
Response = json:decode(RespBody),
Data = maps:get(<<"data">>, Response, #{}),
{ok, Data}.
-spec handle_fetch_error(integer()) -> {error, {http_status, integer()}}.
handle_fetch_error(StatusCode) ->
{error, {http_status, StatusCode}}.
-spec select_api_host(state()) -> {string(), boolean()}.
select_api_host(State) ->
case maps:get(api_canary_host, State) of
@@ -270,12 +462,8 @@ select_api_host(State) ->
should_use_canary_api() ->
erlang:unique_integer([positive]) rem 100 < ?GUILD_API_CANARY_PERCENTAGE.
-spec fetch_guild_data_with_fallback(
guild_id(),
{string(), boolean()},
state()
) -> fetch_result().
fetch_guild_data_with_fallback(GuildId, {ApiHost, false}, _) ->
-spec fetch_guild_data_with_fallback(guild_id(), {string(), boolean()}, state()) -> fetch_result().
fetch_guild_data_with_fallback(GuildId, {ApiHost, false}, _State) ->
fetch_guild_data(GuildId, ApiHost);
fetch_guild_data_with_fallback(GuildId, {ApiHost, true}, State) ->
case fetch_guild_data(GuildId, ApiHost) of
@@ -287,244 +475,57 @@ fetch_guild_data_with_fallback(GuildId, {ApiHost, true}, State) ->
true ->
Error;
false ->
logger:warning(
"[guild_manager] Canary API request failed for ~p, retrying against stable host",
[GuildId]
),
fetch_guild_data(GuildId, StableHost)
end
end.
-spec start_guild(guild_id(), guild_data(), state()) -> {ok, pid(), state()} | {error, term()}.
start_guild(GuildId, Data, State) ->
GuildName = process_registry:build_process_name(guild, GuildId),
case whereis(GuildName) of
undefined ->
GuildState = #{
id => GuildId,
data => Data,
sessions => #{},
presences => #{}
},
Guilds = maps:get(guilds, State),
case guild:start_link(GuildState) of
{ok, Pid} ->
case process_registry:register_and_monitor(GuildName, Pid, Guilds) of
{ok, RegisteredPid, Ref, NewGuilds0} ->
CleanGuilds = maps:remove(GuildName, NewGuilds0),
NewGuilds = maps:put(GuildId, {RegisteredPid, Ref}, CleanGuilds),
{ok, RegisteredPid, State#{guilds => NewGuilds}};
{error, Reason} ->
{error, Reason}
end;
Error ->
Error
end;
_ExistingPid ->
Guilds = maps:get(guilds, State),
case process_registry:lookup_or_monitor(GuildName, GuildId, Guilds) of
{ok, Pid, _Ref, NewGuilds} ->
{ok, Pid, State#{guilds => NewGuilds}};
{error, not_found} ->
{error, process_died}
end
end.
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-spec reload_guilds_in_batches(
[{guild_id(), pid()}],
pid(),
state(),
pos_integer(),
non_neg_integer()
) -> ok.
reload_guilds_in_batches([], _Manager, _State, _BatchSize, _DelayMs) ->
ok;
reload_guilds_in_batches(Guilds, Manager, State, BatchSize, DelayMs) ->
{Batch, Remaining} = lists:split(min(BatchSize, length(Guilds)), Guilds),
lists:foreach(
fun({GuildId, Pid}) ->
ApiHostInfo = select_api_host(State),
spawn(fun() ->
try
case fetch_guild_data_with_fallback(GuildId, ApiHostInfo, State) of
{ok, Data} ->
gen_server:call(Pid, {reload, Data}, ?GUILD_CALL_TIMEOUT);
{error, Reason} ->
logger:error("[guild_manager] Failed to reload guild ~p: ~p", [
GuildId, Reason
])
end
catch
Class:Error:Stacktrace ->
logger:error(
"[guild_manager] Spawned process failed: ~p:~p~n~p",
[Class, Error, Stacktrace]
)
end
end)
end,
Batch
),
case Remaining of
[] ->
ok;
_ ->
timer:sleep(DelayMs),
reload_guilds_in_batches(Remaining, Manager, State, BatchSize, DelayMs)
end.
select_api_host_no_canary_test() ->
State = #{api_host => "http://api.local", api_canary_host => undefined},
{Host, IsCanary} = select_api_host(State),
?assertEqual("http://api.local", Host),
?assertEqual(false, IsCanary).
-spec do_start_or_lookup(guild_id(), gen_server:from(), state()) ->
{reply, {ok, pid()} | {error, term()}, state()} | {noreply, state()}.
do_start_or_lookup(GuildId, From, State) ->
Guilds = maps:get(guilds, State),
case maps:get(GuildId, Guilds, undefined) of
{Pid, _Ref} ->
{reply, {ok, Pid}, State};
loading ->
Pending = maps:get(pending_requests, State),
Requests = maps:get(GuildId, Pending, []),
NewPending = maps:put(GuildId, [From | Requests], Pending),
{noreply, State#{pending_requests => NewPending}};
undefined ->
GuildName = process_registry:build_process_name(guild, GuildId),
case whereis(GuildName) of
undefined ->
NewGuilds = maps:put(GuildId, loading, Guilds),
Pending = maps:get(pending_requests, State),
NewPending = maps:put(GuildId, [From], Pending),
NewState = State#{guilds => NewGuilds, pending_requests => NewPending},
Manager = self(),
ApiHostInfo = select_api_host(State),
spawn(fun() ->
try
Result = fetch_guild_data_with_fallback(GuildId, ApiHostInfo, State),
gen_server:cast(Manager, {guild_data_fetched, GuildId, Result})
catch
Class:Error:Stacktrace ->
logger:error(
"[guild_manager] Spawned process failed: ~p:~p~n~p",
[Class, Error, Stacktrace]
),
gen_server:cast(
Manager, {guild_data_fetched, GuildId, {error, fetch_failed}}
)
end
end),
{noreply, NewState};
_ExistingPid ->
case process_registry:lookup_or_monitor(GuildName, GuildId, Guilds) of
{ok, Pid, _Ref, NewGuilds} ->
{reply, {ok, Pid}, State#{guilds => NewGuilds}};
{error, not_found} ->
{reply, {error, process_died}, State}
end
end
end.
should_use_canary_api_returns_boolean_test() ->
Result = should_use_canary_api(),
?assert(is_boolean(Result)).
-spec do_stop_guild(guild_id(), state()) -> {reply, ok, state()}.
do_stop_guild(GuildId, State) ->
Guilds = maps:get(guilds, State),
GuildName = process_registry:build_process_name(guild, GuildId),
case maps:get(GuildId, Guilds, undefined) of
{Pid, Ref} ->
demonitor(Ref, [flush]),
gen_server:stop(Pid, normal, ?SHUTDOWN_TIMEOUT),
process_registry:safe_unregister(GuildName),
NewGuilds = maps:remove(GuildId, Guilds),
{reply, ok, State#{guilds => NewGuilds}};
_ ->
case whereis(GuildName) of
undefined ->
{reply, ok, State};
ExistingPid ->
gen_server:stop(ExistingPid, normal, ?SHUTDOWN_TIMEOUT),
process_registry:safe_unregister(GuildName),
{reply, ok, State}
end
end.
select_guilds_to_reload_empty_ids_test() ->
Guilds = #{1 => {self(), make_ref()}, 2 => {self(), make_ref()}},
Result = select_guilds_to_reload([], Guilds),
?assertEqual(2, length(Result)).
-spec do_reload_guild(guild_id(), gen_server:from(), state()) ->
{reply, {error, not_found}, state()} | {noreply, state()}.
do_reload_guild(GuildId, From, State) ->
Guilds = maps:get(guilds, State),
GuildName = process_registry:build_process_name(guild, GuildId),
case maps:get(GuildId, Guilds, undefined) of
{Pid, _Ref} ->
Manager = self(),
ApiHostInfo = select_api_host(State),
spawn(fun() ->
try
Result = fetch_guild_data_with_fallback(GuildId, ApiHostInfo, State),
gen_server:cast(Manager, {guild_data_reloaded, GuildId, Pid, From, Result})
catch
Class:Error:Stacktrace ->
logger:error(
"[guild_manager] Spawned process failed: ~p:~p~n~p",
[Class, Error, Stacktrace]
),
gen_server:cast(
Manager,
{guild_data_reloaded, GuildId, Pid, From, {error, fetch_failed}}
)
end
end),
{noreply, State};
_ ->
case whereis(GuildName) of
undefined ->
{reply, {error, not_found}, State};
_ExistingPid ->
case process_registry:lookup_or_monitor(GuildName, GuildId, Guilds) of
{ok, Pid, _Ref, NewGuilds} ->
NewState = State#{guilds => NewGuilds},
Manager = self(),
ApiHostInfo = select_api_host(NewState),
spawn(fun() ->
try
Result = fetch_guild_data_with_fallback(
GuildId, ApiHostInfo, NewState
),
gen_server:cast(
Manager, {guild_data_reloaded, GuildId, Pid, From, Result}
)
catch
Class:Error:Stacktrace ->
logger:error(
"[guild_manager] Spawned process failed: ~p:~p~n~p",
[Class, Error, Stacktrace]
),
gen_server:cast(
Manager,
{guild_data_reloaded, GuildId, Pid, From,
{error, fetch_failed}}
)
end
end),
{noreply, NewState};
{error, not_found} ->
{reply, {error, not_found}, State}
end
end
end.
select_guilds_to_reload_specific_ids_test() ->
Pid = self(),
Ref = make_ref(),
Guilds = #{1 => {Pid, Ref}, 2 => {Pid, Ref}, 3 => loading},
Result = select_guilds_to_reload([1, 3], Guilds),
?assertEqual(1, length(Result)).
-spec do_shutdown_guild(guild_id(), state()) -> {reply, ok, state()}.
do_shutdown_guild(GuildId, State) ->
Guilds = maps:get(guilds, State),
GuildName = process_registry:build_process_name(guild, GuildId),
case maps:get(GuildId, Guilds, undefined) of
{Pid, Ref} ->
demonitor(Ref, [flush]),
gen_server:call(Pid, {terminate}, ?SHUTDOWN_TIMEOUT),
process_registry:safe_unregister(GuildName),
NewGuilds = maps:remove(GuildId, Guilds),
{reply, ok, State#{guilds => NewGuilds}};
_ ->
case whereis(GuildName) of
undefined ->
{reply, ok, State};
ExistingPid ->
catch gen_server:call(ExistingPid, {terminate}, ?SHUTDOWN_TIMEOUT),
process_registry:safe_unregister(GuildName),
{reply, ok, State}
end
end.
reply_to_all_empty_list_test() ->
?assertEqual(ok, reply_to_all([], ok)).
do_start_or_lookup_loading_deduplicates_requests_test() ->
GuildId = 4444,
From1 = {self(), make_ref()},
From2 = {self(), make_ref()},
State0 = #{
guilds => #{GuildId => loading},
api_host => "http://api.local",
api_canary_host => undefined,
pending_requests => #{},
shard_index => 0
},
{noreply, State1} = do_start_or_lookup(GuildId, From1, State0),
Pending1 = maps:get(pending_requests, State1),
?assertEqual([From1], maps:get(GuildId, Pending1)),
{noreply, State2} = do_start_or_lookup(GuildId, From2, State1),
Pending2 = maps:get(pending_requests, State2),
Requests = maps:get(GuildId, Pending2),
?assertEqual(2, length(Requests)),
?assert(lists:member(From1, Requests)),
?assert(lists:member(From2, Requests)).
-endif.

View File

@@ -25,11 +25,14 @@
get_items_in_range/3,
handle_member_update/3,
build_sync_response/4,
member_list_delta/4,
member_list_snapshot/2,
get_online_count/1,
broadcast_member_list_updates/3,
broadcast_all_member_list_updates/1,
broadcast_member_list_updates_for_channel/2,
normalize_ranges/1
normalize_ranges/1,
get_members_cursor/2
]).
-ifdef(TEST).
@@ -43,13 +46,19 @@
-type group_item() :: map().
-type list_item() :: member_item() | group_item().
-type user_id() :: integer().
-type channel_id() :: integer().
-define(MAX_RANGE_END, 100000).
-spec validate_range(range()) -> range() | invalid.
validate_range({Start, End}) when is_integer(Start), is_integer(End),
Start >= 0, End >= 0, Start =< End,
End =< ?MAX_RANGE_END ->
validate_range({Start, End}) when
is_integer(Start),
is_integer(End),
Start >= 0,
End >= 0,
Start =< End,
End =< ?MAX_RANGE_END
->
{Start, End};
validate_range(_) ->
invalid.
@@ -77,27 +86,39 @@ merge_overlapping_ranges([{S1, E1}, {S2, E2} | Rest]) when S2 =< E1 + 1 ->
merge_overlapping_ranges([Range | Rest]) ->
[Range | merge_overlapping_ranges(Rest)].
-spec calculate_list_id(integer(), guild_state()) -> list_id().
calculate_list_id(ChannelId, _State) when is_integer(ChannelId), ChannelId > 0 ->
integer_to_binary(ChannelId);
-spec calculate_list_id(channel_id(), guild_state()) -> list_id().
calculate_list_id(ChannelId, State) when is_integer(ChannelId), ChannelId > 0 ->
Data = maps:get(data, State, #{}),
Channels = map_utils:ensure_list(maps:get(<<"channels">>, Data, [])),
ChannelIdBin = integer_to_binary(ChannelId),
ChannelExists = lists:any(
fun(Ch) -> maps:get(<<"id">>, Ch, undefined) =:= ChannelIdBin end, Channels
),
case ChannelExists of
true -> ChannelIdBin;
false -> <<"0">>
end;
calculate_list_id(_, _) ->
<<"0">>.
-spec get_member_groups(list_id(), guild_state()) -> [group_item()].
get_member_groups(ListId, State) ->
Data = maps:get(data, State, #{}),
Members = map_utils:ensure_list(maps:get(<<"members">>, Data, [])),
Members = guild_data_index:member_values(Data),
Roles = map_utils:ensure_list(maps:get(<<"roles">>, Data, [])),
GuildId = maps:get(id, State, 0),
HoistedRoles = get_hoisted_roles_sorted(Roles, GuildId),
FilteredMembers = filter_members_for_list(ListId, Members, State),
{OnlineMembers, OfflineMembers} = partition_members_by_online(FilteredMembers, State),
RoleGroups = build_role_groups(HoistedRoles, OnlineMembers),
OnlineGroup = #{<<"id">> => <<"online">>, <<"count">> => count_ungrouped_online(OnlineMembers, HoistedRoles)},
OnlineGroup = #{
<<"id">> => <<"online">>, <<"count">> => count_ungrouped_online(OnlineMembers, HoistedRoles)
},
OfflineGroup = #{<<"id">> => <<"offline">>, <<"count">> => length(OfflineMembers)},
RoleGroups ++ [OnlineGroup, OfflineGroup].
-spec subscribe_ranges(binary(), list_id(), [range()], guild_state()) -> {guild_state(), boolean(), [range()]}.
-spec subscribe_ranges(binary(), list_id(), [range()], guild_state()) ->
{guild_state(), boolean(), [range()]}.
subscribe_ranges(SessionId, ListId, Ranges, State) ->
NormalizedRanges = normalize_ranges(Ranges),
Subscriptions = maps:get(member_list_subscriptions, State, #{}),
@@ -205,8 +226,9 @@ build_sync_response(GuildId, ListId, Ranges, State) ->
-spec get_online_count(guild_state()) -> non_neg_integer().
get_online_count(State) ->
Data = maps:get(data, State, #{}),
Members = map_utils:ensure_list(maps:get(<<"members">>, Data, [])),
{OnlineMembers, _} = partition_members_by_online(Members, State),
Members = guild_data_index:member_values(Data),
EligibleMembers = filter_members_for_list(<<"0">>, Members, State),
{OnlineMembers, _} = partition_members_by_online(EligibleMembers, State),
length(OnlineMembers).
-spec broadcast_member_list_updates(user_id(), guild_state(), guild_state()) -> ok.
@@ -226,7 +248,9 @@ broadcast_member_list_updates(_UserId, OldState, UpdatedState) ->
<<"groups">> => Groups,
<<"ops">> => Ops
},
send_member_list_update_to_sessions(ListId, ListSubs, Sessions, Payload, UpdatedState);
send_member_list_update_to_sessions(
ListId, ListSubs, Sessions, Payload, UpdatedState
);
_ ->
ok
end
@@ -246,7 +270,9 @@ send_member_list_update_to_sessions(ListId, ListSubs, Sessions, Payload, State)
true ->
case session_can_view_channel(SessionData, ChannelId, State) of
true ->
gen_server:cast(SessionPid, {dispatch, guild_member_list_update, Payload});
gen_server:cast(
SessionPid, {dispatch, guild_member_list_update, Payload}
);
false ->
ok
end;
@@ -260,7 +286,7 @@ send_member_list_update_to_sessions(ListId, ListSubs, Sessions, Payload, State)
ListSubs
).
-spec member_list_delta(list_id(), guild_state(), guild_state(), user_id()) ->
-spec member_list_delta(list_id(), guild_state(), guild_state(), user_id()) ->
{non_neg_integer(), non_neg_integer(), [group_item()], [list_item()], boolean()}.
member_list_delta(ListId, OldState, UpdatedState, UserId) ->
{OldCount, OldOnline, OldGroups, OldItems} = member_list_snapshot(ListId, OldState),
@@ -270,10 +296,14 @@ member_list_delta(ListId, OldState, UpdatedState, UserId) ->
{MemberCount, OnlineCount, Groups, Ops, true};
{false, _} ->
Ops = diff_items_to_ops(OldItems, Items),
Changed = Ops =/= [] orelse OldCount =/= MemberCount orelse OldOnline =/= OnlineCount orelse OldGroups =/= Groups,
Changed =
Ops =/= [] orelse OldCount =/= MemberCount orelse OldOnline =/= OnlineCount orelse
OldGroups =/= Groups,
{MemberCount, OnlineCount, Groups, Ops, Changed}
end.
-spec presence_move_ops(user_id(), guild_state(), guild_state(), [list_item()], [list_item()]) ->
{boolean(), [map()]}.
presence_move_ops(UserId, OldState, UpdatedState, OldItems, NewItems) ->
case presence_status_changed(UserId, OldState, UpdatedState) of
false ->
@@ -295,6 +325,7 @@ presence_move_ops(UserId, OldState, UpdatedState, OldItems, NewItems) ->
end
end.
-spec presence_status_changed(user_id(), guild_state(), guild_state()) -> boolean().
presence_status_changed(UserId, OldState, UpdatedState) ->
OldPresence = resolve_presence_for_user(OldState, UserId),
NewPresence = resolve_presence_for_user(UpdatedState, UserId),
@@ -302,9 +333,13 @@ presence_status_changed(UserId, OldState, UpdatedState) ->
NewStatus = maps:get(<<"status">>, NewPresence, <<"offline">>),
OldStatus =/= NewStatus.
-spec find_member_entry(user_id(), [list_item()]) ->
{ok, non_neg_integer(), list_item()} | {error, not_found}.
find_member_entry(UserId, Items) ->
find_member_entry(UserId, Items, 0).
-spec find_member_entry(user_id(), [list_item()], non_neg_integer()) ->
{ok, non_neg_integer(), list_item()} | {error, not_found}.
find_member_entry(_UserId, [], _Index) ->
{error, not_found};
find_member_entry(UserId, [Item | Rest], Index) ->
@@ -318,6 +353,7 @@ find_member_entry(UserId, [Item | Rest], Index) ->
end
end.
-spec adjusted_insert_index(non_neg_integer(), non_neg_integer()) -> non_neg_integer().
adjusted_insert_index(OldIdx, NewIdx) when NewIdx > OldIdx ->
NewIdx - 1;
adjusted_insert_index(_OldIdx, NewIdx) ->
@@ -345,9 +381,16 @@ broadcast_all_member_list_updates(State) ->
case maps:get(SessionId, Sessions, undefined) of
SessionData when is_map(SessionData) ->
SessionPid = maps:get(pid, SessionData, undefined),
case {is_pid(SessionPid), session_can_view_channel(SessionData, ChannelId, State)} of
case
{
is_pid(SessionPid),
session_can_view_channel(SessionData, ChannelId, State)
}
of
{true, true} ->
SyncResponse = build_sync_response(GuildId, ListId, Ranges, State),
SyncResponse = build_sync_response(
GuildId, ListId, Ranges, State
),
gen_server:cast(
SessionPid,
{dispatch, guild_member_list_update, SyncResponse}
@@ -366,7 +409,7 @@ broadcast_all_member_list_updates(State) ->
),
ok.
-spec broadcast_member_list_updates_for_channel(integer(), guild_state()) -> ok.
-spec broadcast_member_list_updates_for_channel(channel_id(), guild_state()) -> ok.
broadcast_member_list_updates_for_channel(ChannelId, State) ->
GuildId = maps:get(id, State, 0),
ListId = calculate_list_id(ChannelId, State),
@@ -381,13 +424,23 @@ broadcast_member_list_updates_for_channel(ChannelId, State) ->
case maps:get(SessionId, Sessions, undefined) of
SessionData when is_map(SessionData) ->
SessionPid = maps:get(pid, SessionData, undefined),
case {is_pid(SessionPid), session_can_view_channel(SessionData, ChannelId, State)} of
case
{
is_pid(SessionPid),
session_can_view_channel(SessionData, ChannelId, State)
}
of
{true, true} ->
SyncResponse = build_sync_response(GuildId, ListId, Ranges, State),
SyncResponseWithChannel = maps:put(<<"channel_id">>, integer_to_binary(ChannelId), SyncResponse),
SyncResponse = build_sync_response(
GuildId, ListId, Ranges, State
),
SyncResponseWithChannel = maps:put(
<<"channel_id">>, integer_to_binary(ChannelId), SyncResponse
),
gen_server:cast(
SessionPid,
{dispatch, guild_member_list_update, SyncResponseWithChannel}
{dispatch, guild_member_list_update,
SyncResponseWithChannel}
);
_ ->
ok
@@ -445,14 +498,33 @@ filter_members_for_list(ListId, Members, State) ->
)
end.
-spec connected_session_user_ids(guild_state()) -> sets:set().
connected_session_user_ids(State) ->
Sessions = maps:get(sessions, State, #{}),
maps:fold(
fun(_SessionId, SessionData, Acc) ->
case maps:get(user_id, SessionData, undefined) of
UserId when is_integer(UserId), UserId > 0 ->
sets:add_element(UserId, Acc);
_ ->
Acc
end
end,
sets:new(),
Sessions
).
-spec partition_members_by_online([map()], guild_state()) -> {[map()], [map()]}.
partition_members_by_online(Members, State) ->
ConnectedUserIds = connected_session_user_ids(State),
lists:partition(
fun(Member) ->
UserId = get_member_user_id(Member),
Presence = resolve_presence_for_user(State, UserId),
Status = maps:get(<<"status">>, Presence, <<"offline">>),
Status =/= <<"offline">> andalso Status =/= <<"invisible">>
IsConnected = sets:is_element(UserId, ConnectedUserIds),
IsOnlineStatus = Status =/= <<"offline">> andalso Status =/= <<"invisible">>,
IsConnected andalso IsOnlineStatus
end,
Members
).
@@ -471,15 +543,17 @@ build_role_groups(HoistedRoles, OnlineMembers) ->
-spec count_members_with_top_role(integer(), [map()], [map()]) -> non_neg_integer().
count_members_with_top_role(RoleId, Members, HoistedRoles) ->
HoistedRoleIds = [map_utils:get_integer(R, <<"id">>, 0) || R <- HoistedRoles],
length(lists:filter(
fun(Member) ->
MemberRoles = map_utils:ensure_list(maps:get(<<"roles">>, Member, [])),
MemberRoleIds = [type_conv:to_integer(R) || R <- MemberRoles],
TopHoisted = find_top_hoisted_role(MemberRoleIds, HoistedRoleIds),
TopHoisted =:= RoleId
end,
Members
)).
length(
lists:filter(
fun(Member) ->
MemberRoles = map_utils:ensure_list(maps:get(<<"roles">>, Member, [])),
MemberRoleIds = [type_conv:to_integer(R) || R <- MemberRoles],
TopHoisted = find_top_hoisted_role(MemberRoleIds, HoistedRoleIds),
TopHoisted =:= RoleId
end,
Members
)
).
-spec find_top_hoisted_role([integer()], [integer()]) -> integer() | undefined.
find_top_hoisted_role(MemberRoleIds, HoistedRoleIds) ->
@@ -491,20 +565,22 @@ find_top_hoisted_role(MemberRoleIds, HoistedRoleIds) ->
-spec count_ungrouped_online([map()], [map()]) -> non_neg_integer().
count_ungrouped_online(OnlineMembers, HoistedRoles) ->
HoistedRoleIds = [map_utils:get_integer(R, <<"id">>, 0) || R <- HoistedRoles],
length(lists:filter(
fun(Member) ->
MemberRoles = map_utils:ensure_list(maps:get(<<"roles">>, Member, [])),
MemberRoleIds = [type_conv:to_integer(R) || R <- MemberRoles],
TopHoisted = find_top_hoisted_role(MemberRoleIds, HoistedRoleIds),
TopHoisted =:= undefined
end,
OnlineMembers
)).
length(
lists:filter(
fun(Member) ->
MemberRoles = map_utils:ensure_list(maps:get(<<"roles">>, Member, [])),
MemberRoleIds = [type_conv:to_integer(R) || R <- MemberRoles],
TopHoisted = find_top_hoisted_role(MemberRoleIds, HoistedRoleIds),
TopHoisted =:= undefined
end,
OnlineMembers
)
).
-spec get_sorted_members_for_list(list_id(), guild_state()) -> [map()].
get_sorted_members_for_list(ListId, State) ->
Data = maps:get(data, State, #{}),
Members = map_utils:ensure_list(maps:get(<<"members">>, Data, [])),
Members = guild_data_index:member_values(Data),
FilteredMembers = filter_members_for_list(ListId, Members, State),
lists:sort(
fun(A, B) ->
@@ -519,11 +595,18 @@ get_member_user_id(Member) ->
map_utils:get_integer(User, <<"id">>, 0).
-spec normalize_name(term()) -> binary().
normalize_name(undefined) -> <<>>;
normalize_name(null) -> <<>>;
normalize_name(<<_/binary>> = B) -> B;
normalize_name(undefined) ->
<<>>;
normalize_name(null) ->
<<>>;
normalize_name(<<_/binary>> = B) ->
B;
normalize_name(L) when is_list(L) ->
try unicode:characters_to_binary(L) catch _:_ -> <<>> end;
try
unicode:characters_to_binary(L)
catch
_:_ -> <<>>
end;
normalize_name(I) when is_integer(I) ->
integer_to_binary(I);
normalize_name(_) ->
@@ -560,11 +643,13 @@ casefold_binary(Value) ->
_:_ -> Bin
end.
-spec add_presence_to_member(map(), guild_state()) -> map().
add_presence_to_member(Member, State) ->
UserId = get_member_user_id(Member),
Presence = resolve_presence_for_user(State, UserId),
maps:put(<<"presence">>, Presence, Member).
-spec default_presence() -> map().
default_presence() ->
#{
<<"status">> => <<"offline">>,
@@ -572,20 +657,10 @@ default_presence() ->
<<"afk">> => false
}.
-spec resolve_presence_for_user(guild_state(), user_id()) -> map().
resolve_presence_for_user(State, UserId) ->
Presences = maps:get(presences, State, #{}),
case maps:get(UserId, Presences, undefined) of
undefined -> fetch_presence_from_cache(UserId);
Presence -> Presence
end.
fetch_presence_from_cache(UserId) ->
try presence_cache:get(UserId) of
{ok, Presence} -> Presence;
_ -> default_presence()
catch
exit:{noproc, _} -> default_presence()
end.
MemberPresence = maps:get(member_presence, State, #{}),
maps:get(UserId, MemberPresence, default_presence()).
-spec build_member_list_items([group_item()], [map()], guild_state()) -> [list_item()].
build_member_list_items(Groups, Members, State) ->
@@ -609,9 +684,21 @@ build_member_list_items(Groups, Members, State) ->
end,
OnlineMembers
),
[GroupHeader | [#{<<"member">> => add_presence_to_member(M, State)} || M <- UngroupedOnline]];
[
GroupHeader
| [
#{<<"member">> => add_presence_to_member(M, State)}
|| M <- UngroupedOnline
]
];
<<"offline">> ->
[GroupHeader | [#{<<"member">> => add_presence_to_member(M, State)} || M <- OfflineMembers]];
[
GroupHeader
| [
#{<<"member">> => add_presence_to_member(M, State)}
|| M <- OfflineMembers
]
];
RoleIdBin ->
RoleId = type_conv:to_integer(RoleIdBin),
RoleMembers = lists:filter(
@@ -622,7 +709,10 @@ build_member_list_items(Groups, Members, State) ->
end,
OnlineMembers
),
[GroupHeader | [#{<<"member">> => add_presence_to_member(M, State)} || M <- RoleMembers]]
[
GroupHeader
| [#{<<"member">> => add_presence_to_member(M, State)} || M <- RoleMembers]
]
end
end,
Groups
@@ -636,7 +726,7 @@ slice_items(Items, Start, End) ->
false -> lists:sublist(Items, Start + 1, SafeEnd - Start + 1)
end.
-spec member_in_list(integer(), [map()]) -> boolean().
-spec member_in_list(user_id(), [map()]) -> boolean().
member_in_list(UserId, Members) ->
lists:any(fun(M) -> get_member_user_id(M) =:= UserId end, Members).
@@ -650,50 +740,83 @@ full_sync_ops(ListId, State) ->
SortedMembers = get_sorted_members_for_list(ListId, State),
Items = build_full_items(ListId, State, SortedMembers),
case length(Items) of
0 -> [];
0 ->
[];
N ->
[#{
<<"op">> => <<"SYNC">>,
<<"range">> => [0, N - 1],
<<"items">> => Items
}]
[
#{
<<"op">> => <<"SYNC">>,
<<"range">> => [0, N - 1],
<<"items">> => Items
}
]
end.
-spec upsert_member_in_state(integer(), map(), guild_state()) -> {map() | undefined, map(), guild_state()}.
-spec get_members_cursor(map(), guild_state()) -> {reply, map(), guild_state()}.
get_members_cursor(Request, State) ->
Limit = maps:get(<<"limit">>, Request, 1),
AfterId = maps:get(<<"after">>, Request, undefined),
SortedMembers = sort_members_by_user_id(State),
FilteredMembers = filter_members_after(SortedMembers, AfterId),
ResponseMembers = take_first(FilteredMembers, Limit),
{reply,
#{
members => ResponseMembers,
total => length(SortedMembers)
},
State}.
-spec take_first([map()], integer()) -> [map()].
take_first(_Members, Limit) when Limit =< 0 ->
[];
take_first(Members, Limit) ->
Count = min(Limit, length(Members)),
case Count of
0 -> [];
_ -> lists:sublist(Members, 1, Count)
end.
-spec filter_members_after([map()], integer() | undefined) -> [map()].
filter_members_after(Members, undefined) ->
Members;
filter_members_after(Members, AfterId) ->
lists:dropwhile(
fun(Member) ->
get_member_user_id(Member) =< AfterId
end,
Members
).
-spec sort_members_by_user_id(guild_state()) -> [map()].
sort_members_by_user_id(State) ->
Data = maps:get(data, State, #{}),
Members = guild_data_index:member_values(Data),
lists:sort(
fun(A, B) ->
get_member_user_id(A) =< get_member_user_id(B)
end,
Members
).
-spec upsert_member_in_state(user_id(), map(), guild_state()) ->
{map() | undefined, map(), guild_state()}.
upsert_member_in_state(UserId, MemberUpdate, State) ->
Data = maps:get(data, State, #{}),
Members0 = map_utils:ensure_list(maps:get(<<"members">>, Data, [])),
{Found, CurrentMember} = find_member_by_user_id(UserId, Members0),
Members0 = guild_data_index:member_map(Data),
CurrentMember = maps:get(UserId, Members0, undefined),
UpdatedMember =
case Found of
true -> deep_merge_member(CurrentMember, MemberUpdate);
false -> MemberUpdate
case CurrentMember of
undefined -> MemberUpdate;
_ -> deep_merge_member(CurrentMember, MemberUpdate)
end,
Members1 =
case Found of
true ->
lists:map(
fun(M) ->
case get_member_user_id(M) =:= UserId of
true -> UpdatedMember;
false -> M
end
end,
Members0
);
false ->
Members0 ++ [UpdatedMember]
end,
Data1 = maps:put(<<"members">>, Members1, Data),
Members1 = maps:put(UserId, UpdatedMember, Members0),
Data1 = guild_data_index:put_member_map(Members1, Data),
NewState = maps:put(data, Data1, State),
{case Found of true -> CurrentMember; false -> undefined end, UpdatedMember, NewState}.
-spec find_member_by_user_id(integer(), [map()]) -> {boolean(), map()} | {false, undefined}.
find_member_by_user_id(UserId, Members) ->
case lists:search(fun(M) -> get_member_user_id(M) =:= UserId end, Members) of
{value, M} -> {true, M};
false -> {false, undefined}
end.
{
CurrentMember,
UpdatedMember,
NewState
}.
-spec deep_merge_member(map(), map()) -> map().
deep_merge_member(CurrentMember, MemberUpdate) ->
@@ -741,13 +864,16 @@ diff_items_to_ops(OldItems, NewItems) ->
-spec full_sync_from_items([list_item()]) -> [map()].
full_sync_from_items(Items) ->
case length(Items) of
0 -> [];
0 ->
[];
N ->
[#{
<<"op">> => <<"SYNC">>,
<<"range">> => [0, N - 1],
<<"items">> => Items
}]
[
#{
<<"op">> => <<"SYNC">>,
<<"range">> => [0, N - 1],
<<"items">> => Items
}
]
end.
-spec item_keys([list_item()]) -> [term()].
@@ -777,11 +903,14 @@ updates_for_changed_items(OldItems, NewItems) ->
true ->
Acc;
false ->
[#{
<<"op">> => <<"UPDATE">>,
<<"index">> => Idx,
<<"item">> => NewItem
} | Acc]
[
#{
<<"op">> => <<"UPDATE">>,
<<"index">> => Idx,
<<"item">> => NewItem
}
| Acc
]
end
end,
[],
@@ -792,6 +921,9 @@ updates_for_changed_items(OldItems, NewItems) ->
-spec zip_with_index([term()], [term()]) -> [{non_neg_integer(), term(), term()}].
zip_with_index(A, B) ->
zip_with_index(A, B, 0, []).
-spec zip_with_index([term()], [term()], non_neg_integer(), [{non_neg_integer(), term(), term()}]) ->
[{non_neg_integer(), term(), term()}].
zip_with_index([], [], _I, Acc) ->
lists:reverse(Acc);
zip_with_index([HA | TA], [HB | TB], I, Acc) ->
@@ -802,6 +934,11 @@ zip_with_index(_, _, _I, Acc) ->
-spec mismatch_span([term()], [term()]) -> none | {non_neg_integer(), non_neg_integer()}.
mismatch_span(A, B) ->
mismatch_span(A, B, 0, none).
-spec mismatch_span(
[term()], [term()], non_neg_integer(), none | {non_neg_integer(), non_neg_integer()}
) ->
none | {non_neg_integer(), non_neg_integer()}.
mismatch_span([], [], _I, none) ->
none;
mismatch_span([], [], _I, {S, E}) ->
@@ -831,7 +968,8 @@ sync_range_op(Start, End, NewItems) ->
<<"items">> => slice_items(NewItems, Start, End)
}.
-spec try_pure_insert_delete([list_item()], [list_item()], [term()], [term()]) -> {ok, [map()]} | error.
-spec try_pure_insert_delete([list_item()], [list_item()], [term()], [term()]) ->
{ok, [map()]} | error.
try_pure_insert_delete(OldItems, NewItems, OldKeys, NewKeys) ->
LenOld = length(OldKeys),
LenNew = length(NewKeys),
@@ -872,6 +1010,8 @@ try_pure_insert_delete(OldItems, NewItems, OldKeys, NewKeys) ->
-spec first_mismatch_index([term()], [term()]) -> non_neg_integer().
first_mismatch_index(A, B) ->
first_mismatch_index(A, B, 0).
-spec first_mismatch_index([term()], [term()], non_neg_integer()) -> non_neg_integer().
first_mismatch_index([], _B, I) ->
I;
first_mismatch_index(_A, [], I) ->
@@ -896,6 +1036,8 @@ delete_many(List, Index, Count) ->
-spec insert_ops(non_neg_integer(), [list_item()]) -> [map()].
insert_ops(StartIdx, Items) ->
insert_ops(StartIdx, Items, 0, []).
-spec insert_ops(non_neg_integer(), [list_item()], non_neg_integer(), [map()]) -> [map()].
insert_ops(_StartIdx, [], _Offset, Acc) ->
lists:reverse(Acc);
insert_ops(StartIdx, [Item | Rest], Offset, Acc) ->
@@ -920,18 +1062,23 @@ delete_ops(Idx, Count) ->
lists:seq(1, Count)
).
-spec session_can_view_channel(map(), integer(), guild_state()) -> boolean().
session_can_view_channel(_SessionData, ChannelId, _State) when not is_integer(ChannelId); ChannelId =< 0 ->
-spec session_can_view_channel(map(), channel_id(), guild_state()) -> boolean().
session_can_view_channel(_SessionData, ChannelId, _State) when
not is_integer(ChannelId); ChannelId =< 0
->
false;
session_can_view_channel(SessionData, ChannelId, State) ->
case maps:get(user_id, SessionData, undefined) of
UserId when is_integer(UserId) ->
case {maps:get(user_id, SessionData, undefined), maps:get(viewable_channels, SessionData, undefined)} of
{UserId, ViewableChannels} when is_integer(UserId), is_map(ViewableChannels) ->
maps:is_key(ChannelId, ViewableChannels) orelse
guild_permissions:can_view_channel(UserId, ChannelId, undefined, State);
{UserId, _} when is_integer(UserId) ->
guild_permissions:can_view_channel(UserId, ChannelId, undefined, State);
_ ->
false
end.
-spec list_id_channel_id(list_id()) -> integer().
-spec list_id_channel_id(list_id()) -> channel_id().
list_id_channel_id(ListId) when is_binary(ListId) ->
case type_conv:to_integer(ListId) of
undefined -> 0;
@@ -966,6 +1113,16 @@ list_id_channel_id_parses_binary_test() ->
list_id_channel_id_invalid_value_test() ->
?assertEqual(0, list_id_channel_id(<<"abc">>)).
session_can_view_channel_uses_cached_visibility_test() ->
SessionData = #{user_id => 12, viewable_channels => #{500 => true}},
State = #{data => #{<<"members">> => #{}}},
?assertEqual(true, session_can_view_channel(SessionData, 500, State)).
session_can_view_channel_rejects_when_cache_misses_and_user_missing_test() ->
SessionData = #{user_id => 99, viewable_channels => #{}},
State = #{data => #{<<"members">> => #{}}},
?assertEqual(false, session_can_view_channel(SessionData, 500, State)).
subscribe_ranges_test() ->
State = #{member_list_subscriptions => #{}},
{NewState, ShouldSync, NormalizedRanges} =
@@ -1024,4 +1181,64 @@ validate_range_invalid_negative_test() ->
validate_range_invalid_too_large_test() ->
?assertEqual(invalid, validate_range({0, 100001})).
get_members_cursor_returns_atom_keys_test() ->
Member1 = #{<<"user">> => #{<<"id">> => <<"2">>}},
Member2 = #{<<"user">> => #{<<"id">> => <<"1">>}},
State = #{data => #{<<"members">> => [Member1, Member2]}},
{reply, Reply, _NewState} = get_members_cursor(#{<<"limit">> => 1}, State),
?assert(maps:is_key(members, Reply)),
?assert(maps:is_key(total, Reply)),
?assertNot(maps:is_key(<<"members">>, Reply)),
?assertNot(maps:is_key(<<"total">>, Reply)).
get_online_count_ignores_members_without_connected_session_test() ->
Members = [
#{<<"user">> => #{<<"id">> => <<"1">>}},
#{<<"user">> => #{<<"id">> => <<"2">>}}
],
State = #{
data => #{<<"members">> => Members},
sessions => #{<<"s1">> => #{user_id => 1}},
member_presence => #{
1 => #{<<"status">> => <<"online">>},
2 => #{<<"status">> => <<"online">>}
}
},
?assertEqual(1, get_online_count(State)).
filter_members_for_list_keeps_members_without_connected_session_test() ->
Members = [
#{<<"user">> => #{<<"id">> => <<"1">>}},
#{<<"user">> => #{<<"id">> => <<"2">>}}
],
State = #{
sessions => #{<<"s1">> => #{user_id => 1}},
data => #{<<"channels">> => []}
},
FilteredMembers = filter_members_for_list(<<"0">>, Members, State),
?assertEqual([1, 2], [get_member_user_id(Member) || Member <- FilteredMembers]).
get_member_groups_counts_members_without_connected_session_as_offline_test() ->
Members = [
#{<<"user">> => #{<<"id">> => <<"1">>}},
#{<<"user">> => #{<<"id">> => <<"2">>}}
],
State = #{
id => 123,
data => #{
<<"members">> => Members,
<<"roles">> => []
},
sessions => #{<<"s1">> => #{user_id => 1}},
member_presence => #{
1 => #{<<"status">> => <<"online">>},
2 => #{<<"status">> => <<"online">>}
}
},
Groups = get_member_groups(<<"0">>, State),
OnlineGroup = lists:nth(1, Groups),
OfflineGroup = lists:nth(2, Groups),
?assertEqual(1, maps:get(<<"count">>, OnlineGroup)),
?assertEqual(1, maps:get(<<"count">>, OfflineGroup)).
-endif.

View File

@@ -29,40 +29,35 @@
compute_list_id/1
]).
-record(member_storage, {
members_table :: ets:tid(),
display_name_index :: gb_trees:tree()
}).
-type storage() :: #member_storage{}.
-type storage() :: #{
members_table := ets:tid(),
display_name_index := gb_trees:tree({binary(), user_id()}, user_id())
}.
-type user_id() :: integer().
-type member() :: map().
-type index_key() :: {binary(), user_id()}.
-export_type([storage/0]).
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-endif.
-spec new() -> storage().
new() ->
MembersTable = ets:new(members, [set, private]),
MembersTable = ets:new(members, [set, private, {read_concurrency, true}]),
DisplayNameIndex = gb_trees:empty(),
#member_storage{
members_table = MembersTable,
display_name_index = DisplayNameIndex
#{
members_table => MembersTable,
display_name_index => DisplayNameIndex
}.
-spec insert_member(member(), storage()) -> storage().
insert_member(Member, Storage) ->
UserId = extract_user_id(Member),
case UserId of
case extract_user_id(Member) of
undefined ->
Storage;
_ ->
UserId ->
OldMember = get_member(UserId, Storage),
Storage1 = remove_from_index(OldMember, Storage),
ets:insert(Storage1#member_storage.members_table, {UserId, Member}),
#{members_table := MembersTable} = Storage1,
ets:insert(MembersTable, {UserId, Member}),
add_to_index(UserId, Member, Storage1)
end.
@@ -73,13 +68,15 @@ remove_member(UserId, Storage) ->
Storage;
Member ->
Storage1 = remove_from_index(Member, Storage),
ets:delete(Storage1#member_storage.members_table, UserId),
#{members_table := MembersTable} = Storage1,
ets:delete(MembersTable, UserId),
Storage1
end.
-spec get_member(user_id(), storage()) -> member() | undefined.
get_member(UserId, Storage) ->
case ets:lookup(Storage#member_storage.members_table, UserId) of
#{members_table := MembersTable} = Storage,
case ets:lookup(MembersTable, UserId) of
[{UserId, Member}] -> Member;
[] -> undefined
end.
@@ -100,41 +97,46 @@ get_members_by_ids(UserIds, Storage) ->
search_members(Query, Limit, Storage) when is_binary(Query), Limit > 0 ->
NormalizedQuery = normalize_display_name(Query),
case NormalizedQuery of
<<>> ->
[];
_ ->
search_by_prefix(NormalizedQuery, Limit, Storage)
<<>> -> [];
_ -> search_by_prefix(NormalizedQuery, Limit, Storage)
end;
search_members(_, _, _) ->
[].
-spec get_range(non_neg_integer(), non_neg_integer(), storage()) -> [member()].
get_range(Offset, Limit, Storage) when is_integer(Offset), is_integer(Limit), Limit > 0 ->
Index = Storage#member_storage.display_name_index,
case gb_trees:size(Index) of
Size when Offset >= Size ->
[];
Size ->
AllKeys = gb_trees:keys(Index),
EndIdx = min(Offset + Limit, Size),
SelectedKeys = lists:sublist(AllKeys, Offset + 1, EndIdx - Offset),
lists:filtermap(
fun(Key) ->
UserId = gb_trees:get(Key, Index),
case get_member(UserId, Storage) of
undefined -> false;
Member -> {true, Member}
end
end,
SelectedKeys
)
#{display_name_index := Index} = Storage,
Size = gb_trees:size(Index),
case Offset >= Size of
true -> [];
false -> get_range_from_index(Offset, Limit, Size, Index, Storage)
end;
get_range(_, _, _) ->
[].
-spec get_range_from_index(
non_neg_integer(), non_neg_integer(), non_neg_integer(), gb_trees:tree(), storage()
) ->
[member()].
get_range_from_index(Offset, Limit, Size, Index, Storage) ->
AllKeys = gb_trees:keys(Index),
EndIdx = min(Offset + Limit, Size),
SelectedKeys = lists:sublist(AllKeys, Offset + 1, EndIdx - Offset),
lists:filtermap(
fun(Key) ->
UserId = gb_trees:get(Key, Index),
case get_member(UserId, Storage) of
undefined -> false;
Member -> {true, Member}
end
end,
SelectedKeys
).
-spec count(storage()) -> non_neg_integer().
count(Storage) ->
ets:info(Storage#member_storage.members_table, size).
#{members_table := MembersTable} = Storage,
ets:info(MembersTable, size).
-spec compute_list_id([user_id()]) -> integer().
compute_list_id(UserIds) ->
@@ -155,19 +157,17 @@ extract_user_id(_) ->
-spec get_display_name(member()) -> binary().
get_display_name(Member) when is_map(Member) ->
Nick = maps:get(<<"nick">>, Member, undefined),
case Nick of
undefined ->
User = maps:get(<<"user">>, Member, #{}),
GlobalName = maps:get(<<"global_name">>, User, undefined),
case GlobalName of
undefined ->
maps:get(<<"username">>, User, <<>>);
_ ->
GlobalName
end;
_ ->
Nick
case maps:get(<<"nick">>, Member, undefined) of
undefined -> get_display_name_from_user(Member);
Nick -> Nick
end.
-spec get_display_name_from_user(member()) -> binary().
get_display_name_from_user(Member) ->
User = maps:get(<<"user">>, Member, #{}),
case maps:get(<<"global_name">>, User, undefined) of
undefined -> maps:get(<<"username">>, User, <<>>);
GlobalName -> GlobalName
end.
-spec normalize_display_name(binary()) -> binary().
@@ -180,9 +180,9 @@ add_to_index(UserId, Member, Storage) ->
DisplayName = get_display_name(Member),
NormalizedName = normalize_display_name(DisplayName),
Key = make_index_key(NormalizedName, UserId),
Index = Storage#member_storage.display_name_index,
#{display_name_index := Index} = Storage,
NewIndex = gb_trees:enter(Key, UserId, Index),
Storage#member_storage{display_name_index = NewIndex}.
Storage#{display_name_index => NewIndex}.
-spec remove_from_index(member() | undefined, storage()) -> storage().
remove_from_index(undefined, Storage) ->
@@ -192,41 +192,46 @@ remove_from_index(Member, Storage) ->
DisplayName = get_display_name(Member),
NormalizedName = normalize_display_name(DisplayName),
Key = make_index_key(NormalizedName, UserId),
Index = Storage#member_storage.display_name_index,
#{display_name_index := Index} = Storage,
case gb_trees:is_defined(Key, Index) of
true ->
NewIndex = gb_trees:delete(Key, Index),
Storage#member_storage{display_name_index = NewIndex};
Storage#{display_name_index => NewIndex};
false ->
Storage
end.
-spec make_index_key(binary(), user_id()) -> {binary(), user_id()}.
-spec make_index_key(binary(), user_id()) -> index_key().
make_index_key(NormalizedName, UserId) ->
{NormalizedName, UserId}.
-spec search_by_prefix(binary(), non_neg_integer(), storage()) -> [member()].
search_by_prefix(Prefix, Limit, Storage) ->
Index = Storage#member_storage.display_name_index,
#{display_name_index := Index} = Storage,
AllKeys = gb_trees:keys(Index),
Matches = lists:filtermap(
fun({Name, UserId}) ->
PrefixLen = byte_size(Prefix),
case Name of
<<Prefix:PrefixLen/binary, _/binary>> ->
case get_member(UserId, Storage) of
undefined -> false;
Member -> {true, Member}
end;
_ ->
false
end
end,
AllKeys
),
PrefixLen = byte_size(Prefix),
Matches = find_prefix_matches(AllKeys, Prefix, PrefixLen, Storage, []),
lists:sublist(Matches, Limit).
-spec find_prefix_matches([index_key()], binary(), non_neg_integer(), storage(), [member()]) ->
[member()].
find_prefix_matches([], _Prefix, _PrefixLen, _Storage, Acc) ->
lists:reverse(Acc);
find_prefix_matches([{Name, UserId} | Rest], Prefix, PrefixLen, Storage, Acc) ->
case Name of
<<Prefix:PrefixLen/binary, _/binary>> ->
case get_member(UserId, Storage) of
undefined ->
find_prefix_matches(Rest, Prefix, PrefixLen, Storage, Acc);
Member ->
find_prefix_matches(Rest, Prefix, PrefixLen, Storage, [Member | Acc])
end;
_ ->
find_prefix_matches(Rest, Prefix, PrefixLen, Storage, Acc)
end.
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
new_creates_empty_storage_test() ->
Storage = new(),
@@ -309,11 +314,16 @@ display_name_username_fallback_test() ->
},
?assertEqual(<<"user">>, get_display_name(Member)).
compute_list_id_test() ->
compute_list_id_deterministic_test() ->
Id1 = compute_list_id([1, 2, 3]),
Id2 = compute_list_id([3, 2, 1]),
?assertEqual(Id1, Id2).
compute_list_id_different_for_different_lists_test() ->
Id1 = compute_list_id([1, 2, 3]),
Id2 = compute_list_id([1, 2, 4]),
?assertNotEqual(Id1, Id2).
get_range_test() ->
Storage = new(),
Member1 = #{<<"user">> => #{<<"id">> => <<"1">>, <<"username">> => <<"alice">>}},
@@ -325,4 +335,46 @@ get_range_test() ->
Results = get_range(1, 2, Storage3),
?assertEqual(2, length(Results)).
get_range_offset_beyond_size_test() ->
Storage = new(),
Member1 = #{<<"user">> => #{<<"id">> => <<"1">>, <<"username">> => <<"alice">>}},
Storage1 = insert_member(Member1, Storage),
Results = get_range(10, 5, Storage1),
?assertEqual([], Results).
search_members_empty_query_test() ->
Storage = new(),
Member1 = #{<<"user">> => #{<<"id">> => <<"1">>, <<"username">> => <<"alice">>}},
Storage1 = insert_member(Member1, Storage),
Results = search_members(<<>>, 10, Storage1),
?assertEqual([], Results).
search_members_limit_test() ->
Storage = new(),
Member1 = #{<<"user">> => #{<<"id">> => <<"1">>, <<"username">> => <<"aaa">>}},
Member2 = #{<<"user">> => #{<<"id">> => <<"2">>, <<"username">> => <<"aab">>}},
Member3 = #{<<"user">> => #{<<"id">> => <<"3">>, <<"username">> => <<"aac">>}},
Storage1 = insert_member(Member1, Storage),
Storage2 = insert_member(Member2, Storage1),
Storage3 = insert_member(Member3, Storage2),
Results = search_members(<<"aa">>, 2, Storage3),
?assertEqual(2, length(Results)).
normalize_display_name_test() ->
?assertEqual(<<"hello">>, normalize_display_name(<<"HELLO">>)),
?assertEqual(<<"hello">>, normalize_display_name(<<"Hello">>)),
?assertEqual(<<"hello">>, normalize_display_name(<<"hello">>)).
insert_member_updates_index_test() ->
Storage = new(),
Member1 = #{<<"user">> => #{<<"id">> => <<"1">>, <<"username">> => <<"alice">>}},
Storage1 = insert_member(Member1, Storage),
Member2 = #{<<"user">> => #{<<"id">> => <<"1">>, <<"username">> => <<"bob">>}},
Storage2 = insert_member(Member2, Storage1),
?assertEqual(1, count(Storage2)),
AliceResults = search_members(<<"alice">>, 10, Storage2),
?assertEqual(0, length(AliceResults)),
BobResults = search_members(<<"bob">>, 10, Storage2),
?assertEqual(1, length(BobResults)).
-endif.

View File

@@ -17,16 +17,18 @@
-module(guild_members).
-export([get_users_to_mention_by_roles/2]).
-export([get_users_to_mention_by_user_ids/2]).
-export([get_all_users_to_mention/2]).
-export([resolve_all_mentions/2]).
-export([get_members_with_role/2]).
-export([can_manage_roles/2]).
-export([can_manage_role/2]).
-export([get_assignable_roles/2]).
-export([check_target_member/2]).
-export([get_viewable_channels/2]).
-export([
get_users_to_mention_by_roles/2,
get_users_to_mention_by_user_ids/2,
get_all_users_to_mention/2,
resolve_all_mentions/2,
get_members_with_role/2,
can_manage_roles/2,
can_manage_role/2,
get_assignable_roles/2,
check_target_member/2,
get_viewable_channels/2
]).
-type guild_state() :: map().
-type guild_reply(T) :: {reply, T, guild_state()}.
@@ -37,22 +39,18 @@
-type role_id() :: integer().
-type channel_id() :: integer().
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-endif.
-spec get_users_to_mention_by_roles(map(), guild_state()) -> guild_reply(map()).
get_users_to_mention_by_roles(
#{channel_id := ChannelId, role_ids := RoleIds, author_id := AuthorId}, State
) ->
Members = guild_members(State),
RoleIdSet = normalize_int_list(RoleIds),
UserIds = collect_mentions(
Members,
RoleIdList = normalize_int_list(RoleIds),
CandidateUserIds = user_ids_for_any_role(RoleIdList, State),
UserIds = collect_mentions_for_user_ids(
CandidateUserIds,
AuthorId,
ChannelId,
State,
fun(Member) -> member_has_any_role(Member, RoleIdSet) end
fun(_UserId, _Member) -> true end
),
{reply, #{user_ids => UserIds}, State}.
@@ -60,19 +58,13 @@ get_users_to_mention_by_roles(
get_users_to_mention_by_user_ids(
#{channel_id := ChannelId, user_ids := UserIdsReq, author_id := AuthorId}, State
) ->
Members = guild_members(State),
TargetIds = normalize_int_list(UserIdsReq),
UserIds = collect_mentions(
Members,
UserIds = collect_mentions_for_user_ids(
TargetIds,
AuthorId,
ChannelId,
State,
fun(Member) ->
case member_user_id(Member) of
undefined -> false;
Id -> lists:member(Id, TargetIds)
end
end
fun(_UserId, _Member) -> true end
),
{reply, #{user_ids => UserIds}, State}.
@@ -83,7 +75,7 @@ get_all_users_to_mention(#{channel_id := ChannelId, author_id := AuthorId}, Stat
{reply, #{user_ids => UserIds}, State}.
-spec resolve_all_mentions(map(), guild_state()) -> guild_reply(map()).
resolve_all_mentions(
resolve_all_mentions(Request, State) ->
#{
channel_id := ChannelId,
author_id := AuthorId,
@@ -91,48 +83,122 @@ resolve_all_mentions(
mention_here := MentionHere,
role_ids := RoleIds,
user_ids := DirectUserIds
},
State
) ->
} = Request,
Members = guild_members(State),
MemberMap = guild_data_index:member_map(guild_data(State)),
Sessions = maps:get(sessions, State, #{}),
RoleIdSet = gb_sets:from_list(normalize_int_list(RoleIds)),
DirectUserIdSet = gb_sets:from_list(normalize_int_list(DirectUserIds)),
HasRoleMentions = not gb_sets:is_empty(RoleIdSet),
HasDirectMentions = not gb_sets:is_empty(DirectUserIdSet),
ConnectedUserIds =
case MentionHere of
ConnectedUserIds = build_connected_user_ids(MentionHere, Sessions),
UserIds =
case MentionEveryone of
true ->
gb_sets:from_list([
maps:get(user_id, S)
|| {_Sid, S} <- maps:to_list(Sessions)
]);
resolve_mentions(
Members,
AuthorId,
ChannelId,
MentionEveryone,
MentionHere,
HasRoleMentions,
HasDirectMentions,
RoleIdSet,
DirectUserIdSet,
ConnectedUserIds,
State
);
false ->
gb_sets:empty()
CandidateUserIds = candidate_user_ids_for_mentions(
MentionHere,
HasRoleMentions,
HasDirectMentions,
RoleIdSet,
DirectUserIdSet,
ConnectedUserIds,
State
),
resolve_mentions_for_user_ids(
CandidateUserIds,
MemberMap,
AuthorId,
ChannelId,
MentionEveryone,
MentionHere,
HasRoleMentions,
HasDirectMentions,
RoleIdSet,
DirectUserIdSet,
ConnectedUserIds,
State
)
end,
{reply, #{user_ids => UserIds}, State}.
UserIds = lists:filtermap(
-spec build_connected_user_ids(boolean(), map()) -> gb_sets:set(user_id()).
build_connected_user_ids(false, _Sessions) ->
gb_sets:empty();
build_connected_user_ids(true, Sessions) ->
gb_sets:from_list(
lists:filtermap(
fun({_Sid, SessionData}) ->
case maps:get(user_id, SessionData, undefined) of
UserId when is_integer(UserId) -> {true, UserId};
_ -> false
end
end,
maps:to_list(Sessions)
)
).
-spec resolve_mentions(
[member()],
user_id(),
channel_id(),
boolean(),
boolean(),
boolean(),
boolean(),
gb_sets:set(),
gb_sets:set(),
gb_sets:set(),
guild_state()
) -> [user_id()].
resolve_mentions(
Members,
AuthorId,
ChannelId,
MentionEveryone,
MentionHere,
HasRoleMentions,
HasDirectMentions,
RoleIdSet,
DirectUserIdSet,
ConnectedUserIds,
State
) ->
lists:filtermap(
fun(Member) ->
case member_user_id(Member) of
undefined ->
false;
UserId when UserId =:= AuthorId ->
false;
UserId when UserId =:= AuthorId -> false;
UserId ->
case is_member_bot(Member) of
true ->
false;
false ->
ShouldMention =
MentionEveryone orelse
(MentionHere andalso
gb_sets:is_member(UserId, ConnectedUserIds)) orelse
(HasRoleMentions andalso
member_has_any_role_set(Member, RoleIdSet)) orelse
(HasDirectMentions andalso
gb_sets:is_member(UserId, DirectUserIdSet)),
ShouldMention = check_should_mention(
UserId,
Member,
MentionEveryone,
MentionHere,
HasRoleMentions,
HasDirectMentions,
RoleIdSet,
DirectUserIdSet,
ConnectedUserIds
),
case
ShouldMention andalso
member_can_view_channel(UserId, ChannelId, Member, State)
@@ -144,61 +210,206 @@ resolve_all_mentions(
end
end,
Members
),
{reply, #{user_ids => UserIds}, State}.
).
-spec get_members_with_role(map(), guild_state()) -> guild_reply(map()).
get_members_with_role(#{role_id := RoleId}, State) ->
Members = guild_members(State),
TargetRoles = [RoleId],
UserIds = lists:filtermap(
fun(Member) ->
case member_user_id(Member) of
undefined ->
-spec resolve_mentions_for_user_ids(
[user_id()],
#{user_id() => member()},
user_id(),
channel_id(),
boolean(),
boolean(),
boolean(),
boolean(),
gb_sets:set(),
gb_sets:set(),
gb_sets:set(),
guild_state()
) -> [user_id()].
resolve_mentions_for_user_ids(
CandidateUserIds,
MemberMap,
AuthorId,
ChannelId,
MentionEveryone,
MentionHere,
HasRoleMentions,
HasDirectMentions,
RoleIdSet,
DirectUserIdSet,
ConnectedUserIds,
State
) ->
lists:filtermap(
fun(UserId) ->
case UserId =:= AuthorId of
true ->
false;
UserId ->
case member_has_any_role(Member, TargetRoles) of
true -> {true, UserId};
false -> false
false ->
case maps:get(UserId, MemberMap, undefined) of
undefined ->
false;
Member ->
case is_member_bot(Member) of
true ->
false;
false ->
ShouldMention = check_should_mention(
UserId,
Member,
MentionEveryone,
MentionHere,
HasRoleMentions,
HasDirectMentions,
RoleIdSet,
DirectUserIdSet,
ConnectedUserIds
),
case
ShouldMention andalso
member_can_view_channel(UserId, ChannelId, Member, State)
of
true -> {true, UserId};
false -> false
end
end
end
end
end,
Members
lists:usort(CandidateUserIds)
).
-spec candidate_user_ids_for_mentions(
boolean(),
boolean(),
boolean(),
gb_sets:set(),
gb_sets:set(),
gb_sets:set(),
guild_state()
) -> [user_id()].
candidate_user_ids_for_mentions(
MentionHere,
HasRoleMentions,
HasDirectMentions,
RoleIdSet,
DirectUserIdSet,
ConnectedUserIds,
State
) ->
HereSet =
case MentionHere of
true -> ConnectedUserIds;
false -> gb_sets:empty()
end,
RoleUsersSet =
case HasRoleMentions of
true ->
gb_sets:from_list(user_ids_for_any_role(gb_sets:to_list(RoleIdSet), State));
false ->
gb_sets:empty()
end,
DirectSet =
case HasDirectMentions of
true -> DirectUserIdSet;
false -> gb_sets:empty()
end,
gb_sets:to_list(gb_sets:union(HereSet, gb_sets:union(RoleUsersSet, DirectSet))).
-spec user_ids_for_any_role([role_id()], guild_state()) -> [user_id()].
user_ids_for_any_role(RoleIds, State) ->
Data = guild_data(State),
MemberRoleIndex = guild_data_index:member_role_index(Data),
RoleUserIds = lists:foldl(
fun(RoleId, AccSet) ->
case maps:get(RoleId, MemberRoleIndex, undefined) of
undefined ->
AccSet;
UserMap ->
lists:foldl(
fun(UserId, InnerSet) -> gb_sets:add(UserId, InnerSet) end,
AccSet,
maps:keys(UserMap)
)
end
end,
gb_sets:empty(),
RoleIds
),
gb_sets:to_list(RoleUserIds).
-spec check_should_mention(
user_id(),
member(),
boolean(),
boolean(),
boolean(),
boolean(),
gb_sets:set(),
gb_sets:set(),
gb_sets:set()
) -> boolean().
check_should_mention(
UserId,
Member,
MentionEveryone,
MentionHere,
HasRoleMentions,
HasDirectMentions,
RoleIdSet,
DirectUserIdSet,
ConnectedUserIds
) ->
MentionEveryone orelse
(MentionHere andalso gb_sets:is_member(UserId, ConnectedUserIds)) orelse
(HasRoleMentions andalso member_has_any_role_set(Member, RoleIdSet)) orelse
(HasDirectMentions andalso gb_sets:is_member(UserId, DirectUserIdSet)).
-spec get_members_with_role(map(), guild_state()) -> guild_reply(map()).
get_members_with_role(#{role_id := RoleId}, State) ->
Data = guild_data(State),
MemberRoleIndex = guild_data_index:member_role_index(Data),
TargetRoleId = type_conv:to_integer(RoleId),
UserMap =
case TargetRoleId of
undefined ->
#{};
_ ->
maps:get(TargetRoleId, MemberRoleIndex, #{})
end,
UserIds = lists:sort(maps:keys(UserMap)),
{reply, #{user_ids => UserIds}, State}.
-spec can_manage_roles(map(), guild_state()) -> guild_reply(map()).
can_manage_roles(#{user_id := UserId, role_id := RoleId}, State) ->
Data = guild_data(State),
OwnerId = owner_id(State),
Reply =
if
UserId =:= OwnerId ->
true;
true ->
UserPermissions = guild_permissions:get_member_permissions(
UserId, undefined, State
),
case (UserPermissions band constants:manage_roles_permission()) =/= 0 of
false ->
false;
true ->
Roles = maps:get(<<"roles">>, Data, []),
case find_role_by_id(RoleId, Roles) of
undefined ->
false;
Role ->
UserMax = guild_permissions:get_max_role_position(UserId, State),
UserMax > role_position(Role)
end
end
end,
Reply = check_can_manage_roles(UserId, RoleId, OwnerId, Data, State),
{reply, #{can_manage => Reply}, State}.
-spec check_can_manage_roles(user_id(), role_id(), user_id(), map(), guild_state()) -> boolean().
check_can_manage_roles(UserId, _RoleId, UserId, _Data, _State) ->
true;
check_can_manage_roles(UserId, RoleId, _OwnerId, Data, State) ->
UserPermissions = guild_permissions:get_member_permissions(UserId, undefined, State),
case (UserPermissions band constants:manage_roles_permission()) =/= 0 of
false ->
false;
true ->
Roles = guild_data_index:role_index(Data),
case find_role_by_id(RoleId, Roles) of
undefined ->
false;
Role ->
UserMax = guild_permissions:get_max_role_position(UserId, State),
UserMax > role_position(Role)
end
end.
-spec can_manage_role(map(), guild_state()) -> guild_reply(map()).
can_manage_role(#{user_id := UserId, role_id := RoleId}, State) ->
Data = guild_data(State),
Roles = maps:get(<<"roles">>, Data, []),
Roles = guild_data_index:role_index(Data),
Reply =
case find_role_by_id(RoleId, Roles) of
undefined ->
@@ -212,6 +423,7 @@ can_manage_role(#{user_id := UserId, role_id := RoleId}, State) ->
end,
{reply, #{can_manage => Reply}, State}.
-spec compare_role_ids_for_equal_position(user_id(), role_id(), guild_state()) -> boolean().
compare_role_ids_for_equal_position(UserId, TargetRoleId, State) ->
case guild_permissions:find_member_by_user_id(UserId, State) of
undefined ->
@@ -219,9 +431,8 @@ compare_role_ids_for_equal_position(UserId, TargetRoleId, State) ->
Member ->
MemberRoles = member_roles(Member),
Data = guild_data(State),
Roles = maps:get(<<"roles">>, Data, []),
UserHighestRole = get_highest_role(MemberRoles, Roles),
case UserHighestRole of
Roles = guild_data_index:role_index(Data),
case get_highest_role(MemberRoles, Roles) of
undefined ->
false;
HighestRole ->
@@ -230,39 +441,42 @@ compare_role_ids_for_equal_position(UserId, TargetRoleId, State) ->
end
end.
-spec get_highest_role([role_id()], [role()] | map()) -> role() | undefined.
get_highest_role(MemberRoleIds, Roles) ->
lists:foldl(
fun(RoleId, Acc) ->
case find_role_by_id(RoleId, Roles) of
undefined ->
Acc;
Role ->
case Acc of
undefined ->
Role;
AccRole ->
AccPos = role_position(AccRole),
RolePos = role_position(Role),
if
RolePos > AccPos ->
Role;
RolePos =:= AccPos ->
AccId = map_utils:get_integer(AccRole, <<"id">>, 0),
RId = map_utils:get_integer(Role, <<"id">>, 0),
if
RId < AccId -> Role;
true -> AccRole
end;
true ->
AccRole
end
end
undefined -> Acc;
Role -> compare_roles(Role, Acc)
end
end,
undefined,
MemberRoleIds
).
-spec compare_roles(role(), role() | undefined) -> role().
compare_roles(Role, undefined) ->
Role;
compare_roles(Role, AccRole) ->
AccPos = role_position(AccRole),
RolePos = role_position(Role),
case RolePos > AccPos of
true ->
Role;
false ->
case RolePos =:= AccPos of
true ->
AccId = map_utils:get_integer(AccRole, <<"id">>, 0),
RId = map_utils:get_integer(Role, <<"id">>, 0),
case RId < AccId of
true -> Role;
false -> AccRole
end;
false ->
AccRole
end
end.
-spec get_assignable_roles(map(), guild_state()) -> guild_reply(map()).
get_assignable_roles(#{user_id := UserId}, State) ->
Roles = guild_roles(State),
@@ -270,6 +484,7 @@ get_assignable_roles(#{user_id := UserId}, State) ->
RoleIds = get_assignable_role_ids(UserId, OwnerId, Roles, State),
{reply, #{role_ids => RoleIds}, State}.
-spec get_assignable_role_ids(user_id(), user_id(), [role()], guild_state()) -> [role_id()].
get_assignable_role_ids(OwnerId, OwnerId, Roles, _State) ->
role_ids_from_roles(Roles);
get_assignable_role_ids(UserId, _OwnerId, Roles, State) ->
@@ -279,6 +494,7 @@ get_assignable_role_ids(UserId, _OwnerId, Roles, State) ->
Roles
).
-spec filter_assignable_role(role(), integer()) -> {true, role_id()} | false.
filter_assignable_role(Role, UserMaxPosition) ->
case role_position(Role) < UserMaxPosition of
true ->
@@ -293,19 +509,19 @@ filter_assignable_role(Role, UserMaxPosition) ->
-spec check_target_member(map(), guild_state()) -> guild_reply(map()).
check_target_member(#{user_id := UserId, target_user_id := TargetUserId}, State) ->
OwnerId = owner_id(State),
CanManage =
if
UserId =:= OwnerId ->
true;
TargetUserId =:= OwnerId ->
false;
true ->
UserMaxPos = guild_permissions:get_max_role_position(UserId, State),
TargetMaxPos = guild_permissions:get_max_role_position(TargetUserId, State),
UserMaxPos > TargetMaxPos
end,
CanManage = check_can_manage_target(UserId, TargetUserId, OwnerId, State),
{reply, #{can_manage => CanManage}, State}.
-spec check_can_manage_target(user_id(), user_id(), user_id(), guild_state()) -> boolean().
check_can_manage_target(UserId, _TargetUserId, UserId, _State) ->
true;
check_can_manage_target(_UserId, OwnerId, OwnerId, _State) ->
false;
check_can_manage_target(UserId, TargetUserId, _OwnerId, State) ->
UserMaxPos = guild_permissions:get_max_role_position(UserId, State),
TargetMaxPos = guild_permissions:get_max_role_position(TargetUserId, State),
UserMaxPos > TargetMaxPos.
-spec get_viewable_channels(map(), guild_state()) -> guild_reply(map()).
get_viewable_channels(#{user_id := UserId}, State) ->
Channels = guild_channels(State),
@@ -313,29 +529,33 @@ get_viewable_channels(#{user_id := UserId}, State) ->
undefined ->
{reply, #{channel_ids => []}, State};
Member ->
ChannelIds = lists:filtermap(
fun(Channel) ->
ChannelId = map_utils:get_integer(Channel, <<"id">>, undefined),
case ChannelId of
undefined ->
false;
_ ->
case
guild_permissions:can_view_channel(UserId, ChannelId, Member, State)
of
true -> {true, ChannelId};
false -> false
end
end
end,
Channels
),
ChannelIds = filter_viewable_channels(Channels, UserId, Member, State),
{reply, #{channel_ids => ChannelIds}, State}
end.
-spec filter_viewable_channels([channel()], user_id(), member(), guild_state()) -> [channel_id()].
filter_viewable_channels(Channels, UserId, Member, State) ->
lists:filtermap(
fun(Channel) ->
ChannelId = map_utils:get_integer(Channel, <<"id">>, undefined),
case ChannelId of
undefined ->
false;
_ ->
case guild_permissions:can_view_channel(UserId, ChannelId, Member, State) of
true -> {true, ChannelId};
false -> false
end
end
end,
Channels
).
-spec find_member_by_user_id(user_id(), guild_state()) -> member() | undefined.
find_member_by_user_id(UserId, State) ->
guild_permissions:find_member_by_user_id(UserId, State).
-spec find_role_by_id(role_id(), [role()] | map()) -> role() | undefined.
find_role_by_id(RoleId, Roles) ->
guild_permissions:find_role_by_id(RoleId, Roles).
@@ -345,15 +565,15 @@ guild_data(State) ->
-spec guild_members(guild_state()) -> [member()].
guild_members(State) ->
map_utils:ensure_list(maps:get(<<"members">>, guild_data(State), [])).
guild_data_index:member_values(guild_data(State)).
-spec guild_roles(guild_state()) -> [role()].
guild_roles(State) ->
map_utils:ensure_list(maps:get(<<"roles">>, guild_data(State), [])).
guild_data_index:role_list(guild_data(State)).
-spec guild_channels(guild_state()) -> [channel()].
guild_channels(State) ->
map_utils:ensure_list(maps:get(<<"channels">>, guild_data(State), [])).
guild_data_index:channel_list(guild_data(State)).
-spec owner_id(guild_state()) -> user_id().
owner_id(State) ->
@@ -369,11 +589,6 @@ member_user_id(Member) ->
member_roles(Member) ->
normalize_int_list(map_utils:ensure_list(maps:get(<<"roles">>, Member, []))).
-spec member_has_any_role(member(), [role_id()]) -> boolean().
member_has_any_role(Member, RoleIds) ->
MemberRoles = member_roles(Member),
lists:any(fun(RoleId) -> lists:member(RoleId, MemberRoles) end, RoleIds).
-spec member_has_any_role_set(member(), gb_sets:set(role_id())) -> boolean().
member_has_any_role_set(Member, RoleIdSet) ->
MemberRoles = member_roles(Member),
@@ -390,6 +605,39 @@ member_can_view_channel(UserId, ChannelId, Member, State) when is_integer(Channe
member_can_view_channel(_, _, _, _) ->
false.
-spec collect_mentions_for_user_ids(
[user_id()],
user_id(),
channel_id(),
guild_state(),
fun((user_id(), member()) -> boolean())
) ->
[user_id()].
collect_mentions_for_user_ids(UserIds, AuthorId, ChannelId, State, Predicate) ->
MemberMap = guild_data_index:member_map(guild_data(State)),
lists:filtermap(
fun(UserId) ->
case UserId =:= AuthorId of
true ->
false;
false ->
case maps:get(UserId, MemberMap, undefined) of
undefined ->
false;
Member ->
case
Predicate(UserId, Member) andalso
member_can_view_channel(UserId, ChannelId, Member, State)
of
true -> {true, UserId};
false -> false
end
end
end
end,
lists:usort(UserIds)
).
-spec collect_mentions([member()], user_id(), channel_id(), guild_state(), fun(
(member()) -> boolean()
)) ->
@@ -446,6 +694,7 @@ role_position(Role) ->
maps:get(<<"position">>, Role, 0).
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
get_users_to_mention_by_roles_basic_test() ->
State = test_state(),
@@ -455,6 +704,30 @@ get_users_to_mention_by_roles_basic_test() ->
{reply, #{user_ids := UserIds}, _} = get_users_to_mention_by_roles(Request, State),
?assertEqual([2], UserIds).
get_users_to_mention_by_user_ids_filters_unknown_members_test() ->
State = test_state(),
Request = #{channel_id => 500, user_ids => [2, 999], author_id => 1},
{reply, #{user_ids := UserIds}, _} = get_users_to_mention_by_user_ids(Request, State),
?assertEqual([2], UserIds).
get_members_with_role_uses_member_role_index_test() ->
State = test_state(),
{reply, #{user_ids := UserIds}, _} = get_members_with_role(#{role_id => 200}, State),
?assertEqual([2], UserIds).
resolve_all_mentions_merges_role_and_direct_candidates_test() ->
State = test_state(),
Request = #{
channel_id => 500,
author_id => 1,
mention_everyone => false,
mention_here => false,
role_ids => [200],
user_ids => [3]
},
{reply, #{user_ids := UserIds}, _} = resolve_all_mentions(Request, State),
?assertEqual([2, 3], UserIds).
get_assignable_roles_owner_test() ->
State = test_state(),
{reply, #{role_ids := RoleIds}, _} = get_assignable_roles(#{user_id => 1}, State),
@@ -470,6 +743,32 @@ get_viewable_channels_filters_test() ->
{reply, #{channel_ids := ChannelIds}, _} = get_viewable_channels(#{user_id => 2}, State),
?assert(lists:member(500, ChannelIds)).
member_has_any_role_set_test() ->
Member = #{<<"roles">> => [<<"100">>, <<"200">>]},
RoleSet = gb_sets:from_list([200]),
MissingRoleSet = gb_sets:from_list([999]),
?assertEqual(true, member_has_any_role_set(Member, RoleSet)),
?assertEqual(false, member_has_any_role_set(Member, MissingRoleSet)).
is_member_bot_test() ->
BotMember = #{<<"user">> => #{<<"bot">> => true}},
HumanMember = #{<<"user">> => #{<<"bot">> => false}},
?assertEqual(true, is_member_bot(BotMember)),
?assertEqual(false, is_member_bot(HumanMember)).
normalize_int_list_test() ->
?assertEqual([1, 2, 3], normalize_int_list([<<"1">>, <<"2">>, <<"3">>])),
?assertEqual([1, 2], normalize_int_list([1, 2])),
?assertEqual([], normalize_int_list([])).
role_ids_from_roles_test() ->
Roles = [
#{<<"id">> => <<"100">>},
#{<<"id">> => <<"200">>},
#{<<"name">> => <<"no_id">>}
],
?assertEqual([100, 200], role_ids_from_roles(Roles)).
test_state() ->
GuildId = 100,
OwnerId = 1,

View File

@@ -21,7 +21,8 @@
schedule_passive_sync/1,
handle_passive_sync/1,
send_passive_updates_to_sessions/1,
compute_delta/2
compute_delta/2,
compute_channel_diffs/2
]).
-ifdef(TEST).
@@ -30,71 +31,156 @@
-define(PASSIVE_SYNC_INTERVAL, 30000).
-type guild_state() :: map().
-type channel_id() :: binary().
-type last_message_id() :: binary().
-type version() :: integer().
-type voice_state() :: map().
-spec schedule_passive_sync(guild_state()) -> guild_state().
schedule_passive_sync(State) ->
erlang:send_after(?PASSIVE_SYNC_INTERVAL, self(), passive_sync),
State.
-spec handle_passive_sync(guild_state()) -> {noreply, guild_state()}.
handle_passive_sync(State) ->
NewState = send_passive_updates_to_sessions(State),
schedule_passive_sync(NewState),
{noreply, NewState}.
-spec send_passive_updates_to_sessions(guild_state()) -> guild_state().
send_passive_updates_to_sessions(State) ->
GuildId = maps:get(id, State),
Sessions = maps:get(sessions, State, #{}),
Data = maps:get(data, State, #{}),
Channels = maps:get(<<"channels">>, Data, []),
MemberCount = maps:get(member_count, State, undefined),
IsLargeGuild = case MemberCount of
undefined -> false;
Count when is_integer(Count) -> Count > 250
end,
IsLargeGuild =
case MemberCount of
undefined -> false;
Count when is_integer(Count) -> Count > 250
end,
PassiveSessions = maps:filter(
fun(_SessionId, SessionData) ->
IsLargeGuild andalso session_passive:is_passive(GuildId, SessionData)
end,
Sessions
),
case map_size(PassiveSessions) of
0 ->
State;
_ ->
UpdatedSessions = lists:foldl(
fun({SessionId, SessionData}, AccSessions) ->
Pid = maps:get(pid, SessionData),
UserId = maps:get(user_id, SessionData),
Member = guild_permissions:find_member_by_user_id(UserId, State),
CurrentLastMessageIds = build_last_message_ids(Channels, UserId, Member, State),
PreviousLastMessageIds = maps:get(previous_passive_updates, SessionData, #{}),
Delta = compute_delta(CurrentLastMessageIds, PreviousLastMessageIds),
case {map_size(Delta), is_pid(Pid)} of
{0, _} ->
AccSessions;
{_, true} ->
EventData = #{
<<"guild_id">> => integer_to_binary(GuildId),
<<"channels">> => Delta
},
gen_server:cast(Pid, {dispatch, passive_updates, EventData}),
MergedLastMessageIds = maps:merge(PreviousLastMessageIds, Delta),
UpdatedSessionData = maps:put(previous_passive_updates, MergedLastMessageIds, SessionData),
maps:put(SessionId, UpdatedSessionData, AccSessions);
_ ->
AccSessions
end
end,
Sessions,
maps:to_list(PassiveSessions)
UpdatedSessions = process_passive_sessions(
maps:to_list(PassiveSessions), GuildId, Sessions, Channels, State
),
maps:put(sessions, UpdatedSessions, State)
end.
-spec process_passive_sessions([{binary(), map()}], integer(), map(), [map()], guild_state()) ->
map().
process_passive_sessions(PassiveSessionList, GuildId, Sessions, Channels, State) ->
lists:foldl(
fun({SessionId, SessionData}, AccSessions) ->
process_single_passive_session(
SessionId, SessionData, GuildId, Channels, State, AccSessions
)
end,
Sessions,
PassiveSessionList
).
-spec process_single_passive_session(binary(), map(), integer(), [map()], guild_state(), map()) ->
map().
process_single_passive_session(SessionId, SessionData, GuildId, Channels, State, AccSessions) ->
Pid = maps:get(pid, SessionData),
UserId = maps:get(user_id, SessionData),
Member = guild_permissions:find_member_by_user_id(UserId, State),
CurrentLastMessageIds = build_last_message_ids(Channels, UserId, Member, State),
PreviousLastMessageIds = maps:get(previous_passive_updates, SessionData, #{}),
Delta = compute_delta(CurrentLastMessageIds, PreviousLastMessageIds),
PreviousChannelVersions = maps:get(previous_passive_channel_versions, SessionData, #{}),
{CurrentChannelVersions, CurrentChannelsById} =
build_viewable_channel_snapshots(Channels, UserId, Member, State),
{CreatedChannelIds, UpdatedChannelIds, DeletedChannelIds} =
compute_channel_diffs(CurrentChannelVersions, PreviousChannelVersions),
CreatedChannels = [maps:get(Id, CurrentChannelsById) || Id <- CreatedChannelIds],
UpdatedChannels = [maps:get(Id, CurrentChannelsById) || Id <- UpdatedChannelIds],
ViewableChannels = guild_visibility:viewable_channel_set(UserId, State),
CurrentVoiceStates = build_current_voice_state_map(ViewableChannels, State),
PreviousVoiceStates = maps:get(previous_passive_voice_states, SessionData, #{}),
VoiceStateUpdates = compute_voice_state_updates(
CurrentVoiceStates, PreviousVoiceStates, GuildId
),
UpdatedSessionDataBase =
maps:put(previous_passive_voice_states, CurrentVoiceStates, SessionData),
HasChannelDelta = map_size(Delta) > 0,
HasVoiceUpdates = VoiceStateUpdates =/= [],
HasCreatedChannels = CreatedChannels =/= [],
HasUpdatedChannels = UpdatedChannels =/= [],
HasDeletedChannels = DeletedChannelIds =/= [],
ShouldSend =
HasChannelDelta orelse HasVoiceUpdates orelse
HasCreatedChannels orelse HasUpdatedChannels orelse HasDeletedChannels,
case {ShouldSend, is_pid(Pid)} of
{true, true} ->
EventData = build_passive_event_data(
GuildId,
Delta,
CreatedChannels,
UpdatedChannels,
DeletedChannelIds,
VoiceStateUpdates
),
gen_server:cast(Pid, {dispatch, passive_updates, EventData}),
PreviousLastMessageIds1 = maps:without(DeletedChannelIds, PreviousLastMessageIds),
MergedLastMessageIds = maps:merge(PreviousLastMessageIds1, Delta),
UpdatedSessionData0 =
maps:put(previous_passive_updates, MergedLastMessageIds, UpdatedSessionDataBase),
UpdatedSessionData =
maps:put(
previous_passive_channel_versions, CurrentChannelVersions, UpdatedSessionData0
),
maps:put(SessionId, UpdatedSessionData, AccSessions);
_ ->
UpdatedSessionData =
maps:put(
previous_passive_channel_versions,
CurrentChannelVersions,
UpdatedSessionDataBase
),
maps:put(SessionId, UpdatedSessionData, AccSessions)
end.
-spec build_passive_event_data(integer(), map(), [map()], [map()], [binary()], [map()]) -> map().
build_passive_event_data(
GuildId, Delta, CreatedChannels, UpdatedChannels, DeletedChannelIds, VoiceStateUpdates
) ->
EventDataBase = #{
<<"guild_id">> => integer_to_binary(GuildId),
<<"channels">> => Delta
},
EventData1 =
case CreatedChannels of
[] -> EventDataBase;
_ -> maps:put(<<"created_channels">>, CreatedChannels, EventDataBase)
end,
EventData2 =
case UpdatedChannels of
[] -> EventData1;
_ -> maps:put(<<"updated_channels">>, UpdatedChannels, EventData1)
end,
EventData3 =
case DeletedChannelIds of
[] -> EventData2;
_ -> maps:put(<<"deleted_channel_ids">>, DeletedChannelIds, EventData2)
end,
case VoiceStateUpdates of
[] -> EventData3;
_ -> maps:put(<<"voice_states">>, VoiceStateUpdates, EventData3)
end.
-spec compute_delta(#{channel_id() => last_message_id()}, #{channel_id() => last_message_id()}) ->
#{channel_id() => last_message_id()}.
compute_delta(CurrentLastMessageIds, PreviousLastMessageIds) ->
maps:filter(
fun(ChannelId, CurrentValue) ->
@@ -106,6 +192,23 @@ compute_delta(CurrentLastMessageIds, PreviousLastMessageIds) ->
CurrentLastMessageIds
).
-spec compute_channel_diffs(#{channel_id() => version()}, #{channel_id() => version()}) ->
{[channel_id()], [channel_id()], [channel_id()]}.
compute_channel_diffs(Current, Previous) ->
Created =
[Id || {Id, _} <- maps:to_list(Current), not maps:is_key(Id, Previous)],
Updated =
[
Id
|| {Id, CurV} <- maps:to_list(Current),
maps:is_key(Id, Previous) andalso maps:get(Id, Previous) =/= CurV
],
Deleted =
[Id || {Id, _} <- maps:to_list(Previous), not maps:is_key(Id, Current)],
{Created, Updated, Deleted}.
-spec build_last_message_ids([map()], integer(), map() | undefined, guild_state()) ->
#{channel_id() => last_message_id()}.
build_last_message_ids(Channels, UserId, Member, State) ->
lists:foldl(
fun(Channel, Acc) ->
@@ -117,16 +220,22 @@ build_last_message_ids(Channels, UserId, Member, State) ->
{_, null} ->
Acc;
_ ->
ChannelId = validation:snowflake_or_default(<<"id">>, ChannelIdBin, 0),
case Member of
case parse_snowflake(<<"id">>, ChannelIdBin) of
undefined ->
Acc;
_ ->
case guild_permissions:can_view_channel(UserId, ChannelId, Member, State) of
true ->
maps:put(ChannelIdBin, LastMessageId, Acc);
false ->
Acc
ChannelId ->
case Member of
undefined ->
Acc;
_ ->
case
guild_permissions:can_view_channel(UserId, ChannelId, Member, State)
of
true ->
maps:put(ChannelIdBin, LastMessageId, Acc);
false ->
Acc
end
end
end
end
@@ -135,41 +244,196 @@ build_last_message_ids(Channels, UserId, Member, State) ->
Channels
).
-spec build_viewable_channel_snapshots([map()], integer(), map() | undefined, guild_state()) ->
{#{channel_id() => version()}, #{channel_id() => map()}}.
build_viewable_channel_snapshots(Channels, UserId, Member, State) ->
case Member of
undefined ->
{#{}, #{}};
_ ->
lists:foldl(
fun(Channel, {VersionsAcc, ChannelsAcc}) ->
ChannelIdBin = maps:get(<<"id">>, Channel, undefined),
case ChannelIdBin of
undefined ->
{VersionsAcc, ChannelsAcc};
_ ->
case parse_snowflake(<<"id">>, ChannelIdBin) of
undefined ->
{VersionsAcc, ChannelsAcc};
ChannelId ->
case
guild_permissions:can_view_channel(UserId, ChannelId, Member, State)
of
true ->
Version = map_utils:get_integer(Channel, <<"version">>, 0),
{
maps:put(ChannelIdBin, Version, VersionsAcc),
maps:put(ChannelIdBin, Channel, ChannelsAcc)
};
false ->
{VersionsAcc, ChannelsAcc}
end
end
end
end,
{#{}, #{}},
Channels
)
end.
-spec build_current_voice_state_map(sets:set(), guild_state()) -> #{binary() => voice_state()}.
build_current_voice_state_map(ViewableChannels, State) ->
VoiceStates = maps:get(voice_states, State, #{}),
maps:fold(
fun(ConnectionId, VoiceState, Acc) ->
ChannelIdBin = maps:get(<<"channel_id">>, VoiceState, null),
case ChannelIdBin of
null ->
Acc;
_ ->
case parse_snowflake(<<"channel_id">>, ChannelIdBin) of
undefined ->
Acc;
ChannelId ->
case sets:is_element(ChannelId, ViewableChannels) of
true -> maps:put(ConnectionId, VoiceState, Acc);
false -> Acc
end
end
end
end,
#{},
VoiceStates
).
-spec parse_snowflake(binary(), term()) -> integer() | undefined.
parse_snowflake(FieldName, Value) ->
case validation:validate_snowflake(FieldName, Value) of
{ok, Id} -> Id;
{error, _, _} -> undefined
end.
-spec compute_voice_state_updates(
#{binary() => voice_state()}, #{binary() => voice_state()}, integer()
) ->
[voice_state()].
compute_voice_state_updates(Current, Previous, GuildId) ->
GuildIdBin = integer_to_binary(GuildId),
Updated = maps:fold(
fun(ConnectionId, VoiceState, Acc) ->
PrevState = maps:get(ConnectionId, Previous, undefined),
case is_voice_state_changed(VoiceState, PrevState) of
true -> [ensure_voice_state_guild(VoiceState, GuildIdBin) | Acc];
false -> Acc
end
end,
[],
Current
),
Removed = maps:fold(
fun(ConnectionId, PrevState, Acc) ->
case maps:is_key(ConnectionId, Current) of
true -> Acc;
false -> [build_removed_voice_state(PrevState, GuildIdBin) | Acc]
end
end,
[],
Previous
),
lists:reverse(Updated) ++ lists:reverse(Removed).
-spec is_voice_state_changed(voice_state(), voice_state() | undefined) -> boolean().
is_voice_state_changed(_Current, undefined) ->
true;
is_voice_state_changed(Current, Previous) ->
voice_state_version(Current) =/= voice_state_version(Previous).
-spec voice_state_version(voice_state()) -> integer().
voice_state_version(VoiceState) ->
map_utils:get_integer(VoiceState, <<"version">>, 0).
-spec ensure_voice_state_guild(voice_state(), binary()) -> voice_state().
ensure_voice_state_guild(VoiceState, GuildIdBin) ->
case maps:get(<<"guild_id">>, VoiceState, undefined) of
undefined -> maps:put(<<"guild_id">>, GuildIdBin, VoiceState);
_ -> VoiceState
end.
-spec build_removed_voice_state(voice_state(), binary()) -> voice_state().
build_removed_voice_state(PrevState, GuildIdBin) ->
PrevWithGuild = ensure_voice_state_guild(PrevState, GuildIdBin),
maps:put(<<"channel_id">>, null, PrevWithGuild).
-ifdef(TEST).
compute_delta_empty_previous_test() ->
Current = #{<<"1">> => <<"100">>, <<"2">> => <<"200">>},
Previous = #{},
Delta = compute_delta(Current, Previous),
?assertEqual(Current, Delta),
ok.
?assertEqual(Current, Delta).
compute_delta_no_changes_test() ->
Current = #{<<"1">> => <<"100">>, <<"2">> => <<"200">>},
Previous = #{<<"1">> => <<"100">>, <<"2">> => <<"200">>},
Delta = compute_delta(Current, Previous),
?assertEqual(#{}, Delta),
ok.
?assertEqual(#{}, Delta).
compute_delta_partial_changes_test() ->
Current = #{<<"1">> => <<"101">>, <<"2">> => <<"200">>, <<"3">> => <<"300">>},
Previous = #{<<"1">> => <<"100">>, <<"2">> => <<"200">>},
Delta = compute_delta(Current, Previous),
?assertEqual(#{<<"1">> => <<"101">>, <<"3">> => <<"300">>}, Delta),
ok.
?assertEqual(#{<<"1">> => <<"101">>, <<"3">> => <<"300">>}, Delta).
compute_delta_only_new_channels_test() ->
Current = #{<<"1">> => <<"100">>, <<"2">> => <<"200">>, <<"3">> => <<"300">>},
Previous = #{<<"1">> => <<"100">>, <<"2">> => <<"200">>},
Delta = compute_delta(Current, Previous),
?assertEqual(#{<<"3">> => <<"300">>}, Delta),
ok.
?assertEqual(#{<<"3">> => <<"300">>}, Delta).
compute_delta_ignores_removed_channels_test() ->
Current = #{<<"1">> => <<"100">>},
Previous = #{<<"1">> => <<"100">>, <<"2">> => <<"200">>},
Delta = compute_delta(Current, Previous),
?assertEqual(#{}, Delta),
ok.
?assertEqual(#{}, Delta).
compute_channel_diffs_detects_created_updated_deleted_test() ->
Current = #{<<"1">> => 2, <<"2">> => 1},
Previous = #{<<"1">> => 1, <<"3">> => 9},
{Created, Updated, Deleted} = compute_channel_diffs(Current, Previous),
?assertEqual([<<"2">>], lists:sort(Created)),
?assertEqual([<<"1">>], lists:sort(Updated)),
?assertEqual([<<"3">>], lists:sort(Deleted)).
compute_voice_state_updates_reports_changes_test() ->
PrevVoiceState = #{
<<"connection_id">> => <<"conn1">>,
<<"channel_id">> => <<"100">>,
<<"user_id">> => <<"200">>,
<<"version">> => 1
},
CurrentVoiceState = maps:put(<<"version">>, 2, PrevVoiceState),
Current = #{<<"conn1">> => CurrentVoiceState},
Previous = #{<<"conn1">> => PrevVoiceState},
Updates = compute_voice_state_updates(Current, Previous, 999),
?assertEqual(1, length(Updates)),
Update = hd(Updates),
?assertEqual(<<"conn1">>, maps:get(<<"connection_id">>, Update)),
?assertEqual(integer_to_binary(999), maps:get(<<"guild_id">>, Update)).
compute_voice_state_updates_reports_removal_test() ->
RemovedVoiceState = #{
<<"connection_id">> => <<"conn2">>,
<<"channel_id">> => <<"200">>,
<<"user_id">> => <<"300">>,
<<"version">> => 3
},
Current = #{},
Previous = #{<<"conn2">> => RemovedVoiceState},
Updates = compute_voice_state_updates(Current, Previous, 101),
?assertEqual(1, length(Updates)),
Update = hd(Updates),
?assertEqual(null, maps:get(<<"channel_id">>, Update)),
?assertEqual(integer_to_binary(101), maps:get(<<"guild_id">>, Update)).
-endif.

View File

@@ -0,0 +1,132 @@
%% Copyright (C) 2026 Fluxer Contributors
%%
%% This file is part of Fluxer.
%%
%% Fluxer is free software: you can redistribute it and/or modify
%% it under the terms of the GNU Affero General Public License as published by
%% the Free Software Foundation, either version 3 of the License, or
%% (at your option) any later version.
%%
%% Fluxer is distributed in the hope that it will be useful,
%% but WITHOUT ANY WARRANTY; without even the implied warranty of
%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
%% GNU Affero General Public License for more details.
%%
%% You should have received a copy of the GNU Affero General Public License
%% along with Fluxer. If not, see <https://www.gnu.org/licenses/>.
-module(guild_permission_cache).
-export([
put_state/1,
put_data/2,
delete/1,
get_permissions/3,
get_snapshot/1
]).
-type guild_id() :: integer().
-type user_id() :: integer().
-type channel_id() :: integer().
-type guild_state() :: map().
-type guild_data() :: map().
-define(TABLE, guild_permission_cache).
-spec put_state(guild_state()) -> ok.
put_state(State) when is_map(State) ->
GuildId = maps:get(id, State, undefined),
Data = maps:get(data, State, #{}),
case is_integer(GuildId) of
true ->
put_data(GuildId, Data);
false ->
ok
end;
put_state(_) ->
ok.
-spec put_data(guild_id(), guild_data()) -> ok.
put_data(GuildId, Data) when is_integer(GuildId), is_map(Data) ->
ensure_table(),
NormalizedData = guild_data_index:normalize_data(Data),
Snapshot = #{id => GuildId, data => NormalizedData},
true = ets:insert(?TABLE, {GuildId, Snapshot}),
ok;
put_data(_, _) ->
ok.
-spec delete(guild_id()) -> ok.
delete(GuildId) when is_integer(GuildId) ->
ensure_table(),
_ = ets:delete(?TABLE, GuildId),
ok;
delete(_) ->
ok.
-spec get_permissions(guild_id(), user_id(), channel_id() | undefined) ->
{ok, integer()} | {error, not_found}.
get_permissions(GuildId, UserId, ChannelId) when is_integer(GuildId), is_integer(UserId) ->
case get_snapshot(GuildId) of
{ok, Snapshot} ->
Permissions = guild_permissions:get_member_permissions(UserId, ChannelId, Snapshot),
{ok, Permissions};
{error, not_found} ->
{error, not_found}
end;
get_permissions(_, _, _) ->
{error, not_found}.
-spec get_snapshot(guild_id()) -> {ok, guild_state()} | {error, not_found}.
get_snapshot(GuildId) when is_integer(GuildId) ->
ensure_table(),
case ets:lookup(?TABLE, GuildId) of
[{GuildId, Snapshot}] ->
{ok, Snapshot};
[] ->
{error, not_found}
end;
get_snapshot(_) ->
{error, not_found}.
-spec ensure_table() -> ok.
ensure_table() ->
case ets:whereis(?TABLE) of
undefined ->
try ets:new(?TABLE, [named_table, public, set, {read_concurrency, true}]) of
_ -> ok
catch
error:badarg -> ok
end;
_ ->
ok
end.
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
put_and_get_permissions_test() ->
GuildId = 101,
UserId = 44,
ViewPermission = constants:view_channel_permission(),
Data = #{
<<"guild">> => #{<<"owner_id">> => <<"999">>},
<<"roles">> => [
#{<<"id">> => integer_to_binary(GuildId), <<"permissions">> => integer_to_binary(ViewPermission)}
],
<<"members">> => #{
UserId => #{<<"user">> => #{<<"id">> => integer_to_binary(UserId)}, <<"roles">> => []}
},
<<"channels">> => [
#{<<"id">> => <<"500">>, <<"permission_overwrites">> => []}
]
},
ok = put_data(GuildId, Data),
{ok, Permissions} = get_permissions(GuildId, UserId, 500),
?assert((Permissions band ViewPermission) =/= 0),
ok = delete(GuildId).
missing_guild_returns_not_found_test() ->
?assertEqual({error, not_found}, get_permissions(999999, 1, undefined)).
-endif.

View File

@@ -19,22 +19,21 @@
-define(ALL_PERMISSIONS, 16#FFFFFFFFFFFFFFFF).
-export([get_member_permissions/3]).
-export([can_view_channel/4]).
-export([can_view_channel_by_permissions/4]).
-export([can_manage_channel/3]).
-export([apply_channel_overwrites/5]).
-export([get_max_role_position/2]).
-export([find_member_by_user_id/2]).
-export([find_role_by_id/2]).
-export([find_channel_by_id/2]).
-export([
get_member_permissions/3,
can_view_channel/4,
can_view_channel_by_permissions/4,
can_manage_channel/3,
can_access_message_by_permissions/3,
apply_channel_overwrites/5,
get_max_role_position/2,
find_member_by_user_id/2,
find_role_by_id/2,
find_channel_by_id/2
]).
-export_type([permission/0]).
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-endif.
-type permission() :: non_neg_integer().
-type user_id() :: integer().
-type role_id() :: integer().
@@ -61,13 +60,48 @@ can_view_channel(UserId, ChannelId, Member, State) ->
-spec can_view_channel_by_permissions(user_id(), channel_id(), maybe_member(), guild_state()) ->
boolean().
can_view_channel_by_permissions(UserId, ChannelId, Member, State) ->
(compute_member_permissions(UserId, ChannelId, Member, State) band
constants:view_channel_permission()) =/= 0.
Perms = compute_member_permissions(UserId, ChannelId, Member, State),
(Perms band constants:view_channel_permission()) =/= 0.
-spec can_manage_channel(user_id(), maybe_channel_id(), guild_state()) -> boolean().
can_manage_channel(UserId, ChannelId, State) ->
(get_member_permissions(UserId, ChannelId, State) band
constants:manage_channels_permission()) =/= 0.
Perms = get_member_permissions(UserId, ChannelId, State),
(Perms band constants:manage_channels_permission()) =/= 0.
-spec can_access_message_by_permissions(permission(), binary(), guild_state()) -> boolean().
can_access_message_by_permissions(Permissions, MessageId, State) ->
HasReadHistory = (Permissions band constants:read_message_history_permission()) =/= 0,
case HasReadHistory of
true ->
true;
false ->
case get_message_history_cutoff(State) of
null ->
false;
CutoffMs ->
MessageMs = snowflake_util:extract_timestamp(MessageId),
MessageMs >= CutoffMs
end
end.
-spec get_message_history_cutoff(guild_state()) -> integer() | null.
get_message_history_cutoff(State) ->
case resolve_data_map(State) of
undefined ->
null;
Data ->
Guild = maps:get(<<"guild">>, Data, #{}),
case maps:get(<<"message_history_cutoff">>, Guild, null) of
null ->
null;
CutoffBin when is_binary(CutoffBin) ->
calendar:rfc3339_to_system_time(
binary_to_list(CutoffBin), [{unit, millisecond}]
);
CutoffInt when is_integer(CutoffInt) ->
CutoffInt
end
end.
-spec apply_channel_overwrites(permission(), user_id(), member_roles(), channel(), role_id()) ->
permission().
@@ -86,64 +120,51 @@ get_max_role_position(UserId, State) ->
{_, undefined} ->
-1;
{Member, Data} ->
Roles = ensure_list(maps:get(<<"roles">>, Data, [])),
lists:foldl(
fun(RoleId, MaxPos) ->
case find_role_by_id(RoleId, Roles) of
undefined ->
MaxPos;
Role ->
Position = maps:get(<<"position">>, Role, 0),
max(Position, MaxPos)
end
end,
-1,
member_role_ids(Member)
)
Roles = guild_data_index:role_index(Data),
compute_max_position(Member, Roles)
end.
-spec compute_max_position(member(), [role()] | map()) -> integer().
compute_max_position(Member, Roles) ->
lists:foldl(
fun(RoleId, MaxPos) ->
case find_role_by_id(RoleId, Roles) of
undefined ->
MaxPos;
Role ->
Position = maps:get(<<"position">>, Role, 0),
max(Position, MaxPos)
end
end,
-1,
member_role_ids(Member)
).
-spec find_member_by_user_id(user_id(), guild_state()) -> member() | undefined.
find_member_by_user_id(UserId, State) when is_integer(UserId) ->
case resolve_data_map(State) of
undefined ->
undefined;
Data ->
Members = ensure_list(maps:get(<<"members">>, Data, [])),
lists:foldl(
fun(Member, Acc) ->
case Acc of
undefined ->
MUser = maps:get(<<"user">>, Member, #{}),
MemberId = to_int(maps:get(<<"id">>, MUser, <<"0">>)),
case MemberId =:= UserId of
true -> Member;
false -> undefined
end;
Found ->
Found
end
end,
undefined,
Members
)
Members = guild_data_index:member_map(Data),
maps:get(UserId, Members, undefined)
end;
find_member_by_user_id(_, _) ->
undefined.
-spec find_role_by_id(role_id(), list()) -> role() | undefined.
-spec find_role_by_id(role_id(), [role()] | map()) -> role() | undefined.
find_role_by_id(RoleId, Roles) when is_map(Roles) ->
maps:get(to_int(RoleId), Roles, undefined);
find_role_by_id(RoleId, Roles) ->
TargetId = to_int(RoleId),
lists:foldl(
fun(Role, Acc) ->
case Acc of
undefined ->
case role_id(Role) =:= TargetId of
true -> Role;
false -> undefined
end;
Found ->
Found
end
fun
(_, Found) when Found =/= undefined -> Found;
(Role, undefined) ->
case role_id(Role) =:= TargetId of
true -> Role;
false -> undefined
end
end,
undefined,
ensure_list(Roles)
@@ -155,23 +176,8 @@ find_channel_by_id(ChannelId, State) when is_integer(ChannelId) ->
undefined ->
undefined;
Data ->
Channels = ensure_list(maps:get(<<"channels">>, Data, [])),
lists:foldl(
fun(Channel, Acc) ->
case Acc of
undefined ->
ChanId = to_int(maps:get(<<"id">>, Channel, <<"0">>)),
case ChanId =:= ChannelId of
true -> Channel;
false -> undefined
end;
Found ->
Found
end
end,
undefined,
Channels
)
Channels = guild_data_index:channel_index(Data),
maps:get(ChannelId, Channels, undefined)
end;
find_channel_by_id(_, _) ->
undefined.
@@ -188,36 +194,36 @@ compute_member_permissions(UserId, ChannelId, ProvidedMember, State) when is_int
true ->
?ALL_PERMISSIONS;
false ->
case resolve_member(UserId, ProvidedMember, State) of
undefined ->
0;
Member ->
GuildId = guild_id(State),
Roles = ensure_list(maps:get(<<"roles">>, Data, [])),
BasePermissions = base_role_permissions(GuildId, Roles),
MemberRoles = member_role_ids(Member),
Permissions = aggregate_role_permissions(
MemberRoles, Roles, BasePermissions
),
case (Permissions band constants:administrator_permission()) =/= 0 of
true ->
?ALL_PERMISSIONS;
false ->
maybe_apply_channel_overwrites(
Permissions,
UserId,
MemberRoles,
ChannelId,
GuildId,
State
)
end
end
compute_non_owner_permissions(UserId, ChannelId, ProvidedMember, State, Data)
end
end;
compute_member_permissions(_, _, _, _) ->
0.
-spec compute_non_owner_permissions(
user_id(), maybe_channel_id(), maybe_member(), guild_state(), guild_data()
) ->
permission().
compute_non_owner_permissions(UserId, ChannelId, ProvidedMember, State, Data) ->
case resolve_member(UserId, ProvidedMember, State) of
undefined ->
0;
Member ->
GuildId = guild_id(State),
Roles = guild_data_index:role_index(Data),
BasePermissions = base_role_permissions(GuildId, Roles),
MemberRoles = member_role_ids(Member),
Permissions = aggregate_role_permissions(MemberRoles, Roles, BasePermissions),
case (Permissions band constants:administrator_permission()) =/= 0 of
true ->
?ALL_PERMISSIONS;
false ->
maybe_apply_channel_overwrites(
Permissions, UserId, MemberRoles, ChannelId, GuildId, State
)
end
end.
-spec resolve_member(user_id(), maybe_member(), guild_state()) -> maybe_member().
resolve_member(_UserId, Member, _State) when is_map(Member) ->
Member;
@@ -232,36 +238,25 @@ guild_owner_id(Data) ->
-spec guild_id(guild_state()) -> integer().
guild_id(State) ->
case maps:get(id, State, undefined) of
undefined ->
to_int(maps:get(<<"id">>, State, 0));
GuildId when is_integer(GuildId) ->
GuildId;
GuildId ->
to_int(GuildId)
undefined -> to_int(maps:get(<<"id">>, State, 0));
GuildId when is_integer(GuildId) -> GuildId;
GuildId -> to_int(GuildId)
end.
-spec base_role_permissions(role_id(), list()) -> permission().
-spec base_role_permissions(role_id(), map()) -> permission().
base_role_permissions(GuildId, Roles) ->
lists:foldl(
fun(Role, Acc) ->
case role_id(Role) =:= GuildId of
true -> role_permissions(Role);
false -> Acc
end
end,
0,
ensure_list(Roles)
).
case find_role_by_id(GuildId, Roles) of
undefined -> 0;
Role -> role_permissions(Role)
end.
-spec aggregate_role_permissions(member_roles(), list(), permission()) -> permission().
-spec aggregate_role_permissions(member_roles(), [role()] | map(), permission()) -> permission().
aggregate_role_permissions(MemberRoles, Roles, BasePermissions) ->
lists:foldl(
fun(RoleId, Acc) ->
case find_role_by_id(RoleId, Roles) of
undefined ->
Acc;
Role ->
Acc bor role_permissions(Role)
undefined -> Acc;
Role -> Acc bor role_permissions(Role)
end
end,
BasePermissions,
@@ -277,10 +272,8 @@ maybe_apply_channel_overwrites(Permissions, UserId, MemberRoles, ChannelId, Guil
is_integer(ChannelId)
->
case find_channel_by_id(ChannelId, State) of
undefined ->
Permissions;
Channel ->
apply_channel_overwrites(Permissions, UserId, MemberRoles, Channel, GuildId)
undefined -> Permissions;
Channel -> apply_channel_overwrites(Permissions, UserId, MemberRoles, Channel, GuildId)
end;
maybe_apply_channel_overwrites(Permissions, _UserId, _MemberRoles, _ChannelId, _GuildId, _State) ->
Permissions.
@@ -327,10 +320,8 @@ accumulate_role_overwrites(MemberRoles, Overwrites) ->
lists:foldl(
fun(Overwrite, {A, D}) ->
case overwrite_matches_role(Overwrite, RoleId) of
true ->
{A bor overwrite_allow(Overwrite), D bor overwrite_deny(Overwrite)};
false ->
{A, D}
true -> {A bor overwrite_allow(Overwrite), D bor overwrite_deny(Overwrite)};
false -> {A, D}
end
end,
{AllowAcc, DenyAcc},
@@ -406,10 +397,8 @@ extract_integer_list(_) ->
[].
-spec ensure_list(term()) -> list().
ensure_list(List) when is_list(List) ->
List;
ensure_list(_) ->
[].
ensure_list(List) when is_list(List) -> List;
ensure_list(_) -> [].
-spec to_int(term()) -> integer().
to_int(Value) ->
@@ -421,22 +410,20 @@ to_int(Value) ->
-spec resolve_data_map(guild_state() | map()) -> guild_data() | undefined.
resolve_data_map(State) when is_map(State) ->
case maps:find(data, State) of
{ok, Data} when is_map(Data) ->
Data;
{ok, Data} when is_map(Data) =:= false ->
{ok, Data} when is_map(Data) -> Data;
{ok, Data} ->
Data;
error ->
case State of
#{<<"members">> := _} ->
State;
_ ->
undefined
case maps:is_key(<<"members">>, State) of
true -> State;
false -> undefined
end
end;
resolve_data_map(_) ->
undefined.
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
owner_receives_full_permissions_test() ->
OwnerId = 1,
@@ -537,19 +524,16 @@ administrator_role_grants_all_permissions_test() ->
data => #{
<<"guild">> => #{<<"owner_id">> => integer_to_binary(OwnerId)},
<<"roles">> => [
#{<<"id">> => integer_to_binary(GuildId), <<"permissions">> => integer_to_binary(Admin)}
#{
<<"id">> => integer_to_binary(GuildId),
<<"permissions">> => integer_to_binary(Admin)
}
],
<<"members">> => [
#{
<<"user">> => #{<<"id">> => integer_to_binary(UserId)},
<<"roles">> => []
}
#{<<"user">> => #{<<"id">> => integer_to_binary(UserId)}, <<"roles">> => []}
],
<<"channels">> => [
#{
<<"id">> => integer_to_binary(ChannelId),
<<"permission_overwrites">> => []
}
#{<<"id">> => integer_to_binary(ChannelId), <<"permission_overwrites">> => []}
]
}
},
@@ -557,4 +541,122 @@ administrator_role_grants_all_permissions_test() ->
?assertEqual(?ALL_PERMISSIONS, get_member_permissions(UserId, ChannelId, State)),
?assert(can_view_channel(UserId, ChannelId, undefined, State)).
find_member_by_user_id_found_test() ->
State = #{
data => #{
<<"members">> => [
#{<<"user">> => #{<<"id">> => <<"123">>}, <<"nick">> => <<"Test">>}
]
}
},
Result = find_member_by_user_id(123, State),
?assertEqual(<<"Test">>, maps:get(<<"nick">>, Result)).
find_member_by_user_id_not_found_test() ->
State = #{data => #{<<"members">> => []}},
?assertEqual(undefined, find_member_by_user_id(123, State)).
find_member_by_user_id_map_storage_test() ->
State = #{
data => #{
<<"members">> => #{
321 => #{<<"user">> => #{<<"id">> => <<"321">>}, <<"nick">> => <<"Mapped">>}
}
}
},
Result = find_member_by_user_id(321, State),
?assertEqual(<<"Mapped">>, maps:get(<<"nick">>, Result)).
find_role_by_id_found_test() ->
Roles = [#{<<"id">> => <<"100">>, <<"name">> => <<"Admin">>}],
Result = find_role_by_id(100, Roles),
?assertEqual(<<"Admin">>, maps:get(<<"name">>, Result)).
find_role_by_id_not_found_test() ->
Roles = [#{<<"id">> => <<"100">>}],
?assertEqual(undefined, find_role_by_id(999, Roles)).
find_role_by_id_map_index_test() ->
Roles = #{
100 => #{<<"id">> => <<"100">>, <<"name">> => <<"Admin">>}
},
Result = find_role_by_id(100, Roles),
?assertEqual(<<"Admin">>, maps:get(<<"name">>, Result)).
find_channel_by_id_with_index_test() ->
State = #{
data => #{
<<"channels">> => [#{<<"id">> => <<"900">>, <<"name">> => <<"general">>}],
<<"channel_index">> => #{900 => #{<<"id">> => <<"900">>, <<"name">> => <<"general">>}}
}
},
Result = find_channel_by_id(900, State),
?assertEqual(<<"general">>, maps:get(<<"name">>, Result)).
to_int_test() ->
?assertEqual(123, to_int(123)),
?assertEqual(123, to_int(<<"123">>)),
?assertEqual(0, to_int(undefined)).
ensure_list_test() ->
?assertEqual([1, 2], ensure_list([1, 2])),
?assertEqual([], ensure_list(undefined)),
?assertEqual([], ensure_list(#{})).
can_access_message_with_read_history_test() ->
ReadHistory = constants:read_message_history_permission(),
State = #{data => #{<<"guild">> => #{}}},
MessageId = <<"100">>,
?assertEqual(true, can_access_message_by_permissions(ReadHistory, MessageId, State)).
can_access_message_no_read_history_no_cutoff_test() ->
State = #{data => #{<<"guild">> => #{}}},
MessageId = <<"100">>,
?assertEqual(false, can_access_message_by_permissions(0, MessageId, State)).
can_access_message_no_read_history_null_cutoff_test() ->
State = #{data => #{<<"guild">> => #{<<"message_history_cutoff">> => null}}},
MessageId = <<"100">>,
?assertEqual(false, can_access_message_by_permissions(0, MessageId, State)).
can_access_message_no_read_history_message_before_cutoff_test() ->
CutoffMs = 1704067200000,
BeforeCutoffTimestamp = CutoffMs - 60000,
FluxerEpoch = 1420070400000,
RelativeTs = BeforeCutoffTimestamp - FluxerEpoch,
Snowflake = RelativeTs bsl 22,
MessageId = integer_to_binary(Snowflake),
State = #{data => #{<<"guild">> => #{<<"message_history_cutoff">> => CutoffMs}}},
?assertEqual(false, can_access_message_by_permissions(0, MessageId, State)).
can_access_message_no_read_history_message_after_cutoff_test() ->
CutoffMs = 1704067200000,
AfterCutoffTimestamp = CutoffMs + 60000,
FluxerEpoch = 1420070400000,
RelativeTs = AfterCutoffTimestamp - FluxerEpoch,
Snowflake = RelativeTs bsl 22,
MessageId = integer_to_binary(Snowflake),
State = #{data => #{<<"guild">> => #{<<"message_history_cutoff">> => CutoffMs}}},
?assertEqual(true, can_access_message_by_permissions(0, MessageId, State)).
can_access_message_no_read_history_message_at_cutoff_test() ->
CutoffMs = 1704067200000,
FluxerEpoch = 1420070400000,
RelativeTs = CutoffMs - FluxerEpoch,
Snowflake = RelativeTs bsl 22,
MessageId = integer_to_binary(Snowflake),
State = #{data => #{<<"guild">> => #{<<"message_history_cutoff">> => CutoffMs}}},
?assertEqual(true, can_access_message_by_permissions(0, MessageId, State)).
can_access_message_with_rfc3339_cutoff_test() ->
CutoffBin = <<"2024-01-01T00:00:00Z">>,
CutoffMs = calendar:rfc3339_to_system_time("2024-01-01T00:00:00Z", [{unit, millisecond}]),
AfterCutoffTimestamp = CutoffMs + 60000,
FluxerEpoch = 1420070400000,
RelativeTs = AfterCutoffTimestamp - FluxerEpoch,
Snowflake = RelativeTs bsl 22,
MessageId = integer_to_binary(Snowflake),
State = #{data => #{<<"guild">> => #{<<"message_history_cutoff">> => CutoffBin}}},
?assertEqual(true, can_access_message_by_permissions(0, MessageId, State)).
-endif.

View File

@@ -20,8 +20,6 @@
-export([handle_bus_presence/3, send_cached_presence_to_session/3]).
-export([broadcast_presence_update/3]).
-import(guild_sessions, [handle_user_offline/2]).
-type guild_state() :: map().
-type member() :: map().
-type user_id() :: integer().
@@ -31,13 +29,19 @@
-endif.
-spec handle_bus_presence(user_id(), map(), guild_state()) -> {noreply, guild_state()}.
-spec send_cached_presence_to_session(user_id(), binary(), guild_state()) -> guild_state().
handle_bus_presence(UserId, Payload, State) ->
case maps:get(<<"user_update">>, Payload, false) of
true ->
UserData = maps:get(<<"user">>, Payload, #{}),
UpdatedState = handle_user_data_update(UserId, UserData, State),
guild_member_list:broadcast_member_list_updates(UserId, State, UpdatedState),
MemberData = find_member_by_user_id(UserId, UpdatedState),
case is_map(MemberData) of
true ->
maybe_notify_very_large_guild_member_list({upsert_member, MemberData}, UpdatedState);
false ->
maybe_notify_very_large_guild_member_list({notify_member_update, UserId}, UpdatedState)
end,
maybe_broadcast_member_list_updates(UserId, State, UpdatedState),
{noreply, UpdatedState};
false ->
Member = find_member_by_user_id(UserId, State),
@@ -50,7 +54,6 @@ handle_bus_presence(UserId, Payload, State) ->
Status = constants:status_type_atom(NormalizedStatusBin),
Mobile = maps:get(<<"mobile">>, Payload, false),
Afk = maps:get(<<"afk">>, Payload, false),
logger:debug("[guild_presence] Presence update for UserId=~p, Status=~p", [UserId, Status]),
MemberUser = maps:get(<<"user">>, Member, #{}),
CustomStatus = maps:get(<<"custom_status">>, Payload, null),
PresenceMap = presence_payload:build(
@@ -60,29 +63,48 @@ handle_bus_presence(UserId, Payload, State) ->
Afk,
CustomStatus
),
Presences = maps:get(presences, State, #{}),
UpdatedPresences = maps:put(UserId, PresenceMap, Presences),
StateWithPresences = maps:put(presences, UpdatedPresences, State),
broadcast_presence_update(UserId, PresenceMap, StateWithPresences),
logger:debug("[guild_presence] Broadcasting member list updates for UserId=~p", [UserId]),
guild_member_list:broadcast_member_list_updates(UserId, State, StateWithPresences),
StateWithPresence = store_member_presence(UserId, PresenceMap, State),
broadcast_presence_update(UserId, PresenceMap, StateWithPresence),
maybe_notify_very_large_guild_member_list(
{presence_update, UserId, PresenceMap}, StateWithPresence
),
maybe_broadcast_member_list_updates(UserId, State, StateWithPresence),
StateAfterOffline =
case Status of
offline ->
handle_user_offline(UserId, StateWithPresences);
guild_sessions:handle_user_offline(UserId, StateWithPresence);
_ ->
StateWithPresences
StateWithPresence
end,
{noreply, StateAfterOffline}
end
end.
-spec maybe_broadcast_member_list_updates(user_id(), guild_state(), guild_state()) -> ok.
maybe_broadcast_member_list_updates(UserId, OldState, UpdatedState) ->
case maps:get(very_large_guild_coordinator_pid, UpdatedState, undefined) of
Pid when is_pid(Pid) ->
ok;
_ ->
guild_member_list:broadcast_member_list_updates(UserId, OldState, UpdatedState)
end.
-spec maybe_notify_very_large_guild_member_list(term(), guild_state()) -> ok.
maybe_notify_very_large_guild_member_list(NotifyMsg, State) ->
case maps:get(very_large_guild_coordinator_pid, State, undefined) of
CoordPid when is_pid(CoordPid) ->
gen_server:cast(CoordPid, {very_large_guild_member_list_notify, NotifyMsg}),
ok;
_ ->
ok
end.
-spec broadcast_presence_update(user_id(), map(), guild_state()) -> ok.
broadcast_presence_update(UserId, Payload, State) ->
case find_member_by_user_id(UserId, State) of
undefined ->
ok;
_Member ->
_Member ->
GuildId = map_utils:get_integer(State, id, 0),
PresenceUpdate = maps:put(<<"guild_id">>, integer_to_binary(GuildId), Payload),
Sessions = maps:get(sessions, State, #{}),
@@ -90,7 +112,9 @@ broadcast_presence_update(UserId, Payload, State) ->
SubscribedSessionIds = guild_subscriptions:get_subscribed_sessions(UserId, MemberSubs),
TargetChannels = guild_visibility:viewable_channel_set(UserId, State),
{ValidSessionIds, InvalidSessionIds} =
partition_subscribed_sessions(SubscribedSessionIds, Sessions, TargetChannels, UserId, State),
partition_subscribed_sessions(
SubscribedSessionIds, Sessions, TargetChannels, UserId, State
),
StateAfterInvalidRemovals =
lists:foldl(
fun(SessionId, AccState) ->
@@ -122,10 +146,12 @@ broadcast_presence_update(UserId, Payload, State) ->
ok
end.
-spec normalize_presence_status(binary() | term()) -> binary().
normalize_presence_status(<<"invisible">>) -> <<"offline">>;
normalize_presence_status(Status) when is_binary(Status) -> Status;
normalize_presence_status(_) -> <<"offline">>.
-spec send_cached_presence_to_session(user_id(), binary(), guild_state()) -> guild_state().
send_cached_presence_to_session(UserId, SessionId, State) ->
case presence_cache:get(UserId) of
{ok, Payload} ->
@@ -134,6 +160,7 @@ send_cached_presence_to_session(UserId, SessionId, State) ->
State
end.
-spec send_presence_payload_to_session(user_id(), binary(), map(), guild_state()) -> guild_state().
send_presence_payload_to_session(UserId, SessionId, Payload, State) ->
GuildId = map_utils:get_integer(State, id, 0),
Sessions = maps:get(sessions, State, #{}),
@@ -151,7 +178,9 @@ send_presence_payload_to_session(UserId, SessionId, Payload, State) ->
CustomStatus = maps:get(<<"custom_status">>, Payload, null),
PresenceBase =
presence_payload:build(MemberUser, StatusBin, Mobile, Afk, CustomStatus),
PresenceUpdate = maps:put(<<"guild_id">>, integer_to_binary(GuildId), PresenceBase),
PresenceUpdate = maps:put(
<<"guild_id">>, integer_to_binary(GuildId), PresenceBase
),
gen_server:cast(SessionPid, {dispatch, presence_update, PresenceUpdate}),
State
end;
@@ -162,7 +191,7 @@ send_presence_payload_to_session(UserId, SessionId, Payload, State) ->
-spec handle_user_data_update(user_id(), map(), guild_state()) -> guild_state().
handle_user_data_update(UserId, UserData, State) ->
Data = guild_data(State),
Members = guild_members(State),
Members = guild_data_index:member_map(Data),
case find_member_by_user_id(UserId, State) of
undefined ->
State;
@@ -172,13 +201,13 @@ handle_user_data_update(UserId, UserData, State) ->
false ->
State;
true ->
UpdatedMembers = lists:map(
fun(M) ->
maybe_replace_member(M, UserId, UserData)
UpdatedMembers = maps:map(
fun(_MemberUserId, Member0) ->
maybe_replace_member(Member0, UserId, UserData)
end,
Members
),
UpdatedData = maps:put(<<"members">>, UpdatedMembers, Data),
UpdatedData = guild_data_index:put_member_map(UpdatedMembers, Data),
UpdatedState = maps:put(data, UpdatedData, State),
maybe_dispatch_member_update(UserId, UpdatedState),
UpdatedState
@@ -211,15 +240,13 @@ maybe_dispatch_member_update(UserId, State) ->
guild_data(State) ->
map_utils:ensure_map(map_utils:get_safe(State, data, #{})).
-spec guild_members(guild_state()) -> [map()].
guild_members(State) ->
map_utils:ensure_list(maps:get(<<"members">>, guild_data(State), [])).
-spec member_id(map()) -> user_id() | undefined.
member_id(Member) ->
User = map_utils:ensure_map(maps:get(<<"user">>, Member, #{})),
map_utils:get_integer(User, <<"id">>, undefined).
-spec partition_subscribed_sessions([binary()], map(), sets:set(), user_id(), guild_state()) ->
{[binary()], [binary()]}.
partition_subscribed_sessions(SessionIds, Sessions, TargetChannels, TargetUserId, State) ->
lists:foldl(
fun(SessionId, {Valids, Invalids}) ->
@@ -235,8 +262,12 @@ partition_subscribed_sessions(SessionIds, Sessions, TargetChannels, TargetUserId
UserId when UserId =:= TargetUserId ->
false;
_ ->
SessionChannels = guild_visibility:viewable_channel_set(SessionUserId, State),
not sets:is_empty(sets:intersection(SessionChannels, TargetChannels))
SessionChannels = guild_visibility:viewable_channel_set(
SessionUserId, State
),
not sets:is_empty(
sets:intersection(SessionChannels, TargetChannels)
)
end,
case Shared of
true -> {[SessionId | Valids], Invalids};
@@ -248,12 +279,27 @@ partition_subscribed_sessions(SessionIds, Sessions, TargetChannels, TargetUserId
SessionIds
).
-spec remove_session_member_subscription(binary(), user_id(), guild_state()) -> guild_state().
remove_session_member_subscription(SessionId, UserId, State) ->
MemberSubs = maps:get(member_subscriptions, State, guild_subscriptions:init_state()),
NewMemberSubs = guild_subscriptions:unsubscribe(SessionId, UserId, MemberSubs),
State1 = maps:put(member_subscriptions, NewMemberSubs, State),
guild_sessions:unsubscribe_from_user_presence(UserId, State1).
-spec check_user_data_differs(map(), map()) -> boolean().
check_user_data_differs(CurrentUserData, NewUserData) ->
utils:check_user_data_differs(CurrentUserData, NewUserData).
-spec find_member_by_user_id(user_id(), guild_state()) -> member() | undefined.
find_member_by_user_id(UserId, State) ->
guild_permissions:find_member_by_user_id(UserId, State).
-spec store_member_presence(user_id(), map(), guild_state()) -> guild_state().
store_member_presence(UserId, PresenceMap, State) ->
MemberPresence = maps:get(member_presence, State, #{}),
UpdatedPresence = maps:put(UserId, PresenceMap, MemberPresence),
maps:put(member_presence, UpdatedPresence, State).
-ifdef(TEST).
handle_bus_presence_non_member_noop_test() ->
@@ -279,24 +325,24 @@ handle_bus_presence_user_update_test() ->
Payload = #{<<"user">> => UserData, <<"user_update">> => true},
{noreply, NewState} = handle_bus_presence(1, Payload, State),
Data = maps:get(data, NewState),
[Member | _] = maps:get(<<"members">>, Data),
Member = maps:get(1, maps:get(<<"members">>, Data)),
?assertEqual(<<"Updated">>, maps:get(<<"username">>, maps:get(<<"user">>, Member))).
normalize_presence_status_test() ->
?assertEqual(<<"offline">>, normalize_presence_status(<<"invisible">>)),
?assertEqual(<<"online">>, normalize_presence_status(<<"online">>)),
?assertEqual(<<"idle">>, normalize_presence_status(<<"idle">>)),
?assertEqual(<<"offline">>, normalize_presence_status(undefined)).
presence_test_state() ->
#{
id => 42,
data => #{
<<"members">> => [
#{<<"user">> => #{<<"id">> => <<"1">>, <<"username">> => <<"Alpha">>}}
]
<<"members">> => #{
1 => #{<<"user">> => #{<<"id">> => <<"1">>, <<"username">> => <<"Alpha">>}}
}
},
sessions => #{}
}.
-endif.
check_user_data_differs(CurrentUserData, NewUserData) ->
utils:check_user_data_differs(CurrentUserData, NewUserData).
find_member_by_user_id(UserId, State) ->
guild_permissions:find_member_by_user_id(UserId, State).

View File

@@ -27,6 +27,8 @@
-type session_state() :: map().
-type request_data() :: map().
-type member() :: map().
-type presence() :: map().
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
@@ -34,13 +36,10 @@
-spec handle_request(request_data(), pid(), session_state()) -> ok | {error, atom()}.
handle_request(Data, SocketPid, SessionState) when is_map(Data), is_pid(SocketPid) ->
logger:debug("[guild_request_members] Handling guild members request: ~p", [Data]),
case parse_request(Data) of
{ok, Request} ->
logger:debug("[guild_request_members] Request parsed successfully: ~p", [Request]),
process_request(Request, SocketPid, SessionState);
{error, Reason} ->
logger:warning("[guild_request_members] Failed to parse request: ~p", [Reason]),
{error, Reason}
end;
handle_request(_, _, _) ->
@@ -55,7 +54,6 @@ parse_request(Data) ->
Presences = maps:get(<<"presences">>, Data, false),
Nonce = maps:get(<<"nonce">>, Data, null),
NormalizedNonce = normalize_nonce(Nonce),
case validate_guild_id(GuildIdRaw) of
{ok, GuildId} ->
case validate_user_ids(UserIdsRaw) of
@@ -137,25 +135,16 @@ process_request(Request, SocketPid, SessionState) ->
#{guild_id := GuildId, query := Query, limit := Limit, user_ids := UserIds} = Request,
UserIdBin = maps:get(user_id, SessionState),
UserId = type_conv:to_integer(UserIdBin),
logger:debug(
"[guild_request_members] Processing request for guild ~p, user ~p, user_ids: ~p",
[GuildId, UserId, UserIds]
),
case check_permission(UserId, GuildId, Query, Limit, UserIds, SessionState) of
ok ->
logger:debug("[guild_request_members] Permission check passed, fetching members"),
fetch_and_send_members(Request, SocketPid, SessionState);
{error, Reason} ->
logger:warning(
"[guild_request_members] Permission check failed: ~p",
[Reason]
),
{error, Reason}
end.
-spec check_permission(integer(), integer(), binary(), non_neg_integer(), [integer()], session_state()) ->
-spec check_permission(
integer(), integer(), binary(), non_neg_integer(), [integer()], session_state()
) ->
ok | {error, atom()}.
check_permission(UserId, GuildId, Query, Limit, UserIds, SessionState) ->
RequiresPermission = Query =:= <<>> andalso Limit =:= 0 andalso UserIds =:= [],
@@ -177,7 +166,6 @@ check_management_permission(UserId, _GuildId, GuildPid) ->
KickMembers = constants:kick_members_permission(),
BanMembers = constants:ban_members_permission(),
RequiredPermission = ManageRoles bor KickMembers bor BanMembers,
PermRequest = #{
user_id => UserId,
permission => RequiredPermission,
@@ -215,65 +203,45 @@ fetch_and_send_members(Request, _SocketPid, SessionState) ->
nonce := Nonce
} = Request,
SessionId = maps:get(session_id, SessionState),
logger:debug(
"[guild_request_members] Looking up guild ~p for member request",
[GuildId]
),
case lookup_guild(GuildId, SessionState) of
{ok, GuildPid} ->
logger:debug("[guild_request_members] Guild ~p found, fetching members", [GuildId]),
Members = fetch_members(GuildPid, Query, Limit, UserIds),
logger:debug("[guild_request_members] Found ~p members", [length(Members)]),
PresencesList = maybe_fetch_presences(Presences, GuildPid, Members),
send_member_chunks(GuildPid, SessionId, Members, PresencesList, Nonce),
logger:debug(
"[guild_request_members] Sent ~p member chunks for guild ~p with nonce ~p",
[max(1, (length(Members) + ?CHUNK_SIZE - 1) div ?CHUNK_SIZE), GuildId, Nonce]
),
ok;
{error, Reason} ->
logger:warning(
"[guild_request_members] Failed to lookup guild ~p: ~p",
[GuildId, Reason]
),
{error, Reason}
end.
-spec fetch_members(pid(), binary(), non_neg_integer(), [integer()]) -> [map()].
-spec fetch_members(pid(), binary(), non_neg_integer(), [integer()]) -> [member()].
fetch_members(GuildPid, _Query, _Limit, UserIds) when UserIds =/= [] ->
logger:debug("[guild_request_members] Fetching members by user_ids: ~p", [UserIds]),
case gen_server:call(GuildPid, {list_guild_members, #{limit => 100000, offset => 0}}, 10000) of
#{members := AllMembers} ->
logger:debug("[guild_request_members] Got ~p members from guild, filtering by user_ids", [length(AllMembers)]),
Filtered = filter_members_by_ids(AllMembers, UserIds),
logger:debug("[guild_request_members] Filtered to ~p members", [length(Filtered)]),
Filtered;
Other ->
logger:warning("[guild_request_members] Unexpected response from guild: ~p", [Other]),
filter_members_by_ids(AllMembers, UserIds);
_ ->
[]
end;
fetch_members(GuildPid, Query, Limit, []) ->
ActualLimit = case Limit of 0 -> 100000; L -> L end,
logger:debug("[guild_request_members] Fetching members with query '~s', limit ~p", [Query, ActualLimit]),
case gen_server:call(GuildPid, {list_guild_members, #{limit => ActualLimit, offset => 0}}, 10000) of
ActualLimit =
case Limit of
0 -> 100000;
L -> L
end,
case
gen_server:call(GuildPid, {list_guild_members, #{limit => ActualLimit, offset => 0}}, 10000)
of
#{members := AllMembers} ->
logger:debug("[guild_request_members] Got ~p members from guild", [length(AllMembers)]),
Result = case Query of
case Query of
<<>> ->
lists:sublist(AllMembers, ActualLimit);
_ ->
filter_members_by_query(AllMembers, Query, ActualLimit)
end,
logger:debug("[guild_request_members] Returning ~p members after query/filter", [length(Result)]),
Result;
Other ->
logger:warning("[guild_request_members] Unexpected response from guild: ~p", [Other]),
end;
_ ->
[]
end.
-spec filter_members_by_ids([map()], [integer()]) -> [map()].
-spec filter_members_by_ids([member()], [integer()]) -> [member()].
filter_members_by_ids(Members, UserIds) ->
UserIdSet = sets:from_list(UserIds),
lists:filter(
@@ -284,7 +252,7 @@ filter_members_by_ids(Members, UserIds) ->
Members
).
-spec filter_members_by_query([map()], binary(), non_neg_integer()) -> [map()].
-spec filter_members_by_query([member()], binary(), non_neg_integer()) -> [member()].
filter_members_by_query(Members, Query, Limit) ->
NormalizedQuery = string:lowercase(binary_to_list(Query)),
Matches = lists:filter(
@@ -297,50 +265,47 @@ filter_members_by_query(Members, Query, Limit) ->
),
lists:sublist(Matches, Limit).
-spec get_display_name(map()) -> binary().
-spec get_display_name(member()) -> binary().
get_display_name(Member) when is_map(Member) ->
Nick = maps:get(<<"nick">>, Member, undefined),
case Nick of
undefined -> nick_isundefined(Member);
null -> nick_isundefined(Member);
undefined -> get_fallback_name(Member);
null -> get_fallback_name(Member);
_ when is_binary(Nick) -> Nick;
_ -> nick_isundefined(Member)
_ -> get_fallback_name(Member)
end;
get_display_name(_) ->
<<>>.
nick_isundefined(Member) ->
-spec get_fallback_name(member()) -> binary().
get_fallback_name(Member) ->
User = maps:get(<<"user">>, Member, #{}),
GlobalName = maps:get(<<"global_name">>, User, undefined),
case GlobalName of
undefined ->
Username = maps:get(<<"username">>, User, <<>>),
case Username of
null -> <<>>;
undefined -> <<>>;
_ when is_binary(Username) -> Username;
_ -> <<>>
end;
null ->
Username = maps:get(<<"username">>, User, <<>>),
case Username of
null -> <<>>;
undefined -> <<>>;
_ when is_binary(Username) -> Username;
_ -> <<>>
end;
undefined -> get_username(User);
null -> get_username(User);
_ when is_binary(GlobalName) -> GlobalName;
_ -> get_username(User)
end.
-spec get_username(map()) -> binary().
get_username(User) ->
Username = maps:get(<<"username">>, User, <<>>),
case Username of
null -> <<>>;
undefined -> <<>>;
_ when is_binary(Username) -> Username;
_ -> <<>>
end.
-spec extract_user_id(map()) -> integer() | undefined.
-spec extract_user_id(member()) -> integer() | undefined.
extract_user_id(Member) when is_map(Member) ->
User = maps:get(<<"user">>, Member, #{}),
map_utils:get_integer(User, <<"id">>, undefined);
extract_user_id(_) ->
undefined.
-spec maybe_fetch_presences(boolean(), pid(), [map()]) -> [map()].
-spec maybe_fetch_presences(boolean(), pid(), [member()]) -> [presence()].
maybe_fetch_presences(false, _GuildPid, _Members) ->
[];
maybe_fetch_presences(true, _GuildPid, Members) ->
@@ -361,41 +326,30 @@ maybe_fetch_presences(true, _GuildPid, Members) ->
[P || P <- Cached, presence_visible(P)]
end.
-spec presence_visible(map()) -> boolean().
-spec presence_visible(presence()) -> boolean().
presence_visible(P) ->
Status = maps:get(<<"status">>, P, <<"offline">>),
Status =/= <<"offline">> andalso Status =/= <<"invisible">>.
-spec send_member_chunks(pid(), binary(), [map()], [map()], term()) -> ok.
-spec send_member_chunks(pid(), binary(), [member()], [presence()], term()) -> ok.
send_member_chunks(GuildPid, SessionId, Members, Presences, Nonce) ->
TotalChunks = max(1, (length(Members) + ?CHUNK_SIZE - 1) div ?CHUNK_SIZE),
MemberChunks = chunk_list(Members, ?CHUNK_SIZE),
PresenceChunks = chunk_presences(Presences, MemberChunks),
logger:debug(
"[guild_request_members] Sending ~p member chunks (total members: ~p, nonce: ~p)",
[TotalChunks, length(Members), Nonce]
),
lists:foldl(
fun({MemberChunk, PresenceChunk}, ChunkIndex) ->
ChunkData = build_chunk_data(
MemberChunk, PresenceChunk, ChunkIndex, TotalChunks, Nonce
),
logger:debug(
"[guild_request_members] Sending chunk ~p/~p with ~p members, nonce: ~p",
[ChunkIndex + 1, TotalChunks, length(MemberChunk), Nonce]
),
gen_server:cast(GuildPid, {send_members_chunk, SessionId, ChunkData}),
ChunkIndex + 1
end,
0,
lists:zip(MemberChunks, PresenceChunks)
),
logger:debug("[guild_request_members] All chunks sent successfully"),
ok.
-spec build_chunk_data([map()], [map()], non_neg_integer(), non_neg_integer(), term()) ->
-spec build_chunk_data([member()], [presence()], non_neg_integer(), non_neg_integer(), term()) ->
map().
build_chunk_data(Members, Presences, ChunkIndex, TotalChunks, Nonce) ->
Base = #{
@@ -403,15 +357,15 @@ build_chunk_data(Members, Presences, ChunkIndex, TotalChunks, Nonce) ->
<<"chunk_index">> => ChunkIndex,
<<"chunk_count">> => TotalChunks
},
WithPresences = case Presences of
[] -> Base;
_ -> maps:put(<<"presences">>, Presences, Base)
end,
WithNonce = case Nonce of
WithPresences =
case Presences of
[] -> Base;
_ -> maps:put(<<"presences">>, Presences, Base)
end,
case Nonce of
null -> WithPresences;
_ -> maps:put(<<"nonce">>, Nonce, WithPresences)
end,
WithNonce.
end.
-spec chunk_list([T], pos_integer()) -> [[T]] when T :: term().
chunk_list([], _Size) ->
@@ -419,13 +373,14 @@ chunk_list([], _Size) ->
chunk_list(List, Size) ->
chunk_list(List, Size, []).
-spec chunk_list([T], pos_integer(), [[T]]) -> [[T]] when T :: term().
chunk_list([], _Size, Acc) ->
lists:reverse(Acc);
chunk_list(List, Size, Acc) ->
{Chunk, Rest} = lists:split(min(Size, length(List)), List),
chunk_list(Rest, Size, [Chunk | Acc]).
-spec chunk_presences([map()], [[map()]]) -> [[map()]].
-spec chunk_presences([presence()], [[member()]]) -> [[presence()]].
chunk_presences(Presences, MemberChunks) ->
lists:map(
fun(MemberChunk) ->
@@ -491,15 +446,18 @@ display_name_priority_test() ->
<<"nick">> => <<"Nick">>
},
?assertEqual(<<"Nick">>, get_display_name(MemberWithNick)),
MemberWithGlobal = #{
<<"user">> => #{<<"username">> => <<"user">>, <<"global_name">> => <<"Global">>}
},
?assertEqual(<<"Global">>, get_display_name(MemberWithGlobal)),
MemberWithUsername = #{
<<"user">> => #{<<"username">> => <<"user">>}
},
?assertEqual(<<"user">>, get_display_name(MemberWithUsername)).
normalize_nonce_test() ->
?assertEqual(<<"abc">>, normalize_nonce(<<"abc">>)),
?assertEqual(null, normalize_nonce(<<"this_nonce_is_way_too_long_to_be_valid">>)),
?assertEqual(null, normalize_nonce(undefined)).
-endif.

View File

@@ -20,7 +20,9 @@
-export([
handle_session_connect/3,
handle_session_down/2,
remove_session/2,
filter_sessions_for_channel/4,
filter_sessions_for_message/5,
filter_sessions_for_manage_channels/4,
filter_sessions_exclude_session/2,
handle_user_offline/2,
@@ -29,67 +31,79 @@
build_initial_last_message_ids/1,
is_session_active/2,
subscribe_to_user_presence/2,
unsubscribe_from_user_presence/2
unsubscribe_from_user_presence/2,
set_session_viewable_channels/3,
refresh_all_viewable_channels/1
]).
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-endif.
-import(guild_permissions, [can_view_channel/4, can_manage_channel/3, find_member_by_user_id/2]).
-import(guild_data, [get_guild_state/2]).
-import(guild_availability, [is_guild_unavailable_for_user/2]).
-type guild_state() :: map().
-type session_id() :: binary().
-type user_id() :: integer().
-type guild_id() :: integer().
-type channel_id() :: integer().
-type session_data() :: map().
-type sessions_map() :: #{session_id() => session_data()}.
-type session_pair() :: {session_id(), session_data()}.
-spec handle_session_connect(map(), pid(), guild_state()) ->
{reply, {ok, map()} | {ok, unavailable, map()}, guild_state()}.
handle_session_connect(Request, Pid, State) ->
#{session_id := SessionId, user_id := UserId} = Request,
Sessions = maps:get(sessions, State, #{}),
case maps:is_key(SessionId, Sessions) of
true ->
{reply, {ok, guild_data:get_guild_state(UserId, State)}, State};
false ->
register_new_session(Request, Pid, UserId, SessionId, State)
end.
-spec register_new_session(map(), pid(), user_id(), session_id(), guild_state()) ->
{reply, {ok, map()} | {ok, unavailable, map()}, guild_state()}.
register_new_session(Request, Pid, UserId, SessionId, State) ->
Sessions = maps:get(sessions, State, #{}),
ActiveGuilds = maps:get(active_guilds, Request, sets:new()),
InitialGuildId = maps:get(initial_guild_id, Request, undefined),
UserRoles = session_passive:get_user_roles_for_guild(UserId, State),
Bot = maps:get(bot, Request, false),
GuildId = maps:get(id, State),
case maps:is_key(SessionId, Sessions) of
Ref = monitor(process, Pid),
GuildState = guild_data:get_guild_state(UserId, State),
InitialLastMessageIds = build_initial_last_message_ids(GuildState),
InitialChannelVersions = build_initial_channel_versions(GuildState),
InitialViewableChannels = build_viewable_channel_map(
guild_visibility:get_user_viewable_channels(UserId, State)
),
SessionData = #{
session_id => SessionId,
user_id => UserId,
pid => Pid,
mref => Ref,
active_guilds => ActiveGuilds,
user_roles => UserRoles,
bot => Bot,
is_staff => maps:get(is_staff, Request, false),
previous_passive_updates => InitialLastMessageIds,
previous_passive_channel_versions => InitialChannelVersions,
previous_passive_voice_states => #{},
viewable_channels => InitialViewableChannels
},
NewSessions = maps:put(SessionId, SessionData, Sessions),
State1 = maps:put(sessions, NewSessions, State),
State2 = subscribe_to_user_presence(UserId, State1),
_ = maybe_notify_coordinator(session_connected, SessionId, UserId, State2),
case guild_availability:is_guild_unavailable_for_user(UserId, State2) of
true ->
{reply, {ok, get_guild_state(UserId, State)}, State};
false ->
Ref = monitor(process, Pid),
GuildState = get_guild_state(UserId, State),
InitialLastMessageIds = build_initial_last_message_ids(GuildState),
SessionData = #{
session_id => SessionId,
user_id => UserId,
pid => Pid,
mref => Ref,
active_guilds => ActiveGuilds,
user_roles => UserRoles,
bot => Bot,
previous_passive_updates => InitialLastMessageIds
UnavailableResponse = #{
<<"id">> => integer_to_binary(GuildId),
<<"unavailable">> => true
},
NewSessions = maps:put(SessionId, SessionData, Sessions),
State1 = maps:put(sessions, NewSessions, State),
State2 = subscribe_to_user_presence(UserId, State1),
case is_guild_unavailable_for_user(UserId, State2) of
true ->
GuildId = maps:get(id, State2),
UnavailableResponse = #{
<<"id">> => integer_to_binary(GuildId),
<<"unavailable">> => true
},
{reply, {ok, unavailable, UnavailableResponse}, State2};
false ->
SyncedState = maybe_auto_sync_initial_guild(
SessionId,
GuildId,
InitialGuildId,
State2
),
{reply, {ok, GuildState}, SyncedState}
end
{reply, {ok, unavailable, UnavailableResponse}, State2};
false ->
SyncedState = maybe_auto_sync_initial_guild(SessionId, GuildId, InitialGuildId, State2),
{reply, {ok, GuildState}, SyncedState}
end.
-spec build_initial_last_message_ids(map()) -> #{binary() => binary()}.
build_initial_last_message_ids(GuildState) ->
Channels = maps:get(<<"channels">>, GuildState, []),
lists:foldl(
@@ -106,10 +120,69 @@ build_initial_last_message_ids(GuildState) ->
Channels
).
-spec build_initial_channel_versions(map()) -> #{binary() => integer()}.
build_initial_channel_versions(GuildState) ->
Channels = maps:get(<<"channels">>, GuildState, []),
lists:foldl(
fun(Channel, Acc) ->
ChannelIdBin = maps:get(<<"id">>, Channel, undefined),
case ChannelIdBin of
undefined ->
Acc;
_ ->
Version = map_utils:get_integer(Channel, <<"version">>, 0),
maps:put(ChannelIdBin, Version, Acc)
end
end,
#{},
Channels
).
-spec handle_session_down(reference(), guild_state()) ->
{noreply, guild_state()} | {stop, normal, guild_state()}.
handle_session_down(Ref, State) ->
Sessions = maps:get(sessions, State, #{}),
DisconnectingSession = find_session_by_ref(Ref, Sessions),
State1 = cleanup_disconnecting_session(DisconnectingSession, State),
NewSessions = maps:filter(fun(_K, S) -> maps:get(mref, S) =/= Ref end, Sessions),
NewState = maps:put(sessions, NewSessions, State1),
case map_size(NewSessions) of
0 ->
case should_auto_stop_on_empty(NewState) of
true -> {stop, normal, NewState};
false -> {noreply, NewState}
end;
_ -> {noreply, NewState}
end.
DisconnectingSession = maps:fold(
-spec remove_session(session_id(), guild_state()) -> guild_state().
remove_session(SessionId, State) ->
Sessions = maps:get(sessions, State, #{}),
case maps:get(SessionId, Sessions, undefined) of
undefined ->
State;
Session ->
maybe_demonitor_session(Session),
StateAfterCleanup = cleanup_disconnecting_session(Session, State),
SessionsAfterCleanup = maps:get(sessions, StateAfterCleanup, #{}),
NewSessions = maps:remove(SessionId, SessionsAfterCleanup),
maps:put(sessions, NewSessions, StateAfterCleanup)
end.
-spec maybe_demonitor_session(session_data()) -> ok.
maybe_demonitor_session(Session) ->
Ref = maps:get(mref, Session, undefined),
case is_reference(Ref) of
true ->
demonitor(Ref, [flush]),
ok;
false ->
ok
end.
-spec find_session_by_ref(reference(), sessions_map()) -> session_data() | undefined.
find_session_by_ref(Ref, Sessions) ->
maps:fold(
fun(_K, S, Acc) ->
case maps:get(mref, S) =:= Ref of
true -> S;
@@ -118,95 +191,253 @@ handle_session_down(Ref, State) ->
end,
undefined,
Sessions
),
State1 =
case DisconnectingSession of
undefined ->
State;
Session ->
UserId = maps:get(user_id, Session),
SessionId = maps:get(session_id, Session),
StateAfterPresence = unsubscribe_from_user_presence(UserId, State),
StateAfterMemberList = guild_member_list:unsubscribe_session(
SessionId, StateAfterPresence
),
MemberSubs = maps:get(
member_subscriptions, StateAfterMemberList, guild_subscriptions:init_state()
),
NewMemberSubs = guild_subscriptions:unsubscribe_session(SessionId, MemberSubs),
maps:put(member_subscriptions, NewMemberSubs, StateAfterMemberList)
end,
NewSessions = maps:filter(fun(_K, S) -> maps:get(mref, S) =/= Ref end, Sessions),
NewState = maps:put(sessions, NewSessions, State1),
case map_size(NewSessions) of
0 ->
{stop, normal, NewState};
_ ->
{noreply, NewState}
end.
filter_sessions_for_channel(Sessions, ChannelId, SessionIdOpt, State) ->
GuildId = maps:get(id, State, 0),
lists:filter(
fun({Sid, S}) ->
UserId = maps:get(user_id, S),
Member = find_member_by_user_id(UserId, State),
ExcludeSession =
case SessionIdOpt of
undefined -> false;
SessionId -> Sid =:= SessionId
end,
case {ExcludeSession, Member} of
{true, _} ->
false;
{_, undefined} ->
logger:warning(
"[guild_sessions] Filtering out session with no member: "
"guild_id=~p session_id=~p user_id=~p",
[GuildId, Sid, UserId]
),
false;
{false, _} ->
can_view_channel(UserId, ChannelId, Member, State)
end
end,
maps:to_list(Sessions)
).
filter_sessions_for_manage_channels(Sessions, ChannelId, SessionIdOpt, State) ->
-spec cleanup_disconnecting_session(session_data() | undefined, guild_state()) -> guild_state().
cleanup_disconnecting_session(undefined, State) ->
State;
cleanup_disconnecting_session(Session, State) ->
UserId = maps:get(user_id, Session),
SessionId = maps:get(session_id, Session),
_ = maybe_notify_coordinator(session_disconnected, SessionId, UserId, State),
StateAfterPresence = unsubscribe_from_user_presence(UserId, State),
StateAfterMemberList = guild_member_list:unsubscribe_session(SessionId, StateAfterPresence),
MemberSubs = maps:get(
member_subscriptions, StateAfterMemberList, guild_subscriptions:init_state()
),
NewMemberSubs = guild_subscriptions:unsubscribe_session(SessionId, MemberSubs),
StateAfterSubs = maps:put(member_subscriptions, NewMemberSubs, StateAfterMemberList),
cleanup_connect_admission_for_session(SessionId, StateAfterSubs).
-spec cleanup_connect_admission_for_session(session_id(), guild_state()) -> guild_state().
cleanup_connect_admission_for_session(SessionId, State) ->
Pending0 = maps:get(session_connect_pending, State, undefined),
State1 =
case Pending0 of
Pending when is_map(Pending) ->
maps:put(session_connect_pending, maps:remove(SessionId, Pending), State);
_ ->
State
end,
Queue0 = maps:get(session_connect_queue, State1, undefined),
Queue = normalize_connect_queue(Queue0),
case queue:is_queue(Queue) of
true ->
Filtered = queue:filter(
fun(Item) ->
Request = maps:get(request, Item, #{}),
maps:get(session_id, Request, undefined) =/= SessionId
end,
Queue
),
maps:put(session_connect_queue, Filtered, State1);
false ->
State1
end.
-spec normalize_connect_queue(term()) -> queue:queue() | undefined.
normalize_connect_queue(Value) when is_list(Value) ->
queue:from_list(Value);
normalize_connect_queue(Value) ->
case queue:is_queue(Value) of
true -> Value;
false -> undefined
end.
-spec should_auto_stop_on_empty(guild_state()) -> boolean().
should_auto_stop_on_empty(State) ->
case maps:get(disable_auto_stop_on_empty, State, false) of
true ->
false;
false ->
case maps:get(very_large_guild_coordinator_pid, State, undefined) of
Pid when is_pid(Pid) -> false;
_ -> true
end
end.
-spec maybe_notify_coordinator(session_connected | session_disconnected, session_id(), user_id(), guild_state()) ->
ok.
maybe_notify_coordinator(Type, SessionId, UserId, State) ->
case {maps:get(very_large_guild_coordinator_pid, State, undefined),
maps:get(very_large_guild_shard_index, State, undefined)}
of
{CoordPid, ShardIndex} when is_pid(CoordPid), is_integer(ShardIndex) ->
Msg =
case Type of
session_connected ->
{very_large_guild_session_connected, ShardIndex, SessionId, UserId};
session_disconnected ->
{very_large_guild_session_disconnected, ShardIndex, SessionId, UserId}
end,
CoordPid ! Msg,
ok;
_ ->
ok
end.
-spec filter_sessions_for_channel(
sessions_map(), channel_id(), session_id() | undefined, guild_state()
) ->
[session_pair()].
filter_sessions_for_channel(Sessions, ChannelId, SessionIdOpt, State) ->
lists:filter(
fun({Sid, S}) ->
UserId = maps:get(user_id, S),
ExcludeSession =
case SessionIdOpt of
undefined -> false;
SessionId -> Sid =:= SessionId
end,
case maps:get(pending_connect, S, false) of
true ->
false;
false ->
ExcludeSession = should_exclude_session(Sid, SessionIdOpt),
case ExcludeSession of
true ->
false;
false ->
can_manage_channel(UserId, ChannelId, State)
session_can_view_channel(S, ChannelId, State)
end
end
end,
maps:to_list(Sessions)
).
filter_sessions_exclude_session(Sessions, SessionIdOpt) ->
case SessionIdOpt of
-spec filter_sessions_for_message(
sessions_map(), channel_id(), binary(), session_id() | undefined, guild_state()
) ->
[session_pair()].
filter_sessions_for_message(Sessions, ChannelId, MessageId, SessionIdOpt, State) ->
lists:filter(
fun({Sid, S}) ->
case maps:get(pending_connect, S, false) of
true ->
false;
false ->
ExcludeSession = should_exclude_session(Sid, SessionIdOpt),
case ExcludeSession of
true ->
false;
false ->
UserId = maps:get(user_id, S, undefined),
case session_can_view_channel(S, ChannelId, State) of
false ->
false;
true ->
Perms = guild_permissions:get_member_permissions(UserId, ChannelId, State),
guild_permissions:can_access_message_by_permissions(
Perms, MessageId, State
)
end
end
end
end,
maps:to_list(Sessions)
).
-spec filter_sessions_for_manage_channels(
sessions_map(), channel_id(), session_id() | undefined, guild_state()
) ->
[session_pair()].
filter_sessions_for_manage_channels(Sessions, ChannelId, SessionIdOpt, State) ->
lists:filter(
fun({Sid, S}) ->
case maps:get(pending_connect, S, false) of
true ->
false;
false ->
UserId = maps:get(user_id, S),
ExcludeSession = should_exclude_session(Sid, SessionIdOpt),
case ExcludeSession of
true -> false;
false -> guild_permissions:can_manage_channel(UserId, ChannelId, State)
end
end
end,
maps:to_list(Sessions)
).
-spec filter_sessions_exclude_session(sessions_map(), session_id() | undefined) -> [session_pair()].
filter_sessions_exclude_session(Sessions, undefined) ->
[{Sid, S} || {Sid, S} <- maps:to_list(Sessions), maps:get(pending_connect, S, false) =/= true];
filter_sessions_exclude_session(Sessions, SessionId) ->
[
{Sid, S}
|| {Sid, S} <- maps:to_list(Sessions),
Sid =/= SessionId,
maps:get(pending_connect, S, false) =/= true
].
-spec should_exclude_session(session_id(), session_id() | undefined) -> boolean().
should_exclude_session(_, undefined) -> false;
should_exclude_session(Sid, SessionId) -> Sid =:= SessionId.
-spec set_session_viewable_channels(session_id(), map(), guild_state()) -> guild_state().
set_session_viewable_channels(SessionId, ViewableChannels, State) ->
Sessions = maps:get(sessions, State, #{}),
case maps:get(SessionId, Sessions, undefined) of
undefined ->
maps:to_list(Sessions);
SessionId ->
[{Sid, S} || {Sid, S} <- maps:to_list(Sessions), Sid =/= SessionId]
State;
SessionData ->
NewSessionData = maps:put(viewable_channels, ViewableChannels, SessionData),
NewSessions = maps:put(SessionId, NewSessionData, Sessions),
maps:put(sessions, NewSessions, State)
end.
-spec refresh_all_viewable_channels(guild_state()) -> guild_state().
refresh_all_viewable_channels(State) ->
Sessions = maps:get(sessions, State, #{}),
lists:foldl(
fun({SessionId, SessionData}, AccState) ->
UserId = maps:get(user_id, SessionData, undefined),
case is_integer(UserId) of
true ->
ViewableChannels = build_viewable_channel_map(
guild_visibility:get_user_viewable_channels(UserId, AccState)
),
set_session_viewable_channels(SessionId, ViewableChannels, AccState);
false ->
AccState
end
end,
State,
maps:to_list(Sessions)
).
-spec session_can_view_channel(session_data(), channel_id(), guild_state()) -> boolean().
session_can_view_channel(SessionData, ChannelId, State) ->
UserId = maps:get(user_id, SessionData, undefined),
case {UserId, maps:get(viewable_channels, SessionData, undefined)} of
{Uid, ViewableChannels} when is_integer(Uid), is_map(ViewableChannels) ->
case maps:is_key(ChannelId, ViewableChannels) of
true ->
true;
false ->
check_member_channel_access(Uid, ChannelId, State)
end;
{Uid, _} when is_integer(Uid) ->
check_member_channel_access(Uid, ChannelId, State);
_ ->
false
end.
-spec check_member_channel_access(user_id(), channel_id(), guild_state()) -> boolean().
check_member_channel_access(UserId, ChannelId, State) ->
Member = guild_permissions:find_member_by_user_id(UserId, State),
case Member of
undefined ->
false;
_ ->
guild_permissions:can_view_channel(UserId, ChannelId, Member, State)
end.
-spec build_viewable_channel_map([channel_id()]) -> #{channel_id() => true}.
build_viewable_channel_map(ChannelIds) ->
lists:foldl(
fun(ChannelId, Acc) ->
maps:put(ChannelId, true, Acc)
end,
#{},
ChannelIds
).
-spec subscribe_to_user_presence(user_id(), guild_state()) -> guild_state().
subscribe_to_user_presence(UserId, State) ->
PresenceSubs = maps:get(presence_subscriptions, State, #{}),
CurrentCount = maps:get(UserId, PresenceSubs, 0),
@@ -221,6 +452,7 @@ subscribe_to_user_presence(UserId, State) ->
maps:put(presence_subscriptions, NewSubs, State)
end.
-spec unsubscribe_from_user_presence(user_id(), guild_state()) -> guild_state().
unsubscribe_from_user_presence(UserId, State) ->
PresenceSubs = maps:get(presence_subscriptions, State, #{}),
CurrentCount = maps:get(UserId, PresenceSubs, 0),
@@ -235,30 +467,33 @@ unsubscribe_from_user_presence(UserId, State) ->
maps:put(presence_subscriptions, NewSubs, State)
end.
-spec handle_user_offline(user_id(), guild_state()) -> guild_state().
handle_user_offline(UserId, State) ->
PresenceSubs = maps:get(presence_subscriptions, State, #{}),
case maps:get(UserId, PresenceSubs, undefined) of
0 ->
presence_bus:unsubscribe(UserId),
NewSubs = maps:remove(UserId, PresenceSubs),
maps:put(presence_subscriptions, NewSubs, State);
undefined ->
State;
StateWithSubs = maps:put(presence_subscriptions, NewSubs, State),
MemberPresence = maps:get(member_presence, StateWithSubs, #{}),
UpdatedMemberPresence = maps:remove(UserId, MemberPresence),
maps:put(member_presence, UpdatedMemberPresence, StateWithSubs);
_ ->
State
end.
-spec maybe_send_cached_presence(user_id(), guild_state()) -> guild_state().
maybe_send_cached_presence(UserId, State) ->
case presence_cache:get(UserId) of
{ok, Payload} ->
case guild_presence:handle_bus_presence(UserId, Payload, State) of
{noreply, UpdatedState} ->
UpdatedState
{noreply, UpdatedState} -> UpdatedState
end;
_ ->
State
end.
-spec set_session_active_guild(session_id(), guild_id(), guild_state()) -> guild_state().
set_session_active_guild(SessionId, GuildId, State) ->
Sessions = maps:get(sessions, State, #{}),
case maps:get(SessionId, Sessions, undefined) of
@@ -270,6 +505,7 @@ set_session_active_guild(SessionId, GuildId, State) ->
maps:put(sessions, NewSessions, State)
end.
-spec set_session_passive_guild(session_id(), guild_id(), guild_state()) -> guild_state().
set_session_passive_guild(SessionId, GuildId, State) ->
Sessions = maps:get(sessions, State, #{}),
case maps:get(SessionId, Sessions, undefined) of
@@ -282,39 +518,39 @@ set_session_passive_guild(SessionId, GuildId, State) ->
maps:put(sessions, NewSessions, State)
end.
-spec is_session_active(session_id(), guild_state()) -> boolean().
is_session_active(SessionId, State) ->
GuildId = maps:get(id, State, 0),
Sessions = maps:get(sessions, State, #{}),
case maps:get(SessionId, Sessions, undefined) of
undefined ->
false;
SessionData ->
not session_passive:is_passive(GuildId, SessionData)
undefined -> false;
SessionData -> not session_passive:is_passive(GuildId, SessionData)
end.
maybe_auto_sync_initial_guild(SessionId, GuildId, InitialGuildId, State) ->
case InitialGuildId of
GuildId ->
Sessions = maps:get(sessions, State, #{}),
case maps:get(SessionId, Sessions, undefined) of
undefined ->
State;
SessionData ->
SyncedSessionData = session_passive:mark_guild_synced(GuildId, SessionData),
NewSessions = maps:put(SessionId, SyncedSessionData, Sessions),
maps:put(sessions, NewSessions, State)
end;
_ ->
State
end.
-spec maybe_auto_sync_initial_guild(
session_id(), guild_id(), guild_id() | undefined, guild_state()
) ->
guild_state().
maybe_auto_sync_initial_guild(SessionId, GuildId, GuildId, State) ->
Sessions = maps:get(sessions, State, #{}),
case maps:get(SessionId, Sessions, undefined) of
undefined ->
State;
SessionData ->
SyncedSessionData = session_passive:mark_guild_synced(GuildId, SessionData),
NewSessions = maps:put(SessionId, SyncedSessionData, Sessions),
maps:put(sessions, NewSessions, State)
end;
maybe_auto_sync_initial_guild(_SessionId, _GuildId, _InitialGuildId, State) ->
State.
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
build_initial_last_message_ids_empty_channels_test() ->
GuildState = #{<<"channels">> => []},
Result = build_initial_last_message_ids(GuildState),
?assertEqual(#{}, Result),
ok.
?assertEqual(#{}, Result).
build_initial_last_message_ids_with_channels_test() ->
GuildState = #{
@@ -324,8 +560,7 @@ build_initial_last_message_ids_with_channels_test() ->
]
},
Result = build_initial_last_message_ids(GuildState),
?assertEqual(#{<<"100">> => <<"500">>, <<"101">> => <<"600">>}, Result),
ok.
?assertEqual(#{<<"100">> => <<"500">>, <<"101">> => <<"600">>}, Result).
build_initial_last_message_ids_filters_null_test() ->
GuildState = #{
@@ -336,13 +571,237 @@ build_initial_last_message_ids_filters_null_test() ->
]
},
Result = build_initial_last_message_ids(GuildState),
?assertEqual(#{<<"100">> => <<"500">>}, Result),
ok.
?assertEqual(#{<<"100">> => <<"500">>}, Result).
build_initial_last_message_ids_no_channels_key_test() ->
GuildState = #{},
Result = build_initial_last_message_ids(GuildState),
?assertEqual(#{}, Result),
?assertEqual(#{}, Result).
build_initial_channel_versions_test() ->
GuildState = #{
<<"channels">> => [
#{<<"id">> => <<"100">>, <<"version">> => 5},
#{<<"id">> => <<"101">>}
]
},
Result = build_initial_channel_versions(GuildState),
?assertEqual(#{<<"100">> => 5, <<"101">> => 0}, Result).
filter_sessions_for_channel_uses_cached_viewable_channels_test() ->
SessionId = <<"s1">>,
SessionData = #{
session_id => SessionId,
user_id => 10,
pid => self(),
viewable_channels => #{200 => true}
},
Sessions = #{SessionId => SessionData},
State = #{
sessions => Sessions,
data => #{<<"members">> => #{}}
},
Result = filter_sessions_for_channel(Sessions, 200, undefined, State),
?assertEqual([{SessionId, SessionData}], Result).
refresh_all_viewable_channels_populates_cache_test() ->
SessionId = <<"s1">>,
UserId = 10,
GuildId = 42,
ViewPerm = constants:view_channel_permission(),
State = #{
id => GuildId,
sessions => #{
SessionId => #{
session_id => SessionId,
user_id => UserId,
pid => self()
}
},
data => #{
<<"guild">> => #{<<"owner_id">> => <<"999">>},
<<"roles">> => [
#{<<"id">> => integer_to_binary(GuildId), <<"permissions">> => integer_to_binary(ViewPerm)}
],
<<"members">> => #{
UserId => #{<<"user">> => #{<<"id">> => integer_to_binary(UserId)}, <<"roles">> => []}
},
<<"channels">> => [
#{<<"id">> => <<"200">>, <<"permission_overwrites">> => []}
]
}
},
UpdatedState = refresh_all_viewable_channels(State),
UpdatedSessions = maps:get(sessions, UpdatedState, #{}),
UpdatedSession = maps:get(SessionId, UpdatedSessions),
ViewableChannels = maps:get(viewable_channels, UpdatedSession, #{}),
?assertEqual(true, maps:is_key(200, ViewableChannels)).
filter_sessions_exclude_session_undefined_test() ->
Sessions = #{<<"s1">> => #{}, <<"s2">> => #{}},
Result = filter_sessions_exclude_session(Sessions, undefined),
?assertEqual(2, length(Result)).
filter_sessions_exclude_session_specific_test() ->
Sessions = #{<<"s1">> => #{}, <<"s2">> => #{}},
Result = filter_sessions_exclude_session(Sessions, <<"s1">>),
?assertEqual(1, length(Result)),
?assertEqual([{<<"s2">>, #{}}], Result).
should_exclude_session_test() ->
?assertEqual(false, should_exclude_session(<<"s1">>, undefined)),
?assertEqual(true, should_exclude_session(<<"s1">>, <<"s1">>)),
?assertEqual(false, should_exclude_session(<<"s1">>, <<"s2">>)).
find_session_by_ref_found_test() ->
Ref = make_ref(),
Sessions = #{<<"s1">> => #{mref => Ref, user_id => 1}},
Result = find_session_by_ref(Ref, Sessions),
?assertEqual(#{mref => Ref, user_id => 1}, Result).
find_session_by_ref_not_found_test() ->
Ref = make_ref(),
OtherRef = make_ref(),
Sessions = #{<<"s1">> => #{mref => OtherRef, user_id => 1}},
Result = find_session_by_ref(Ref, Sessions),
?assertEqual(undefined, Result).
remove_session_removes_entry_test() ->
SessionId = <<"s1">>,
SessionData = #{
session_id => SessionId,
user_id => 1,
pid => self(),
mref => make_ref(),
active_guilds => sets:new(),
user_roles => [],
bot => false,
previous_passive_updates => #{},
previous_passive_channel_versions => #{},
previous_passive_voice_states => #{}
},
State = #{
sessions => #{SessionId => SessionData},
presence_subscriptions => #{1 => 1},
member_list_subscriptions => #{},
member_subscriptions => guild_subscriptions:init_state()
},
NewState = remove_session(SessionId, State),
?assertEqual(#{}, maps:get(sessions, NewState, #{})),
?assertEqual(0, maps:get(1, maps:get(presence_subscriptions, NewState, #{}), 0)).
remove_session_cleans_connect_pending_test() ->
SessionId = <<"s1">>,
SessionData = #{
session_id => SessionId,
user_id => 1,
pid => self(),
mref => make_ref(),
active_guilds => sets:new(),
user_roles => [],
bot => false,
previous_passive_updates => #{},
previous_passive_channel_versions => #{},
previous_passive_voice_states => #{}
},
State = #{
sessions => #{SessionId => SessionData},
presence_subscriptions => #{1 => 1},
member_list_subscriptions => #{},
member_subscriptions => guild_subscriptions:init_state(),
session_connect_pending => #{SessionId => 3, <<"s2">> => 1},
session_connect_queue => [
#{request => #{session_id => SessionId}, attempt => 3},
#{request => #{session_id => <<"s2">>}, attempt => 1}
]
},
NewState = remove_session(SessionId, State),
Pending = maps:get(session_connect_pending, NewState, #{}),
?assertEqual(false, maps:is_key(SessionId, Pending)),
?assertEqual(true, maps:is_key(<<"s2">>, Pending)),
Queue0 = maps:get(session_connect_queue, NewState, queue:new()),
Queue =
case Queue0 of
L when is_list(L) -> L;
_ ->
case queue:is_queue(Queue0) of
true -> queue:to_list(Queue0);
false -> []
end
end,
?assertEqual(
1,
length([
Item
|| Item <- Queue,
maps:get(session_id, maps:get(request, Item, #{}), undefined) =:= <<"s2">>
])
),
?assertEqual(
0,
length([
Item
|| Item <- Queue,
maps:get(session_id, maps:get(request, Item, #{}), undefined) =:= SessionId
])
),
ok.
cleanup_connect_admission_queue_format_test() ->
S1 = <<"s1">>,
S2 = <<"s2">>,
S3 = <<"s3">>,
Queue = queue:from_list([
#{request => #{session_id => S1}, attempt => 0},
#{request => #{session_id => S2}, attempt => 1},
#{request => #{session_id => S1}, attempt => 2},
#{request => #{session_id => S3}, attempt => 3}
]),
State0 = #{
sessions => #{},
session_connect_queue => Queue,
presence_subscriptions => #{},
member_list_subscriptions => #{},
member_subscriptions => guild_subscriptions:init_state()
},
State1 = cleanup_connect_admission_for_session(S1, State0),
ResultQueue0 = maps:get(session_connect_queue, State1),
ResultQueue =
case queue:is_queue(ResultQueue0) of
true -> queue:to_list(ResultQueue0);
false -> ResultQueue0
end,
?assertEqual(2, length(ResultQueue)),
SessionIds = [
maps:get(session_id, maps:get(request, Item, #{}), undefined)
|| Item <- ResultQueue
],
?assertEqual(false, lists:member(S1, SessionIds)),
?assertEqual(true, lists:member(S2, SessionIds)),
?assertEqual(true, lists:member(S3, SessionIds)).
pending_connect_filtered_from_channel_sessions_test() ->
NormalSession = #{
session_id => <<"s1">>,
user_id => 10,
pid => self(),
viewable_channels => #{200 => true}
},
PendingSession = #{
session_id => <<"s2">>,
user_id => 11,
pid => self(),
pending_connect => true,
viewable_channels => #{200 => true}
},
Sessions = #{<<"s1">> => NormalSession, <<"s2">> => PendingSession},
State = #{
sessions => Sessions,
data => #{<<"members">> => #{}}
},
Result = filter_sessions_for_channel(Sessions, 200, undefined, State),
?assertEqual(1, length(Result)),
[{ResultSid, _}] = Result,
?assertEqual(<<"s1">>, ResultSid).
-endif.

View File

@@ -17,24 +17,26 @@
-module(guild_state).
-export([
update_state/3
]).
-export([update_state/3]).
-import(guild_user_data, [maybe_update_cached_user_data/3]).
-import(guild_availability, [handle_unavailability_transition/2]).
-import(guild_visibility, [compute_and_dispatch_visibility_changes/2]).
-import(guild, [update_counts/1]).
-type guild_state() :: map().
-type guild_data() :: map().
-type event() :: atom().
-type event_data() :: map().
-type user_id() :: integer().
-type role_id() :: binary().
-spec update_state(event(), event_data(), guild_state()) -> guild_state().
update_state(Event, EventData, State) ->
StateWithUpdatedUser = maybe_update_cached_user_data(Event, EventData, State),
Data = maps:get(data, StateWithUpdatedUser),
StateWithUpdatedUser0 = guild_user_data:maybe_update_cached_user_data(Event, EventData, State),
Data0 = maps:get(data, StateWithUpdatedUser0),
Data = guild_data_index:normalize_data(Data0),
StateWithUpdatedUser = maps:put(data, Data, StateWithUpdatedUser0),
UpdatedData = update_data_for_event(Event, EventData, Data, State),
UpdatedState = maps:put(data, UpdatedData, StateWithUpdatedUser),
handle_post_update(Event, EventData, StateWithUpdatedUser, UpdatedState).
handle_post_update(Event, StateWithUpdatedUser, UpdatedState).
-spec update_data_for_event(event(), event_data(), guild_data(), guild_state()) -> guild_data().
update_data_for_event(guild_update, EventData, Data, _State) ->
handle_guild_update(EventData, Data);
update_data_for_event(guild_member_add, EventData, Data, _State) ->
@@ -70,23 +72,184 @@ update_data_for_event(guild_stickers_update, EventData, Data, _State) ->
update_data_for_event(_Event, _EventData, Data, _State) ->
Data.
handle_post_update(guild_update, StateWithUpdatedUser, UpdatedState) ->
handle_unavailability_transition(StateWithUpdatedUser, UpdatedState),
-spec handle_post_update(event(), event_data(), guild_state(), guild_state()) -> guild_state().
handle_post_update(guild_update, _EventData, StateWithUpdatedUser, UpdatedState) ->
guild_availability:handle_unavailability_transition(StateWithUpdatedUser, UpdatedState);
handle_post_update(guild_member_add, _EventData, _StateWithUpdatedUser, UpdatedState) ->
guild:update_counts(UpdatedState);
handle_post_update(guild_member_update, EventData, StateWithUpdatedUser, UpdatedState) ->
UserId = extract_user_id(EventData),
case is_integer(UserId) andalso UserId > 0 of
true ->
guild_visibility:compute_and_dispatch_visibility_changes_for_users(
[UserId],
StateWithUpdatedUser,
UpdatedState
);
false ->
guild_visibility:compute_and_dispatch_visibility_changes(
StateWithUpdatedUser,
UpdatedState
)
end;
handle_post_update(guild_role_create, _EventData, _StateWithUpdatedUser, UpdatedState) ->
UpdatedState;
handle_post_update(guild_member_add, _StateWithUpdatedUser, UpdatedState) ->
update_counts(UpdatedState);
handle_post_update(guild_member_remove, _StateWithUpdatedUser, UpdatedState) ->
handle_post_update(guild_role_update, EventData, StateWithUpdatedUser, UpdatedState) ->
recompute_visibility_for_roles(
extract_role_ids_from_role_update(EventData),
StateWithUpdatedUser,
UpdatedState
);
handle_post_update(guild_role_update_bulk, EventData, StateWithUpdatedUser, UpdatedState) ->
recompute_visibility_for_roles(
extract_role_ids_from_role_update_bulk(EventData),
StateWithUpdatedUser,
UpdatedState
);
handle_post_update(guild_role_delete, EventData, StateWithUpdatedUser, UpdatedState) ->
recompute_visibility_for_roles(
extract_role_ids_from_role_delete(EventData),
StateWithUpdatedUser,
UpdatedState
);
handle_post_update(channel_update, EventData, StateWithUpdatedUser, UpdatedState) ->
ChannelIds = extract_channel_ids_from_channel_update(EventData),
guild_visibility:compute_and_dispatch_visibility_changes_for_channels(
ChannelIds,
StateWithUpdatedUser,
UpdatedState
);
handle_post_update(channel_update_bulk, EventData, StateWithUpdatedUser, UpdatedState) ->
ChannelIds = extract_channel_ids_from_channel_update_bulk(EventData),
guild_visibility:compute_and_dispatch_visibility_changes_for_channels(
ChannelIds,
StateWithUpdatedUser,
UpdatedState
);
handle_post_update(guild_member_remove, EventData, _StateWithUpdatedUser, UpdatedState) ->
UserId = extract_user_id(EventData),
State1 = cleanup_removed_member_sessions(UpdatedState),
update_counts(State1);
handle_post_update(Event, StateWithUpdatedUser, UpdatedState) ->
State2 = maybe_disconnect_removed_member(UserId, State1),
guild:update_counts(State2);
handle_post_update(Event, _EventData, StateWithUpdatedUser, UpdatedState) ->
case needs_visibility_check(Event) of
true ->
compute_and_dispatch_visibility_changes(StateWithUpdatedUser, UpdatedState),
UpdatedState;
guild_visibility:compute_and_dispatch_visibility_changes(
StateWithUpdatedUser, UpdatedState
);
false ->
UpdatedState
end.
-spec maybe_disconnect_removed_member(user_id(), guild_state()) -> guild_state().
maybe_disconnect_removed_member(UserId, State) when is_integer(UserId), UserId > 0 ->
{reply, _Result, NewState} =
guild_voice_disconnect:disconnect_voice_user(
#{user_id => UserId, connection_id => null},
State
),
NewState;
maybe_disconnect_removed_member(_, State) ->
State.
-spec recompute_visibility_for_roles([integer()], guild_state(), guild_state()) -> guild_state().
recompute_visibility_for_roles(RoleIds, StateWithUpdatedUser, UpdatedState) ->
GuildId = maps:get(id, UpdatedState, 0),
case lists:any(fun(RoleId) -> RoleId =:= GuildId end, RoleIds) of
true ->
guild_visibility:compute_and_dispatch_visibility_changes(
StateWithUpdatedUser,
UpdatedState
);
false ->
UserIds = affected_user_ids_for_roles(RoleIds, StateWithUpdatedUser, UpdatedState),
case UserIds of
[] ->
UpdatedState;
_ ->
guild_visibility:compute_and_dispatch_visibility_changes_for_users(
UserIds,
StateWithUpdatedUser,
UpdatedState
)
end
end.
-spec affected_user_ids_for_roles([integer()], guild_state(), guild_state()) -> [user_id()].
affected_user_ids_for_roles(RoleIds, StateWithUpdatedUser, UpdatedState) ->
OldData = maps:get(data, StateWithUpdatedUser, #{}),
NewData = maps:get(data, UpdatedState, #{}),
OldMemberRoleIndex = guild_data_index:member_role_index(OldData),
NewMemberRoleIndex = guild_data_index:member_role_index(NewData),
UserIdSet = lists:foldl(
fun(RoleId, AccSet) ->
OldUsers = maps:keys(maps:get(RoleId, OldMemberRoleIndex, #{})),
NewUsers = maps:keys(maps:get(RoleId, NewMemberRoleIndex, #{})),
RoleUsers = OldUsers ++ NewUsers,
lists:foldl(
fun(UserId, InnerSet) -> sets:add_element(UserId, InnerSet) end,
AccSet,
RoleUsers
)
end,
sets:new(),
lists:usort(RoleIds)
),
sets:to_list(UserIdSet).
-spec extract_role_ids_from_role_update(event_data()) -> [integer()].
extract_role_ids_from_role_update(EventData) ->
RoleData = maps:get(<<"role">>, EventData, #{}),
case type_conv:to_integer(maps:get(<<"id">>, RoleData, undefined)) of
undefined -> [];
RoleId -> [RoleId]
end.
-spec extract_role_ids_from_role_update_bulk(event_data()) -> [integer()].
extract_role_ids_from_role_update_bulk(EventData) ->
Roles = maps:get(<<"roles">>, EventData, []),
lists:filtermap(
fun(RoleData) ->
case type_conv:to_integer(maps:get(<<"id">>, RoleData, undefined)) of
undefined ->
false;
RoleId ->
{true, RoleId}
end
end,
Roles
).
-spec extract_role_ids_from_role_delete(event_data()) -> [integer()].
extract_role_ids_from_role_delete(EventData) ->
case type_conv:to_integer(maps:get(<<"role_id">>, EventData, undefined)) of
undefined -> [];
RoleId -> [RoleId]
end.
-spec extract_channel_ids_from_channel_update(event_data()) -> [integer()].
extract_channel_ids_from_channel_update(EventData) ->
case type_conv:to_integer(maps:get(<<"id">>, EventData, undefined)) of
undefined -> [];
ChannelId -> [ChannelId]
end.
-spec extract_channel_ids_from_channel_update_bulk(event_data()) -> [integer()].
extract_channel_ids_from_channel_update_bulk(EventData) ->
Channels = maps:get(<<"channels">>, EventData, []),
lists:filtermap(
fun(ChannelData) ->
case type_conv:to_integer(maps:get(<<"id">>, ChannelData, undefined)) of
undefined ->
false;
ChannelId ->
{true, ChannelId}
end
end,
Channels
).
-spec needs_visibility_check(event()) -> boolean().
needs_visibility_check(guild_role_create) -> true;
needs_visibility_check(guild_role_update) -> true;
needs_visibility_check(guild_role_update_bulk) -> true;
@@ -96,33 +259,36 @@ needs_visibility_check(channel_update) -> true;
needs_visibility_check(channel_update_bulk) -> true;
needs_visibility_check(_) -> false.
-spec handle_guild_update(event_data(), guild_data()) -> guild_data().
handle_guild_update(EventData, Data) ->
Guild = maps:get(<<"guild">>, Data),
UpdatedGuild = maps:merge(Guild, EventData),
maps:put(<<"guild">>, UpdatedGuild, Data).
-spec handle_member_add(event_data(), guild_data()) -> guild_data().
handle_member_add(EventData, Data) ->
Members = maps:get(<<"members">>, Data, []),
UpdatedData = maps:put(<<"members">>, Members ++ [EventData], Data),
UpdatedData.
guild_data_index:put_member(EventData, Data).
-spec handle_member_update(event_data(), guild_data()) -> guild_data().
handle_member_update(EventData, Data) ->
Members = maps:get(<<"members">>, Data, []),
UserId = extract_user_id(EventData),
UpdatedMembers = replace_member_by_id(Members, UserId, EventData),
maps:put(<<"members">>, UpdatedMembers, Data).
Members = guild_data_index:member_map(Data),
case maps:is_key(UserId, Members) of
true ->
guild_data_index:put_member(EventData, Data);
false ->
Data
end.
-spec handle_member_remove(event_data(), guild_data(), guild_state()) -> guild_data().
handle_member_remove(EventData, Data, _State) ->
Members = maps:get(<<"members">>, Data, []),
UserId = extract_user_id(EventData),
FilteredMembers = remove_member_by_id(Members, UserId),
maps:put(<<"members">>, FilteredMembers, Data).
guild_data_index:remove_member(UserId, Data).
-spec cleanup_removed_member_sessions(guild_state()) -> guild_state().
cleanup_removed_member_sessions(State) ->
Data = maps:get(data, State),
Members = maps:get(<<"members">>, Data, []),
MemberUserIds = sets:from_list([extract_user_id_from_member(M) || M <- Members]),
MemberUserIds = sets:from_list(guild_data_index:member_ids(Data)),
Sessions = maps:get(sessions, State, #{}),
FilteredSessions = maps:filter(
fun(_K, S) ->
@@ -131,158 +297,141 @@ cleanup_removed_member_sessions(State) ->
end,
Sessions
),
maps:put(sessions, FilteredSessions, State).
Presences = maps:get(presences, State, #{}),
FilteredPresences = maps:filter(
fun(UserId, _V) ->
sets:is_element(UserId, MemberUserIds)
end,
Presences
),
State1 = maps:put(sessions, FilteredSessions, State),
maps:put(presences, FilteredPresences, State1).
extract_user_id_from_member(Member) when is_map(Member) ->
MUser = maps:get(<<"user">>, Member, #{}),
utils:binary_to_integer_safe(maps:get(<<"id">>, MUser, <<"0">>));
extract_user_id_from_member(_) ->
0.
-spec extract_user_id(event_data()) -> user_id().
extract_user_id(EventData) ->
MUser = maps:get(<<"user">>, EventData, #{}),
utils:binary_to_integer_safe(maps:get(<<"id">>, MUser, <<"0">>)).
replace_member_by_id(Members, UserId, NewMember) ->
lists:map(
fun(M) when is_map(M) ->
MMUser = maps:get(<<"user">>, M, #{}),
MUserId = utils:binary_to_integer_safe(maps:get(<<"id">>, MMUser, <<"0">>)),
case MUserId =:= UserId of
true -> NewMember;
false -> M
end
end,
Members
).
remove_member_by_id(Members, UserId) ->
lists:filter(
fun(M) when is_map(M) ->
MMUser = maps:get(<<"user">>, M, #{}),
MUserId = utils:binary_to_integer_safe(maps:get(<<"id">>, MMUser, <<"0">>)),
MUserId =/= UserId
end,
Members
).
-spec handle_role_create(event_data(), guild_data()) -> guild_data().
handle_role_create(EventData, Data) ->
Roles = maps:get(<<"roles">>, Data, []),
Roles = guild_data_index:role_list(Data),
RoleData = maps:get(<<"role">>, EventData),
maps:put(<<"roles">>, Roles ++ [RoleData], Data).
guild_data_index:put_roles(Roles ++ [RoleData], Data).
-spec handle_role_update(event_data(), guild_data()) -> guild_data().
handle_role_update(EventData, Data) ->
Roles = maps:get(<<"roles">>, Data, []),
Roles = guild_data_index:role_list(Data),
RoleData = maps:get(<<"role">>, EventData),
RoleId = maps:get(<<"id">>, RoleData),
UpdatedRoles = replace_item_by_id(Roles, RoleId, RoleData),
maps:put(<<"roles">>, UpdatedRoles, Data).
guild_data_index:put_roles(UpdatedRoles, Data).
-spec handle_role_update_bulk(event_data(), guild_data()) -> guild_data().
handle_role_update_bulk(EventData, Data) ->
Roles = maps:get(<<"roles">>, Data, []),
Roles = guild_data_index:role_list(Data),
BulkRoles = maps:get(<<"roles">>, EventData, []),
UpdatedRoles = bulk_update_items(Roles, BulkRoles),
maps:put(<<"roles">>, UpdatedRoles, Data).
guild_data_index:put_roles(UpdatedRoles, Data).
-spec handle_role_delete(event_data(), guild_data()) -> guild_data().
handle_role_delete(EventData, Data) ->
Roles = maps:get(<<"roles">>, Data, []),
Roles = guild_data_index:role_list(Data),
RoleId = maps:get(<<"role_id">>, EventData),
FilteredRoles = remove_item_by_id(Roles, RoleId),
Data1 = maps:put(<<"roles">>, FilteredRoles, Data),
Data1 = guild_data_index:put_roles(FilteredRoles, Data),
Data2 = strip_role_from_members(RoleId, Data1),
strip_role_from_channel_overwrites(RoleId, Data2).
-spec strip_role_from_members(role_id(), guild_data()) -> guild_data().
strip_role_from_members(RoleId, Data) ->
Members = maps:get(<<"members">>, Data, []),
UpdatedMembers = lists:map(
fun(Member) when is_map(Member) ->
MemberRoles = maps:get(<<"roles">>, Member, []),
FilteredRoles = lists:filter(
fun(R) ->
RoleIdInt = utils:binary_to_integer_safe(RoleId),
RInt = utils:binary_to_integer_safe(R),
RInt =/= RoleIdInt
end,
MemberRoles
),
maps:put(<<"roles">>, FilteredRoles, Member);
(Member) ->
Member
RoleIdInt = utils:binary_to_integer_safe(RoleId),
MemberRoleIndex = guild_data_index:member_role_index(Data),
AffectedUsers = maps:keys(maps:get(RoleIdInt, MemberRoleIndex, #{})),
lists:foldl(
fun(UserId, AccData) ->
case guild_data_index:get_member(UserId, AccData) of
undefined ->
AccData;
Member ->
MemberRoles = maps:get(<<"roles">>, Member, []),
FilteredRoles = lists:filter(
fun(R) ->
RInt = utils:binary_to_integer_safe(R),
RInt =/= RoleIdInt
end,
MemberRoles
),
guild_data_index:put_member(maps:put(<<"roles">>, FilteredRoles, Member), AccData)
end
end,
Members
),
maps:put(<<"members">>, UpdatedMembers, Data).
Data,
AffectedUsers
).
-spec strip_role_from_channel_overwrites(role_id(), guild_data()) -> guild_data().
strip_role_from_channel_overwrites(RoleId, Data) ->
Channels = maps:get(<<"channels">>, Data, []),
Channels = guild_data_index:channel_list(Data),
RoleIdInt = utils:binary_to_integer_safe(RoleId),
UpdatedChannels = lists:map(
fun(Channel) when is_map(Channel) ->
Overwrites = maps:get(<<"permission_overwrites">>, Channel, []),
FilteredOverwrites = lists:filter(
fun(Overwrite) when is_map(Overwrite) ->
OverwriteType = maps:get(<<"type">>, Overwrite, 0),
OverwriteId = utils:binary_to_integer_safe(maps:get(<<"id">>, Overwrite, <<"0">>)),
not (OverwriteType =:= 0 andalso OverwriteId =:= RoleIdInt);
(_) ->
true
end,
Overwrites
),
maps:put(<<"permission_overwrites">>, FilteredOverwrites, Channel);
(Channel) ->
Channel
fun
(Channel) when is_map(Channel) ->
Overwrites = maps:get(<<"permission_overwrites">>, Channel, []),
FilteredOverwrites = lists:filter(
fun
(Overwrite) when is_map(Overwrite) ->
OverwriteType = maps:get(<<"type">>, Overwrite, 0),
OverwriteId = utils:binary_to_integer_safe(
maps:get(<<"id">>, Overwrite, <<"0">>)
),
not (OverwriteType =:= 0 andalso OverwriteId =:= RoleIdInt);
(_) ->
true
end,
Overwrites
),
maps:put(<<"permission_overwrites">>, FilteredOverwrites, Channel);
(Channel) ->
Channel
end,
Channels
),
maps:put(<<"channels">>, UpdatedChannels, Data).
guild_data_index:put_channels(UpdatedChannels, Data).
-spec handle_channel_create(event_data(), guild_data()) -> guild_data().
handle_channel_create(EventData, Data) ->
Channels = maps:get(<<"channels">>, Data, []),
maps:put(<<"channels">>, Channels ++ [EventData], Data).
Channels = guild_data_index:channel_list(Data),
guild_data_index:put_channels(Channels ++ [EventData], Data).
-spec handle_channel_update(event_data(), guild_data()) -> guild_data().
handle_channel_update(EventData, Data) ->
Channels = maps:get(<<"channels">>, Data, []),
Channels = guild_data_index:channel_list(Data),
ChannelId = maps:get(<<"id">>, EventData),
UpdatedChannels = replace_item_by_id(Channels, ChannelId, EventData),
maps:put(<<"channels">>, UpdatedChannels, Data).
guild_data_index:put_channels(UpdatedChannels, Data).
-spec handle_channel_update_bulk(event_data(), guild_data()) -> guild_data().
handle_channel_update_bulk(EventData, Data) ->
Channels = maps:get(<<"channels">>, Data, []),
Channels = guild_data_index:channel_list(Data),
BulkChannels = maps:get(<<"channels">>, EventData, []),
UpdatedChannels = bulk_update_items(Channels, BulkChannels),
maps:put(<<"channels">>, UpdatedChannels, Data).
guild_data_index:put_channels(UpdatedChannels, Data).
-spec handle_channel_delete(event_data(), guild_data()) -> guild_data().
handle_channel_delete(EventData, Data) ->
Channels = maps:get(<<"channels">>, Data, []),
Channels = guild_data_index:channel_list(Data),
ChannelId = maps:get(<<"id">>, EventData),
FilteredChannels = remove_item_by_id(Channels, ChannelId),
maps:put(<<"channels">>, FilteredChannels, Data).
guild_data_index:put_channels(FilteredChannels, Data).
-spec handle_message_create(event_data(), guild_data()) -> guild_data().
handle_message_create(EventData, Data) ->
Channels = maps:get(<<"channels">>, Data, []),
Channels = guild_data_index:channel_list(Data),
ChannelId = maps:get(<<"channel_id">>, EventData),
MessageId = maps:get(<<"id">>, EventData),
UpdatedChannels = update_channel_field(Channels, ChannelId, <<"last_message_id">>, MessageId),
maps:put(<<"channels">>, UpdatedChannels, Data).
guild_data_index:put_channels(UpdatedChannels, Data).
-spec handle_channel_pins_update(event_data(), guild_data()) -> guild_data().
handle_channel_pins_update(EventData, Data) ->
Channels = maps:get(<<"channels">>, Data, []),
Channels = guild_data_index:channel_list(Data),
ChannelId = maps:get(<<"channel_id">>, EventData),
LastPin = maps:get(<<"last_pin_timestamp">>, EventData),
UpdatedChannels = update_channel_field(Channels, ChannelId, <<"last_pin_timestamp">>, LastPin),
maps:put(<<"channels">>, UpdatedChannels, Data).
guild_data_index:put_channels(UpdatedChannels, Data).
-spec update_channel_field([map()], binary(), binary(), term()) -> [map()].
update_channel_field(Channels, ChannelId, Field, Value) ->
lists:map(
fun(C) when is_map(C) ->
@@ -294,12 +443,15 @@ update_channel_field(Channels, ChannelId, Field, Value) ->
Channels
).
-spec handle_emojis_update(event_data(), guild_data()) -> guild_data().
handle_emojis_update(EventData, Data) ->
maps:put(<<"emojis">>, maps:get(<<"emojis">>, EventData, []), Data).
-spec handle_stickers_update(event_data(), guild_data()) -> guild_data().
handle_stickers_update(EventData, Data) ->
maps:put(<<"stickers">>, maps:get(<<"stickers">>, EventData, []), Data).
-spec replace_item_by_id([map()], binary(), map()) -> [map()].
replace_item_by_id(Items, Id, NewItem) ->
lists:map(
fun(Item) when is_map(Item) ->
@@ -311,6 +463,7 @@ replace_item_by_id(Items, Id, NewItem) ->
Items
).
-spec remove_item_by_id([map()], binary()) -> [map()].
remove_item_by_id(Items, Id) ->
lists:filter(
fun(Item) when is_map(Item) ->
@@ -319,6 +472,7 @@ remove_item_by_id(Items, Id) ->
Items
).
-spec bulk_update_items([map()], [map()]) -> [map()].
bulk_update_items(Items, BulkItems) ->
BulkMap = lists:foldl(
fun
@@ -333,7 +487,6 @@ bulk_update_items(Items, BulkItems) ->
#{},
BulkItems
),
lists:map(
fun
(Item) when is_map(Item) ->
@@ -358,26 +511,19 @@ handle_role_delete_strips_from_members_test() ->
#{<<"id">> => <<"100">>, <<"name">> => <<"Admin">>},
#{<<"id">> => <<"200">>, <<"name">> => <<"Moderator">>}
],
<<"members">> => [
#{
<<"user">> => #{<<"id">> => <<"1">>},
<<"roles">> => [<<"100">>, <<"200">>]
},
#{
<<"user">> => #{<<"id">> => <<"2">>},
<<"roles">> => [<<"200">>]
},
#{
<<"user">> => #{<<"id">> => <<"3">>},
<<"roles">> => [<<"100">>]
}
],
<<"members">> => #{
1 => #{<<"user">> => #{<<"id">> => <<"1">>}, <<"roles">> => [<<"100">>, <<"200">>]},
2 => #{<<"user">> => #{<<"id">> => <<"2">>}, <<"roles">> => [<<"200">>]},
3 => #{<<"user">> => #{<<"id">> => <<"3">>}, <<"roles">> => [<<"100">>]}
},
<<"channels">> => []
},
EventData = #{<<"role_id">> => RoleIdToDelete},
Result = handle_role_delete(EventData, Data),
Members = maps:get(<<"members">>, Result),
[M1, M2, M3] = Members,
M1 = maps:get(1, Members),
M2 = maps:get(2, Members),
M3 = maps:get(3, Members),
?assertEqual([<<"100">>], maps:get(<<"roles">>, M1)),
?assertEqual([], maps:get(<<"roles">>, M2)),
?assertEqual([<<"100">>], maps:get(<<"roles">>, M3)).
@@ -394,45 +540,24 @@ handle_role_delete_strips_from_channel_overwrites_test() ->
#{
<<"id">> => <<"500">>,
<<"permission_overwrites">> => [
#{<<"id">> => <<"100">>, <<"type">> => 0, <<"allow">> => <<"0">>, <<"deny">> => <<"1024">>},
#{<<"id">> => <<"200">>, <<"type">> => 0, <<"allow">> => <<"1024">>, <<"deny">> => <<"0">>},
#{<<"id">> => <<"1">>, <<"type">> => 1, <<"allow">> => <<"2048">>, <<"deny">> => <<"0">>}
]
},
#{
<<"id">> => <<"501">>,
<<"permission_overwrites">> => [
#{<<"id">> => <<"200">>, <<"type">> => 0, <<"allow">> => <<"1024">>, <<"deny">> => <<"0">>}
]
}
]
},
EventData = #{<<"role_id">> => RoleIdToDelete},
Result = handle_role_delete(EventData, Data),
Channels = maps:get(<<"channels">>, Result),
[Ch1, Ch2] = Channels,
Ch1Overwrites = maps:get(<<"permission_overwrites">>, Ch1),
Ch2Overwrites = maps:get(<<"permission_overwrites">>, Ch2),
?assertEqual(2, length(Ch1Overwrites)),
?assertEqual(0, length(Ch2Overwrites)),
OverwriteIds = [maps:get(<<"id">>, O) || O <- Ch1Overwrites],
?assert(lists:member(<<"100">>, OverwriteIds)),
?assert(lists:member(<<"1">>, OverwriteIds)),
?assertNot(lists:member(<<"200">>, OverwriteIds)).
handle_role_delete_preserves_user_overwrites_test() ->
RoleIdToDelete = <<"200">>,
Data = #{
<<"roles">> => [
#{<<"id">> => <<"200">>, <<"name">> => <<"Moderator">>}
],
<<"members">> => [],
<<"channels">> => [
#{
<<"id">> => <<"500">>,
<<"permission_overwrites">> => [
#{<<"id">> => <<"200">>, <<"type">> => 0, <<"allow">> => <<"1024">>, <<"deny">> => <<"0">>},
#{<<"id">> => <<"200">>, <<"type">> => 1, <<"allow">> => <<"2048">>, <<"deny">> => <<"0">>}
#{
<<"id">> => <<"100">>,
<<"type">> => 0,
<<"allow">> => <<"0">>,
<<"deny">> => <<"1024">>
},
#{
<<"id">> => <<"200">>,
<<"type">> => 0,
<<"allow">> => <<"1024">>,
<<"deny">> => <<"0">>
},
#{
<<"id">> => <<"1">>,
<<"type">> => 1,
<<"allow">> => <<"2048">>,
<<"deny">> => <<"0">>
}
]
}
]
@@ -441,26 +566,141 @@ handle_role_delete_preserves_user_overwrites_test() ->
Result = handle_role_delete(EventData, Data),
Channels = maps:get(<<"channels">>, Result),
[Ch1] = Channels,
Overwrites = maps:get(<<"permission_overwrites">>, Ch1),
?assertEqual(1, length(Overwrites)),
[RemainingOverwrite] = Overwrites,
?assertEqual(1, maps:get(<<"type">>, RemainingOverwrite)).
Ch1Overwrites = maps:get(<<"permission_overwrites">>, Ch1),
?assertEqual(2, length(Ch1Overwrites)).
handle_role_delete_removes_role_from_roles_list_test() ->
RoleIdToDelete = <<"200">>,
handle_member_add_test() ->
Data = #{<<"members">> => #{1 => #{<<"user">> => #{<<"id">> => <<"1">>}}}},
EventData = #{<<"user">> => #{<<"id">> => <<"2">>}},
Result = handle_member_add(EventData, Data),
Members = maps:get(<<"members">>, Result),
?assertEqual(2, map_size(Members)).
handle_member_update_test() ->
Data = #{
<<"roles">> => [
#{<<"id">> => <<"100">>, <<"name">> => <<"Admin">>},
#{<<"id">> => <<"200">>, <<"name">> => <<"Moderator">>}
],
<<"members">> => [],
<<"channels">> => []
<<"members">> => #{
1 => #{<<"user">> => #{<<"id">> => <<"1">>}, <<"nick">> => <<"OldNick">>}
}
},
EventData = #{<<"role_id">> => RoleIdToDelete},
Result = handle_role_delete(EventData, Data),
Roles = maps:get(<<"roles">>, Result),
?assertEqual(1, length(Roles)),
[RemainingRole] = Roles,
?assertEqual(<<"100">>, maps:get(<<"id">>, RemainingRole)).
EventData = #{<<"user">> => #{<<"id">> => <<"1">>}, <<"nick">> => <<"NewNick">>},
Result = handle_member_update(EventData, Data),
Members = maps:get(<<"members">>, Result),
Member = maps:get(1, Members),
?assertEqual(<<"NewNick">>, maps:get(<<"nick">>, Member)).
handle_channel_create_test() ->
Data = #{<<"channels">> => []},
EventData = #{<<"id">> => <<"100">>, <<"name">> => <<"general">>},
Result = handle_channel_create(EventData, Data),
Channels = maps:get(<<"channels">>, Result),
?assertEqual(1, length(Channels)).
bulk_update_items_test() ->
Items = [
#{<<"id">> => <<"1">>, <<"value">> => <<"old1">>},
#{<<"id">> => <<"2">>, <<"value">> => <<"old2">>}
],
BulkItems = [
#{<<"id">> => <<"1">>, <<"value">> => <<"new1">>}
],
Result = bulk_update_items(Items, BulkItems),
[Item1, Item2] = Result,
?assertEqual(<<"new1">>, maps:get(<<"value">>, Item1)),
?assertEqual(<<"old2">>, maps:get(<<"value">>, Item2)).
needs_visibility_check_test() ->
?assertEqual(true, needs_visibility_check(guild_role_update)),
?assertEqual(true, needs_visibility_check(channel_update)),
?assertEqual(false, needs_visibility_check(message_create)),
?assertEqual(false, needs_visibility_check(unknown_event)).
extract_channel_ids_from_channel_update_test() ->
?assertEqual([42], extract_channel_ids_from_channel_update(#{<<"id">> => <<"42">>})),
?assertEqual([], extract_channel_ids_from_channel_update(#{})).
extract_channel_ids_from_channel_update_bulk_test() ->
EventData = #{
<<"channels">> => [
#{<<"id">> => <<"10">>},
#{<<"id">> => <<"11">>},
#{<<"name">> => <<"missing_id">>}
]
},
?assertEqual([10, 11], extract_channel_ids_from_channel_update_bulk(EventData)).
guild_member_remove_disconnects_voice_test() ->
Self = self(),
TestFun = fun(GuildId, ChannelId, UserId, ConnectionId) ->
Self ! {force_disconnect, GuildId, ChannelId, UserId, ConnectionId},
{ok, #{success => true}}
end,
GuildId = 42,
UserId = 5,
ChannelId = 20,
Data = #{
<<"guild">> => #{<<"owner_id">> => <<"999">>},
<<"roles">> => [
#{<<"id">> => integer_to_binary(GuildId), <<"permissions">> => <<"0">>}
],
<<"members">> => #{
UserId => #{<<"user">> => #{<<"id">> => integer_to_binary(UserId)}, <<"roles">> => []}
},
<<"channels">> => [#{<<"id">> => integer_to_binary(ChannelId)}]
},
VoiceStates = #{
<<"conn">> => #{
<<"user_id">> => integer_to_binary(UserId),
<<"guild_id">> => integer_to_binary(GuildId),
<<"channel_id">> => integer_to_binary(ChannelId),
<<"connection_id">> => <<"conn">>
}
},
Sessions = #{<<"s1">> => #{user_id => UserId, pid => self()}},
State = #{
id => GuildId,
data => Data,
voice_states => VoiceStates,
sessions => Sessions,
test_force_disconnect_fun => TestFun
},
EventData = #{<<"user">> => #{<<"id">> => integer_to_binary(UserId)}},
UpdatedState = update_state(guild_member_remove, EventData, State),
?assertEqual(#{}, maps:get(voice_states, UpdatedState)),
?assertEqual(#{}, maps:get(sessions, UpdatedState, #{})),
receive
{force_disconnect, GuildId, ChannelId, UserId, <<"conn">>} -> ok
after 200 ->
?assert(false)
end.
guild_update_syncs_unavailability_cache_test() ->
GuildId = 420042,
CleanupState = #{
id => GuildId,
data => #{
<<"guild">> => #{<<"features">> => []}
}
},
_ = guild_availability:update_unavailability_cache_for_state(CleanupState),
try
State0 = #{
id => GuildId,
data => #{
<<"guild">> => #{<<"features">> => []},
<<"members">> => []
},
sessions => #{}
},
State1 = update_state(
guild_update,
#{<<"features">> => [<<"UNAVAILABLE_FOR_EVERYONE">>]},
State0
),
?assertEqual(unavailable_for_everyone, guild_availability:get_cached_unavailability_mode(GuildId)),
_State2 = update_state(guild_update, #{<<"features">> => []}, State1),
?assertEqual(available, guild_availability:get_cached_unavailability_mode(GuildId))
after
_ = guild_availability:update_unavailability_cache_for_state(CleanupState)
end.
-endif.

View File

@@ -77,10 +77,8 @@ unsubscribe_session(SessionId, State) ->
update_subscriptions(SessionId, NewMemberIds, State) ->
CurrentlySubscribed = get_user_ids_for_session(SessionId, State),
NewMemberIdSet = sets:from_list(NewMemberIds),
ToRemove = sets:subtract(CurrentlySubscribed, NewMemberIdSet),
ToAdd = sets:subtract(NewMemberIdSet, CurrentlySubscribed),
State1 = sets:fold(
fun(UserId, AccState) ->
unsubscribe(SessionId, UserId, AccState)
@@ -88,7 +86,6 @@ update_subscriptions(SessionId, NewMemberIds, State) ->
State,
ToRemove
),
sets:fold(
fun(UserId, AccState) ->
subscribe(SessionId, UserId, AccState)
@@ -182,4 +179,11 @@ update_subscriptions_test() ->
?assert(is_subscribed(<<"session1">>, 200, State3)),
?assert(is_subscribed(<<"session1">>, 300, State3)).
get_user_ids_for_session_test() ->
State0 = init_state(),
State1 = subscribe(<<"session1">>, 100, State0),
State2 = subscribe(<<"session1">>, 200, State1),
UserIds = get_user_ids_for_session(<<"session1">>, State2),
?assertEqual([100, 200], lists:sort(sets:to_list(UserIds))).
-endif.

View File

@@ -19,5 +19,6 @@
-export([send_guild_sync/2]).
-spec send_guild_sync(pid(), binary()) -> ok.
send_guild_sync(GuildPid, SessionId) ->
gen_server:cast(GuildPid, {send_guild_sync, SessionId}).

View File

@@ -23,95 +23,88 @@
-include_lib("eunit/include/eunit.hrl").
-endif.
-spec handle_subscriptions(map(), pid(), map()) -> ok.
-type session_state() :: map().
-type session_id() :: binary() | undefined.
-spec handle_subscriptions(map(), pid(), session_state()) -> ok.
handle_subscriptions(Data, SocketPid, SessionState) ->
Subscriptions = maps:get(<<"subscriptions">>, Data, #{}),
Guilds = maps:get(guilds, SessionState, #{}),
SessionId = maps:get(id, SessionState, undefined),
logger:debug("[guild_unified_subscriptions] Processing ~p guild subscriptions for session ~p", [
map_size(Subscriptions), SessionId
]),
maps:foreach(
fun(GuildIdBin, GuildSubData) ->
process_guild_subscription(GuildIdBin, GuildSubData, Guilds, SessionId, SocketPid, SessionState)
process_guild_subscription(
GuildIdBin, GuildSubData, Guilds, SessionId, SocketPid, SessionState
)
end,
Subscriptions
),
ok.
-spec process_guild_subscription(binary(), map(), map(), binary() | undefined, pid(), map()) -> ok.
-spec process_guild_subscription(binary(), map(), map(), session_id(), pid(), session_state()) ->
ok.
process_guild_subscription(GuildIdBin, GuildSubData, Guilds, SessionId, SocketPid, SessionState) ->
case validation:validate_snowflake(<<"guild_id">>, GuildIdBin) of
{ok, GuildId} ->
case maps:get(GuildId, Guilds, undefined) of
{GuildPid, _Ref} when is_pid(GuildPid) ->
process_guild_sub_options(GuildId, GuildPid, GuildSubData, SessionId, SocketPid, SessionState);
process_guild_sub_options(
GuildId, GuildPid, GuildSubData, SessionId, SocketPid, SessionState
);
undefined ->
logger:warning("[guild_unified_subscriptions] Guild ~p not found in session state", [GuildId]),
ok;
_ ->
ok
end;
{error, _, Reason} ->
logger:warning("[guild_unified_subscriptions] Invalid guild_id ~p: ~p", [GuildIdBin, Reason]),
{error, _, _} ->
ok
end.
-spec process_guild_sub_options(integer(), pid(), map(), binary() | undefined, pid(), map()) -> ok.
-spec process_guild_sub_options(integer(), pid(), map(), session_id(), pid(), session_state()) ->
ok.
process_guild_sub_options(GuildId, GuildPid, GuildSubData, SessionId, SocketPid, SessionState) ->
WasActive = not session_passive:is_passive(GuildId, SessionState),
ActiveChanged = process_active_flag(GuildSubData, GuildPid, SessionId, WasActive),
process_sync_flag(GuildSubData, GuildId, GuildPid, SessionId, ActiveChanged),
process_member_list_channels(GuildSubData, GuildId, GuildPid, SessionId, SocketPid),
process_member_subscriptions(GuildSubData, GuildPid, SessionId),
process_typing_flag(GuildSubData, GuildPid, SessionId),
ok.
-spec process_active_flag(map(), pid(), binary() | undefined, boolean()) -> boolean().
-spec process_active_flag(map(), pid(), session_id(), boolean()) -> boolean().
process_active_flag(GuildSubData, GuildPid, SessionId, WasActive) ->
case maps:get(<<"active">>, GuildSubData, undefined) of
undefined ->
false;
true ->
gen_server:cast(GuildPid, {set_session_active, SessionId}),
logger:debug("[guild_unified_subscriptions] Set session ~p active", [SessionId]),
not WasActive;
false ->
gen_server:cast(GuildPid, {set_session_passive, SessionId}),
logger:debug("[guild_unified_subscriptions] Set session ~p passive", [SessionId]),
WasActive
end.
-spec process_sync_flag(map(), integer(), pid(), binary() | undefined, boolean()) -> ok.
process_sync_flag(GuildSubData, GuildId, GuildPid, SessionId, ActiveChanged) ->
-spec process_sync_flag(map(), integer(), pid(), session_id(), boolean()) -> ok.
process_sync_flag(GuildSubData, _GuildId, GuildPid, SessionId, ActiveChanged) ->
ShouldSync = maps:get(<<"sync">>, GuildSubData, false) =:= true orelse ActiveChanged,
case ShouldSync of
true ->
guild_sync:send_guild_sync(GuildPid, SessionId),
logger:debug("[guild_unified_subscriptions] Sent guild sync for guild ~p", [GuildId]);
guild_sync:send_guild_sync(GuildPid, SessionId);
false ->
ok
end.
-spec process_member_list_channels(map(), integer(), pid(), binary() | undefined, pid()) -> ok.
-spec process_member_list_channels(map(), integer(), pid(), session_id(), pid()) -> ok.
process_member_list_channels(GuildSubData, GuildId, GuildPid, SessionId, SocketPid) ->
case maps:get(<<"member_list_channels">>, GuildSubData, undefined) of
undefined ->
ok;
MemberListChannels when is_map(MemberListChannels) ->
logger:debug("[guild_unified_subscriptions] Processing ~p member list channels for guild ~p", [
map_size(MemberListChannels), GuildId
]),
maps:foreach(
fun(ChannelIdBin, Ranges) ->
process_channel_lazy_subscribe(ChannelIdBin, Ranges, GuildId, GuildPid, SessionId, SocketPid)
process_channel_lazy_subscribe(
ChannelIdBin, Ranges, GuildId, GuildPid, SessionId, SocketPid
)
end,
MemberListChannels
);
@@ -119,28 +112,29 @@ process_member_list_channels(GuildSubData, GuildId, GuildPid, SessionId, SocketP
ok
end.
-spec process_channel_lazy_subscribe(binary(), list(), integer(), pid(), binary() | undefined, pid()) -> ok.
-spec process_channel_lazy_subscribe(binary(), list(), integer(), pid(), session_id(), pid()) -> ok.
process_channel_lazy_subscribe(ChannelIdBin, Ranges, _GuildId, GuildPid, SessionId, _SocketPid) ->
case validation:validate_snowflake(<<"channel_id">>, ChannelIdBin) of
{ok, ChannelId} ->
ParsedRanges = parse_ranges(Ranges),
logger:debug("[guild_unified_subscriptions] Lazy subscribe channel ~p with ranges ~p", [
ChannelId, ParsedRanges
]),
case gen_server:call(GuildPid, {lazy_subscribe, #{
session_id => SessionId,
channel_id => ChannelId,
ranges => ParsedRanges
}}, 10000) of
case
gen_server:call(
GuildPid,
{lazy_subscribe, #{
session_id => SessionId,
channel_id => ChannelId,
ranges => ParsedRanges
}},
10000
)
of
ok ->
ok;
Error ->
logger:error("[guild_unified_subscriptions] lazy_subscribe failed for channel ~p: ~p", [
ChannelId, Error
])
_Error ->
ok
end;
{error, _, Reason} ->
logger:warning("[guild_unified_subscriptions] Invalid channel_id ~p: ~p", [ChannelIdBin, Reason])
{error, _, _} ->
ok
end,
ok.
@@ -160,16 +154,13 @@ parse_ranges(Ranges) when is_list(Ranges) ->
parse_ranges(_) ->
[].
-spec process_member_subscriptions(map(), pid(), binary() | undefined) -> ok.
-spec process_member_subscriptions(map(), pid(), session_id()) -> ok.
process_member_subscriptions(GuildSubData, GuildPid, SessionId) ->
case maps:get(<<"members">>, GuildSubData, undefined) of
undefined ->
ok;
Members when is_list(Members) ->
MemberIds = parse_member_ids(Members),
logger:debug("[guild_unified_subscriptions] Updating member subscriptions with ~p members", [
length(MemberIds)
]),
gen_server:cast(GuildPid, {update_member_subscriptions, SessionId, MemberIds});
_ ->
ok
@@ -189,16 +180,13 @@ parse_member_ids(Members) when is_list(Members) ->
parse_member_ids(_) ->
[].
-spec process_typing_flag(map(), pid(), binary() | undefined) -> ok.
-spec process_typing_flag(map(), pid(), session_id()) -> ok.
process_typing_flag(GuildSubData, GuildPid, SessionId) ->
case maps:get(<<"typing">>, GuildSubData, undefined) of
undefined ->
ok;
TypingFlag when is_boolean(TypingFlag) ->
gen_server:cast(GuildPid, {set_session_typing_override, SessionId, TypingFlag}),
logger:debug("[guild_unified_subscriptions] Set typing override to ~p for session ~p", [
TypingFlag, SessionId
]);
gen_server:cast(GuildPid, {set_session_typing_override, SessionId, TypingFlag});
_ ->
ok
end.

View File

@@ -24,129 +24,155 @@
check_user_data_differs/2
]).
-import(guild_permissions, [find_member_by_user_id/2]).
-type guild_state() :: map().
-type user_id() :: integer().
-type member() :: map().
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-endif.
-spec update_user_data(map(), guild_state()) -> {noreply, guild_state()}.
update_user_data(EventData, State) ->
UserId = utils:binary_to_integer_safe(maps:get(<<"id">>, EventData)),
Data = maps:get(data, State),
Members = maps:get(<<"members">>, Data, []),
UpdatedMembers = lists:map(
fun(Member) when is_map(Member) ->
MUser = maps:get(<<"user">>, Member, #{}),
MemberId =
case is_map(MUser) of
true ->
utils:binary_to_integer_safe(maps:get(<<"id">>, MUser, <<"0">>));
false ->
undefined
end,
if
MemberId =:= UserId ->
maps:put(<<"user">>, EventData, Member);
true ->
Member
end
Members = guild_data_index:member_map(Data),
UpdatedMembers = maps:map(
fun(_MemberUserId, Member) ->
maybe_update_member_user(Member, UserId, EventData)
end,
Members
),
UpdatedData = maps:put(<<"members">>, UpdatedMembers, Data),
UpdatedData = guild_data_index:put_member_map(UpdatedMembers, Data),
UpdatedState = maps:put(data, UpdatedData, State),
UpdatedMember = find_member_by_user_id(UserId, UpdatedState),
case UpdatedMember of
undefined -> ok;
M -> gen_server:cast(self(), {dispatch, #{event => guild_member_update, data => M}})
end,
dispatch_member_update_if_found(UserId, UpdatedState),
{noreply, UpdatedState}.
handle_user_data_update(UserId, UserData, State) ->
Data = maps:get(data, State),
Members = maps:get(<<"members">>, Data, []),
-spec maybe_update_member_user(member(), user_id(), map()) -> member().
maybe_update_member_user(Member, UserId, EventData) ->
MUser = maps:get(<<"user">>, Member, #{}),
MemberId =
case is_map(MUser) of
true -> utils:binary_to_integer_safe(maps:get(<<"id">>, MUser, <<"0">>));
false -> undefined
end,
case MemberId =:= UserId of
true -> maps:put(<<"user">>, EventData, Member);
false -> Member
end.
CurrentMember = find_member_by_user_id(UserId, State),
case CurrentMember of
-spec dispatch_member_update_if_found(user_id(), guild_state()) -> ok.
dispatch_member_update_if_found(UserId, State) ->
case guild_permissions:find_member_by_user_id(UserId, State) of
undefined -> ok;
M -> gen_server:cast(self(), {dispatch, #{event => guild_member_update, data => M}})
end.
-spec handle_user_data_update(user_id(), map(), guild_state()) -> guild_state().
handle_user_data_update(UserId, UserData, State) ->
case guild_permissions:find_member_by_user_id(UserId, State) of
undefined ->
State;
Member ->
CurrentUserData = maps:get(<<"user">>, Member, #{}),
IsDifferent = check_user_data_differs(CurrentUserData, UserData),
if
IsDifferent ->
UpdatedMembers = lists:map(
fun(M) when is_map(M) ->
MUser = maps:get(<<"user">>, M, #{}),
MemberId =
case is_map(MUser) of
true ->
utils:binary_to_integer_safe(
maps:get(<<"id">>, MUser, <<"0">>)
);
false ->
undefined
end,
if
MemberId =:= UserId ->
maps:put(<<"user">>, UserData, M);
true ->
M
end
end,
Members
),
UpdatedData = maps:put(<<"members">>, UpdatedMembers, Data),
UpdatedState = maps:put(data, UpdatedData, State),
UpdatedMember = find_member_by_user_id(UserId, UpdatedState),
case UpdatedMember of
undefined ->
ok;
M ->
GuildId = maps:get(id, UpdatedState),
MemberUpdateData = maps:put(
<<"guild_id">>, integer_to_binary(GuildId), M
),
gen_server:cast(
self(),
{dispatch, #{
event => guild_member_update, data => MemberUpdateData
}}
)
end,
UpdatedState;
case check_user_data_differs(CurrentUserData, UserData) of
false ->
State;
true ->
State
apply_user_data_update(UserId, UserData, State)
end
end.
-spec apply_user_data_update(user_id(), map(), guild_state()) -> guild_state().
apply_user_data_update(UserId, UserData, State) ->
Data = maps:get(data, State),
Members = guild_data_index:member_map(Data),
UpdatedMembers = maps:map(
fun(_MemberUserId, Member) ->
maybe_update_member_user(Member, UserId, UserData)
end,
Members
),
UpdatedData = guild_data_index:put_member_map(UpdatedMembers, Data),
UpdatedState = maps:put(data, UpdatedData, State),
dispatch_guild_member_update(UserId, UpdatedState),
UpdatedState.
-spec dispatch_guild_member_update(user_id(), guild_state()) -> ok.
dispatch_guild_member_update(UserId, State) ->
case guild_permissions:find_member_by_user_id(UserId, State) of
undefined ->
ok;
M ->
GuildId = maps:get(id, State),
MemberUpdateData = maps:put(<<"guild_id">>, integer_to_binary(GuildId), M),
gen_server:cast(
self(),
{dispatch, #{event => guild_member_update, data => MemberUpdateData}}
)
end.
-spec check_user_data_differs(map(), map()) -> boolean().
check_user_data_differs(CurrentUserData, NewUserData) ->
utils:check_user_data_differs(CurrentUserData, NewUserData).
maybe_update_cached_user_data(Event, EventData, State) ->
case Event of
E when E =:= message_create; E =:= message_update ->
case maps:get(<<"author">>, EventData, undefined) of
-spec maybe_update_cached_user_data(atom(), map(), guild_state()) -> guild_state().
maybe_update_cached_user_data(Event, EventData, State) when
Event =:= message_create; Event =:= message_update
->
case maps:get(<<"author">>, EventData, undefined) of
undefined ->
State;
AuthorData ->
UserId = utils:binary_to_integer_safe(maps:get(<<"id">>, AuthorData, <<"0">>)),
case guild_permissions:find_member_by_user_id(UserId, State) of
undefined ->
State;
AuthorData ->
UserId = utils:binary_to_integer_safe(maps:get(<<"id">>, AuthorData, <<"0">>)),
case find_member_by_user_id(UserId, State) of
undefined ->
State;
Member ->
CurrentUserData = maps:get(<<"user">>, Member, #{}),
case check_user_data_differs(CurrentUserData, AuthorData) of
true ->
handle_user_data_update(UserId, AuthorData, State);
false ->
State
end
Member ->
CurrentUserData = maps:get(<<"user">>, Member, #{}),
case check_user_data_differs(CurrentUserData, AuthorData) of
true ->
handle_user_data_update(UserId, AuthorData, State);
false ->
State
end
end;
_ ->
State
end.
end
end;
maybe_update_cached_user_data(_, _, State) ->
State.
-ifdef(TEST).
update_user_data_updates_member_test() ->
State = test_state(),
EventData = #{<<"id">> => <<"100">>, <<"username">> => <<"updated">>},
{noreply, UpdatedState} = update_user_data(EventData, State),
Data = maps:get(data, UpdatedState),
Member = maps:get(100, maps:get(<<"members">>, Data)),
User = maps:get(<<"user">>, Member),
?assertEqual(<<"updated">>, maps:get(<<"username">>, User)).
handle_user_data_update_no_change_test() ->
State = test_state(),
UserData = #{<<"id">> => <<"100">>, <<"username">> => <<"alice">>},
NewState = handle_user_data_update(100, UserData, State),
?assertEqual(State, NewState).
check_user_data_differs_test() ->
Current = #{<<"username">> => <<"alice">>},
Same = #{<<"username">> => <<"alice">>},
Different = #{<<"username">> => <<"bob">>},
?assertEqual(false, check_user_data_differs(Current, Same)),
?assertEqual(true, check_user_data_differs(Current, Different)).
test_state() ->
#{
id => 42,
data => #{
<<"members">> => #{
100 => #{<<"user">> => #{<<"id">> => <<"100">>, <<"username">> => <<"alice">>}}
}
}
}.
-endif.

View File

@@ -23,18 +23,37 @@
has_virtual_access/3,
get_virtual_channels_for_user/2,
get_users_with_virtual_access/2,
dispatch_channel_visibility_change/4
dispatch_channel_visibility_change/4,
mark_pending_join/3,
clear_pending_join/3,
is_pending_join/3,
mark_preserve/3,
clear_preserve/3,
has_preserve/3,
mark_move_pending/3,
clear_move_pending/3,
is_move_pending/3
]).
-import(guild_permissions, [find_channel_by_id/2]).
-type guild_state() :: map().
-type user_id() :: integer().
-type channel_id() :: integer().
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-endif.
-spec add_virtual_access(user_id(), channel_id(), guild_state()) -> guild_state().
add_virtual_access(UserId, ChannelId, State) ->
VirtualAccess = maps:get(virtual_channel_access, State, #{}),
UserChannels = maps:get(UserId, VirtualAccess, sets:new()),
UpdatedUserChannels = sets:add_element(ChannelId, UserChannels),
UpdatedVirtualAccess = maps:put(UserId, UpdatedUserChannels, VirtualAccess),
maps:put(virtual_channel_access, UpdatedVirtualAccess, State).
State1 = maps:put(virtual_channel_access, UpdatedVirtualAccess, State),
State2 = update_user_session_view_cache(UserId, ChannelId, add, State1),
mark_pending_join(UserId, ChannelId, State2).
-spec remove_virtual_access(user_id(), channel_id(), guild_state()) -> guild_state().
remove_virtual_access(UserId, ChannelId, State) ->
VirtualAccess = maps:get(virtual_channel_access, State, #{}),
case maps:get(UserId, VirtualAccess, undefined) of
@@ -44,32 +63,149 @@ remove_virtual_access(UserId, ChannelId, State) ->
UpdatedUserChannels = sets:del_element(ChannelId, UserChannels),
case sets:size(UpdatedUserChannels) of
0 ->
UpdatedVirtualAccess = maps:remove(UserId, VirtualAccess),
maps:put(virtual_channel_access, UpdatedVirtualAccess, State);
remove_all_user_virtual_access(UserId, State);
_ ->
UpdatedVirtualAccess = maps:put(UserId, UpdatedUserChannels, VirtualAccess),
maps:put(virtual_channel_access, UpdatedVirtualAccess, State)
update_user_virtual_access(UserId, ChannelId, UpdatedUserChannels, State)
end
end.
-spec remove_all_user_virtual_access(user_id(), guild_state()) -> guild_state().
remove_all_user_virtual_access(UserId, State) ->
VirtualAccess = maps:get(virtual_channel_access, State, #{}),
PendingMap = maps:get(virtual_channel_access_pending, State, #{}),
PreserveMap = maps:get(virtual_channel_access_preserve, State, #{}),
MoveMap = maps:get(virtual_channel_access_move_pending, State, #{}),
UpdatedVirtualAccess = maps:remove(UserId, VirtualAccess),
UpdatedPending = maps:remove(UserId, PendingMap),
UpdatedPreserve = maps:remove(UserId, PreserveMap),
UpdatedMove = maps:remove(UserId, MoveMap),
State1 = maps:put(virtual_channel_access, UpdatedVirtualAccess, State),
State2 = maps:put(virtual_channel_access_pending, UpdatedPending, State1),
State3 = maps:put(virtual_channel_access_preserve, UpdatedPreserve, State2),
State4 = maps:put(virtual_channel_access_move_pending, UpdatedMove, State3),
clear_user_session_view_cache(UserId, State4).
-spec update_user_virtual_access(user_id(), channel_id(), sets:set(), guild_state()) ->
guild_state().
update_user_virtual_access(UserId, ChannelId, UpdatedUserChannels, State) ->
VirtualAccess = maps:get(virtual_channel_access, State, #{}),
PendingMap = maps:get(virtual_channel_access_pending, State, #{}),
PreserveMap = maps:get(virtual_channel_access_preserve, State, #{}),
MoveMap = maps:get(virtual_channel_access_move_pending, State, #{}),
UpdatedVirtualAccess = maps:put(UserId, UpdatedUserChannels, VirtualAccess),
UpdatedUserPending = sets:del_element(ChannelId, maps:get(UserId, PendingMap, sets:new())),
UpdatedPending = maps:put(UserId, UpdatedUserPending, PendingMap),
UpdatedUserPreserve = sets:del_element(ChannelId, maps:get(UserId, PreserveMap, sets:new())),
UpdatedPreserve = maps:put(UserId, UpdatedUserPreserve, PreserveMap),
UpdatedUserMove = sets:del_element(ChannelId, maps:get(UserId, MoveMap, sets:new())),
UpdatedMove = maps:put(UserId, UpdatedUserMove, MoveMap),
State1 = maps:put(virtual_channel_access, UpdatedVirtualAccess, State),
State2 = maps:put(virtual_channel_access_pending, UpdatedPending, State1),
State3 = maps:put(virtual_channel_access_preserve, UpdatedPreserve, State2),
State4 = maps:put(virtual_channel_access_move_pending, UpdatedMove, State3),
update_user_session_view_cache(UserId, ChannelId, remove, State4).
-spec has_virtual_access(user_id(), channel_id(), guild_state()) -> boolean().
has_virtual_access(UserId, ChannelId, State) ->
VirtualAccess = maps:get(virtual_channel_access, State, #{}),
case maps:get(UserId, VirtualAccess, undefined) of
undefined ->
false;
UserChannels ->
sets:is_element(ChannelId, UserChannels)
undefined -> false;
UserChannels -> sets:is_element(ChannelId, UserChannels)
end.
-spec get_virtual_channels_for_user(user_id(), guild_state()) -> [channel_id()].
get_virtual_channels_for_user(UserId, State) ->
VirtualAccess = maps:get(virtual_channel_access, State, #{}),
case maps:get(UserId, VirtualAccess, undefined) of
undefined ->
[];
UserChannels ->
sets:to_list(UserChannels)
undefined -> [];
UserChannels -> sets:to_list(UserChannels)
end.
-spec mark_pending_join(user_id(), channel_id(), guild_state()) -> guild_state().
mark_pending_join(UserId, ChannelId, State) ->
PendingMap = maps:get(virtual_channel_access_pending, State, #{}),
UserPending = maps:get(UserId, PendingMap, sets:new()),
UpdatedUserPending = sets:add_element(ChannelId, UserPending),
UpdatedPending = maps:put(UserId, UpdatedUserPending, PendingMap),
maps:put(virtual_channel_access_pending, UpdatedPending, State).
-spec clear_pending_join(user_id(), channel_id(), guild_state()) -> guild_state().
clear_pending_join(UserId, ChannelId, State) ->
PendingMap = maps:get(virtual_channel_access_pending, State, #{}),
UserPending = maps:get(UserId, PendingMap, sets:new()),
UpdatedUserPending = sets:del_element(ChannelId, UserPending),
UpdatedPending =
case sets:size(UpdatedUserPending) of
0 -> maps:remove(UserId, PendingMap);
_ -> maps:put(UserId, UpdatedUserPending, PendingMap)
end,
maps:put(virtual_channel_access_pending, UpdatedPending, State).
-spec is_pending_join(user_id(), channel_id(), guild_state()) -> boolean().
is_pending_join(UserId, ChannelId, State) ->
PendingMap = maps:get(virtual_channel_access_pending, State, #{}),
case maps:get(UserId, PendingMap, undefined) of
undefined -> false;
UserPending -> sets:is_element(ChannelId, UserPending)
end.
-spec mark_preserve(user_id(), channel_id(), guild_state()) -> guild_state().
mark_preserve(UserId, ChannelId, State) ->
PreserveMap = maps:get(virtual_channel_access_preserve, State, #{}),
UserPreserve = maps:get(UserId, PreserveMap, sets:new()),
UpdatedUserPreserve = sets:add_element(ChannelId, UserPreserve),
UpdatedPreserve = maps:put(UserId, UpdatedUserPreserve, PreserveMap),
maps:put(virtual_channel_access_preserve, UpdatedPreserve, State).
-spec clear_preserve(user_id(), channel_id(), guild_state()) -> guild_state().
clear_preserve(UserId, ChannelId, State) ->
PreserveMap = maps:get(virtual_channel_access_preserve, State, #{}),
UserPreserve = maps:get(UserId, PreserveMap, sets:new()),
UpdatedUserPreserve = sets:del_element(ChannelId, UserPreserve),
UpdatedPreserve =
case sets:size(UpdatedUserPreserve) of
0 -> maps:remove(UserId, PreserveMap);
_ -> maps:put(UserId, UpdatedUserPreserve, PreserveMap)
end,
maps:put(virtual_channel_access_preserve, UpdatedPreserve, State).
-spec has_preserve(user_id(), channel_id(), guild_state()) -> boolean().
has_preserve(UserId, ChannelId, State) ->
PreserveMap = maps:get(virtual_channel_access_preserve, State, #{}),
case maps:get(UserId, PreserveMap, undefined) of
undefined -> false;
UserPreserve -> sets:is_element(ChannelId, UserPreserve)
end.
-spec mark_move_pending(user_id(), channel_id(), guild_state()) -> guild_state().
mark_move_pending(UserId, ChannelId, State) ->
MoveMap = maps:get(virtual_channel_access_move_pending, State, #{}),
UserMoves = maps:get(UserId, MoveMap, sets:new()),
UpdatedUserMoves = sets:add_element(ChannelId, UserMoves),
UpdatedMoveMap = maps:put(UserId, UpdatedUserMoves, MoveMap),
maps:put(virtual_channel_access_move_pending, UpdatedMoveMap, State).
-spec clear_move_pending(user_id(), channel_id(), guild_state()) -> guild_state().
clear_move_pending(UserId, ChannelId, State) ->
MoveMap = maps:get(virtual_channel_access_move_pending, State, #{}),
UserMoves = maps:get(UserId, MoveMap, sets:new()),
UpdatedUserMoves = sets:del_element(ChannelId, UserMoves),
UpdatedMoveMap =
case sets:size(UpdatedUserMoves) of
0 -> maps:remove(UserId, MoveMap);
_ -> maps:put(UserId, UpdatedUserMoves, MoveMap)
end,
maps:put(virtual_channel_access_move_pending, UpdatedMoveMap, State).
-spec is_move_pending(user_id(), channel_id(), guild_state()) -> boolean().
is_move_pending(UserId, ChannelId, State) ->
MoveMap = maps:get(virtual_channel_access_move_pending, State, #{}),
case maps:get(UserId, MoveMap, undefined) of
undefined -> false;
UserMoves -> sets:is_element(ChannelId, UserMoves)
end.
-spec get_users_with_virtual_access(channel_id(), guild_state()) -> [user_id()].
get_users_with_virtual_access(ChannelId, State) ->
VirtualAccess = maps:get(virtual_channel_access, State, #{}),
maps:fold(
@@ -83,45 +219,160 @@ get_users_with_virtual_access(ChannelId, State) ->
VirtualAccess
).
-spec dispatch_channel_visibility_change(user_id(), channel_id(), add | remove, guild_state()) ->
ok.
dispatch_channel_visibility_change(UserId, ChannelId, Action, State) ->
Channel = find_channel_by_id(ChannelId, State),
Channel = guild_permissions:find_channel_by_id(ChannelId, State),
case Channel of
undefined ->
ok;
_ ->
Sessions = maps:get(sessions, State, #{}),
GuildId = maps:get(id, State),
UserSessions = maps:filter(
fun(_Sid, SessionData) ->
maps:get(user_id, SessionData) =:= UserId
end,
Sessions
),
case Action of
add ->
ChannelWithGuild = maps:put(
<<"guild_id">>, integer_to_binary(GuildId), Channel
),
maps:foreach(
fun(_Sid, SessionData) ->
Pid = maps:get(pid, SessionData),
gen_server:cast(Pid, {dispatch, channel_create, ChannelWithGuild})
end,
UserSessions
);
remove ->
ChannelDelete = #{
<<"id">> => integer_to_binary(ChannelId),
<<"guild_id">> => integer_to_binary(GuildId)
},
maps:foreach(
fun(_Sid, SessionData) ->
Pid = maps:get(pid, SessionData),
gen_server:cast(Pid, {dispatch, channel_delete, ChannelDelete})
end,
UserSessions
)
end
dispatch_to_user_sessions(Action, Channel, ChannelId, GuildId, UserSessions)
end.
-spec dispatch_to_user_sessions(add | remove, map(), channel_id(), integer(), map()) -> ok.
dispatch_to_user_sessions(add, Channel, _ChannelId, GuildId, UserSessions) ->
ChannelWithGuild = maps:put(<<"guild_id">>, integer_to_binary(GuildId), Channel),
maps:foreach(
fun(_Sid, SessionData) ->
Pid = maps:get(pid, SessionData),
gen_server:cast(Pid, {dispatch, channel_create, ChannelWithGuild})
end,
UserSessions
);
dispatch_to_user_sessions(remove, _Channel, ChannelId, GuildId, UserSessions) ->
ChannelDelete = #{
<<"id">> => integer_to_binary(ChannelId),
<<"guild_id">> => integer_to_binary(GuildId)
},
maps:foreach(
fun(_Sid, SessionData) ->
Pid = maps:get(pid, SessionData),
gen_server:cast(Pid, {dispatch, channel_delete, ChannelDelete})
end,
UserSessions
).
-spec update_user_session_view_cache(user_id(), channel_id(), add | remove, guild_state()) ->
guild_state().
update_user_session_view_cache(UserId, ChannelId, Action, State) ->
Sessions = maps:get(sessions, State, #{}),
UpdatedSessions = maps:map(
fun(_SessionId, SessionData) ->
case maps:get(user_id, SessionData, undefined) of
UserId ->
update_session_view_cache(SessionData, ChannelId, Action);
_ ->
SessionData
end
end,
Sessions
),
maps:put(sessions, UpdatedSessions, State).
-spec clear_user_session_view_cache(user_id(), guild_state()) -> guild_state().
clear_user_session_view_cache(UserId, State) ->
Sessions = maps:get(sessions, State, #{}),
UpdatedSessions = maps:map(
fun(_SessionId, SessionData) ->
case maps:get(user_id, SessionData, undefined) of
UserId ->
maps:put(viewable_channels, #{}, SessionData);
_ ->
SessionData
end
end,
Sessions
),
maps:put(sessions, UpdatedSessions, State).
-spec update_session_view_cache(map(), channel_id(), add | remove) -> map().
update_session_view_cache(SessionData, ChannelId, add) ->
ViewableChannels = ensure_viewable_channel_map(maps:get(viewable_channels, SessionData, #{})),
UpdatedViewableChannels = maps:put(ChannelId, true, ViewableChannels),
maps:put(viewable_channels, UpdatedViewableChannels, SessionData);
update_session_view_cache(SessionData, ChannelId, remove) ->
ViewableChannels = ensure_viewable_channel_map(maps:get(viewable_channels, SessionData, #{})),
UpdatedViewableChannels = maps:remove(ChannelId, ViewableChannels),
maps:put(viewable_channels, UpdatedViewableChannels, SessionData).
-spec ensure_viewable_channel_map(term()) -> map().
ensure_viewable_channel_map(ViewableChannels) when is_map(ViewableChannels) ->
ViewableChannels;
ensure_viewable_channel_map(_) ->
#{}.
-ifdef(TEST).
add_virtual_access_test() ->
State = #{},
State1 = add_virtual_access(100, 500, State),
?assertEqual(true, has_virtual_access(100, 500, State1)),
?assertEqual(true, is_pending_join(100, 500, State1)).
add_virtual_access_updates_session_cache_test() ->
State = #{
sessions => #{
<<"s1">> => #{user_id => 100, viewable_channels => #{}}
}
},
State1 = add_virtual_access(100, 500, State),
Session = maps:get(<<"s1">>, maps:get(sessions, State1)),
ViewableChannels = maps:get(viewable_channels, Session, #{}),
?assertEqual(true, maps:is_key(500, ViewableChannels)).
remove_virtual_access_test() ->
State = add_virtual_access(100, 500, #{}),
State1 = remove_virtual_access(100, 500, State),
?assertEqual(false, has_virtual_access(100, 500, State1)).
remove_virtual_access_updates_session_cache_test() ->
State = #{
sessions => #{
<<"s1">> => #{user_id => 100, viewable_channels => #{500 => true}}
},
virtual_channel_access => #{100 => sets:from_list([500])},
virtual_channel_access_pending => #{100 => sets:from_list([500])},
virtual_channel_access_preserve => #{100 => sets:new()},
virtual_channel_access_move_pending => #{100 => sets:new()}
},
State1 = remove_virtual_access(100, 500, State),
Session = maps:get(<<"s1">>, maps:get(sessions, State1)),
ViewableChannels = maps:get(viewable_channels, Session, #{}),
?assertEqual(false, maps:is_key(500, ViewableChannels)).
get_virtual_channels_for_user_test() ->
State = add_virtual_access(100, 500, #{}),
State1 = add_virtual_access(100, 501, State),
Channels = lists:sort(get_virtual_channels_for_user(100, State1)),
?assertEqual([500, 501], Channels).
get_users_with_virtual_access_test() ->
State = add_virtual_access(100, 500, #{}),
State1 = add_virtual_access(101, 500, State),
Users = lists:sort(get_users_with_virtual_access(500, State1)),
?assertEqual([100, 101], Users).
mark_and_clear_preserve_test() ->
State = add_virtual_access(100, 500, #{}),
State1 = mark_preserve(100, 500, State),
?assertEqual(true, has_preserve(100, 500, State1)),
State2 = clear_preserve(100, 500, State1),
?assertEqual(false, has_preserve(100, 500, State2)).
mark_and_clear_move_pending_test() ->
State = add_virtual_access(100, 500, #{}),
State1 = mark_move_pending(100, 500, State),
?assertEqual(true, is_move_pending(100, 500, State1)),
State2 = clear_move_pending(100, 500, State1),
?assertEqual(false, is_move_pending(100, 500, State2)).
-endif.

View File

@@ -20,19 +20,25 @@
-export([
get_user_viewable_channels/2,
compute_and_dispatch_visibility_changes/2,
compute_and_dispatch_visibility_changes_for_users/3,
compute_and_dispatch_visibility_changes_for_channels/3,
viewable_channel_set/2,
have_shared_viewable_channel/3
]).
-import(guild_member_list, [calculate_list_id/2, build_sync_response/4]).
-import(guild_permissions, [can_view_channel/4, find_member_by_user_id/2, find_channel_by_id/2]).
-type guild_state() :: map().
-type user_id() :: integer().
-type channel_id() :: integer().
-spec get_user_viewable_channels(integer(), map()) -> [integer()].
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-endif.
-spec get_user_viewable_channels(user_id(), guild_state()) -> [channel_id()].
get_user_viewable_channels(UserId, State) ->
Data = map_utils:ensure_map(map_utils:get_safe(State, data, #{})),
Channels = map_utils:ensure_list(maps:get(<<"channels">>, Data, [])),
Member = find_member_by_user_id(UserId, State),
Member = guild_permissions:find_member_by_user_id(UserId, State),
case Member of
undefined ->
[];
@@ -44,7 +50,9 @@ get_user_viewable_channels(UserId, State) ->
undefined ->
false;
_ ->
case can_view_channel(UserId, ChannelId, Member, State) of
case
guild_permissions:can_view_channel(UserId, ChannelId, Member, State)
of
true -> {true, ChannelId};
false -> false
end
@@ -54,62 +62,346 @@ get_user_viewable_channels(UserId, State) ->
)
end.
-spec viewable_channel_set(integer(), map()) -> sets:set().
-spec viewable_channel_set(user_id(), guild_state()) -> sets:set(channel_id()).
viewable_channel_set(UserId, State) when is_integer(UserId) ->
sets:from_list(get_user_viewable_channels(UserId, State));
case get_cached_viewable_channel_map(UserId, State) of
undefined ->
sets:from_list(get_user_viewable_channels(UserId, State));
ViewableChannelMap ->
sets:from_list(maps:keys(ViewableChannelMap))
end;
viewable_channel_set(_, _) ->
sets:new().
-spec have_shared_viewable_channel(integer(), integer(), map()) -> boolean().
have_shared_viewable_channel(UserId, OtherUserId, State) when is_integer(UserId), is_integer(OtherUserId), UserId =/= OtherUserId ->
-spec get_cached_viewable_channel_map(user_id(), guild_state()) -> map() | undefined.
get_cached_viewable_channel_map(UserId, State) ->
Sessions = maps:get(sessions, State, #{}),
maps:fold(
fun(_SessionId, SessionData, Acc) ->
case Acc of
undefined ->
case maps:get(user_id, SessionData, undefined) of
UserId ->
case maps:get(viewable_channels, SessionData, undefined) of
ViewableChannels when is_map(ViewableChannels) ->
ViewableChannels;
_ ->
undefined
end;
_ ->
undefined
end;
_ ->
Acc
end
end,
undefined,
Sessions
).
-spec have_shared_viewable_channel(user_id(), user_id(), guild_state()) -> boolean().
have_shared_viewable_channel(UserId, OtherUserId, State) when
is_integer(UserId), is_integer(OtherUserId), UserId =/= OtherUserId
->
SetA = viewable_channel_set(UserId, State),
SetB = viewable_channel_set(OtherUserId, State),
not sets:is_empty(sets:intersection(SetA, SetB));
have_shared_viewable_channel(_, _, _) ->
false.
-spec compute_and_dispatch_visibility_changes(map(), map()) -> ok.
compute_and_dispatch_visibility_changes(OldState, NewState) ->
Sessions = maps:get(sessions, NewState, #{}),
GuildId = maps:get(id, NewState, 0),
-spec filter_connected_session_entries(map()) -> [{binary(), map()}].
filter_connected_session_entries(Sessions) ->
[{Sid, S} || {Sid, S} <- maps:to_list(Sessions), maps:get(pending_connect, S, false) =/= true].
lists:foreach(
fun({SessionId, SessionData}) ->
-spec compute_and_dispatch_visibility_changes(guild_state(), guild_state()) -> guild_state().
compute_and_dispatch_visibility_changes(OldState, NewState) ->
compute_and_dispatch_visibility_changes_for_sessions(
filter_connected_session_entries(maps:get(sessions, NewState, #{})),
OldState,
NewState
).
-spec compute_and_dispatch_visibility_changes_for_channels(
[channel_id()], guild_state(), guild_state()
) ->
guild_state().
compute_and_dispatch_visibility_changes_for_channels(ChannelIds, OldState, NewState) ->
ValidChannelIds = lists:usort([Id || Id <- ChannelIds, is_integer(Id), Id > 0]),
case ValidChannelIds of
[] ->
compute_and_dispatch_visibility_changes(OldState, NewState);
_ ->
compute_and_dispatch_channel_visibility_changes_for_sessions(
filter_connected_session_entries(maps:get(sessions, NewState, #{})),
ValidChannelIds,
OldState,
NewState
)
end.
-spec compute_and_dispatch_visibility_changes_for_users([user_id()], guild_state(), guild_state()) ->
guild_state().
compute_and_dispatch_visibility_changes_for_users(UserIds, OldState, NewState) ->
Sessions = maps:get(sessions, NewState, #{}),
UserIdSet = sets:from_list(UserIds),
TargetSessions = lists:filter(
fun({_SessionId, SessionData}) ->
maps:get(pending_connect, SessionData, false) =/= true andalso
begin
SessionUserId = maps:get(user_id, SessionData, undefined),
is_integer(SessionUserId) andalso sets:is_element(SessionUserId, UserIdSet)
end
end,
maps:to_list(Sessions)
),
compute_and_dispatch_visibility_changes_for_sessions(TargetSessions, OldState, NewState).
-spec compute_and_dispatch_channel_visibility_changes_for_sessions(
[{binary(), map()}], [channel_id()], guild_state(), guild_state()
) ->
guild_state().
compute_and_dispatch_channel_visibility_changes_for_sessions(
SessionEntries, ChannelIds, OldState, NewState
) ->
GuildId = maps:get(id, NewState, 0),
lists:foldl(
fun({SessionId, SessionData}, AccState) ->
compute_channel_visibility_changes_for_session(
SessionId,
SessionData,
ChannelIds,
OldState,
AccState,
GuildId
)
end,
NewState,
SessionEntries
).
-spec compute_channel_visibility_changes_for_session(
binary(), map(), [channel_id()], guild_state(), guild_state(), integer()
) ->
guild_state().
compute_channel_visibility_changes_for_session(
SessionId, SessionData, ChannelIds, OldState, NewState, GuildId
) ->
UserId = maps:get(user_id, SessionData, undefined),
case is_integer(UserId) of
false ->
NewState;
true ->
Pid = maps:get(pid, SessionData, undefined),
OldMember = guild_permissions:find_member_by_user_id(UserId, OldState),
ConnectedSet = connected_voice_channel_set(UserId, NewState),
InitialViewableMap = ensure_viewable_channel_map(SessionData, UserId, OldState),
{FinalViewableMap, StateAfterChannels} = lists:foldl(
fun(ChannelId, {ViewableMapAcc, StateAcc}) ->
OldVisible = channel_is_visible(UserId, ChannelId, OldMember, OldState),
{StateAfterPreserve, NewVisible} = ensure_new_channel_visibility(
UserId,
ChannelId,
ConnectedSet,
StateAcc
),
UpdatedViewableMap = update_viewable_map_for_channel(
ViewableMapAcc, ChannelId, NewVisible
),
case {OldVisible, NewVisible} of
{true, false} ->
dispatch_channel_delete(ChannelId, Pid, OldState, GuildId),
{UpdatedViewableMap, StateAfterPreserve};
{false, true} ->
dispatch_channel_create(
ChannelId, Pid, StateAfterPreserve, GuildId
),
send_member_list_sync(
SessionId,
SessionData,
ChannelId,
GuildId,
StateAfterPreserve
),
{UpdatedViewableMap, StateAfterPreserve};
_ ->
{UpdatedViewableMap, StateAfterPreserve}
end
end,
{InitialViewableMap, NewState},
ChannelIds
),
guild_sessions:set_session_viewable_channels(
SessionId,
FinalViewableMap,
StateAfterChannels
)
end.
-spec ensure_viewable_channel_map(map(), user_id(), guild_state()) -> #{channel_id() => true}.
ensure_viewable_channel_map(SessionData, UserId, State) ->
case maps:get(viewable_channels, SessionData, undefined) of
ViewableChannels when is_map(ViewableChannels) ->
ViewableChannels;
_ ->
viewable_channel_map(sets:from_list(get_user_viewable_channels(UserId, State)))
end.
-spec ensure_new_channel_visibility(
user_id(), channel_id(), sets:set(channel_id()), guild_state()
) ->
{guild_state(), boolean()}.
ensure_new_channel_visibility(UserId, ChannelId, ConnectedSet, State) ->
NewMember = guild_permissions:find_member_by_user_id(UserId, State),
NewVisible0 = channel_is_visible(UserId, ChannelId, NewMember, State),
case {NewVisible0, sets:is_element(ChannelId, ConnectedSet)} of
{true, _} ->
{State, true};
{false, false} ->
{State, false};
{false, true} ->
case guild_virtual_channel_access:has_virtual_access(UserId, ChannelId, State) of
true ->
{State, true};
false ->
State1 = guild_virtual_channel_access:add_virtual_access(
UserId,
ChannelId,
State
),
State2 = guild_virtual_channel_access:clear_pending_join(
UserId,
ChannelId,
State1
),
{State2, true}
end
end.
-spec channel_is_visible(user_id(), channel_id(), map() | undefined, guild_state()) -> boolean().
channel_is_visible(UserId, ChannelId, Member, State) ->
guild_permissions:can_view_channel(UserId, ChannelId, Member, State).
-spec update_viewable_map_for_channel(map(), channel_id(), boolean()) -> map().
update_viewable_map_for_channel(ViewableMap, ChannelId, true) ->
maps:put(ChannelId, true, ViewableMap);
update_viewable_map_for_channel(ViewableMap, ChannelId, false) ->
maps:remove(ChannelId, ViewableMap).
-spec compute_and_dispatch_visibility_changes_for_sessions(
[{binary(), map()}], guild_state(), guild_state()
) -> guild_state().
compute_and_dispatch_visibility_changes_for_sessions(SessionEntries, OldState, NewState) ->
GuildId = maps:get(id, NewState, 0),
lists:foldl(
fun({SessionId, SessionData}, AccState) ->
UserId = maps:get(user_id, SessionData),
Pid = maps:get(pid, SessionData),
OldViewable = get_user_viewable_channels(UserId, OldState),
NewViewable = get_user_viewable_channels(UserId, NewState),
OldSet = sets:from_list(OldViewable),
OldSet = cached_viewable_channel_set(SessionData, UserId, OldState),
NewViewable = get_user_viewable_channels(UserId, AccState),
NewSet = sets:from_list(NewViewable),
Removed = sets:subtract(OldSet, NewSet),
Added = sets:subtract(NewSet, OldSet),
ConnectedSet = connected_voice_channel_set(UserId, AccState),
Removed0 = sets:subtract(OldSet, NewSet),
{StateWithAccess, PreservedSet} = preserve_connected_channels(
UserId,
Removed0,
ConnectedSet,
AccState
),
NewSet2 = sets:union(NewSet, PreservedSet),
StateWithCachedVisibility = guild_sessions:set_session_viewable_channels(
SessionId,
viewable_channel_map(NewSet2),
StateWithAccess
),
Removed = sets:subtract(OldSet, NewSet2),
Added = sets:subtract(NewSet2, OldSet),
lists:foreach(
fun(ChannelId) ->
dispatch_channel_delete(ChannelId, Pid, OldState, GuildId)
end,
sets:to_list(Removed)
),
lists:foreach(
fun(ChannelId) ->
dispatch_channel_create(ChannelId, Pid, NewState, GuildId),
send_member_list_sync(SessionId, SessionData, ChannelId, GuildId, NewState)
dispatch_channel_create(ChannelId, Pid, StateWithCachedVisibility, GuildId),
send_member_list_sync(
SessionId, SessionData, ChannelId, GuildId, StateWithCachedVisibility
)
end,
sets:to_list(Added)
)
),
StateWithCachedVisibility
end,
maps:to_list(Sessions)
),
ok.
NewState,
SessionEntries
).
-spec cached_viewable_channel_set(map(), user_id(), guild_state()) -> sets:set(channel_id()).
cached_viewable_channel_set(SessionData, UserId, State) ->
case maps:get(viewable_channels, SessionData, undefined) of
ViewableMap when is_map(ViewableMap) ->
sets:from_list(maps:keys(ViewableMap));
_ ->
sets:from_list(get_user_viewable_channels(UserId, State))
end.
-spec viewable_channel_map(sets:set(channel_id())) -> #{channel_id() => true}.
viewable_channel_map(ChannelSet) ->
sets:fold(
fun(ChannelId, Acc) ->
maps:put(ChannelId, true, Acc)
end,
#{},
ChannelSet
).
-spec connected_voice_channel_set(user_id(), guild_state()) -> sets:set(channel_id()).
connected_voice_channel_set(UserId, State) ->
VoiceStates = voice_state_utils:voice_states(State),
maps:fold(
fun(_ConnId, VoiceState, Acc) ->
case voice_state_utils:voice_state_user_id(VoiceState) of
UserId ->
case voice_state_utils:voice_state_channel_id(VoiceState) of
undefined -> Acc;
ChannelId -> sets:add_element(ChannelId, Acc)
end;
_ ->
Acc
end
end,
sets:new(),
VoiceStates
).
-spec preserve_connected_channels(
user_id(), sets:set(channel_id()), sets:set(channel_id()), guild_state()
) ->
{guild_state(), sets:set(channel_id())}.
preserve_connected_channels(UserId, RemovedSet, ConnectedSet, State) ->
ToPreserve = sets:intersection(RemovedSet, ConnectedSet),
UpdatedState = sets:fold(
fun(ChannelId, AccState) ->
case guild_virtual_channel_access:has_virtual_access(UserId, ChannelId, AccState) of
true ->
AccState;
false ->
State1 = guild_virtual_channel_access:add_virtual_access(
UserId, ChannelId, AccState
),
guild_virtual_channel_access:clear_pending_join(UserId, ChannelId, State1)
end
end,
State,
ToPreserve
),
{UpdatedState, ToPreserve}.
-spec dispatch_channel_delete(channel_id(), pid(), guild_state(), integer()) -> ok.
dispatch_channel_delete(ChannelId, SessionPid, OldState, GuildId) ->
case is_pid(SessionPid) of
true ->
case find_channel_by_id(ChannelId, OldState) of
case guild_permissions:find_channel_by_id(ChannelId, OldState) of
undefined ->
ok;
_Channel ->
@@ -123,10 +415,11 @@ dispatch_channel_delete(ChannelId, SessionPid, OldState, GuildId) ->
ok
end.
-spec dispatch_channel_create(channel_id(), pid(), guild_state(), integer()) -> ok.
dispatch_channel_create(ChannelId, SessionPid, NewState, GuildId) ->
case is_pid(SessionPid) of
true ->
case find_channel_by_id(ChannelId, NewState) of
case guild_permissions:find_channel_by_id(ChannelId, NewState) of
undefined ->
ok;
Channel ->
@@ -139,13 +432,14 @@ dispatch_channel_create(ChannelId, SessionPid, NewState, GuildId) ->
ok
end.
-spec send_member_list_sync(binary(), map(), channel_id(), integer(), guild_state()) -> ok.
send_member_list_sync(SessionId, SessionData, ChannelId, GuildId, State) ->
SessionPid = maps:get(pid, SessionData),
case is_pid(SessionPid) of
false ->
ok;
true ->
ListId = calculate_list_id(ChannelId, State),
ListId = guild_member_list:calculate_list_id(ChannelId, State),
MemberListSubs = maps:get(member_list_subscriptions, State, #{}),
ListSubs = maps:get(ListId, MemberListSubs, #{}),
Ranges = maps:get(SessionId, ListSubs, []),
@@ -156,15 +450,254 @@ send_member_list_sync(SessionId, SessionData, ChannelId, GuildId, State) ->
SessionUserId = maps:get(user_id, SessionData),
case can_send_member_list(SessionUserId, ChannelId, State) of
true ->
SyncResponse = build_sync_response(GuildId, ListId, Ranges, State),
SyncResponseWithChannel = maps:put(<<"channel_id">>, integer_to_binary(ChannelId), SyncResponse),
gen_server:cast(SessionPid, {dispatch, guild_member_list_update, SyncResponseWithChannel});
SyncResponse = guild_member_list:build_sync_response(
GuildId, ListId, Ranges, State
),
SyncResponseWithChannel = maps:put(
<<"channel_id">>, integer_to_binary(ChannelId), SyncResponse
),
gen_server:cast(
SessionPid,
{dispatch, guild_member_list_update, SyncResponseWithChannel}
);
false ->
ok
end
end
end.
-spec can_send_member_list(user_id() | undefined, channel_id(), guild_state()) -> boolean().
can_send_member_list(UserId, ChannelId, State) ->
is_integer(UserId) andalso
guild_permissions:can_view_channel(UserId, ChannelId, undefined, State).
-ifdef(TEST).
get_user_viewable_channels_returns_empty_for_non_member_test() ->
State = #{
data => #{
<<"channels">> => [#{<<"id">> => <<"100">>, <<"type">> => 0}],
<<"members">> => []
}
},
?assertEqual([], get_user_viewable_channels(999, State)).
viewable_channel_set_returns_empty_for_invalid_user_test() ->
State = #{data => #{}},
?assertEqual(sets:new(), viewable_channel_set(undefined, State)).
have_shared_viewable_channel_same_user_test() ->
State = #{data => #{}},
?assertEqual(false, have_shared_viewable_channel(100, 100, State)).
preserves_connected_channel_visibility_on_permission_loss_test() ->
UserId = 10,
GuildId = 1,
ChannelId = 5,
ViewPerm = constants:view_channel_permission(),
OldState = visibility_state(GuildId, UserId, ChannelId, ViewPerm, true),
NewState = visibility_state(GuildId, UserId, ChannelId, 0, true),
UpdatedState = compute_and_dispatch_visibility_changes(OldState, NewState),
?assert(guild_virtual_channel_access:has_virtual_access(UserId, ChannelId, UpdatedState)),
?assertEqual(
false,
guild_virtual_channel_access:is_pending_join(UserId, ChannelId, UpdatedState)
),
?assert(guild_permissions:can_view_channel(UserId, ChannelId, undefined, UpdatedState)).
does_not_add_virtual_access_when_not_connected_test() ->
UserId = 20,
GuildId = 2,
ChannelId = 6,
ViewPerm = constants:view_channel_permission(),
OldState = visibility_state(GuildId, UserId, ChannelId, ViewPerm, false),
NewState = visibility_state(GuildId, UserId, ChannelId, 0, false),
UpdatedState = compute_and_dispatch_visibility_changes(OldState, NewState),
?assertNot(guild_virtual_channel_access:has_virtual_access(UserId, ChannelId, UpdatedState)).
does_not_add_virtual_access_when_permission_remains_test() ->
UserId = 30,
GuildId = 3,
ChannelId = 7,
ViewPerm = constants:view_channel_permission(),
OldState = visibility_state(GuildId, UserId, ChannelId, ViewPerm, true),
NewState = visibility_state(GuildId, UserId, ChannelId, ViewPerm, true),
UpdatedState = compute_and_dispatch_visibility_changes(OldState, NewState),
?assertNot(guild_virtual_channel_access:has_virtual_access(UserId, ChannelId, UpdatedState)).
compute_and_dispatch_visibility_changes_for_users_targets_selected_users_test() ->
GuildId = 33,
ChannelId = 77,
RoleId = 101,
ViewPerm = constants:view_channel_permission(),
Session10 = #{
session_id => <<"s10">>,
user_id => 10,
pid => self(),
viewable_channels => #{ChannelId => true}
},
Session20 = #{
session_id => <<"s20">>,
user_id => 20,
pid => self(),
viewable_channels => #{ChannelId => true}
},
OldState = #{
id => GuildId,
sessions => #{<<"s10">> => Session10, <<"s20">> => Session20},
data => #{
<<"guild">> => #{<<"owner_id">> => <<"999">>},
<<"roles">> => [
#{<<"id">> => integer_to_binary(GuildId), <<"permissions">> => <<"0">>},
#{<<"id">> => integer_to_binary(RoleId), <<"permissions">> => integer_to_binary(ViewPerm)}
],
<<"members">> => #{
10 => #{<<"user">> => #{<<"id">> => <<"10">>}, <<"roles">> => [integer_to_binary(RoleId)]},
20 => #{<<"user">> => #{<<"id">> => <<"20">>}, <<"roles">> => [integer_to_binary(RoleId)]}
},
<<"channels">> => [#{<<"id">> => integer_to_binary(ChannelId), <<"permission_overwrites">> => []}]
}
},
NewState = #{
id => GuildId,
sessions => #{<<"s10">> => Session10, <<"s20">> => Session20},
data => #{
<<"guild">> => #{<<"owner_id">> => <<"999">>},
<<"roles">> => [
#{<<"id">> => integer_to_binary(GuildId), <<"permissions">> => <<"0">>},
#{<<"id">> => integer_to_binary(RoleId), <<"permissions">> => integer_to_binary(ViewPerm)}
],
<<"members">> => #{
10 => #{<<"user">> => #{<<"id">> => <<"10">>}, <<"roles">> => []},
20 => #{<<"user">> => #{<<"id">> => <<"20">>}, <<"roles">> => [integer_to_binary(RoleId)]}
},
<<"channels">> => [#{<<"id">> => integer_to_binary(ChannelId), <<"permission_overwrites">> => []}]
}
},
UpdatedState = compute_and_dispatch_visibility_changes_for_users([10], OldState, NewState),
UpdatedSessions = maps:get(sessions, UpdatedState),
UpdatedSession10 = maps:get(<<"s10">>, UpdatedSessions),
UpdatedSession20 = maps:get(<<"s20">>, UpdatedSessions),
?assertEqual(false, maps:is_key(ChannelId, maps:get(viewable_channels, UpdatedSession10, #{}))),
?assertEqual(true, maps:is_key(ChannelId, maps:get(viewable_channels, UpdatedSession20, #{}))).
compute_and_dispatch_visibility_changes_for_channels_limits_to_changed_channels_test() ->
GuildId = 44,
UserId = 10,
ChannelA = 100,
ChannelB = 101,
ViewPerm = constants:view_channel_permission(),
BaseRole = #{
<<"id">> => integer_to_binary(GuildId),
<<"permissions">> => integer_to_binary(ViewPerm)
},
Session = #{
session_id => <<"s10">>,
user_id => UserId,
pid => self(),
viewable_channels => #{ChannelA => true, ChannelB => true}
},
OldState = #{
id => GuildId,
sessions => #{<<"s10">> => Session},
data => #{
<<"guild">> => #{<<"owner_id">> => <<"999">>},
<<"roles">> => [BaseRole],
<<"members">> => #{
UserId => #{
<<"user">> => #{<<"id">> => integer_to_binary(UserId)},
<<"roles">> => []
}
},
<<"channels">> => [
#{<<"id">> => integer_to_binary(ChannelA), <<"permission_overwrites">> => []},
#{<<"id">> => integer_to_binary(ChannelB), <<"permission_overwrites">> => []}
]
}
},
NewState = #{
id => GuildId,
sessions => #{<<"s10">> => Session},
data => #{
<<"guild">> => #{<<"owner_id">> => <<"999">>},
<<"roles">> => [BaseRole],
<<"members">> => #{
UserId => #{
<<"user">> => #{<<"id">> => integer_to_binary(UserId)},
<<"roles">> => []
}
},
<<"channels">> => [
#{
<<"id">> => integer_to_binary(ChannelA),
<<"permission_overwrites">> => [
#{
<<"id">> => integer_to_binary(GuildId),
<<"type">> => 0,
<<"allow">> => <<"0">>,
<<"deny">> => integer_to_binary(ViewPerm)
}
]
},
#{<<"id">> => integer_to_binary(ChannelB), <<"permission_overwrites">> => []}
]
}
},
UpdatedState = compute_and_dispatch_visibility_changes_for_channels(
[ChannelA],
OldState,
NewState
),
UpdatedSession = maps:get(<<"s10">>, maps:get(sessions, UpdatedState)),
UpdatedViewable = maps:get(viewable_channels, UpdatedSession, #{}),
?assertEqual(false, maps:is_key(ChannelA, UpdatedViewable)),
?assertEqual(true, maps:is_key(ChannelB, UpdatedViewable)).
visibility_state(GuildId, UserId, ChannelId, Perms, Connected) ->
VoiceStates =
case Connected of
true ->
#{
<<"conn">> => #{
<<"user_id">> => integer_to_binary(UserId),
<<"guild_id">> => integer_to_binary(GuildId),
<<"channel_id">> => integer_to_binary(ChannelId),
<<"connection_id">> => <<"conn">>
}
};
false ->
#{}
end,
Sessions = #{<<"s1">> => #{user_id => UserId, pid => self()}},
Data = #{
<<"guild">> => #{<<"owner_id">> => <<"999">>},
<<"roles">> => [
#{
<<"id">> => integer_to_binary(GuildId),
<<"permissions">> => integer_to_binary(Perms)
}
],
<<"members">> => [
#{<<"user">> => #{<<"id">> => integer_to_binary(UserId)}, <<"roles">> => []}
],
<<"channels">> => [
#{<<"id">> => integer_to_binary(ChannelId), <<"permission_overwrites">> => []}
]
},
#{
id => GuildId,
data => Data,
sessions => Sessions,
voice_states => VoiceStates
}.
filter_connected_session_entries_excludes_pending_test() ->
Normal = #{session_id => <<"s1">>, user_id => 1, pending_connect => false},
Pending = #{session_id => <<"s2">>, user_id => 2, pending_connect => true},
NoPending = #{session_id => <<"s3">>, user_id => 3},
Sessions = #{<<"s1">> => Normal, <<"s2">> => Pending, <<"s3">> => NoPending},
Result = filter_connected_session_entries(Sessions),
ResultIds = lists:sort([Sid || {Sid, _} <- Result]),
?assertEqual([<<"s1">>, <<"s3">>], ResultIds).
-endif.

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -17,6 +17,8 @@
-module(dm_voice).
-import(guild_voice_unclaimed_account_utils, [parse_unclaimed_error/1]).
-export([voice_state_update/2]).
-export([get_voice_state/2]).
-export([get_voice_token/6]).
@@ -24,58 +26,48 @@
-export([broadcast_voice_state_update/3]).
-export([join_or_create_call/5, join_or_create_call/6]).
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-endif.
-type dm_state() :: map().
-type voice_state() :: map().
-type voice_state_map() :: #{binary() => voice_state()}.
-spec voice_state_update(map(), dm_state()) ->
{reply, map(), dm_state()} | {reply, {error, atom(), atom()}, dm_state()}.
voice_state_update(Request, State) ->
#{
user_id := UserId,
channel_id := ChannelId
} = Request,
ConnectionId = maps:get(connection_id, Request, undefined),
VoiceStates = maps:get(dm_voice_states, State, #{}),
case ChannelId of
null ->
handle_dm_disconnect(ConnectionId, UserId, VoiceStates, State);
ChannelIdValue ->
Channels = maps:get(channels, State, #{}),
UserId = maps:get(user_id, State),
logger:info(
"[dm_voice] Looking up channel ~p for user ~p, channels map has ~p entries",
[ChannelIdValue, UserId, maps:size(Channels)]
),
case maps:get(ChannelIdValue, Channels, undefined) of
undefined ->
logger:info(
"[dm_voice] Channel ~p not found locally for user ~p, trying RPC fallback",
[ChannelIdValue, UserId]
),
case fetch_dm_channel_via_rpc(ChannelIdValue, UserId) of
{ok, Channel} ->
NewChannels = maps:put(ChannelIdValue, Channel, Channels),
NewState = maps:put(channels, NewChannels, State),
logger:info(
"[dm_voice] RPC fallback found channel ~p for user ~p, added to local map",
[ChannelIdValue, UserId]
),
handle_dm_voice_with_channel(
Channel, ChannelIdValue, UserId, Request, NewState
);
{error, Reason} ->
logger:warning(
"[dm_voice] Channel ~p not found for user ~p via RPC: ~p",
[ChannelIdValue, UserId, Reason]
),
{error, _Reason} ->
{reply, gateway_errors:error(dm_channel_not_found), State}
end;
Channel ->
logger:info(
"[dm_voice] Found channel ~p for user ~p, type: ~p",
[ChannelIdValue, UserId, maps:get(<<"type">>, Channel, 0)]
),
handle_dm_voice_with_channel(Channel, ChannelIdValue, UserId, Request, State)
end
end.
-spec handle_dm_voice_with_channel(map(), integer(), integer(), map(), dm_state()) ->
{reply, map(), dm_state()} | {reply, {error, atom(), atom()}, dm_state()}.
handle_dm_voice_with_channel(Channel, ChannelIdValue, UserId, Request, State) ->
#{
session_id := SessionId,
@@ -86,11 +78,10 @@ handle_dm_voice_with_channel(Channel, ChannelIdValue, UserId, Request, State) ->
SelfStream = maps:get(self_stream, Request, false),
ConnectionId = maps:get(connection_id, Request, undefined),
IsMobile = maps:get(is_mobile, Request, false),
ViewerStreamKey = maps:get(viewer_stream_key, Request, undefined),
ViewerStreamKeys = maps:get(viewer_stream_keys, Request, undefined),
Latitude = maps:get(latitude, Request, null),
Longitude = maps:get(longitude, Request, null),
VoiceStates = maps:get(dm_voice_states, State, #{}),
ChannelType = maps:get(<<"type">>, Channel, 0),
case is_dm_channel_type(ChannelType) of
false ->
@@ -109,7 +100,7 @@ handle_dm_voice_with_channel(Channel, ChannelIdValue, UserId, Request, State) ->
SelfDeaf,
SelfVideo,
SelfStream,
ViewerStreamKey,
ViewerStreamKeys,
IsMobile,
Latitude,
Longitude,
@@ -119,6 +110,8 @@ handle_dm_voice_with_channel(Channel, ChannelIdValue, UserId, Request, State) ->
end
end.
-spec handle_dm_disconnect(binary() | undefined, integer(), voice_state_map(), dm_state()) ->
{reply, map(), dm_state()} | {reply, {error, atom(), atom()}, dm_state()}.
handle_dm_disconnect(undefined, _UserId, _VoiceStates, State) ->
{reply, gateway_errors:error(voice_missing_connection_id), State};
handle_dm_disconnect(ConnectionId, _UserId, VoiceStates, State) ->
@@ -128,13 +121,11 @@ handle_dm_disconnect(ConnectionId, _UserId, VoiceStates, State) ->
OldVoiceState ->
NewVoiceStates = maps:remove(ConnectionId, VoiceStates),
NewState = maps:put(dm_voice_states, NewVoiceStates, State),
OldChannelId = maps:get(<<"channel_id">>, OldVoiceState, null),
DisconnectVoiceState = maps:put(
<<"channel_id">>, null, maps:put(<<"connection_id">>, ConnectionId, OldVoiceState)
),
SessionId = maps:get(id, State),
case OldChannelId of
null ->
ok;
@@ -154,7 +145,6 @@ handle_dm_disconnect(ConnectionId, _UserId, VoiceStates, State) ->
end
end)
end,
case OldChannelId of
null ->
ok;
@@ -164,15 +154,29 @@ handle_dm_disconnect(ConnectionId, _UserId, VoiceStates, State) ->
broadcast_voice_state_update(
OldChannelIdInt, DisconnectVoiceState, NewState
);
{error, _, Reason} ->
logger:warning("[dm_voice] Invalid channel_id: ~p", [Reason]),
{error, _, _Reason} ->
ok
end
end,
{reply, #{success => true}, NewState}
end.
-spec handle_dm_connect_or_update(
binary() | undefined | null,
integer(),
integer(),
binary(),
boolean(),
boolean(),
boolean(),
boolean(),
term(),
boolean(),
term(),
term(),
voice_state_map(),
dm_state()
) -> {reply, map(), dm_state()} | {reply, {error, atom(), atom()}, dm_state()}.
handle_dm_connect_or_update(
ConnectionId,
ChannelIdValue,
@@ -182,7 +186,7 @@ handle_dm_connect_or_update(
SelfDeaf,
SelfVideo,
SelfStream,
ViewerStreamKey,
ViewerStreamKeys,
IsMobile,
Latitude,
Longitude,
@@ -190,7 +194,7 @@ handle_dm_connect_or_update(
State
) when ConnectionId =:= undefined; ConnectionId =:= null ->
VoiceStates = maps:get(dm_voice_states, State, #{}),
case validate_dm_viewer_stream_key(ViewerStreamKey, ChannelIdValue, VoiceStates) of
case validate_dm_viewer_stream_keys(ViewerStreamKeys, ChannelIdValue, VoiceStates) of
{error, ErrorAtom} ->
{reply, gateway_errors:error(ErrorAtom), State};
{ok, ParsedViewerKey} ->
@@ -218,7 +222,7 @@ handle_dm_connect_or_update(
SelfDeaf,
SelfVideo,
SelfStream,
ViewerStreamKey,
ViewerStreamKeys,
IsMobile,
_Latitude,
_Longitude,
@@ -229,43 +233,49 @@ handle_dm_connect_or_update(
undefined ->
{reply, gateway_errors:error(voice_connection_not_found), State};
ExistingVoiceState ->
ExistingSessionId = maps:get(<<"session_id">>, ExistingVoiceState, undefined),
EffectiveSessionId = resolve_effective_session_id(ExistingSessionId, SessionId),
ValidViewerKey = validate_dm_viewer_stream_key(
ViewerStreamKey, ChannelIdValue, VoiceStates
),
case ValidViewerKey of
{error, ErrorAtom} ->
{reply, gateway_errors:error(ErrorAtom), State};
{ok, ParsedViewerKey} ->
UpdatedVoiceState = ExistingVoiceState#{
<<"channel_id">> => integer_to_binary(ChannelIdValue),
<<"session_id">> => EffectiveSessionId,
<<"self_mute">> => SelfMute,
<<"self_deaf">> => SelfDeaf,
<<"self_video">> => SelfVideo,
<<"self_stream">> => SelfStream,
<<"is_mobile">> => IsMobile,
<<"viewer_stream_key">> => ParsedViewerKey
},
NewVoiceStates = maps:put(ConnectionId, UpdatedVoiceState, VoiceStates),
NewState = maps:put(dm_voice_states, NewVoiceStates, State),
broadcast_voice_state_update(ChannelIdValue, UpdatedVoiceState, NewState),
OldChannelId = maps:get(<<"channel_id">>, ExistingVoiceState, null),
NewChannelIdBin = integer_to_binary(ChannelIdValue),
NeedsToken = OldChannelId =/= NewChannelIdBin,
maybe_spawn_join_call(
NeedsToken, ChannelIdValue, UserId, UpdatedVoiceState, SessionId
case guild_voice_state:user_matches_voice_state(ExistingVoiceState, UserId) of
false ->
{reply, gateway_errors:error(voice_user_mismatch), State};
true ->
ExistingSessionId = maps:get(<<"session_id">>, ExistingVoiceState, undefined),
EffectiveSessionId = resolve_effective_session_id(ExistingSessionId, SessionId),
ValidViewerKey = validate_dm_viewer_stream_keys(
ViewerStreamKeys, ChannelIdValue, VoiceStates
),
{reply, #{success => true, needs_token => NeedsToken}, NewState}
case ValidViewerKey of
{error, ErrorAtom} ->
{reply, gateway_errors:error(ErrorAtom), State};
{ok, ParsedViewerKey} ->
UpdatedVoiceState = ExistingVoiceState#{
<<"channel_id">> => integer_to_binary(ChannelIdValue),
<<"session_id">> => EffectiveSessionId,
<<"self_mute">> => SelfMute,
<<"self_deaf">> => SelfDeaf,
<<"self_video">> => SelfVideo,
<<"self_stream">> => SelfStream,
<<"is_mobile">> => IsMobile,
<<"viewer_stream_keys">> => ParsedViewerKey
},
NewVoiceStates = maps:put(ConnectionId, UpdatedVoiceState, VoiceStates),
NewState = maps:put(dm_voice_states, NewVoiceStates, State),
broadcast_voice_state_update(ChannelIdValue, UpdatedVoiceState, NewState),
OldChannelId = maps:get(<<"channel_id">>, ExistingVoiceState, null),
NewChannelIdBin = integer_to_binary(ChannelIdValue),
NeedsToken = OldChannelId =/= NewChannelIdBin,
maybe_spawn_join_call(
NeedsToken,
ChannelIdValue,
UserId,
UpdatedVoiceState,
EffectiveSessionId,
State
),
{reply, #{success => true, needs_token => NeedsToken}, NewState}
end
end
end.
-spec normalize_session_id(term()) -> binary() | undefined.
normalize_session_id(undefined) ->
undefined;
normalize_session_id(SessionId) when is_binary(SessionId) ->
@@ -281,32 +291,50 @@ normalize_session_id(SessionId) ->
_:_ -> SessionId
end.
validate_dm_viewer_stream_key(RawKey, ChannelIdValue, VoiceStates) ->
case RawKey of
undefined ->
{ok, null};
null ->
{ok, null};
_ when not is_binary(RawKey) -> {error, voice_invalid_state};
_ ->
case voice_state_utils:parse_stream_key(RawKey) of
{ok, #{scope := dm, channel_id := ParsedChannelId, connection_id := ConnId}} when
ParsedChannelId =:= ChannelIdValue
->
case maps:get(ConnId, VoiceStates, undefined) of
undefined ->
{error, voice_connection_not_found};
StreamVS ->
case map_utils:get_integer(StreamVS, <<"channel_id">>, undefined) of
ChannelIdValue -> {ok, RawKey};
_ -> {error, voice_invalid_state}
end
end;
_ ->
{error, voice_invalid_state}
end
-spec validate_dm_viewer_stream_keys(term(), integer(), voice_state_map()) ->
{ok, list()} | {error, atom()}.
validate_dm_viewer_stream_keys(RawKeys, _ChannelIdValue, _VoiceStates) when RawKeys =:= undefined; RawKeys =:= null ->
{ok, []};
validate_dm_viewer_stream_keys(RawKeys, _ChannelIdValue, _VoiceStates) when not is_list(RawKeys) ->
{error, voice_invalid_state};
validate_dm_viewer_stream_keys(Keys, ChannelIdValue, VoiceStates) ->
validate_dm_viewer_stream_keys_list(Keys, ChannelIdValue, VoiceStates, []).
-spec validate_dm_viewer_stream_keys_list(list(), integer(), voice_state_map(), list()) ->
{ok, list()} | {error, atom()}.
validate_dm_viewer_stream_keys_list([], _ChannelIdValue, _VoiceStates, Acc) ->
{ok, lists:reverse(Acc)};
validate_dm_viewer_stream_keys_list([Key | Rest], ChannelIdValue, VoiceStates, Acc) ->
case validate_single_dm_viewer_stream_key(Key, ChannelIdValue, VoiceStates) of
{ok, ValidKey} ->
validate_dm_viewer_stream_keys_list(Rest, ChannelIdValue, VoiceStates, [ValidKey | Acc]);
{error, _} = Error ->
Error
end.
-spec validate_single_dm_viewer_stream_key(term(), integer(), voice_state_map()) ->
{ok, binary()} | {error, atom()}.
validate_single_dm_viewer_stream_key(RawKey, _ChannelIdValue, _VoiceStates) when not is_binary(RawKey) ->
{error, voice_invalid_state};
validate_single_dm_viewer_stream_key(RawKey, ChannelIdValue, VoiceStates) ->
case voice_state_utils:parse_stream_key(RawKey) of
{ok, #{scope := dm, channel_id := ParsedChannelId, connection_id := ConnId}} when
ParsedChannelId =:= ChannelIdValue
->
case maps:get(ConnId, VoiceStates, undefined) of
undefined ->
{error, voice_connection_not_found};
StreamVS ->
case map_utils:get_integer(StreamVS, <<"channel_id">>, undefined) of
ChannelIdValue -> {ok, RawKey};
_ -> {error, voice_invalid_state}
end
end;
_ ->
{error, voice_invalid_state}
end.
-spec resolve_effective_session_id(term(), term()) -> binary() | undefined.
resolve_effective_session_id(ExistingSessionId, RequestSessionId) ->
ExistingNormalized = normalize_session_id(ExistingSessionId),
RequestNormalized = normalize_session_id(RequestSessionId),
@@ -316,28 +344,43 @@ resolve_effective_session_id(ExistingSessionId, RequestSessionId) ->
_ -> ExistingNormalized
end.
maybe_spawn_join_call(false, _ChannelId, _UserId, _VoiceState, _SessionId) ->
-spec maybe_spawn_join_call(
boolean(), integer(), integer(), voice_state(), binary() | undefined, dm_state()
) -> ok.
maybe_spawn_join_call(false, _ChannelId, _UserId, _VoiceState, _SessionId, _State) ->
ok;
maybe_spawn_join_call(true, ChannelId, UserId, VoiceState, SessionId) ->
spawn(fun() ->
try
join_or_create_call(ChannelId, UserId, VoiceState, SessionId, self())
catch
_:_ -> ok
end
end).
maybe_spawn_join_call(true, ChannelId, UserId, VoiceState, SessionId, State) when
is_binary(SessionId)
->
SessionPid = maps:get(session_pid, State, undefined),
case SessionPid of
Pid when is_pid(Pid) ->
spawn(fun() ->
try
join_or_create_call(ChannelId, UserId, VoiceState, SessionId, Pid)
catch
_:_ -> ok
end
end),
ok;
_ ->
ok
end;
maybe_spawn_join_call(true, _ChannelId, _UserId, _VoiceState, _SessionId, _State) ->
ok.
-spec get_voice_token(integer(), integer(), binary(), pid(), term(), term()) -> ok | error.
get_voice_token(ChannelId, UserId, _SessionId, SessionPid, Latitude, Longitude) ->
Req = voice_utils:build_voice_token_rpc_request(
null, ChannelId, UserId, null, Latitude, Longitude
),
case rpc_client:call(Req) of
Region = resolve_call_region(ChannelId),
ReqWithRegion = voice_utils:add_rtc_region_to_request(Req, Region),
case rpc_client:call(ReqWithRegion) of
{ok, Data} ->
Token = maps:get(<<"token">>, Data),
Endpoint = maps:get(<<"endpoint">>, Data),
ConnectionId = maps:get(<<"connectionId">>, Data),
SessionPid !
{voice_server_update, #{
channel_id => integer_to_binary(ChannelId),
@@ -357,6 +400,20 @@ get_voice_token(ChannelId, UserId, _SessionId, SessionPid, Latitude, Longitude)
error
end.
-spec get_dm_voice_token_and_create_state(
integer(),
integer(),
binary(),
boolean(),
boolean(),
boolean(),
boolean(),
term(),
boolean(),
term(),
term(),
dm_state()
) -> {reply, map(), dm_state()} | {reply, {error, atom(), atom()}, dm_state()}.
get_dm_voice_token_and_create_state(
UserId,
ChannelId,
@@ -365,7 +422,7 @@ get_dm_voice_token_and_create_state(
SelfDeaf,
SelfVideo,
SelfStream,
ViewerStreamKey,
ViewerStreamKeys,
IsMobile,
Latitude,
Longitude,
@@ -374,8 +431,9 @@ get_dm_voice_token_and_create_state(
Req = voice_utils:build_voice_token_rpc_request(
null, ChannelId, UserId, null, Latitude, Longitude
),
case rpc_client:call(Req) of
Region = resolve_call_region(ChannelId, State),
ReqWithRegion = voice_utils:add_rtc_region_to_request(Req, Region),
case rpc_client:call(ReqWithRegion) of
{ok, Data} ->
handle_dm_token_success(
Data,
@@ -386,7 +444,7 @@ get_dm_voice_token_and_create_state(
SelfDeaf,
SelfVideo,
SelfStream,
ViewerStreamKey,
ViewerStreamKeys,
IsMobile,
State
);
@@ -399,6 +457,48 @@ get_dm_voice_token_and_create_state(
{reply, gateway_errors:error(voice_token_failed), State}
end.
-spec resolve_call_region(integer()) -> binary() | null.
resolve_call_region(ChannelId) ->
case gen_server:call(call_manager, {lookup, ChannelId}, 5000) of
{ok, CallPid} ->
case gen_server:call(CallPid, {get_state}, 5000) of
{ok, CallData} -> maps:get(region, CallData, null);
_ -> null
end;
_ ->
null
end.
-spec resolve_call_region(integer(), dm_state()) -> binary() | null.
resolve_call_region(ChannelId, State) ->
Calls = maps:get(calls, State, #{}),
case maps:get(ChannelId, Calls, undefined) of
{CallPid, _Ref} when is_pid(CallPid) ->
try
case gen_server:call(CallPid, {get_state}, 250) of
{ok, CallData} -> maps:get(region, CallData, null);
_ -> null
end
catch
_:_ -> null
end;
_ ->
null
end.
-spec handle_dm_token_success(
map(),
integer(),
integer(),
binary(),
boolean(),
boolean(),
boolean(),
boolean(),
term(),
boolean(),
dm_state()
) -> {reply, map(), dm_state()}.
handle_dm_token_success(
Data,
UserId,
@@ -408,14 +508,13 @@ handle_dm_token_success(
SelfDeaf,
SelfVideo,
SelfStream,
ViewerStreamKey,
ViewerStreamKeys,
IsMobile,
State
) ->
Token = maps:get(<<"token">>, Data),
Endpoint = maps:get(<<"endpoint">>, Data),
ConnectionId = maps:get(<<"connectionId">>, Data),
VoiceState = #{
<<"user_id">> => integer_to_binary(UserId),
<<"channel_id">> => integer_to_binary(ChannelId),
@@ -426,15 +525,12 @@ handle_dm_token_success(
<<"self_deaf">> => SelfDeaf,
<<"self_video">> => SelfVideo,
<<"self_stream">> => SelfStream,
<<"viewer_stream_key">> => ViewerStreamKey
<<"viewer_stream_keys">> => ViewerStreamKeys
},
VoiceStates = maps:get(dm_voice_states, State, #{}),
NewVoiceStates = maps:put(ConnectionId, VoiceState, VoiceStates),
NewState = maps:put(dm_voice_states, NewVoiceStates, State),
broadcast_voice_state_update(ChannelId, VoiceState, NewState),
SessionPid = maps:get(session_pid, State),
VoiceServerUpdate = #{
<<"token">> => Token,
@@ -443,7 +539,6 @@ handle_dm_token_success(
<<"connection_id">> => ConnectionId
},
gen_server:cast(SessionPid, {dispatch, voice_server_update, VoiceServerUpdate}),
GatewaySessionId = maps:get(id, State),
spawn(fun() ->
try
@@ -452,23 +547,22 @@ handle_dm_token_success(
_:_ -> ok
end
end),
{reply, #{success => true, needs_token => false, connection_id => ConnectionId}, NewState}.
-spec get_voice_state(binary(), dm_state()) -> voice_state() | undefined.
get_voice_state(ConnectionId, State) ->
VoiceStates = maps:get(dm_voice_states, State, #{}),
maps:get(ConnectionId, VoiceStates, undefined).
-spec disconnect_voice_user(integer(), dm_state()) -> {reply, map(), dm_state()}.
disconnect_voice_user(UserId, State) ->
VoiceStates = maps:get(dm_voice_states, State, #{}),
UserVoiceStates = maps:filter(
fun(_ConnectionId, VoiceState) ->
maps:get(<<"user_id">>, VoiceState) =:= integer_to_binary(UserId)
end,
VoiceStates
),
case maps:size(UserVoiceStates) of
0 ->
{reply, #{success => true}, State};
@@ -481,71 +575,58 @@ disconnect_voice_user(UserId, State) ->
UserVoiceStates
),
NewState = maps:put(dm_voice_states, NewVoiceStates, State),
maps:foreach(
fun(_ConnectionId, VoiceState) ->
ChannelId = maps:get(<<"channel_id">>, VoiceState, null),
DisconnectVoiceState = maps:put(
<<"channel_id">>,
null,
maps:put(<<"connection_id">>, _ConnectionId, VoiceState)
),
case ChannelId of
null ->
ok;
_ ->
case validation:validate_snowflake(<<"channel_id">>, ChannelId) of
{ok, ChannelIdInt} ->
broadcast_voice_state_update(
ChannelIdInt, DisconnectVoiceState, NewState
);
{error, _, Reason} ->
logger:warning(
"[dm_voice] Invalid channel_id in voice state: ~p", [Reason]
),
ok
end
end
end,
UserVoiceStates
),
spawn(fun() ->
maps:foreach(
fun(ConnId, VoiceState) ->
ChannelId = maps:get(<<"channel_id">>, VoiceState, null),
DisconnectVoiceState = maps:put(
<<"channel_id">>,
null,
maps:put(<<"connection_id">>, ConnId, VoiceState)
),
case ChannelId of
null ->
ok;
_ ->
case validation:validate_snowflake(<<"channel_id">>, ChannelId) of
{ok, ChannelIdInt} ->
broadcast_voice_state_update(
ChannelIdInt, DisconnectVoiceState, NewState
);
{error, _, _Reason} ->
ok
end
end
end,
UserVoiceStates
)
end),
{reply, #{success => true}, NewState}
end.
parse_unclaimed_error(Body) when is_binary(Body) ->
try jsx:decode(Body, [return_maps]) of
#{<<"code">> := <<"UNCLAIMED_ACCOUNT_RESTRICTED">>} -> true;
#{<<"error">> := #{<<"code">> := <<"UNCLAIMED_ACCOUNT_RESTRICTED">>}} -> true;
_ -> false
catch
_:_ -> false
end;
parse_unclaimed_error(_) ->
false.
-spec broadcast_voice_state_update(integer(), voice_state(), dm_state()) -> ok.
broadcast_voice_state_update(ChannelId, VoiceState, State) ->
Channels = maps:get(channels, State, #{}),
case maps:get(ChannelId, Channels, undefined) of
undefined ->
ok;
Channel ->
Recipients = maps:get(<<"recipient_ids">>, Channel, []),
UserId = maps:get(user_id, State),
AllRecipients = lists:usort([UserId | Recipients]),
Event = voice_state_update,
lists:foreach(
fun(RecipientId) ->
presence_manager:dispatch_to_user(RecipientId, Event, VoiceState)
end,
AllRecipients
)
spawn(fun() ->
lists:foreach(
fun(RecipientId) ->
presence_manager:dispatch_to_user(RecipientId, Event, VoiceState)
end,
AllRecipients
)
end),
ok
end.
-spec check_recipient(integer(), integer(), dm_state()) -> boolean().
check_recipient(UserId, ChannelId, State) ->
Channels = maps:get(channels, State, #{}),
case maps:get(ChannelId, Channels, undefined) of
@@ -556,20 +637,24 @@ check_recipient(UserId, ChannelId, State) ->
is_dm_channel_type(ChannelType) andalso is_channel_recipient(UserId, Channel, State)
end.
-spec is_dm_channel_type(integer()) -> boolean().
is_dm_channel_type(1) -> true;
is_dm_channel_type(3) -> true;
is_dm_channel_type(_) -> false.
-spec is_channel_recipient(integer(), map(), dm_state()) -> boolean().
is_channel_recipient(UserId, Channel, State) ->
Recipients = maps:get(<<"recipient_ids">>, Channel, []),
CurrentUserId = maps:get(user_id, State),
lists:member(UserId, [CurrentUserId | Recipients]).
-spec join_or_create_call(integer(), integer(), voice_state(), binary(), pid()) -> ok.
join_or_create_call(ChannelId, UserId, VoiceState, SessionId, SessionPid) ->
join_or_create_call(ChannelId, UserId, VoiceState, SessionId, SessionPid, 10).
join_or_create_call(_ChannelId, UserId, _VoiceState, _SessionId, _SessionPid, 0) ->
logger:warning("[dm_voice] Failed to join call after retries, user ~p could not join", [UserId]),
-spec join_or_create_call(integer(), integer(), voice_state(), binary(), pid(), non_neg_integer()) ->
ok.
join_or_create_call(_ChannelId, _UserId, _VoiceState, _SessionId, _SessionPid, 0) ->
ok;
join_or_create_call(ChannelId, UserId, VoiceState, SessionId, SessionPid, Retries) ->
ConnectionId = maps:get(<<"connection_id">>, VoiceState, undefined),
@@ -597,6 +682,7 @@ join_or_create_call(ChannelId, UserId, VoiceState, SessionId, SessionPid, Retrie
join_or_create_call(ChannelId, UserId, VoiceState, SessionId, SessionPid, Retries - 1)
end.
-spec fetch_dm_channel_via_rpc(integer(), integer()) -> {ok, map()} | {error, term()}.
fetch_dm_channel_via_rpc(ChannelId, UserId) ->
Req = #{
<<"type">> => <<"get_dm_channel">>,
@@ -614,6 +700,7 @@ fetch_dm_channel_via_rpc(ChannelId, UserId) ->
{error, Reason}
end.
-spec convert_api_channel_to_gateway_format(map(), integer()) -> map().
convert_api_channel_to_gateway_format(Channel, CurrentUserId) ->
ChannelType = maps:get(<<"type">>, Channel, 0),
Recipients = maps:get(<<"recipients">>, Channel, []),
@@ -627,6 +714,7 @@ convert_api_channel_to_gateway_format(Channel, CurrentUserId) ->
<<"recipient_ids">> => RecipientIds
}.
-spec extract_recipient_id(term(), integer()) -> {true, integer()} | false.
extract_recipient_id(Recipient, CurrentUserId) when is_map(Recipient) ->
case maps:get(<<"id">>, Recipient, undefined) of
undefined -> false;
@@ -635,6 +723,7 @@ extract_recipient_id(Recipient, CurrentUserId) when is_map(Recipient) ->
extract_recipient_id(Id, CurrentUserId) ->
filter_recipient_id(parse_id(Id), CurrentUserId).
-spec parse_id(term()) -> integer() | null.
parse_id(Id) when is_integer(Id) -> Id;
parse_id(Id) when is_binary(Id) ->
case validation:validate_snowflake(<<"id">>, Id) of
@@ -644,6 +733,111 @@ parse_id(Id) when is_binary(Id) ->
parse_id(_) ->
null.
-spec filter_recipient_id(integer() | null, integer()) -> {true, integer()} | false.
filter_recipient_id(null, _CurrentUserId) -> false;
filter_recipient_id(Id, Id) -> false;
filter_recipient_id(Id, _CurrentUserId) -> {true, Id}.
-ifdef(TEST).
is_dm_channel_type_test() ->
?assert(is_dm_channel_type(1)),
?assert(is_dm_channel_type(3)),
?assertNot(is_dm_channel_type(0)),
?assertNot(is_dm_channel_type(2)),
?assertNot(is_dm_channel_type(4)).
normalize_session_id_test() ->
?assertEqual(undefined, normalize_session_id(undefined)),
?assertEqual(<<"abc">>, normalize_session_id(<<"abc">>)),
?assertEqual(<<"123">>, normalize_session_id(123)),
?assertEqual(<<"test">>, normalize_session_id("test")).
validate_dm_viewer_stream_keys_null_test() ->
?assertEqual({ok, []}, validate_dm_viewer_stream_keys(undefined, 123, #{})),
?assertEqual({ok, []}, validate_dm_viewer_stream_keys(null, 123, #{})).
validate_dm_viewer_stream_keys_invalid_type_test() ->
?assertEqual({error, voice_invalid_state}, validate_dm_viewer_stream_keys(123, 456, #{})).
resolve_effective_session_id_test() ->
?assertEqual(<<"req">>, resolve_effective_session_id(undefined, <<"req">>)),
?assertEqual(<<"existing">>, resolve_effective_session_id(<<"existing">>, <<"req">>)),
?assertEqual(<<"same">>, resolve_effective_session_id(<<"same">>, <<"same">>)).
filter_recipient_id_test() ->
?assertEqual(false, filter_recipient_id(null, 1)),
?assertEqual(false, filter_recipient_id(1, 1)),
?assertEqual({true, 2}, filter_recipient_id(2, 1)).
parse_id_test() ->
?assertEqual(123, parse_id(123)),
?assertEqual(null, parse_id(invalid)).
handle_dm_connect_or_update_user_mismatch_test() ->
VoiceStates = #{
<<"conn-1">> => #{
<<"channel_id">> => <<"100">>,
<<"user_id">> => <<"20">>,
<<"session_id">> => <<"sess">>
}
},
State = #{dm_voice_states => VoiceStates},
{reply, {error, validation_error, voice_user_mismatch}, _} =
handle_dm_connect_or_update(
<<"conn-1">>,
100,
10,
<<"sess">>,
false,
false,
false,
false,
undefined,
false,
null,
null,
VoiceStates,
State
).
handle_dm_connect_or_update_owner_match_proceeds_test() ->
VoiceStates = #{
<<"conn-1">> => #{
<<"channel_id">> => <<"100">>,
<<"user_id">> => <<"10">>,
<<"session_id">> => <<"sess">>
}
},
State = #{
dm_voice_states => VoiceStates,
channels => #{100 => #{<<"type">> => 1, <<"recipient_ids">> => [10]}},
user_id => 10,
id => <<"sess">>,
session_pid => self()
},
case
handle_dm_connect_or_update(
<<"conn-1">>,
100,
10,
<<"sess">>,
false,
false,
false,
false,
undefined,
false,
null,
null,
VoiceStates,
State
)
of
{reply, {error, validation_error, voice_user_mismatch}, _} ->
error(should_not_get_user_mismatch);
{reply, _, _} ->
ok
end.
-endif.

View File

@@ -26,7 +26,7 @@
-export([confirm_voice_connection_from_livekit/2]).
-export([move_member/2]).
-export([broadcast_voice_state_update/3]).
-export([broadcast_voice_server_update_to_session/6]).
-export([broadcast_voice_server_update_to_session/7]).
-export([send_voice_server_update_for_move/5]).
-export([send_voice_server_updates_for_move/4]).
-export([switch_voice_region_handler/2]).
@@ -35,67 +35,88 @@
-export([handle_virtual_channel_access_for_move/4]).
-export([cleanup_virtual_access_on_disconnect/2]).
voice_state_update(Request, State) ->
case guild_voice_connection:voice_state_update(Request, State) of
{reply, Response, NewState} ->
{reply, Response, NewState};
{error, Category, Message} ->
{reply, {error, Category, Message}, State}
end.
-type guild_state() :: map().
-type voice_state() :: map().
-spec voice_state_update(map(), guild_state()) ->
{reply, map(), guild_state()} | {reply, {error, atom(), atom()}, guild_state()}.
voice_state_update(Request, State) ->
guild_voice_connection:voice_state_update(Request, State).
-spec get_voice_state(map(), guild_state()) -> {reply, map(), guild_state()}.
get_voice_state(Request, State) ->
guild_voice_state:get_voice_state(Request, State).
-spec get_voice_states_list(guild_state()) -> [voice_state()].
get_voice_states_list(State) ->
guild_voice_state:get_voice_states_list(State).
-spec update_member_voice(map(), guild_state()) -> {reply, map(), guild_state()}.
update_member_voice(Request, State) ->
guild_voice_member:update_member_voice(Request, State).
-spec disconnect_voice_user(map(), guild_state()) -> {reply, map(), guild_state()}.
disconnect_voice_user(Request, State) ->
guild_voice_disconnect:disconnect_voice_user(Request, State).
-spec disconnect_voice_user_if_in_channel(map(), guild_state()) -> {reply, map(), guild_state()}.
disconnect_voice_user_if_in_channel(Request, State) ->
guild_voice_disconnect:disconnect_voice_user_if_in_channel(Request, State).
-spec disconnect_all_voice_users_in_channel(map(), guild_state()) -> {reply, map(), guild_state()}.
disconnect_all_voice_users_in_channel(Request, State) ->
guild_voice_disconnect:disconnect_all_voice_users_in_channel(Request, State).
-spec confirm_voice_connection_from_livekit(map(), guild_state()) ->
{reply, map(), guild_state()} | {error, atom(), atom()}.
confirm_voice_connection_from_livekit(Request, State) ->
case guild_voice_connection:confirm_voice_connection_from_livekit(Request, State) of
{reply, Response, NewState} ->
{reply, Response, NewState};
{error, Category, Message} ->
{reply, {error, Category, Message}, State}
end.
guild_voice_connection:confirm_voice_connection_from_livekit(Request, State).
-spec move_member(map(), guild_state()) -> {reply, map(), guild_state()}.
move_member(Request, State) ->
guild_voice_move:move_member(Request, State).
-spec send_voice_server_update_for_move(integer(), integer(), integer(), binary(), pid()) -> ok.
send_voice_server_update_for_move(GuildId, ChannelId, UserId, SessionId, GuildPid) ->
guild_voice_move:send_voice_server_update_for_move(
GuildId, ChannelId, UserId, SessionId, GuildPid
).
-spec send_voice_server_updates_for_move(integer(), integer(), [map()], pid()) -> ok.
send_voice_server_updates_for_move(GuildId, ChannelId, SessionDataList, GuildPid) ->
guild_voice_move:send_voice_server_updates_for_move(
GuildId, ChannelId, SessionDataList, GuildPid
).
-spec broadcast_voice_state_update(voice_state(), guild_state(), binary() | null) -> ok.
broadcast_voice_state_update(VoiceState, State, OldChannelIdBin) ->
guild_voice_broadcast:broadcast_voice_state_update(VoiceState, State, OldChannelIdBin).
broadcast_voice_server_update_to_session(GuildId, SessionId, Token, Endpoint, ConnectionId, State) ->
-spec broadcast_voice_server_update_to_session(
integer(), integer(), binary(), binary(), binary(), binary(), guild_state()
) -> ok.
broadcast_voice_server_update_to_session(
GuildId,
ChannelId,
SessionId,
Token,
Endpoint,
ConnectionId,
State
) ->
guild_voice_broadcast:broadcast_voice_server_update_to_session(
GuildId, SessionId, Token, Endpoint, ConnectionId, State
GuildId, ChannelId, SessionId, Token, Endpoint, ConnectionId, State
).
-spec switch_voice_region_handler(map(), guild_state()) -> {reply, map(), guild_state()}.
switch_voice_region_handler(Request, State) ->
guild_voice_region:switch_voice_region_handler(Request, State).
-spec switch_voice_region(integer(), integer(), pid()) -> ok | {error, term()}.
switch_voice_region(GuildId, ChannelId, GuildPid) ->
guild_voice_region:switch_voice_region(GuildId, ChannelId, GuildPid).
-spec handle_virtual_channel_access_for_move(integer(), integer(), map(), pid()) -> ok.
handle_virtual_channel_access_for_move(UserId, ChannelId, _ConnectionsToMove, GuildPid) ->
case gen_server:call(GuildPid, {get_sessions}, 10000) of
State when is_map(State) ->
@@ -104,16 +125,21 @@ handle_virtual_channel_access_for_move(UserId, ChannelId, _ConnectionsToMove, Gu
undefined ->
ok;
_ ->
HasViewPermission = guild_permissions:can_view_channel_by_permissions(
UserId, ChannelId, Member, State
Permissions = guild_permissions:get_member_permissions(
UserId, ChannelId, State
),
case HasViewPermission of
ViewPerm = constants:view_channel_permission(),
ConnectPerm = constants:connect_permission(),
HasView = (Permissions band ViewPerm) =:= ViewPerm,
HasConnect = (Permissions band ConnectPerm) =:= ConnectPerm,
case HasView andalso HasConnect of
true ->
ok;
false ->
gen_server:cast(
gen_server:call(
GuildPid,
{add_virtual_channel_access, UserId, ChannelId}
{add_virtual_channel_access, UserId, ChannelId},
10000
)
end
end;
@@ -121,5 +147,6 @@ handle_virtual_channel_access_for_move(UserId, ChannelId, _ConnectionsToMove, Gu
ok
end.
-spec cleanup_virtual_access_on_disconnect(integer(), pid()) -> ok.
cleanup_virtual_access_on_disconnect(UserId, GuildPid) ->
gen_server:cast(GuildPid, {cleanup_virtual_access_for_user, UserId}).

View File

@@ -18,28 +18,23 @@
-module(guild_voice_broadcast).
-export([broadcast_voice_state_update/3]).
-export([broadcast_voice_server_update_to_session/6]).
-export([broadcast_voice_server_update_to_session/7]).
-ifdef(TEST).
-define(WARN_MISSING_CONN(_VoiceState), ok).
-else.
-define(WARN_MISSING_CONN(VoiceState),
logger:warning(
"[guild_voice_broadcast] Skipping VOICE_STATE_UPDATE broadcast - missing connection_id: ~p",
[VoiceState]
)
).
-include_lib("eunit/include/eunit.hrl").
-endif.
-type guild_state() :: map().
-type voice_state() :: map().
-spec broadcast_voice_state_update(voice_state(), guild_state(), binary() | null) -> ok.
broadcast_voice_state_update(VoiceState, State, OldChannelIdBin) ->
case maps:get(<<"connection_id">>, VoiceState, undefined) of
undefined ->
?WARN_MISSING_CONN(VoiceState),
ok;
ConnectionId ->
_ConnectionId ->
Sessions = maps:get(sessions, State, #{}),
ChannelIdBin = maps:get(<<"channel_id">>, VoiceState, null),
FilterChannelIdBin =
case ChannelIdBin of
null ->
@@ -47,71 +42,100 @@ broadcast_voice_state_update(VoiceState, State, OldChannelIdBin) ->
_ ->
ChannelIdBin
end,
FilterChannelId = utils:binary_to_integer_safe(FilterChannelIdBin),
UserId = maps:get(<<"user_id">>, VoiceState, <<"unknown">>),
GuildId = maps:get(id, State, 0),
AllSessionDetails = [{Sid, maps:get(user_id, S)} || {Sid, S} <- maps:to_list(Sessions)],
logger:info(
"[guild_voice_broadcast] Broadcasting voice state update: "
"guild_id=~p user_id=~p channel_id=~p connection_id=~p "
"total_sessions=~p all_sessions=~p filter_channel_id=~p",
[
GuildId,
UserId,
ChannelIdBin,
ConnectionId,
maps:size(Sessions),
AllSessionDetails,
FilterChannelId
]
),
FilteredSessions = guild_sessions:filter_sessions_for_channel(
Sessions, FilterChannelId, undefined, State
),
SessionDetails = [{Sid, maps:get(user_id, S)} || {Sid, S} <- FilteredSessions],
Pids = [maps:get(pid, S) || {_Sid, S} <- FilteredSessions],
logger:info(
"[guild_voice_broadcast] Filtered sessions: "
"guild_id=~p user_id=~p filtered_count=~p session_details=~p pids=~p",
[GuildId, UserId, length(FilteredSessions), SessionDetails, Pids]
),
lists:foreach(
fun(Pid) when is_pid(Pid) ->
logger:info(
"[guild_voice_broadcast] Sending voice_state_update to session pid ~p",
[Pid]
),
gen_server:cast(Pid, {dispatch, voice_state_update, VoiceState})
end,
Pids
)
),
maybe_relay_voice_state_update(VoiceState, OldChannelIdBin, State),
ok
end.
broadcast_voice_server_update_to_session(GuildId, SessionId, Token, Endpoint, ConnectionId, State) ->
-spec broadcast_voice_server_update_to_session(
integer(), integer(), binary(), binary(), binary(), binary(), guild_state()
) -> ok.
broadcast_voice_server_update_to_session(
GuildId,
ChannelId,
SessionId,
Token,
Endpoint,
ConnectionId,
State
) ->
VoiceServerUpdate = #{
<<"token">> => Token,
<<"endpoint">> => Endpoint,
<<"guild_id">> => integer_to_binary(GuildId),
<<"channel_id">> => integer_to_binary(ChannelId),
<<"connection_id">> => ConnectionId
},
Sessions = maps:get(sessions, State, #{}),
case maps:get(SessionId, Sessions, undefined) of
undefined ->
maybe_relay_voice_server_update(
GuildId, ChannelId, SessionId, Token, Endpoint, ConnectionId, State
),
ok;
SessionData ->
SessionPid = maps:get(pid, SessionData, null),
case SessionPid of
Pid when is_pid(Pid) ->
gen_server:cast(Pid, {dispatch, voice_server_update, VoiceServerUpdate});
gen_server:cast(Pid, {dispatch, voice_server_update, VoiceServerUpdate}),
ok;
_ ->
maybe_relay_voice_server_update(
GuildId, ChannelId, SessionId, Token, Endpoint, ConnectionId, State
),
ok
end
end.
-spec maybe_relay_voice_state_update(map(), binary() | null, guild_state()) -> ok.
maybe_relay_voice_state_update(VoiceState, OldChannelIdBin, State) ->
case {maps:get(very_large_guild_coordinator_pid, State, undefined),
maps:get(very_large_guild_shard_index, State, undefined)}
of
{CoordPid, ShardIndex} when is_pid(CoordPid), is_integer(ShardIndex) ->
CoordPid ! {very_large_guild_voice_state_update, ShardIndex, VoiceState, OldChannelIdBin},
ok;
_ ->
ok
end.
-spec maybe_relay_voice_server_update(
integer(),
integer(),
binary(),
binary(),
binary(),
binary(),
guild_state()
) -> ok.
maybe_relay_voice_server_update(GuildId, ChannelId, SessionId, Token, Endpoint, ConnectionId, State) ->
case {maps:get(very_large_guild_coordinator_pid, State, undefined),
maps:get(very_large_guild_shard_index, State, undefined)}
of
{CoordPid, ShardIndex} when is_pid(CoordPid), is_integer(ShardIndex) ->
CoordPid !
{very_large_guild_voice_server_update, ShardIndex, GuildId, ChannelId, SessionId,
Token, Endpoint, ConnectionId},
ok;
_ ->
ok
end.
-ifdef(TEST).
broadcast_voice_state_update_missing_connection_id_test() ->
VoiceState = #{<<"user_id">> => <<"1">>},
State = #{sessions => #{}},
?assertEqual(ok, broadcast_voice_state_update(VoiceState, State, null)).
-endif.

File diff suppressed because it is too large Load Diff

View File

@@ -23,15 +23,18 @@
-export([disconnect_voice_user_if_in_channel/2]).
-export([disconnect_all_voice_users_in_channel/2]).
-export([cleanup_virtual_channel_access_for_user/2]).
-type guild_state() :: map().
-type voice_state() :: map().
-type voice_state_map() :: #{binary() => voice_state()}.
-export([recently_disconnected_voice_states/1]).
-export([clear_recently_disconnected/2]).
-export([clear_recently_disconnected_for_channel/2]).
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-endif.
-type guild_state() :: map().
-type voice_state() :: map().
-type voice_state_map() :: #{binary() => voice_state()}.
-spec handle_voice_disconnect(
binary() | undefined,
term(),
@@ -45,7 +48,10 @@ handle_voice_disconnect(ConnectionId, _SessionId, UserId, VoiceStates0, State) -
VoiceStates = voice_state_utils:ensure_voice_states(VoiceStates0),
case maps:get(ConnectionId, VoiceStates, undefined) of
undefined ->
{reply, #{success => true}, State};
%% Voice state not in voice_states - check if it's still pending
%% (user disconnected before LiveKit confirmation)
State1 = clear_pending_voice_connection(ConnectionId, State),
{reply, #{success => true}, State1};
OldVoiceState ->
case guild_voice_state:user_matches_voice_state(OldVoiceState, UserId) of
false ->
@@ -64,16 +70,43 @@ handle_voice_disconnect(ConnectionId, _SessionId, UserId, VoiceStates0, State) -
{GuildId, ChannelId} ->
maybe_force_disconnect(GuildId, ChannelId, UserId, ConnectionId, State),
NewVoiceStates = maps:remove(ConnectionId, VoiceStates),
NewState = maps:put(voice_states, NewVoiceStates, State),
NewState0 = maps:put(voice_states, NewVoiceStates, State),
NewState = clear_recently_disconnected(ConnectionId, NewState0),
voice_state_utils:broadcast_disconnects(
#{ConnectionId => OldVoiceState}, NewState
),
FinalState = cleanup_virtual_channel_access_for_user(UserId, NewState),
FinalState = maybe_cleanup_after_disconnect(
UserId, ChannelId, NewState
),
{reply, #{success => true}, FinalState}
end
end
end.
-spec clear_pending_voice_connection(binary(), guild_state()) -> guild_state().
clear_pending_voice_connection(ConnectionId, State) ->
PendingConnections = maps:get(pending_voice_connections, State, #{}),
case maps:is_key(ConnectionId, PendingConnections) of
false ->
State;
true ->
NewPendingConnections = maps:remove(ConnectionId, PendingConnections),
maps:put(pending_voice_connections, NewPendingConnections, State)
end.
-spec maybe_cleanup_after_disconnect(integer(), integer(), guild_state()) -> guild_state().
maybe_cleanup_after_disconnect(UserId, ChannelId, State) ->
case
guild_virtual_channel_access:is_pending_join(UserId, ChannelId, State) orelse
guild_virtual_channel_access:has_preserve(UserId, ChannelId, State) orelse
guild_virtual_channel_access:is_move_pending(UserId, ChannelId, State)
of
true ->
State;
false ->
cleanup_virtual_channel_access_for_user(UserId, State)
end.
-spec disconnect_voice_user(map(), guild_state()) -> {reply, map(), guild_state()}.
disconnect_voice_user(#{user_id := UserId} = Request, State) ->
ConnectionId = maps:get(connection_id, Request, null),
@@ -85,12 +118,22 @@ disconnect_voice_user(#{user_id := UserId} = Request, State) ->
end),
case maps:size(UserVoiceStates) of
0 ->
{reply, #{success => true}, State};
%% No active voice states - also clean up any pending connections
State1 = clear_pending_voice_connections_for_user(UserId, State),
{reply, #{success => true}, State1};
_ ->
maybe_force_disconnect_voice_states(UserVoiceStates, State),
NewVoiceStates = voice_state_utils:drop_voice_states(
UserVoiceStates, VoiceStates
),
NewState = maps:put(voice_states, NewVoiceStates, State),
NewState0 = maps:put(voice_states, NewVoiceStates, State),
NewState = maps:fold(
fun(ConnId, _, AccState) ->
clear_recently_disconnected(ConnId, AccState)
end,
NewState0,
UserVoiceStates
),
voice_state_utils:broadcast_disconnects(UserVoiceStates, NewState),
FinalState = cleanup_virtual_channel_access_for_user(UserId, NewState),
{reply, #{success => true}, FinalState}
@@ -98,14 +141,18 @@ disconnect_voice_user(#{user_id := UserId} = Request, State) ->
SpecificConnection ->
case maps:get(SpecificConnection, VoiceStates, undefined) of
undefined ->
{reply, #{success => true}, State};
%% Not found in voice_states - also clean up pending connection
State1 = clear_pending_voice_connection(SpecificConnection, State),
{reply, #{success => true}, State1};
VoiceState ->
case voice_state_utils:voice_state_user_id(VoiceState) of
undefined ->
{reply, gateway_errors:error(voice_invalid_state), State};
VoiceStateUserId when VoiceStateUserId =:= UserId ->
maybe_force_disconnect_voice_state(SpecificConnection, VoiceState, State),
NewVoiceStates = maps:remove(SpecificConnection, VoiceStates),
NewState = maps:put(voice_states, NewVoiceStates, State),
NewState0 = maps:put(voice_states, NewVoiceStates, State),
NewState = clear_recently_disconnected(SpecificConnection, NewState0),
voice_state_utils:broadcast_disconnects(
#{SpecificConnection => VoiceState}, NewState
),
@@ -117,6 +164,19 @@ disconnect_voice_user(#{user_id := UserId} = Request, State) ->
end
end.
-spec clear_pending_voice_connections_for_user(integer(), guild_state()) -> guild_state().
clear_pending_voice_connections_for_user(UserId, State) ->
PendingConnections = maps:get(pending_voice_connections, State, #{}),
FilteredPending = maps:filter(
fun(_ConnId, PendingData) ->
PendingUserId = maps:get(user_id, PendingData, undefined),
PendingUserId =/= UserId
end,
PendingConnections
),
maps:put(pending_voice_connections, FilteredPending, State).
-spec disconnect_voice_user_if_in_channel(map(), guild_state()) -> {reply, map(), guild_state()}.
disconnect_voice_user_if_in_channel(
#{user_id := UserId, expected_channel_id := ExpectedChannelId} = Request,
State
@@ -131,27 +191,36 @@ disconnect_voice_user_if_in_channel(
end),
case maps:size(UserVoiceStates) of
0 ->
%% Not found in voice_states - also clean up any pending connections
%% for this user/channel (user disconnected before LiveKit confirmation)
State1 = clear_pending_voice_connections_for_user_channel(
UserId, ExpectedChannelId, State
),
{reply,
#{
success => true,
ignored => true,
reason => <<"not_in_expected_channel">>
},
State};
State1};
_ ->
NewVoiceStates = voice_state_utils:drop_voice_states(
UserVoiceStates, VoiceStates
),
NewState = maps:put(voice_states, NewVoiceStates, State),
NewState0 = maps:put(voice_states, NewVoiceStates, State),
NewState = cache_recently_disconnected(UserVoiceStates, NewState0),
voice_state_utils:broadcast_disconnects(UserVoiceStates, NewState),
{reply, #{success => true}, NewState}
end;
ConnId ->
case maps:get(ConnId, VoiceStates, undefined) of
undefined ->
%% Not found in voice_states - also clean up pending connection
%% (user disconnected before LiveKit confirmation)
State1 = clear_pending_voice_connection(ConnId, State),
{reply,
#{success => true, ignored => true, reason => <<"connection_not_found">>},
State};
State1};
VoiceState ->
case
{
@@ -161,7 +230,10 @@ disconnect_voice_user_if_in_channel(
of
{UserId, ExpectedChannelId} ->
NewVoiceStates = maps:remove(ConnId, VoiceStates),
NewState = maps:put(voice_states, NewVoiceStates, State),
NewState0 = maps:put(voice_states, NewVoiceStates, State),
NewState = cache_recently_disconnected(
#{ConnId => VoiceState}, NewState0
),
voice_state_utils:broadcast_disconnects(
#{ConnId => VoiceState}, NewState
),
@@ -178,105 +250,157 @@ disconnect_voice_user_if_in_channel(
end
end.
-spec clear_pending_voice_connections_for_user_channel(integer(), integer(), guild_state()) ->
guild_state().
clear_pending_voice_connections_for_user_channel(UserId, ChannelId, State) ->
PendingConnections = maps:get(pending_voice_connections, State, #{}),
FilteredPending = maps:filter(
fun(_ConnId, PendingData) ->
PendingUserId = maps:get(user_id, PendingData, undefined),
PendingChannelId = maps:get(channel_id, PendingData, undefined),
not (PendingUserId =:= UserId andalso PendingChannelId =:= ChannelId)
end,
PendingConnections
),
maps:put(pending_voice_connections, FilteredPending, State).
-spec clear_pending_voice_connections_for_channel(integer(), guild_state()) -> guild_state().
clear_pending_voice_connections_for_channel(ChannelId, State) ->
PendingConnections = maps:get(pending_voice_connections, State, #{}),
FilteredPending = maps:filter(
fun(_ConnId, PendingData) ->
PendingChannelId = maps:get(channel_id, PendingData, undefined),
PendingChannelId =/= ChannelId
end,
PendingConnections
),
maps:put(pending_voice_connections, FilteredPending, State).
-spec disconnect_all_voice_users_in_channel(map(), guild_state()) -> {reply, map(), guild_state()}.
disconnect_all_voice_users_in_channel(#{channel_id := ChannelId}, State) ->
VoiceStates = voice_state_utils:voice_states(State),
ChannelVoiceStates = voice_state_utils:filter_voice_states(VoiceStates, fun(_, V) ->
voice_state_utils:voice_state_channel_id(V) =:= ChannelId
end),
%% Also clean up any pending connections for this channel
%% (users that requested tokens but haven't confirmed via LiveKit yet)
State1 = clear_pending_voice_connections_for_channel(ChannelId, State),
case maps:size(ChannelVoiceStates) of
0 ->
{reply, #{success => true, disconnected_count => 0}, State};
{reply, #{success => true, disconnected_count => 0}, State1};
Count ->
maybe_force_disconnect_voice_states(ChannelVoiceStates, State1),
NewVoiceStates = voice_state_utils:drop_voice_states(ChannelVoiceStates, VoiceStates),
NewState = maps:put(voice_states, NewVoiceStates, State),
NewState0 = maps:put(voice_states, NewVoiceStates, State1),
NewState = clear_recently_disconnected_for_channel(ChannelId, NewState0),
voice_state_utils:broadcast_disconnects(ChannelVoiceStates, NewState),
{reply, #{success => true, disconnected_count => Count}, NewState}
end.
-spec maybe_force_disconnect_voice_states(voice_state_map(), guild_state()) -> ok.
maybe_force_disconnect_voice_states(VoiceStates, State) ->
maps:foreach(
fun(ConnId, VoiceState) ->
maybe_force_disconnect_voice_state(ConnId, VoiceState, State)
end,
VoiceStates
),
ok.
-spec maybe_force_disconnect_voice_state(binary(), voice_state(), guild_state()) -> ok.
maybe_force_disconnect_voice_state(ConnectionId, VoiceState, State) ->
UserId = voice_state_utils:voice_state_user_id(VoiceState),
ChannelId = voice_state_utils:voice_state_channel_id(VoiceState),
GuildId = resolve_guild_id(VoiceState, State),
case {GuildId, ChannelId, UserId} of
{GId, CId, UId} when is_integer(GId), is_integer(CId), is_integer(UId) ->
_ = maybe_force_disconnect(GId, CId, UId, ConnectionId, State),
ok;
_ ->
ok
end.
-spec resolve_guild_id(voice_state(), guild_state()) -> integer() | undefined.
resolve_guild_id(VoiceState, State) ->
case voice_state_utils:voice_state_guild_id(VoiceState) of
undefined ->
map_utils:get_integer(State, id, undefined);
GuildId ->
GuildId
end.
-spec force_disconnect_participant(integer(), integer(), integer(), binary()) ->
{ok, map()} | {error, term()}.
force_disconnect_participant(GuildId, ChannelId, UserId, ConnectionId) ->
Req = voice_utils:build_force_disconnect_rpc_request(GuildId, ChannelId, UserId, ConnectionId),
case rpc_client:call(Req) of
{ok, _Data} ->
logger:debug(
"[guild_voice_disconnect] Force disconnected participant via RPC ~p",
[
[
{guildId, GuildId},
{channelId, ChannelId},
{userId, UserId},
{connectionId, ConnectionId}
]
]
),
{ok, #{success => true}};
{error, Reason} ->
logger:error(
"[guild_voice_disconnect] Failed to force disconnect participant via RPC ~p",
[
[
{guildId, GuildId},
{channelId, ChannelId},
{userId, UserId},
{connectionId, ConnectionId},
{error, Reason}
]
]
),
{error, Reason}
end.
-spec cleanup_virtual_channel_access_for_user(integer(), guild_state()) -> guild_state().
cleanup_virtual_channel_access_for_user(UserId, State) ->
VoiceStates = voice_state_utils:voice_states(State),
HasVoiceConnection = maps:fold(
VirtualChannels = guild_virtual_channel_access:get_virtual_channels_for_user(UserId, State),
lists:foldl(
fun(ChannelId, AccState) ->
case user_has_voice_connection_in_channel(UserId, ChannelId, VoiceStates) of
true ->
AccState;
false ->
case
guild_virtual_channel_access:is_pending_join(UserId, ChannelId, AccState) orelse
guild_virtual_channel_access:has_preserve(UserId, ChannelId, AccState) orelse
guild_virtual_channel_access:is_move_pending(
UserId, ChannelId, AccState
)
of
true ->
AccState;
false ->
ok = maybe_dispatch_visibility_remove(UserId, ChannelId, AccState),
guild_virtual_channel_access:remove_virtual_access(
UserId, ChannelId, AccState
)
end
end
end,
State,
VirtualChannels
).
-spec user_has_voice_connection_in_channel(integer(), integer(), voice_state_map()) -> boolean().
user_has_voice_connection_in_channel(UserId, ChannelId, VoiceStates) ->
ChannelIdBin = integer_to_binary(ChannelId),
maps:fold(
fun(_ConnId, VoiceState, Acc) ->
case Acc of
true -> true;
false -> voice_state_utils:voice_state_user_id(VoiceState) =:= UserId
true ->
true;
false ->
voice_state_utils:voice_state_user_id(VoiceState) =:= UserId andalso
maps:get(<<"channel_id">>, VoiceState, null) =:= ChannelIdBin
end
end,
false,
VoiceStates
),
case HasVoiceConnection of
).
-spec maybe_dispatch_visibility_remove(integer(), integer(), guild_state()) -> ok.
maybe_dispatch_visibility_remove(UserId, ChannelId, State) ->
case guild_virtual_channel_access:has_virtual_access(UserId, ChannelId, State) of
true ->
State;
guild_virtual_channel_access:dispatch_channel_visibility_change(
UserId, ChannelId, remove, State
);
false ->
VirtualChannels = guild_virtual_channel_access:get_virtual_channels_for_user(
UserId, State
),
lists:foldl(
fun(ChannelId, AccState) ->
Member = guild_permissions:find_member_by_user_id(UserId, AccState),
case Member of
undefined ->
AccState;
_ ->
HasViewPermission = guild_permissions:can_view_channel_by_permissions(
UserId, ChannelId, Member, AccState
),
case HasViewPermission of
true ->
guild_virtual_channel_access:remove_virtual_access(
UserId, ChannelId, AccState
);
false ->
guild_virtual_channel_access:dispatch_channel_visibility_change(
UserId, ChannelId, remove, AccState
),
guild_virtual_channel_access:remove_virtual_access(
UserId, ChannelId, AccState
)
end
end
end,
State,
VirtualChannels
)
ok
end.
-spec maybe_force_disconnect(integer(), integer(), integer(), binary(), guild_state()) ->
{ok, map()} | {error, term()}.
maybe_force_disconnect(GuildId, ChannelId, UserId, ConnectionId, State) ->
case maps:get(test_force_disconnect_fun, State, undefined) of
Fun when is_function(Fun, 4) ->
@@ -285,6 +409,59 @@ maybe_force_disconnect(GuildId, ChannelId, UserId, ConnectionId, State) ->
force_disconnect_participant(GuildId, ChannelId, UserId, ConnectionId)
end.
-define(RECENTLY_DISCONNECTED_TTL_MS, 60000).
-spec recently_disconnected_voice_states(guild_state()) -> map().
recently_disconnected_voice_states(State) ->
case maps:get(recently_disconnected_voice_states, State, undefined) of
Map when is_map(Map) -> Map;
_ -> #{}
end.
-spec cache_recently_disconnected(voice_state_map(), guild_state()) -> guild_state().
cache_recently_disconnected(VoiceStatesToCache, State) ->
Now = erlang:system_time(millisecond),
Existing = recently_disconnected_voice_states(State),
Swept = sweep_expired_recently_disconnected(Existing, Now),
NewEntries = maps:fold(
fun(ConnId, VoiceState, Acc) ->
maps:put(ConnId, #{voice_state => VoiceState, disconnected_at => Now}, Acc)
end,
Swept,
VoiceStatesToCache
),
maps:put(recently_disconnected_voice_states, NewEntries, State).
-spec sweep_expired_recently_disconnected(map(), integer()) -> map().
sweep_expired_recently_disconnected(Cache, Now) ->
maps:filter(
fun(_ConnId, #{disconnected_at := DisconnectedAt}) ->
(Now - DisconnectedAt) < ?RECENTLY_DISCONNECTED_TTL_MS;
(_ConnId, _) ->
false
end,
Cache
).
-spec clear_recently_disconnected(binary(), guild_state()) -> guild_state().
clear_recently_disconnected(ConnectionId, State) ->
Cache = recently_disconnected_voice_states(State),
NewCache = maps:remove(ConnectionId, Cache),
maps:put(recently_disconnected_voice_states, NewCache, State).
-spec clear_recently_disconnected_for_channel(integer(), guild_state()) -> guild_state().
clear_recently_disconnected_for_channel(ChannelId, State) ->
Cache = recently_disconnected_voice_states(State),
NewCache = maps:filter(
fun(_ConnId, #{voice_state := VS}) ->
voice_state_utils:voice_state_channel_id(VS) =/= ChannelId;
(_ConnId, _) ->
false
end,
Cache
),
maps:put(recently_disconnected_voice_states, NewCache, State).
-ifdef(TEST).
disconnect_voice_user_removes_all_connections_test() ->
@@ -292,7 +469,10 @@ disconnect_voice_user_removes_all_connections_test() ->
<<"a">> => voice_state_fixture(5, 10, 20),
<<"b">> => voice_state_fixture(5, 10, 21)
},
State = #{voice_states => VoiceStates},
State = #{
voice_states => VoiceStates,
test_force_disconnect_fun => fun(_, _, _, _) -> {ok, #{success => true}} end
},
{reply, #{success := true}, #{voice_states := #{}}} =
disconnect_voice_user(#{user_id => 5, connection_id => null}, State).
@@ -309,6 +489,353 @@ disconnect_voice_user_if_in_channel_ignored_test() ->
{reply, #{ignored := true}, _} =
disconnect_voice_user_if_in_channel(#{user_id => 5, expected_channel_id => 99}, State).
user_has_voice_connection_in_channel_test() ->
VoiceStates = #{
<<"conn">> => #{<<"user_id">> => <<"10">>, <<"channel_id">> => <<"100">>}
},
?assert(user_has_voice_connection_in_channel(10, 100, VoiceStates)),
?assertNot(user_has_voice_connection_in_channel(10, 200, VoiceStates)),
?assertNot(user_has_voice_connection_in_channel(20, 100, VoiceStates)).
recently_disconnected_voice_states_default_test() ->
?assertEqual(#{}, recently_disconnected_voice_states(#{})).
recently_disconnected_voice_states_returns_map_test() ->
Cache = #{<<"conn">> => #{voice_state => #{}, disconnected_at => 1000}},
State = #{recently_disconnected_voice_states => Cache},
?assertEqual(Cache, recently_disconnected_voice_states(State)).
cache_recently_disconnected_test() ->
VS = voice_state_fixture(5, 10, 20),
State = #{},
NewState = cache_recently_disconnected(#{<<"conn">> => VS}, State),
Cache = recently_disconnected_voice_states(NewState),
?assert(maps:is_key(<<"conn">>, Cache)),
#{<<"conn">> := #{voice_state := CachedVS}} = Cache,
?assertEqual(VS, CachedVS).
clear_recently_disconnected_test() ->
VS = voice_state_fixture(5, 10, 20),
State0 = cache_recently_disconnected(#{<<"conn">> => VS}, #{}),
State1 = clear_recently_disconnected(<<"conn">>, State0),
?assertEqual(#{}, recently_disconnected_voice_states(State1)).
clear_recently_disconnected_for_channel_test() ->
VS1 = voice_state_fixture(5, 10, 20),
VS2 = voice_state_fixture(6, 10, 30),
State0 = cache_recently_disconnected(#{<<"a">> => VS1, <<"b">> => VS2}, #{}),
State1 = clear_recently_disconnected_for_channel(20, State0),
Cache = recently_disconnected_voice_states(State1),
?assertNot(maps:is_key(<<"a">>, Cache)),
?assert(maps:is_key(<<"b">>, Cache)).
sweep_expired_recently_disconnected_test() ->
Now = erlang:system_time(millisecond),
Cache = #{
<<"fresh">> => #{voice_state => #{}, disconnected_at => Now - 1000},
<<"expired">> => #{voice_state => #{}, disconnected_at => Now - 70000}
},
Swept = sweep_expired_recently_disconnected(Cache, Now),
?assert(maps:is_key(<<"fresh">>, Swept)),
?assertNot(maps:is_key(<<"expired">>, Swept)).
disconnect_voice_user_if_in_channel_caches_by_connection_test() ->
VS = voice_state_fixture(5, 10, 20),
VoiceStates = #{<<"conn">> => VS},
State = #{voice_states => VoiceStates},
{reply, #{success := true}, NewState} =
disconnect_voice_user_if_in_channel(
#{user_id => 5, expected_channel_id => 20, connection_id => <<"conn">>},
State
),
Cache = recently_disconnected_voice_states(NewState),
?assert(maps:is_key(<<"conn">>, Cache)).
disconnect_voice_user_calls_force_disconnect_for_all_connections_test() ->
Self = self(),
TestFun = fun(GuildId, ChannelId, UserId, ConnectionId) ->
Self ! {force_disconnect, GuildId, ChannelId, UserId, ConnectionId},
{ok, #{success => true}}
end,
VoiceStates = #{
<<"a">> => voice_state_fixture(5, 10, 20),
<<"b">> => voice_state_fixture(5, 10, 21)
},
State = #{
id => 10,
voice_states => VoiceStates,
test_force_disconnect_fun => TestFun
},
{reply, #{success := true}, #{voice_states := #{}}} =
disconnect_voice_user(#{user_id => 5, connection_id => null}, State),
Msgs = collect_force_disconnect_messages(2),
?assertEqual(2, length(Msgs)),
?assert(lists:member({force_disconnect, 10, 20, 5, <<"a">>}, Msgs)),
?assert(lists:member({force_disconnect, 10, 21, 5, <<"b">>}, Msgs)).
disconnect_voice_user_calls_force_disconnect_for_specific_connection_test() ->
Self = self(),
TestFun = fun(GuildId, ChannelId, UserId, ConnectionId) ->
Self ! {force_disconnect, GuildId, ChannelId, UserId, ConnectionId},
{ok, #{success => true}}
end,
VoiceStates = #{
<<"a">> => voice_state_fixture(5, 10, 20),
<<"b">> => voice_state_fixture(5, 10, 21)
},
State = #{
id => 10,
voice_states => VoiceStates,
test_force_disconnect_fun => TestFun
},
{reply, #{success := true}, NewState} =
disconnect_voice_user(#{user_id => 5, connection_id => <<"a">>}, State),
Msgs = collect_force_disconnect_messages(1),
?assertEqual(1, length(Msgs)),
?assert(lists:member({force_disconnect, 10, 20, 5, <<"a">>}, Msgs)),
Remaining = maps:get(voice_states, NewState),
?assert(maps:is_key(<<"b">>, Remaining)),
?assertNot(maps:is_key(<<"a">>, Remaining)).
disconnect_all_voice_users_in_channel_calls_force_disconnect_test() ->
Self = self(),
TestFun = fun(GuildId, ChannelId, UserId, ConnectionId) ->
Self ! {force_disconnect, GuildId, ChannelId, UserId, ConnectionId},
{ok, #{success => true}}
end,
VoiceStates = #{
<<"a">> => voice_state_fixture(5, 10, 20),
<<"b">> => voice_state_fixture(6, 10, 20),
<<"c">> => voice_state_fixture(7, 10, 30)
},
State = #{
id => 10,
voice_states => VoiceStates,
test_force_disconnect_fun => TestFun
},
{reply, #{success := true, disconnected_count := 2}, NewState} =
disconnect_all_voice_users_in_channel(#{channel_id => 20}, State),
Msgs = collect_force_disconnect_messages(2),
?assertEqual(2, length(Msgs)),
?assert(lists:member({force_disconnect, 10, 20, 5, <<"a">>}, Msgs)),
?assert(lists:member({force_disconnect, 10, 20, 6, <<"b">>}, Msgs)),
Remaining = maps:get(voice_states, NewState),
?assert(maps:is_key(<<"c">>, Remaining)),
?assertNot(maps:is_key(<<"a">>, Remaining)),
?assertNot(maps:is_key(<<"b">>, Remaining)).
disconnect_voice_user_if_in_channel_skips_force_disconnect_test() ->
Self = self(),
TestFun = fun(_, _, _, _) ->
Self ! force_disconnect_called,
{ok, #{success => true}}
end,
VoiceStates = #{<<"a">> => voice_state_fixture(5, 10, 20)},
State = #{
voice_states => VoiceStates,
test_force_disconnect_fun => TestFun
},
{reply, #{success := true}, _} =
disconnect_voice_user_if_in_channel(
#{user_id => 5, expected_channel_id => 20, connection_id => <<"a">>},
State
),
receive
force_disconnect_called -> ?assert(false)
after 0 ->
ok
end.
%% Tests for clear_pending_voice_connection/2
clear_pending_voice_connection_removes_connection_test() ->
PendingConnections = #{
<<"conn1">> => #{user_id => 1, channel_id => 100},
<<"conn2">> => #{user_id => 2, channel_id => 200}
},
State = #{pending_voice_connections => PendingConnections},
NewState = clear_pending_voice_connection(<<"conn1">>, State),
NewPending = maps:get(pending_voice_connections, NewState),
?assertNot(maps:is_key(<<"conn1">>, NewPending)),
?assert(maps:is_key(<<"conn2">>, NewPending)).
clear_pending_voice_connection_ignores_missing_test() ->
PendingConnections = #{<<"conn1">> => #{user_id => 1, channel_id => 100}},
State = #{pending_voice_connections => PendingConnections},
NewState = clear_pending_voice_connection(<<"missing">>, State),
?assertEqual(State, NewState).
clear_pending_voice_connection_handles_empty_pending_test() ->
State = #{voice_states => #{}},
NewState = clear_pending_voice_connection(<<"conn">>, State),
?assertEqual(#{}, maps:get(pending_voice_connections, NewState, #{})).
%% Tests for clear_pending_voice_connections_for_user/2
clear_pending_voice_connections_for_user_removes_all_user_connections_test() ->
PendingConnections = #{
<<"conn1">> => #{user_id => 5, channel_id => 100},
<<"conn2">> => #{user_id => 5, channel_id => 200},
<<"conn3">> => #{user_id => 6, channel_id => 100}
},
State = #{pending_voice_connections => PendingConnections},
NewState = clear_pending_voice_connections_for_user(5, State),
NewPending = maps:get(pending_voice_connections, NewState),
?assertNot(maps:is_key(<<"conn1">>, NewPending)),
?assertNot(maps:is_key(<<"conn2">>, NewPending)),
?assert(maps:is_key(<<"conn3">>, NewPending)).
%% Tests for clear_pending_voice_connections_for_user_channel/3
clear_pending_voice_connections_for_user_channel_removes_matching_test() ->
PendingConnections = #{
<<"conn1">> => #{user_id => 5, channel_id => 100},
<<"conn2">> => #{user_id => 5, channel_id => 200},
<<"conn3">> => #{user_id => 6, channel_id => 100}
},
State = #{pending_voice_connections => PendingConnections},
NewState = clear_pending_voice_connections_for_user_channel(5, 100, State),
NewPending = maps:get(pending_voice_connections, NewState),
?assertNot(maps:is_key(<<"conn1">>, NewPending)),
?assert(maps:is_key(<<"conn2">>, NewPending)),
?assert(maps:is_key(<<"conn3">>, NewPending)).
%% Tests for clear_pending_voice_connections_for_channel/2
clear_pending_voice_connections_for_channel_removes_all_channel_connections_test() ->
PendingConnections = #{
<<"conn1">> => #{user_id => 5, channel_id => 100},
<<"conn2">> => #{user_id => 6, channel_id => 100},
<<"conn3">> => #{user_id => 7, channel_id => 200}
},
State = #{pending_voice_connections => PendingConnections},
NewState = clear_pending_voice_connections_for_channel(100, State),
NewPending = maps:get(pending_voice_connections, NewState),
?assertNot(maps:is_key(<<"conn1">>, NewPending)),
?assertNot(maps:is_key(<<"conn2">>, NewPending)),
?assert(maps:is_key(<<"conn3">>, NewPending)).
%% Tests for disconnect handlers cleaning up pending connections
handle_voice_disconnect_cleans_pending_when_not_in_voice_states_test() ->
PendingConnections = #{<<"conn1">> => #{user_id => 5, channel_id => 100}},
State = #{
voice_states => #{},
pending_voice_connections => PendingConnections
},
{reply, #{success := true}, NewState} =
handle_voice_disconnect(<<"conn1">>, undefined, 5, #{}, State),
NewPending = maps:get(pending_voice_connections, NewState),
?assertNot(maps:is_key(<<"conn1">>, NewPending)).
disconnect_voice_user_cleans_pending_when_no_active_states_test() ->
PendingConnections = #{
<<"conn1">> => #{user_id => 5, channel_id => 100},
<<"conn2">> => #{user_id => 6, channel_id => 100}
},
State = #{
id => 10,
voice_states => #{},
pending_voice_connections => PendingConnections
},
{reply, #{success := true}, NewState} =
disconnect_voice_user(#{user_id => 5}, State),
NewPending = maps:get(pending_voice_connections, NewState),
?assertNot(maps:is_key(<<"conn1">>, NewPending)),
?assert(maps:is_key(<<"conn2">>, NewPending)).
disconnect_voice_user_cleans_pending_for_specific_connection_test() ->
PendingConnections = #{
<<"conn1">> => #{user_id => 5, channel_id => 100},
<<"conn2">> => #{user_id => 5, channel_id => 200}
},
State = #{
id => 10,
voice_states => #{},
pending_voice_connections => PendingConnections
},
{reply, #{success := true}, NewState} =
disconnect_voice_user(#{user_id => 5, connection_id => <<"conn1">>}, State),
NewPending = maps:get(pending_voice_connections, NewState),
?assertNot(maps:is_key(<<"conn1">>, NewPending)),
?assert(maps:is_key(<<"conn2">>, NewPending)).
disconnect_voice_user_if_in_channel_cleans_pending_when_not_found_test() ->
PendingConnections = #{
<<"conn1">> => #{user_id => 5, channel_id => 100},
<<"conn2">> => #{user_id => 6, channel_id => 100}
},
State = #{
id => 10,
voice_states => #{},
pending_voice_connections => PendingConnections
},
{reply, #{success := true, ignored := true}, NewState} =
disconnect_voice_user_if_in_channel(
#{user_id => 5, expected_channel_id => 100},
State
),
NewPending = maps:get(pending_voice_connections, NewState),
?assertNot(maps:is_key(<<"conn1">>, NewPending)),
?assert(maps:is_key(<<"conn2">>, NewPending)).
disconnect_all_voice_users_in_channel_cleans_pending_test() ->
Self = self(),
TestFun = fun(GuildId, ChannelId, UserId, ConnectionId) ->
Self ! {force_disconnect, GuildId, ChannelId, UserId, ConnectionId},
{ok, #{success => true}}
end,
VoiceStates = #{
<<"a">> => voice_state_fixture(5, 10, 20)
},
PendingConnections = #{
<<"pending1">> => #{user_id => 6, channel_id => 20},
<<"pending2">> => #{user_id => 7, channel_id => 30}
},
State = #{
id => 10,
voice_states => VoiceStates,
pending_voice_connections => PendingConnections,
test_force_disconnect_fun => TestFun
},
{reply, #{success := true, disconnected_count := 1}, NewState} =
disconnect_all_voice_users_in_channel(#{channel_id => 20}, State),
_ = collect_force_disconnect_messages(1),
NewPending = maps:get(pending_voice_connections, NewState),
?assertNot(maps:is_key(<<"pending1">>, NewPending)),
?assert(maps:is_key(<<"pending2">>, NewPending)).
disconnect_all_voice_users_in_channel_cleans_pending_when_no_active_states_test() ->
PendingConnections = #{
<<"pending1">> => #{user_id => 5, channel_id => 20},
<<"pending2">> => #{user_id => 6, channel_id => 30}
},
State = #{
id => 10,
voice_states => #{},
pending_voice_connections => PendingConnections
},
{reply, #{success := true, disconnected_count := 0}, NewState} =
disconnect_all_voice_users_in_channel(#{channel_id => 20}, State),
NewPending = maps:get(pending_voice_connections, NewState),
?assertNot(maps:is_key(<<"pending1">>, NewPending)),
?assert(maps:is_key(<<"pending2">>, NewPending)).
collect_force_disconnect_messages(Count) ->
collect_force_disconnect_messages(Count, []).
collect_force_disconnect_messages(0, Acc) ->
lists:reverse(Acc);
collect_force_disconnect_messages(Count, Acc) when Count > 0 ->
receive
{force_disconnect, _, _, _, _} = Msg ->
collect_force_disconnect_messages(Count - 1, [Msg | Acc]);
_Other ->
collect_force_disconnect_messages(Count, Acc)
after 200 ->
lists:reverse(Acc)
end.
voice_state_fixture(UserId, GuildId, ChannelId) ->
#{
<<"user_id">> => integer_to_binary(UserId),

View File

@@ -40,7 +40,6 @@ update_member_voice(Request, State) ->
#{user_id := UserId, mute := Mute, deaf := Deaf} = Request,
VoiceStates = voice_state_utils:voice_states(State),
GuildId = map_utils:get_integer(State, id, 0),
case find_member_by_user_id(UserId, State) of
undefined ->
{reply, gateway_errors:error(voice_member_not_found), State};
@@ -48,7 +47,6 @@ update_member_voice(Request, State) ->
UpdatedMember = set_member_voice_flags(Member, Mute, Deaf),
StateWithUpdatedMember = store_member(UpdatedMember, State),
UserVoiceStates = user_voice_states(UserId, VoiceStates),
case maps:size(UserVoiceStates) of
0 ->
{reply, #{success => true}, StateWithUpdatedMember};
@@ -64,43 +62,22 @@ update_member_voice(Request, State) ->
end
end.
-spec find_member_by_user_id(integer(), guild_state()) -> member() | undefined.
find_member_by_user_id(UserId, State) ->
guild_permissions:find_member_by_user_id(UserId, State).
-spec find_channel_by_id(integer(), guild_state()) -> map() | undefined.
find_channel_by_id(ChannelId, State) ->
guild_permissions:find_channel_by_id(ChannelId, State).
-spec enforce_participant_state_in_livekit(integer(), integer(), integer(), boolean(), boolean()) ->
ok.
enforce_participant_state_in_livekit(GuildId, ChannelId, UserId, Mute, Deaf) ->
Req = voice_utils:build_update_participant_rpc_request(GuildId, ChannelId, UserId, Mute, Deaf),
case rpc_client:call(Req) of
{ok, _Data} ->
logger:debug(
"[guild_voice_member] Enforced participant state in LiveKit ~p",
[
[
{guildId, GuildId},
{channelId, ChannelId},
{userId, UserId},
{mute, Mute},
{deaf, Deaf}
]
]
),
ok;
{error, Reason} ->
logger:warning(
"[guild_voice_member] Failed to enforce participant state in LiveKit ~p",
[
[
{guildId, GuildId},
{channelId, ChannelId},
{userId, UserId},
{mute, Mute},
{deaf, Deaf},
{error, Reason}
]
]
),
{error, _Reason} ->
ok
end.
@@ -108,16 +85,10 @@ enforce_participant_state_in_livekit(GuildId, ChannelId, UserId, Mute, Deaf) ->
guild_data(State) ->
map_utils:ensure_map(map_utils:get_safe(State, data, #{})).
-spec guild_members(guild_state()) -> [member()].
guild_members(State) ->
map_utils:ensure_list(maps:get(<<"members">>, guild_data(State), [])).
-spec member_user_id(member()) -> integer() | undefined.
member_user_id(Member) when is_map(Member) ->
User = map_utils:ensure_map(maps:get(<<"user">>, Member, #{})),
map_utils:get_integer(User, <<"id">>, undefined);
member_user_id(_) ->
undefined.
map_utils:get_integer(User, <<"id">>, undefined).
-spec set_member_voice_flags(member(), boolean(), boolean()) -> member().
set_member_voice_flags(Member, Mute, Deaf) ->
@@ -128,19 +99,9 @@ store_member(Member, State) ->
case member_user_id(Member) of
undefined ->
State;
TargetId ->
_TargetId ->
Data = guild_data(State),
Members = guild_members(State),
UpdatedMembers = lists:map(
fun(Current) ->
case member_user_id(Current) of
TargetId -> Member;
_ -> Current
end
end,
Members
),
UpdatedData = maps:put(<<"members">>, UpdatedMembers, Data),
UpdatedData = guild_data_index:put_member(Member, Data),
maps:put(data, UpdatedData, State)
end.
@@ -246,9 +207,17 @@ update_member_voice_updates_voice_states_test() ->
?assert(false)
end.
update_member_voice_member_not_found_test() ->
State = voice_member_test_state(#{}),
Request = #{user_id => 999, mute => true, deaf => false},
{reply, Error, _} = update_member_voice(Request, State),
?assertEqual({error, not_found, voice_member_not_found}, Error).
voice_member_test_state(Overrides) ->
BaseData = #{
<<"members">> => [member_fixture(10)]
<<"members">> => #{
10 => member_fixture(10)
}
},
BaseState = #{
id => 42,

View File

@@ -19,9 +19,16 @@
-export([move_member/2]).
-export([send_voice_server_update_for_move/5]).
-export([send_voice_server_update_for_move/6]).
-export([send_voice_server_updates_for_move/4]).
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-endif.
-type guild_state() :: map().
-type voice_state() :: map().
-type voice_state_map() :: #{binary() => voice_state()}.
-type move_request() :: #{
user_id := integer(),
moderator_id := integer(),
@@ -31,10 +38,6 @@
deaf := boolean()
}.
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-endif.
-spec move_member(move_request(), guild_state()) -> {reply, map(), guild_state()}.
move_member(Request, State) ->
#{
@@ -44,10 +47,17 @@ move_member(Request, State) ->
} = Request,
ConnectionId = maps:get(connection_id, Request, null),
ChannelId = normalize_channel_id(ChannelIdRaw),
logger:debug(
"Handling voice move_member request",
#{
user_id => UserId,
moderator_id => ModeratorId,
channel_id => ChannelId,
connection_id => ConnectionId
}
),
VoiceStates = voice_state_utils:voice_states(State),
UserVoiceStates = find_user_voice_states(UserId, VoiceStates),
case maps:size(UserVoiceStates) of
0 ->
{reply, gateway_errors:error(voice_user_not_in_voice), State};
@@ -55,11 +65,20 @@ move_member(Request, State) ->
ConnectionsToMove = select_connections_to_move(
ConnectionId, UserId, VoiceStates, UserVoiceStates
),
logger:debug(
"Selected voice connections to move",
#{
user_id => UserId,
connection_id => ConnectionId,
connections_to_move_count => maps:size(ConnectionsToMove)
}
),
handle_move(
ConnectionsToMove, ChannelId, UserId, ModeratorId, ConnectionId, VoiceStates, State
)
end.
-spec find_user_voice_states(integer(), voice_state_map()) -> voice_state_map().
find_user_voice_states(UserId, VoiceStates) ->
maps:filter(
fun(_ConnId, VoiceState) ->
@@ -68,6 +87,8 @@ find_user_voice_states(UserId, VoiceStates) ->
VoiceStates
).
-spec select_connections_to_move(binary() | null, integer(), voice_state_map(), voice_state_map()) ->
voice_state_map().
select_connections_to_move(null, _UserId, _VoiceStates, UserVoiceStates) ->
UserVoiceStates;
select_connections_to_move(ConnectionId, UserId, VoiceStates, _UserVoiceStates) ->
@@ -83,11 +104,16 @@ select_connections_to_move(ConnectionId, UserId, VoiceStates, _UserVoiceStates)
end
end.
-spec handle_move(
voice_state_map(),
integer() | null,
integer(),
integer(),
binary() | null,
voice_state_map(),
guild_state()
) -> {reply, map(), guild_state()}.
handle_move(ConnectionsToMove, ChannelId, UserId, ModeratorId, ConnectionId, VoiceStates, State) ->
logger:info(
"[guild_voice_move] handle_move user_id=~p moderator_id=~p channel_id=~p connection_id=~p connections=~p",
[UserId, ModeratorId, ChannelId, ConnectionId, maps:keys(ConnectionsToMove)]
),
case maps:size(ConnectionsToMove) of
0 ->
Error =
@@ -99,14 +125,24 @@ handle_move(ConnectionsToMove, ChannelId, UserId, ModeratorId, ConnectionId, Voi
_ ->
case ChannelId of
null ->
logger:debug(
"Disconnect move requested",
#{user_id => UserId, connection_id => ConnectionId}
),
handle_disconnect_move(ConnectionsToMove, UserId, VoiceStates, State);
ChannelIdValue ->
logger:debug(
"Channel move requested",
#{user_id => UserId, channel_id => ChannelIdValue, connection_id => ConnectionId}
),
handle_channel_move(
ConnectionsToMove, ChannelIdValue, UserId, ModeratorId, VoiceStates, State
)
end
end.
-spec handle_disconnect_move(voice_state_map(), integer(), voice_state_map(), guild_state()) ->
{reply, map(), guild_state()}.
handle_disconnect_move(ConnectionsToMove, UserId, VoiceStates, State) ->
NewVoiceStates = maps:fold(
fun(ConnId, _VoiceState, Acc) -> maps:remove(ConnId, Acc) end,
@@ -114,79 +150,107 @@ handle_disconnect_move(ConnectionsToMove, UserId, VoiceStates, State) ->
ConnectionsToMove
),
NewState = maps:put(voice_states, NewVoiceStates, State),
maps:foreach(
fun(_ConnId, VoiceState) ->
OldChannelIdBin = maps:get(<<"channel_id">>, VoiceState, null),
DisconnectVoiceState = maps:put(<<"channel_id">>, null, VoiceState),
guild_voice_broadcast:broadcast_voice_state_update(
DisconnectVoiceState, NewState, OldChannelIdBin
)
end,
ConnectionsToMove
),
spawn(fun() ->
maps:foreach(
fun(_ConnId, VoiceState) ->
OldChannelIdBin = maps:get(<<"channel_id">>, VoiceState, null),
DisconnectVoiceState = maps:put(<<"channel_id">>, null, VoiceState),
guild_voice_broadcast:broadcast_voice_state_update(
DisconnectVoiceState, NewState, OldChannelIdBin
)
end,
ConnectionsToMove
)
end),
{reply, #{success => true, user_id => UserId, connections_moved => ConnectionsToMove},
NewState}.
-spec handle_channel_move(
voice_state_map(), integer(), integer(), integer(), voice_state_map(), guild_state()
) -> {reply, map(), guild_state()}.
handle_channel_move(ConnectionsToMove, ChannelIdValue, UserId, ModeratorId, VoiceStates, State) ->
logger:info(
"[guild_voice_move] handle_channel_move user_id=~p moderator_id=~p target_channel_id=~p connections=~p",
[UserId, ModeratorId, ChannelIdValue, maps:keys(ConnectionsToMove)]
),
Channel = guild_voice_member:find_channel_by_id(ChannelIdValue, State),
case Channel of
undefined ->
{reply, gateway_errors:error(voice_channel_not_found), State};
_ ->
StateWithPending0 = guild_virtual_channel_access:mark_pending_join(
UserId, ChannelIdValue, State
),
StateWithPending1 = guild_virtual_channel_access:mark_preserve(
UserId, ChannelIdValue, StateWithPending0
),
StateWithPending2 = guild_virtual_channel_access:mark_move_pending(
UserId, ChannelIdValue, StateWithPending1
),
ChannelType = maps:get(<<"type">>, Channel, 0),
case ChannelType of
2 ->
check_move_permissions_and_execute(
ConnectionsToMove, ChannelIdValue, UserId, ModeratorId, VoiceStates, State
ConnectionsToMove,
ChannelIdValue,
UserId,
ModeratorId,
VoiceStates,
StateWithPending2
);
_ ->
{reply, gateway_errors:error(voice_channel_not_voice), State}
end
end.
-spec check_move_permissions_and_execute(
voice_state_map(), integer(), integer(), integer(), voice_state_map(), guild_state()
) -> {reply, map(), guild_state()}.
check_move_permissions_and_execute(
ConnectionsToMove, ChannelIdValue, _UserId, ModeratorId, VoiceStates, State
ConnectionsToMove, ChannelIdValue, UserId, ModeratorId, VoiceStates, State
) ->
ViewPerm = constants:view_channel_permission(),
ConnectPerm = constants:connect_permission(),
ModPerms = guild_permissions:get_member_permissions(ModeratorId, ChannelIdValue, State),
ModHasConnect = (ModPerms band ConnectPerm) =:= ConnectPerm,
ModHasView = (ModPerms band ViewPerm) =:= ViewPerm,
case ModHasConnect andalso ModHasView of
false ->
{reply, gateway_errors:error(voice_moderator_missing_connect), State};
true ->
execute_move(ConnectionsToMove, VoiceStates, State)
execute_move(ConnectionsToMove, ChannelIdValue, UserId, VoiceStates, State)
end.
execute_move(ConnectionsToMove, VoiceStates, State) ->
-spec execute_move(voice_state_map(), integer(), integer(), voice_state_map(), guild_state()) ->
{reply, map(), guild_state()}.
execute_move(ConnectionsToMove, ChannelIdValue, UserId, VoiceStates, State) ->
StatePending = guild_virtual_channel_access:mark_pending_join(UserId, ChannelIdValue, State),
StatePending2 = guild_virtual_channel_access:mark_preserve(
UserId, ChannelIdValue, StatePending
),
StatePending3 = guild_virtual_channel_access:mark_move_pending(
UserId, ChannelIdValue, StatePending2
),
logger:debug(
"Executing voice channel move",
#{user_id => UserId, channel_id => ChannelIdValue}
),
NewVoiceStates = maps:fold(
fun(ConnId, _VoiceState, Acc) -> maps:remove(ConnId, Acc) end,
VoiceStates,
ConnectionsToMove
),
StateAfterDisconnect = maps:put(voice_states, NewVoiceStates, State),
maps:foreach(
fun(_ConnId, VoiceState) ->
OldChannelIdBin = maps:get(<<"channel_id">>, VoiceState, null),
DisconnectVoiceState = maps:put(<<"channel_id">>, null, VoiceState),
guild_voice_broadcast:broadcast_voice_state_update(
DisconnectVoiceState, StateAfterDisconnect, OldChannelIdBin
)
end,
ConnectionsToMove
),
StateAfterDisconnect = maps:put(voice_states, NewVoiceStates, StatePending3),
StateWithVirtualAccess = maybe_add_virtual_access(UserId, ChannelIdValue, StateAfterDisconnect),
spawn(fun() ->
maps:foreach(
fun(_ConnId, VoiceState) ->
OldChannelIdBin = maps:get(<<"channel_id">>, VoiceState, null),
DisconnectVoiceState = maps:put(<<"channel_id">>, null, VoiceState),
guild_voice_broadcast:broadcast_voice_state_update(
DisconnectVoiceState, StateWithVirtualAccess, OldChannelIdBin
)
end,
ConnectionsToMove
)
end),
SessionData = extract_session_data(ConnectionsToMove),
{reply,
#{
success => true,
@@ -194,8 +258,9 @@ execute_move(ConnectionsToMove, VoiceStates, State) ->
session_data => SessionData,
connections_to_move => ConnectionsToMove
},
StateAfterDisconnect}.
StateWithVirtualAccess}.
-spec extract_session_data(voice_state_map()) -> [map()].
extract_session_data(ConnectionsToMove) ->
{_ConnectionIds, SessionData} = maps:fold(
fun(ConnId, VoiceState, {AccConnIds, AccSessionData}) ->
@@ -223,6 +288,159 @@ member_user_id(Member) ->
User = map_utils:ensure_map(maps:get(<<"user">>, map_utils:ensure_map(Member), #{})),
map_utils:get_integer(User, <<"id">>, undefined).
-spec send_voice_server_update_for_move(
integer(), integer(), integer(), binary() | undefined, pid()
) -> ok.
send_voice_server_update_for_move(GuildId, ChannelId, UserId, SessionId, GuildPid) ->
send_voice_server_update_for_move(GuildId, ChannelId, UserId, SessionId, null, GuildPid).
-spec send_voice_server_update_for_move(
integer(), integer(), integer(), binary() | undefined, binary() | null, pid()
) -> ok.
send_voice_server_update_for_move(GuildId, ChannelId, UserId, SessionId, OldConnectionId, GuildPid) ->
case SessionId of
undefined ->
ok;
_ ->
spawn(fun() ->
case gen_server:call(GuildPid, {get_sessions}, 10000) of
State when is_map(State) ->
VoicePermissions = voice_utils:compute_voice_permissions(
UserId, ChannelId, State
),
case
guild_voice_connection:request_voice_token(
GuildId, ChannelId, UserId, OldConnectionId, VoicePermissions
)
of
{ok, TokenData} ->
Token = maps:get(token, TokenData),
Endpoint = maps:get(endpoint, TokenData),
ConnectionId = maps:get(connection_id, TokenData),
guild_voice_broadcast:broadcast_voice_server_update_to_session(
GuildId,
ChannelId,
SessionId,
Token,
Endpoint,
ConnectionId,
State
);
{error, _Reason} ->
ok
end;
_ ->
ok
end
end),
ok
end.
-spec maybe_add_virtual_access(integer(), integer(), guild_state()) -> guild_state().
maybe_add_virtual_access(UserId, ChannelId, State) ->
Member = guild_permissions:find_member_by_user_id(UserId, State),
case Member of
undefined ->
State;
_ ->
Permissions = guild_permissions:get_member_permissions(UserId, ChannelId, State),
ViewPerm = constants:view_channel_permission(),
ConnectPerm = constants:connect_permission(),
HasView = (Permissions band ViewPerm) =:= ViewPerm,
HasConnect = (Permissions band ConnectPerm) =:= ConnectPerm,
case HasView andalso HasConnect of
true ->
State;
false ->
NewState = guild_virtual_channel_access:add_virtual_access(
UserId, ChannelId, State
),
guild_virtual_channel_access:dispatch_channel_visibility_change(
UserId, ChannelId, add, NewState
),
NewState
end
end.
-spec send_voice_server_updates_for_move(integer(), integer(), [map()], pid()) -> ok.
send_voice_server_updates_for_move(GuildId, ChannelId, SessionDataList, GuildPid) ->
spawn(fun() ->
lists:foreach(
fun(SessionInfo) ->
send_single_voice_server_update(GuildId, ChannelId, SessionInfo, GuildPid)
end,
SessionDataList
)
end),
ok.
-spec send_single_voice_server_update(integer(), integer(), map(), pid()) -> ok.
send_single_voice_server_update(GuildId, ChannelId, SessionInfo, GuildPid) ->
SessionId = maps:get(session_id, SessionInfo),
SelfMute = maps:get(self_mute, SessionInfo),
SelfDeaf = maps:get(self_deaf, SessionInfo),
SelfVideo = maps:get(self_video, SessionInfo),
SelfStream = maps:get(self_stream, SessionInfo),
IsMobile = maps:get(is_mobile, SessionInfo),
OldConnectionId = maps:get(connection_id, SessionInfo, null),
Member = maps:get(member, SessionInfo),
ServerMute = maps:get(<<"mute">>, Member, false),
ServerDeaf = maps:get(<<"deaf">>, Member, false),
case member_user_id(Member) of
undefined ->
ok;
UserId ->
case gen_server:call(GuildPid, {get_sessions}, 10000) of
StateData when is_map(StateData) ->
VoicePermissions = voice_utils:compute_voice_permissions(
UserId, ChannelId, StateData
),
case
guild_voice_connection:request_voice_token(
GuildId, ChannelId, UserId, OldConnectionId, VoicePermissions
)
of
{ok, TokenData} ->
Token = maps:get(token, TokenData),
Endpoint = maps:get(endpoint, TokenData),
NewConnectionId = maps:get(connection_id, TokenData),
PendingMetadata = #{
<<"user_id">> => UserId,
<<"guild_id">> => GuildId,
<<"channel_id">> => ChannelId,
<<"connection_id">> => NewConnectionId,
<<"session_id">> => SessionId,
<<"self_mute">> => SelfMute,
<<"self_deaf">> => SelfDeaf,
<<"self_video">> => SelfVideo,
<<"self_stream">> => SelfStream,
<<"is_mobile">> => IsMobile,
<<"server_mute">> => ServerMute,
<<"server_deaf">> => ServerDeaf,
<<"member">> => Member
},
_ = gen_server:call(
GuildPid,
{store_pending_connection, NewConnectionId, PendingMetadata},
10000
),
guild_voice_broadcast:broadcast_voice_server_update_to_session(
GuildId,
ChannelId,
SessionId,
Token,
Endpoint,
NewConnectionId,
StateData
);
{error, _Reason} ->
ok
end;
_ ->
ok
end
end.
-ifdef(TEST).
move_member_user_not_in_voice_test() ->
@@ -253,6 +471,12 @@ select_connections_to_move_specific_connection_test() ->
?assertEqual(#{<<"conn-b">> => maps:get(<<"conn-b">>, VoiceStates)}, Selected),
?assertEqual(#{}, select_connections_to_move(<<"conn-b">>, 10, VoiceStates, #{})).
normalize_channel_id_test() ->
?assertEqual(null, normalize_channel_id(null)),
?assertEqual(123, normalize_channel_id(123)),
?assertEqual(456, normalize_channel_id(<<"456">>)),
?assertEqual(null, normalize_channel_id(undefined)).
test_state(VoiceStates) ->
#{
id => 1,
@@ -274,105 +498,3 @@ voice_state_fixture(UserId, ChannelId, ConnId) ->
}.
-endif.
send_voice_server_update_for_move(GuildId, ChannelId, UserId, SessionId, GuildPid) ->
case SessionId of
undefined ->
ok;
_ ->
case gen_server:call(GuildPid, {get_sessions}, 10000) of
State when is_map(State) ->
VoicePermissions = voice_utils:compute_voice_permissions(
UserId, ChannelId, State
),
case
guild_voice_connection:request_voice_token(
GuildId, ChannelId, UserId, VoicePermissions
)
of
{ok, TokenData} ->
Token = maps:get(token, TokenData),
Endpoint = maps:get(endpoint, TokenData),
ConnectionId = maps:get(connection_id, TokenData),
guild_voice_broadcast:broadcast_voice_server_update_to_session(
GuildId, SessionId, Token, Endpoint, ConnectionId, State
);
{error, _Reason} ->
ok
end;
_ ->
ok
end
end.
send_voice_server_updates_for_move(GuildId, ChannelId, SessionDataList, GuildPid) ->
lists:foreach(
fun(SessionInfo) ->
send_single_voice_server_update(GuildId, ChannelId, SessionInfo, GuildPid)
end,
SessionDataList
).
send_single_voice_server_update(GuildId, ChannelId, SessionInfo, GuildPid) ->
SessionId = maps:get(session_id, SessionInfo),
SelfMute = maps:get(self_mute, SessionInfo),
SelfDeaf = maps:get(self_deaf, SessionInfo),
SelfVideo = maps:get(self_video, SessionInfo),
SelfStream = maps:get(self_stream, SessionInfo),
IsMobile = maps:get(is_mobile, SessionInfo),
Member = maps:get(member, SessionInfo),
ServerMute = maps:get(<<"mute">>, Member, false),
ServerDeaf = maps:get(<<"deaf">>, Member, false),
case member_user_id(Member) of
undefined ->
logger:warning(
"[guild_voice_move] Missing user_id in member while sending voice server update: ~p",
[SessionInfo]
),
ok;
UserId ->
case gen_server:call(GuildPid, {get_sessions}, 10000) of
StateData when is_map(StateData) ->
VoicePermissions = voice_utils:compute_voice_permissions(
UserId, ChannelId, StateData
),
case
guild_voice_connection:request_voice_token(
GuildId, ChannelId, UserId, VoicePermissions
)
of
{ok, TokenData} ->
Token = maps:get(token, TokenData),
Endpoint = maps:get(endpoint, TokenData),
NewConnectionId = maps:get(connection_id, TokenData),
PendingMetadata = #{
<<"user_id">> => UserId,
<<"guild_id">> => GuildId,
<<"channel_id">> => ChannelId,
<<"connection_id">> => NewConnectionId,
<<"session_id">> => SessionId,
<<"self_mute">> => SelfMute,
<<"self_deaf">> => SelfDeaf,
<<"self_video">> => SelfVideo,
<<"self_stream">> => SelfStream,
<<"is_mobile">> => IsMobile,
<<"server_mute">> => ServerMute,
<<"server_deaf">> => ServerDeaf,
<<"member">> => Member
},
gen_server:cast(
GuildPid,
{store_pending_connection, NewConnectionId, PendingMetadata}
),
guild_voice_broadcast:broadcast_voice_server_update_to_session(
GuildId, SessionId, Token, Endpoint, NewConnectionId, StateData
);
{error, _Reason} ->
ok
end;
_ ->
ok
end
end.

View File

@@ -26,23 +26,23 @@
-type guild_state() :: map().
-type voice_state() :: map().
-type user_id() :: integer().
-type channel_id() :: integer().
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-endif.
-spec sync_user_voice_permissions(integer(), guild_state()) -> ok.
-spec sync_user_voice_permissions(user_id(), guild_state()) -> ok.
sync_user_voice_permissions(UserId, State) ->
VoiceStates = voice_state_utils:voice_states(State),
GuildId = map_utils:get_integer(State, id, 0),
UserVoiceStates = maps:filter(
fun(_ConnId, VoiceState) ->
voice_state_utils:voice_state_user_id(VoiceState) =:= UserId
end,
VoiceStates
),
maps:foreach(
fun(_ConnId, VoiceState) ->
sync_voice_state_permissions(GuildId, UserId, VoiceState, State)
@@ -51,18 +51,16 @@ sync_user_voice_permissions(UserId, State) ->
),
ok.
-spec sync_all_voice_permissions_for_channel(integer(), guild_state()) -> ok.
-spec sync_all_voice_permissions_for_channel(channel_id(), guild_state()) -> ok.
sync_all_voice_permissions_for_channel(ChannelId, State) ->
VoiceStates = voice_state_utils:voice_states(State),
GuildId = map_utils:get_integer(State, id, 0),
ChannelVoiceStates = maps:filter(
fun(_ConnId, VoiceState) ->
voice_state_utils:voice_state_channel_id(VoiceState) =:= ChannelId
end,
VoiceStates
),
maps:foreach(
fun(_ConnId, VoiceState) ->
UserId = voice_state_utils:voice_state_user_id(VoiceState),
@@ -84,15 +82,12 @@ maybe_sync_permissions_on_role_update(RoleUpdate, State) ->
_ ->
OldPermissions = maps:get(<<"old_permissions">>, RoleUpdate, 0),
NewPermissions = maps:get(<<"permissions">>, RoleUpdate, 0),
AdminPerm = constants:administrator_permission(),
SpeakPerm = constants:speak_permission(),
StreamPerm = constants:stream_permission(),
VoicePerms = AdminPerm bor SpeakPerm bor StreamPerm,
OldVoicePerms = OldPermissions band VoicePerms,
NewVoicePerms = NewPermissions band VoicePerms,
case OldVoicePerms =/= NewVoicePerms of
true ->
sync_users_with_role(RoleId, State);
@@ -110,7 +105,6 @@ maybe_sync_permissions_on_member_update(MemberUpdate, State) ->
_ ->
OldRoles = maps:get(<<"old_roles">>, MemberUpdate, []),
NewRoles = maps:get(<<"roles">>, MemberUpdate, []),
case OldRoles =/= NewRoles of
true ->
sync_user_voice_permissions(UserId, State);
@@ -119,11 +113,10 @@ maybe_sync_permissions_on_member_update(MemberUpdate, State) ->
end
end.
-spec sync_voice_state_permissions(integer(), integer(), voice_state(), guild_state()) -> ok.
-spec sync_voice_state_permissions(integer(), user_id(), voice_state(), guild_state()) -> ok.
sync_voice_state_permissions(GuildId, UserId, VoiceState, State) ->
ChannelId = voice_state_utils:voice_state_channel_id(VoiceState),
ConnectionId = maps:get(<<"connection_id">>, VoiceState, undefined),
case {ChannelId, ConnectionId} of
{undefined, _} ->
ok;
@@ -134,7 +127,9 @@ sync_voice_state_permissions(GuildId, UserId, VoiceState, State) ->
dispatch_permission_update(GuildId, ChId, UserId, ConnId, VoicePermissions, State)
end.
-spec dispatch_permission_update(integer(), integer(), integer(), binary(), map(), guild_state()) ->
-spec dispatch_permission_update(
integer(), channel_id(), user_id(), binary(), map(), guild_state()
) ->
ok.
dispatch_permission_update(GuildId, ChannelId, UserId, ConnectionId, VoicePermissions, State) ->
case maps:get(test_permission_sync_fun, State, undefined) of
@@ -149,7 +144,7 @@ dispatch_permission_update(GuildId, ChannelId, UserId, ConnectionId, VoicePermis
end.
-spec enforce_voice_permissions_in_livekit(
integer(), integer(), integer(), binary(), map()
integer(), channel_id(), user_id(), binary(), map()
) -> ok.
enforce_voice_permissions_in_livekit(GuildId, ChannelId, UserId, ConnectionId, VoicePermissions) ->
Req = voice_utils:build_update_participant_permissions_rpc_request(
@@ -157,33 +152,8 @@ enforce_voice_permissions_in_livekit(GuildId, ChannelId, UserId, ConnectionId, V
),
case rpc_client:call(Req) of
{ok, _Data} ->
logger:debug(
"[guild_voice_permission_sync] Synced voice permissions ~p",
[
[
{guildId, GuildId},
{channelId, ChannelId},
{userId, UserId},
{connectionId, ConnectionId},
{permissions, VoicePermissions}
]
]
),
ok;
{error, Reason} ->
logger:warning(
"[guild_voice_permission_sync] Failed to sync voice permissions ~p",
[
[
{guildId, GuildId},
{channelId, ChannelId},
{userId, UserId},
{connectionId, ConnectionId},
{permissions, VoicePermissions},
{error, Reason}
]
]
),
{error, _Reason} ->
ok
end.
@@ -192,7 +162,6 @@ sync_users_with_role(RoleId, State) ->
RoleIdBin = ensure_binary(RoleId),
VoiceStates = voice_state_utils:voice_states(State),
GuildId = map_utils:get_integer(State, id, 0),
maps:foreach(
fun(_ConnId, VoiceState) ->
UserId = voice_state_utils:voice_state_user_id(VoiceState),
@@ -212,7 +181,7 @@ sync_users_with_role(RoleId, State) ->
),
ok.
-spec user_has_role(integer(), binary(), guild_state()) -> boolean().
-spec user_has_role(user_id(), binary(), guild_state()) -> boolean().
user_has_role(UserId, RoleIdBin, State) ->
case guild_voice_member:find_member_by_user_id(UserId, State) of
undefined ->
@@ -222,7 +191,7 @@ user_has_role(UserId, RoleIdBin, State) ->
lists:member(RoleIdBin, Roles)
end.
-spec get_member_user_id(map()) -> integer() | undefined.
-spec get_member_user_id(map()) -> user_id() | undefined.
get_member_user_id(MemberUpdate) ->
User = maps:get(<<"user">>, MemberUpdate, #{}),
map_utils:get_integer(User, <<"id">>, undefined).
@@ -242,19 +211,16 @@ sync_user_voice_permissions_syncs_connected_user_test() ->
ChannelId = 500,
GuildId = 42,
RoleId = 999,
VoiceState = #{
<<"user_id">> => integer_to_binary(UserId),
<<"channel_id">> => integer_to_binary(ChannelId),
<<"connection_id">> => <<"test-conn">>
},
Permissions =
constants:view_channel_permission() bor
constants:connect_permission() bor
constants:speak_permission() bor
constants:stream_permission(),
State = #{
id => GuildId,
voice_states => #{<<"conn">> => VoiceState},
@@ -301,4 +267,17 @@ sync_user_voice_permissions_no_voice_state_test() ->
},
ok = sync_user_voice_permissions(10, State).
maybe_sync_permissions_on_member_update_no_role_change_test() ->
State = #{id => 42, voice_states => #{}},
MemberUpdate = #{
<<"user">> => #{<<"id">> => <<"10">>},
<<"roles">> => [<<"1">>],
<<"old_roles">> => [<<"1">>]
},
?assertEqual(ok, maybe_sync_permissions_on_member_update(MemberUpdate, State)).
ensure_binary_test() ->
?assertEqual(<<"123">>, ensure_binary(123)),
?assertEqual(<<"abc">>, ensure_binary(<<"abc">>)).
-endif.

View File

@@ -19,15 +19,15 @@
-export([check_voice_permissions_and_limits/6]).
-type guild_state() :: map().
-type voice_state_map() :: #{binary() => map()}.
-type channel() :: map().
-import(utils, [parse_iso8601_to_unix_ms/1]).
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-endif.
-import(utils, [parse_iso8601_to_unix_ms/1]).
-type guild_state() :: map().
-type voice_state_map() :: #{binary() => map()}.
-type channel() :: map().
-spec check_voice_permissions_and_limits(
integer(), integer(), channel(), voice_state_map(), guild_state(), boolean()
@@ -45,8 +45,10 @@ check_voice_permissions_and_limits(UserId, ChannelIdValue, Channel, VoiceStates,
case
channel_has_capacity(UserId, ChannelIdValue, Channel, VoiceStates, IsUpdate)
of
true -> {ok, allowed};
false -> gateway_errors:error(voice_channel_full)
true ->
{ok, allowed};
false ->
gateway_errors:error(voice_channel_full)
end
end
end.
@@ -57,18 +59,26 @@ has_view_and_connect_perms(UserId, ChannelIdValue, State) ->
true ->
true;
false ->
Permissions = resolve_permissions(UserId, ChannelIdValue, State),
ViewPerm = constants:view_channel_permission(),
ConnectPerm = constants:connect_permission(),
(Permissions band ViewPerm) =:= ViewPerm andalso
(Permissions band ConnectPerm) =:= ConnectPerm
case guild_virtual_channel_access:is_move_pending(UserId, ChannelIdValue, State) of
true ->
true;
false ->
Permissions = resolve_permissions(UserId, ChannelIdValue, State),
ViewPerm = constants:view_channel_permission(),
ConnectPerm = constants:connect_permission(),
HasView = (Permissions band ViewPerm) =:= ViewPerm,
HasConnect = (Permissions band ConnectPerm) =:= ConnectPerm,
HasView andalso HasConnect
end
end.
-spec channel_has_capacity(integer(), integer(), channel(), voice_state_map(), boolean()) ->
boolean().
channel_has_capacity(UserId, ChannelIdValue, Channel, VoiceStates, IsUpdate) ->
UserLimit = maps:get(<<"user_limit">>, Channel, 0),
case UserLimit of
AnyCameraActive = any_camera_active_in_channel(ChannelIdValue, VoiceStates),
EffectiveLimit = effective_user_limit(UserLimit, AnyCameraActive),
case EffectiveLimit of
0 ->
true;
Limit when Limit > 0 ->
@@ -85,6 +95,26 @@ channel_has_capacity(UserId, ChannelIdValue, Channel, VoiceStates, IsUpdate) ->
true
end.
-spec any_camera_active_in_channel(integer(), voice_state_map()) -> boolean().
any_camera_active_in_channel(ChannelIdValue, VoiceStates) ->
lists:any(
fun({_ConnId, VS}) ->
case map_utils:get_integer(VS, <<"channel_id">>, undefined) of
ChannelIdValue ->
maps:get(<<"self_video">>, VS, false) =:= true;
_ ->
false
end
end,
maps:to_list(VoiceStates)
).
-spec effective_user_limit(integer(), boolean()) -> integer().
effective_user_limit(0, false) -> 0;
effective_user_limit(0, true) -> 25;
effective_user_limit(Limit, false) -> Limit;
effective_user_limit(Limit, true) -> min(Limit, 25).
-spec is_member_timed_out(integer(), guild_state()) -> boolean().
is_member_timed_out(UserId, State) ->
case guild_permissions:find_member_by_user_id(UserId, State) of
@@ -98,13 +128,11 @@ is_member_timed_out(UserId, State) ->
undefined ->
false;
Value when is_integer(Value) ->
Value > erlang:system_time(millisecond);
_ ->
false
Value > erlang:system_time(millisecond)
end
end.
-spec users_in_channel(integer(), voice_state_map()) -> sets:set().
-spec users_in_channel(integer(), voice_state_map()) -> sets:set(integer()).
users_in_channel(ChannelIdValue, VoiceStates0) ->
VoiceStates = voice_state_utils:ensure_voice_states(VoiceStates0),
maps:fold(
@@ -161,6 +189,18 @@ voice_permissions_existing_user_update_test() ->
),
?assertEqual({ok, allowed}, Result).
users_in_channel_test() ->
VoiceStates = #{
<<"conn1">> => #{<<"channel_id">> => <<"10">>, <<"user_id">> => <<"1">>},
<<"conn2">> => #{<<"channel_id">> => <<"10">>, <<"user_id">> => <<"2">>},
<<"conn3">> => #{<<"channel_id">> => <<"20">>, <<"user_id">> => <<"3">>}
},
Result = users_in_channel(10, VoiceStates),
?assertEqual(2, sets:size(Result)),
?assert(sets:is_element(1, Result)),
?assert(sets:is_element(2, Result)),
?assertNot(sets:is_element(3, Result)).
required_voice_perms() ->
constants:view_channel_permission() bor constants:connect_permission().

View File

@@ -20,9 +20,17 @@
-export([switch_voice_region_handler/2]).
-export([switch_voice_region/3]).
-type guild_state() :: map().
-type guild_reply(T) :: {reply, T, guild_state()}.
-type voice_state() :: map().
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-endif.
-spec switch_voice_region_handler(map(), guild_state()) -> guild_reply(map()).
switch_voice_region_handler(Request, State) ->
#{channel_id := ChannelId} = Request,
Channel = guild_voice_member:find_channel_by_id(ChannelId, State),
case Channel of
undefined ->
@@ -37,42 +45,21 @@ switch_voice_region_handler(Request, State) ->
end
end.
-spec switch_voice_region(integer(), integer(), pid()) -> ok.
switch_voice_region(GuildId, ChannelId, GuildPid) ->
case gen_server:call(GuildPid, {get_sessions}, 10000) of
State when is_map(State) ->
VoiceStates = voice_state_utils:voice_states(State),
UsersInChannel = maps:fold(
fun(ConnectionId, VoiceState, Acc) ->
case voice_state_utils:voice_state_channel_id(VoiceState) of
ChannelId ->
case voice_state_utils:voice_state_user_id(VoiceState) of
undefined ->
logger:warning(
"[guild_voice_region] Missing user_id for connection ~p",
[ConnectionId]
),
Acc;
UserId ->
SessionId = maps:get(<<"session_id">>, VoiceState, undefined),
[{UserId, SessionId, VoiceState} | Acc]
end;
_ ->
Acc
end
end,
[],
VoiceStates
),
UsersInChannel = collect_users_in_channel(VoiceStates, ChannelId),
lists:foreach(
fun({UserId, SessionId, VoiceState}) ->
fun({UserId, SessionId, ExistingConnectionId, VoiceState}) ->
case SessionId of
undefined ->
ok;
_ ->
send_voice_server_update_for_region_switch(
GuildId, ChannelId, UserId, SessionId, VoiceState, GuildPid
GuildId, ChannelId, UserId, SessionId, ExistingConnectionId,
VoiceState, GuildPid
)
end
end,
@@ -82,42 +69,61 @@ switch_voice_region(GuildId, ChannelId, GuildPid) ->
ok
end.
-spec collect_users_in_channel(map(), integer()) ->
[{integer(), binary() | undefined, binary(), voice_state()}].
collect_users_in_channel(VoiceStates, ChannelId) ->
maps:fold(
fun(ConnectionId, VoiceState, Acc) ->
case voice_state_utils:voice_state_channel_id(VoiceState) of
ChannelId ->
case voice_state_utils:voice_state_user_id(VoiceState) of
undefined ->
Acc;
UserId ->
SessionId = maps:get(<<"session_id">>, VoiceState, undefined),
[{UserId, SessionId, ConnectionId, VoiceState} | Acc]
end;
_ ->
Acc
end
end,
[],
VoiceStates
).
-spec send_voice_server_update_for_region_switch(
integer(), integer(), integer(), binary(), binary(), voice_state(), pid()
) -> ok.
send_voice_server_update_for_region_switch(
GuildId, ChannelId, UserId, SessionId, ExistingVoiceState, GuildPid
GuildId, ChannelId, UserId, SessionId, ExistingConnectionId, ExistingVoiceState, GuildPid
) ->
case gen_server:call(GuildPid, {get_sessions}, 10000) of
State when is_map(State) ->
VoicePermissions = voice_utils:compute_voice_permissions(UserId, ChannelId, State),
TokenNonce = voice_utils:generate_token_nonce(),
case
guild_voice_connection:request_voice_token(
GuildId, ChannelId, UserId, VoicePermissions
GuildId, ChannelId, UserId, ExistingConnectionId, VoicePermissions, TokenNonce
)
of
{ok, TokenData} ->
Token = maps:get(token, TokenData),
Endpoint = maps:get(endpoint, TokenData),
ConnectionId = maps:get(connection_id, TokenData),
PendingMetadata = #{
user_id => UserId,
guild_id => GuildId,
channel_id => ChannelId,
session_id => SessionId,
self_mute => maps:get(<<"self_mute">>, ExistingVoiceState, false),
self_deaf => maps:get(<<"self_deaf">>, ExistingVoiceState, false),
self_video => maps:get(<<"self_video">>, ExistingVoiceState, false),
self_stream => maps:get(<<"self_stream">>, ExistingVoiceState, false),
is_mobile => maps:get(<<"is_mobile">>, ExistingVoiceState, false),
server_mute => maps:get(<<"mute">>, ExistingVoiceState, false),
server_deaf => maps:get(<<"deaf">>, ExistingVoiceState, false),
member => maps:get(<<"member">>, ExistingVoiceState, #{})
},
gen_server:cast(
GuildPid, {store_pending_connection, ConnectionId, PendingMetadata}
PendingMetadata = build_pending_metadata(
UserId, GuildId, ChannelId, SessionId, ExistingVoiceState, TokenNonce
),
_ = gen_server:call(
GuildPid, {store_pending_connection, ConnectionId, PendingMetadata}, 10000
),
guild_voice_broadcast:broadcast_voice_server_update_to_session(
GuildId, SessionId, Token, Endpoint, ConnectionId, State
GuildId,
ChannelId,
SessionId,
Token,
Endpoint,
ConnectionId,
State
);
{error, _Reason} ->
ok
@@ -125,3 +131,73 @@ send_voice_server_update_for_region_switch(
_ ->
ok
end.
-spec build_pending_metadata(integer(), integer(), integer(), binary(), voice_state(), binary()) -> map().
build_pending_metadata(UserId, GuildId, ChannelId, SessionId, ExistingVoiceState, TokenNonce) ->
Now = erlang:system_time(millisecond),
#{
user_id => UserId,
guild_id => GuildId,
channel_id => ChannelId,
session_id => SessionId,
self_mute => maps:get(<<"self_mute">>, ExistingVoiceState, false),
self_deaf => maps:get(<<"self_deaf">>, ExistingVoiceState, false),
self_video => maps:get(<<"self_video">>, ExistingVoiceState, false),
self_stream => maps:get(<<"self_stream">>, ExistingVoiceState, false),
is_mobile => maps:get(<<"is_mobile">>, ExistingVoiceState, false),
server_mute => maps:get(<<"mute">>, ExistingVoiceState, false),
server_deaf => maps:get(<<"deaf">>, ExistingVoiceState, false),
member => maps:get(<<"member">>, ExistingVoiceState, #{}),
viewer_stream_keys => [],
token_nonce => TokenNonce,
created_at => Now,
expires_at => Now + 30000
}.
-ifdef(TEST).
switch_voice_region_handler_not_found_test() ->
State = #{data => #{<<"channels">> => []}},
Request = #{channel_id => 999},
{reply, Error, _} = switch_voice_region_handler(Request, State),
?assertEqual({error, not_found, voice_channel_not_found}, Error).
switch_voice_region_handler_not_voice_test() ->
State = #{
data => #{
<<"channels">> => [
#{<<"id">> => <<"100">>, <<"type">> => 0}
]
}
},
Request = #{channel_id => 100},
{reply, Error, _} = switch_voice_region_handler(Request, State),
?assertEqual({error, validation_error, voice_channel_not_voice}, Error).
switch_voice_region_handler_success_test() ->
State = #{
data => #{
<<"channels">> => [
#{<<"id">> => <<"100">>, <<"type">> => 2}
]
}
},
Request = #{channel_id => 100},
{reply, Reply, _} = switch_voice_region_handler(Request, State),
?assertEqual(true, maps:get(success, Reply)).
collect_users_in_channel_test() ->
VoiceState = #{
<<"channel_id">> => <<"100">>,
<<"user_id">> => <<"10">>,
<<"session_id">> => <<"sess1">>
},
VoiceStates = #{<<"conn1">> => VoiceState},
Result = collect_users_in_channel(VoiceStates, 100),
?assertEqual(1, length(Result)),
[{UserId, SessionId, ConnectionId, _}] = Result,
?assertEqual(10, UserId),
?assertEqual(<<"sess1">>, SessionId),
?assertEqual(<<"conn1">>, ConnectionId).
-endif.

View File

@@ -26,14 +26,14 @@
-export([create_voice_state/8]).
-export([extract_session_info_from_voice_state/2]).
-type guild_state() :: map().
-type voice_state() :: map().
-type voice_state_map() :: #{binary() => voice_state()}.
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-endif.
-type guild_state() :: map().
-type voice_state() :: map().
-type voice_state_map() :: #{binary() => voice_state()}.
-spec get_voice_state(map(), guild_state()) -> {reply, map(), guild_state()}.
get_voice_state(Request, State) ->
case maps:get(connection_id, Request, null) of
@@ -58,7 +58,7 @@ get_voice_states_list(State) ->
voice_state_map(),
guild_state(),
boolean(),
term()
list()
) -> {reply, map(), guild_state()}.
update_voice_state_data(
ConnectionId,
@@ -69,39 +69,85 @@ update_voice_state_data(
VoiceStates,
State,
NeedsToken,
ViewerStreamKey
ViewerStreamKeys
) ->
#voice_flags{
self_mute = SelfMute,
self_deaf = SelfDeaf,
self_video = SelfVideo,
self_stream = SelfStream,
is_mobile = IsMobile
#{
self_mute := SelfMute,
self_deaf := SelfDeaf,
self_video := SelfVideo,
self_stream := SelfStream,
is_mobile := IsMobile
} = Flags,
ServerMute = maps:get(<<"mute">>, Member, false),
ServerDeaf = maps:get(<<"deaf">>, Member, false),
OldVersion = maps:get(<<"version">>, ExistingVoiceState, 0),
UpdatedVoiceState = ExistingVoiceState#{
<<"channel_id">> => ChannelIdBin,
<<"mute">> => ServerMute,
<<"deaf">> => ServerDeaf,
<<"self_mute">> => SelfMute,
<<"self_deaf">> => SelfDeaf,
<<"self_video">> => SelfVideo,
<<"self_stream">> => SelfStream,
<<"is_mobile">> => IsMobile,
<<"viewer_stream_key">> => ViewerStreamKey,
<<"version">> => OldVersion + 1
},
NewVoiceStates = maps:put(ConnectionId, UpdatedVoiceState, VoiceStates),
NewState = maps:put(voice_states, NewVoiceStates, State),
guild_voice_broadcast:broadcast_voice_state_update(UpdatedVoiceState, NewState, ChannelIdBin),
Reply =
case NeedsToken of
true -> #{success => true, voice_state => UpdatedVoiceState, needs_token => true};
false -> #{success => true, voice_state => UpdatedVoiceState}
end,
{reply, Reply, NewState}.
OldChannelIdBin = maps:get(<<"channel_id">>, ExistingVoiceState, null),
IsChannelChange = OldChannelIdBin =/= ChannelIdBin,
HasStateChange = has_voice_state_change(
ExistingVoiceState, ChannelIdBin, ServerMute, ServerDeaf,
SelfMute, SelfDeaf, SelfVideo, SelfStream, IsMobile, ViewerStreamKeys
),
case HasStateChange of
false ->
Reply = #{success => true, voice_state => ExistingVoiceState},
{reply, Reply, State};
true ->
OldVersion = maps:get(<<"version">>, ExistingVoiceState, 0),
UpdatedVoiceState = ExistingVoiceState#{
<<"channel_id">> => ChannelIdBin,
<<"mute">> => ServerMute,
<<"deaf">> => ServerDeaf,
<<"self_mute">> => SelfMute,
<<"self_deaf">> => SelfDeaf,
<<"self_video">> => SelfVideo,
<<"self_stream">> => SelfStream,
<<"is_mobile">> => IsMobile,
<<"viewer_stream_keys">> => ViewerStreamKeys,
<<"version">> => OldVersion + 1
},
NewVoiceStates = maps:put(ConnectionId, UpdatedVoiceState, VoiceStates),
NewState = maps:put(voice_states, NewVoiceStates, State),
case IsChannelChange of
true ->
DisconnectState = ExistingVoiceState#{
<<"channel_id">> => null,
<<"connection_id">> => ConnectionId
},
guild_voice_broadcast:broadcast_voice_state_update(
DisconnectState, NewState, OldChannelIdBin
),
guild_voice_broadcast:broadcast_voice_state_update(
UpdatedVoiceState, NewState, ChannelIdBin
);
false ->
guild_voice_broadcast:broadcast_voice_state_update(
UpdatedVoiceState, NewState, ChannelIdBin
)
end,
Reply =
case NeedsToken of
true -> #{success => true, voice_state => UpdatedVoiceState, needs_token => true};
false -> #{success => true, voice_state => UpdatedVoiceState}
end,
{reply, Reply, NewState}
end.
-spec has_voice_state_change(
voice_state(), binary(), boolean(), boolean(),
boolean(), boolean(), boolean(), boolean(), boolean(), term()
) -> boolean().
has_voice_state_change(
ExistingVoiceState, ChannelIdBin, ServerMute, ServerDeaf,
SelfMute, SelfDeaf, SelfVideo, SelfStream, IsMobile, ViewerStreamKeys
) ->
maps:get(<<"channel_id">>, ExistingVoiceState, null) =/= ChannelIdBin orelse
maps:get(<<"mute">>, ExistingVoiceState, false) =/= ServerMute orelse
maps:get(<<"deaf">>, ExistingVoiceState, false) =/= ServerDeaf orelse
maps:get(<<"self_mute">>, ExistingVoiceState, false) =/= SelfMute orelse
maps:get(<<"self_deaf">>, ExistingVoiceState, false) =/= SelfDeaf orelse
maps:get(<<"self_video">>, ExistingVoiceState, false) =/= SelfVideo orelse
maps:get(<<"self_stream">>, ExistingVoiceState, false) =/= SelfStream orelse
maps:get(<<"is_mobile">>, ExistingVoiceState, false) =/= IsMobile orelse
maps:get(<<"viewer_stream_keys">>, ExistingVoiceState, []) =/= ViewerStreamKeys.
-spec user_matches_voice_state(voice_state(), integer() | binary()) -> boolean().
user_matches_voice_state(VoiceState, UserId) when is_integer(UserId) ->
@@ -122,7 +168,7 @@ user_matches_voice_state(_VoiceState, _UserId) ->
boolean(),
boolean(),
voice_flags(),
term()
list()
) -> voice_state().
create_voice_state(
GuildIdBin,
@@ -132,14 +178,14 @@ create_voice_state(
ServerMute,
ServerDeaf,
Flags,
ViewerStreamKey
ViewerStreamKeys
) ->
#voice_flags{
self_mute = SelfMute,
self_deaf = SelfDeaf,
self_video = SelfVideo,
self_stream = SelfStream,
is_mobile = IsMobile
#{
self_mute := SelfMute,
self_deaf := SelfDeaf,
self_video := SelfVideo,
self_stream := SelfStream,
is_mobile := IsMobile
} = Flags,
#{
<<"guild_id">> => GuildIdBin,
@@ -153,7 +199,7 @@ create_voice_state(
<<"self_video">> => SelfVideo,
<<"self_stream">> => SelfStream,
<<"is_mobile">> => IsMobile,
<<"viewer_stream_key">> => ViewerStreamKey,
<<"viewer_stream_keys">> => ViewerStreamKeys,
<<"version">> => 0
}.
@@ -177,29 +223,138 @@ user_matches_voice_state_integer_test() ->
?assert(user_matches_voice_state(VoiceState, 10)),
?assertNot(user_matches_voice_state(VoiceState, 11)).
update_voice_state_data_updates_version_test() ->
VoiceState = #{<<"version">> => 1, <<"channel_id">> => <<"1">>},
Member = #{<<"mute">> => true, <<"deaf">> => false},
Flags = #voice_flags{
self_mute = true,
self_deaf = false,
self_video = false,
self_stream = false,
is_mobile = false
user_matches_voice_state_binary_test() ->
VoiceState = #{<<"user_id">> => <<"123">>},
?assert(user_matches_voice_state(VoiceState, <<"123">>)),
?assertNot(user_matches_voice_state(VoiceState, <<"456">>)).
user_matches_voice_state_undefined_test() ->
VoiceState = #{},
?assertNot(user_matches_voice_state(VoiceState, 10)).
create_voice_state_test() ->
Flags = #{
self_mute => true,
self_deaf => false,
self_video => true,
self_stream => false,
is_mobile => true
},
{reply, #{voice_state := Updated}, _} =
update_voice_state_data(
<<"conn">>,
<<"2">>,
Flags,
Member,
VoiceState,
#{<<"conn">> => VoiceState},
#{voice_states => #{}},
false,
null
),
?assertEqual(2, maps:get(<<"version">>, Updated)),
?assertEqual(<<"2">>, maps:get(<<"channel_id">>, Updated)).
VS = create_voice_state(
<<"1">>, <<"2">>, <<"3">>, <<"conn">>, false, false, Flags, []
),
?assertEqual(<<"1">>, maps:get(<<"guild_id">>, VS)),
?assertEqual(<<"2">>, maps:get(<<"channel_id">>, VS)),
?assertEqual(<<"3">>, maps:get(<<"user_id">>, VS)),
?assertEqual(<<"conn">>, maps:get(<<"connection_id">>, VS)),
?assertEqual(true, maps:get(<<"self_mute">>, VS)),
?assertEqual(false, maps:get(<<"self_deaf">>, VS)),
?assertEqual(true, maps:get(<<"self_video">>, VS)),
?assertEqual(false, maps:get(<<"self_stream">>, VS)),
?assertEqual(true, maps:get(<<"is_mobile">>, VS)),
?assertEqual(0, maps:get(<<"version">>, VS)).
extract_session_info_from_voice_state_test() ->
VoiceState = #{
<<"session_id">> => <<"sess">>,
<<"self_mute">> => true,
<<"self_deaf">> => false,
<<"self_video">> => true,
<<"self_stream">> => false,
<<"is_mobile">> => true,
<<"member">> => #{<<"id">> => <<"m">>}
},
Info = extract_session_info_from_voice_state(<<"conn">>, VoiceState),
?assertEqual(<<"conn">>, maps:get(connection_id, Info)),
?assertEqual(<<"sess">>, maps:get(session_id, Info)),
?assertEqual(true, maps:get(self_mute, Info)),
?assertEqual(#{<<"id">> => <<"m">>}, maps:get(member, Info)).
has_voice_state_change_no_change_test() ->
ExistingVoiceState = #{
<<"channel_id">> => <<"100">>,
<<"mute">> => false,
<<"deaf">> => false,
<<"self_mute">> => true,
<<"self_deaf">> => false,
<<"self_video">> => false,
<<"self_stream">> => false,
<<"is_mobile">> => false,
<<"viewer_stream_keys">> => []
},
?assertNot(has_voice_state_change(
ExistingVoiceState, <<"100">>, false, false, true, false, false, false, false, []
)).
has_voice_state_change_channel_change_test() ->
ExistingVoiceState = #{
<<"channel_id">> => <<"100">>,
<<"mute">> => false,
<<"deaf">> => false,
<<"self_mute">> => false,
<<"self_deaf">> => false,
<<"self_video">> => false,
<<"self_stream">> => false,
<<"is_mobile">> => false,
<<"viewer_stream_keys">> => []
},
?assert(has_voice_state_change(
ExistingVoiceState, <<"200">>, false, false, false, false, false, false, false, []
)).
has_voice_state_change_self_mute_change_test() ->
ExistingVoiceState = #{
<<"channel_id">> => <<"100">>,
<<"mute">> => false,
<<"deaf">> => false,
<<"self_mute">> => false,
<<"self_deaf">> => false,
<<"self_video">> => false,
<<"self_stream">> => false,
<<"is_mobile">> => false,
<<"viewer_stream_keys">> => []
},
?assert(has_voice_state_change(
ExistingVoiceState, <<"100">>, false, false, true, false, false, false, false, []
)).
has_voice_state_change_server_mute_change_test() ->
ExistingVoiceState = #{
<<"channel_id">> => <<"100">>,
<<"mute">> => false,
<<"deaf">> => false,
<<"self_mute">> => false,
<<"self_deaf">> => false,
<<"self_video">> => false,
<<"self_stream">> => false,
<<"is_mobile">> => false,
<<"viewer_stream_keys">> => []
},
?assert(has_voice_state_change(
ExistingVoiceState, <<"100">>, true, false, false, false, false, false, false, []
)).
has_voice_state_change_viewer_stream_keys_change_test() ->
ExistingVoiceState = #{
<<"channel_id">> => <<"100">>,
<<"mute">> => false,
<<"deaf">> => false,
<<"self_mute">> => false,
<<"self_deaf">> => false,
<<"self_video">> => false,
<<"self_stream">> => false,
<<"is_mobile">> => false,
<<"viewer_stream_keys">> => []
},
?assert(has_voice_state_change(
ExistingVoiceState, <<"100">>, false, false, false, false, false, false, false,
[<<"999:100:conn">>]
)).
has_voice_state_change_defaults_test() ->
ExistingVoiceState = #{},
?assertNot(has_voice_state_change(
ExistingVoiceState, null, false, false, false, false, false, false, false, []
)).
-endif.

View File

@@ -0,0 +1,99 @@
%% Copyright (C) 2026 Fluxer Contributors
%%
%% This file is part of Fluxer.
%%
%% Fluxer is free software: you can redistribute it and/or modify
%% it under the terms of the GNU Affero General Public License as published by
%% the Free Software Foundation, either version 3 of the License, or
%% (at your option) any later version.
%%
%% Fluxer is distributed in the hope that it will be useful,
%% but WITHOUT ANY WARRANTY; without even the implied warranty of
%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
%% GNU Affero General Public License for more details.
%%
%% You should have received a copy of the GNU Affero General Public License
%% along with Fluxer. If not, see <https://www.gnu.org/licenses/>.
-module(guild_voice_unclaimed_account_utils).
-export([parse_unclaimed_error/1]).
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-endif.
-spec parse_unclaimed_error(iodata() | term()) -> boolean().
parse_unclaimed_error(Body) when is_binary(Body); is_list(Body) ->
try json:decode(iolist_to_binary(Body)) of
Map when is_map(Map) ->
case get_unclaimed_error_code(Map) of
Code when is_binary(Code) -> is_voice_unclaimed_error_code(Code);
_ -> false
end;
_ ->
false
catch
_:_ -> false
end;
parse_unclaimed_error(_) ->
false.
-spec get_unclaimed_error_code(map()) -> binary() | undefined.
get_unclaimed_error_code(Map) when is_map(Map) ->
case maps:get(<<"code">>, Map, undefined) of
Code when is_binary(Code) ->
Code;
_ ->
case maps:get(<<"error">>, Map, undefined) of
Error when is_map(Error) ->
maps:get(<<"code">>, Error, undefined);
_ ->
undefined
end
end.
-spec is_voice_unclaimed_error_code(binary()) -> boolean().
is_voice_unclaimed_error_code(Code) when is_binary(Code) ->
lists:member(
Code,
[
<<"UNCLAIMED_ACCOUNT_CANNOT_JOIN_ONE_ON_ONE_VOICE_CALLS">>,
<<"UNCLAIMED_ACCOUNT_CANNOT_JOIN_VOICE_CHANNELS">>
]
).
-ifdef(TEST).
parse_unclaimed_error_with_direct_code_test() ->
Body = json:encode(#{<<"code">> => <<"UNCLAIMED_ACCOUNT_CANNOT_JOIN_VOICE_CHANNELS">>}),
?assertEqual(true, parse_unclaimed_error(Body)).
parse_unclaimed_error_with_nested_code_test() ->
Body = json:encode(#{
<<"error">> => #{<<"code">> => <<"UNCLAIMED_ACCOUNT_CANNOT_JOIN_ONE_ON_ONE_VOICE_CALLS">>}
}),
?assertEqual(true, parse_unclaimed_error(Body)).
parse_unclaimed_error_with_unknown_code_test() ->
Body = json:encode(#{<<"code">> => <<"SOME_OTHER_ERROR">>}),
?assertEqual(false, parse_unclaimed_error(Body)).
parse_unclaimed_error_with_invalid_json_test() ->
?assertEqual(false, parse_unclaimed_error(<<"not json">>)).
parse_unclaimed_error_with_non_binary_test() ->
?assertEqual(false, parse_unclaimed_error(undefined)),
?assertEqual(false, parse_unclaimed_error(123)).
is_voice_unclaimed_error_code_test() ->
?assertEqual(
true, is_voice_unclaimed_error_code(<<"UNCLAIMED_ACCOUNT_CANNOT_JOIN_VOICE_CHANNELS">>)
),
?assertEqual(
true,
is_voice_unclaimed_error_code(<<"UNCLAIMED_ACCOUNT_CANNOT_JOIN_ONE_ON_ONE_VOICE_CALLS">>)
),
?assertEqual(false, is_voice_unclaimed_error_code(<<"OTHER_ERROR">>)).
-endif.

View File

@@ -24,6 +24,10 @@
channel_has_capacity/3
]).
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-endif.
-type user_id() :: integer().
-type session_id() :: binary().
-type session_pid() :: pid().
@@ -103,3 +107,56 @@ channel_has_capacity(ChannelId, UserLimit, VoiceStates) ->
-spec ensure_binary(binary() | integer()) -> binary().
ensure_binary(Value) when is_binary(Value) -> Value;
ensure_binary(Value) when is_integer(Value) -> integer_to_binary(Value).
-ifdef(TEST).
find_session_by_user_id_test() ->
Pid = self(),
Ref = make_ref(),
Sessions = #{
<<"session1">> => {100, Pid, Ref},
<<"session2">> => {200, Pid, make_ref()}
},
?assertMatch({ok, <<"session1">>, _, _}, find_session_by_user_id(100, Sessions)),
?assertEqual(not_found, find_session_by_user_id(999, Sessions)).
disconnect_user_not_found_test() ->
VoiceStates = #{},
Sessions = #{},
CleanupFun = fun(_, _) -> ok end,
?assertMatch({not_found, _, _}, disconnect_user(100, VoiceStates, Sessions, CleanupFun)).
disconnect_user_if_in_channel_mismatch_test() ->
VoiceStates = #{100 => #{<<"channel_id">> => <<"999">>}},
Sessions = #{},
CleanupFun = fun(_, _) -> ok end,
Result = disconnect_user_if_in_channel(100, 123, VoiceStates, Sessions, CleanupFun),
?assertMatch({channel_mismatch, _, _}, Result).
disconnect_user_if_in_channel_not_found_test() ->
VoiceStates = #{},
Sessions = #{},
CleanupFun = fun(_, _) -> ok end,
Result = disconnect_user_if_in_channel(100, 123, VoiceStates, Sessions, CleanupFun),
?assertMatch({not_found, _, _}, Result).
channel_has_capacity_unlimited_test() ->
VoiceStates = #{
1 => #{<<"channel_id">> => <<"100">>},
2 => #{<<"channel_id">> => <<"100">>}
},
?assert(channel_has_capacity(100, 0, VoiceStates)).
channel_has_capacity_limited_test() ->
VoiceStates = #{
1 => #{<<"channel_id">> => <<"100">>},
2 => #{<<"channel_id">> => <<"100">>}
},
?assertNot(channel_has_capacity(100, 2, VoiceStates)),
?assert(channel_has_capacity(100, 3, VoiceStates)).
ensure_binary_test() ->
?assertEqual(<<"123">>, ensure_binary(123)),
?assertEqual(<<"abc">>, ensure_binary(<<"abc">>)).
-endif.

View File

@@ -24,6 +24,10 @@
confirm_pending_connection/2
]).
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-endif.
-type connection_id() :: binary().
-type pending_metadata() :: map().
-type pending_map() :: #{connection_id() => pending_metadata()}.
@@ -56,3 +60,35 @@ confirm_pending_connection(ConnectionId, PendingMap) ->
_Metadata ->
{confirmed, maps:remove(ConnectionId, PendingMap)}
end.
-ifdef(TEST).
add_pending_connection_test() ->
PendingMap = #{},
Metadata = #{user_id => 1, channel_id => 2},
Result = add_pending_connection(<<"conn">>, Metadata, PendingMap),
?assert(maps:is_key(<<"conn">>, Result)),
StoredMetadata = maps:get(<<"conn">>, Result),
?assertEqual(1, maps:get(user_id, StoredMetadata)),
?assertEqual(2, maps:get(channel_id, StoredMetadata)),
?assert(maps:is_key(joined_at, StoredMetadata)).
remove_pending_connection_test() ->
PendingMap = #{<<"conn">> => #{user_id => 1}},
?assertEqual(#{}, remove_pending_connection(<<"conn">>, PendingMap)),
?assertEqual(PendingMap, remove_pending_connection(undefined, PendingMap)),
?assertEqual(PendingMap, remove_pending_connection(<<"other">>, PendingMap)).
get_pending_connection_test() ->
PendingMap = #{<<"conn">> => #{user_id => 1}},
?assertEqual(#{user_id => 1}, get_pending_connection(<<"conn">>, PendingMap)),
?assertEqual(undefined, get_pending_connection(<<"other">>, PendingMap)),
?assertEqual(undefined, get_pending_connection(undefined, PendingMap)).
confirm_pending_connection_test() ->
PendingMap = #{<<"conn">> => #{user_id => 1}},
?assertMatch({confirmed, #{}}, confirm_pending_connection(<<"conn">>, PendingMap)),
?assertMatch({not_found, _}, confirm_pending_connection(<<"other">>, PendingMap)),
?assertMatch({not_found, _}, confirm_pending_connection(undefined, PendingMap)).
-endif.

View File

@@ -33,66 +33,86 @@
build_stream_key/3
]).
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-endif.
-type voice_state() :: map().
-type voice_state_map() :: #{binary() => voice_state()}.
-type guild_state() :: map().
-type stream_key_result() :: #{
scope := guild | dm,
guild_id := integer() | undefined,
channel_id := integer(),
connection_id := binary()
}.
-spec voice_states(guild_state()) -> voice_state_map().
voice_states(State) when is_map(State) ->
case maps:get(voice_states, State, undefined) of
Map when is_map(Map) -> Map;
_ -> #{}
end.
-spec ensure_voice_states(term()) -> voice_state_map().
ensure_voice_states(Map) when is_map(Map) ->
Map;
ensure_voice_states(_) ->
#{}.
-spec voice_state_user_id(voice_state()) -> integer() | undefined.
voice_state_user_id(VoiceState) ->
map_utils:get_integer(VoiceState, <<"user_id">>, undefined).
-spec voice_state_channel_id(voice_state()) -> integer() | undefined.
voice_state_channel_id(VoiceState) ->
map_utils:get_integer(VoiceState, <<"channel_id">>, undefined).
-spec voice_state_guild_id(voice_state()) -> integer() | undefined.
voice_state_guild_id(VoiceState) ->
map_utils:get_integer(VoiceState, <<"guild_id">>, undefined).
-spec filter_voice_states(voice_state_map(), fun((binary(), voice_state()) -> boolean())) ->
voice_state_map().
filter_voice_states(VoiceStates, Predicate) when is_map(VoiceStates) ->
maps:filter(Predicate, VoiceStates);
filter_voice_states(_, _) ->
#{}.
-spec drop_voice_states(voice_state_map(), voice_state_map()) -> voice_state_map().
drop_voice_states(ToDrop, VoiceStates) ->
maps:fold(fun(ConnId, _VoiceState, Acc) -> maps:remove(ConnId, Acc) end, VoiceStates, ToDrop).
-spec broadcast_disconnects(voice_state_map(), guild_state()) -> ok.
broadcast_disconnects(VoiceStates, State) ->
maps:foreach(
fun(ConnId, VoiceState) ->
OldChannelIdBin = maps:get(<<"channel_id">>, VoiceState, null),
DisconnectVoiceState = VoiceState#{
<<"channel_id">> => null,
<<"connection_id">> => ConnId
},
guild_voice_broadcast:broadcast_voice_state_update(
DisconnectVoiceState, State, OldChannelIdBin
)
end,
VoiceStates
).
spawn(fun() ->
maps:foreach(
fun(ConnId, VoiceState) ->
OldChannelIdBin = maps:get(<<"channel_id">>, VoiceState, null),
DisconnectVoiceState = VoiceState#{
<<"channel_id">> => null,
<<"connection_id">> => ConnId
},
guild_voice_broadcast:broadcast_voice_state_update(
DisconnectVoiceState, State, OldChannelIdBin
)
end,
VoiceStates
)
end),
ok.
-spec voice_flags_from_context(map()) -> voice_flags().
voice_flags_from_context(Context) ->
#voice_flags{
self_mute = maps:get(self_mute, Context, false),
self_deaf = maps:get(self_deaf, Context, false),
self_video = maps:get(self_video, Context, false),
self_stream = maps:get(self_stream, Context, false),
is_mobile = maps:get(is_mobile, Context, false)
#{
self_mute => maps:get(self_mute, Context, false),
self_deaf => maps:get(self_deaf, Context, false),
self_video => maps:get(self_video, Context, false),
self_stream => maps:get(self_stream, Context, false),
is_mobile => maps:get(is_mobile, Context, false)
}.
-spec parse_stream_key(term()) ->
{ok, #{
scope := guild | dm,
guild_id := integer() | undefined,
channel_id := integer(),
connection_id := binary()
}}
| {error, invalid_stream_key}.
-spec parse_stream_key(term()) -> {ok, stream_key_result()} | {error, invalid_stream_key}.
parse_stream_key(StreamKey) when is_binary(StreamKey) ->
Parts = binary:split(StreamKey, <<":">>, [global]),
case Parts of
@@ -126,12 +146,7 @@ parse_channel_bin(ChannelBin) ->
Chan.
-spec build_stream_key_result({dm, undefined} | {guild, integer()}, integer(), binary()) ->
{ok, #{
scope := guild | dm,
guild_id := integer() | undefined,
channel_id := integer(),
connection_id := binary()
}}.
{ok, stream_key_result()}.
build_stream_key_result({dm, _}, ChannelId, ConnId) ->
{ok, #{
scope => dm,
@@ -162,3 +177,84 @@ build_stream_key(GuildId, ChannelId, ConnectionId) when
":",
ConnectionId/binary
>>.
-ifdef(TEST).
voice_states_returns_map_test() ->
State = #{voice_states => #{<<"a">> => #{}}},
?assertEqual(#{<<"a">> => #{}}, voice_states(State)),
?assertEqual(#{}, voice_states(#{})),
?assertEqual(#{}, voice_states(#{voice_states => not_a_map})).
ensure_voice_states_test() ->
?assertEqual(#{<<"a">> => 1}, ensure_voice_states(#{<<"a">> => 1})),
?assertEqual(#{}, ensure_voice_states(not_a_map)).
voice_state_user_id_test() ->
?assertEqual(123, voice_state_user_id(#{<<"user_id">> => <<"123">>})),
?assertEqual(undefined, voice_state_user_id(#{})).
voice_state_channel_id_test() ->
?assertEqual(456, voice_state_channel_id(#{<<"channel_id">> => <<"456">>})),
?assertEqual(undefined, voice_state_channel_id(#{})).
voice_state_guild_id_test() ->
?assertEqual(789, voice_state_guild_id(#{<<"guild_id">> => <<"789">>})),
?assertEqual(undefined, voice_state_guild_id(#{})).
filter_voice_states_test() ->
VoiceStates = #{
<<"a">> => #{<<"user_id">> => <<"1">>},
<<"b">> => #{<<"user_id">> => <<"2">>}
},
Filtered = filter_voice_states(VoiceStates, fun(_, V) ->
maps:get(<<"user_id">>, V) =:= <<"1">>
end),
?assertEqual(#{<<"a">> => #{<<"user_id">> => <<"1">>}}, Filtered).
drop_voice_states_test() ->
VoiceStates = #{<<"a">> => #{}, <<"b">> => #{}, <<"c">> => #{}},
ToDrop = #{<<"a">> => #{}, <<"c">> => #{}},
Result = drop_voice_states(ToDrop, VoiceStates),
?assertEqual(#{<<"b">> => #{}}, Result).
voice_flags_from_context_test() ->
Context = #{
self_mute => true,
self_deaf => false,
self_video => true,
self_stream => false,
is_mobile => true
},
Flags = voice_flags_from_context(Context),
?assertEqual(true, maps:get(self_mute, Flags)),
?assertEqual(false, maps:get(self_deaf, Flags)),
?assertEqual(true, maps:get(self_video, Flags)),
?assertEqual(false, maps:get(self_stream, Flags)),
?assertEqual(true, maps:get(is_mobile, Flags)).
parse_stream_key_dm_test() ->
Result = parse_stream_key(<<"dm:123:conn-id">>),
?assertMatch({ok, #{scope := dm, channel_id := 123, connection_id := <<"conn-id">>}}, Result).
parse_stream_key_guild_test() ->
Result = parse_stream_key(<<"999:123:conn-id">>),
?assertMatch(
{ok, #{scope := guild, guild_id := 999, channel_id := 123, connection_id := <<"conn-id">>}},
Result
).
parse_stream_key_invalid_test() ->
?assertEqual({error, invalid_stream_key}, parse_stream_key(<<"invalid">>)),
?assertEqual({error, invalid_stream_key}, parse_stream_key(<<"a:b">>)),
?assertEqual({error, invalid_stream_key}, parse_stream_key(123)).
build_stream_key_dm_test() ->
Result = build_stream_key(undefined, 123, <<"conn">>),
?assertEqual(<<"dm:123:conn">>, Result).
build_stream_key_guild_test() ->
Result = build_stream_key(999, 123, <<"conn">>),
?assertEqual(<<"999:123:conn">>, Result).
-endif.

View File

@@ -20,45 +20,56 @@
-export([
build_voice_token_rpc_request/6,
build_voice_token_rpc_request/7,
build_voice_token_rpc_request/8,
build_force_disconnect_rpc_request/4,
build_update_participant_rpc_request/5,
build_update_participant_permissions_rpc_request/5,
add_geolocation_to_request/3,
compute_voice_permissions/3
add_rtc_region_to_request/2,
compute_voice_permissions/3,
generate_token_nonce/0
]).
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-endif.
-type guild_state() :: map().
-type voice_permissions() :: #{
can_speak := boolean(),
can_stream := boolean(),
can_video := boolean()
}.
-spec build_voice_token_rpc_request(
integer() | null,
integer(),
integer(),
binary() | integer() | null,
binary() | number() | list() | null,
binary() | number() | list() | null
) -> map().
build_voice_token_rpc_request(GuildId, ChannelId, UserId, ConnectionId, Latitude, Longitude) ->
BaseReq0 = #{
<<"type">> => <<"voice_get_token">>,
<<"channel_id">> => integer_to_binary(ChannelId),
<<"user_id">> => integer_to_binary(UserId)
},
BaseReq =
case GuildId of
null ->
#{
<<"type">> => <<"voice_get_token">>,
<<"channel_id">> => integer_to_binary(ChannelId),
<<"user_id">> => integer_to_binary(UserId)
};
_ ->
BaseMap = #{
<<"type">> => <<"voice_get_token">>,
<<"guild_id">> => integer_to_binary(GuildId),
<<"channel_id">> => integer_to_binary(ChannelId),
<<"user_id">> => integer_to_binary(UserId)
},
case ConnectionId of
null ->
BaseMap;
ConnectionId when is_binary(ConnectionId) ->
maps:put(<<"connection_id">>, ConnectionId, BaseMap);
ConnectionId when is_integer(ConnectionId) ->
maps:put(<<"connection_id">>, integer_to_binary(ConnectionId), BaseMap);
_ ->
BaseMap
end
null -> BaseReq0;
_ -> maps:put(<<"guild_id">>, integer_to_binary(GuildId), BaseReq0)
end,
WithConnection = add_connection_id_to_request(BaseReq, ConnectionId),
add_geolocation_to_request(WithConnection, Latitude, Longitude).
add_geolocation_to_request(BaseReq, Latitude, Longitude).
-spec add_geolocation_to_request(
map(),
binary() | number() | list() | null,
binary() | number() | list() | null
) -> map().
add_geolocation_to_request(RequestMap, Latitude, Longitude) ->
case {Latitude, Longitude} of
case {normalise_coordinate(Latitude), normalise_coordinate(Longitude)} of
{Lat, Long} when is_binary(Lat) andalso is_binary(Long) ->
maps:merge(RequestMap, #{
<<"latitude">> => Lat,
@@ -68,6 +79,47 @@ add_geolocation_to_request(RequestMap, Latitude, Longitude) ->
RequestMap
end.
-spec normalise_coordinate(binary() | number() | list() | null) -> binary() | undefined.
normalise_coordinate(null) ->
undefined;
normalise_coordinate(Value) when is_binary(Value) ->
Value;
normalise_coordinate(Value) when is_integer(Value) ->
integer_to_binary(Value);
normalise_coordinate(Value) when is_float(Value) ->
float_to_binary(Value, [short]);
normalise_coordinate(Value) when is_list(Value) ->
try
list_to_binary(Value)
catch
error:badarg -> undefined
end;
normalise_coordinate(_Value) ->
undefined.
-spec add_rtc_region_to_request(map(), binary() | null) -> map().
add_rtc_region_to_request(RequestMap, Region) ->
case Region of
RegionBin when is_binary(RegionBin) ->
maps:put(<<"rtc_region">>, RegionBin, RequestMap);
_ ->
RequestMap
end.
-spec add_connection_id_to_request(map(), binary() | integer() | null) -> map().
add_connection_id_to_request(RequestMap, ConnectionId) ->
case ConnectionId of
null ->
RequestMap;
ConnectionIdBin when is_binary(ConnectionIdBin) ->
maps:put(<<"connection_id">>, ConnectionIdBin, RequestMap);
ConnectionIdInt when is_integer(ConnectionIdInt) ->
maps:put(<<"connection_id">>, integer_to_binary(ConnectionIdInt), RequestMap);
_ ->
RequestMap
end.
-spec build_force_disconnect_rpc_request(integer() | null, integer(), integer(), binary()) -> map().
build_force_disconnect_rpc_request(GuildId, ChannelId, UserId, ConnectionId) ->
BaseReq = #{
<<"type">> => <<"voice_force_disconnect_participant">>,
@@ -82,6 +134,9 @@ build_force_disconnect_rpc_request(GuildId, ChannelId, UserId, ConnectionId) ->
maps:put(<<"guild_id">>, integer_to_binary(GuildId), BaseReq)
end.
-spec build_update_participant_rpc_request(
integer() | null, integer(), integer(), boolean(), boolean()
) -> map().
build_update_participant_rpc_request(GuildId, ChannelId, UserId, Mute, Deaf) ->
BaseReq = #{
<<"type">> => <<"voice_update_participant">>,
@@ -97,6 +152,9 @@ build_update_participant_rpc_request(GuildId, ChannelId, UserId, Mute, Deaf) ->
maps:put(<<"guild_id">>, integer_to_binary(GuildId), BaseReq)
end.
-spec build_update_participant_permissions_rpc_request(
integer() | null, integer(), integer(), binary(), voice_permissions()
) -> map().
build_update_participant_permissions_rpc_request(
GuildId, ChannelId, UserId, ConnectionId, VoicePermissions
) ->
@@ -116,35 +174,181 @@ build_update_participant_permissions_rpc_request(
maps:put(<<"guild_id">>, integer_to_binary(GuildId), BaseReq)
end.
-spec compute_voice_permissions(integer(), integer(), map()) -> map().
-spec compute_voice_permissions(integer(), integer(), guild_state()) -> voice_permissions().
compute_voice_permissions(UserId, ChannelId, State) ->
Permissions = guild_permissions:get_member_permissions(UserId, ChannelId, State),
SpeakPerm = constants:speak_permission(),
StreamPerm = constants:stream_permission(),
AdminPerm = constants:administrator_permission(),
IsAdmin = (Permissions band AdminPerm) =:= AdminPerm,
CanSpeak = IsAdmin orelse ((Permissions band SpeakPerm) =:= SpeakPerm),
CanStream = IsAdmin orelse ((Permissions band StreamPerm) =:= StreamPerm),
HasVirtualAccess = guild_virtual_channel_access:has_virtual_access(UserId, ChannelId, State),
FinalCanSpeak = CanSpeak orelse HasVirtualAccess,
FinalCanStream = CanStream orelse HasVirtualAccess,
#{
can_speak => FinalCanSpeak,
can_stream => FinalCanStream,
can_video => FinalCanStream
}.
-spec build_voice_token_rpc_request(
integer() | null,
integer(),
integer(),
binary() | integer() | null,
binary() | null,
binary() | null,
voice_permissions()
) -> map().
build_voice_token_rpc_request(
GuildId, ChannelId, UserId, ConnectionId, Latitude, Longitude, VoicePermissions
) ->
build_voice_token_rpc_request(
GuildId, ChannelId, UserId, ConnectionId, Latitude, Longitude, VoicePermissions, null
).
-spec build_voice_token_rpc_request(
integer() | null,
integer(),
integer(),
binary() | integer() | null,
binary() | null,
binary() | null,
voice_permissions(),
binary() | null
) -> map().
build_voice_token_rpc_request(
GuildId, ChannelId, UserId, ConnectionId, Latitude, Longitude, VoicePermissions, TokenNonce
) ->
BaseReq = build_voice_token_rpc_request(
GuildId, ChannelId, UserId, ConnectionId, Latitude, Longitude
),
maps:merge(BaseReq, #{
Req0 = maps:merge(BaseReq, #{
<<"can_speak">> => maps:get(can_speak, VoicePermissions, true),
<<"can_stream">> => maps:get(can_stream, VoicePermissions, true),
<<"can_video">> => maps:get(can_video, VoicePermissions, true)
}).
}),
case TokenNonce of
null -> Req0;
undefined -> Req0;
_ when is_binary(TokenNonce) -> maps:put(<<"token_nonce">>, TokenNonce, Req0);
_ -> Req0
end.
-spec generate_token_nonce() -> binary().
generate_token_nonce() ->
Bytes = crypto:strong_rand_bytes(16),
binary:encode_hex(Bytes, lowercase).
-ifdef(TEST).
build_voice_token_rpc_request_guild_test() ->
Req = build_voice_token_rpc_request(123, 456, 789, null, null, null),
?assertEqual(<<"voice_get_token">>, maps:get(<<"type">>, Req)),
?assertEqual(<<"123">>, maps:get(<<"guild_id">>, Req)),
?assertEqual(<<"456">>, maps:get(<<"channel_id">>, Req)),
?assertEqual(<<"789">>, maps:get(<<"user_id">>, Req)),
?assertNot(maps:is_key(<<"connection_id">>, Req)).
build_voice_token_rpc_request_dm_test() ->
Req = build_voice_token_rpc_request(null, 456, 789, null, null, null),
?assertEqual(<<"voice_get_token">>, maps:get(<<"type">>, Req)),
?assertNot(maps:is_key(<<"guild_id">>, Req)),
?assertEqual(<<"456">>, maps:get(<<"channel_id">>, Req)).
build_voice_token_rpc_request_with_connection_test() ->
Req = build_voice_token_rpc_request(123, 456, 789, <<"conn-id">>, null, null),
?assertEqual(<<"conn-id">>, maps:get(<<"connection_id">>, Req)).
build_voice_token_rpc_request_dm_with_connection_test() ->
Req = build_voice_token_rpc_request(null, 456, 789, <<"conn-id">>, null, null),
?assertEqual(<<"conn-id">>, maps:get(<<"connection_id">>, Req)).
add_geolocation_to_request_test() ->
BaseReq = #{<<"type">> => <<"test">>},
WithGeo = add_geolocation_to_request(BaseReq, <<"1.0">>, <<"2.0">>),
?assertEqual(<<"1.0">>, maps:get(<<"latitude">>, WithGeo)),
?assertEqual(<<"2.0">>, maps:get(<<"longitude">>, WithGeo)),
WithoutGeo = add_geolocation_to_request(BaseReq, null, null),
?assertNot(maps:is_key(<<"latitude">>, WithoutGeo)).
add_geolocation_to_request_number_test() ->
BaseReq = #{<<"type">> => <<"test">>},
WithGeo = add_geolocation_to_request(BaseReq, 1.5, 2.25),
?assertEqual(<<"1.5">>, maps:get(<<"latitude">>, WithGeo)),
?assertEqual(<<"2.25">>, maps:get(<<"longitude">>, WithGeo)).
add_rtc_region_to_request_test() ->
BaseReq = #{<<"type">> => <<"test">>},
WithRegion = add_rtc_region_to_request(BaseReq, <<"us-east">>),
?assertEqual(<<"us-east">>, maps:get(<<"rtc_region">>, WithRegion)),
WithoutRegion = add_rtc_region_to_request(BaseReq, null),
?assertNot(maps:is_key(<<"rtc_region">>, WithoutRegion)).
build_force_disconnect_rpc_request_test() ->
Req = build_force_disconnect_rpc_request(123, 456, 789, <<"conn">>),
?assertEqual(<<"voice_force_disconnect_participant">>, maps:get(<<"type">>, Req)),
?assertEqual(<<"123">>, maps:get(<<"guild_id">>, Req)),
?assertEqual(<<"conn">>, maps:get(<<"connection_id">>, Req)).
build_update_participant_rpc_request_test() ->
Req = build_update_participant_rpc_request(123, 456, 789, true, false),
?assertEqual(<<"voice_update_participant">>, maps:get(<<"type">>, Req)),
?assertEqual(true, maps:get(<<"mute">>, Req)),
?assertEqual(false, maps:get(<<"deaf">>, Req)).
generate_token_nonce_format_test() ->
Nonce = generate_token_nonce(),
?assert(is_binary(Nonce)),
?assertEqual(32, byte_size(Nonce)),
?assert(lists:all(fun(C) ->
(C >= $0 andalso C =< $9) orelse (C >= $a andalso C =< $f)
end, binary_to_list(Nonce))).
generate_token_nonce_unique_test() ->
Nonce1 = generate_token_nonce(),
Nonce2 = generate_token_nonce(),
Nonce3 = generate_token_nonce(),
?assertNot(Nonce1 =:= Nonce2),
?assertNot(Nonce2 =:= Nonce3),
?assertNot(Nonce1 =:= Nonce3).
build_voice_token_rpc_request_with_nonce_test() ->
VoicePerms = #{
can_speak => true,
can_stream => false,
can_video => false
},
Req = build_voice_token_rpc_request(
123, 456, 789, null, null, null, VoicePerms, <<"test-nonce-123">>
),
?assertEqual(<<"test-nonce-123">>, maps:get(<<"token_nonce">>, Req)),
?assertEqual(true, maps:get(<<"can_speak">>, Req)),
?assertEqual(false, maps:get(<<"can_stream">>, Req)).
build_voice_token_rpc_request_without_nonce_test() ->
VoicePerms = #{
can_speak => true,
can_stream => true,
can_video => true
},
Req = build_voice_token_rpc_request(
123, 456, 789, null, null, null, VoicePerms, null
),
?assertNot(maps:is_key(<<"token_nonce">>, Req)),
?assertEqual(true, maps:get(<<"can_speak">>, Req)).
build_voice_token_rpc_request_undefined_nonce_test() ->
VoicePerms = #{
can_speak => false,
can_stream => true,
can_video => true
},
Req = build_voice_token_rpc_request(
123, 456, 789, null, null, null, VoicePerms, undefined
),
?assertNot(maps:is_key(<<"token_nonce">>, Req)),
?assertEqual(false, maps:get(<<"can_speak">>, Req)).
-endif.