refactor progress
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -21,123 +21,533 @@
|
||||
is_guild_unavailable_for_user/2,
|
||||
is_user_staff/2,
|
||||
check_unavailability_transition/2,
|
||||
handle_unavailability_transition/2
|
||||
handle_unavailability_transition/2,
|
||||
get_cached_unavailability_mode/1,
|
||||
is_guild_unavailable_for_user_from_cache/2,
|
||||
update_unavailability_cache_for_state/1
|
||||
]).
|
||||
|
||||
-import(guild_permissions, [find_member_by_user_id/2]).
|
||||
-import(guild_data, [get_guild_state/2]).
|
||||
-type guild_state() :: map().
|
||||
-type user_id() :: integer().
|
||||
-type guild_id() :: integer().
|
||||
-type unavailability_mode() ::
|
||||
available
|
||||
| unavailable_for_everyone
|
||||
| unavailable_for_everyone_but_staff.
|
||||
-type transition_result() :: {unavailable_enabled, boolean()} | unavailable_disabled | no_change.
|
||||
|
||||
-define(GUILD_UNAVAILABILITY_CACHE, guild_unavailability_cache).
|
||||
-define(STAFF_USER_FLAG, 16#1).
|
||||
|
||||
-ifdef(TEST).
|
||||
-include_lib("eunit/include/eunit.hrl").
|
||||
-endif.
|
||||
|
||||
-spec is_guild_unavailable_for_user(user_id(), guild_state()) -> boolean().
|
||||
is_guild_unavailable_for_user(UserId, State) ->
|
||||
Data = maps:get(data, State),
|
||||
Guild = maps:get(<<"guild">>, Data),
|
||||
Features = maps:get(<<"features">>, Guild, []),
|
||||
|
||||
HasUnavailableForEveryone = lists:member(<<"UNAVAILABLE_FOR_EVERYONE">>, Features),
|
||||
HasUnavailableForEveryoneButStaff =
|
||||
lists:member(<<"UNAVAILABLE_FOR_EVERYONE_BUT_STAFF">>, Features),
|
||||
|
||||
case {HasUnavailableForEveryone, HasUnavailableForEveryoneButStaff} of
|
||||
{true, _} ->
|
||||
case get_unavailability_mode_from_state(State) of
|
||||
unavailable_for_everyone ->
|
||||
true;
|
||||
{false, true} ->
|
||||
unavailable_for_everyone_but_staff ->
|
||||
not is_user_staff(UserId, State);
|
||||
{false, false} ->
|
||||
available ->
|
||||
false
|
||||
end.
|
||||
|
||||
-spec is_user_staff(user_id(), guild_state()) -> boolean().
|
||||
is_user_staff(UserId, State) ->
|
||||
case find_member_by_user_id(UserId, State) of
|
||||
undefined ->
|
||||
case is_user_staff_from_sessions(UserId, State) of
|
||||
true ->
|
||||
true;
|
||||
false ->
|
||||
false;
|
||||
Member ->
|
||||
User = maps:get(<<"user">>, Member, #{}),
|
||||
Flags = utils:binary_to_integer_safe(maps:get(<<"flags">>, User, <<"0">>)),
|
||||
(Flags band 16#1) =:= 16#1
|
||||
end.
|
||||
|
||||
check_unavailability_transition(OldState, NewState) ->
|
||||
OldData = maps:get(data, OldState),
|
||||
OldGuild = maps:get(<<"guild">>, OldData),
|
||||
OldFeatures = maps:get(<<"features">>, OldGuild, []),
|
||||
|
||||
NewData = maps:get(data, NewState),
|
||||
NewGuild = maps:get(<<"guild">>, NewData),
|
||||
NewFeatures = maps:get(<<"features">>, NewGuild, []),
|
||||
|
||||
OldUnavailableForEveryone = lists:member(<<"UNAVAILABLE_FOR_EVERYONE">>, OldFeatures),
|
||||
NewUnavailableForEveryone = lists:member(<<"UNAVAILABLE_FOR_EVERYONE">>, NewFeatures),
|
||||
|
||||
OldUnavailableForEveryoneButStaff =
|
||||
lists:member(<<"UNAVAILABLE_FOR_EVERYONE_BUT_STAFF">>, OldFeatures),
|
||||
NewUnavailableForEveryoneButStaff =
|
||||
lists:member(<<"UNAVAILABLE_FOR_EVERYONE_BUT_STAFF">>, NewFeatures),
|
||||
|
||||
OldIsUnavailable = OldUnavailableForEveryone orelse OldUnavailableForEveryoneButStaff,
|
||||
NewIsUnavailable = NewUnavailableForEveryone orelse NewUnavailableForEveryoneButStaff,
|
||||
|
||||
case {OldIsUnavailable, NewIsUnavailable} of
|
||||
{false, true} ->
|
||||
{unavailable_enabled, NewUnavailableForEveryoneButStaff};
|
||||
{true, false} ->
|
||||
unavailable_disabled;
|
||||
_ ->
|
||||
case
|
||||
{OldUnavailableForEveryoneButStaff, NewUnavailableForEveryoneButStaff,
|
||||
OldUnavailableForEveryone, NewUnavailableForEveryone}
|
||||
of
|
||||
{true, false, false, true} ->
|
||||
{unavailable_enabled, false};
|
||||
{false, true, true, false} ->
|
||||
{unavailable_enabled, true};
|
||||
_ ->
|
||||
no_change
|
||||
undefined ->
|
||||
case guild_permissions:find_member_by_user_id(UserId, State) of
|
||||
undefined ->
|
||||
false;
|
||||
Member ->
|
||||
User = maps:get(<<"user">>, Member, #{}),
|
||||
is_user_staff_from_user_data(User)
|
||||
end
|
||||
end.
|
||||
|
||||
handle_unavailability_transition(OldState, NewState) ->
|
||||
GuildId = maps:get(id, NewState),
|
||||
UnavailablePayload = #{
|
||||
<<"id">> => integer_to_binary(GuildId),
|
||||
<<"unavailable">> => true
|
||||
},
|
||||
-spec is_user_staff_from_sessions(user_id(), guild_state()) -> boolean() | undefined.
|
||||
is_user_staff_from_sessions(UserId, State) ->
|
||||
Sessions = maps:get(sessions, State, #{}),
|
||||
maps:fold(
|
||||
fun(_SessionId, SessionData, Acc) ->
|
||||
case Acc of
|
||||
undefined ->
|
||||
SessionUserId = maps:get(user_id, SessionData, undefined),
|
||||
SessionIsStaff = maps:get(is_staff, SessionData, undefined),
|
||||
case {SessionUserId =:= UserId, SessionIsStaff} of
|
||||
{true, true} ->
|
||||
true;
|
||||
{true, false} ->
|
||||
false;
|
||||
_ ->
|
||||
undefined
|
||||
end;
|
||||
_ ->
|
||||
Acc
|
||||
end
|
||||
end,
|
||||
undefined,
|
||||
Sessions
|
||||
).
|
||||
|
||||
-spec get_cached_unavailability_mode(guild_id()) -> unavailability_mode().
|
||||
get_cached_unavailability_mode(GuildId) ->
|
||||
ensure_unavailability_cache_table(),
|
||||
case ets:lookup(?GUILD_UNAVAILABILITY_CACHE, GuildId) of
|
||||
[{GuildId, Mode}] ->
|
||||
normalize_unavailability_mode(Mode);
|
||||
[] ->
|
||||
available
|
||||
end.
|
||||
|
||||
-spec is_guild_unavailable_for_user_from_cache(guild_id(), map()) -> boolean().
|
||||
is_guild_unavailable_for_user_from_cache(GuildId, UserData) ->
|
||||
case get_cached_unavailability_mode(GuildId) of
|
||||
unavailable_for_everyone ->
|
||||
true;
|
||||
unavailable_for_everyone_but_staff ->
|
||||
not is_user_staff_from_user_data(UserData);
|
||||
available ->
|
||||
false
|
||||
end.
|
||||
|
||||
-spec update_unavailability_cache_for_state(guild_state()) -> unavailability_mode().
|
||||
update_unavailability_cache_for_state(State) ->
|
||||
GuildId = maps:get(id, State),
|
||||
Mode = get_unavailability_mode_from_state(State),
|
||||
set_cached_unavailability_mode(GuildId, Mode),
|
||||
Mode.
|
||||
|
||||
-spec check_unavailability_transition(guild_state(), guild_state()) -> transition_result().
|
||||
check_unavailability_transition(OldState, NewState) ->
|
||||
OldMode = get_unavailability_mode_from_state(OldState),
|
||||
NewMode = get_unavailability_mode_from_state(NewState),
|
||||
case {OldMode, NewMode} of
|
||||
{available, unavailable_for_everyone} ->
|
||||
{unavailable_enabled, false};
|
||||
{available, unavailable_for_everyone_but_staff} ->
|
||||
{unavailable_enabled, true};
|
||||
{unavailable_for_everyone, available} ->
|
||||
unavailable_disabled;
|
||||
{unavailable_for_everyone_but_staff, available} ->
|
||||
unavailable_disabled;
|
||||
{unavailable_for_everyone_but_staff, unavailable_for_everyone} ->
|
||||
{unavailable_enabled, false};
|
||||
{unavailable_for_everyone, unavailable_for_everyone_but_staff} ->
|
||||
{unavailable_enabled, true};
|
||||
_ ->
|
||||
no_change
|
||||
end.
|
||||
|
||||
-spec handle_unavailability_transition(guild_state(), guild_state()) -> guild_state().
|
||||
handle_unavailability_transition(OldState, NewState) ->
|
||||
_ = update_unavailability_cache_for_state(NewState),
|
||||
GuildId = maps:get(id, NewState),
|
||||
case check_unavailability_transition(OldState, NewState) of
|
||||
{unavailable_enabled, StaffOnly} ->
|
||||
Sessions = maps:get(sessions, NewState, #{}),
|
||||
lists:foreach(
|
||||
fun({_SessionId, SessionData}) ->
|
||||
UserId = maps:get(user_id, SessionData),
|
||||
Pid = maps:get(pid, SessionData),
|
||||
|
||||
ShouldBeUnavailable =
|
||||
case StaffOnly of
|
||||
true -> not is_user_staff(UserId, NewState);
|
||||
false -> true
|
||||
end,
|
||||
|
||||
case ShouldBeUnavailable of
|
||||
true ->
|
||||
gen_server:cast(Pid, {dispatch, guild_delete, UnavailablePayload});
|
||||
false ->
|
||||
ok
|
||||
end
|
||||
end,
|
||||
maps:to_list(Sessions)
|
||||
);
|
||||
disconnect_ineligible_sessions(StaffOnly, NewState, GuildId);
|
||||
unavailable_disabled ->
|
||||
Sessions = maps:get(sessions, NewState, #{}),
|
||||
GuildId = maps:get(id, NewState),
|
||||
BulkPresences = presence_utils:collect_guild_member_presences(NewState),
|
||||
lists:foreach(
|
||||
fun({_SessionId, SessionData}) ->
|
||||
UserId = maps:get(user_id, SessionData),
|
||||
Pid = maps:get(pid, SessionData),
|
||||
GuildState = get_guild_state(UserId, NewState),
|
||||
gen_server:cast(Pid, {dispatch, guild_create, GuildState}),
|
||||
presence_utils:send_presence_bulk(Pid, GuildId, UserId, BulkPresences)
|
||||
case maps:get(pending_connect, SessionData, false) of
|
||||
true ->
|
||||
ok;
|
||||
false ->
|
||||
UserId = maps:get(user_id, SessionData),
|
||||
Pid = maps:get(pid, SessionData),
|
||||
GuildState = guild_data:get_guild_state(UserId, NewState),
|
||||
gen_server:cast(Pid, {dispatch, guild_create, GuildState}),
|
||||
presence_utils:send_presence_bulk(Pid, GuildId, UserId, BulkPresences)
|
||||
end
|
||||
end,
|
||||
maps:to_list(Sessions)
|
||||
);
|
||||
),
|
||||
NewState;
|
||||
no_change ->
|
||||
NewState
|
||||
end.
|
||||
|
||||
-spec disconnect_ineligible_sessions(boolean(), guild_state(), guild_id()) -> guild_state().
|
||||
disconnect_ineligible_sessions(StaffOnly, State, GuildId) ->
|
||||
Sessions = maps:get(sessions, State, #{}),
|
||||
{FinalState, _DisconnectedUsers} = lists:foldl(
|
||||
fun({SessionId, SessionData}, {AccState, ProcessedUsers}) ->
|
||||
UserId = maps:get(user_id, SessionData),
|
||||
case should_disconnect_user(UserId, StaffOnly, AccState) of
|
||||
true ->
|
||||
Pid = maps:get(pid, SessionData, undefined),
|
||||
maybe_send_guild_leave(Pid, GuildId),
|
||||
{VoiceState, UpdatedUsers} =
|
||||
maybe_disconnect_voice_for_user(UserId, ProcessedUsers, AccState),
|
||||
NewState = guild_sessions:remove_session(SessionId, VoiceState),
|
||||
{NewState, UpdatedUsers};
|
||||
false ->
|
||||
{AccState, ProcessedUsers}
|
||||
end
|
||||
end,
|
||||
{State, sets:new()},
|
||||
maps:to_list(Sessions)
|
||||
),
|
||||
FinalState.
|
||||
|
||||
-spec should_disconnect_user(user_id(), boolean(), guild_state()) -> boolean().
|
||||
should_disconnect_user(UserId, true, State) ->
|
||||
not is_user_staff(UserId, State);
|
||||
should_disconnect_user(_UserId, false, _State) ->
|
||||
true.
|
||||
|
||||
-spec maybe_send_guild_leave(pid() | undefined, guild_id()) -> ok.
|
||||
maybe_send_guild_leave(Pid, GuildId) when is_pid(Pid) ->
|
||||
gen_server:cast(Pid, {guild_leave, GuildId, forced_unavailable}),
|
||||
ok;
|
||||
maybe_send_guild_leave(_Pid, _GuildId) ->
|
||||
ok.
|
||||
|
||||
-spec maybe_disconnect_voice_for_user(user_id(), sets:set(user_id()), guild_state()) ->
|
||||
{guild_state(), sets:set(user_id())}.
|
||||
maybe_disconnect_voice_for_user(UserId, ProcessedUsers, State) ->
|
||||
case sets:is_element(UserId, ProcessedUsers) of
|
||||
true ->
|
||||
{State, ProcessedUsers};
|
||||
false ->
|
||||
{reply, _Result, VoiceState} = guild_voice_disconnect:disconnect_voice_user(
|
||||
#{user_id => UserId, connection_id => null},
|
||||
State
|
||||
),
|
||||
{VoiceState, sets:add_element(UserId, ProcessedUsers)}
|
||||
end.
|
||||
|
||||
-spec ensure_unavailability_cache_table() -> ok.
|
||||
ensure_unavailability_cache_table() ->
|
||||
case ets:whereis(?GUILD_UNAVAILABILITY_CACHE) of
|
||||
undefined ->
|
||||
try ets:new(?GUILD_UNAVAILABILITY_CACHE, [named_table, public, set, {read_concurrency, true}]) of
|
||||
_ -> ok
|
||||
catch
|
||||
error:badarg -> ok
|
||||
end;
|
||||
_ ->
|
||||
ok
|
||||
end.
|
||||
|
||||
-spec set_cached_unavailability_mode(guild_id(), unavailability_mode()) -> ok.
|
||||
set_cached_unavailability_mode(GuildId, available) ->
|
||||
ensure_unavailability_cache_table(),
|
||||
ets:delete(?GUILD_UNAVAILABILITY_CACHE, GuildId),
|
||||
ok;
|
||||
set_cached_unavailability_mode(GuildId, Mode) ->
|
||||
ensure_unavailability_cache_table(),
|
||||
ets:insert(?GUILD_UNAVAILABILITY_CACHE, {GuildId, Mode}),
|
||||
ok.
|
||||
|
||||
-spec normalize_unavailability_mode(term()) -> unavailability_mode().
|
||||
normalize_unavailability_mode(unavailable_for_everyone) ->
|
||||
unavailable_for_everyone;
|
||||
normalize_unavailability_mode(unavailable_for_everyone_but_staff) ->
|
||||
unavailable_for_everyone_but_staff;
|
||||
normalize_unavailability_mode(_) ->
|
||||
available.
|
||||
|
||||
-spec get_unavailability_mode_from_state(guild_state()) -> unavailability_mode().
|
||||
get_unavailability_mode_from_state(State) ->
|
||||
Data = maps:get(data, State, #{}),
|
||||
Guild = maps:get(<<"guild">>, Data, #{}),
|
||||
Features = maps:get(<<"features">>, Guild, []),
|
||||
get_unavailability_mode_from_features(Features).
|
||||
|
||||
-spec get_unavailability_mode_from_features(term()) -> unavailability_mode().
|
||||
get_unavailability_mode_from_features(Features) when is_list(Features) ->
|
||||
HasUnavailableForEveryone = lists:member(<<"UNAVAILABLE_FOR_EVERYONE">>, Features),
|
||||
HasUnavailableForEveryoneButStaff =
|
||||
lists:member(<<"UNAVAILABLE_FOR_EVERYONE_BUT_STAFF">>, Features),
|
||||
case {HasUnavailableForEveryone, HasUnavailableForEveryoneButStaff} of
|
||||
{true, _} ->
|
||||
unavailable_for_everyone;
|
||||
{false, true} ->
|
||||
unavailable_for_everyone_but_staff;
|
||||
{false, false} ->
|
||||
available
|
||||
end;
|
||||
get_unavailability_mode_from_features(_) ->
|
||||
available.
|
||||
|
||||
-spec is_user_staff_from_user_data(map()) -> boolean().
|
||||
is_user_staff_from_user_data(UserData) when is_map(UserData) ->
|
||||
case parse_is_staff_value(maps:get(<<"is_staff">>, UserData, undefined)) of
|
||||
undefined ->
|
||||
is_user_staff_from_flags(UserData);
|
||||
IsStaff ->
|
||||
IsStaff
|
||||
end;
|
||||
is_user_staff_from_user_data(_) ->
|
||||
false.
|
||||
|
||||
-spec parse_is_staff_value(term()) -> boolean() | undefined.
|
||||
parse_is_staff_value(true) ->
|
||||
true;
|
||||
parse_is_staff_value(false) ->
|
||||
false;
|
||||
parse_is_staff_value(<<"true">>) ->
|
||||
true;
|
||||
parse_is_staff_value(<<"false">>) ->
|
||||
false;
|
||||
parse_is_staff_value(_) ->
|
||||
undefined.
|
||||
|
||||
-spec is_user_staff_from_flags(map()) -> boolean().
|
||||
is_user_staff_from_flags(UserData) ->
|
||||
FlagsValue = maps:get(<<"flags">>, UserData, 0),
|
||||
Flags = type_conv:to_integer(FlagsValue),
|
||||
case Flags of
|
||||
undefined ->
|
||||
false;
|
||||
Value when is_integer(Value) ->
|
||||
(Value band ?STAFF_USER_FLAG) =:= ?STAFF_USER_FLAG
|
||||
end.
|
||||
|
||||
-ifdef(TEST).
|
||||
|
||||
-spec cleanup_unavailability_cache(guild_id()) -> ok.
|
||||
cleanup_unavailability_cache(GuildId) ->
|
||||
set_cached_unavailability_mode(GuildId, available).
|
||||
|
||||
disconnect_ineligible_sessions_staff_only_test() ->
|
||||
Parent = self(),
|
||||
GuildId = 99001,
|
||||
NonStaffPid = start_session_capture(non_staff, Parent),
|
||||
StaffPid = start_session_capture(staff, Parent),
|
||||
try
|
||||
BaseState = state_for_unavailability_transition_test(GuildId, NonStaffPid, StaffPid),
|
||||
OldState = BaseState,
|
||||
NewState = BaseState#{
|
||||
data => #{
|
||||
<<"guild">> => #{
|
||||
<<"features">> => [<<"UNAVAILABLE_FOR_EVERYONE_BUT_STAFF">>]
|
||||
},
|
||||
<<"members">> => [
|
||||
#{<<"user">> => #{<<"id">> => <<"1001">>, <<"flags">> => <<"0">>}},
|
||||
#{<<"user">> => #{<<"id">> => <<"1002">>, <<"flags">> => <<"1">>}}
|
||||
]
|
||||
}
|
||||
},
|
||||
UpdatedState = handle_unavailability_transition(OldState, NewState),
|
||||
Sessions = maps:get(sessions, UpdatedState, #{}),
|
||||
?assertEqual(1, map_size(Sessions)),
|
||||
?assert(maps:is_key(<<"staff">>, Sessions)),
|
||||
?assertEqual(unavailable_for_everyone_but_staff, get_cached_unavailability_mode(GuildId)),
|
||||
receive
|
||||
{non_staff, {'$gen_cast', {guild_leave, GuildId, forced_unavailable}}} -> ok
|
||||
after 1000 ->
|
||||
?assert(false)
|
||||
end,
|
||||
receive
|
||||
{staff, {'$gen_cast', {guild_leave, GuildId, forced_unavailable}}} ->
|
||||
?assert(false)
|
||||
after 200 ->
|
||||
ok
|
||||
end
|
||||
after
|
||||
cleanup_unavailability_cache(GuildId),
|
||||
NonStaffPid ! stop,
|
||||
StaffPid ! stop
|
||||
end.
|
||||
|
||||
disconnect_ineligible_sessions_everyone_test() ->
|
||||
Parent = self(),
|
||||
GuildId = 99002,
|
||||
UserOnePid = start_session_capture(user_one, Parent),
|
||||
UserTwoPid = start_session_capture(user_two, Parent),
|
||||
try
|
||||
BaseState = state_for_unavailability_transition_test(GuildId, UserOnePid, UserTwoPid),
|
||||
OldState = BaseState,
|
||||
NewState = BaseState#{
|
||||
data => #{
|
||||
<<"guild">> => #{
|
||||
<<"features">> => [<<"UNAVAILABLE_FOR_EVERYONE">>]
|
||||
},
|
||||
<<"members">> => [
|
||||
#{<<"user">> => #{<<"id">> => <<"1001">>, <<"flags">> => <<"0">>}},
|
||||
#{<<"user">> => #{<<"id">> => <<"1002">>, <<"flags">> => <<"1">>}}
|
||||
]
|
||||
}
|
||||
},
|
||||
UpdatedState = handle_unavailability_transition(OldState, NewState),
|
||||
?assertEqual(#{}, maps:get(sessions, UpdatedState, #{})),
|
||||
?assertEqual(unavailable_for_everyone, get_cached_unavailability_mode(GuildId)),
|
||||
receive
|
||||
{user_one, {'$gen_cast', {guild_leave, GuildId, forced_unavailable}}} -> ok
|
||||
after 1000 ->
|
||||
?assert(false)
|
||||
end,
|
||||
receive
|
||||
{user_two, {'$gen_cast', {guild_leave, GuildId, forced_unavailable}}} -> ok
|
||||
after 1000 ->
|
||||
?assert(false)
|
||||
end
|
||||
after
|
||||
cleanup_unavailability_cache(GuildId),
|
||||
UserOnePid ! stop,
|
||||
UserTwoPid ! stop
|
||||
end.
|
||||
|
||||
is_guild_unavailable_for_user_from_cache_is_staff_test() ->
|
||||
GuildId = 99003,
|
||||
try
|
||||
set_cached_unavailability_mode(GuildId, unavailable_for_everyone_but_staff),
|
||||
?assertEqual(
|
||||
true,
|
||||
is_guild_unavailable_for_user_from_cache(
|
||||
GuildId,
|
||||
#{<<"is_staff">> => false}
|
||||
)
|
||||
),
|
||||
?assertEqual(
|
||||
false,
|
||||
is_guild_unavailable_for_user_from_cache(
|
||||
GuildId,
|
||||
#{<<"is_staff">> => true}
|
||||
)
|
||||
)
|
||||
after
|
||||
cleanup_unavailability_cache(GuildId)
|
||||
end.
|
||||
|
||||
-spec start_session_capture(atom(), pid()) -> pid().
|
||||
start_session_capture(Tag, Parent) ->
|
||||
spawn(fun() -> session_capture_loop(Tag, Parent) end).
|
||||
|
||||
-spec session_capture_loop(atom(), pid()) -> ok.
|
||||
session_capture_loop(Tag, Parent) ->
|
||||
receive
|
||||
stop ->
|
||||
ok;
|
||||
{'$gen_cast', Msg} ->
|
||||
Parent ! {Tag, {'$gen_cast', Msg}},
|
||||
session_capture_loop(Tag, Parent);
|
||||
_Other ->
|
||||
session_capture_loop(Tag, Parent)
|
||||
end.
|
||||
|
||||
-spec state_for_unavailability_transition_test(guild_id(), pid(), pid()) -> guild_state().
|
||||
state_for_unavailability_transition_test(GuildId, NonStaffPid, StaffPid) ->
|
||||
#{
|
||||
id => GuildId,
|
||||
sessions => #{
|
||||
<<"non_staff">> => #{
|
||||
session_id => <<"non_staff">>,
|
||||
user_id => 1001,
|
||||
pid => NonStaffPid,
|
||||
mref => make_ref(),
|
||||
active_guilds => sets:new(),
|
||||
user_roles => [],
|
||||
bot => false,
|
||||
is_staff => false,
|
||||
previous_passive_updates => #{},
|
||||
previous_passive_channel_versions => #{},
|
||||
previous_passive_voice_states => #{}
|
||||
},
|
||||
<<"staff">> => #{
|
||||
session_id => <<"staff">>,
|
||||
user_id => 1002,
|
||||
pid => StaffPid,
|
||||
mref => make_ref(),
|
||||
active_guilds => sets:new(),
|
||||
user_roles => [],
|
||||
bot => false,
|
||||
is_staff => true,
|
||||
previous_passive_updates => #{},
|
||||
previous_passive_channel_versions => #{},
|
||||
previous_passive_voice_states => #{}
|
||||
}
|
||||
},
|
||||
presence_subscriptions => #{},
|
||||
member_list_subscriptions => #{},
|
||||
member_subscriptions => guild_subscriptions:init_state(),
|
||||
data => #{
|
||||
<<"guild">> => #{
|
||||
<<"features">> => []
|
||||
},
|
||||
<<"members">> => [
|
||||
#{<<"user">> => #{<<"id">> => <<"1001">>, <<"flags">> => <<"0">>}},
|
||||
#{<<"user">> => #{<<"id">> => <<"1002">>, <<"flags">> => <<"1">>}}
|
||||
]
|
||||
},
|
||||
voice_states => #{},
|
||||
pending_voice_connections => #{}
|
||||
}.
|
||||
|
||||
is_guild_unavailable_for_user_unavailable_for_everyone_test() ->
|
||||
State = #{
|
||||
data => #{
|
||||
<<"guild">> => #{
|
||||
<<"features">> => [<<"UNAVAILABLE_FOR_EVERYONE">>]
|
||||
},
|
||||
<<"members">> => []
|
||||
}
|
||||
},
|
||||
?assertEqual(true, is_guild_unavailable_for_user(123, State)).
|
||||
|
||||
is_guild_unavailable_for_user_available_test() ->
|
||||
State = #{
|
||||
data => #{
|
||||
<<"guild">> => #{
|
||||
<<"features">> => []
|
||||
},
|
||||
<<"members">> => []
|
||||
}
|
||||
},
|
||||
?assertEqual(false, is_guild_unavailable_for_user(123, State)).
|
||||
|
||||
check_unavailability_transition_no_change_test() ->
|
||||
State = #{
|
||||
data => #{
|
||||
<<"guild">> => #{
|
||||
<<"features">> => []
|
||||
}
|
||||
}
|
||||
},
|
||||
?assertEqual(no_change, check_unavailability_transition(State, State)).
|
||||
|
||||
check_unavailability_transition_enabled_test() ->
|
||||
OldState = #{
|
||||
data => #{
|
||||
<<"guild">> => #{
|
||||
<<"features">> => []
|
||||
}
|
||||
}
|
||||
},
|
||||
NewState = #{
|
||||
data => #{
|
||||
<<"guild">> => #{
|
||||
<<"features">> => [<<"UNAVAILABLE_FOR_EVERYONE">>]
|
||||
}
|
||||
}
|
||||
},
|
||||
?assertEqual({unavailable_enabled, false}, check_unavailability_transition(OldState, NewState)).
|
||||
|
||||
check_unavailability_transition_disabled_test() ->
|
||||
OldState = #{
|
||||
data => #{
|
||||
<<"guild">> => #{
|
||||
<<"features">> => [<<"UNAVAILABLE_FOR_EVERYONE">>]
|
||||
}
|
||||
}
|
||||
},
|
||||
NewState = #{
|
||||
data => #{
|
||||
<<"guild">> => #{
|
||||
<<"features">> => []
|
||||
}
|
||||
}
|
||||
},
|
||||
?assertEqual(unavailable_disabled, check_unavailability_transition(OldState, NewState)).
|
||||
|
||||
-endif.
|
||||
|
||||
@@ -17,9 +17,7 @@
|
||||
|
||||
-module(guild_client).
|
||||
|
||||
-export([
|
||||
voice_state_update/3
|
||||
]).
|
||||
-export([voice_state_update/3]).
|
||||
|
||||
-export_type([
|
||||
voice_state_update_success/0,
|
||||
@@ -27,7 +25,10 @@
|
||||
voice_state_update_result/0
|
||||
]).
|
||||
|
||||
-define(DEFAULT_TIMEOUT, 12000).
|
||||
-define(CIRCUIT_BREAKER_TABLE, guild_circuit_breaker).
|
||||
-define(FAILURE_THRESHOLD, 5).
|
||||
-define(RECOVERY_TIMEOUT_MS, 30000).
|
||||
-define(MAX_CONCURRENT, 50).
|
||||
|
||||
-type voice_state_update_success() :: #{
|
||||
success := true,
|
||||
@@ -44,10 +45,39 @@
|
||||
{ok, voice_state_update_success()}
|
||||
| {error, timeout}
|
||||
| {error, noproc}
|
||||
| {error, circuit_breaker_open}
|
||||
| {error, too_many_requests}
|
||||
| {error, atom(), atom()}.
|
||||
|
||||
-type circuit_state() :: closed | open | half_open.
|
||||
|
||||
-spec voice_state_update(pid(), map(), timeout()) -> voice_state_update_result().
|
||||
voice_state_update(GuildPid, Request, Timeout) ->
|
||||
ensure_table(),
|
||||
case acquire_slot(GuildPid) of
|
||||
ok ->
|
||||
try
|
||||
execute_with_circuit_breaker(GuildPid, Request, Timeout)
|
||||
after
|
||||
release_slot(GuildPid)
|
||||
end;
|
||||
{error, Reason} ->
|
||||
{error, Reason}
|
||||
end.
|
||||
|
||||
-spec execute_with_circuit_breaker(pid(), map(), timeout()) -> voice_state_update_result().
|
||||
execute_with_circuit_breaker(GuildPid, Request, Timeout) ->
|
||||
case get_circuit_state(GuildPid) of
|
||||
open ->
|
||||
{error, circuit_breaker_open};
|
||||
State when State =:= closed; State =:= half_open ->
|
||||
Result = do_call(GuildPid, Request, Timeout),
|
||||
update_circuit_state(GuildPid, Result, State),
|
||||
Result
|
||||
end.
|
||||
|
||||
-spec do_call(pid(), map(), timeout()) -> voice_state_update_result().
|
||||
do_call(GuildPid, Request, Timeout) ->
|
||||
try gen_server:call(GuildPid, {voice_state_update, Request}, Timeout) of
|
||||
Response when is_map(Response) ->
|
||||
case maps:get(success, Response, false) of
|
||||
@@ -57,12 +87,138 @@ voice_state_update(GuildPid, Request, Timeout) ->
|
||||
{error, Category, ErrorAtom} when is_atom(Category), is_atom(ErrorAtom) ->
|
||||
{error, Category, ErrorAtom}
|
||||
catch
|
||||
exit:{timeout, _} ->
|
||||
{error, timeout};
|
||||
exit:{noproc, _} ->
|
||||
{error, noproc};
|
||||
exit:{normal, _} ->
|
||||
{error, noproc}
|
||||
exit:{timeout, _} -> {error, timeout};
|
||||
exit:{noproc, _} -> {error, noproc};
|
||||
exit:{normal, _} -> {error, noproc}
|
||||
end.
|
||||
|
||||
-spec get_circuit_state(pid()) -> circuit_state().
|
||||
get_circuit_state(GuildPid) ->
|
||||
case safe_lookup(GuildPid) of
|
||||
[] ->
|
||||
closed;
|
||||
[{_, #{state := open, opened_at := OpenedAt}}] ->
|
||||
Now = erlang:system_time(millisecond),
|
||||
case Now - OpenedAt > ?RECOVERY_TIMEOUT_MS of
|
||||
true -> half_open;
|
||||
false -> open
|
||||
end;
|
||||
[{_, #{state := State}}] ->
|
||||
State
|
||||
end.
|
||||
|
||||
-spec update_circuit_state(pid(), voice_state_update_result(), circuit_state()) -> ok.
|
||||
update_circuit_state(GuildPid, Result, PrevState) ->
|
||||
IsSuccess = is_success_result(Result),
|
||||
case {IsSuccess, PrevState} of
|
||||
{true, half_open} ->
|
||||
ets:delete(?CIRCUIT_BREAKER_TABLE, GuildPid),
|
||||
ok;
|
||||
{true, closed} ->
|
||||
reset_failures(GuildPid);
|
||||
{false, _} ->
|
||||
record_failure(GuildPid)
|
||||
end.
|
||||
|
||||
-spec is_success_result(voice_state_update_result()) -> boolean().
|
||||
is_success_result({ok, _}) -> true;
|
||||
is_success_result(_) -> false.
|
||||
|
||||
-spec reset_failures(pid()) -> ok.
|
||||
reset_failures(GuildPid) ->
|
||||
case safe_lookup(GuildPid) of
|
||||
[{_, State}] ->
|
||||
ets:insert(?CIRCUIT_BREAKER_TABLE, {GuildPid, State#{failures => 0}}),
|
||||
ok;
|
||||
[] ->
|
||||
ok
|
||||
end.
|
||||
|
||||
-spec record_failure(pid()) -> ok.
|
||||
record_failure(GuildPid) ->
|
||||
Now = erlang:system_time(millisecond),
|
||||
case safe_lookup(GuildPid) of
|
||||
[] ->
|
||||
ets:insert(
|
||||
?CIRCUIT_BREAKER_TABLE,
|
||||
{GuildPid, #{
|
||||
state => closed,
|
||||
failures => 1,
|
||||
concurrent => 0
|
||||
}}
|
||||
),
|
||||
ok;
|
||||
[{_, #{failures := F} = State}] when F + 1 >= ?FAILURE_THRESHOLD ->
|
||||
ets:insert(
|
||||
?CIRCUIT_BREAKER_TABLE,
|
||||
{GuildPid, State#{
|
||||
state => open,
|
||||
failures => F + 1,
|
||||
opened_at => Now
|
||||
}}
|
||||
),
|
||||
ok;
|
||||
[{_, #{failures := F} = State}] ->
|
||||
ets:insert(?CIRCUIT_BREAKER_TABLE, {GuildPid, State#{failures => F + 1}}),
|
||||
ok
|
||||
end.
|
||||
|
||||
-spec acquire_slot(pid()) -> ok | {error, too_many_requests}.
|
||||
acquire_slot(GuildPid) ->
|
||||
case safe_lookup(GuildPid) of
|
||||
[] ->
|
||||
ets:insert(
|
||||
?CIRCUIT_BREAKER_TABLE,
|
||||
{GuildPid, #{
|
||||
state => closed,
|
||||
failures => 0,
|
||||
concurrent => 1
|
||||
}}
|
||||
),
|
||||
ok;
|
||||
[{_, #{concurrent := C}}] when C >= ?MAX_CONCURRENT ->
|
||||
{error, too_many_requests};
|
||||
[{_, #{concurrent := C} = State}] ->
|
||||
ets:insert(?CIRCUIT_BREAKER_TABLE, {GuildPid, State#{concurrent => C + 1}}),
|
||||
ok
|
||||
end.
|
||||
|
||||
-spec release_slot(pid()) -> ok.
|
||||
release_slot(GuildPid) ->
|
||||
case safe_lookup(GuildPid) of
|
||||
[{_, #{concurrent := C} = State}] when C > 0 ->
|
||||
ets:insert(?CIRCUIT_BREAKER_TABLE, {GuildPid, State#{concurrent => C - 1}}),
|
||||
ok;
|
||||
_ ->
|
||||
ok
|
||||
end.
|
||||
|
||||
-spec safe_lookup(pid()) -> list().
|
||||
safe_lookup(GuildPid) ->
|
||||
try ets:lookup(?CIRCUIT_BREAKER_TABLE, GuildPid) of
|
||||
Result -> Result
|
||||
catch
|
||||
error:badarg -> []
|
||||
end.
|
||||
|
||||
-spec ensure_table() -> ok.
|
||||
ensure_table() ->
|
||||
case ets:whereis(?CIRCUIT_BREAKER_TABLE) of
|
||||
undefined ->
|
||||
try
|
||||
ets:new(?CIRCUIT_BREAKER_TABLE, [
|
||||
named_table,
|
||||
public,
|
||||
set,
|
||||
{read_concurrency, true},
|
||||
{write_concurrency, true}
|
||||
]),
|
||||
ok
|
||||
catch
|
||||
error:badarg -> ok
|
||||
end;
|
||||
_ ->
|
||||
ok
|
||||
end.
|
||||
|
||||
-ifdef(TEST).
|
||||
@@ -72,4 +228,136 @@ module_exports_test() ->
|
||||
Exports = guild_client:module_info(exports),
|
||||
?assert(lists:member({voice_state_update, 3}, Exports)).
|
||||
|
||||
ensure_table_creates_table_test() ->
|
||||
catch ets:delete(?CIRCUIT_BREAKER_TABLE),
|
||||
?assertEqual(undefined, ets:whereis(?CIRCUIT_BREAKER_TABLE)),
|
||||
ensure_table(),
|
||||
?assertNotEqual(undefined, ets:whereis(?CIRCUIT_BREAKER_TABLE)).
|
||||
|
||||
ensure_table_idempotent_test() ->
|
||||
ensure_table(),
|
||||
ensure_table(),
|
||||
?assertNotEqual(undefined, ets:whereis(?CIRCUIT_BREAKER_TABLE)).
|
||||
|
||||
acquire_slot_creates_entry_test() ->
|
||||
ensure_table(),
|
||||
Pid = spawn(fun() ->
|
||||
receive
|
||||
done -> ok
|
||||
end
|
||||
end),
|
||||
ets:delete_all_objects(?CIRCUIT_BREAKER_TABLE),
|
||||
?assertEqual(ok, acquire_slot(Pid)),
|
||||
[{Pid, State}] = ets:lookup(?CIRCUIT_BREAKER_TABLE, Pid),
|
||||
?assertEqual(1, maps:get(concurrent, State)),
|
||||
Pid ! done.
|
||||
|
||||
acquire_slot_increments_test() ->
|
||||
ensure_table(),
|
||||
Pid = spawn(fun() ->
|
||||
receive
|
||||
done -> ok
|
||||
end
|
||||
end),
|
||||
ets:delete_all_objects(?CIRCUIT_BREAKER_TABLE),
|
||||
acquire_slot(Pid),
|
||||
acquire_slot(Pid),
|
||||
[{Pid, State}] = ets:lookup(?CIRCUIT_BREAKER_TABLE, Pid),
|
||||
?assertEqual(2, maps:get(concurrent, State)),
|
||||
Pid ! done.
|
||||
|
||||
release_slot_decrements_test() ->
|
||||
ensure_table(),
|
||||
Pid = spawn(fun() ->
|
||||
receive
|
||||
done -> ok
|
||||
end
|
||||
end),
|
||||
ets:delete_all_objects(?CIRCUIT_BREAKER_TABLE),
|
||||
acquire_slot(Pid),
|
||||
acquire_slot(Pid),
|
||||
release_slot(Pid),
|
||||
[{Pid, State}] = ets:lookup(?CIRCUIT_BREAKER_TABLE, Pid),
|
||||
?assertEqual(1, maps:get(concurrent, State)),
|
||||
Pid ! done.
|
||||
|
||||
get_circuit_state_closed_test() ->
|
||||
ensure_table(),
|
||||
Pid = spawn(fun() ->
|
||||
receive
|
||||
done -> ok
|
||||
end
|
||||
end),
|
||||
ets:delete_all_objects(?CIRCUIT_BREAKER_TABLE),
|
||||
?assertEqual(closed, get_circuit_state(Pid)),
|
||||
Pid ! done.
|
||||
|
||||
get_circuit_state_open_test() ->
|
||||
ensure_table(),
|
||||
Pid = spawn(fun() ->
|
||||
receive
|
||||
done -> ok
|
||||
end
|
||||
end),
|
||||
ets:delete_all_objects(?CIRCUIT_BREAKER_TABLE),
|
||||
Now = erlang:system_time(millisecond),
|
||||
ets:insert(
|
||||
?CIRCUIT_BREAKER_TABLE,
|
||||
{Pid, #{
|
||||
state => open,
|
||||
failures => 5,
|
||||
concurrent => 0,
|
||||
opened_at => Now
|
||||
}}
|
||||
),
|
||||
?assertEqual(open, get_circuit_state(Pid)),
|
||||
Pid ! done.
|
||||
|
||||
get_circuit_state_half_open_test() ->
|
||||
ensure_table(),
|
||||
Pid = spawn(fun() ->
|
||||
receive
|
||||
done -> ok
|
||||
end
|
||||
end),
|
||||
ets:delete_all_objects(?CIRCUIT_BREAKER_TABLE),
|
||||
OldTime = erlang:system_time(millisecond) - ?RECOVERY_TIMEOUT_MS - 1000,
|
||||
ets:insert(
|
||||
?CIRCUIT_BREAKER_TABLE,
|
||||
{Pid, #{
|
||||
state => open,
|
||||
failures => 5,
|
||||
concurrent => 0,
|
||||
opened_at => OldTime
|
||||
}}
|
||||
),
|
||||
?assertEqual(half_open, get_circuit_state(Pid)),
|
||||
Pid ! done.
|
||||
|
||||
record_failure_opens_circuit_test() ->
|
||||
ensure_table(),
|
||||
Pid = spawn(fun() ->
|
||||
receive
|
||||
done -> ok
|
||||
end
|
||||
end),
|
||||
ets:delete_all_objects(?CIRCUIT_BREAKER_TABLE),
|
||||
ets:insert(
|
||||
?CIRCUIT_BREAKER_TABLE,
|
||||
{Pid, #{
|
||||
state => closed,
|
||||
failures => ?FAILURE_THRESHOLD - 1,
|
||||
concurrent => 0
|
||||
}}
|
||||
),
|
||||
record_failure(Pid),
|
||||
[{Pid, State}] = ets:lookup(?CIRCUIT_BREAKER_TABLE, Pid),
|
||||
?assertEqual(open, maps:get(state, State)),
|
||||
Pid ! done.
|
||||
|
||||
is_success_result_test() ->
|
||||
?assertEqual(true, is_success_result({ok, #{}})),
|
||||
?assertEqual(false, is_success_result({error, timeout})),
|
||||
?assertEqual(false, is_success_result({error, noproc})).
|
||||
|
||||
-endif.
|
||||
|
||||
@@ -31,6 +31,7 @@
|
||||
-type guild_data_map() :: map().
|
||||
-type guild_member() :: map().
|
||||
-type channel_list() :: [map()].
|
||||
-type user_id() :: integer().
|
||||
|
||||
-ifdef(TEST).
|
||||
-include_lib("eunit/include/eunit.hrl").
|
||||
@@ -45,11 +46,10 @@ get_guild_data(#{user_id := UserId}, State) ->
|
||||
Reply = #{guild_data => GuildData},
|
||||
{reply, Reply, State};
|
||||
_ ->
|
||||
Members = map_utils:ensure_list(maps:get(<<"members">>, Data, [])),
|
||||
case member_in_list(UserId, Members) of
|
||||
false ->
|
||||
case guild_data_index:get_member(UserId, Data) of
|
||||
undefined ->
|
||||
{reply, #{guild_data => null, error_reason => <<"forbidden">>}, State};
|
||||
true ->
|
||||
_ ->
|
||||
GuildData = build_complete_guild_data(Data, State),
|
||||
{reply, #{guild_data => GuildData}, State}
|
||||
end
|
||||
@@ -76,7 +76,7 @@ has_member(#{user_id := UserId}, State) ->
|
||||
-spec list_guild_members(map(), guild_state()) -> guild_reply(map()).
|
||||
list_guild_members(#{limit := Limit, offset := Offset}, State) ->
|
||||
Data = guild_data_map(State),
|
||||
AllMembers = map_utils:ensure_list(maps:get(<<"members">>, Data, [])),
|
||||
AllMembers = guild_data_index:member_list(Data),
|
||||
TotalCount = length(AllMembers),
|
||||
PaginatedMembers = paginate_members(AllMembers, Limit, Offset),
|
||||
{reply, #{members => PaginatedMembers, total => TotalCount}, State}.
|
||||
@@ -93,19 +93,24 @@ get_first_viewable_text_channel(State) ->
|
||||
EveryoneChannelId = find_everyone_viewable_text_channel(Channels, State),
|
||||
{reply, #{channel_id => EveryoneChannelId}, State}.
|
||||
|
||||
-spec get_guild_state(integer(), guild_state()) -> map().
|
||||
-spec get_guild_state(user_id(), guild_state()) -> map().
|
||||
get_guild_state(UserId, State) ->
|
||||
Data = guild_data_map(State),
|
||||
GuildId = map_utils:get_integer(State, id, 0),
|
||||
AllChannels = channels_from_data(Data),
|
||||
AllMembers = map_utils:ensure_list(maps:get(<<"members">>, Data, [])),
|
||||
AllMembers = guild_data_index:member_values(Data),
|
||||
Member = find_member_by_user_id(UserId, State),
|
||||
{ViewableChannels, JoinedAt} = derive_member_view(UserId, Member, State, AllChannels),
|
||||
OnlineCount = guild_member_list:get_online_count(State),
|
||||
OwnMemberList = case Member of
|
||||
undefined -> [];
|
||||
M -> [M]
|
||||
end,
|
||||
OwnMemberList =
|
||||
case Member of
|
||||
undefined -> [];
|
||||
M -> [M]
|
||||
end,
|
||||
VoiceStates = guild_voice:get_voice_states_list(State),
|
||||
VoiceMembers = voice_members_from_states(VoiceStates, AllMembers),
|
||||
Members = merge_members(OwnMemberList, VoiceMembers),
|
||||
MemberCount = maps:get(member_count, State, length(AllMembers)),
|
||||
#{
|
||||
<<"id">> => integer_to_binary(GuildId),
|
||||
<<"properties">> => maps:get(<<"guild">>, Data, #{}),
|
||||
@@ -113,11 +118,11 @@ get_guild_state(UserId, State) ->
|
||||
<<"channels">> => ViewableChannels,
|
||||
<<"emojis">> => maps:get(<<"emojis">>, Data, []),
|
||||
<<"stickers">> => maps:get(<<"stickers">>, Data, []),
|
||||
<<"members">> => OwnMemberList,
|
||||
<<"member_count">> => length(AllMembers),
|
||||
<<"members">> => Members,
|
||||
<<"member_count">> => MemberCount,
|
||||
<<"online_count">> => OnlineCount,
|
||||
<<"presences">> => [],
|
||||
<<"voice_states">> => guild_voice:get_voice_states_list(State),
|
||||
<<"voice_states">> => VoiceStates,
|
||||
<<"joined_at">> => JoinedAt
|
||||
}.
|
||||
|
||||
@@ -140,6 +145,7 @@ find_everyone_viewable_text_channel(Channels, State) ->
|
||||
map_utils:ensure_list(Channels)
|
||||
).
|
||||
|
||||
-spec find_member_by_user_id(user_id(), guild_state()) -> guild_member() | undefined.
|
||||
find_member_by_user_id(UserId, State) ->
|
||||
guild_permissions:find_member_by_user_id(UserId, State).
|
||||
|
||||
@@ -164,19 +170,7 @@ channels_from_state(State) ->
|
||||
|
||||
-spec channels_from_data(guild_data_map()) -> channel_list().
|
||||
channels_from_data(Data) ->
|
||||
map_utils:ensure_list(maps:get(<<"channels">>, Data, [])).
|
||||
|
||||
-spec member_in_list(integer(), [guild_member()]) -> boolean().
|
||||
member_in_list(UserId, Members) ->
|
||||
lists:any(fun(Member) -> member_matches(UserId, Member) end, Members).
|
||||
|
||||
-spec member_matches(integer(), guild_member()) -> boolean().
|
||||
member_matches(UserId, Member) ->
|
||||
MemberUser = map_utils:ensure_map(maps:get(<<"user">>, Member, #{})),
|
||||
case map_utils:get_integer(MemberUser, <<"id">>, undefined) of
|
||||
undefined -> false;
|
||||
Id -> Id =:= UserId
|
||||
end.
|
||||
guild_data_index:channel_list(Data).
|
||||
|
||||
-spec paginate_members([guild_member()], non_neg_integer(), non_neg_integer()) -> [guild_member()].
|
||||
paginate_members(Members, Limit, Offset) ->
|
||||
@@ -191,7 +185,7 @@ paginate_members(Members, Limit, Offset) ->
|
||||
end
|
||||
end.
|
||||
|
||||
-spec derive_member_view(integer(), guild_member() | undefined, guild_state(), channel_list()) ->
|
||||
-spec derive_member_view(user_id(), guild_member() | undefined, guild_state(), channel_list()) ->
|
||||
{channel_list(), term()}.
|
||||
derive_member_view(_UserId, undefined, _State, _Channels) ->
|
||||
{[], null};
|
||||
@@ -210,7 +204,63 @@ derive_member_view(UserId, Member, State, Channels) ->
|
||||
JoinedAt = maps:get(<<"joined_at">>, Member, null),
|
||||
{Filtered, JoinedAt}.
|
||||
|
||||
-spec role_permissions_for_id(list(), integer()) -> integer().
|
||||
-spec voice_members_from_states([map()], [guild_member()]) -> [guild_member()].
|
||||
voice_members_from_states(VoiceStates, Members) ->
|
||||
MemberIndex = build_member_index(Members),
|
||||
lists:filtermap(
|
||||
fun(VoiceState) ->
|
||||
case voice_state_utils:voice_state_user_id(VoiceState) of
|
||||
undefined ->
|
||||
false;
|
||||
UserId ->
|
||||
case maps:get(UserId, MemberIndex, undefined) of
|
||||
undefined -> false;
|
||||
Member -> {true, Member}
|
||||
end
|
||||
end
|
||||
end,
|
||||
VoiceStates
|
||||
).
|
||||
|
||||
-spec build_member_index([guild_member()]) -> #{integer() => guild_member()}.
|
||||
build_member_index(Members) ->
|
||||
lists:foldl(
|
||||
fun(Member, Acc) ->
|
||||
case member_user_id(Member) of
|
||||
undefined -> Acc;
|
||||
UserId -> maps:put(UserId, Member, Acc)
|
||||
end
|
||||
end,
|
||||
#{},
|
||||
Members
|
||||
).
|
||||
|
||||
-spec merge_members([guild_member()], [guild_member()]) -> [guild_member()].
|
||||
merge_members(Primary, Secondary) ->
|
||||
{Merged, _} =
|
||||
lists:foldl(
|
||||
fun(Member, {Acc, Seen}) ->
|
||||
case member_user_id(Member) of
|
||||
undefined ->
|
||||
{Acc, Seen};
|
||||
UserId ->
|
||||
case sets:is_element(UserId, Seen) of
|
||||
true -> {Acc, Seen};
|
||||
false -> {[Member | Acc], sets:add_element(UserId, Seen)}
|
||||
end
|
||||
end
|
||||
end,
|
||||
{[], sets:new()},
|
||||
Primary ++ Secondary
|
||||
),
|
||||
lists:reverse(Merged).
|
||||
|
||||
-spec member_user_id(guild_member()) -> integer() | undefined.
|
||||
member_user_id(Member) ->
|
||||
MemberUser = map_utils:ensure_map(maps:get(<<"user">>, Member, #{})),
|
||||
map_utils:get_integer(MemberUser, <<"id">>, undefined).
|
||||
|
||||
-spec role_permissions_for_id([map()], integer()) -> integer().
|
||||
role_permissions_for_id(Roles, GuildId) ->
|
||||
lists:foldl(
|
||||
fun(Role, Acc) ->
|
||||
@@ -229,6 +279,10 @@ select_first_viewable(Channel, GuildId, BasePerms) ->
|
||||
ChannelId = map_utils:get_integer(Channel, <<"id">>, undefined),
|
||||
select_first_viewable(ChannelType, ChannelId, Channel, GuildId, BasePerms).
|
||||
|
||||
-spec select_first_viewable(
|
||||
integer() | undefined, integer() | undefined, map(), integer(), integer()
|
||||
) ->
|
||||
integer() | null.
|
||||
select_first_viewable(0, ChannelId, Channel, GuildId, BasePerms) when is_integer(ChannelId) ->
|
||||
case (BasePerms band constants:administrator_permission()) =/= 0 of
|
||||
true ->
|
||||
@@ -273,6 +327,13 @@ find_everyone_viewable_text_channel_test() ->
|
||||
ChannelId = find_everyone_viewable_text_channel(Channels, State),
|
||||
?assertEqual(500, ChannelId).
|
||||
|
||||
paginate_members_test() ->
|
||||
Members = [#{<<"id">> => 1}, #{<<"id">> => 2}, #{<<"id">> => 3}],
|
||||
?assertEqual([#{<<"id">> => 1}, #{<<"id">> => 2}], paginate_members(Members, 2, 0)),
|
||||
?assertEqual([#{<<"id">> => 2}, #{<<"id">> => 3}], paginate_members(Members, 2, 1)),
|
||||
?assertEqual([#{<<"id">> => 3}], paginate_members(Members, 2, 2)),
|
||||
?assertEqual([], paginate_members(Members, 2, 5)).
|
||||
|
||||
test_state() ->
|
||||
GuildId = 100,
|
||||
ViewPerm = constants:view_channel_permission(),
|
||||
|
||||
480
fluxer_gateway/src/guild/guild_data_index.erl
Normal file
480
fluxer_gateway/src/guild/guild_data_index.erl
Normal file
@@ -0,0 +1,480 @@
|
||||
%% Copyright (C) 2026 Fluxer Contributors
|
||||
%%
|
||||
%% This file is part of Fluxer.
|
||||
%%
|
||||
%% Fluxer is free software: you can redistribute it and/or modify
|
||||
%% it under the terms of the GNU Affero General Public License as published by
|
||||
%% the Free Software Foundation, either version 3 of the License, or
|
||||
%% (at your option) any later version.
|
||||
%%
|
||||
%% Fluxer is distributed in the hope that it will be useful,
|
||||
%% but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
%% GNU Affero General Public License for more details.
|
||||
%%
|
||||
%% You should have received a copy of the GNU Affero General Public License
|
||||
%% along with Fluxer. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
-module(guild_data_index).
|
||||
|
||||
-export([
|
||||
normalize_data/1,
|
||||
member_map/1,
|
||||
member_values/1,
|
||||
member_list/1,
|
||||
member_count/1,
|
||||
member_ids/1,
|
||||
member_role_index/1,
|
||||
get_member/2,
|
||||
put_member/2,
|
||||
put_member_map/2,
|
||||
put_member_list/2,
|
||||
remove_member/2,
|
||||
role_list/1,
|
||||
role_index/1,
|
||||
put_roles/2,
|
||||
channel_list/1,
|
||||
channel_index/1,
|
||||
put_channels/2
|
||||
]).
|
||||
|
||||
-type guild_data() :: map().
|
||||
-type member() :: map().
|
||||
-type role() :: map().
|
||||
-type channel() :: map().
|
||||
-type user_id() :: integer().
|
||||
-type snowflake_id() :: integer().
|
||||
-type role_member_index() :: #{snowflake_id() => #{user_id() => true}}.
|
||||
|
||||
-spec normalize_data(guild_data()) -> guild_data().
|
||||
normalize_data(Data) when is_map(Data) ->
|
||||
MemberMap = member_map(Data),
|
||||
Roles = role_list(Data),
|
||||
Channels = channel_list(Data),
|
||||
Data1 = maps:put(<<"members">>, MemberMap, Data),
|
||||
Data2 = maps:put(<<"roles">>, Roles, Data1),
|
||||
Data3 = maps:put(<<"channels">>, Channels, Data2),
|
||||
Data4 = maps:put(<<"role_index">>, build_id_index(Roles), Data3),
|
||||
Data5 = maps:put(<<"channel_index">>, build_id_index(Channels), Data4),
|
||||
maps:put(<<"member_role_index">>, build_member_role_index(MemberMap), Data5);
|
||||
normalize_data(Data) ->
|
||||
Data.
|
||||
|
||||
-spec member_map(guild_data()) -> #{user_id() => member()}.
|
||||
member_map(Data) when is_map(Data) ->
|
||||
case maps:get(<<"members">>, Data, #{}) of
|
||||
Members when is_map(Members) ->
|
||||
normalize_member_map(Members);
|
||||
Members when is_list(Members) ->
|
||||
build_member_map(Members);
|
||||
_ ->
|
||||
#{}
|
||||
end;
|
||||
member_map(_) ->
|
||||
#{}.
|
||||
|
||||
-spec member_list(guild_data()) -> [member()].
|
||||
member_list(Data) ->
|
||||
MemberPairs = maps:to_list(member_map(Data)),
|
||||
SortedPairs = lists:sort(fun({A, _}, {B, _}) -> A =< B end, MemberPairs),
|
||||
[Member || {_UserId, Member} <- SortedPairs].
|
||||
|
||||
-spec member_values(guild_data()) -> [member()].
|
||||
member_values(Data) ->
|
||||
maps:values(member_map(Data)).
|
||||
|
||||
-spec member_count(guild_data()) -> non_neg_integer().
|
||||
member_count(Data) ->
|
||||
map_size(member_map(Data)).
|
||||
|
||||
-spec member_ids(guild_data()) -> [user_id()].
|
||||
member_ids(Data) ->
|
||||
maps:keys(member_map(Data)).
|
||||
|
||||
-spec member_role_index(guild_data()) -> role_member_index().
|
||||
member_role_index(Data) when is_map(Data) ->
|
||||
case maps:get(<<"member_role_index">>, Data, undefined) of
|
||||
Index when is_map(Index) -> normalize_member_role_index(Index);
|
||||
_ -> build_member_role_index(member_map(Data))
|
||||
end;
|
||||
member_role_index(_) ->
|
||||
#{}.
|
||||
|
||||
-spec get_member(user_id(), guild_data()) -> member() | undefined.
|
||||
get_member(UserId, Data) when is_integer(UserId) ->
|
||||
maps:get(UserId, member_map(Data), undefined);
|
||||
get_member(_, _) ->
|
||||
undefined.
|
||||
|
||||
-spec put_member(member(), guild_data()) -> guild_data().
|
||||
put_member(Member, Data) when is_map(Member), is_map(Data) ->
|
||||
case member_user_id(Member) of
|
||||
undefined ->
|
||||
Data;
|
||||
UserId ->
|
||||
MemberMap = member_map(Data),
|
||||
ExistingMember = maps:get(UserId, MemberMap, undefined),
|
||||
ExistingRoles = member_role_ids(ExistingMember),
|
||||
UpdatedRoles = member_role_ids(Member),
|
||||
RoleIndex = member_role_index(Data),
|
||||
RoleIndex1 = remove_user_from_member_role_index(UserId, ExistingRoles, RoleIndex),
|
||||
RoleIndex2 = add_user_to_member_role_index(UserId, UpdatedRoles, RoleIndex1),
|
||||
Data1 = maps:put(<<"members">>, maps:put(UserId, Member, MemberMap), Data),
|
||||
maps:put(<<"member_role_index">>, RoleIndex2, Data1)
|
||||
end;
|
||||
put_member(_, Data) ->
|
||||
Data.
|
||||
|
||||
-spec put_member_map(#{user_id() => member()}, guild_data()) -> guild_data().
|
||||
put_member_map(MemberMap, Data) when is_map(MemberMap), is_map(Data) ->
|
||||
NormalizedMemberMap = normalize_member_map(MemberMap),
|
||||
Data1 = maps:put(<<"members">>, NormalizedMemberMap, Data),
|
||||
maps:put(<<"member_role_index">>, build_member_role_index(NormalizedMemberMap), Data1);
|
||||
put_member_map(_, Data) ->
|
||||
Data.
|
||||
|
||||
-spec put_member_list([member()], guild_data()) -> guild_data().
|
||||
put_member_list(Members, Data) when is_list(Members), is_map(Data) ->
|
||||
put_member_map(build_member_map(Members), Data);
|
||||
put_member_list(_, Data) ->
|
||||
Data.
|
||||
|
||||
-spec remove_member(user_id(), guild_data()) -> guild_data().
|
||||
remove_member(UserId, Data) when is_integer(UserId), is_map(Data) ->
|
||||
MemberMap = member_map(Data),
|
||||
Member = maps:get(UserId, MemberMap, undefined),
|
||||
MemberRoles = member_role_ids(Member),
|
||||
RoleIndex = member_role_index(Data),
|
||||
RoleIndex1 = remove_user_from_member_role_index(UserId, MemberRoles, RoleIndex),
|
||||
Data1 = maps:put(<<"members">>, maps:remove(UserId, MemberMap), Data),
|
||||
maps:put(<<"member_role_index">>, RoleIndex1, Data1);
|
||||
remove_member(_, Data) ->
|
||||
Data.
|
||||
|
||||
-spec role_list(guild_data()) -> [role()].
|
||||
role_list(Data) when is_map(Data) ->
|
||||
ensure_list(maps:get(<<"roles">>, Data, []));
|
||||
role_list(_) ->
|
||||
[].
|
||||
|
||||
-spec role_index(guild_data()) -> #{snowflake_id() => role()}.
|
||||
role_index(Data) when is_map(Data) ->
|
||||
case maps:get(<<"role_index">>, Data, undefined) of
|
||||
Index when is_map(Index) -> normalize_id_index(Index);
|
||||
_ -> build_id_index(role_list(Data))
|
||||
end;
|
||||
role_index(_) ->
|
||||
#{}.
|
||||
|
||||
-spec put_roles([role()], guild_data()) -> guild_data().
|
||||
put_roles(Roles, Data) when is_map(Data) ->
|
||||
RoleList = ensure_list(Roles),
|
||||
Data1 = maps:put(<<"roles">>, RoleList, Data),
|
||||
maps:put(<<"role_index">>, build_id_index(RoleList), Data1);
|
||||
put_roles(_, Data) ->
|
||||
Data.
|
||||
|
||||
-spec channel_list(guild_data()) -> [channel()].
|
||||
channel_list(Data) when is_map(Data) ->
|
||||
ensure_list(maps:get(<<"channels">>, Data, []));
|
||||
channel_list(_) ->
|
||||
[].
|
||||
|
||||
-spec channel_index(guild_data()) -> #{snowflake_id() => channel()}.
|
||||
channel_index(Data) when is_map(Data) ->
|
||||
case maps:get(<<"channel_index">>, Data, undefined) of
|
||||
Index when is_map(Index) -> normalize_id_index(Index);
|
||||
_ -> build_id_index(channel_list(Data))
|
||||
end;
|
||||
channel_index(_) ->
|
||||
#{}.
|
||||
|
||||
-spec put_channels([channel()], guild_data()) -> guild_data().
|
||||
put_channels(Channels, Data) when is_map(Data) ->
|
||||
ChannelList = ensure_list(Channels),
|
||||
Data1 = maps:put(<<"channels">>, ChannelList, Data),
|
||||
maps:put(<<"channel_index">>, build_id_index(ChannelList), Data1);
|
||||
put_channels(_, Data) ->
|
||||
Data.
|
||||
|
||||
-spec build_member_map([member()]) -> #{user_id() => member()}.
|
||||
build_member_map(Members) ->
|
||||
lists:foldl(
|
||||
fun(Member, Acc) ->
|
||||
case member_user_id(Member) of
|
||||
undefined ->
|
||||
Acc;
|
||||
UserId ->
|
||||
maps:put(UserId, Member, Acc)
|
||||
end
|
||||
end,
|
||||
#{},
|
||||
Members
|
||||
).
|
||||
|
||||
-spec normalize_member_map(map()) -> #{user_id() => member()}.
|
||||
normalize_member_map(MemberMap) ->
|
||||
maps:fold(
|
||||
fun(Key, Member, Acc) ->
|
||||
case normalize_member_key(Key, Member) of
|
||||
undefined ->
|
||||
Acc;
|
||||
UserId ->
|
||||
maps:put(UserId, Member, Acc)
|
||||
end
|
||||
end,
|
||||
#{},
|
||||
MemberMap
|
||||
).
|
||||
|
||||
-spec normalize_member_key(term(), member()) -> user_id() | undefined.
|
||||
normalize_member_key(Key, Member) ->
|
||||
case type_conv:to_integer(Key) of
|
||||
undefined -> member_user_id(Member);
|
||||
UserId -> UserId
|
||||
end.
|
||||
|
||||
-spec member_user_id(member()) -> user_id() | undefined.
|
||||
member_user_id(Member) when is_map(Member) ->
|
||||
User = maps:get(<<"user">>, Member, #{}),
|
||||
map_utils:get_integer(User, <<"id">>, undefined);
|
||||
member_user_id(_) ->
|
||||
undefined.
|
||||
|
||||
-spec member_role_ids(term()) -> [snowflake_id()].
|
||||
member_role_ids(Member) when is_map(Member) ->
|
||||
extract_integer_list(maps:get(<<"roles">>, Member, []));
|
||||
member_role_ids(_) ->
|
||||
[].
|
||||
|
||||
-spec build_member_role_index(#{user_id() => member()}) -> role_member_index().
|
||||
build_member_role_index(MemberMap) ->
|
||||
maps:fold(
|
||||
fun(UserId, Member, Acc) ->
|
||||
add_user_to_member_role_index(UserId, member_role_ids(Member), Acc)
|
||||
end,
|
||||
#{},
|
||||
MemberMap
|
||||
).
|
||||
|
||||
-spec normalize_member_role_index(map()) -> role_member_index().
|
||||
normalize_member_role_index(Index) ->
|
||||
maps:fold(
|
||||
fun(RoleKey, Members, Acc) ->
|
||||
case type_conv:to_integer(RoleKey) of
|
||||
undefined ->
|
||||
Acc;
|
||||
RoleId ->
|
||||
NormalizedMembers = normalize_member_role_members(Members),
|
||||
case map_size(NormalizedMembers) of
|
||||
0 ->
|
||||
Acc;
|
||||
_ ->
|
||||
maps:put(RoleId, NormalizedMembers, Acc)
|
||||
end
|
||||
end
|
||||
end,
|
||||
#{},
|
||||
Index
|
||||
).
|
||||
|
||||
-spec normalize_member_role_members(term()) -> #{user_id() => true}.
|
||||
normalize_member_role_members(Members) when is_map(Members) ->
|
||||
maps:fold(
|
||||
fun(UserKey, _Flag, Acc) ->
|
||||
case type_conv:to_integer(UserKey) of
|
||||
undefined ->
|
||||
Acc;
|
||||
UserId ->
|
||||
maps:put(UserId, true, Acc)
|
||||
end
|
||||
end,
|
||||
#{},
|
||||
Members
|
||||
);
|
||||
normalize_member_role_members(_) ->
|
||||
#{}.
|
||||
|
||||
-spec add_user_to_member_role_index(user_id(), [snowflake_id()], role_member_index()) ->
|
||||
role_member_index().
|
||||
add_user_to_member_role_index(UserId, RoleIds, RoleIndex) when is_integer(UserId) ->
|
||||
lists:foldl(
|
||||
fun(RoleId, Acc) ->
|
||||
RoleMembers = maps:get(RoleId, Acc, #{}),
|
||||
maps:put(RoleId, maps:put(UserId, true, RoleMembers), Acc)
|
||||
end,
|
||||
RoleIndex,
|
||||
RoleIds
|
||||
).
|
||||
|
||||
-spec remove_user_from_member_role_index(user_id(), [snowflake_id()], role_member_index()) ->
|
||||
role_member_index().
|
||||
remove_user_from_member_role_index(UserId, RoleIds, RoleIndex) when is_integer(UserId) ->
|
||||
lists:foldl(
|
||||
fun(RoleId, Acc) ->
|
||||
RoleMembers = maps:get(RoleId, Acc, #{}),
|
||||
UpdatedRoleMembers = maps:remove(UserId, RoleMembers),
|
||||
case map_size(UpdatedRoleMembers) of
|
||||
0 ->
|
||||
maps:remove(RoleId, Acc);
|
||||
_ ->
|
||||
maps:put(RoleId, UpdatedRoleMembers, Acc)
|
||||
end
|
||||
end,
|
||||
RoleIndex,
|
||||
RoleIds
|
||||
).
|
||||
|
||||
-spec build_id_index([map()]) -> #{snowflake_id() => map()}.
|
||||
build_id_index(Items) ->
|
||||
lists:foldl(
|
||||
fun(Item, Acc) ->
|
||||
case map_utils:get_integer(Item, <<"id">>, undefined) of
|
||||
undefined ->
|
||||
Acc;
|
||||
Id ->
|
||||
maps:put(Id, Item, Acc)
|
||||
end
|
||||
end,
|
||||
#{},
|
||||
Items
|
||||
).
|
||||
|
||||
-spec normalize_id_index(map()) -> #{snowflake_id() => map()}.
|
||||
normalize_id_index(Index) ->
|
||||
maps:fold(
|
||||
fun(Key, Item, Acc) ->
|
||||
case type_conv:to_integer(Key) of
|
||||
undefined ->
|
||||
case map_utils:get_integer(Item, <<"id">>, undefined) of
|
||||
undefined -> Acc;
|
||||
Id -> maps:put(Id, Item, Acc)
|
||||
end;
|
||||
Id ->
|
||||
maps:put(Id, Item, Acc)
|
||||
end
|
||||
end,
|
||||
#{},
|
||||
Index
|
||||
).
|
||||
|
||||
-spec ensure_list(term()) -> list().
|
||||
ensure_list(List) when is_list(List) ->
|
||||
List;
|
||||
ensure_list(_) ->
|
||||
[].
|
||||
|
||||
-spec extract_integer_list(term()) -> [integer()].
|
||||
extract_integer_list(List) when is_list(List) ->
|
||||
lists:reverse(
|
||||
lists:foldl(
|
||||
fun(Value, Acc) ->
|
||||
case type_conv:to_integer(Value) of
|
||||
undefined -> Acc;
|
||||
Int -> [Int | Acc]
|
||||
end
|
||||
end,
|
||||
[],
|
||||
List
|
||||
)
|
||||
);
|
||||
extract_integer_list(_) ->
|
||||
[].
|
||||
|
||||
-ifdef(TEST).
|
||||
-include_lib("eunit/include/eunit.hrl").
|
||||
|
||||
member_map_from_list_test() ->
|
||||
Data = #{
|
||||
<<"members">> => [
|
||||
#{<<"user">> => #{<<"id">> => <<"2">>}, <<"nick">> => <<"beta">>},
|
||||
#{<<"user">> => #{<<"id">> => <<"1">>}, <<"nick">> => <<"alpha">>}
|
||||
]
|
||||
},
|
||||
MemberMap = member_map(Data),
|
||||
?assertMatch(#{1 := _, 2 := _}, MemberMap),
|
||||
?assertEqual(2, member_count(Data)).
|
||||
|
||||
member_list_is_sorted_test() ->
|
||||
Data = #{
|
||||
<<"members">> => #{
|
||||
5 => #{<<"user">> => #{<<"id">> => <<"5">>}},
|
||||
2 => #{<<"user">> => #{<<"id">> => <<"2">>}}
|
||||
}
|
||||
},
|
||||
[First, Second] = member_list(Data),
|
||||
?assertEqual(2, member_user_id(First)),
|
||||
?assertEqual(5, member_user_id(Second)).
|
||||
|
||||
member_values_returns_members_without_sorting_test() ->
|
||||
Data = #{
|
||||
<<"members">> => #{
|
||||
9 => #{<<"user">> => #{<<"id">> => <<"9">>}},
|
||||
2 => #{<<"user">> => #{<<"id">> => <<"2">>}}
|
||||
}
|
||||
},
|
||||
Values = member_values(Data),
|
||||
?assertEqual(2, length(Values)).
|
||||
|
||||
put_member_updates_entry_test() ->
|
||||
Data = #{
|
||||
<<"members">> => #{
|
||||
10 => #{<<"user">> => #{<<"id">> => <<"10">>}, <<"nick">> => <<"old">>}
|
||||
}
|
||||
},
|
||||
UpdatedData = put_member(
|
||||
#{<<"user">> => #{<<"id">> => <<"10">>}, <<"nick">> => <<"new">>},
|
||||
Data
|
||||
),
|
||||
UpdatedMember = get_member(10, UpdatedData),
|
||||
?assertEqual(<<"new">>, maps:get(<<"nick">>, UpdatedMember)).
|
||||
|
||||
remove_member_removes_entry_test() ->
|
||||
Data = #{
|
||||
<<"members">> => #{
|
||||
10 => #{<<"user">> => #{<<"id">> => <<"10">>}}
|
||||
}
|
||||
},
|
||||
UpdatedData = remove_member(10, Data),
|
||||
?assertEqual(undefined, get_member(10, UpdatedData)).
|
||||
|
||||
normalize_data_builds_indexes_test() ->
|
||||
Data = #{
|
||||
<<"members">> => [#{<<"user">> => #{<<"id">> => <<"1">>}}],
|
||||
<<"roles">> => [#{<<"id">> => <<"100">>}],
|
||||
<<"channels">> => [#{<<"id">> => <<"200">>}]
|
||||
},
|
||||
Normalized = normalize_data(Data),
|
||||
?assert(is_map(maps:get(<<"members">>, Normalized))),
|
||||
?assertMatch(#{100 := _}, role_index(Normalized)),
|
||||
?assertMatch(#{200 := _}, channel_index(Normalized)).
|
||||
|
||||
member_role_index_builds_role_to_user_lookup_test() ->
|
||||
Data = #{
|
||||
<<"members">> => #{
|
||||
1 => #{<<"user">> => #{<<"id">> => <<"1">>}, <<"roles">> => [<<"10">>, <<"11">>]},
|
||||
2 => #{<<"user">> => #{<<"id">> => <<"2">>}, <<"roles">> => [<<"11">>]}
|
||||
}
|
||||
},
|
||||
Index = member_role_index(Data),
|
||||
?assertEqual(#{1 => true}, maps:get(10, Index)),
|
||||
?assertEqual(#{1 => true, 2 => true}, maps:get(11, Index)).
|
||||
|
||||
put_member_and_remove_member_keep_member_role_index_in_sync_test() ->
|
||||
Data0 = #{
|
||||
<<"members">> => #{
|
||||
3 => #{<<"user">> => #{<<"id">> => <<"3">>}, <<"roles">> => [<<"20">>]}
|
||||
}
|
||||
},
|
||||
Data1 = put_member(
|
||||
#{<<"user">> => #{<<"id">> => <<"3">>}, <<"roles">> => [<<"30">>]},
|
||||
Data0
|
||||
),
|
||||
Index1 = member_role_index(Data1),
|
||||
?assertEqual(undefined, maps:get(20, Index1, undefined)),
|
||||
?assertEqual(#{3 => true}, maps:get(30, Index1)),
|
||||
Data2 = remove_member(3, Data1),
|
||||
Index2 = member_role_index(Data2),
|
||||
?assertEqual(undefined, maps:get(30, Index2, undefined)).
|
||||
|
||||
-endif.
|
||||
File diff suppressed because it is too large
Load Diff
@@ -23,23 +23,15 @@
|
||||
-export([start_link/0]).
|
||||
-export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]).
|
||||
|
||||
-define(GUILD_PID_CACHE, guild_pid_cache).
|
||||
|
||||
-type guild_id() :: integer().
|
||||
-type shard_map() :: #{pid => pid(), ref => reference()}.
|
||||
-type shard_map() :: #{pid := pid(), ref := reference()}.
|
||||
-type state() :: #{
|
||||
shards => #{non_neg_integer() => shard_map()},
|
||||
shard_count => pos_integer()
|
||||
shards := #{non_neg_integer() => shard_map()},
|
||||
shard_count := pos_integer()
|
||||
}.
|
||||
|
||||
-record(shard, {
|
||||
pid :: pid(),
|
||||
ref :: reference()
|
||||
}).
|
||||
|
||||
-record(state, {
|
||||
shards = #{} :: #{non_neg_integer() => #shard{}},
|
||||
shard_count = 1 :: pos_integer()
|
||||
}).
|
||||
|
||||
-spec start_link() -> {ok, pid()} | {error, term()}.
|
||||
start_link() ->
|
||||
gen_server:start_link({local, ?MODULE}, ?MODULE, [], []).
|
||||
@@ -47,23 +39,23 @@ start_link() ->
|
||||
-spec init(list()) -> {ok, state()}.
|
||||
init([]) ->
|
||||
process_flag(trap_exit, true),
|
||||
{ShardCount, Source} = determine_shard_count(),
|
||||
ShardMap = start_shards(ShardCount, #{}),
|
||||
maybe_log_shard_source(guild_manager, ShardCount, Source),
|
||||
ets:new(?GUILD_PID_CACHE, [named_table, public, set, {read_concurrency, true}]),
|
||||
{ShardCount, _Source} = determine_shard_count(),
|
||||
ShardMap = start_shards(ShardCount),
|
||||
{ok, #{shards => ShardMap, shard_count => ShardCount}}.
|
||||
|
||||
-spec handle_call(term(), gen_server:from(), state()) -> {reply, term(), state()}.
|
||||
handle_call({start_or_lookup, GuildId} = Request, _From, State) ->
|
||||
{Reply, NewState} = forward_call(GuildId, Request, State),
|
||||
handle_call({start_or_lookup, GuildId}, _From, State) ->
|
||||
{Reply, NewState} = forward_call(GuildId, {start_or_lookup, GuildId}, State),
|
||||
{reply, Reply, NewState};
|
||||
handle_call({stop_guild, GuildId} = Request, _From, State) ->
|
||||
{Reply, NewState} = forward_call(GuildId, Request, State),
|
||||
handle_call({stop_guild, GuildId}, _From, State) ->
|
||||
{Reply, NewState} = forward_call(GuildId, {stop_guild, GuildId}, State),
|
||||
{reply, Reply, NewState};
|
||||
handle_call({reload_guild, GuildId} = Request, _From, State) ->
|
||||
{Reply, NewState} = forward_call(GuildId, Request, State),
|
||||
handle_call({reload_guild, GuildId}, _From, State) ->
|
||||
{Reply, NewState} = forward_call(GuildId, {reload_guild, GuildId}, State),
|
||||
{reply, Reply, NewState};
|
||||
handle_call({shutdown_guild, GuildId} = Request, _From, State) ->
|
||||
{Reply, NewState} = forward_call(GuildId, Request, State),
|
||||
handle_call({shutdown_guild, GuildId}, _From, State) ->
|
||||
{Reply, NewState} = forward_call(GuildId, {shutdown_guild, GuildId}, State),
|
||||
{reply, Reply, NewState};
|
||||
handle_call({reload_all_guilds, GuildIds}, _From, State) ->
|
||||
{Reply, NewState} = handle_reload_all(GuildIds, State),
|
||||
@@ -74,8 +66,7 @@ handle_call(get_local_count, _From, State) ->
|
||||
handle_call(get_global_count, _From, State) ->
|
||||
{Count, NewState} = aggregate_counts(get_global_count, State),
|
||||
{reply, {ok, Count}, NewState};
|
||||
handle_call(Request, _From, State) ->
|
||||
logger:warning("[guild_manager] unknown request ~p", [Request]),
|
||||
handle_call(_Request, _From, State) ->
|
||||
{reply, ok, State}.
|
||||
|
||||
-spec handle_cast(term(), state()) -> {noreply, state()}.
|
||||
@@ -83,21 +74,20 @@ handle_cast(_Msg, State) ->
|
||||
{noreply, State}.
|
||||
|
||||
-spec handle_info(term(), state()) -> {noreply, state()}.
|
||||
handle_info({'DOWN', Ref, process, _Pid, Reason}, State) ->
|
||||
handle_info({'DOWN', Ref, process, Pid, _Reason}, State) ->
|
||||
Shards = maps:get(shards, State),
|
||||
case find_shard_by_ref(Ref, Shards) of
|
||||
{ok, Index} ->
|
||||
logger:warning("[guild_manager] shard ~p crashed: ~p", [Index, Reason]),
|
||||
{_Shard, NewState} = restart_shard(Index, State),
|
||||
{noreply, NewState};
|
||||
not_found ->
|
||||
cleanup_guild_from_cache(Pid),
|
||||
{noreply, State}
|
||||
end;
|
||||
handle_info({'EXIT', Pid, Reason}, State) ->
|
||||
handle_info({'EXIT', Pid, _Reason}, State) ->
|
||||
Shards = maps:get(shards, State),
|
||||
case find_shard_by_pid(Pid, Shards) of
|
||||
{ok, Index} ->
|
||||
logger:warning("[guild_manager] shard ~p exited: ~p", [Index, Reason]),
|
||||
{_Shard, NewState} = restart_shard(Index, State),
|
||||
{noreply, NewState};
|
||||
not_found ->
|
||||
@@ -116,17 +106,10 @@ terminate(_Reason, State) ->
|
||||
end,
|
||||
maps:values(Shards)
|
||||
),
|
||||
catch ets:delete(?GUILD_PID_CACHE),
|
||||
ok.
|
||||
|
||||
-spec code_change(term(), term(), term()) -> {ok, state()}.
|
||||
code_change(_OldVsn, #state{shards = OldShards, shard_count = ShardCount}, _Extra) ->
|
||||
NewShards = maps:map(
|
||||
fun(_Index, #shard{pid = Pid, ref = Ref}) ->
|
||||
#{pid => Pid, ref => Ref}
|
||||
end,
|
||||
OldShards
|
||||
),
|
||||
{ok, #{shards => NewShards, shard_count => ShardCount}};
|
||||
code_change(_OldVsn, State, _Extra) when is_map(State) ->
|
||||
{ok, State}.
|
||||
|
||||
@@ -139,19 +122,26 @@ determine_shard_count() ->
|
||||
{default_shard_count(), auto}
|
||||
end.
|
||||
|
||||
-spec start_shards(pos_integer(), #{}) -> #{non_neg_integer() => shard_map()}.
|
||||
start_shards(Count, Acc) ->
|
||||
-spec default_shard_count() -> pos_integer().
|
||||
default_shard_count() ->
|
||||
Candidates = [
|
||||
erlang:system_info(logical_processors_available),
|
||||
erlang:system_info(schedulers_online)
|
||||
],
|
||||
max(1, lists:max([C || C <- Candidates, is_integer(C), C > 0] ++ [1])).
|
||||
|
||||
-spec start_shards(pos_integer()) -> #{non_neg_integer() => shard_map()}.
|
||||
start_shards(Count) ->
|
||||
lists:foldl(
|
||||
fun(Index, MapAcc) ->
|
||||
case start_shard(Index) of
|
||||
{ok, Shard} ->
|
||||
maps:put(Index, Shard, MapAcc);
|
||||
{error, Reason} ->
|
||||
logger:warning("[guild_manager] failed to start shard ~p: ~p", [Index, Reason]),
|
||||
{error, _Reason} ->
|
||||
MapAcc
|
||||
end
|
||||
end,
|
||||
Acc,
|
||||
#{},
|
||||
lists:seq(0, Count - 1)
|
||||
).
|
||||
|
||||
@@ -172,14 +162,31 @@ restart_shard(Index, State) ->
|
||||
{ok, Shard} ->
|
||||
Updated = State#{shards => maps:put(Index, Shard, Shards)},
|
||||
{Shard, Updated};
|
||||
{error, Reason} ->
|
||||
logger:error("[guild_manager] failed to restart shard ~p: ~p", [Index, Reason]),
|
||||
Dummy = #{pid => spawn(fun() -> exit(normal) end), ref => make_ref()},
|
||||
{error, _Reason} ->
|
||||
DummyPid = spawn(fun() -> ok end),
|
||||
Dummy = #{pid => DummyPid, ref => make_ref()},
|
||||
{Dummy, State}
|
||||
end.
|
||||
|
||||
-spec forward_call(guild_id(), term(), state()) -> {term(), state()}.
|
||||
forward_call(GuildId, {start_or_lookup, _} = Request, State) ->
|
||||
case ets:lookup(?GUILD_PID_CACHE, GuildId) of
|
||||
[{GuildId, GuildPid}] when is_pid(GuildPid) ->
|
||||
case erlang:is_process_alive(GuildPid) of
|
||||
true ->
|
||||
{{ok, GuildPid}, State};
|
||||
false ->
|
||||
ets:delete(?GUILD_PID_CACHE, GuildId),
|
||||
forward_call_to_shard(GuildId, Request, State)
|
||||
end;
|
||||
[] ->
|
||||
forward_call_to_shard(GuildId, Request, State)
|
||||
end;
|
||||
forward_call(GuildId, Request, State) ->
|
||||
forward_call_to_shard(GuildId, Request, State).
|
||||
|
||||
-spec forward_call_to_shard(guild_id(), term(), state()) -> {term(), state()}.
|
||||
forward_call_to_shard(GuildId, Request, State) ->
|
||||
{Index, State1} = ensure_shard(GuildId, State),
|
||||
Shards = maps:get(shards, State1),
|
||||
ShardMap = maps:get(Index, Shards),
|
||||
@@ -187,7 +194,11 @@ forward_call(GuildId, Request, State) ->
|
||||
case catch gen_server:call(Pid, Request, ?DEFAULT_GEN_SERVER_TIMEOUT) of
|
||||
{'EXIT', _} ->
|
||||
{_Shard, State2} = restart_shard(Index, State1),
|
||||
forward_call(GuildId, Request, State2);
|
||||
forward_call_to_shard(GuildId, Request, State2);
|
||||
{ok, GuildPid} = Reply ->
|
||||
ets:insert(?GUILD_PID_CACHE, {GuildId, GuildPid}),
|
||||
erlang:monitor(process, GuildPid),
|
||||
{Reply, State1};
|
||||
Reply ->
|
||||
{Reply, State1}
|
||||
end.
|
||||
@@ -223,146 +234,137 @@ select_shard(GuildId, Count) when Count > 0 ->
|
||||
-spec aggregate_counts(term(), state()) -> {non_neg_integer(), state()}.
|
||||
aggregate_counts(Request, State) ->
|
||||
Shards = maps:get(shards, State),
|
||||
Counts =
|
||||
[
|
||||
begin
|
||||
Pid = maps:get(pid, ShardMap),
|
||||
case catch gen_server:call(Pid, Request, ?DEFAULT_GEN_SERVER_TIMEOUT) of
|
||||
{ok, Count} -> Count;
|
||||
_ -> 0
|
||||
end
|
||||
Counts = lists:map(
|
||||
fun(ShardMap) ->
|
||||
Pid = maps:get(pid, ShardMap),
|
||||
case catch gen_server:call(Pid, Request, ?DEFAULT_GEN_SERVER_TIMEOUT) of
|
||||
{ok, Count} -> Count;
|
||||
_ -> 0
|
||||
end
|
||||
|| ShardMap <- maps:values(Shards)
|
||||
],
|
||||
end,
|
||||
maps:values(Shards)
|
||||
),
|
||||
{lists:sum(Counts), State}.
|
||||
|
||||
-spec handle_reload_all([guild_id()], state()) -> {#{count => non_neg_integer()}, state()}.
|
||||
-spec handle_reload_all([guild_id()], state()) -> {#{count := non_neg_integer()}, state()}.
|
||||
handle_reload_all([], State) ->
|
||||
Shards = maps:get(shards, State),
|
||||
{Replies, FinalState} =
|
||||
lists:foldl(
|
||||
fun({_Index, ShardMap}, {AccReplies, AccState}) ->
|
||||
Pid = maps:get(pid, ShardMap),
|
||||
case catch gen_server:call(Pid, {reload_all_guilds, []}, 60000) of
|
||||
Reply ->
|
||||
{AccReplies ++ [Reply], AccState}
|
||||
end
|
||||
end,
|
||||
{[], State},
|
||||
maps:to_list(Shards)
|
||||
),
|
||||
Count = lists:sum([maps:get(count, Reply, 0) || Reply <- Replies]),
|
||||
{Replies, FinalState} = lists:foldl(
|
||||
fun({_Index, ShardMap}, {AccReplies, AccState}) ->
|
||||
Pid = maps:get(pid, ShardMap),
|
||||
Reply = catch gen_server:call(Pid, {reload_all_guilds, []}, 15000),
|
||||
{[Reply | AccReplies], AccState}
|
||||
end,
|
||||
{[], State},
|
||||
maps:to_list(Shards)
|
||||
),
|
||||
Count = lists:sum([maps:get(count, Reply, 0) || Reply <- Replies, is_map(Reply)]),
|
||||
{#{count => Count}, FinalState};
|
||||
handle_reload_all(GuildIds, State) ->
|
||||
Count = maps:get(shard_count, State),
|
||||
Groups = group_ids_by_shard(GuildIds, Count),
|
||||
{TotalCount, FinalState} =
|
||||
lists:foldl(
|
||||
fun({Index, Ids}, {AccCount, AccState}) ->
|
||||
{ShardIdx, State1} = ensure_shard_for_index(Index, AccState),
|
||||
Shards = maps:get(shards, State1),
|
||||
ShardMap = maps:get(ShardIdx, Shards),
|
||||
Pid = maps:get(pid, ShardMap),
|
||||
case catch gen_server:call(Pid, {reload_all_guilds, Ids}, 60000) of
|
||||
#{count := CountReply} ->
|
||||
{AccCount + CountReply, State1};
|
||||
_ ->
|
||||
{AccCount, State1}
|
||||
end
|
||||
end,
|
||||
{0, State},
|
||||
Groups
|
||||
),
|
||||
{TotalCount, FinalState} = lists:foldl(
|
||||
fun({Index, Ids}, {AccCount, AccState}) ->
|
||||
{ShardIdx, State1} = ensure_shard_for_index(Index, AccState),
|
||||
Shards = maps:get(shards, State1),
|
||||
ShardMap = maps:get(ShardIdx, Shards),
|
||||
Pid = maps:get(pid, ShardMap),
|
||||
case catch gen_server:call(Pid, {reload_all_guilds, Ids}, 15000) of
|
||||
#{count := CountReply} ->
|
||||
{AccCount + CountReply, State1};
|
||||
_ ->
|
||||
{AccCount, State1}
|
||||
end
|
||||
end,
|
||||
{0, State},
|
||||
Groups
|
||||
),
|
||||
{#{count => TotalCount}, FinalState}.
|
||||
|
||||
-spec group_ids_by_shard([guild_id()], pos_integer()) -> [{non_neg_integer(), [guild_id()]}].
|
||||
group_ids_by_shard(GuildIds, ShardCount) ->
|
||||
lists:foldl(
|
||||
fun(GuildId, Acc) ->
|
||||
Index = select_shard(GuildId, ShardCount),
|
||||
case lists:keytake(Index, 1, Acc) of
|
||||
{value, {Index, Ids}, Rest} ->
|
||||
[{Index, [GuildId | Ids]} | Rest];
|
||||
false ->
|
||||
[{Index, [GuildId]} | Acc]
|
||||
end
|
||||
end,
|
||||
[],
|
||||
GuildIds
|
||||
).
|
||||
rendezvous_router:group_keys(GuildIds, ShardCount).
|
||||
|
||||
-spec find_shard_by_ref(reference(), #{non_neg_integer() => shard_map()}) ->
|
||||
{ok, non_neg_integer()} | not_found.
|
||||
find_shard_by_ref(Ref, Shards) ->
|
||||
maps:fold(
|
||||
fun
|
||||
(Index, ShardMap, _) when is_map(ShardMap) ->
|
||||
case maps:get(ref, ShardMap) of
|
||||
R when R =:= Ref -> {ok, Index};
|
||||
_ -> not_found
|
||||
end;
|
||||
(_, _, Acc) ->
|
||||
Acc
|
||||
end,
|
||||
not_found,
|
||||
Shards
|
||||
).
|
||||
find_shard_by(fun(#{ref := R}) -> R =:= Ref end, Shards).
|
||||
|
||||
-spec find_shard_by_pid(pid(), #{non_neg_integer() => shard_map()}) ->
|
||||
{ok, non_neg_integer()} | not_found.
|
||||
find_shard_by_pid(Pid, Shards) ->
|
||||
find_shard_by(fun(#{pid := P}) -> P =:= Pid end, Shards).
|
||||
|
||||
-spec find_shard_by(fun((shard_map()) -> boolean()), #{non_neg_integer() => shard_map()}) ->
|
||||
{ok, non_neg_integer()} | not_found.
|
||||
find_shard_by(Pred, Shards) ->
|
||||
maps:fold(
|
||||
fun
|
||||
(Index, ShardMap, _) when is_map(ShardMap) ->
|
||||
case maps:get(pid, ShardMap) of
|
||||
P when P =:= Pid -> {ok, Index};
|
||||
_ -> not_found
|
||||
end;
|
||||
(_, _, Acc) ->
|
||||
Acc
|
||||
(_, _, {ok, _} = Found) ->
|
||||
Found;
|
||||
(Index, ShardMap, not_found) ->
|
||||
case Pred(ShardMap) of
|
||||
true -> {ok, Index};
|
||||
false -> not_found
|
||||
end
|
||||
end,
|
||||
not_found,
|
||||
Shards
|
||||
).
|
||||
|
||||
-spec default_shard_count() -> pos_integer().
|
||||
default_shard_count() ->
|
||||
Candidates = [
|
||||
erlang:system_info(logical_processors_available), erlang:system_info(schedulers_online)
|
||||
],
|
||||
lists:max([C || C <- Candidates, is_integer(C), C > 0] ++ [1]).
|
||||
|
||||
-spec maybe_log_shard_source(atom(), pos_integer(), configured | auto) -> ok.
|
||||
maybe_log_shard_source(Name, Count, configured) ->
|
||||
logger:info("[~p] starting with ~p shards (configured)", [Name, Count]),
|
||||
ok;
|
||||
maybe_log_shard_source(Name, Count, auto) ->
|
||||
logger:info("[~p] starting with ~p shards (auto)", [Name, Count]),
|
||||
-spec cleanup_guild_from_cache(pid()) -> ok.
|
||||
cleanup_guild_from_cache(Pid) ->
|
||||
case ets:match_object(?GUILD_PID_CACHE, {'$1', Pid}) of
|
||||
[{GuildId, _Pid}] ->
|
||||
ets:delete(?GUILD_PID_CACHE, GuildId);
|
||||
[] ->
|
||||
ok
|
||||
end,
|
||||
ok.
|
||||
|
||||
-ifdef(TEST).
|
||||
-include_lib("eunit/include/eunit.hrl").
|
||||
|
||||
determine_shard_count_configured_test() ->
|
||||
with_runtime_config(guild_shards, 4, fun() ->
|
||||
?assertMatch({4, configured}, determine_shard_count())
|
||||
end).
|
||||
default_shard_count_positive_test() ->
|
||||
Count = default_shard_count(),
|
||||
?assert(Count >= 1).
|
||||
|
||||
determine_shard_count_auto_test() ->
|
||||
with_runtime_config(guild_shards, undefined, fun() ->
|
||||
{Count, auto} = determine_shard_count(),
|
||||
?assert(Count > 0)
|
||||
end).
|
||||
select_shard_deterministic_test() ->
|
||||
GuildId = 12345,
|
||||
ShardCount = 8,
|
||||
Shard1 = select_shard(GuildId, ShardCount),
|
||||
Shard2 = select_shard(GuildId, ShardCount),
|
||||
?assertEqual(Shard1, Shard2).
|
||||
|
||||
select_shard_in_range_test() ->
|
||||
ShardCount = 8,
|
||||
lists:foreach(
|
||||
fun(GuildId) ->
|
||||
Shard = select_shard(GuildId, ShardCount),
|
||||
?assert(Shard >= 0 andalso Shard < ShardCount)
|
||||
end,
|
||||
lists:seq(1, 100)
|
||||
).
|
||||
|
||||
group_ids_by_shard_test() ->
|
||||
GuildIds = [1, 2, 3, 4, 5],
|
||||
ShardCount = 2,
|
||||
Groups = group_ids_by_shard(GuildIds, ShardCount),
|
||||
AllIds = lists:flatten([Ids || {_, Ids} <- Groups]),
|
||||
?assertEqual(lists:sort(GuildIds), lists:sort(AllIds)).
|
||||
|
||||
find_shard_by_ref_found_test() ->
|
||||
Ref = make_ref(),
|
||||
Shards = #{0 => #{pid => self(), ref => Ref}},
|
||||
?assertMatch({ok, 0}, find_shard_by_ref(Ref, Shards)).
|
||||
|
||||
find_shard_by_ref_not_found_test() ->
|
||||
Shards = #{0 => #{pid => self(), ref => make_ref()}},
|
||||
?assertEqual(not_found, find_shard_by_ref(make_ref(), Shards)).
|
||||
|
||||
find_shard_by_pid_found_test() ->
|
||||
Pid = self(),
|
||||
Shards = #{0 => #{pid => Pid, ref => make_ref()}},
|
||||
?assertMatch({ok, 0}, find_shard_by_pid(Pid, Shards)).
|
||||
|
||||
with_runtime_config(Key, Value, Fun) ->
|
||||
Original = fluxer_gateway_env:get(Key),
|
||||
fluxer_gateway_env:patch(#{Key => Value}),
|
||||
Result = Fun(),
|
||||
fluxer_gateway_env:update(fun(Map) ->
|
||||
case Original of
|
||||
undefined -> maps:remove(Key, Map);
|
||||
Val -> maps:put(Key, Val, Map)
|
||||
end
|
||||
end),
|
||||
Result.
|
||||
-endif.
|
||||
|
||||
@@ -21,6 +21,8 @@
|
||||
-include_lib("fluxer_gateway/include/timeout_config.hrl").
|
||||
|
||||
-define(GUILD_API_CANARY_PERCENTAGE, 5).
|
||||
-define(BATCH_SIZE, 10).
|
||||
-define(BATCH_DELAY_MS, 100).
|
||||
|
||||
-export([start_link/1]).
|
||||
-export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]).
|
||||
@@ -30,56 +32,34 @@
|
||||
-type guild_data() :: #{binary() => term()}.
|
||||
-type fetch_result() :: {ok, guild_data()} | {error, term()}.
|
||||
-type state() :: #{
|
||||
guilds => #{guild_id() => guild_ref() | loading},
|
||||
api_host => string(),
|
||||
api_canary_host => undefined | string(),
|
||||
pending_requests => #{guild_id() => [gen_server:from()]}
|
||||
guilds := #{guild_id() => guild_ref() | loading},
|
||||
api_host := string(),
|
||||
api_canary_host := undefined | string(),
|
||||
pending_requests := #{guild_id() => [gen_server:from()]},
|
||||
shard_index := non_neg_integer()
|
||||
}.
|
||||
|
||||
-record(state, {
|
||||
guilds = #{} :: #{guild_id() => guild_ref() | loading},
|
||||
api_host :: string(),
|
||||
api_canary_host :: undefined | string(),
|
||||
pending_requests = #{} :: #{guild_id() => [gen_server:from()]}
|
||||
}).
|
||||
|
||||
-spec start_link(non_neg_integer()) -> {ok, pid()} | {error, term()}.
|
||||
start_link(ShardIndex) ->
|
||||
gen_server:start_link(?MODULE, #{shard_index => ShardIndex}, []).
|
||||
|
||||
-spec init(map()) -> {ok, state()}.
|
||||
init(_Args) ->
|
||||
init(Args) ->
|
||||
process_flag(trap_exit, true),
|
||||
fluxer_gateway_env:load(),
|
||||
ApiHost = fluxer_gateway_env:get(api_host),
|
||||
ApiCanaryHost = fluxer_gateway_env:get(api_canary_host),
|
||||
ShardIndex = maps:get(shard_index, Args, 0),
|
||||
{ok, #{
|
||||
guilds => #{},
|
||||
api_host => ApiHost,
|
||||
api_canary_host => ApiCanaryHost,
|
||||
pending_requests => #{}
|
||||
pending_requests => #{},
|
||||
shard_index => ShardIndex
|
||||
}}.
|
||||
|
||||
-spec handle_call(Request, From, State) -> Result when
|
||||
Request ::
|
||||
{start_or_lookup, guild_id()}
|
||||
| {stop_guild, guild_id()}
|
||||
| {reload_guild, guild_id()}
|
||||
| {reload_all_guilds, [guild_id()]}
|
||||
| {shutdown_guild, guild_id()}
|
||||
| get_local_count
|
||||
| get_global_count
|
||||
| term(),
|
||||
From :: gen_server:from(),
|
||||
State :: state(),
|
||||
Result ::
|
||||
{reply, Reply, state()}
|
||||
| {noreply, state()},
|
||||
Reply ::
|
||||
{ok, pid()}
|
||||
| {error, term()}
|
||||
| ok
|
||||
| {ok, non_neg_integer()}.
|
||||
-spec handle_call(term(), gen_server:from(), state()) ->
|
||||
{reply, term(), state()} | {noreply, state()}.
|
||||
handle_call({start_or_lookup, GuildId}, From, State) ->
|
||||
do_start_or_lookup(GuildId, From, State);
|
||||
handle_call({stop_guild, GuildId}, _From, State) ->
|
||||
@@ -87,37 +67,7 @@ handle_call({stop_guild, GuildId}, _From, State) ->
|
||||
handle_call({reload_guild, GuildId}, From, State) ->
|
||||
do_reload_guild(GuildId, From, State);
|
||||
handle_call({reload_all_guilds, GuildIds}, From, State) ->
|
||||
Guilds = maps:get(guilds, State),
|
||||
GuildsToReload =
|
||||
case GuildIds of
|
||||
[] ->
|
||||
[{GuildId, Pid} || {GuildId, {Pid, _Ref}} <- maps:to_list(Guilds)];
|
||||
Ids ->
|
||||
lists:filtermap(
|
||||
fun(GuildId) ->
|
||||
case maps:get(GuildId, Guilds, undefined) of
|
||||
{Pid, _Ref} -> {true, {GuildId, Pid}};
|
||||
_ -> false
|
||||
end
|
||||
end,
|
||||
Ids
|
||||
)
|
||||
end,
|
||||
Manager = self(),
|
||||
spawn(fun() ->
|
||||
try
|
||||
reload_guilds_in_batches(GuildsToReload, Manager, State, 10, 100),
|
||||
gen_server:cast(Manager, {all_guilds_reloaded, From, length(GuildsToReload)})
|
||||
catch
|
||||
Class:Error:Stacktrace ->
|
||||
logger:error(
|
||||
"[guild_manager] Spawned process failed: ~p:~p~n~p",
|
||||
[Class, Error, Stacktrace]
|
||||
),
|
||||
gen_server:cast(Manager, {all_guilds_reloaded, From, 0})
|
||||
end
|
||||
end),
|
||||
{noreply, State};
|
||||
do_reload_all_guilds(GuildIds, From, State);
|
||||
handle_call({shutdown_guild, GuildId}, _From, State) ->
|
||||
do_shutdown_guild(GuildId, State);
|
||||
handle_call(get_local_count, _From, State) ->
|
||||
@@ -131,60 +81,18 @@ handle_call(get_global_count, _From, State) ->
|
||||
handle_call(_Unknown, _From, State) ->
|
||||
{reply, ok, State}.
|
||||
|
||||
-spec handle_cast(Request, State) -> {noreply, state()} when
|
||||
Request ::
|
||||
{guild_data_fetched, guild_id(), fetch_result()}
|
||||
| {guild_data_reloaded, guild_id(), pid(), gen_server:from(), fetch_result()}
|
||||
| {all_guilds_reloaded, gen_server:from(), non_neg_integer()}
|
||||
| term(),
|
||||
State :: state().
|
||||
-spec handle_cast(term(), state()) -> {noreply, state()}.
|
||||
handle_cast({guild_data_fetched, GuildId, Result}, State) ->
|
||||
Pending = maps:get(pending_requests, State),
|
||||
Requests = maps:get(GuildId, Pending, []),
|
||||
Guilds = maps:get(guilds, State),
|
||||
case Result of
|
||||
{ok, Data} ->
|
||||
case start_guild(GuildId, Data, State) of
|
||||
{ok, Pid, NewState} ->
|
||||
lists:foreach(fun(From) -> gen_server:reply(From, {ok, Pid}) end, Requests),
|
||||
NewPending = maps:remove(GuildId, Pending),
|
||||
NewGuilds = maps:get(guilds, NewState),
|
||||
CleanGuilds = maps:remove(GuildId, NewGuilds),
|
||||
{noreply, NewState#{pending_requests => NewPending, guilds => CleanGuilds}};
|
||||
{error, Reason} ->
|
||||
logger:error("[guild_manager] Failed to start guild ~p: ~p", [GuildId, Reason]),
|
||||
lists:foreach(
|
||||
fun(From) -> gen_server:reply(From, {error, Reason}) end, Requests
|
||||
),
|
||||
NewGuilds = maps:remove(GuildId, Guilds),
|
||||
NewPending = maps:remove(GuildId, Pending),
|
||||
{noreply, State#{guilds => NewGuilds, pending_requests => NewPending}}
|
||||
end;
|
||||
_ ->
|
||||
lists:foreach(fun(From) -> gen_server:reply(From, {error, not_found}) end, Requests),
|
||||
NewGuilds = maps:remove(GuildId, Guilds),
|
||||
NewPending = maps:remove(GuildId, Pending),
|
||||
{noreply, State#{guilds => NewGuilds, pending_requests => NewPending}}
|
||||
end;
|
||||
handle_cast({guild_data_reloaded, _GuildId, Pid, From, Result}, State) ->
|
||||
case Result of
|
||||
{ok, Data} ->
|
||||
gen_server:call(Pid, {reload, Data}, ?GUILD_CALL_TIMEOUT),
|
||||
gen_server:reply(From, ok),
|
||||
{noreply, State};
|
||||
_ ->
|
||||
gen_server:reply(From, {error, fetch_failed}),
|
||||
{noreply, State}
|
||||
end;
|
||||
handle_guild_data_fetched(GuildId, Result, State);
|
||||
handle_cast({guild_data_reloaded, GuildId, Pid, From, Result}, State) ->
|
||||
handle_guild_data_reloaded(GuildId, Pid, From, Result, State);
|
||||
handle_cast({all_guilds_reloaded, From, Count}, State) ->
|
||||
gen_server:reply(From, #{count => Count}),
|
||||
{noreply, State};
|
||||
handle_cast(_Unknown, State) ->
|
||||
{noreply, State}.
|
||||
|
||||
-spec handle_info(Info, State) -> {noreply, state()} when
|
||||
Info :: {'DOWN', reference(), process, pid(), term()} | term(),
|
||||
State :: state().
|
||||
-spec handle_info(term(), state()) -> {noreply, state()}.
|
||||
handle_info({'DOWN', _Ref, process, Pid, _Reason}, State) ->
|
||||
Guilds = maps:get(guilds, State),
|
||||
NewGuilds = process_registry:cleanup_on_down(Pid, Guilds),
|
||||
@@ -192,23 +100,323 @@ handle_info({'DOWN', _Ref, process, Pid, _Reason}, State) ->
|
||||
handle_info(_Unknown, State) ->
|
||||
{noreply, State}.
|
||||
|
||||
-spec terminate(Reason, State) -> ok when
|
||||
Reason :: term(),
|
||||
State :: state().
|
||||
-spec terminate(term(), state()) -> ok.
|
||||
terminate(_Reason, _State) ->
|
||||
ok.
|
||||
|
||||
-spec code_change(term(), term(), term()) -> {ok, state()}.
|
||||
code_change(_OldVsn, #state{guilds = Guilds, api_host = ApiHost, api_canary_host = ApiCanaryHost, pending_requests = Pending}, _Extra) ->
|
||||
{ok, #{
|
||||
guilds => Guilds,
|
||||
api_host => ApiHost,
|
||||
api_canary_host => ApiCanaryHost,
|
||||
pending_requests => Pending
|
||||
}};
|
||||
code_change(_OldVsn, State, _Extra) when is_map(State) ->
|
||||
{ok, State}.
|
||||
|
||||
-spec do_start_or_lookup(guild_id(), gen_server:from(), state()) ->
|
||||
{reply, {ok, pid()} | {error, term()}, state()} | {noreply, state()}.
|
||||
do_start_or_lookup(GuildId, From, State) ->
|
||||
Guilds = maps:get(guilds, State),
|
||||
case maps:get(GuildId, Guilds, undefined) of
|
||||
{Pid, _Ref} ->
|
||||
{reply, {ok, Pid}, State};
|
||||
loading ->
|
||||
add_pending_request(GuildId, From, State);
|
||||
undefined ->
|
||||
lookup_or_fetch(GuildId, From, State)
|
||||
end.
|
||||
|
||||
-spec lookup_or_fetch(guild_id(), gen_server:from(), state()) ->
|
||||
{reply, {ok, pid()}, state()} | {noreply, state()}.
|
||||
lookup_or_fetch(GuildId, From, State) ->
|
||||
GuildName = process_registry:build_process_name(guild, GuildId),
|
||||
case whereis(GuildName) of
|
||||
undefined ->
|
||||
start_fetch(GuildId, From, State);
|
||||
_ExistingPid ->
|
||||
Guilds = maps:get(guilds, State),
|
||||
case process_registry:lookup_or_monitor(GuildName, GuildId, Guilds) of
|
||||
{ok, Pid, _Ref, NewGuilds} ->
|
||||
{reply, {ok, Pid}, State#{guilds => NewGuilds}};
|
||||
{error, not_found} ->
|
||||
{reply, {error, process_died}, State}
|
||||
end
|
||||
end.
|
||||
|
||||
-spec start_fetch(guild_id(), gen_server:from(), state()) -> {noreply, state()}.
|
||||
start_fetch(GuildId, From, State) ->
|
||||
Guilds = maps:get(guilds, State),
|
||||
NewGuilds = maps:put(GuildId, loading, Guilds),
|
||||
Pending = maps:get(pending_requests, State),
|
||||
NewPending = maps:put(GuildId, [From], Pending),
|
||||
NewState = State#{guilds => NewGuilds, pending_requests => NewPending},
|
||||
spawn_fetch(GuildId, State),
|
||||
{noreply, NewState}.
|
||||
|
||||
-spec spawn_fetch(guild_id(), state()) -> pid().
|
||||
spawn_fetch(GuildId, State) ->
|
||||
Manager = self(),
|
||||
ApiHostInfo = select_api_host(State),
|
||||
spawn(fun() ->
|
||||
try
|
||||
Result = fetch_guild_data_with_fallback(GuildId, ApiHostInfo, State),
|
||||
gen_server:cast(Manager, {guild_data_fetched, GuildId, Result})
|
||||
catch
|
||||
_:_:_ ->
|
||||
gen_server:cast(Manager, {guild_data_fetched, GuildId, {error, fetch_failed}})
|
||||
end
|
||||
end).
|
||||
|
||||
-spec add_pending_request(guild_id(), gen_server:from(), state()) -> {noreply, state()}.
|
||||
add_pending_request(GuildId, From, State) ->
|
||||
Pending = maps:get(pending_requests, State),
|
||||
Requests = maps:get(GuildId, Pending, []),
|
||||
NewPending = maps:put(GuildId, [From | Requests], Pending),
|
||||
{noreply, State#{pending_requests => NewPending}}.
|
||||
|
||||
-spec handle_guild_data_fetched(guild_id(), fetch_result(), state()) -> {noreply, state()}.
|
||||
handle_guild_data_fetched(GuildId, Result, State) ->
|
||||
Pending = maps:get(pending_requests, State),
|
||||
Requests = maps:get(GuildId, Pending, []),
|
||||
Guilds = maps:get(guilds, State),
|
||||
case Result of
|
||||
{ok, Data} ->
|
||||
case start_guild(GuildId, Data, State) of
|
||||
{ok, Pid, NewState} ->
|
||||
reply_to_all(Requests, {ok, Pid}),
|
||||
NewPending = maps:remove(GuildId, Pending),
|
||||
{noreply, NewState#{pending_requests => NewPending}};
|
||||
{error, Reason} ->
|
||||
reply_to_all(Requests, {error, Reason}),
|
||||
NewGuilds = maps:remove(GuildId, Guilds),
|
||||
NewPending = maps:remove(GuildId, Pending),
|
||||
{noreply, State#{guilds => NewGuilds, pending_requests => NewPending}}
|
||||
end;
|
||||
{error, Reason} ->
|
||||
reply_to_all(Requests, {error, Reason}),
|
||||
NewGuilds = maps:remove(GuildId, Guilds),
|
||||
NewPending = maps:remove(GuildId, Pending),
|
||||
{noreply, State#{guilds => NewGuilds, pending_requests => NewPending}}
|
||||
end.
|
||||
|
||||
-spec handle_guild_data_reloaded(guild_id(), pid(), gen_server:from(), fetch_result(), state()) ->
|
||||
{noreply, state()}.
|
||||
handle_guild_data_reloaded(_GuildId, Pid, From, Result, State) ->
|
||||
case Result of
|
||||
{ok, Data} ->
|
||||
catch gen_server:call(Pid, {reload, Data}, ?GUILD_CALL_TIMEOUT),
|
||||
gen_server:reply(From, ok);
|
||||
_ ->
|
||||
gen_server:reply(From, {error, fetch_failed})
|
||||
end,
|
||||
{noreply, State}.
|
||||
|
||||
-spec reply_to_all([gen_server:from()], term()) -> ok.
|
||||
reply_to_all(Requests, Reply) ->
|
||||
lists:foreach(fun(From) -> gen_server:reply(From, Reply) end, Requests).
|
||||
|
||||
-spec do_stop_guild(guild_id(), state()) -> {reply, ok, state()}.
|
||||
do_stop_guild(GuildId, State) ->
|
||||
Guilds = maps:get(guilds, State),
|
||||
GuildName = process_registry:build_process_name(guild, GuildId),
|
||||
case maps:get(GuildId, Guilds, undefined) of
|
||||
{Pid, Ref} ->
|
||||
demonitor(Ref, [flush]),
|
||||
catch gen_server:stop(Pid, normal, ?SHUTDOWN_TIMEOUT),
|
||||
process_registry:safe_unregister(GuildName),
|
||||
NewGuilds = maps:remove(GuildId, Guilds),
|
||||
{reply, ok, State#{guilds => NewGuilds}};
|
||||
_ ->
|
||||
case whereis(GuildName) of
|
||||
undefined ->
|
||||
{reply, ok, State};
|
||||
ExistingPid ->
|
||||
catch gen_server:stop(ExistingPid, normal, ?SHUTDOWN_TIMEOUT),
|
||||
process_registry:safe_unregister(GuildName),
|
||||
{reply, ok, State}
|
||||
end
|
||||
end.
|
||||
|
||||
-spec do_reload_guild(guild_id(), gen_server:from(), state()) ->
|
||||
{reply, {error, not_found}, state()} | {noreply, state()}.
|
||||
do_reload_guild(GuildId, From, State) ->
|
||||
Guilds = maps:get(guilds, State),
|
||||
GuildName = process_registry:build_process_name(guild, GuildId),
|
||||
case maps:get(GuildId, Guilds, undefined) of
|
||||
{Pid, _Ref} ->
|
||||
spawn_reload(GuildId, Pid, From, State),
|
||||
{noreply, State};
|
||||
_ ->
|
||||
case whereis(GuildName) of
|
||||
undefined ->
|
||||
{reply, {error, not_found}, State};
|
||||
_ExistingPid ->
|
||||
case process_registry:lookup_or_monitor(GuildName, GuildId, Guilds) of
|
||||
{ok, Pid, _Ref, NewGuilds} ->
|
||||
NewState = State#{guilds => NewGuilds},
|
||||
spawn_reload(GuildId, Pid, From, NewState),
|
||||
{noreply, NewState};
|
||||
{error, not_found} ->
|
||||
{reply, {error, not_found}, State}
|
||||
end
|
||||
end
|
||||
end.
|
||||
|
||||
-spec spawn_reload(guild_id(), pid(), gen_server:from(), state()) -> pid().
|
||||
spawn_reload(GuildId, Pid, From, State) ->
|
||||
Manager = self(),
|
||||
ApiHostInfo = select_api_host(State),
|
||||
spawn(fun() ->
|
||||
try
|
||||
Result = fetch_guild_data_with_fallback(GuildId, ApiHostInfo, State),
|
||||
gen_server:cast(Manager, {guild_data_reloaded, GuildId, Pid, From, Result})
|
||||
catch
|
||||
_:_:_ ->
|
||||
gen_server:cast(
|
||||
Manager, {guild_data_reloaded, GuildId, Pid, From, {error, fetch_failed}}
|
||||
)
|
||||
end
|
||||
end).
|
||||
|
||||
-spec do_reload_all_guilds([guild_id()], gen_server:from(), state()) -> {noreply, state()}.
|
||||
do_reload_all_guilds(GuildIds, From, State) ->
|
||||
Guilds = maps:get(guilds, State),
|
||||
GuildsToReload = select_guilds_to_reload(GuildIds, Guilds),
|
||||
Manager = self(),
|
||||
spawn(fun() ->
|
||||
try
|
||||
reload_guilds_in_batches(GuildsToReload, State),
|
||||
gen_server:cast(Manager, {all_guilds_reloaded, From, length(GuildsToReload)})
|
||||
catch
|
||||
_:_:_ ->
|
||||
gen_server:cast(Manager, {all_guilds_reloaded, From, 0})
|
||||
end
|
||||
end),
|
||||
{noreply, State}.
|
||||
|
||||
-spec select_guilds_to_reload([guild_id()], #{guild_id() => guild_ref() | loading}) ->
|
||||
[{guild_id(), pid()}].
|
||||
select_guilds_to_reload([], Guilds) ->
|
||||
[{GuildId, Pid} || {GuildId, {Pid, _Ref}} <- maps:to_list(Guilds)];
|
||||
select_guilds_to_reload(GuildIds, Guilds) ->
|
||||
lists:filtermap(
|
||||
fun(GuildId) ->
|
||||
case maps:get(GuildId, Guilds, undefined) of
|
||||
{Pid, _Ref} -> {true, {GuildId, Pid}};
|
||||
_ -> false
|
||||
end
|
||||
end,
|
||||
GuildIds
|
||||
).
|
||||
|
||||
-spec reload_guilds_in_batches([{guild_id(), pid()}], state()) -> ok.
|
||||
reload_guilds_in_batches([], _State) ->
|
||||
ok;
|
||||
reload_guilds_in_batches(Guilds, State) ->
|
||||
{Batch, Remaining} = lists:split(min(?BATCH_SIZE, length(Guilds)), Guilds),
|
||||
reload_batch(Batch, State),
|
||||
case Remaining of
|
||||
[] ->
|
||||
ok;
|
||||
_ ->
|
||||
timer:sleep(?BATCH_DELAY_MS),
|
||||
reload_guilds_in_batches(Remaining, State)
|
||||
end.
|
||||
|
||||
-spec reload_batch([{guild_id(), pid()}], state()) -> ok.
|
||||
reload_batch(Batch, State) ->
|
||||
ApiHostInfo = select_api_host(State),
|
||||
lists:foreach(
|
||||
fun({GuildId, Pid}) ->
|
||||
spawn(fun() ->
|
||||
try
|
||||
case fetch_guild_data_with_fallback(GuildId, ApiHostInfo, State) of
|
||||
{ok, Data} ->
|
||||
catch gen_server:call(Pid, {reload, Data}, ?GUILD_CALL_TIMEOUT);
|
||||
{error, _Reason} ->
|
||||
ok
|
||||
end
|
||||
catch
|
||||
_:_ -> ok
|
||||
end
|
||||
end)
|
||||
end,
|
||||
Batch
|
||||
).
|
||||
|
||||
-spec do_shutdown_guild(guild_id(), state()) -> {reply, ok, state()}.
|
||||
do_shutdown_guild(GuildId, State) ->
|
||||
Guilds = maps:get(guilds, State),
|
||||
GuildName = process_registry:build_process_name(guild, GuildId),
|
||||
case maps:get(GuildId, Guilds, undefined) of
|
||||
{Pid, Ref} ->
|
||||
demonitor(Ref, [flush]),
|
||||
catch gen_server:call(Pid, {terminate}, ?SHUTDOWN_TIMEOUT),
|
||||
process_registry:safe_unregister(GuildName),
|
||||
NewGuilds = maps:remove(GuildId, Guilds),
|
||||
{reply, ok, State#{guilds => NewGuilds}};
|
||||
_ ->
|
||||
case whereis(GuildName) of
|
||||
undefined ->
|
||||
{reply, ok, State};
|
||||
ExistingPid ->
|
||||
catch gen_server:call(ExistingPid, {terminate}, ?SHUTDOWN_TIMEOUT),
|
||||
process_registry:safe_unregister(GuildName),
|
||||
{reply, ok, State}
|
||||
end
|
||||
end.
|
||||
|
||||
-spec start_guild(guild_id(), guild_data(), state()) -> {ok, pid(), state()} | {error, term()}.
|
||||
start_guild(GuildId, Data, State) ->
|
||||
GuildName = process_registry:build_process_name(guild, GuildId),
|
||||
case whereis(GuildName) of
|
||||
undefined ->
|
||||
start_new_guild(GuildId, Data, GuildName, State);
|
||||
_ExistingPid ->
|
||||
lookup_existing_guild(GuildId, GuildName, State)
|
||||
end.
|
||||
|
||||
-spec start_new_guild(guild_id(), guild_data(), atom(), state()) ->
|
||||
{ok, pid(), state()} | {error, term()}.
|
||||
start_new_guild(GuildId, Data, GuildName, State) ->
|
||||
GuildState = #{
|
||||
id => GuildId,
|
||||
data => Data,
|
||||
sessions => #{}
|
||||
},
|
||||
Guilds = maps:get(guilds, State),
|
||||
GuildModule =
|
||||
case is_very_large_guild(Data) of
|
||||
true -> very_large_guild;
|
||||
false -> guild
|
||||
end,
|
||||
case GuildModule:start_link(GuildState) of
|
||||
{ok, Pid} ->
|
||||
case process_registry:register_and_monitor(GuildName, Pid, Guilds) of
|
||||
{ok, RegisteredPid, Ref, NewGuilds0} ->
|
||||
CleanGuilds = maps:remove(GuildName, NewGuilds0),
|
||||
NewGuilds = maps:put(GuildId, {RegisteredPid, Ref}, CleanGuilds),
|
||||
{ok, RegisteredPid, State#{guilds => NewGuilds}};
|
||||
{error, Reason} ->
|
||||
{error, Reason}
|
||||
end;
|
||||
Error ->
|
||||
Error
|
||||
end.
|
||||
|
||||
-spec is_very_large_guild(guild_data()) -> boolean().
|
||||
is_very_large_guild(Data) when is_map(Data) ->
|
||||
Guild = maps:get(<<"guild">>, Data, #{}),
|
||||
Features = maps:get(<<"features">>, Guild, []),
|
||||
lists:member(<<"VERY_LARGE_GUILD">>, Features);
|
||||
is_very_large_guild(_) ->
|
||||
false.
|
||||
|
||||
-spec lookup_existing_guild(guild_id(), atom(), state()) -> {ok, pid(), state()} | {error, term()}.
|
||||
lookup_existing_guild(GuildId, GuildName, State) ->
|
||||
Guilds = maps:get(guilds, State),
|
||||
case process_registry:lookup_or_monitor(GuildName, GuildId, Guilds) of
|
||||
{ok, Pid, _Ref, NewGuilds} ->
|
||||
{ok, Pid, State#{guilds => NewGuilds}};
|
||||
{error, not_found} ->
|
||||
{error, process_died}
|
||||
end.
|
||||
|
||||
-spec fetch_guild_data(guild_id(), string()) -> fetch_result().
|
||||
fetch_guild_data(GuildId, ApiHost) ->
|
||||
RpcRequest = #{
|
||||
@@ -217,43 +425,27 @@ fetch_guild_data(GuildId, ApiHost) ->
|
||||
<<"version">> => 1
|
||||
},
|
||||
Url = rpc_client:get_rpc_url(ApiHost),
|
||||
Headers =
|
||||
rpc_client:get_rpc_headers() ++ [{<<"content-type">>, <<"application/json">>}],
|
||||
Body = jsx:encode(RpcRequest),
|
||||
case
|
||||
hackney:request(post, Url, Headers, Body, [{recv_timeout, 30000}, {connect_timeout, 5000}])
|
||||
of
|
||||
{ok, 200, _RespHeaders, ClientRef} ->
|
||||
case hackney:body(ClientRef) of
|
||||
{ok, RespBody} ->
|
||||
hackney:close(ClientRef),
|
||||
Response = jsx:decode(RespBody, [return_maps]),
|
||||
Data = maps:get(<<"data">>, Response, #{}),
|
||||
{ok, Data};
|
||||
{error, BodyReason} ->
|
||||
hackney:close(ClientRef),
|
||||
logger:error("[guild_manager] Failed to read guild response body: ~p", [
|
||||
BodyReason
|
||||
]),
|
||||
{error, fetch_failed}
|
||||
end;
|
||||
{ok, StatusCode, _RespHeaders, ClientRef} ->
|
||||
ErrorBody =
|
||||
case hackney:body(ClientRef) of
|
||||
{ok, Body2} -> Body2;
|
||||
{error, _} -> <<"<unable to read error body>">>
|
||||
end,
|
||||
hackney:close(ClientRef),
|
||||
logger:error(
|
||||
"[guild_manager] Guild RPC failed with status ~p: ~s",
|
||||
[StatusCode, ErrorBody]
|
||||
),
|
||||
{error, fetch_failed};
|
||||
Headers = rpc_client:get_rpc_headers() ++ [{<<"content-type">>, <<"application/json">>}],
|
||||
Body = json:encode(RpcRequest),
|
||||
case gateway_http_client:request(rpc, post, Url, Headers, Body) of
|
||||
{ok, 200, _RespHeaders, RespBody} ->
|
||||
handle_fetch_response(RespBody);
|
||||
{ok, StatusCode, _RespHeaders, _RespBody} ->
|
||||
handle_fetch_error(StatusCode);
|
||||
{error, Reason} ->
|
||||
logger:error("[guild_manager] Guild RPC request failed: ~p", [Reason]),
|
||||
{error, fetch_failed}
|
||||
{error, {request_failed, Reason}}
|
||||
end.
|
||||
|
||||
-spec handle_fetch_response(binary()) -> fetch_result().
|
||||
handle_fetch_response(RespBody) ->
|
||||
Response = json:decode(RespBody),
|
||||
Data = maps:get(<<"data">>, Response, #{}),
|
||||
{ok, Data}.
|
||||
|
||||
-spec handle_fetch_error(integer()) -> {error, {http_status, integer()}}.
|
||||
handle_fetch_error(StatusCode) ->
|
||||
{error, {http_status, StatusCode}}.
|
||||
|
||||
-spec select_api_host(state()) -> {string(), boolean()}.
|
||||
select_api_host(State) ->
|
||||
case maps:get(api_canary_host, State) of
|
||||
@@ -270,12 +462,8 @@ select_api_host(State) ->
|
||||
should_use_canary_api() ->
|
||||
erlang:unique_integer([positive]) rem 100 < ?GUILD_API_CANARY_PERCENTAGE.
|
||||
|
||||
-spec fetch_guild_data_with_fallback(
|
||||
guild_id(),
|
||||
{string(), boolean()},
|
||||
state()
|
||||
) -> fetch_result().
|
||||
fetch_guild_data_with_fallback(GuildId, {ApiHost, false}, _) ->
|
||||
-spec fetch_guild_data_with_fallback(guild_id(), {string(), boolean()}, state()) -> fetch_result().
|
||||
fetch_guild_data_with_fallback(GuildId, {ApiHost, false}, _State) ->
|
||||
fetch_guild_data(GuildId, ApiHost);
|
||||
fetch_guild_data_with_fallback(GuildId, {ApiHost, true}, State) ->
|
||||
case fetch_guild_data(GuildId, ApiHost) of
|
||||
@@ -287,244 +475,57 @@ fetch_guild_data_with_fallback(GuildId, {ApiHost, true}, State) ->
|
||||
true ->
|
||||
Error;
|
||||
false ->
|
||||
logger:warning(
|
||||
"[guild_manager] Canary API request failed for ~p, retrying against stable host",
|
||||
[GuildId]
|
||||
),
|
||||
fetch_guild_data(GuildId, StableHost)
|
||||
end
|
||||
end.
|
||||
|
||||
-spec start_guild(guild_id(), guild_data(), state()) -> {ok, pid(), state()} | {error, term()}.
|
||||
start_guild(GuildId, Data, State) ->
|
||||
GuildName = process_registry:build_process_name(guild, GuildId),
|
||||
case whereis(GuildName) of
|
||||
undefined ->
|
||||
GuildState = #{
|
||||
id => GuildId,
|
||||
data => Data,
|
||||
sessions => #{},
|
||||
presences => #{}
|
||||
},
|
||||
Guilds = maps:get(guilds, State),
|
||||
case guild:start_link(GuildState) of
|
||||
{ok, Pid} ->
|
||||
case process_registry:register_and_monitor(GuildName, Pid, Guilds) of
|
||||
{ok, RegisteredPid, Ref, NewGuilds0} ->
|
||||
CleanGuilds = maps:remove(GuildName, NewGuilds0),
|
||||
NewGuilds = maps:put(GuildId, {RegisteredPid, Ref}, CleanGuilds),
|
||||
{ok, RegisteredPid, State#{guilds => NewGuilds}};
|
||||
{error, Reason} ->
|
||||
{error, Reason}
|
||||
end;
|
||||
Error ->
|
||||
Error
|
||||
end;
|
||||
_ExistingPid ->
|
||||
Guilds = maps:get(guilds, State),
|
||||
case process_registry:lookup_or_monitor(GuildName, GuildId, Guilds) of
|
||||
{ok, Pid, _Ref, NewGuilds} ->
|
||||
{ok, Pid, State#{guilds => NewGuilds}};
|
||||
{error, not_found} ->
|
||||
{error, process_died}
|
||||
end
|
||||
end.
|
||||
-ifdef(TEST).
|
||||
-include_lib("eunit/include/eunit.hrl").
|
||||
|
||||
-spec reload_guilds_in_batches(
|
||||
[{guild_id(), pid()}],
|
||||
pid(),
|
||||
state(),
|
||||
pos_integer(),
|
||||
non_neg_integer()
|
||||
) -> ok.
|
||||
reload_guilds_in_batches([], _Manager, _State, _BatchSize, _DelayMs) ->
|
||||
ok;
|
||||
reload_guilds_in_batches(Guilds, Manager, State, BatchSize, DelayMs) ->
|
||||
{Batch, Remaining} = lists:split(min(BatchSize, length(Guilds)), Guilds),
|
||||
lists:foreach(
|
||||
fun({GuildId, Pid}) ->
|
||||
ApiHostInfo = select_api_host(State),
|
||||
spawn(fun() ->
|
||||
try
|
||||
case fetch_guild_data_with_fallback(GuildId, ApiHostInfo, State) of
|
||||
{ok, Data} ->
|
||||
gen_server:call(Pid, {reload, Data}, ?GUILD_CALL_TIMEOUT);
|
||||
{error, Reason} ->
|
||||
logger:error("[guild_manager] Failed to reload guild ~p: ~p", [
|
||||
GuildId, Reason
|
||||
])
|
||||
end
|
||||
catch
|
||||
Class:Error:Stacktrace ->
|
||||
logger:error(
|
||||
"[guild_manager] Spawned process failed: ~p:~p~n~p",
|
||||
[Class, Error, Stacktrace]
|
||||
)
|
||||
end
|
||||
end)
|
||||
end,
|
||||
Batch
|
||||
),
|
||||
case Remaining of
|
||||
[] ->
|
||||
ok;
|
||||
_ ->
|
||||
timer:sleep(DelayMs),
|
||||
reload_guilds_in_batches(Remaining, Manager, State, BatchSize, DelayMs)
|
||||
end.
|
||||
select_api_host_no_canary_test() ->
|
||||
State = #{api_host => "http://api.local", api_canary_host => undefined},
|
||||
{Host, IsCanary} = select_api_host(State),
|
||||
?assertEqual("http://api.local", Host),
|
||||
?assertEqual(false, IsCanary).
|
||||
|
||||
-spec do_start_or_lookup(guild_id(), gen_server:from(), state()) ->
|
||||
{reply, {ok, pid()} | {error, term()}, state()} | {noreply, state()}.
|
||||
do_start_or_lookup(GuildId, From, State) ->
|
||||
Guilds = maps:get(guilds, State),
|
||||
case maps:get(GuildId, Guilds, undefined) of
|
||||
{Pid, _Ref} ->
|
||||
{reply, {ok, Pid}, State};
|
||||
loading ->
|
||||
Pending = maps:get(pending_requests, State),
|
||||
Requests = maps:get(GuildId, Pending, []),
|
||||
NewPending = maps:put(GuildId, [From | Requests], Pending),
|
||||
{noreply, State#{pending_requests => NewPending}};
|
||||
undefined ->
|
||||
GuildName = process_registry:build_process_name(guild, GuildId),
|
||||
case whereis(GuildName) of
|
||||
undefined ->
|
||||
NewGuilds = maps:put(GuildId, loading, Guilds),
|
||||
Pending = maps:get(pending_requests, State),
|
||||
NewPending = maps:put(GuildId, [From], Pending),
|
||||
NewState = State#{guilds => NewGuilds, pending_requests => NewPending},
|
||||
Manager = self(),
|
||||
ApiHostInfo = select_api_host(State),
|
||||
spawn(fun() ->
|
||||
try
|
||||
Result = fetch_guild_data_with_fallback(GuildId, ApiHostInfo, State),
|
||||
gen_server:cast(Manager, {guild_data_fetched, GuildId, Result})
|
||||
catch
|
||||
Class:Error:Stacktrace ->
|
||||
logger:error(
|
||||
"[guild_manager] Spawned process failed: ~p:~p~n~p",
|
||||
[Class, Error, Stacktrace]
|
||||
),
|
||||
gen_server:cast(
|
||||
Manager, {guild_data_fetched, GuildId, {error, fetch_failed}}
|
||||
)
|
||||
end
|
||||
end),
|
||||
{noreply, NewState};
|
||||
_ExistingPid ->
|
||||
case process_registry:lookup_or_monitor(GuildName, GuildId, Guilds) of
|
||||
{ok, Pid, _Ref, NewGuilds} ->
|
||||
{reply, {ok, Pid}, State#{guilds => NewGuilds}};
|
||||
{error, not_found} ->
|
||||
{reply, {error, process_died}, State}
|
||||
end
|
||||
end
|
||||
end.
|
||||
should_use_canary_api_returns_boolean_test() ->
|
||||
Result = should_use_canary_api(),
|
||||
?assert(is_boolean(Result)).
|
||||
|
||||
-spec do_stop_guild(guild_id(), state()) -> {reply, ok, state()}.
|
||||
do_stop_guild(GuildId, State) ->
|
||||
Guilds = maps:get(guilds, State),
|
||||
GuildName = process_registry:build_process_name(guild, GuildId),
|
||||
case maps:get(GuildId, Guilds, undefined) of
|
||||
{Pid, Ref} ->
|
||||
demonitor(Ref, [flush]),
|
||||
gen_server:stop(Pid, normal, ?SHUTDOWN_TIMEOUT),
|
||||
process_registry:safe_unregister(GuildName),
|
||||
NewGuilds = maps:remove(GuildId, Guilds),
|
||||
{reply, ok, State#{guilds => NewGuilds}};
|
||||
_ ->
|
||||
case whereis(GuildName) of
|
||||
undefined ->
|
||||
{reply, ok, State};
|
||||
ExistingPid ->
|
||||
gen_server:stop(ExistingPid, normal, ?SHUTDOWN_TIMEOUT),
|
||||
process_registry:safe_unregister(GuildName),
|
||||
{reply, ok, State}
|
||||
end
|
||||
end.
|
||||
select_guilds_to_reload_empty_ids_test() ->
|
||||
Guilds = #{1 => {self(), make_ref()}, 2 => {self(), make_ref()}},
|
||||
Result = select_guilds_to_reload([], Guilds),
|
||||
?assertEqual(2, length(Result)).
|
||||
|
||||
-spec do_reload_guild(guild_id(), gen_server:from(), state()) ->
|
||||
{reply, {error, not_found}, state()} | {noreply, state()}.
|
||||
do_reload_guild(GuildId, From, State) ->
|
||||
Guilds = maps:get(guilds, State),
|
||||
GuildName = process_registry:build_process_name(guild, GuildId),
|
||||
case maps:get(GuildId, Guilds, undefined) of
|
||||
{Pid, _Ref} ->
|
||||
Manager = self(),
|
||||
ApiHostInfo = select_api_host(State),
|
||||
spawn(fun() ->
|
||||
try
|
||||
Result = fetch_guild_data_with_fallback(GuildId, ApiHostInfo, State),
|
||||
gen_server:cast(Manager, {guild_data_reloaded, GuildId, Pid, From, Result})
|
||||
catch
|
||||
Class:Error:Stacktrace ->
|
||||
logger:error(
|
||||
"[guild_manager] Spawned process failed: ~p:~p~n~p",
|
||||
[Class, Error, Stacktrace]
|
||||
),
|
||||
gen_server:cast(
|
||||
Manager,
|
||||
{guild_data_reloaded, GuildId, Pid, From, {error, fetch_failed}}
|
||||
)
|
||||
end
|
||||
end),
|
||||
{noreply, State};
|
||||
_ ->
|
||||
case whereis(GuildName) of
|
||||
undefined ->
|
||||
{reply, {error, not_found}, State};
|
||||
_ExistingPid ->
|
||||
case process_registry:lookup_or_monitor(GuildName, GuildId, Guilds) of
|
||||
{ok, Pid, _Ref, NewGuilds} ->
|
||||
NewState = State#{guilds => NewGuilds},
|
||||
Manager = self(),
|
||||
ApiHostInfo = select_api_host(NewState),
|
||||
spawn(fun() ->
|
||||
try
|
||||
Result = fetch_guild_data_with_fallback(
|
||||
GuildId, ApiHostInfo, NewState
|
||||
),
|
||||
gen_server:cast(
|
||||
Manager, {guild_data_reloaded, GuildId, Pid, From, Result}
|
||||
)
|
||||
catch
|
||||
Class:Error:Stacktrace ->
|
||||
logger:error(
|
||||
"[guild_manager] Spawned process failed: ~p:~p~n~p",
|
||||
[Class, Error, Stacktrace]
|
||||
),
|
||||
gen_server:cast(
|
||||
Manager,
|
||||
{guild_data_reloaded, GuildId, Pid, From,
|
||||
{error, fetch_failed}}
|
||||
)
|
||||
end
|
||||
end),
|
||||
{noreply, NewState};
|
||||
{error, not_found} ->
|
||||
{reply, {error, not_found}, State}
|
||||
end
|
||||
end
|
||||
end.
|
||||
select_guilds_to_reload_specific_ids_test() ->
|
||||
Pid = self(),
|
||||
Ref = make_ref(),
|
||||
Guilds = #{1 => {Pid, Ref}, 2 => {Pid, Ref}, 3 => loading},
|
||||
Result = select_guilds_to_reload([1, 3], Guilds),
|
||||
?assertEqual(1, length(Result)).
|
||||
|
||||
-spec do_shutdown_guild(guild_id(), state()) -> {reply, ok, state()}.
|
||||
do_shutdown_guild(GuildId, State) ->
|
||||
Guilds = maps:get(guilds, State),
|
||||
GuildName = process_registry:build_process_name(guild, GuildId),
|
||||
case maps:get(GuildId, Guilds, undefined) of
|
||||
{Pid, Ref} ->
|
||||
demonitor(Ref, [flush]),
|
||||
gen_server:call(Pid, {terminate}, ?SHUTDOWN_TIMEOUT),
|
||||
process_registry:safe_unregister(GuildName),
|
||||
NewGuilds = maps:remove(GuildId, Guilds),
|
||||
{reply, ok, State#{guilds => NewGuilds}};
|
||||
_ ->
|
||||
case whereis(GuildName) of
|
||||
undefined ->
|
||||
{reply, ok, State};
|
||||
ExistingPid ->
|
||||
catch gen_server:call(ExistingPid, {terminate}, ?SHUTDOWN_TIMEOUT),
|
||||
process_registry:safe_unregister(GuildName),
|
||||
{reply, ok, State}
|
||||
end
|
||||
end.
|
||||
reply_to_all_empty_list_test() ->
|
||||
?assertEqual(ok, reply_to_all([], ok)).
|
||||
|
||||
do_start_or_lookup_loading_deduplicates_requests_test() ->
|
||||
GuildId = 4444,
|
||||
From1 = {self(), make_ref()},
|
||||
From2 = {self(), make_ref()},
|
||||
State0 = #{
|
||||
guilds => #{GuildId => loading},
|
||||
api_host => "http://api.local",
|
||||
api_canary_host => undefined,
|
||||
pending_requests => #{},
|
||||
shard_index => 0
|
||||
},
|
||||
{noreply, State1} = do_start_or_lookup(GuildId, From1, State0),
|
||||
Pending1 = maps:get(pending_requests, State1),
|
||||
?assertEqual([From1], maps:get(GuildId, Pending1)),
|
||||
{noreply, State2} = do_start_or_lookup(GuildId, From2, State1),
|
||||
Pending2 = maps:get(pending_requests, State2),
|
||||
Requests = maps:get(GuildId, Pending2),
|
||||
?assertEqual(2, length(Requests)),
|
||||
?assert(lists:member(From1, Requests)),
|
||||
?assert(lists:member(From2, Requests)).
|
||||
|
||||
-endif.
|
||||
|
||||
@@ -25,11 +25,14 @@
|
||||
get_items_in_range/3,
|
||||
handle_member_update/3,
|
||||
build_sync_response/4,
|
||||
member_list_delta/4,
|
||||
member_list_snapshot/2,
|
||||
get_online_count/1,
|
||||
broadcast_member_list_updates/3,
|
||||
broadcast_all_member_list_updates/1,
|
||||
broadcast_member_list_updates_for_channel/2,
|
||||
normalize_ranges/1
|
||||
normalize_ranges/1,
|
||||
get_members_cursor/2
|
||||
]).
|
||||
|
||||
-ifdef(TEST).
|
||||
@@ -43,13 +46,19 @@
|
||||
-type group_item() :: map().
|
||||
-type list_item() :: member_item() | group_item().
|
||||
-type user_id() :: integer().
|
||||
-type channel_id() :: integer().
|
||||
|
||||
-define(MAX_RANGE_END, 100000).
|
||||
|
||||
-spec validate_range(range()) -> range() | invalid.
|
||||
validate_range({Start, End}) when is_integer(Start), is_integer(End),
|
||||
Start >= 0, End >= 0, Start =< End,
|
||||
End =< ?MAX_RANGE_END ->
|
||||
validate_range({Start, End}) when
|
||||
is_integer(Start),
|
||||
is_integer(End),
|
||||
Start >= 0,
|
||||
End >= 0,
|
||||
Start =< End,
|
||||
End =< ?MAX_RANGE_END
|
||||
->
|
||||
{Start, End};
|
||||
validate_range(_) ->
|
||||
invalid.
|
||||
@@ -77,27 +86,39 @@ merge_overlapping_ranges([{S1, E1}, {S2, E2} | Rest]) when S2 =< E1 + 1 ->
|
||||
merge_overlapping_ranges([Range | Rest]) ->
|
||||
[Range | merge_overlapping_ranges(Rest)].
|
||||
|
||||
-spec calculate_list_id(integer(), guild_state()) -> list_id().
|
||||
calculate_list_id(ChannelId, _State) when is_integer(ChannelId), ChannelId > 0 ->
|
||||
integer_to_binary(ChannelId);
|
||||
-spec calculate_list_id(channel_id(), guild_state()) -> list_id().
|
||||
calculate_list_id(ChannelId, State) when is_integer(ChannelId), ChannelId > 0 ->
|
||||
Data = maps:get(data, State, #{}),
|
||||
Channels = map_utils:ensure_list(maps:get(<<"channels">>, Data, [])),
|
||||
ChannelIdBin = integer_to_binary(ChannelId),
|
||||
ChannelExists = lists:any(
|
||||
fun(Ch) -> maps:get(<<"id">>, Ch, undefined) =:= ChannelIdBin end, Channels
|
||||
),
|
||||
case ChannelExists of
|
||||
true -> ChannelIdBin;
|
||||
false -> <<"0">>
|
||||
end;
|
||||
calculate_list_id(_, _) ->
|
||||
<<"0">>.
|
||||
|
||||
-spec get_member_groups(list_id(), guild_state()) -> [group_item()].
|
||||
get_member_groups(ListId, State) ->
|
||||
Data = maps:get(data, State, #{}),
|
||||
Members = map_utils:ensure_list(maps:get(<<"members">>, Data, [])),
|
||||
Members = guild_data_index:member_values(Data),
|
||||
Roles = map_utils:ensure_list(maps:get(<<"roles">>, Data, [])),
|
||||
GuildId = maps:get(id, State, 0),
|
||||
HoistedRoles = get_hoisted_roles_sorted(Roles, GuildId),
|
||||
FilteredMembers = filter_members_for_list(ListId, Members, State),
|
||||
{OnlineMembers, OfflineMembers} = partition_members_by_online(FilteredMembers, State),
|
||||
RoleGroups = build_role_groups(HoistedRoles, OnlineMembers),
|
||||
OnlineGroup = #{<<"id">> => <<"online">>, <<"count">> => count_ungrouped_online(OnlineMembers, HoistedRoles)},
|
||||
OnlineGroup = #{
|
||||
<<"id">> => <<"online">>, <<"count">> => count_ungrouped_online(OnlineMembers, HoistedRoles)
|
||||
},
|
||||
OfflineGroup = #{<<"id">> => <<"offline">>, <<"count">> => length(OfflineMembers)},
|
||||
RoleGroups ++ [OnlineGroup, OfflineGroup].
|
||||
|
||||
-spec subscribe_ranges(binary(), list_id(), [range()], guild_state()) -> {guild_state(), boolean(), [range()]}.
|
||||
-spec subscribe_ranges(binary(), list_id(), [range()], guild_state()) ->
|
||||
{guild_state(), boolean(), [range()]}.
|
||||
subscribe_ranges(SessionId, ListId, Ranges, State) ->
|
||||
NormalizedRanges = normalize_ranges(Ranges),
|
||||
Subscriptions = maps:get(member_list_subscriptions, State, #{}),
|
||||
@@ -205,8 +226,9 @@ build_sync_response(GuildId, ListId, Ranges, State) ->
|
||||
-spec get_online_count(guild_state()) -> non_neg_integer().
|
||||
get_online_count(State) ->
|
||||
Data = maps:get(data, State, #{}),
|
||||
Members = map_utils:ensure_list(maps:get(<<"members">>, Data, [])),
|
||||
{OnlineMembers, _} = partition_members_by_online(Members, State),
|
||||
Members = guild_data_index:member_values(Data),
|
||||
EligibleMembers = filter_members_for_list(<<"0">>, Members, State),
|
||||
{OnlineMembers, _} = partition_members_by_online(EligibleMembers, State),
|
||||
length(OnlineMembers).
|
||||
|
||||
-spec broadcast_member_list_updates(user_id(), guild_state(), guild_state()) -> ok.
|
||||
@@ -226,7 +248,9 @@ broadcast_member_list_updates(_UserId, OldState, UpdatedState) ->
|
||||
<<"groups">> => Groups,
|
||||
<<"ops">> => Ops
|
||||
},
|
||||
send_member_list_update_to_sessions(ListId, ListSubs, Sessions, Payload, UpdatedState);
|
||||
send_member_list_update_to_sessions(
|
||||
ListId, ListSubs, Sessions, Payload, UpdatedState
|
||||
);
|
||||
_ ->
|
||||
ok
|
||||
end
|
||||
@@ -246,7 +270,9 @@ send_member_list_update_to_sessions(ListId, ListSubs, Sessions, Payload, State)
|
||||
true ->
|
||||
case session_can_view_channel(SessionData, ChannelId, State) of
|
||||
true ->
|
||||
gen_server:cast(SessionPid, {dispatch, guild_member_list_update, Payload});
|
||||
gen_server:cast(
|
||||
SessionPid, {dispatch, guild_member_list_update, Payload}
|
||||
);
|
||||
false ->
|
||||
ok
|
||||
end;
|
||||
@@ -260,7 +286,7 @@ send_member_list_update_to_sessions(ListId, ListSubs, Sessions, Payload, State)
|
||||
ListSubs
|
||||
).
|
||||
|
||||
-spec member_list_delta(list_id(), guild_state(), guild_state(), user_id()) ->
|
||||
-spec member_list_delta(list_id(), guild_state(), guild_state(), user_id()) ->
|
||||
{non_neg_integer(), non_neg_integer(), [group_item()], [list_item()], boolean()}.
|
||||
member_list_delta(ListId, OldState, UpdatedState, UserId) ->
|
||||
{OldCount, OldOnline, OldGroups, OldItems} = member_list_snapshot(ListId, OldState),
|
||||
@@ -270,10 +296,14 @@ member_list_delta(ListId, OldState, UpdatedState, UserId) ->
|
||||
{MemberCount, OnlineCount, Groups, Ops, true};
|
||||
{false, _} ->
|
||||
Ops = diff_items_to_ops(OldItems, Items),
|
||||
Changed = Ops =/= [] orelse OldCount =/= MemberCount orelse OldOnline =/= OnlineCount orelse OldGroups =/= Groups,
|
||||
Changed =
|
||||
Ops =/= [] orelse OldCount =/= MemberCount orelse OldOnline =/= OnlineCount orelse
|
||||
OldGroups =/= Groups,
|
||||
{MemberCount, OnlineCount, Groups, Ops, Changed}
|
||||
end.
|
||||
|
||||
-spec presence_move_ops(user_id(), guild_state(), guild_state(), [list_item()], [list_item()]) ->
|
||||
{boolean(), [map()]}.
|
||||
presence_move_ops(UserId, OldState, UpdatedState, OldItems, NewItems) ->
|
||||
case presence_status_changed(UserId, OldState, UpdatedState) of
|
||||
false ->
|
||||
@@ -295,6 +325,7 @@ presence_move_ops(UserId, OldState, UpdatedState, OldItems, NewItems) ->
|
||||
end
|
||||
end.
|
||||
|
||||
-spec presence_status_changed(user_id(), guild_state(), guild_state()) -> boolean().
|
||||
presence_status_changed(UserId, OldState, UpdatedState) ->
|
||||
OldPresence = resolve_presence_for_user(OldState, UserId),
|
||||
NewPresence = resolve_presence_for_user(UpdatedState, UserId),
|
||||
@@ -302,9 +333,13 @@ presence_status_changed(UserId, OldState, UpdatedState) ->
|
||||
NewStatus = maps:get(<<"status">>, NewPresence, <<"offline">>),
|
||||
OldStatus =/= NewStatus.
|
||||
|
||||
-spec find_member_entry(user_id(), [list_item()]) ->
|
||||
{ok, non_neg_integer(), list_item()} | {error, not_found}.
|
||||
find_member_entry(UserId, Items) ->
|
||||
find_member_entry(UserId, Items, 0).
|
||||
|
||||
-spec find_member_entry(user_id(), [list_item()], non_neg_integer()) ->
|
||||
{ok, non_neg_integer(), list_item()} | {error, not_found}.
|
||||
find_member_entry(_UserId, [], _Index) ->
|
||||
{error, not_found};
|
||||
find_member_entry(UserId, [Item | Rest], Index) ->
|
||||
@@ -318,6 +353,7 @@ find_member_entry(UserId, [Item | Rest], Index) ->
|
||||
end
|
||||
end.
|
||||
|
||||
-spec adjusted_insert_index(non_neg_integer(), non_neg_integer()) -> non_neg_integer().
|
||||
adjusted_insert_index(OldIdx, NewIdx) when NewIdx > OldIdx ->
|
||||
NewIdx - 1;
|
||||
adjusted_insert_index(_OldIdx, NewIdx) ->
|
||||
@@ -345,9 +381,16 @@ broadcast_all_member_list_updates(State) ->
|
||||
case maps:get(SessionId, Sessions, undefined) of
|
||||
SessionData when is_map(SessionData) ->
|
||||
SessionPid = maps:get(pid, SessionData, undefined),
|
||||
case {is_pid(SessionPid), session_can_view_channel(SessionData, ChannelId, State)} of
|
||||
case
|
||||
{
|
||||
is_pid(SessionPid),
|
||||
session_can_view_channel(SessionData, ChannelId, State)
|
||||
}
|
||||
of
|
||||
{true, true} ->
|
||||
SyncResponse = build_sync_response(GuildId, ListId, Ranges, State),
|
||||
SyncResponse = build_sync_response(
|
||||
GuildId, ListId, Ranges, State
|
||||
),
|
||||
gen_server:cast(
|
||||
SessionPid,
|
||||
{dispatch, guild_member_list_update, SyncResponse}
|
||||
@@ -366,7 +409,7 @@ broadcast_all_member_list_updates(State) ->
|
||||
),
|
||||
ok.
|
||||
|
||||
-spec broadcast_member_list_updates_for_channel(integer(), guild_state()) -> ok.
|
||||
-spec broadcast_member_list_updates_for_channel(channel_id(), guild_state()) -> ok.
|
||||
broadcast_member_list_updates_for_channel(ChannelId, State) ->
|
||||
GuildId = maps:get(id, State, 0),
|
||||
ListId = calculate_list_id(ChannelId, State),
|
||||
@@ -381,13 +424,23 @@ broadcast_member_list_updates_for_channel(ChannelId, State) ->
|
||||
case maps:get(SessionId, Sessions, undefined) of
|
||||
SessionData when is_map(SessionData) ->
|
||||
SessionPid = maps:get(pid, SessionData, undefined),
|
||||
case {is_pid(SessionPid), session_can_view_channel(SessionData, ChannelId, State)} of
|
||||
case
|
||||
{
|
||||
is_pid(SessionPid),
|
||||
session_can_view_channel(SessionData, ChannelId, State)
|
||||
}
|
||||
of
|
||||
{true, true} ->
|
||||
SyncResponse = build_sync_response(GuildId, ListId, Ranges, State),
|
||||
SyncResponseWithChannel = maps:put(<<"channel_id">>, integer_to_binary(ChannelId), SyncResponse),
|
||||
SyncResponse = build_sync_response(
|
||||
GuildId, ListId, Ranges, State
|
||||
),
|
||||
SyncResponseWithChannel = maps:put(
|
||||
<<"channel_id">>, integer_to_binary(ChannelId), SyncResponse
|
||||
),
|
||||
gen_server:cast(
|
||||
SessionPid,
|
||||
{dispatch, guild_member_list_update, SyncResponseWithChannel}
|
||||
{dispatch, guild_member_list_update,
|
||||
SyncResponseWithChannel}
|
||||
);
|
||||
_ ->
|
||||
ok
|
||||
@@ -445,14 +498,33 @@ filter_members_for_list(ListId, Members, State) ->
|
||||
)
|
||||
end.
|
||||
|
||||
-spec connected_session_user_ids(guild_state()) -> sets:set().
|
||||
connected_session_user_ids(State) ->
|
||||
Sessions = maps:get(sessions, State, #{}),
|
||||
maps:fold(
|
||||
fun(_SessionId, SessionData, Acc) ->
|
||||
case maps:get(user_id, SessionData, undefined) of
|
||||
UserId when is_integer(UserId), UserId > 0 ->
|
||||
sets:add_element(UserId, Acc);
|
||||
_ ->
|
||||
Acc
|
||||
end
|
||||
end,
|
||||
sets:new(),
|
||||
Sessions
|
||||
).
|
||||
|
||||
-spec partition_members_by_online([map()], guild_state()) -> {[map()], [map()]}.
|
||||
partition_members_by_online(Members, State) ->
|
||||
ConnectedUserIds = connected_session_user_ids(State),
|
||||
lists:partition(
|
||||
fun(Member) ->
|
||||
UserId = get_member_user_id(Member),
|
||||
Presence = resolve_presence_for_user(State, UserId),
|
||||
Status = maps:get(<<"status">>, Presence, <<"offline">>),
|
||||
Status =/= <<"offline">> andalso Status =/= <<"invisible">>
|
||||
IsConnected = sets:is_element(UserId, ConnectedUserIds),
|
||||
IsOnlineStatus = Status =/= <<"offline">> andalso Status =/= <<"invisible">>,
|
||||
IsConnected andalso IsOnlineStatus
|
||||
end,
|
||||
Members
|
||||
).
|
||||
@@ -471,15 +543,17 @@ build_role_groups(HoistedRoles, OnlineMembers) ->
|
||||
-spec count_members_with_top_role(integer(), [map()], [map()]) -> non_neg_integer().
|
||||
count_members_with_top_role(RoleId, Members, HoistedRoles) ->
|
||||
HoistedRoleIds = [map_utils:get_integer(R, <<"id">>, 0) || R <- HoistedRoles],
|
||||
length(lists:filter(
|
||||
fun(Member) ->
|
||||
MemberRoles = map_utils:ensure_list(maps:get(<<"roles">>, Member, [])),
|
||||
MemberRoleIds = [type_conv:to_integer(R) || R <- MemberRoles],
|
||||
TopHoisted = find_top_hoisted_role(MemberRoleIds, HoistedRoleIds),
|
||||
TopHoisted =:= RoleId
|
||||
end,
|
||||
Members
|
||||
)).
|
||||
length(
|
||||
lists:filter(
|
||||
fun(Member) ->
|
||||
MemberRoles = map_utils:ensure_list(maps:get(<<"roles">>, Member, [])),
|
||||
MemberRoleIds = [type_conv:to_integer(R) || R <- MemberRoles],
|
||||
TopHoisted = find_top_hoisted_role(MemberRoleIds, HoistedRoleIds),
|
||||
TopHoisted =:= RoleId
|
||||
end,
|
||||
Members
|
||||
)
|
||||
).
|
||||
|
||||
-spec find_top_hoisted_role([integer()], [integer()]) -> integer() | undefined.
|
||||
find_top_hoisted_role(MemberRoleIds, HoistedRoleIds) ->
|
||||
@@ -491,20 +565,22 @@ find_top_hoisted_role(MemberRoleIds, HoistedRoleIds) ->
|
||||
-spec count_ungrouped_online([map()], [map()]) -> non_neg_integer().
|
||||
count_ungrouped_online(OnlineMembers, HoistedRoles) ->
|
||||
HoistedRoleIds = [map_utils:get_integer(R, <<"id">>, 0) || R <- HoistedRoles],
|
||||
length(lists:filter(
|
||||
fun(Member) ->
|
||||
MemberRoles = map_utils:ensure_list(maps:get(<<"roles">>, Member, [])),
|
||||
MemberRoleIds = [type_conv:to_integer(R) || R <- MemberRoles],
|
||||
TopHoisted = find_top_hoisted_role(MemberRoleIds, HoistedRoleIds),
|
||||
TopHoisted =:= undefined
|
||||
end,
|
||||
OnlineMembers
|
||||
)).
|
||||
length(
|
||||
lists:filter(
|
||||
fun(Member) ->
|
||||
MemberRoles = map_utils:ensure_list(maps:get(<<"roles">>, Member, [])),
|
||||
MemberRoleIds = [type_conv:to_integer(R) || R <- MemberRoles],
|
||||
TopHoisted = find_top_hoisted_role(MemberRoleIds, HoistedRoleIds),
|
||||
TopHoisted =:= undefined
|
||||
end,
|
||||
OnlineMembers
|
||||
)
|
||||
).
|
||||
|
||||
-spec get_sorted_members_for_list(list_id(), guild_state()) -> [map()].
|
||||
get_sorted_members_for_list(ListId, State) ->
|
||||
Data = maps:get(data, State, #{}),
|
||||
Members = map_utils:ensure_list(maps:get(<<"members">>, Data, [])),
|
||||
Members = guild_data_index:member_values(Data),
|
||||
FilteredMembers = filter_members_for_list(ListId, Members, State),
|
||||
lists:sort(
|
||||
fun(A, B) ->
|
||||
@@ -519,11 +595,18 @@ get_member_user_id(Member) ->
|
||||
map_utils:get_integer(User, <<"id">>, 0).
|
||||
|
||||
-spec normalize_name(term()) -> binary().
|
||||
normalize_name(undefined) -> <<>>;
|
||||
normalize_name(null) -> <<>>;
|
||||
normalize_name(<<_/binary>> = B) -> B;
|
||||
normalize_name(undefined) ->
|
||||
<<>>;
|
||||
normalize_name(null) ->
|
||||
<<>>;
|
||||
normalize_name(<<_/binary>> = B) ->
|
||||
B;
|
||||
normalize_name(L) when is_list(L) ->
|
||||
try unicode:characters_to_binary(L) catch _:_ -> <<>> end;
|
||||
try
|
||||
unicode:characters_to_binary(L)
|
||||
catch
|
||||
_:_ -> <<>>
|
||||
end;
|
||||
normalize_name(I) when is_integer(I) ->
|
||||
integer_to_binary(I);
|
||||
normalize_name(_) ->
|
||||
@@ -560,11 +643,13 @@ casefold_binary(Value) ->
|
||||
_:_ -> Bin
|
||||
end.
|
||||
|
||||
-spec add_presence_to_member(map(), guild_state()) -> map().
|
||||
add_presence_to_member(Member, State) ->
|
||||
UserId = get_member_user_id(Member),
|
||||
Presence = resolve_presence_for_user(State, UserId),
|
||||
maps:put(<<"presence">>, Presence, Member).
|
||||
|
||||
-spec default_presence() -> map().
|
||||
default_presence() ->
|
||||
#{
|
||||
<<"status">> => <<"offline">>,
|
||||
@@ -572,20 +657,10 @@ default_presence() ->
|
||||
<<"afk">> => false
|
||||
}.
|
||||
|
||||
-spec resolve_presence_for_user(guild_state(), user_id()) -> map().
|
||||
resolve_presence_for_user(State, UserId) ->
|
||||
Presences = maps:get(presences, State, #{}),
|
||||
case maps:get(UserId, Presences, undefined) of
|
||||
undefined -> fetch_presence_from_cache(UserId);
|
||||
Presence -> Presence
|
||||
end.
|
||||
|
||||
fetch_presence_from_cache(UserId) ->
|
||||
try presence_cache:get(UserId) of
|
||||
{ok, Presence} -> Presence;
|
||||
_ -> default_presence()
|
||||
catch
|
||||
exit:{noproc, _} -> default_presence()
|
||||
end.
|
||||
MemberPresence = maps:get(member_presence, State, #{}),
|
||||
maps:get(UserId, MemberPresence, default_presence()).
|
||||
|
||||
-spec build_member_list_items([group_item()], [map()], guild_state()) -> [list_item()].
|
||||
build_member_list_items(Groups, Members, State) ->
|
||||
@@ -609,9 +684,21 @@ build_member_list_items(Groups, Members, State) ->
|
||||
end,
|
||||
OnlineMembers
|
||||
),
|
||||
[GroupHeader | [#{<<"member">> => add_presence_to_member(M, State)} || M <- UngroupedOnline]];
|
||||
[
|
||||
GroupHeader
|
||||
| [
|
||||
#{<<"member">> => add_presence_to_member(M, State)}
|
||||
|| M <- UngroupedOnline
|
||||
]
|
||||
];
|
||||
<<"offline">> ->
|
||||
[GroupHeader | [#{<<"member">> => add_presence_to_member(M, State)} || M <- OfflineMembers]];
|
||||
[
|
||||
GroupHeader
|
||||
| [
|
||||
#{<<"member">> => add_presence_to_member(M, State)}
|
||||
|| M <- OfflineMembers
|
||||
]
|
||||
];
|
||||
RoleIdBin ->
|
||||
RoleId = type_conv:to_integer(RoleIdBin),
|
||||
RoleMembers = lists:filter(
|
||||
@@ -622,7 +709,10 @@ build_member_list_items(Groups, Members, State) ->
|
||||
end,
|
||||
OnlineMembers
|
||||
),
|
||||
[GroupHeader | [#{<<"member">> => add_presence_to_member(M, State)} || M <- RoleMembers]]
|
||||
[
|
||||
GroupHeader
|
||||
| [#{<<"member">> => add_presence_to_member(M, State)} || M <- RoleMembers]
|
||||
]
|
||||
end
|
||||
end,
|
||||
Groups
|
||||
@@ -636,7 +726,7 @@ slice_items(Items, Start, End) ->
|
||||
false -> lists:sublist(Items, Start + 1, SafeEnd - Start + 1)
|
||||
end.
|
||||
|
||||
-spec member_in_list(integer(), [map()]) -> boolean().
|
||||
-spec member_in_list(user_id(), [map()]) -> boolean().
|
||||
member_in_list(UserId, Members) ->
|
||||
lists:any(fun(M) -> get_member_user_id(M) =:= UserId end, Members).
|
||||
|
||||
@@ -650,50 +740,83 @@ full_sync_ops(ListId, State) ->
|
||||
SortedMembers = get_sorted_members_for_list(ListId, State),
|
||||
Items = build_full_items(ListId, State, SortedMembers),
|
||||
case length(Items) of
|
||||
0 -> [];
|
||||
0 ->
|
||||
[];
|
||||
N ->
|
||||
[#{
|
||||
<<"op">> => <<"SYNC">>,
|
||||
<<"range">> => [0, N - 1],
|
||||
<<"items">> => Items
|
||||
}]
|
||||
[
|
||||
#{
|
||||
<<"op">> => <<"SYNC">>,
|
||||
<<"range">> => [0, N - 1],
|
||||
<<"items">> => Items
|
||||
}
|
||||
]
|
||||
end.
|
||||
|
||||
-spec upsert_member_in_state(integer(), map(), guild_state()) -> {map() | undefined, map(), guild_state()}.
|
||||
-spec get_members_cursor(map(), guild_state()) -> {reply, map(), guild_state()}.
|
||||
get_members_cursor(Request, State) ->
|
||||
Limit = maps:get(<<"limit">>, Request, 1),
|
||||
AfterId = maps:get(<<"after">>, Request, undefined),
|
||||
SortedMembers = sort_members_by_user_id(State),
|
||||
FilteredMembers = filter_members_after(SortedMembers, AfterId),
|
||||
ResponseMembers = take_first(FilteredMembers, Limit),
|
||||
{reply,
|
||||
#{
|
||||
members => ResponseMembers,
|
||||
total => length(SortedMembers)
|
||||
},
|
||||
State}.
|
||||
|
||||
-spec take_first([map()], integer()) -> [map()].
|
||||
take_first(_Members, Limit) when Limit =< 0 ->
|
||||
[];
|
||||
take_first(Members, Limit) ->
|
||||
Count = min(Limit, length(Members)),
|
||||
case Count of
|
||||
0 -> [];
|
||||
_ -> lists:sublist(Members, 1, Count)
|
||||
end.
|
||||
|
||||
-spec filter_members_after([map()], integer() | undefined) -> [map()].
|
||||
filter_members_after(Members, undefined) ->
|
||||
Members;
|
||||
filter_members_after(Members, AfterId) ->
|
||||
lists:dropwhile(
|
||||
fun(Member) ->
|
||||
get_member_user_id(Member) =< AfterId
|
||||
end,
|
||||
Members
|
||||
).
|
||||
|
||||
-spec sort_members_by_user_id(guild_state()) -> [map()].
|
||||
sort_members_by_user_id(State) ->
|
||||
Data = maps:get(data, State, #{}),
|
||||
Members = guild_data_index:member_values(Data),
|
||||
lists:sort(
|
||||
fun(A, B) ->
|
||||
get_member_user_id(A) =< get_member_user_id(B)
|
||||
end,
|
||||
Members
|
||||
).
|
||||
|
||||
-spec upsert_member_in_state(user_id(), map(), guild_state()) ->
|
||||
{map() | undefined, map(), guild_state()}.
|
||||
upsert_member_in_state(UserId, MemberUpdate, State) ->
|
||||
Data = maps:get(data, State, #{}),
|
||||
Members0 = map_utils:ensure_list(maps:get(<<"members">>, Data, [])),
|
||||
{Found, CurrentMember} = find_member_by_user_id(UserId, Members0),
|
||||
Members0 = guild_data_index:member_map(Data),
|
||||
CurrentMember = maps:get(UserId, Members0, undefined),
|
||||
UpdatedMember =
|
||||
case Found of
|
||||
true -> deep_merge_member(CurrentMember, MemberUpdate);
|
||||
false -> MemberUpdate
|
||||
case CurrentMember of
|
||||
undefined -> MemberUpdate;
|
||||
_ -> deep_merge_member(CurrentMember, MemberUpdate)
|
||||
end,
|
||||
Members1 =
|
||||
case Found of
|
||||
true ->
|
||||
lists:map(
|
||||
fun(M) ->
|
||||
case get_member_user_id(M) =:= UserId of
|
||||
true -> UpdatedMember;
|
||||
false -> M
|
||||
end
|
||||
end,
|
||||
Members0
|
||||
);
|
||||
false ->
|
||||
Members0 ++ [UpdatedMember]
|
||||
end,
|
||||
Data1 = maps:put(<<"members">>, Members1, Data),
|
||||
Members1 = maps:put(UserId, UpdatedMember, Members0),
|
||||
Data1 = guild_data_index:put_member_map(Members1, Data),
|
||||
NewState = maps:put(data, Data1, State),
|
||||
{case Found of true -> CurrentMember; false -> undefined end, UpdatedMember, NewState}.
|
||||
|
||||
-spec find_member_by_user_id(integer(), [map()]) -> {boolean(), map()} | {false, undefined}.
|
||||
find_member_by_user_id(UserId, Members) ->
|
||||
case lists:search(fun(M) -> get_member_user_id(M) =:= UserId end, Members) of
|
||||
{value, M} -> {true, M};
|
||||
false -> {false, undefined}
|
||||
end.
|
||||
{
|
||||
CurrentMember,
|
||||
UpdatedMember,
|
||||
NewState
|
||||
}.
|
||||
|
||||
-spec deep_merge_member(map(), map()) -> map().
|
||||
deep_merge_member(CurrentMember, MemberUpdate) ->
|
||||
@@ -741,13 +864,16 @@ diff_items_to_ops(OldItems, NewItems) ->
|
||||
-spec full_sync_from_items([list_item()]) -> [map()].
|
||||
full_sync_from_items(Items) ->
|
||||
case length(Items) of
|
||||
0 -> [];
|
||||
0 ->
|
||||
[];
|
||||
N ->
|
||||
[#{
|
||||
<<"op">> => <<"SYNC">>,
|
||||
<<"range">> => [0, N - 1],
|
||||
<<"items">> => Items
|
||||
}]
|
||||
[
|
||||
#{
|
||||
<<"op">> => <<"SYNC">>,
|
||||
<<"range">> => [0, N - 1],
|
||||
<<"items">> => Items
|
||||
}
|
||||
]
|
||||
end.
|
||||
|
||||
-spec item_keys([list_item()]) -> [term()].
|
||||
@@ -777,11 +903,14 @@ updates_for_changed_items(OldItems, NewItems) ->
|
||||
true ->
|
||||
Acc;
|
||||
false ->
|
||||
[#{
|
||||
<<"op">> => <<"UPDATE">>,
|
||||
<<"index">> => Idx,
|
||||
<<"item">> => NewItem
|
||||
} | Acc]
|
||||
[
|
||||
#{
|
||||
<<"op">> => <<"UPDATE">>,
|
||||
<<"index">> => Idx,
|
||||
<<"item">> => NewItem
|
||||
}
|
||||
| Acc
|
||||
]
|
||||
end
|
||||
end,
|
||||
[],
|
||||
@@ -792,6 +921,9 @@ updates_for_changed_items(OldItems, NewItems) ->
|
||||
-spec zip_with_index([term()], [term()]) -> [{non_neg_integer(), term(), term()}].
|
||||
zip_with_index(A, B) ->
|
||||
zip_with_index(A, B, 0, []).
|
||||
|
||||
-spec zip_with_index([term()], [term()], non_neg_integer(), [{non_neg_integer(), term(), term()}]) ->
|
||||
[{non_neg_integer(), term(), term()}].
|
||||
zip_with_index([], [], _I, Acc) ->
|
||||
lists:reverse(Acc);
|
||||
zip_with_index([HA | TA], [HB | TB], I, Acc) ->
|
||||
@@ -802,6 +934,11 @@ zip_with_index(_, _, _I, Acc) ->
|
||||
-spec mismatch_span([term()], [term()]) -> none | {non_neg_integer(), non_neg_integer()}.
|
||||
mismatch_span(A, B) ->
|
||||
mismatch_span(A, B, 0, none).
|
||||
|
||||
-spec mismatch_span(
|
||||
[term()], [term()], non_neg_integer(), none | {non_neg_integer(), non_neg_integer()}
|
||||
) ->
|
||||
none | {non_neg_integer(), non_neg_integer()}.
|
||||
mismatch_span([], [], _I, none) ->
|
||||
none;
|
||||
mismatch_span([], [], _I, {S, E}) ->
|
||||
@@ -831,7 +968,8 @@ sync_range_op(Start, End, NewItems) ->
|
||||
<<"items">> => slice_items(NewItems, Start, End)
|
||||
}.
|
||||
|
||||
-spec try_pure_insert_delete([list_item()], [list_item()], [term()], [term()]) -> {ok, [map()]} | error.
|
||||
-spec try_pure_insert_delete([list_item()], [list_item()], [term()], [term()]) ->
|
||||
{ok, [map()]} | error.
|
||||
try_pure_insert_delete(OldItems, NewItems, OldKeys, NewKeys) ->
|
||||
LenOld = length(OldKeys),
|
||||
LenNew = length(NewKeys),
|
||||
@@ -872,6 +1010,8 @@ try_pure_insert_delete(OldItems, NewItems, OldKeys, NewKeys) ->
|
||||
-spec first_mismatch_index([term()], [term()]) -> non_neg_integer().
|
||||
first_mismatch_index(A, B) ->
|
||||
first_mismatch_index(A, B, 0).
|
||||
|
||||
-spec first_mismatch_index([term()], [term()], non_neg_integer()) -> non_neg_integer().
|
||||
first_mismatch_index([], _B, I) ->
|
||||
I;
|
||||
first_mismatch_index(_A, [], I) ->
|
||||
@@ -896,6 +1036,8 @@ delete_many(List, Index, Count) ->
|
||||
-spec insert_ops(non_neg_integer(), [list_item()]) -> [map()].
|
||||
insert_ops(StartIdx, Items) ->
|
||||
insert_ops(StartIdx, Items, 0, []).
|
||||
|
||||
-spec insert_ops(non_neg_integer(), [list_item()], non_neg_integer(), [map()]) -> [map()].
|
||||
insert_ops(_StartIdx, [], _Offset, Acc) ->
|
||||
lists:reverse(Acc);
|
||||
insert_ops(StartIdx, [Item | Rest], Offset, Acc) ->
|
||||
@@ -920,18 +1062,23 @@ delete_ops(Idx, Count) ->
|
||||
lists:seq(1, Count)
|
||||
).
|
||||
|
||||
-spec session_can_view_channel(map(), integer(), guild_state()) -> boolean().
|
||||
session_can_view_channel(_SessionData, ChannelId, _State) when not is_integer(ChannelId); ChannelId =< 0 ->
|
||||
-spec session_can_view_channel(map(), channel_id(), guild_state()) -> boolean().
|
||||
session_can_view_channel(_SessionData, ChannelId, _State) when
|
||||
not is_integer(ChannelId); ChannelId =< 0
|
||||
->
|
||||
false;
|
||||
session_can_view_channel(SessionData, ChannelId, State) ->
|
||||
case maps:get(user_id, SessionData, undefined) of
|
||||
UserId when is_integer(UserId) ->
|
||||
case {maps:get(user_id, SessionData, undefined), maps:get(viewable_channels, SessionData, undefined)} of
|
||||
{UserId, ViewableChannels} when is_integer(UserId), is_map(ViewableChannels) ->
|
||||
maps:is_key(ChannelId, ViewableChannels) orelse
|
||||
guild_permissions:can_view_channel(UserId, ChannelId, undefined, State);
|
||||
{UserId, _} when is_integer(UserId) ->
|
||||
guild_permissions:can_view_channel(UserId, ChannelId, undefined, State);
|
||||
_ ->
|
||||
false
|
||||
end.
|
||||
|
||||
-spec list_id_channel_id(list_id()) -> integer().
|
||||
-spec list_id_channel_id(list_id()) -> channel_id().
|
||||
list_id_channel_id(ListId) when is_binary(ListId) ->
|
||||
case type_conv:to_integer(ListId) of
|
||||
undefined -> 0;
|
||||
@@ -966,6 +1113,16 @@ list_id_channel_id_parses_binary_test() ->
|
||||
list_id_channel_id_invalid_value_test() ->
|
||||
?assertEqual(0, list_id_channel_id(<<"abc">>)).
|
||||
|
||||
session_can_view_channel_uses_cached_visibility_test() ->
|
||||
SessionData = #{user_id => 12, viewable_channels => #{500 => true}},
|
||||
State = #{data => #{<<"members">> => #{}}},
|
||||
?assertEqual(true, session_can_view_channel(SessionData, 500, State)).
|
||||
|
||||
session_can_view_channel_rejects_when_cache_misses_and_user_missing_test() ->
|
||||
SessionData = #{user_id => 99, viewable_channels => #{}},
|
||||
State = #{data => #{<<"members">> => #{}}},
|
||||
?assertEqual(false, session_can_view_channel(SessionData, 500, State)).
|
||||
|
||||
subscribe_ranges_test() ->
|
||||
State = #{member_list_subscriptions => #{}},
|
||||
{NewState, ShouldSync, NormalizedRanges} =
|
||||
@@ -1024,4 +1181,64 @@ validate_range_invalid_negative_test() ->
|
||||
validate_range_invalid_too_large_test() ->
|
||||
?assertEqual(invalid, validate_range({0, 100001})).
|
||||
|
||||
get_members_cursor_returns_atom_keys_test() ->
|
||||
Member1 = #{<<"user">> => #{<<"id">> => <<"2">>}},
|
||||
Member2 = #{<<"user">> => #{<<"id">> => <<"1">>}},
|
||||
State = #{data => #{<<"members">> => [Member1, Member2]}},
|
||||
{reply, Reply, _NewState} = get_members_cursor(#{<<"limit">> => 1}, State),
|
||||
?assert(maps:is_key(members, Reply)),
|
||||
?assert(maps:is_key(total, Reply)),
|
||||
?assertNot(maps:is_key(<<"members">>, Reply)),
|
||||
?assertNot(maps:is_key(<<"total">>, Reply)).
|
||||
|
||||
get_online_count_ignores_members_without_connected_session_test() ->
|
||||
Members = [
|
||||
#{<<"user">> => #{<<"id">> => <<"1">>}},
|
||||
#{<<"user">> => #{<<"id">> => <<"2">>}}
|
||||
],
|
||||
State = #{
|
||||
data => #{<<"members">> => Members},
|
||||
sessions => #{<<"s1">> => #{user_id => 1}},
|
||||
member_presence => #{
|
||||
1 => #{<<"status">> => <<"online">>},
|
||||
2 => #{<<"status">> => <<"online">>}
|
||||
}
|
||||
},
|
||||
?assertEqual(1, get_online_count(State)).
|
||||
|
||||
filter_members_for_list_keeps_members_without_connected_session_test() ->
|
||||
Members = [
|
||||
#{<<"user">> => #{<<"id">> => <<"1">>}},
|
||||
#{<<"user">> => #{<<"id">> => <<"2">>}}
|
||||
],
|
||||
State = #{
|
||||
sessions => #{<<"s1">> => #{user_id => 1}},
|
||||
data => #{<<"channels">> => []}
|
||||
},
|
||||
FilteredMembers = filter_members_for_list(<<"0">>, Members, State),
|
||||
?assertEqual([1, 2], [get_member_user_id(Member) || Member <- FilteredMembers]).
|
||||
|
||||
get_member_groups_counts_members_without_connected_session_as_offline_test() ->
|
||||
Members = [
|
||||
#{<<"user">> => #{<<"id">> => <<"1">>}},
|
||||
#{<<"user">> => #{<<"id">> => <<"2">>}}
|
||||
],
|
||||
State = #{
|
||||
id => 123,
|
||||
data => #{
|
||||
<<"members">> => Members,
|
||||
<<"roles">> => []
|
||||
},
|
||||
sessions => #{<<"s1">> => #{user_id => 1}},
|
||||
member_presence => #{
|
||||
1 => #{<<"status">> => <<"online">>},
|
||||
2 => #{<<"status">> => <<"online">>}
|
||||
}
|
||||
},
|
||||
Groups = get_member_groups(<<"0">>, State),
|
||||
OnlineGroup = lists:nth(1, Groups),
|
||||
OfflineGroup = lists:nth(2, Groups),
|
||||
?assertEqual(1, maps:get(<<"count">>, OnlineGroup)),
|
||||
?assertEqual(1, maps:get(<<"count">>, OfflineGroup)).
|
||||
|
||||
-endif.
|
||||
|
||||
@@ -29,40 +29,35 @@
|
||||
compute_list_id/1
|
||||
]).
|
||||
|
||||
-record(member_storage, {
|
||||
members_table :: ets:tid(),
|
||||
display_name_index :: gb_trees:tree()
|
||||
}).
|
||||
|
||||
-type storage() :: #member_storage{}.
|
||||
-type storage() :: #{
|
||||
members_table := ets:tid(),
|
||||
display_name_index := gb_trees:tree({binary(), user_id()}, user_id())
|
||||
}.
|
||||
-type user_id() :: integer().
|
||||
-type member() :: map().
|
||||
-type index_key() :: {binary(), user_id()}.
|
||||
|
||||
-export_type([storage/0]).
|
||||
|
||||
-ifdef(TEST).
|
||||
-include_lib("eunit/include/eunit.hrl").
|
||||
-endif.
|
||||
|
||||
-spec new() -> storage().
|
||||
new() ->
|
||||
MembersTable = ets:new(members, [set, private]),
|
||||
MembersTable = ets:new(members, [set, private, {read_concurrency, true}]),
|
||||
DisplayNameIndex = gb_trees:empty(),
|
||||
#member_storage{
|
||||
members_table = MembersTable,
|
||||
display_name_index = DisplayNameIndex
|
||||
#{
|
||||
members_table => MembersTable,
|
||||
display_name_index => DisplayNameIndex
|
||||
}.
|
||||
|
||||
-spec insert_member(member(), storage()) -> storage().
|
||||
insert_member(Member, Storage) ->
|
||||
UserId = extract_user_id(Member),
|
||||
case UserId of
|
||||
case extract_user_id(Member) of
|
||||
undefined ->
|
||||
Storage;
|
||||
_ ->
|
||||
UserId ->
|
||||
OldMember = get_member(UserId, Storage),
|
||||
Storage1 = remove_from_index(OldMember, Storage),
|
||||
ets:insert(Storage1#member_storage.members_table, {UserId, Member}),
|
||||
#{members_table := MembersTable} = Storage1,
|
||||
ets:insert(MembersTable, {UserId, Member}),
|
||||
add_to_index(UserId, Member, Storage1)
|
||||
end.
|
||||
|
||||
@@ -73,13 +68,15 @@ remove_member(UserId, Storage) ->
|
||||
Storage;
|
||||
Member ->
|
||||
Storage1 = remove_from_index(Member, Storage),
|
||||
ets:delete(Storage1#member_storage.members_table, UserId),
|
||||
#{members_table := MembersTable} = Storage1,
|
||||
ets:delete(MembersTable, UserId),
|
||||
Storage1
|
||||
end.
|
||||
|
||||
-spec get_member(user_id(), storage()) -> member() | undefined.
|
||||
get_member(UserId, Storage) ->
|
||||
case ets:lookup(Storage#member_storage.members_table, UserId) of
|
||||
#{members_table := MembersTable} = Storage,
|
||||
case ets:lookup(MembersTable, UserId) of
|
||||
[{UserId, Member}] -> Member;
|
||||
[] -> undefined
|
||||
end.
|
||||
@@ -100,41 +97,46 @@ get_members_by_ids(UserIds, Storage) ->
|
||||
search_members(Query, Limit, Storage) when is_binary(Query), Limit > 0 ->
|
||||
NormalizedQuery = normalize_display_name(Query),
|
||||
case NormalizedQuery of
|
||||
<<>> ->
|
||||
[];
|
||||
_ ->
|
||||
search_by_prefix(NormalizedQuery, Limit, Storage)
|
||||
<<>> -> [];
|
||||
_ -> search_by_prefix(NormalizedQuery, Limit, Storage)
|
||||
end;
|
||||
search_members(_, _, _) ->
|
||||
[].
|
||||
|
||||
-spec get_range(non_neg_integer(), non_neg_integer(), storage()) -> [member()].
|
||||
get_range(Offset, Limit, Storage) when is_integer(Offset), is_integer(Limit), Limit > 0 ->
|
||||
Index = Storage#member_storage.display_name_index,
|
||||
case gb_trees:size(Index) of
|
||||
Size when Offset >= Size ->
|
||||
[];
|
||||
Size ->
|
||||
AllKeys = gb_trees:keys(Index),
|
||||
EndIdx = min(Offset + Limit, Size),
|
||||
SelectedKeys = lists:sublist(AllKeys, Offset + 1, EndIdx - Offset),
|
||||
lists:filtermap(
|
||||
fun(Key) ->
|
||||
UserId = gb_trees:get(Key, Index),
|
||||
case get_member(UserId, Storage) of
|
||||
undefined -> false;
|
||||
Member -> {true, Member}
|
||||
end
|
||||
end,
|
||||
SelectedKeys
|
||||
)
|
||||
#{display_name_index := Index} = Storage,
|
||||
Size = gb_trees:size(Index),
|
||||
case Offset >= Size of
|
||||
true -> [];
|
||||
false -> get_range_from_index(Offset, Limit, Size, Index, Storage)
|
||||
end;
|
||||
get_range(_, _, _) ->
|
||||
[].
|
||||
|
||||
-spec get_range_from_index(
|
||||
non_neg_integer(), non_neg_integer(), non_neg_integer(), gb_trees:tree(), storage()
|
||||
) ->
|
||||
[member()].
|
||||
get_range_from_index(Offset, Limit, Size, Index, Storage) ->
|
||||
AllKeys = gb_trees:keys(Index),
|
||||
EndIdx = min(Offset + Limit, Size),
|
||||
SelectedKeys = lists:sublist(AllKeys, Offset + 1, EndIdx - Offset),
|
||||
lists:filtermap(
|
||||
fun(Key) ->
|
||||
UserId = gb_trees:get(Key, Index),
|
||||
case get_member(UserId, Storage) of
|
||||
undefined -> false;
|
||||
Member -> {true, Member}
|
||||
end
|
||||
end,
|
||||
SelectedKeys
|
||||
).
|
||||
|
||||
-spec count(storage()) -> non_neg_integer().
|
||||
count(Storage) ->
|
||||
ets:info(Storage#member_storage.members_table, size).
|
||||
#{members_table := MembersTable} = Storage,
|
||||
ets:info(MembersTable, size).
|
||||
|
||||
-spec compute_list_id([user_id()]) -> integer().
|
||||
compute_list_id(UserIds) ->
|
||||
@@ -155,19 +157,17 @@ extract_user_id(_) ->
|
||||
|
||||
-spec get_display_name(member()) -> binary().
|
||||
get_display_name(Member) when is_map(Member) ->
|
||||
Nick = maps:get(<<"nick">>, Member, undefined),
|
||||
case Nick of
|
||||
undefined ->
|
||||
User = maps:get(<<"user">>, Member, #{}),
|
||||
GlobalName = maps:get(<<"global_name">>, User, undefined),
|
||||
case GlobalName of
|
||||
undefined ->
|
||||
maps:get(<<"username">>, User, <<>>);
|
||||
_ ->
|
||||
GlobalName
|
||||
end;
|
||||
_ ->
|
||||
Nick
|
||||
case maps:get(<<"nick">>, Member, undefined) of
|
||||
undefined -> get_display_name_from_user(Member);
|
||||
Nick -> Nick
|
||||
end.
|
||||
|
||||
-spec get_display_name_from_user(member()) -> binary().
|
||||
get_display_name_from_user(Member) ->
|
||||
User = maps:get(<<"user">>, Member, #{}),
|
||||
case maps:get(<<"global_name">>, User, undefined) of
|
||||
undefined -> maps:get(<<"username">>, User, <<>>);
|
||||
GlobalName -> GlobalName
|
||||
end.
|
||||
|
||||
-spec normalize_display_name(binary()) -> binary().
|
||||
@@ -180,9 +180,9 @@ add_to_index(UserId, Member, Storage) ->
|
||||
DisplayName = get_display_name(Member),
|
||||
NormalizedName = normalize_display_name(DisplayName),
|
||||
Key = make_index_key(NormalizedName, UserId),
|
||||
Index = Storage#member_storage.display_name_index,
|
||||
#{display_name_index := Index} = Storage,
|
||||
NewIndex = gb_trees:enter(Key, UserId, Index),
|
||||
Storage#member_storage{display_name_index = NewIndex}.
|
||||
Storage#{display_name_index => NewIndex}.
|
||||
|
||||
-spec remove_from_index(member() | undefined, storage()) -> storage().
|
||||
remove_from_index(undefined, Storage) ->
|
||||
@@ -192,41 +192,46 @@ remove_from_index(Member, Storage) ->
|
||||
DisplayName = get_display_name(Member),
|
||||
NormalizedName = normalize_display_name(DisplayName),
|
||||
Key = make_index_key(NormalizedName, UserId),
|
||||
Index = Storage#member_storage.display_name_index,
|
||||
#{display_name_index := Index} = Storage,
|
||||
case gb_trees:is_defined(Key, Index) of
|
||||
true ->
|
||||
NewIndex = gb_trees:delete(Key, Index),
|
||||
Storage#member_storage{display_name_index = NewIndex};
|
||||
Storage#{display_name_index => NewIndex};
|
||||
false ->
|
||||
Storage
|
||||
end.
|
||||
|
||||
-spec make_index_key(binary(), user_id()) -> {binary(), user_id()}.
|
||||
-spec make_index_key(binary(), user_id()) -> index_key().
|
||||
make_index_key(NormalizedName, UserId) ->
|
||||
{NormalizedName, UserId}.
|
||||
|
||||
-spec search_by_prefix(binary(), non_neg_integer(), storage()) -> [member()].
|
||||
search_by_prefix(Prefix, Limit, Storage) ->
|
||||
Index = Storage#member_storage.display_name_index,
|
||||
#{display_name_index := Index} = Storage,
|
||||
AllKeys = gb_trees:keys(Index),
|
||||
Matches = lists:filtermap(
|
||||
fun({Name, UserId}) ->
|
||||
PrefixLen = byte_size(Prefix),
|
||||
case Name of
|
||||
<<Prefix:PrefixLen/binary, _/binary>> ->
|
||||
case get_member(UserId, Storage) of
|
||||
undefined -> false;
|
||||
Member -> {true, Member}
|
||||
end;
|
||||
_ ->
|
||||
false
|
||||
end
|
||||
end,
|
||||
AllKeys
|
||||
),
|
||||
PrefixLen = byte_size(Prefix),
|
||||
Matches = find_prefix_matches(AllKeys, Prefix, PrefixLen, Storage, []),
|
||||
lists:sublist(Matches, Limit).
|
||||
|
||||
-spec find_prefix_matches([index_key()], binary(), non_neg_integer(), storage(), [member()]) ->
|
||||
[member()].
|
||||
find_prefix_matches([], _Prefix, _PrefixLen, _Storage, Acc) ->
|
||||
lists:reverse(Acc);
|
||||
find_prefix_matches([{Name, UserId} | Rest], Prefix, PrefixLen, Storage, Acc) ->
|
||||
case Name of
|
||||
<<Prefix:PrefixLen/binary, _/binary>> ->
|
||||
case get_member(UserId, Storage) of
|
||||
undefined ->
|
||||
find_prefix_matches(Rest, Prefix, PrefixLen, Storage, Acc);
|
||||
Member ->
|
||||
find_prefix_matches(Rest, Prefix, PrefixLen, Storage, [Member | Acc])
|
||||
end;
|
||||
_ ->
|
||||
find_prefix_matches(Rest, Prefix, PrefixLen, Storage, Acc)
|
||||
end.
|
||||
|
||||
-ifdef(TEST).
|
||||
-include_lib("eunit/include/eunit.hrl").
|
||||
|
||||
new_creates_empty_storage_test() ->
|
||||
Storage = new(),
|
||||
@@ -309,11 +314,16 @@ display_name_username_fallback_test() ->
|
||||
},
|
||||
?assertEqual(<<"user">>, get_display_name(Member)).
|
||||
|
||||
compute_list_id_test() ->
|
||||
compute_list_id_deterministic_test() ->
|
||||
Id1 = compute_list_id([1, 2, 3]),
|
||||
Id2 = compute_list_id([3, 2, 1]),
|
||||
?assertEqual(Id1, Id2).
|
||||
|
||||
compute_list_id_different_for_different_lists_test() ->
|
||||
Id1 = compute_list_id([1, 2, 3]),
|
||||
Id2 = compute_list_id([1, 2, 4]),
|
||||
?assertNotEqual(Id1, Id2).
|
||||
|
||||
get_range_test() ->
|
||||
Storage = new(),
|
||||
Member1 = #{<<"user">> => #{<<"id">> => <<"1">>, <<"username">> => <<"alice">>}},
|
||||
@@ -325,4 +335,46 @@ get_range_test() ->
|
||||
Results = get_range(1, 2, Storage3),
|
||||
?assertEqual(2, length(Results)).
|
||||
|
||||
get_range_offset_beyond_size_test() ->
|
||||
Storage = new(),
|
||||
Member1 = #{<<"user">> => #{<<"id">> => <<"1">>, <<"username">> => <<"alice">>}},
|
||||
Storage1 = insert_member(Member1, Storage),
|
||||
Results = get_range(10, 5, Storage1),
|
||||
?assertEqual([], Results).
|
||||
|
||||
search_members_empty_query_test() ->
|
||||
Storage = new(),
|
||||
Member1 = #{<<"user">> => #{<<"id">> => <<"1">>, <<"username">> => <<"alice">>}},
|
||||
Storage1 = insert_member(Member1, Storage),
|
||||
Results = search_members(<<>>, 10, Storage1),
|
||||
?assertEqual([], Results).
|
||||
|
||||
search_members_limit_test() ->
|
||||
Storage = new(),
|
||||
Member1 = #{<<"user">> => #{<<"id">> => <<"1">>, <<"username">> => <<"aaa">>}},
|
||||
Member2 = #{<<"user">> => #{<<"id">> => <<"2">>, <<"username">> => <<"aab">>}},
|
||||
Member3 = #{<<"user">> => #{<<"id">> => <<"3">>, <<"username">> => <<"aac">>}},
|
||||
Storage1 = insert_member(Member1, Storage),
|
||||
Storage2 = insert_member(Member2, Storage1),
|
||||
Storage3 = insert_member(Member3, Storage2),
|
||||
Results = search_members(<<"aa">>, 2, Storage3),
|
||||
?assertEqual(2, length(Results)).
|
||||
|
||||
normalize_display_name_test() ->
|
||||
?assertEqual(<<"hello">>, normalize_display_name(<<"HELLO">>)),
|
||||
?assertEqual(<<"hello">>, normalize_display_name(<<"Hello">>)),
|
||||
?assertEqual(<<"hello">>, normalize_display_name(<<"hello">>)).
|
||||
|
||||
insert_member_updates_index_test() ->
|
||||
Storage = new(),
|
||||
Member1 = #{<<"user">> => #{<<"id">> => <<"1">>, <<"username">> => <<"alice">>}},
|
||||
Storage1 = insert_member(Member1, Storage),
|
||||
Member2 = #{<<"user">> => #{<<"id">> => <<"1">>, <<"username">> => <<"bob">>}},
|
||||
Storage2 = insert_member(Member2, Storage1),
|
||||
?assertEqual(1, count(Storage2)),
|
||||
AliceResults = search_members(<<"alice">>, 10, Storage2),
|
||||
?assertEqual(0, length(AliceResults)),
|
||||
BobResults = search_members(<<"bob">>, 10, Storage2),
|
||||
?assertEqual(1, length(BobResults)).
|
||||
|
||||
-endif.
|
||||
|
||||
@@ -17,16 +17,18 @@
|
||||
|
||||
-module(guild_members).
|
||||
|
||||
-export([get_users_to_mention_by_roles/2]).
|
||||
-export([get_users_to_mention_by_user_ids/2]).
|
||||
-export([get_all_users_to_mention/2]).
|
||||
-export([resolve_all_mentions/2]).
|
||||
-export([get_members_with_role/2]).
|
||||
-export([can_manage_roles/2]).
|
||||
-export([can_manage_role/2]).
|
||||
-export([get_assignable_roles/2]).
|
||||
-export([check_target_member/2]).
|
||||
-export([get_viewable_channels/2]).
|
||||
-export([
|
||||
get_users_to_mention_by_roles/2,
|
||||
get_users_to_mention_by_user_ids/2,
|
||||
get_all_users_to_mention/2,
|
||||
resolve_all_mentions/2,
|
||||
get_members_with_role/2,
|
||||
can_manage_roles/2,
|
||||
can_manage_role/2,
|
||||
get_assignable_roles/2,
|
||||
check_target_member/2,
|
||||
get_viewable_channels/2
|
||||
]).
|
||||
|
||||
-type guild_state() :: map().
|
||||
-type guild_reply(T) :: {reply, T, guild_state()}.
|
||||
@@ -37,22 +39,18 @@
|
||||
-type role_id() :: integer().
|
||||
-type channel_id() :: integer().
|
||||
|
||||
-ifdef(TEST).
|
||||
-include_lib("eunit/include/eunit.hrl").
|
||||
-endif.
|
||||
|
||||
-spec get_users_to_mention_by_roles(map(), guild_state()) -> guild_reply(map()).
|
||||
get_users_to_mention_by_roles(
|
||||
#{channel_id := ChannelId, role_ids := RoleIds, author_id := AuthorId}, State
|
||||
) ->
|
||||
Members = guild_members(State),
|
||||
RoleIdSet = normalize_int_list(RoleIds),
|
||||
UserIds = collect_mentions(
|
||||
Members,
|
||||
RoleIdList = normalize_int_list(RoleIds),
|
||||
CandidateUserIds = user_ids_for_any_role(RoleIdList, State),
|
||||
UserIds = collect_mentions_for_user_ids(
|
||||
CandidateUserIds,
|
||||
AuthorId,
|
||||
ChannelId,
|
||||
State,
|
||||
fun(Member) -> member_has_any_role(Member, RoleIdSet) end
|
||||
fun(_UserId, _Member) -> true end
|
||||
),
|
||||
{reply, #{user_ids => UserIds}, State}.
|
||||
|
||||
@@ -60,19 +58,13 @@ get_users_to_mention_by_roles(
|
||||
get_users_to_mention_by_user_ids(
|
||||
#{channel_id := ChannelId, user_ids := UserIdsReq, author_id := AuthorId}, State
|
||||
) ->
|
||||
Members = guild_members(State),
|
||||
TargetIds = normalize_int_list(UserIdsReq),
|
||||
UserIds = collect_mentions(
|
||||
Members,
|
||||
UserIds = collect_mentions_for_user_ids(
|
||||
TargetIds,
|
||||
AuthorId,
|
||||
ChannelId,
|
||||
State,
|
||||
fun(Member) ->
|
||||
case member_user_id(Member) of
|
||||
undefined -> false;
|
||||
Id -> lists:member(Id, TargetIds)
|
||||
end
|
||||
end
|
||||
fun(_UserId, _Member) -> true end
|
||||
),
|
||||
{reply, #{user_ids => UserIds}, State}.
|
||||
|
||||
@@ -83,7 +75,7 @@ get_all_users_to_mention(#{channel_id := ChannelId, author_id := AuthorId}, Stat
|
||||
{reply, #{user_ids => UserIds}, State}.
|
||||
|
||||
-spec resolve_all_mentions(map(), guild_state()) -> guild_reply(map()).
|
||||
resolve_all_mentions(
|
||||
resolve_all_mentions(Request, State) ->
|
||||
#{
|
||||
channel_id := ChannelId,
|
||||
author_id := AuthorId,
|
||||
@@ -91,48 +83,122 @@ resolve_all_mentions(
|
||||
mention_here := MentionHere,
|
||||
role_ids := RoleIds,
|
||||
user_ids := DirectUserIds
|
||||
},
|
||||
State
|
||||
) ->
|
||||
} = Request,
|
||||
Members = guild_members(State),
|
||||
MemberMap = guild_data_index:member_map(guild_data(State)),
|
||||
Sessions = maps:get(sessions, State, #{}),
|
||||
|
||||
RoleIdSet = gb_sets:from_list(normalize_int_list(RoleIds)),
|
||||
DirectUserIdSet = gb_sets:from_list(normalize_int_list(DirectUserIds)),
|
||||
HasRoleMentions = not gb_sets:is_empty(RoleIdSet),
|
||||
HasDirectMentions = not gb_sets:is_empty(DirectUserIdSet),
|
||||
|
||||
ConnectedUserIds =
|
||||
case MentionHere of
|
||||
ConnectedUserIds = build_connected_user_ids(MentionHere, Sessions),
|
||||
UserIds =
|
||||
case MentionEveryone of
|
||||
true ->
|
||||
gb_sets:from_list([
|
||||
maps:get(user_id, S)
|
||||
|| {_Sid, S} <- maps:to_list(Sessions)
|
||||
]);
|
||||
resolve_mentions(
|
||||
Members,
|
||||
AuthorId,
|
||||
ChannelId,
|
||||
MentionEveryone,
|
||||
MentionHere,
|
||||
HasRoleMentions,
|
||||
HasDirectMentions,
|
||||
RoleIdSet,
|
||||
DirectUserIdSet,
|
||||
ConnectedUserIds,
|
||||
State
|
||||
);
|
||||
false ->
|
||||
gb_sets:empty()
|
||||
CandidateUserIds = candidate_user_ids_for_mentions(
|
||||
MentionHere,
|
||||
HasRoleMentions,
|
||||
HasDirectMentions,
|
||||
RoleIdSet,
|
||||
DirectUserIdSet,
|
||||
ConnectedUserIds,
|
||||
State
|
||||
),
|
||||
resolve_mentions_for_user_ids(
|
||||
CandidateUserIds,
|
||||
MemberMap,
|
||||
AuthorId,
|
||||
ChannelId,
|
||||
MentionEveryone,
|
||||
MentionHere,
|
||||
HasRoleMentions,
|
||||
HasDirectMentions,
|
||||
RoleIdSet,
|
||||
DirectUserIdSet,
|
||||
ConnectedUserIds,
|
||||
State
|
||||
)
|
||||
end,
|
||||
{reply, #{user_ids => UserIds}, State}.
|
||||
|
||||
UserIds = lists:filtermap(
|
||||
-spec build_connected_user_ids(boolean(), map()) -> gb_sets:set(user_id()).
|
||||
build_connected_user_ids(false, _Sessions) ->
|
||||
gb_sets:empty();
|
||||
build_connected_user_ids(true, Sessions) ->
|
||||
gb_sets:from_list(
|
||||
lists:filtermap(
|
||||
fun({_Sid, SessionData}) ->
|
||||
case maps:get(user_id, SessionData, undefined) of
|
||||
UserId when is_integer(UserId) -> {true, UserId};
|
||||
_ -> false
|
||||
end
|
||||
end,
|
||||
maps:to_list(Sessions)
|
||||
)
|
||||
).
|
||||
|
||||
-spec resolve_mentions(
|
||||
[member()],
|
||||
user_id(),
|
||||
channel_id(),
|
||||
boolean(),
|
||||
boolean(),
|
||||
boolean(),
|
||||
boolean(),
|
||||
gb_sets:set(),
|
||||
gb_sets:set(),
|
||||
gb_sets:set(),
|
||||
guild_state()
|
||||
) -> [user_id()].
|
||||
resolve_mentions(
|
||||
Members,
|
||||
AuthorId,
|
||||
ChannelId,
|
||||
MentionEveryone,
|
||||
MentionHere,
|
||||
HasRoleMentions,
|
||||
HasDirectMentions,
|
||||
RoleIdSet,
|
||||
DirectUserIdSet,
|
||||
ConnectedUserIds,
|
||||
State
|
||||
) ->
|
||||
lists:filtermap(
|
||||
fun(Member) ->
|
||||
case member_user_id(Member) of
|
||||
undefined ->
|
||||
false;
|
||||
UserId when UserId =:= AuthorId ->
|
||||
false;
|
||||
UserId when UserId =:= AuthorId -> false;
|
||||
UserId ->
|
||||
case is_member_bot(Member) of
|
||||
true ->
|
||||
false;
|
||||
false ->
|
||||
ShouldMention =
|
||||
MentionEveryone orelse
|
||||
(MentionHere andalso
|
||||
gb_sets:is_member(UserId, ConnectedUserIds)) orelse
|
||||
(HasRoleMentions andalso
|
||||
member_has_any_role_set(Member, RoleIdSet)) orelse
|
||||
(HasDirectMentions andalso
|
||||
gb_sets:is_member(UserId, DirectUserIdSet)),
|
||||
ShouldMention = check_should_mention(
|
||||
UserId,
|
||||
Member,
|
||||
MentionEveryone,
|
||||
MentionHere,
|
||||
HasRoleMentions,
|
||||
HasDirectMentions,
|
||||
RoleIdSet,
|
||||
DirectUserIdSet,
|
||||
ConnectedUserIds
|
||||
),
|
||||
case
|
||||
ShouldMention andalso
|
||||
member_can_view_channel(UserId, ChannelId, Member, State)
|
||||
@@ -144,61 +210,206 @@ resolve_all_mentions(
|
||||
end
|
||||
end,
|
||||
Members
|
||||
),
|
||||
{reply, #{user_ids => UserIds}, State}.
|
||||
).
|
||||
|
||||
-spec get_members_with_role(map(), guild_state()) -> guild_reply(map()).
|
||||
get_members_with_role(#{role_id := RoleId}, State) ->
|
||||
Members = guild_members(State),
|
||||
TargetRoles = [RoleId],
|
||||
UserIds = lists:filtermap(
|
||||
fun(Member) ->
|
||||
case member_user_id(Member) of
|
||||
undefined ->
|
||||
-spec resolve_mentions_for_user_ids(
|
||||
[user_id()],
|
||||
#{user_id() => member()},
|
||||
user_id(),
|
||||
channel_id(),
|
||||
boolean(),
|
||||
boolean(),
|
||||
boolean(),
|
||||
boolean(),
|
||||
gb_sets:set(),
|
||||
gb_sets:set(),
|
||||
gb_sets:set(),
|
||||
guild_state()
|
||||
) -> [user_id()].
|
||||
resolve_mentions_for_user_ids(
|
||||
CandidateUserIds,
|
||||
MemberMap,
|
||||
AuthorId,
|
||||
ChannelId,
|
||||
MentionEveryone,
|
||||
MentionHere,
|
||||
HasRoleMentions,
|
||||
HasDirectMentions,
|
||||
RoleIdSet,
|
||||
DirectUserIdSet,
|
||||
ConnectedUserIds,
|
||||
State
|
||||
) ->
|
||||
lists:filtermap(
|
||||
fun(UserId) ->
|
||||
case UserId =:= AuthorId of
|
||||
true ->
|
||||
false;
|
||||
UserId ->
|
||||
case member_has_any_role(Member, TargetRoles) of
|
||||
true -> {true, UserId};
|
||||
false -> false
|
||||
false ->
|
||||
case maps:get(UserId, MemberMap, undefined) of
|
||||
undefined ->
|
||||
false;
|
||||
Member ->
|
||||
case is_member_bot(Member) of
|
||||
true ->
|
||||
false;
|
||||
false ->
|
||||
ShouldMention = check_should_mention(
|
||||
UserId,
|
||||
Member,
|
||||
MentionEveryone,
|
||||
MentionHere,
|
||||
HasRoleMentions,
|
||||
HasDirectMentions,
|
||||
RoleIdSet,
|
||||
DirectUserIdSet,
|
||||
ConnectedUserIds
|
||||
),
|
||||
case
|
||||
ShouldMention andalso
|
||||
member_can_view_channel(UserId, ChannelId, Member, State)
|
||||
of
|
||||
true -> {true, UserId};
|
||||
false -> false
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end,
|
||||
Members
|
||||
lists:usort(CandidateUserIds)
|
||||
).
|
||||
|
||||
-spec candidate_user_ids_for_mentions(
|
||||
boolean(),
|
||||
boolean(),
|
||||
boolean(),
|
||||
gb_sets:set(),
|
||||
gb_sets:set(),
|
||||
gb_sets:set(),
|
||||
guild_state()
|
||||
) -> [user_id()].
|
||||
candidate_user_ids_for_mentions(
|
||||
MentionHere,
|
||||
HasRoleMentions,
|
||||
HasDirectMentions,
|
||||
RoleIdSet,
|
||||
DirectUserIdSet,
|
||||
ConnectedUserIds,
|
||||
State
|
||||
) ->
|
||||
HereSet =
|
||||
case MentionHere of
|
||||
true -> ConnectedUserIds;
|
||||
false -> gb_sets:empty()
|
||||
end,
|
||||
RoleUsersSet =
|
||||
case HasRoleMentions of
|
||||
true ->
|
||||
gb_sets:from_list(user_ids_for_any_role(gb_sets:to_list(RoleIdSet), State));
|
||||
false ->
|
||||
gb_sets:empty()
|
||||
end,
|
||||
DirectSet =
|
||||
case HasDirectMentions of
|
||||
true -> DirectUserIdSet;
|
||||
false -> gb_sets:empty()
|
||||
end,
|
||||
gb_sets:to_list(gb_sets:union(HereSet, gb_sets:union(RoleUsersSet, DirectSet))).
|
||||
|
||||
-spec user_ids_for_any_role([role_id()], guild_state()) -> [user_id()].
|
||||
user_ids_for_any_role(RoleIds, State) ->
|
||||
Data = guild_data(State),
|
||||
MemberRoleIndex = guild_data_index:member_role_index(Data),
|
||||
RoleUserIds = lists:foldl(
|
||||
fun(RoleId, AccSet) ->
|
||||
case maps:get(RoleId, MemberRoleIndex, undefined) of
|
||||
undefined ->
|
||||
AccSet;
|
||||
UserMap ->
|
||||
lists:foldl(
|
||||
fun(UserId, InnerSet) -> gb_sets:add(UserId, InnerSet) end,
|
||||
AccSet,
|
||||
maps:keys(UserMap)
|
||||
)
|
||||
end
|
||||
end,
|
||||
gb_sets:empty(),
|
||||
RoleIds
|
||||
),
|
||||
gb_sets:to_list(RoleUserIds).
|
||||
|
||||
-spec check_should_mention(
|
||||
user_id(),
|
||||
member(),
|
||||
boolean(),
|
||||
boolean(),
|
||||
boolean(),
|
||||
boolean(),
|
||||
gb_sets:set(),
|
||||
gb_sets:set(),
|
||||
gb_sets:set()
|
||||
) -> boolean().
|
||||
check_should_mention(
|
||||
UserId,
|
||||
Member,
|
||||
MentionEveryone,
|
||||
MentionHere,
|
||||
HasRoleMentions,
|
||||
HasDirectMentions,
|
||||
RoleIdSet,
|
||||
DirectUserIdSet,
|
||||
ConnectedUserIds
|
||||
) ->
|
||||
MentionEveryone orelse
|
||||
(MentionHere andalso gb_sets:is_member(UserId, ConnectedUserIds)) orelse
|
||||
(HasRoleMentions andalso member_has_any_role_set(Member, RoleIdSet)) orelse
|
||||
(HasDirectMentions andalso gb_sets:is_member(UserId, DirectUserIdSet)).
|
||||
|
||||
-spec get_members_with_role(map(), guild_state()) -> guild_reply(map()).
|
||||
get_members_with_role(#{role_id := RoleId}, State) ->
|
||||
Data = guild_data(State),
|
||||
MemberRoleIndex = guild_data_index:member_role_index(Data),
|
||||
TargetRoleId = type_conv:to_integer(RoleId),
|
||||
UserMap =
|
||||
case TargetRoleId of
|
||||
undefined ->
|
||||
#{};
|
||||
_ ->
|
||||
maps:get(TargetRoleId, MemberRoleIndex, #{})
|
||||
end,
|
||||
UserIds = lists:sort(maps:keys(UserMap)),
|
||||
{reply, #{user_ids => UserIds}, State}.
|
||||
|
||||
-spec can_manage_roles(map(), guild_state()) -> guild_reply(map()).
|
||||
can_manage_roles(#{user_id := UserId, role_id := RoleId}, State) ->
|
||||
Data = guild_data(State),
|
||||
OwnerId = owner_id(State),
|
||||
Reply =
|
||||
if
|
||||
UserId =:= OwnerId ->
|
||||
true;
|
||||
true ->
|
||||
UserPermissions = guild_permissions:get_member_permissions(
|
||||
UserId, undefined, State
|
||||
),
|
||||
case (UserPermissions band constants:manage_roles_permission()) =/= 0 of
|
||||
false ->
|
||||
false;
|
||||
true ->
|
||||
Roles = maps:get(<<"roles">>, Data, []),
|
||||
case find_role_by_id(RoleId, Roles) of
|
||||
undefined ->
|
||||
false;
|
||||
Role ->
|
||||
UserMax = guild_permissions:get_max_role_position(UserId, State),
|
||||
UserMax > role_position(Role)
|
||||
end
|
||||
end
|
||||
end,
|
||||
Reply = check_can_manage_roles(UserId, RoleId, OwnerId, Data, State),
|
||||
{reply, #{can_manage => Reply}, State}.
|
||||
|
||||
-spec check_can_manage_roles(user_id(), role_id(), user_id(), map(), guild_state()) -> boolean().
|
||||
check_can_manage_roles(UserId, _RoleId, UserId, _Data, _State) ->
|
||||
true;
|
||||
check_can_manage_roles(UserId, RoleId, _OwnerId, Data, State) ->
|
||||
UserPermissions = guild_permissions:get_member_permissions(UserId, undefined, State),
|
||||
case (UserPermissions band constants:manage_roles_permission()) =/= 0 of
|
||||
false ->
|
||||
false;
|
||||
true ->
|
||||
Roles = guild_data_index:role_index(Data),
|
||||
case find_role_by_id(RoleId, Roles) of
|
||||
undefined ->
|
||||
false;
|
||||
Role ->
|
||||
UserMax = guild_permissions:get_max_role_position(UserId, State),
|
||||
UserMax > role_position(Role)
|
||||
end
|
||||
end.
|
||||
|
||||
-spec can_manage_role(map(), guild_state()) -> guild_reply(map()).
|
||||
can_manage_role(#{user_id := UserId, role_id := RoleId}, State) ->
|
||||
Data = guild_data(State),
|
||||
Roles = maps:get(<<"roles">>, Data, []),
|
||||
Roles = guild_data_index:role_index(Data),
|
||||
Reply =
|
||||
case find_role_by_id(RoleId, Roles) of
|
||||
undefined ->
|
||||
@@ -212,6 +423,7 @@ can_manage_role(#{user_id := UserId, role_id := RoleId}, State) ->
|
||||
end,
|
||||
{reply, #{can_manage => Reply}, State}.
|
||||
|
||||
-spec compare_role_ids_for_equal_position(user_id(), role_id(), guild_state()) -> boolean().
|
||||
compare_role_ids_for_equal_position(UserId, TargetRoleId, State) ->
|
||||
case guild_permissions:find_member_by_user_id(UserId, State) of
|
||||
undefined ->
|
||||
@@ -219,9 +431,8 @@ compare_role_ids_for_equal_position(UserId, TargetRoleId, State) ->
|
||||
Member ->
|
||||
MemberRoles = member_roles(Member),
|
||||
Data = guild_data(State),
|
||||
Roles = maps:get(<<"roles">>, Data, []),
|
||||
UserHighestRole = get_highest_role(MemberRoles, Roles),
|
||||
case UserHighestRole of
|
||||
Roles = guild_data_index:role_index(Data),
|
||||
case get_highest_role(MemberRoles, Roles) of
|
||||
undefined ->
|
||||
false;
|
||||
HighestRole ->
|
||||
@@ -230,39 +441,42 @@ compare_role_ids_for_equal_position(UserId, TargetRoleId, State) ->
|
||||
end
|
||||
end.
|
||||
|
||||
-spec get_highest_role([role_id()], [role()] | map()) -> role() | undefined.
|
||||
get_highest_role(MemberRoleIds, Roles) ->
|
||||
lists:foldl(
|
||||
fun(RoleId, Acc) ->
|
||||
case find_role_by_id(RoleId, Roles) of
|
||||
undefined ->
|
||||
Acc;
|
||||
Role ->
|
||||
case Acc of
|
||||
undefined ->
|
||||
Role;
|
||||
AccRole ->
|
||||
AccPos = role_position(AccRole),
|
||||
RolePos = role_position(Role),
|
||||
if
|
||||
RolePos > AccPos ->
|
||||
Role;
|
||||
RolePos =:= AccPos ->
|
||||
AccId = map_utils:get_integer(AccRole, <<"id">>, 0),
|
||||
RId = map_utils:get_integer(Role, <<"id">>, 0),
|
||||
if
|
||||
RId < AccId -> Role;
|
||||
true -> AccRole
|
||||
end;
|
||||
true ->
|
||||
AccRole
|
||||
end
|
||||
end
|
||||
undefined -> Acc;
|
||||
Role -> compare_roles(Role, Acc)
|
||||
end
|
||||
end,
|
||||
undefined,
|
||||
MemberRoleIds
|
||||
).
|
||||
|
||||
-spec compare_roles(role(), role() | undefined) -> role().
|
||||
compare_roles(Role, undefined) ->
|
||||
Role;
|
||||
compare_roles(Role, AccRole) ->
|
||||
AccPos = role_position(AccRole),
|
||||
RolePos = role_position(Role),
|
||||
case RolePos > AccPos of
|
||||
true ->
|
||||
Role;
|
||||
false ->
|
||||
case RolePos =:= AccPos of
|
||||
true ->
|
||||
AccId = map_utils:get_integer(AccRole, <<"id">>, 0),
|
||||
RId = map_utils:get_integer(Role, <<"id">>, 0),
|
||||
case RId < AccId of
|
||||
true -> Role;
|
||||
false -> AccRole
|
||||
end;
|
||||
false ->
|
||||
AccRole
|
||||
end
|
||||
end.
|
||||
|
||||
-spec get_assignable_roles(map(), guild_state()) -> guild_reply(map()).
|
||||
get_assignable_roles(#{user_id := UserId}, State) ->
|
||||
Roles = guild_roles(State),
|
||||
@@ -270,6 +484,7 @@ get_assignable_roles(#{user_id := UserId}, State) ->
|
||||
RoleIds = get_assignable_role_ids(UserId, OwnerId, Roles, State),
|
||||
{reply, #{role_ids => RoleIds}, State}.
|
||||
|
||||
-spec get_assignable_role_ids(user_id(), user_id(), [role()], guild_state()) -> [role_id()].
|
||||
get_assignable_role_ids(OwnerId, OwnerId, Roles, _State) ->
|
||||
role_ids_from_roles(Roles);
|
||||
get_assignable_role_ids(UserId, _OwnerId, Roles, State) ->
|
||||
@@ -279,6 +494,7 @@ get_assignable_role_ids(UserId, _OwnerId, Roles, State) ->
|
||||
Roles
|
||||
).
|
||||
|
||||
-spec filter_assignable_role(role(), integer()) -> {true, role_id()} | false.
|
||||
filter_assignable_role(Role, UserMaxPosition) ->
|
||||
case role_position(Role) < UserMaxPosition of
|
||||
true ->
|
||||
@@ -293,19 +509,19 @@ filter_assignable_role(Role, UserMaxPosition) ->
|
||||
-spec check_target_member(map(), guild_state()) -> guild_reply(map()).
|
||||
check_target_member(#{user_id := UserId, target_user_id := TargetUserId}, State) ->
|
||||
OwnerId = owner_id(State),
|
||||
CanManage =
|
||||
if
|
||||
UserId =:= OwnerId ->
|
||||
true;
|
||||
TargetUserId =:= OwnerId ->
|
||||
false;
|
||||
true ->
|
||||
UserMaxPos = guild_permissions:get_max_role_position(UserId, State),
|
||||
TargetMaxPos = guild_permissions:get_max_role_position(TargetUserId, State),
|
||||
UserMaxPos > TargetMaxPos
|
||||
end,
|
||||
CanManage = check_can_manage_target(UserId, TargetUserId, OwnerId, State),
|
||||
{reply, #{can_manage => CanManage}, State}.
|
||||
|
||||
-spec check_can_manage_target(user_id(), user_id(), user_id(), guild_state()) -> boolean().
|
||||
check_can_manage_target(UserId, _TargetUserId, UserId, _State) ->
|
||||
true;
|
||||
check_can_manage_target(_UserId, OwnerId, OwnerId, _State) ->
|
||||
false;
|
||||
check_can_manage_target(UserId, TargetUserId, _OwnerId, State) ->
|
||||
UserMaxPos = guild_permissions:get_max_role_position(UserId, State),
|
||||
TargetMaxPos = guild_permissions:get_max_role_position(TargetUserId, State),
|
||||
UserMaxPos > TargetMaxPos.
|
||||
|
||||
-spec get_viewable_channels(map(), guild_state()) -> guild_reply(map()).
|
||||
get_viewable_channels(#{user_id := UserId}, State) ->
|
||||
Channels = guild_channels(State),
|
||||
@@ -313,29 +529,33 @@ get_viewable_channels(#{user_id := UserId}, State) ->
|
||||
undefined ->
|
||||
{reply, #{channel_ids => []}, State};
|
||||
Member ->
|
||||
ChannelIds = lists:filtermap(
|
||||
fun(Channel) ->
|
||||
ChannelId = map_utils:get_integer(Channel, <<"id">>, undefined),
|
||||
case ChannelId of
|
||||
undefined ->
|
||||
false;
|
||||
_ ->
|
||||
case
|
||||
guild_permissions:can_view_channel(UserId, ChannelId, Member, State)
|
||||
of
|
||||
true -> {true, ChannelId};
|
||||
false -> false
|
||||
end
|
||||
end
|
||||
end,
|
||||
Channels
|
||||
),
|
||||
ChannelIds = filter_viewable_channels(Channels, UserId, Member, State),
|
||||
{reply, #{channel_ids => ChannelIds}, State}
|
||||
end.
|
||||
|
||||
-spec filter_viewable_channels([channel()], user_id(), member(), guild_state()) -> [channel_id()].
|
||||
filter_viewable_channels(Channels, UserId, Member, State) ->
|
||||
lists:filtermap(
|
||||
fun(Channel) ->
|
||||
ChannelId = map_utils:get_integer(Channel, <<"id">>, undefined),
|
||||
case ChannelId of
|
||||
undefined ->
|
||||
false;
|
||||
_ ->
|
||||
case guild_permissions:can_view_channel(UserId, ChannelId, Member, State) of
|
||||
true -> {true, ChannelId};
|
||||
false -> false
|
||||
end
|
||||
end
|
||||
end,
|
||||
Channels
|
||||
).
|
||||
|
||||
-spec find_member_by_user_id(user_id(), guild_state()) -> member() | undefined.
|
||||
find_member_by_user_id(UserId, State) ->
|
||||
guild_permissions:find_member_by_user_id(UserId, State).
|
||||
|
||||
-spec find_role_by_id(role_id(), [role()] | map()) -> role() | undefined.
|
||||
find_role_by_id(RoleId, Roles) ->
|
||||
guild_permissions:find_role_by_id(RoleId, Roles).
|
||||
|
||||
@@ -345,15 +565,15 @@ guild_data(State) ->
|
||||
|
||||
-spec guild_members(guild_state()) -> [member()].
|
||||
guild_members(State) ->
|
||||
map_utils:ensure_list(maps:get(<<"members">>, guild_data(State), [])).
|
||||
guild_data_index:member_values(guild_data(State)).
|
||||
|
||||
-spec guild_roles(guild_state()) -> [role()].
|
||||
guild_roles(State) ->
|
||||
map_utils:ensure_list(maps:get(<<"roles">>, guild_data(State), [])).
|
||||
guild_data_index:role_list(guild_data(State)).
|
||||
|
||||
-spec guild_channels(guild_state()) -> [channel()].
|
||||
guild_channels(State) ->
|
||||
map_utils:ensure_list(maps:get(<<"channels">>, guild_data(State), [])).
|
||||
guild_data_index:channel_list(guild_data(State)).
|
||||
|
||||
-spec owner_id(guild_state()) -> user_id().
|
||||
owner_id(State) ->
|
||||
@@ -369,11 +589,6 @@ member_user_id(Member) ->
|
||||
member_roles(Member) ->
|
||||
normalize_int_list(map_utils:ensure_list(maps:get(<<"roles">>, Member, []))).
|
||||
|
||||
-spec member_has_any_role(member(), [role_id()]) -> boolean().
|
||||
member_has_any_role(Member, RoleIds) ->
|
||||
MemberRoles = member_roles(Member),
|
||||
lists:any(fun(RoleId) -> lists:member(RoleId, MemberRoles) end, RoleIds).
|
||||
|
||||
-spec member_has_any_role_set(member(), gb_sets:set(role_id())) -> boolean().
|
||||
member_has_any_role_set(Member, RoleIdSet) ->
|
||||
MemberRoles = member_roles(Member),
|
||||
@@ -390,6 +605,39 @@ member_can_view_channel(UserId, ChannelId, Member, State) when is_integer(Channe
|
||||
member_can_view_channel(_, _, _, _) ->
|
||||
false.
|
||||
|
||||
-spec collect_mentions_for_user_ids(
|
||||
[user_id()],
|
||||
user_id(),
|
||||
channel_id(),
|
||||
guild_state(),
|
||||
fun((user_id(), member()) -> boolean())
|
||||
) ->
|
||||
[user_id()].
|
||||
collect_mentions_for_user_ids(UserIds, AuthorId, ChannelId, State, Predicate) ->
|
||||
MemberMap = guild_data_index:member_map(guild_data(State)),
|
||||
lists:filtermap(
|
||||
fun(UserId) ->
|
||||
case UserId =:= AuthorId of
|
||||
true ->
|
||||
false;
|
||||
false ->
|
||||
case maps:get(UserId, MemberMap, undefined) of
|
||||
undefined ->
|
||||
false;
|
||||
Member ->
|
||||
case
|
||||
Predicate(UserId, Member) andalso
|
||||
member_can_view_channel(UserId, ChannelId, Member, State)
|
||||
of
|
||||
true -> {true, UserId};
|
||||
false -> false
|
||||
end
|
||||
end
|
||||
end
|
||||
end,
|
||||
lists:usort(UserIds)
|
||||
).
|
||||
|
||||
-spec collect_mentions([member()], user_id(), channel_id(), guild_state(), fun(
|
||||
(member()) -> boolean()
|
||||
)) ->
|
||||
@@ -446,6 +694,7 @@ role_position(Role) ->
|
||||
maps:get(<<"position">>, Role, 0).
|
||||
|
||||
-ifdef(TEST).
|
||||
-include_lib("eunit/include/eunit.hrl").
|
||||
|
||||
get_users_to_mention_by_roles_basic_test() ->
|
||||
State = test_state(),
|
||||
@@ -455,6 +704,30 @@ get_users_to_mention_by_roles_basic_test() ->
|
||||
{reply, #{user_ids := UserIds}, _} = get_users_to_mention_by_roles(Request, State),
|
||||
?assertEqual([2], UserIds).
|
||||
|
||||
get_users_to_mention_by_user_ids_filters_unknown_members_test() ->
|
||||
State = test_state(),
|
||||
Request = #{channel_id => 500, user_ids => [2, 999], author_id => 1},
|
||||
{reply, #{user_ids := UserIds}, _} = get_users_to_mention_by_user_ids(Request, State),
|
||||
?assertEqual([2], UserIds).
|
||||
|
||||
get_members_with_role_uses_member_role_index_test() ->
|
||||
State = test_state(),
|
||||
{reply, #{user_ids := UserIds}, _} = get_members_with_role(#{role_id => 200}, State),
|
||||
?assertEqual([2], UserIds).
|
||||
|
||||
resolve_all_mentions_merges_role_and_direct_candidates_test() ->
|
||||
State = test_state(),
|
||||
Request = #{
|
||||
channel_id => 500,
|
||||
author_id => 1,
|
||||
mention_everyone => false,
|
||||
mention_here => false,
|
||||
role_ids => [200],
|
||||
user_ids => [3]
|
||||
},
|
||||
{reply, #{user_ids := UserIds}, _} = resolve_all_mentions(Request, State),
|
||||
?assertEqual([2, 3], UserIds).
|
||||
|
||||
get_assignable_roles_owner_test() ->
|
||||
State = test_state(),
|
||||
{reply, #{role_ids := RoleIds}, _} = get_assignable_roles(#{user_id => 1}, State),
|
||||
@@ -470,6 +743,32 @@ get_viewable_channels_filters_test() ->
|
||||
{reply, #{channel_ids := ChannelIds}, _} = get_viewable_channels(#{user_id => 2}, State),
|
||||
?assert(lists:member(500, ChannelIds)).
|
||||
|
||||
member_has_any_role_set_test() ->
|
||||
Member = #{<<"roles">> => [<<"100">>, <<"200">>]},
|
||||
RoleSet = gb_sets:from_list([200]),
|
||||
MissingRoleSet = gb_sets:from_list([999]),
|
||||
?assertEqual(true, member_has_any_role_set(Member, RoleSet)),
|
||||
?assertEqual(false, member_has_any_role_set(Member, MissingRoleSet)).
|
||||
|
||||
is_member_bot_test() ->
|
||||
BotMember = #{<<"user">> => #{<<"bot">> => true}},
|
||||
HumanMember = #{<<"user">> => #{<<"bot">> => false}},
|
||||
?assertEqual(true, is_member_bot(BotMember)),
|
||||
?assertEqual(false, is_member_bot(HumanMember)).
|
||||
|
||||
normalize_int_list_test() ->
|
||||
?assertEqual([1, 2, 3], normalize_int_list([<<"1">>, <<"2">>, <<"3">>])),
|
||||
?assertEqual([1, 2], normalize_int_list([1, 2])),
|
||||
?assertEqual([], normalize_int_list([])).
|
||||
|
||||
role_ids_from_roles_test() ->
|
||||
Roles = [
|
||||
#{<<"id">> => <<"100">>},
|
||||
#{<<"id">> => <<"200">>},
|
||||
#{<<"name">> => <<"no_id">>}
|
||||
],
|
||||
?assertEqual([100, 200], role_ids_from_roles(Roles)).
|
||||
|
||||
test_state() ->
|
||||
GuildId = 100,
|
||||
OwnerId = 1,
|
||||
|
||||
@@ -21,7 +21,8 @@
|
||||
schedule_passive_sync/1,
|
||||
handle_passive_sync/1,
|
||||
send_passive_updates_to_sessions/1,
|
||||
compute_delta/2
|
||||
compute_delta/2,
|
||||
compute_channel_diffs/2
|
||||
]).
|
||||
|
||||
-ifdef(TEST).
|
||||
@@ -30,71 +31,156 @@
|
||||
|
||||
-define(PASSIVE_SYNC_INTERVAL, 30000).
|
||||
|
||||
-type guild_state() :: map().
|
||||
-type channel_id() :: binary().
|
||||
-type last_message_id() :: binary().
|
||||
-type version() :: integer().
|
||||
-type voice_state() :: map().
|
||||
|
||||
-spec schedule_passive_sync(guild_state()) -> guild_state().
|
||||
schedule_passive_sync(State) ->
|
||||
erlang:send_after(?PASSIVE_SYNC_INTERVAL, self(), passive_sync),
|
||||
State.
|
||||
|
||||
-spec handle_passive_sync(guild_state()) -> {noreply, guild_state()}.
|
||||
handle_passive_sync(State) ->
|
||||
NewState = send_passive_updates_to_sessions(State),
|
||||
schedule_passive_sync(NewState),
|
||||
{noreply, NewState}.
|
||||
|
||||
-spec send_passive_updates_to_sessions(guild_state()) -> guild_state().
|
||||
send_passive_updates_to_sessions(State) ->
|
||||
GuildId = maps:get(id, State),
|
||||
Sessions = maps:get(sessions, State, #{}),
|
||||
Data = maps:get(data, State, #{}),
|
||||
Channels = maps:get(<<"channels">>, Data, []),
|
||||
|
||||
MemberCount = maps:get(member_count, State, undefined),
|
||||
|
||||
IsLargeGuild = case MemberCount of
|
||||
undefined -> false;
|
||||
Count when is_integer(Count) -> Count > 250
|
||||
end,
|
||||
|
||||
IsLargeGuild =
|
||||
case MemberCount of
|
||||
undefined -> false;
|
||||
Count when is_integer(Count) -> Count > 250
|
||||
end,
|
||||
PassiveSessions = maps:filter(
|
||||
fun(_SessionId, SessionData) ->
|
||||
IsLargeGuild andalso session_passive:is_passive(GuildId, SessionData)
|
||||
end,
|
||||
Sessions
|
||||
),
|
||||
|
||||
case map_size(PassiveSessions) of
|
||||
0 ->
|
||||
State;
|
||||
_ ->
|
||||
UpdatedSessions = lists:foldl(
|
||||
fun({SessionId, SessionData}, AccSessions) ->
|
||||
Pid = maps:get(pid, SessionData),
|
||||
UserId = maps:get(user_id, SessionData),
|
||||
Member = guild_permissions:find_member_by_user_id(UserId, State),
|
||||
|
||||
CurrentLastMessageIds = build_last_message_ids(Channels, UserId, Member, State),
|
||||
PreviousLastMessageIds = maps:get(previous_passive_updates, SessionData, #{}),
|
||||
Delta = compute_delta(CurrentLastMessageIds, PreviousLastMessageIds),
|
||||
|
||||
case {map_size(Delta), is_pid(Pid)} of
|
||||
{0, _} ->
|
||||
AccSessions;
|
||||
{_, true} ->
|
||||
EventData = #{
|
||||
<<"guild_id">> => integer_to_binary(GuildId),
|
||||
<<"channels">> => Delta
|
||||
},
|
||||
gen_server:cast(Pid, {dispatch, passive_updates, EventData}),
|
||||
MergedLastMessageIds = maps:merge(PreviousLastMessageIds, Delta),
|
||||
UpdatedSessionData = maps:put(previous_passive_updates, MergedLastMessageIds, SessionData),
|
||||
maps:put(SessionId, UpdatedSessionData, AccSessions);
|
||||
_ ->
|
||||
AccSessions
|
||||
end
|
||||
end,
|
||||
Sessions,
|
||||
maps:to_list(PassiveSessions)
|
||||
UpdatedSessions = process_passive_sessions(
|
||||
maps:to_list(PassiveSessions), GuildId, Sessions, Channels, State
|
||||
),
|
||||
maps:put(sessions, UpdatedSessions, State)
|
||||
end.
|
||||
|
||||
-spec process_passive_sessions([{binary(), map()}], integer(), map(), [map()], guild_state()) ->
|
||||
map().
|
||||
process_passive_sessions(PassiveSessionList, GuildId, Sessions, Channels, State) ->
|
||||
lists:foldl(
|
||||
fun({SessionId, SessionData}, AccSessions) ->
|
||||
process_single_passive_session(
|
||||
SessionId, SessionData, GuildId, Channels, State, AccSessions
|
||||
)
|
||||
end,
|
||||
Sessions,
|
||||
PassiveSessionList
|
||||
).
|
||||
|
||||
-spec process_single_passive_session(binary(), map(), integer(), [map()], guild_state(), map()) ->
|
||||
map().
|
||||
process_single_passive_session(SessionId, SessionData, GuildId, Channels, State, AccSessions) ->
|
||||
Pid = maps:get(pid, SessionData),
|
||||
UserId = maps:get(user_id, SessionData),
|
||||
Member = guild_permissions:find_member_by_user_id(UserId, State),
|
||||
CurrentLastMessageIds = build_last_message_ids(Channels, UserId, Member, State),
|
||||
PreviousLastMessageIds = maps:get(previous_passive_updates, SessionData, #{}),
|
||||
Delta = compute_delta(CurrentLastMessageIds, PreviousLastMessageIds),
|
||||
PreviousChannelVersions = maps:get(previous_passive_channel_versions, SessionData, #{}),
|
||||
{CurrentChannelVersions, CurrentChannelsById} =
|
||||
build_viewable_channel_snapshots(Channels, UserId, Member, State),
|
||||
{CreatedChannelIds, UpdatedChannelIds, DeletedChannelIds} =
|
||||
compute_channel_diffs(CurrentChannelVersions, PreviousChannelVersions),
|
||||
CreatedChannels = [maps:get(Id, CurrentChannelsById) || Id <- CreatedChannelIds],
|
||||
UpdatedChannels = [maps:get(Id, CurrentChannelsById) || Id <- UpdatedChannelIds],
|
||||
ViewableChannels = guild_visibility:viewable_channel_set(UserId, State),
|
||||
CurrentVoiceStates = build_current_voice_state_map(ViewableChannels, State),
|
||||
PreviousVoiceStates = maps:get(previous_passive_voice_states, SessionData, #{}),
|
||||
VoiceStateUpdates = compute_voice_state_updates(
|
||||
CurrentVoiceStates, PreviousVoiceStates, GuildId
|
||||
),
|
||||
UpdatedSessionDataBase =
|
||||
maps:put(previous_passive_voice_states, CurrentVoiceStates, SessionData),
|
||||
HasChannelDelta = map_size(Delta) > 0,
|
||||
HasVoiceUpdates = VoiceStateUpdates =/= [],
|
||||
HasCreatedChannels = CreatedChannels =/= [],
|
||||
HasUpdatedChannels = UpdatedChannels =/= [],
|
||||
HasDeletedChannels = DeletedChannelIds =/= [],
|
||||
ShouldSend =
|
||||
HasChannelDelta orelse HasVoiceUpdates orelse
|
||||
HasCreatedChannels orelse HasUpdatedChannels orelse HasDeletedChannels,
|
||||
case {ShouldSend, is_pid(Pid)} of
|
||||
{true, true} ->
|
||||
EventData = build_passive_event_data(
|
||||
GuildId,
|
||||
Delta,
|
||||
CreatedChannels,
|
||||
UpdatedChannels,
|
||||
DeletedChannelIds,
|
||||
VoiceStateUpdates
|
||||
),
|
||||
gen_server:cast(Pid, {dispatch, passive_updates, EventData}),
|
||||
PreviousLastMessageIds1 = maps:without(DeletedChannelIds, PreviousLastMessageIds),
|
||||
MergedLastMessageIds = maps:merge(PreviousLastMessageIds1, Delta),
|
||||
UpdatedSessionData0 =
|
||||
maps:put(previous_passive_updates, MergedLastMessageIds, UpdatedSessionDataBase),
|
||||
UpdatedSessionData =
|
||||
maps:put(
|
||||
previous_passive_channel_versions, CurrentChannelVersions, UpdatedSessionData0
|
||||
),
|
||||
maps:put(SessionId, UpdatedSessionData, AccSessions);
|
||||
_ ->
|
||||
UpdatedSessionData =
|
||||
maps:put(
|
||||
previous_passive_channel_versions,
|
||||
CurrentChannelVersions,
|
||||
UpdatedSessionDataBase
|
||||
),
|
||||
maps:put(SessionId, UpdatedSessionData, AccSessions)
|
||||
end.
|
||||
|
||||
-spec build_passive_event_data(integer(), map(), [map()], [map()], [binary()], [map()]) -> map().
|
||||
build_passive_event_data(
|
||||
GuildId, Delta, CreatedChannels, UpdatedChannels, DeletedChannelIds, VoiceStateUpdates
|
||||
) ->
|
||||
EventDataBase = #{
|
||||
<<"guild_id">> => integer_to_binary(GuildId),
|
||||
<<"channels">> => Delta
|
||||
},
|
||||
EventData1 =
|
||||
case CreatedChannels of
|
||||
[] -> EventDataBase;
|
||||
_ -> maps:put(<<"created_channels">>, CreatedChannels, EventDataBase)
|
||||
end,
|
||||
EventData2 =
|
||||
case UpdatedChannels of
|
||||
[] -> EventData1;
|
||||
_ -> maps:put(<<"updated_channels">>, UpdatedChannels, EventData1)
|
||||
end,
|
||||
EventData3 =
|
||||
case DeletedChannelIds of
|
||||
[] -> EventData2;
|
||||
_ -> maps:put(<<"deleted_channel_ids">>, DeletedChannelIds, EventData2)
|
||||
end,
|
||||
case VoiceStateUpdates of
|
||||
[] -> EventData3;
|
||||
_ -> maps:put(<<"voice_states">>, VoiceStateUpdates, EventData3)
|
||||
end.
|
||||
|
||||
-spec compute_delta(#{channel_id() => last_message_id()}, #{channel_id() => last_message_id()}) ->
|
||||
#{channel_id() => last_message_id()}.
|
||||
compute_delta(CurrentLastMessageIds, PreviousLastMessageIds) ->
|
||||
maps:filter(
|
||||
fun(ChannelId, CurrentValue) ->
|
||||
@@ -106,6 +192,23 @@ compute_delta(CurrentLastMessageIds, PreviousLastMessageIds) ->
|
||||
CurrentLastMessageIds
|
||||
).
|
||||
|
||||
-spec compute_channel_diffs(#{channel_id() => version()}, #{channel_id() => version()}) ->
|
||||
{[channel_id()], [channel_id()], [channel_id()]}.
|
||||
compute_channel_diffs(Current, Previous) ->
|
||||
Created =
|
||||
[Id || {Id, _} <- maps:to_list(Current), not maps:is_key(Id, Previous)],
|
||||
Updated =
|
||||
[
|
||||
Id
|
||||
|| {Id, CurV} <- maps:to_list(Current),
|
||||
maps:is_key(Id, Previous) andalso maps:get(Id, Previous) =/= CurV
|
||||
],
|
||||
Deleted =
|
||||
[Id || {Id, _} <- maps:to_list(Previous), not maps:is_key(Id, Current)],
|
||||
{Created, Updated, Deleted}.
|
||||
|
||||
-spec build_last_message_ids([map()], integer(), map() | undefined, guild_state()) ->
|
||||
#{channel_id() => last_message_id()}.
|
||||
build_last_message_ids(Channels, UserId, Member, State) ->
|
||||
lists:foldl(
|
||||
fun(Channel, Acc) ->
|
||||
@@ -117,16 +220,22 @@ build_last_message_ids(Channels, UserId, Member, State) ->
|
||||
{_, null} ->
|
||||
Acc;
|
||||
_ ->
|
||||
ChannelId = validation:snowflake_or_default(<<"id">>, ChannelIdBin, 0),
|
||||
case Member of
|
||||
case parse_snowflake(<<"id">>, ChannelIdBin) of
|
||||
undefined ->
|
||||
Acc;
|
||||
_ ->
|
||||
case guild_permissions:can_view_channel(UserId, ChannelId, Member, State) of
|
||||
true ->
|
||||
maps:put(ChannelIdBin, LastMessageId, Acc);
|
||||
false ->
|
||||
Acc
|
||||
ChannelId ->
|
||||
case Member of
|
||||
undefined ->
|
||||
Acc;
|
||||
_ ->
|
||||
case
|
||||
guild_permissions:can_view_channel(UserId, ChannelId, Member, State)
|
||||
of
|
||||
true ->
|
||||
maps:put(ChannelIdBin, LastMessageId, Acc);
|
||||
false ->
|
||||
Acc
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
@@ -135,41 +244,196 @@ build_last_message_ids(Channels, UserId, Member, State) ->
|
||||
Channels
|
||||
).
|
||||
|
||||
-spec build_viewable_channel_snapshots([map()], integer(), map() | undefined, guild_state()) ->
|
||||
{#{channel_id() => version()}, #{channel_id() => map()}}.
|
||||
build_viewable_channel_snapshots(Channels, UserId, Member, State) ->
|
||||
case Member of
|
||||
undefined ->
|
||||
{#{}, #{}};
|
||||
_ ->
|
||||
lists:foldl(
|
||||
fun(Channel, {VersionsAcc, ChannelsAcc}) ->
|
||||
ChannelIdBin = maps:get(<<"id">>, Channel, undefined),
|
||||
case ChannelIdBin of
|
||||
undefined ->
|
||||
{VersionsAcc, ChannelsAcc};
|
||||
_ ->
|
||||
case parse_snowflake(<<"id">>, ChannelIdBin) of
|
||||
undefined ->
|
||||
{VersionsAcc, ChannelsAcc};
|
||||
ChannelId ->
|
||||
case
|
||||
guild_permissions:can_view_channel(UserId, ChannelId, Member, State)
|
||||
of
|
||||
true ->
|
||||
Version = map_utils:get_integer(Channel, <<"version">>, 0),
|
||||
{
|
||||
maps:put(ChannelIdBin, Version, VersionsAcc),
|
||||
maps:put(ChannelIdBin, Channel, ChannelsAcc)
|
||||
};
|
||||
false ->
|
||||
{VersionsAcc, ChannelsAcc}
|
||||
end
|
||||
end
|
||||
end
|
||||
end,
|
||||
{#{}, #{}},
|
||||
Channels
|
||||
)
|
||||
end.
|
||||
|
||||
-spec build_current_voice_state_map(sets:set(), guild_state()) -> #{binary() => voice_state()}.
|
||||
build_current_voice_state_map(ViewableChannels, State) ->
|
||||
VoiceStates = maps:get(voice_states, State, #{}),
|
||||
maps:fold(
|
||||
fun(ConnectionId, VoiceState, Acc) ->
|
||||
ChannelIdBin = maps:get(<<"channel_id">>, VoiceState, null),
|
||||
case ChannelIdBin of
|
||||
null ->
|
||||
Acc;
|
||||
_ ->
|
||||
case parse_snowflake(<<"channel_id">>, ChannelIdBin) of
|
||||
undefined ->
|
||||
Acc;
|
||||
ChannelId ->
|
||||
case sets:is_element(ChannelId, ViewableChannels) of
|
||||
true -> maps:put(ConnectionId, VoiceState, Acc);
|
||||
false -> Acc
|
||||
end
|
||||
end
|
||||
end
|
||||
end,
|
||||
#{},
|
||||
VoiceStates
|
||||
).
|
||||
|
||||
-spec parse_snowflake(binary(), term()) -> integer() | undefined.
|
||||
parse_snowflake(FieldName, Value) ->
|
||||
case validation:validate_snowflake(FieldName, Value) of
|
||||
{ok, Id} -> Id;
|
||||
{error, _, _} -> undefined
|
||||
end.
|
||||
|
||||
-spec compute_voice_state_updates(
|
||||
#{binary() => voice_state()}, #{binary() => voice_state()}, integer()
|
||||
) ->
|
||||
[voice_state()].
|
||||
compute_voice_state_updates(Current, Previous, GuildId) ->
|
||||
GuildIdBin = integer_to_binary(GuildId),
|
||||
Updated = maps:fold(
|
||||
fun(ConnectionId, VoiceState, Acc) ->
|
||||
PrevState = maps:get(ConnectionId, Previous, undefined),
|
||||
case is_voice_state_changed(VoiceState, PrevState) of
|
||||
true -> [ensure_voice_state_guild(VoiceState, GuildIdBin) | Acc];
|
||||
false -> Acc
|
||||
end
|
||||
end,
|
||||
[],
|
||||
Current
|
||||
),
|
||||
Removed = maps:fold(
|
||||
fun(ConnectionId, PrevState, Acc) ->
|
||||
case maps:is_key(ConnectionId, Current) of
|
||||
true -> Acc;
|
||||
false -> [build_removed_voice_state(PrevState, GuildIdBin) | Acc]
|
||||
end
|
||||
end,
|
||||
[],
|
||||
Previous
|
||||
),
|
||||
lists:reverse(Updated) ++ lists:reverse(Removed).
|
||||
|
||||
-spec is_voice_state_changed(voice_state(), voice_state() | undefined) -> boolean().
|
||||
is_voice_state_changed(_Current, undefined) ->
|
||||
true;
|
||||
is_voice_state_changed(Current, Previous) ->
|
||||
voice_state_version(Current) =/= voice_state_version(Previous).
|
||||
|
||||
-spec voice_state_version(voice_state()) -> integer().
|
||||
voice_state_version(VoiceState) ->
|
||||
map_utils:get_integer(VoiceState, <<"version">>, 0).
|
||||
|
||||
-spec ensure_voice_state_guild(voice_state(), binary()) -> voice_state().
|
||||
ensure_voice_state_guild(VoiceState, GuildIdBin) ->
|
||||
case maps:get(<<"guild_id">>, VoiceState, undefined) of
|
||||
undefined -> maps:put(<<"guild_id">>, GuildIdBin, VoiceState);
|
||||
_ -> VoiceState
|
||||
end.
|
||||
|
||||
-spec build_removed_voice_state(voice_state(), binary()) -> voice_state().
|
||||
build_removed_voice_state(PrevState, GuildIdBin) ->
|
||||
PrevWithGuild = ensure_voice_state_guild(PrevState, GuildIdBin),
|
||||
maps:put(<<"channel_id">>, null, PrevWithGuild).
|
||||
|
||||
-ifdef(TEST).
|
||||
|
||||
compute_delta_empty_previous_test() ->
|
||||
Current = #{<<"1">> => <<"100">>, <<"2">> => <<"200">>},
|
||||
Previous = #{},
|
||||
Delta = compute_delta(Current, Previous),
|
||||
?assertEqual(Current, Delta),
|
||||
ok.
|
||||
?assertEqual(Current, Delta).
|
||||
|
||||
compute_delta_no_changes_test() ->
|
||||
Current = #{<<"1">> => <<"100">>, <<"2">> => <<"200">>},
|
||||
Previous = #{<<"1">> => <<"100">>, <<"2">> => <<"200">>},
|
||||
Delta = compute_delta(Current, Previous),
|
||||
?assertEqual(#{}, Delta),
|
||||
ok.
|
||||
?assertEqual(#{}, Delta).
|
||||
|
||||
compute_delta_partial_changes_test() ->
|
||||
Current = #{<<"1">> => <<"101">>, <<"2">> => <<"200">>, <<"3">> => <<"300">>},
|
||||
Previous = #{<<"1">> => <<"100">>, <<"2">> => <<"200">>},
|
||||
Delta = compute_delta(Current, Previous),
|
||||
?assertEqual(#{<<"1">> => <<"101">>, <<"3">> => <<"300">>}, Delta),
|
||||
ok.
|
||||
?assertEqual(#{<<"1">> => <<"101">>, <<"3">> => <<"300">>}, Delta).
|
||||
|
||||
compute_delta_only_new_channels_test() ->
|
||||
Current = #{<<"1">> => <<"100">>, <<"2">> => <<"200">>, <<"3">> => <<"300">>},
|
||||
Previous = #{<<"1">> => <<"100">>, <<"2">> => <<"200">>},
|
||||
Delta = compute_delta(Current, Previous),
|
||||
?assertEqual(#{<<"3">> => <<"300">>}, Delta),
|
||||
ok.
|
||||
?assertEqual(#{<<"3">> => <<"300">>}, Delta).
|
||||
|
||||
compute_delta_ignores_removed_channels_test() ->
|
||||
Current = #{<<"1">> => <<"100">>},
|
||||
Previous = #{<<"1">> => <<"100">>, <<"2">> => <<"200">>},
|
||||
Delta = compute_delta(Current, Previous),
|
||||
?assertEqual(#{}, Delta),
|
||||
ok.
|
||||
?assertEqual(#{}, Delta).
|
||||
|
||||
compute_channel_diffs_detects_created_updated_deleted_test() ->
|
||||
Current = #{<<"1">> => 2, <<"2">> => 1},
|
||||
Previous = #{<<"1">> => 1, <<"3">> => 9},
|
||||
{Created, Updated, Deleted} = compute_channel_diffs(Current, Previous),
|
||||
?assertEqual([<<"2">>], lists:sort(Created)),
|
||||
?assertEqual([<<"1">>], lists:sort(Updated)),
|
||||
?assertEqual([<<"3">>], lists:sort(Deleted)).
|
||||
|
||||
compute_voice_state_updates_reports_changes_test() ->
|
||||
PrevVoiceState = #{
|
||||
<<"connection_id">> => <<"conn1">>,
|
||||
<<"channel_id">> => <<"100">>,
|
||||
<<"user_id">> => <<"200">>,
|
||||
<<"version">> => 1
|
||||
},
|
||||
CurrentVoiceState = maps:put(<<"version">>, 2, PrevVoiceState),
|
||||
Current = #{<<"conn1">> => CurrentVoiceState},
|
||||
Previous = #{<<"conn1">> => PrevVoiceState},
|
||||
Updates = compute_voice_state_updates(Current, Previous, 999),
|
||||
?assertEqual(1, length(Updates)),
|
||||
Update = hd(Updates),
|
||||
?assertEqual(<<"conn1">>, maps:get(<<"connection_id">>, Update)),
|
||||
?assertEqual(integer_to_binary(999), maps:get(<<"guild_id">>, Update)).
|
||||
|
||||
compute_voice_state_updates_reports_removal_test() ->
|
||||
RemovedVoiceState = #{
|
||||
<<"connection_id">> => <<"conn2">>,
|
||||
<<"channel_id">> => <<"200">>,
|
||||
<<"user_id">> => <<"300">>,
|
||||
<<"version">> => 3
|
||||
},
|
||||
Current = #{},
|
||||
Previous = #{<<"conn2">> => RemovedVoiceState},
|
||||
Updates = compute_voice_state_updates(Current, Previous, 101),
|
||||
?assertEqual(1, length(Updates)),
|
||||
Update = hd(Updates),
|
||||
?assertEqual(null, maps:get(<<"channel_id">>, Update)),
|
||||
?assertEqual(integer_to_binary(101), maps:get(<<"guild_id">>, Update)).
|
||||
|
||||
-endif.
|
||||
|
||||
132
fluxer_gateway/src/guild/guild_permission_cache.erl
Normal file
132
fluxer_gateway/src/guild/guild_permission_cache.erl
Normal file
@@ -0,0 +1,132 @@
|
||||
%% Copyright (C) 2026 Fluxer Contributors
|
||||
%%
|
||||
%% This file is part of Fluxer.
|
||||
%%
|
||||
%% Fluxer is free software: you can redistribute it and/or modify
|
||||
%% it under the terms of the GNU Affero General Public License as published by
|
||||
%% the Free Software Foundation, either version 3 of the License, or
|
||||
%% (at your option) any later version.
|
||||
%%
|
||||
%% Fluxer is distributed in the hope that it will be useful,
|
||||
%% but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
%% GNU Affero General Public License for more details.
|
||||
%%
|
||||
%% You should have received a copy of the GNU Affero General Public License
|
||||
%% along with Fluxer. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
-module(guild_permission_cache).
|
||||
|
||||
-export([
|
||||
put_state/1,
|
||||
put_data/2,
|
||||
delete/1,
|
||||
get_permissions/3,
|
||||
get_snapshot/1
|
||||
]).
|
||||
|
||||
-type guild_id() :: integer().
|
||||
-type user_id() :: integer().
|
||||
-type channel_id() :: integer().
|
||||
-type guild_state() :: map().
|
||||
-type guild_data() :: map().
|
||||
|
||||
-define(TABLE, guild_permission_cache).
|
||||
|
||||
-spec put_state(guild_state()) -> ok.
|
||||
put_state(State) when is_map(State) ->
|
||||
GuildId = maps:get(id, State, undefined),
|
||||
Data = maps:get(data, State, #{}),
|
||||
case is_integer(GuildId) of
|
||||
true ->
|
||||
put_data(GuildId, Data);
|
||||
false ->
|
||||
ok
|
||||
end;
|
||||
put_state(_) ->
|
||||
ok.
|
||||
|
||||
-spec put_data(guild_id(), guild_data()) -> ok.
|
||||
put_data(GuildId, Data) when is_integer(GuildId), is_map(Data) ->
|
||||
ensure_table(),
|
||||
NormalizedData = guild_data_index:normalize_data(Data),
|
||||
Snapshot = #{id => GuildId, data => NormalizedData},
|
||||
true = ets:insert(?TABLE, {GuildId, Snapshot}),
|
||||
ok;
|
||||
put_data(_, _) ->
|
||||
ok.
|
||||
|
||||
-spec delete(guild_id()) -> ok.
|
||||
delete(GuildId) when is_integer(GuildId) ->
|
||||
ensure_table(),
|
||||
_ = ets:delete(?TABLE, GuildId),
|
||||
ok;
|
||||
delete(_) ->
|
||||
ok.
|
||||
|
||||
-spec get_permissions(guild_id(), user_id(), channel_id() | undefined) ->
|
||||
{ok, integer()} | {error, not_found}.
|
||||
get_permissions(GuildId, UserId, ChannelId) when is_integer(GuildId), is_integer(UserId) ->
|
||||
case get_snapshot(GuildId) of
|
||||
{ok, Snapshot} ->
|
||||
Permissions = guild_permissions:get_member_permissions(UserId, ChannelId, Snapshot),
|
||||
{ok, Permissions};
|
||||
{error, not_found} ->
|
||||
{error, not_found}
|
||||
end;
|
||||
get_permissions(_, _, _) ->
|
||||
{error, not_found}.
|
||||
|
||||
-spec get_snapshot(guild_id()) -> {ok, guild_state()} | {error, not_found}.
|
||||
get_snapshot(GuildId) when is_integer(GuildId) ->
|
||||
ensure_table(),
|
||||
case ets:lookup(?TABLE, GuildId) of
|
||||
[{GuildId, Snapshot}] ->
|
||||
{ok, Snapshot};
|
||||
[] ->
|
||||
{error, not_found}
|
||||
end;
|
||||
get_snapshot(_) ->
|
||||
{error, not_found}.
|
||||
|
||||
-spec ensure_table() -> ok.
|
||||
ensure_table() ->
|
||||
case ets:whereis(?TABLE) of
|
||||
undefined ->
|
||||
try ets:new(?TABLE, [named_table, public, set, {read_concurrency, true}]) of
|
||||
_ -> ok
|
||||
catch
|
||||
error:badarg -> ok
|
||||
end;
|
||||
_ ->
|
||||
ok
|
||||
end.
|
||||
|
||||
-ifdef(TEST).
|
||||
-include_lib("eunit/include/eunit.hrl").
|
||||
|
||||
put_and_get_permissions_test() ->
|
||||
GuildId = 101,
|
||||
UserId = 44,
|
||||
ViewPermission = constants:view_channel_permission(),
|
||||
Data = #{
|
||||
<<"guild">> => #{<<"owner_id">> => <<"999">>},
|
||||
<<"roles">> => [
|
||||
#{<<"id">> => integer_to_binary(GuildId), <<"permissions">> => integer_to_binary(ViewPermission)}
|
||||
],
|
||||
<<"members">> => #{
|
||||
UserId => #{<<"user">> => #{<<"id">> => integer_to_binary(UserId)}, <<"roles">> => []}
|
||||
},
|
||||
<<"channels">> => [
|
||||
#{<<"id">> => <<"500">>, <<"permission_overwrites">> => []}
|
||||
]
|
||||
},
|
||||
ok = put_data(GuildId, Data),
|
||||
{ok, Permissions} = get_permissions(GuildId, UserId, 500),
|
||||
?assert((Permissions band ViewPermission) =/= 0),
|
||||
ok = delete(GuildId).
|
||||
|
||||
missing_guild_returns_not_found_test() ->
|
||||
?assertEqual({error, not_found}, get_permissions(999999, 1, undefined)).
|
||||
|
||||
-endif.
|
||||
@@ -19,22 +19,21 @@
|
||||
|
||||
-define(ALL_PERMISSIONS, 16#FFFFFFFFFFFFFFFF).
|
||||
|
||||
-export([get_member_permissions/3]).
|
||||
-export([can_view_channel/4]).
|
||||
-export([can_view_channel_by_permissions/4]).
|
||||
-export([can_manage_channel/3]).
|
||||
-export([apply_channel_overwrites/5]).
|
||||
-export([get_max_role_position/2]).
|
||||
-export([find_member_by_user_id/2]).
|
||||
-export([find_role_by_id/2]).
|
||||
-export([find_channel_by_id/2]).
|
||||
-export([
|
||||
get_member_permissions/3,
|
||||
can_view_channel/4,
|
||||
can_view_channel_by_permissions/4,
|
||||
can_manage_channel/3,
|
||||
can_access_message_by_permissions/3,
|
||||
apply_channel_overwrites/5,
|
||||
get_max_role_position/2,
|
||||
find_member_by_user_id/2,
|
||||
find_role_by_id/2,
|
||||
find_channel_by_id/2
|
||||
]).
|
||||
|
||||
-export_type([permission/0]).
|
||||
|
||||
-ifdef(TEST).
|
||||
-include_lib("eunit/include/eunit.hrl").
|
||||
-endif.
|
||||
|
||||
-type permission() :: non_neg_integer().
|
||||
-type user_id() :: integer().
|
||||
-type role_id() :: integer().
|
||||
@@ -61,13 +60,48 @@ can_view_channel(UserId, ChannelId, Member, State) ->
|
||||
-spec can_view_channel_by_permissions(user_id(), channel_id(), maybe_member(), guild_state()) ->
|
||||
boolean().
|
||||
can_view_channel_by_permissions(UserId, ChannelId, Member, State) ->
|
||||
(compute_member_permissions(UserId, ChannelId, Member, State) band
|
||||
constants:view_channel_permission()) =/= 0.
|
||||
Perms = compute_member_permissions(UserId, ChannelId, Member, State),
|
||||
(Perms band constants:view_channel_permission()) =/= 0.
|
||||
|
||||
-spec can_manage_channel(user_id(), maybe_channel_id(), guild_state()) -> boolean().
|
||||
can_manage_channel(UserId, ChannelId, State) ->
|
||||
(get_member_permissions(UserId, ChannelId, State) band
|
||||
constants:manage_channels_permission()) =/= 0.
|
||||
Perms = get_member_permissions(UserId, ChannelId, State),
|
||||
(Perms band constants:manage_channels_permission()) =/= 0.
|
||||
|
||||
-spec can_access_message_by_permissions(permission(), binary(), guild_state()) -> boolean().
|
||||
can_access_message_by_permissions(Permissions, MessageId, State) ->
|
||||
HasReadHistory = (Permissions band constants:read_message_history_permission()) =/= 0,
|
||||
case HasReadHistory of
|
||||
true ->
|
||||
true;
|
||||
false ->
|
||||
case get_message_history_cutoff(State) of
|
||||
null ->
|
||||
false;
|
||||
CutoffMs ->
|
||||
MessageMs = snowflake_util:extract_timestamp(MessageId),
|
||||
MessageMs >= CutoffMs
|
||||
end
|
||||
end.
|
||||
|
||||
-spec get_message_history_cutoff(guild_state()) -> integer() | null.
|
||||
get_message_history_cutoff(State) ->
|
||||
case resolve_data_map(State) of
|
||||
undefined ->
|
||||
null;
|
||||
Data ->
|
||||
Guild = maps:get(<<"guild">>, Data, #{}),
|
||||
case maps:get(<<"message_history_cutoff">>, Guild, null) of
|
||||
null ->
|
||||
null;
|
||||
CutoffBin when is_binary(CutoffBin) ->
|
||||
calendar:rfc3339_to_system_time(
|
||||
binary_to_list(CutoffBin), [{unit, millisecond}]
|
||||
);
|
||||
CutoffInt when is_integer(CutoffInt) ->
|
||||
CutoffInt
|
||||
end
|
||||
end.
|
||||
|
||||
-spec apply_channel_overwrites(permission(), user_id(), member_roles(), channel(), role_id()) ->
|
||||
permission().
|
||||
@@ -86,64 +120,51 @@ get_max_role_position(UserId, State) ->
|
||||
{_, undefined} ->
|
||||
-1;
|
||||
{Member, Data} ->
|
||||
Roles = ensure_list(maps:get(<<"roles">>, Data, [])),
|
||||
lists:foldl(
|
||||
fun(RoleId, MaxPos) ->
|
||||
case find_role_by_id(RoleId, Roles) of
|
||||
undefined ->
|
||||
MaxPos;
|
||||
Role ->
|
||||
Position = maps:get(<<"position">>, Role, 0),
|
||||
max(Position, MaxPos)
|
||||
end
|
||||
end,
|
||||
-1,
|
||||
member_role_ids(Member)
|
||||
)
|
||||
Roles = guild_data_index:role_index(Data),
|
||||
compute_max_position(Member, Roles)
|
||||
end.
|
||||
|
||||
-spec compute_max_position(member(), [role()] | map()) -> integer().
|
||||
compute_max_position(Member, Roles) ->
|
||||
lists:foldl(
|
||||
fun(RoleId, MaxPos) ->
|
||||
case find_role_by_id(RoleId, Roles) of
|
||||
undefined ->
|
||||
MaxPos;
|
||||
Role ->
|
||||
Position = maps:get(<<"position">>, Role, 0),
|
||||
max(Position, MaxPos)
|
||||
end
|
||||
end,
|
||||
-1,
|
||||
member_role_ids(Member)
|
||||
).
|
||||
|
||||
-spec find_member_by_user_id(user_id(), guild_state()) -> member() | undefined.
|
||||
find_member_by_user_id(UserId, State) when is_integer(UserId) ->
|
||||
case resolve_data_map(State) of
|
||||
undefined ->
|
||||
undefined;
|
||||
Data ->
|
||||
Members = ensure_list(maps:get(<<"members">>, Data, [])),
|
||||
lists:foldl(
|
||||
fun(Member, Acc) ->
|
||||
case Acc of
|
||||
undefined ->
|
||||
MUser = maps:get(<<"user">>, Member, #{}),
|
||||
MemberId = to_int(maps:get(<<"id">>, MUser, <<"0">>)),
|
||||
case MemberId =:= UserId of
|
||||
true -> Member;
|
||||
false -> undefined
|
||||
end;
|
||||
Found ->
|
||||
Found
|
||||
end
|
||||
end,
|
||||
undefined,
|
||||
Members
|
||||
)
|
||||
Members = guild_data_index:member_map(Data),
|
||||
maps:get(UserId, Members, undefined)
|
||||
end;
|
||||
find_member_by_user_id(_, _) ->
|
||||
undefined.
|
||||
|
||||
-spec find_role_by_id(role_id(), list()) -> role() | undefined.
|
||||
-spec find_role_by_id(role_id(), [role()] | map()) -> role() | undefined.
|
||||
find_role_by_id(RoleId, Roles) when is_map(Roles) ->
|
||||
maps:get(to_int(RoleId), Roles, undefined);
|
||||
find_role_by_id(RoleId, Roles) ->
|
||||
TargetId = to_int(RoleId),
|
||||
lists:foldl(
|
||||
fun(Role, Acc) ->
|
||||
case Acc of
|
||||
undefined ->
|
||||
case role_id(Role) =:= TargetId of
|
||||
true -> Role;
|
||||
false -> undefined
|
||||
end;
|
||||
Found ->
|
||||
Found
|
||||
end
|
||||
fun
|
||||
(_, Found) when Found =/= undefined -> Found;
|
||||
(Role, undefined) ->
|
||||
case role_id(Role) =:= TargetId of
|
||||
true -> Role;
|
||||
false -> undefined
|
||||
end
|
||||
end,
|
||||
undefined,
|
||||
ensure_list(Roles)
|
||||
@@ -155,23 +176,8 @@ find_channel_by_id(ChannelId, State) when is_integer(ChannelId) ->
|
||||
undefined ->
|
||||
undefined;
|
||||
Data ->
|
||||
Channels = ensure_list(maps:get(<<"channels">>, Data, [])),
|
||||
lists:foldl(
|
||||
fun(Channel, Acc) ->
|
||||
case Acc of
|
||||
undefined ->
|
||||
ChanId = to_int(maps:get(<<"id">>, Channel, <<"0">>)),
|
||||
case ChanId =:= ChannelId of
|
||||
true -> Channel;
|
||||
false -> undefined
|
||||
end;
|
||||
Found ->
|
||||
Found
|
||||
end
|
||||
end,
|
||||
undefined,
|
||||
Channels
|
||||
)
|
||||
Channels = guild_data_index:channel_index(Data),
|
||||
maps:get(ChannelId, Channels, undefined)
|
||||
end;
|
||||
find_channel_by_id(_, _) ->
|
||||
undefined.
|
||||
@@ -188,36 +194,36 @@ compute_member_permissions(UserId, ChannelId, ProvidedMember, State) when is_int
|
||||
true ->
|
||||
?ALL_PERMISSIONS;
|
||||
false ->
|
||||
case resolve_member(UserId, ProvidedMember, State) of
|
||||
undefined ->
|
||||
0;
|
||||
Member ->
|
||||
GuildId = guild_id(State),
|
||||
Roles = ensure_list(maps:get(<<"roles">>, Data, [])),
|
||||
BasePermissions = base_role_permissions(GuildId, Roles),
|
||||
MemberRoles = member_role_ids(Member),
|
||||
Permissions = aggregate_role_permissions(
|
||||
MemberRoles, Roles, BasePermissions
|
||||
),
|
||||
case (Permissions band constants:administrator_permission()) =/= 0 of
|
||||
true ->
|
||||
?ALL_PERMISSIONS;
|
||||
false ->
|
||||
maybe_apply_channel_overwrites(
|
||||
Permissions,
|
||||
UserId,
|
||||
MemberRoles,
|
||||
ChannelId,
|
||||
GuildId,
|
||||
State
|
||||
)
|
||||
end
|
||||
end
|
||||
compute_non_owner_permissions(UserId, ChannelId, ProvidedMember, State, Data)
|
||||
end
|
||||
end;
|
||||
compute_member_permissions(_, _, _, _) ->
|
||||
0.
|
||||
|
||||
-spec compute_non_owner_permissions(
|
||||
user_id(), maybe_channel_id(), maybe_member(), guild_state(), guild_data()
|
||||
) ->
|
||||
permission().
|
||||
compute_non_owner_permissions(UserId, ChannelId, ProvidedMember, State, Data) ->
|
||||
case resolve_member(UserId, ProvidedMember, State) of
|
||||
undefined ->
|
||||
0;
|
||||
Member ->
|
||||
GuildId = guild_id(State),
|
||||
Roles = guild_data_index:role_index(Data),
|
||||
BasePermissions = base_role_permissions(GuildId, Roles),
|
||||
MemberRoles = member_role_ids(Member),
|
||||
Permissions = aggregate_role_permissions(MemberRoles, Roles, BasePermissions),
|
||||
case (Permissions band constants:administrator_permission()) =/= 0 of
|
||||
true ->
|
||||
?ALL_PERMISSIONS;
|
||||
false ->
|
||||
maybe_apply_channel_overwrites(
|
||||
Permissions, UserId, MemberRoles, ChannelId, GuildId, State
|
||||
)
|
||||
end
|
||||
end.
|
||||
|
||||
-spec resolve_member(user_id(), maybe_member(), guild_state()) -> maybe_member().
|
||||
resolve_member(_UserId, Member, _State) when is_map(Member) ->
|
||||
Member;
|
||||
@@ -232,36 +238,25 @@ guild_owner_id(Data) ->
|
||||
-spec guild_id(guild_state()) -> integer().
|
||||
guild_id(State) ->
|
||||
case maps:get(id, State, undefined) of
|
||||
undefined ->
|
||||
to_int(maps:get(<<"id">>, State, 0));
|
||||
GuildId when is_integer(GuildId) ->
|
||||
GuildId;
|
||||
GuildId ->
|
||||
to_int(GuildId)
|
||||
undefined -> to_int(maps:get(<<"id">>, State, 0));
|
||||
GuildId when is_integer(GuildId) -> GuildId;
|
||||
GuildId -> to_int(GuildId)
|
||||
end.
|
||||
|
||||
-spec base_role_permissions(role_id(), list()) -> permission().
|
||||
-spec base_role_permissions(role_id(), map()) -> permission().
|
||||
base_role_permissions(GuildId, Roles) ->
|
||||
lists:foldl(
|
||||
fun(Role, Acc) ->
|
||||
case role_id(Role) =:= GuildId of
|
||||
true -> role_permissions(Role);
|
||||
false -> Acc
|
||||
end
|
||||
end,
|
||||
0,
|
||||
ensure_list(Roles)
|
||||
).
|
||||
case find_role_by_id(GuildId, Roles) of
|
||||
undefined -> 0;
|
||||
Role -> role_permissions(Role)
|
||||
end.
|
||||
|
||||
-spec aggregate_role_permissions(member_roles(), list(), permission()) -> permission().
|
||||
-spec aggregate_role_permissions(member_roles(), [role()] | map(), permission()) -> permission().
|
||||
aggregate_role_permissions(MemberRoles, Roles, BasePermissions) ->
|
||||
lists:foldl(
|
||||
fun(RoleId, Acc) ->
|
||||
case find_role_by_id(RoleId, Roles) of
|
||||
undefined ->
|
||||
Acc;
|
||||
Role ->
|
||||
Acc bor role_permissions(Role)
|
||||
undefined -> Acc;
|
||||
Role -> Acc bor role_permissions(Role)
|
||||
end
|
||||
end,
|
||||
BasePermissions,
|
||||
@@ -277,10 +272,8 @@ maybe_apply_channel_overwrites(Permissions, UserId, MemberRoles, ChannelId, Guil
|
||||
is_integer(ChannelId)
|
||||
->
|
||||
case find_channel_by_id(ChannelId, State) of
|
||||
undefined ->
|
||||
Permissions;
|
||||
Channel ->
|
||||
apply_channel_overwrites(Permissions, UserId, MemberRoles, Channel, GuildId)
|
||||
undefined -> Permissions;
|
||||
Channel -> apply_channel_overwrites(Permissions, UserId, MemberRoles, Channel, GuildId)
|
||||
end;
|
||||
maybe_apply_channel_overwrites(Permissions, _UserId, _MemberRoles, _ChannelId, _GuildId, _State) ->
|
||||
Permissions.
|
||||
@@ -327,10 +320,8 @@ accumulate_role_overwrites(MemberRoles, Overwrites) ->
|
||||
lists:foldl(
|
||||
fun(Overwrite, {A, D}) ->
|
||||
case overwrite_matches_role(Overwrite, RoleId) of
|
||||
true ->
|
||||
{A bor overwrite_allow(Overwrite), D bor overwrite_deny(Overwrite)};
|
||||
false ->
|
||||
{A, D}
|
||||
true -> {A bor overwrite_allow(Overwrite), D bor overwrite_deny(Overwrite)};
|
||||
false -> {A, D}
|
||||
end
|
||||
end,
|
||||
{AllowAcc, DenyAcc},
|
||||
@@ -406,10 +397,8 @@ extract_integer_list(_) ->
|
||||
[].
|
||||
|
||||
-spec ensure_list(term()) -> list().
|
||||
ensure_list(List) when is_list(List) ->
|
||||
List;
|
||||
ensure_list(_) ->
|
||||
[].
|
||||
ensure_list(List) when is_list(List) -> List;
|
||||
ensure_list(_) -> [].
|
||||
|
||||
-spec to_int(term()) -> integer().
|
||||
to_int(Value) ->
|
||||
@@ -421,22 +410,20 @@ to_int(Value) ->
|
||||
-spec resolve_data_map(guild_state() | map()) -> guild_data() | undefined.
|
||||
resolve_data_map(State) when is_map(State) ->
|
||||
case maps:find(data, State) of
|
||||
{ok, Data} when is_map(Data) ->
|
||||
Data;
|
||||
{ok, Data} when is_map(Data) =:= false ->
|
||||
{ok, Data} when is_map(Data) -> Data;
|
||||
{ok, Data} ->
|
||||
Data;
|
||||
error ->
|
||||
case State of
|
||||
#{<<"members">> := _} ->
|
||||
State;
|
||||
_ ->
|
||||
undefined
|
||||
case maps:is_key(<<"members">>, State) of
|
||||
true -> State;
|
||||
false -> undefined
|
||||
end
|
||||
end;
|
||||
resolve_data_map(_) ->
|
||||
undefined.
|
||||
|
||||
-ifdef(TEST).
|
||||
-include_lib("eunit/include/eunit.hrl").
|
||||
|
||||
owner_receives_full_permissions_test() ->
|
||||
OwnerId = 1,
|
||||
@@ -537,19 +524,16 @@ administrator_role_grants_all_permissions_test() ->
|
||||
data => #{
|
||||
<<"guild">> => #{<<"owner_id">> => integer_to_binary(OwnerId)},
|
||||
<<"roles">> => [
|
||||
#{<<"id">> => integer_to_binary(GuildId), <<"permissions">> => integer_to_binary(Admin)}
|
||||
#{
|
||||
<<"id">> => integer_to_binary(GuildId),
|
||||
<<"permissions">> => integer_to_binary(Admin)
|
||||
}
|
||||
],
|
||||
<<"members">> => [
|
||||
#{
|
||||
<<"user">> => #{<<"id">> => integer_to_binary(UserId)},
|
||||
<<"roles">> => []
|
||||
}
|
||||
#{<<"user">> => #{<<"id">> => integer_to_binary(UserId)}, <<"roles">> => []}
|
||||
],
|
||||
<<"channels">> => [
|
||||
#{
|
||||
<<"id">> => integer_to_binary(ChannelId),
|
||||
<<"permission_overwrites">> => []
|
||||
}
|
||||
#{<<"id">> => integer_to_binary(ChannelId), <<"permission_overwrites">> => []}
|
||||
]
|
||||
}
|
||||
},
|
||||
@@ -557,4 +541,122 @@ administrator_role_grants_all_permissions_test() ->
|
||||
?assertEqual(?ALL_PERMISSIONS, get_member_permissions(UserId, ChannelId, State)),
|
||||
?assert(can_view_channel(UserId, ChannelId, undefined, State)).
|
||||
|
||||
find_member_by_user_id_found_test() ->
|
||||
State = #{
|
||||
data => #{
|
||||
<<"members">> => [
|
||||
#{<<"user">> => #{<<"id">> => <<"123">>}, <<"nick">> => <<"Test">>}
|
||||
]
|
||||
}
|
||||
},
|
||||
Result = find_member_by_user_id(123, State),
|
||||
?assertEqual(<<"Test">>, maps:get(<<"nick">>, Result)).
|
||||
|
||||
find_member_by_user_id_not_found_test() ->
|
||||
State = #{data => #{<<"members">> => []}},
|
||||
?assertEqual(undefined, find_member_by_user_id(123, State)).
|
||||
|
||||
find_member_by_user_id_map_storage_test() ->
|
||||
State = #{
|
||||
data => #{
|
||||
<<"members">> => #{
|
||||
321 => #{<<"user">> => #{<<"id">> => <<"321">>}, <<"nick">> => <<"Mapped">>}
|
||||
}
|
||||
}
|
||||
},
|
||||
Result = find_member_by_user_id(321, State),
|
||||
?assertEqual(<<"Mapped">>, maps:get(<<"nick">>, Result)).
|
||||
|
||||
find_role_by_id_found_test() ->
|
||||
Roles = [#{<<"id">> => <<"100">>, <<"name">> => <<"Admin">>}],
|
||||
Result = find_role_by_id(100, Roles),
|
||||
?assertEqual(<<"Admin">>, maps:get(<<"name">>, Result)).
|
||||
|
||||
find_role_by_id_not_found_test() ->
|
||||
Roles = [#{<<"id">> => <<"100">>}],
|
||||
?assertEqual(undefined, find_role_by_id(999, Roles)).
|
||||
|
||||
find_role_by_id_map_index_test() ->
|
||||
Roles = #{
|
||||
100 => #{<<"id">> => <<"100">>, <<"name">> => <<"Admin">>}
|
||||
},
|
||||
Result = find_role_by_id(100, Roles),
|
||||
?assertEqual(<<"Admin">>, maps:get(<<"name">>, Result)).
|
||||
|
||||
find_channel_by_id_with_index_test() ->
|
||||
State = #{
|
||||
data => #{
|
||||
<<"channels">> => [#{<<"id">> => <<"900">>, <<"name">> => <<"general">>}],
|
||||
<<"channel_index">> => #{900 => #{<<"id">> => <<"900">>, <<"name">> => <<"general">>}}
|
||||
}
|
||||
},
|
||||
Result = find_channel_by_id(900, State),
|
||||
?assertEqual(<<"general">>, maps:get(<<"name">>, Result)).
|
||||
|
||||
to_int_test() ->
|
||||
?assertEqual(123, to_int(123)),
|
||||
?assertEqual(123, to_int(<<"123">>)),
|
||||
?assertEqual(0, to_int(undefined)).
|
||||
|
||||
ensure_list_test() ->
|
||||
?assertEqual([1, 2], ensure_list([1, 2])),
|
||||
?assertEqual([], ensure_list(undefined)),
|
||||
?assertEqual([], ensure_list(#{})).
|
||||
|
||||
can_access_message_with_read_history_test() ->
|
||||
ReadHistory = constants:read_message_history_permission(),
|
||||
State = #{data => #{<<"guild">> => #{}}},
|
||||
MessageId = <<"100">>,
|
||||
?assertEqual(true, can_access_message_by_permissions(ReadHistory, MessageId, State)).
|
||||
|
||||
can_access_message_no_read_history_no_cutoff_test() ->
|
||||
State = #{data => #{<<"guild">> => #{}}},
|
||||
MessageId = <<"100">>,
|
||||
?assertEqual(false, can_access_message_by_permissions(0, MessageId, State)).
|
||||
|
||||
can_access_message_no_read_history_null_cutoff_test() ->
|
||||
State = #{data => #{<<"guild">> => #{<<"message_history_cutoff">> => null}}},
|
||||
MessageId = <<"100">>,
|
||||
?assertEqual(false, can_access_message_by_permissions(0, MessageId, State)).
|
||||
|
||||
can_access_message_no_read_history_message_before_cutoff_test() ->
|
||||
CutoffMs = 1704067200000,
|
||||
BeforeCutoffTimestamp = CutoffMs - 60000,
|
||||
FluxerEpoch = 1420070400000,
|
||||
RelativeTs = BeforeCutoffTimestamp - FluxerEpoch,
|
||||
Snowflake = RelativeTs bsl 22,
|
||||
MessageId = integer_to_binary(Snowflake),
|
||||
State = #{data => #{<<"guild">> => #{<<"message_history_cutoff">> => CutoffMs}}},
|
||||
?assertEqual(false, can_access_message_by_permissions(0, MessageId, State)).
|
||||
|
||||
can_access_message_no_read_history_message_after_cutoff_test() ->
|
||||
CutoffMs = 1704067200000,
|
||||
AfterCutoffTimestamp = CutoffMs + 60000,
|
||||
FluxerEpoch = 1420070400000,
|
||||
RelativeTs = AfterCutoffTimestamp - FluxerEpoch,
|
||||
Snowflake = RelativeTs bsl 22,
|
||||
MessageId = integer_to_binary(Snowflake),
|
||||
State = #{data => #{<<"guild">> => #{<<"message_history_cutoff">> => CutoffMs}}},
|
||||
?assertEqual(true, can_access_message_by_permissions(0, MessageId, State)).
|
||||
|
||||
can_access_message_no_read_history_message_at_cutoff_test() ->
|
||||
CutoffMs = 1704067200000,
|
||||
FluxerEpoch = 1420070400000,
|
||||
RelativeTs = CutoffMs - FluxerEpoch,
|
||||
Snowflake = RelativeTs bsl 22,
|
||||
MessageId = integer_to_binary(Snowflake),
|
||||
State = #{data => #{<<"guild">> => #{<<"message_history_cutoff">> => CutoffMs}}},
|
||||
?assertEqual(true, can_access_message_by_permissions(0, MessageId, State)).
|
||||
|
||||
can_access_message_with_rfc3339_cutoff_test() ->
|
||||
CutoffBin = <<"2024-01-01T00:00:00Z">>,
|
||||
CutoffMs = calendar:rfc3339_to_system_time("2024-01-01T00:00:00Z", [{unit, millisecond}]),
|
||||
AfterCutoffTimestamp = CutoffMs + 60000,
|
||||
FluxerEpoch = 1420070400000,
|
||||
RelativeTs = AfterCutoffTimestamp - FluxerEpoch,
|
||||
Snowflake = RelativeTs bsl 22,
|
||||
MessageId = integer_to_binary(Snowflake),
|
||||
State = #{data => #{<<"guild">> => #{<<"message_history_cutoff">> => CutoffBin}}},
|
||||
?assertEqual(true, can_access_message_by_permissions(0, MessageId, State)).
|
||||
|
||||
-endif.
|
||||
|
||||
@@ -20,8 +20,6 @@
|
||||
-export([handle_bus_presence/3, send_cached_presence_to_session/3]).
|
||||
-export([broadcast_presence_update/3]).
|
||||
|
||||
-import(guild_sessions, [handle_user_offline/2]).
|
||||
|
||||
-type guild_state() :: map().
|
||||
-type member() :: map().
|
||||
-type user_id() :: integer().
|
||||
@@ -31,13 +29,19 @@
|
||||
-endif.
|
||||
|
||||
-spec handle_bus_presence(user_id(), map(), guild_state()) -> {noreply, guild_state()}.
|
||||
-spec send_cached_presence_to_session(user_id(), binary(), guild_state()) -> guild_state().
|
||||
handle_bus_presence(UserId, Payload, State) ->
|
||||
case maps:get(<<"user_update">>, Payload, false) of
|
||||
true ->
|
||||
UserData = maps:get(<<"user">>, Payload, #{}),
|
||||
UpdatedState = handle_user_data_update(UserId, UserData, State),
|
||||
guild_member_list:broadcast_member_list_updates(UserId, State, UpdatedState),
|
||||
MemberData = find_member_by_user_id(UserId, UpdatedState),
|
||||
case is_map(MemberData) of
|
||||
true ->
|
||||
maybe_notify_very_large_guild_member_list({upsert_member, MemberData}, UpdatedState);
|
||||
false ->
|
||||
maybe_notify_very_large_guild_member_list({notify_member_update, UserId}, UpdatedState)
|
||||
end,
|
||||
maybe_broadcast_member_list_updates(UserId, State, UpdatedState),
|
||||
{noreply, UpdatedState};
|
||||
false ->
|
||||
Member = find_member_by_user_id(UserId, State),
|
||||
@@ -50,7 +54,6 @@ handle_bus_presence(UserId, Payload, State) ->
|
||||
Status = constants:status_type_atom(NormalizedStatusBin),
|
||||
Mobile = maps:get(<<"mobile">>, Payload, false),
|
||||
Afk = maps:get(<<"afk">>, Payload, false),
|
||||
logger:debug("[guild_presence] Presence update for UserId=~p, Status=~p", [UserId, Status]),
|
||||
MemberUser = maps:get(<<"user">>, Member, #{}),
|
||||
CustomStatus = maps:get(<<"custom_status">>, Payload, null),
|
||||
PresenceMap = presence_payload:build(
|
||||
@@ -60,29 +63,48 @@ handle_bus_presence(UserId, Payload, State) ->
|
||||
Afk,
|
||||
CustomStatus
|
||||
),
|
||||
Presences = maps:get(presences, State, #{}),
|
||||
UpdatedPresences = maps:put(UserId, PresenceMap, Presences),
|
||||
StateWithPresences = maps:put(presences, UpdatedPresences, State),
|
||||
broadcast_presence_update(UserId, PresenceMap, StateWithPresences),
|
||||
logger:debug("[guild_presence] Broadcasting member list updates for UserId=~p", [UserId]),
|
||||
guild_member_list:broadcast_member_list_updates(UserId, State, StateWithPresences),
|
||||
StateWithPresence = store_member_presence(UserId, PresenceMap, State),
|
||||
broadcast_presence_update(UserId, PresenceMap, StateWithPresence),
|
||||
maybe_notify_very_large_guild_member_list(
|
||||
{presence_update, UserId, PresenceMap}, StateWithPresence
|
||||
),
|
||||
maybe_broadcast_member_list_updates(UserId, State, StateWithPresence),
|
||||
StateAfterOffline =
|
||||
case Status of
|
||||
offline ->
|
||||
handle_user_offline(UserId, StateWithPresences);
|
||||
guild_sessions:handle_user_offline(UserId, StateWithPresence);
|
||||
_ ->
|
||||
StateWithPresences
|
||||
StateWithPresence
|
||||
end,
|
||||
{noreply, StateAfterOffline}
|
||||
end
|
||||
end.
|
||||
|
||||
-spec maybe_broadcast_member_list_updates(user_id(), guild_state(), guild_state()) -> ok.
|
||||
maybe_broadcast_member_list_updates(UserId, OldState, UpdatedState) ->
|
||||
case maps:get(very_large_guild_coordinator_pid, UpdatedState, undefined) of
|
||||
Pid when is_pid(Pid) ->
|
||||
ok;
|
||||
_ ->
|
||||
guild_member_list:broadcast_member_list_updates(UserId, OldState, UpdatedState)
|
||||
end.
|
||||
|
||||
-spec maybe_notify_very_large_guild_member_list(term(), guild_state()) -> ok.
|
||||
maybe_notify_very_large_guild_member_list(NotifyMsg, State) ->
|
||||
case maps:get(very_large_guild_coordinator_pid, State, undefined) of
|
||||
CoordPid when is_pid(CoordPid) ->
|
||||
gen_server:cast(CoordPid, {very_large_guild_member_list_notify, NotifyMsg}),
|
||||
ok;
|
||||
_ ->
|
||||
ok
|
||||
end.
|
||||
|
||||
-spec broadcast_presence_update(user_id(), map(), guild_state()) -> ok.
|
||||
broadcast_presence_update(UserId, Payload, State) ->
|
||||
case find_member_by_user_id(UserId, State) of
|
||||
undefined ->
|
||||
ok;
|
||||
_Member ->
|
||||
_Member ->
|
||||
GuildId = map_utils:get_integer(State, id, 0),
|
||||
PresenceUpdate = maps:put(<<"guild_id">>, integer_to_binary(GuildId), Payload),
|
||||
Sessions = maps:get(sessions, State, #{}),
|
||||
@@ -90,7 +112,9 @@ broadcast_presence_update(UserId, Payload, State) ->
|
||||
SubscribedSessionIds = guild_subscriptions:get_subscribed_sessions(UserId, MemberSubs),
|
||||
TargetChannels = guild_visibility:viewable_channel_set(UserId, State),
|
||||
{ValidSessionIds, InvalidSessionIds} =
|
||||
partition_subscribed_sessions(SubscribedSessionIds, Sessions, TargetChannels, UserId, State),
|
||||
partition_subscribed_sessions(
|
||||
SubscribedSessionIds, Sessions, TargetChannels, UserId, State
|
||||
),
|
||||
StateAfterInvalidRemovals =
|
||||
lists:foldl(
|
||||
fun(SessionId, AccState) ->
|
||||
@@ -122,10 +146,12 @@ broadcast_presence_update(UserId, Payload, State) ->
|
||||
ok
|
||||
end.
|
||||
|
||||
-spec normalize_presence_status(binary() | term()) -> binary().
|
||||
normalize_presence_status(<<"invisible">>) -> <<"offline">>;
|
||||
normalize_presence_status(Status) when is_binary(Status) -> Status;
|
||||
normalize_presence_status(_) -> <<"offline">>.
|
||||
|
||||
-spec send_cached_presence_to_session(user_id(), binary(), guild_state()) -> guild_state().
|
||||
send_cached_presence_to_session(UserId, SessionId, State) ->
|
||||
case presence_cache:get(UserId) of
|
||||
{ok, Payload} ->
|
||||
@@ -134,6 +160,7 @@ send_cached_presence_to_session(UserId, SessionId, State) ->
|
||||
State
|
||||
end.
|
||||
|
||||
-spec send_presence_payload_to_session(user_id(), binary(), map(), guild_state()) -> guild_state().
|
||||
send_presence_payload_to_session(UserId, SessionId, Payload, State) ->
|
||||
GuildId = map_utils:get_integer(State, id, 0),
|
||||
Sessions = maps:get(sessions, State, #{}),
|
||||
@@ -151,7 +178,9 @@ send_presence_payload_to_session(UserId, SessionId, Payload, State) ->
|
||||
CustomStatus = maps:get(<<"custom_status">>, Payload, null),
|
||||
PresenceBase =
|
||||
presence_payload:build(MemberUser, StatusBin, Mobile, Afk, CustomStatus),
|
||||
PresenceUpdate = maps:put(<<"guild_id">>, integer_to_binary(GuildId), PresenceBase),
|
||||
PresenceUpdate = maps:put(
|
||||
<<"guild_id">>, integer_to_binary(GuildId), PresenceBase
|
||||
),
|
||||
gen_server:cast(SessionPid, {dispatch, presence_update, PresenceUpdate}),
|
||||
State
|
||||
end;
|
||||
@@ -162,7 +191,7 @@ send_presence_payload_to_session(UserId, SessionId, Payload, State) ->
|
||||
-spec handle_user_data_update(user_id(), map(), guild_state()) -> guild_state().
|
||||
handle_user_data_update(UserId, UserData, State) ->
|
||||
Data = guild_data(State),
|
||||
Members = guild_members(State),
|
||||
Members = guild_data_index:member_map(Data),
|
||||
case find_member_by_user_id(UserId, State) of
|
||||
undefined ->
|
||||
State;
|
||||
@@ -172,13 +201,13 @@ handle_user_data_update(UserId, UserData, State) ->
|
||||
false ->
|
||||
State;
|
||||
true ->
|
||||
UpdatedMembers = lists:map(
|
||||
fun(M) ->
|
||||
maybe_replace_member(M, UserId, UserData)
|
||||
UpdatedMembers = maps:map(
|
||||
fun(_MemberUserId, Member0) ->
|
||||
maybe_replace_member(Member0, UserId, UserData)
|
||||
end,
|
||||
Members
|
||||
),
|
||||
UpdatedData = maps:put(<<"members">>, UpdatedMembers, Data),
|
||||
UpdatedData = guild_data_index:put_member_map(UpdatedMembers, Data),
|
||||
UpdatedState = maps:put(data, UpdatedData, State),
|
||||
maybe_dispatch_member_update(UserId, UpdatedState),
|
||||
UpdatedState
|
||||
@@ -211,15 +240,13 @@ maybe_dispatch_member_update(UserId, State) ->
|
||||
guild_data(State) ->
|
||||
map_utils:ensure_map(map_utils:get_safe(State, data, #{})).
|
||||
|
||||
-spec guild_members(guild_state()) -> [map()].
|
||||
guild_members(State) ->
|
||||
map_utils:ensure_list(maps:get(<<"members">>, guild_data(State), [])).
|
||||
|
||||
-spec member_id(map()) -> user_id() | undefined.
|
||||
member_id(Member) ->
|
||||
User = map_utils:ensure_map(maps:get(<<"user">>, Member, #{})),
|
||||
map_utils:get_integer(User, <<"id">>, undefined).
|
||||
|
||||
-spec partition_subscribed_sessions([binary()], map(), sets:set(), user_id(), guild_state()) ->
|
||||
{[binary()], [binary()]}.
|
||||
partition_subscribed_sessions(SessionIds, Sessions, TargetChannels, TargetUserId, State) ->
|
||||
lists:foldl(
|
||||
fun(SessionId, {Valids, Invalids}) ->
|
||||
@@ -235,8 +262,12 @@ partition_subscribed_sessions(SessionIds, Sessions, TargetChannels, TargetUserId
|
||||
UserId when UserId =:= TargetUserId ->
|
||||
false;
|
||||
_ ->
|
||||
SessionChannels = guild_visibility:viewable_channel_set(SessionUserId, State),
|
||||
not sets:is_empty(sets:intersection(SessionChannels, TargetChannels))
|
||||
SessionChannels = guild_visibility:viewable_channel_set(
|
||||
SessionUserId, State
|
||||
),
|
||||
not sets:is_empty(
|
||||
sets:intersection(SessionChannels, TargetChannels)
|
||||
)
|
||||
end,
|
||||
case Shared of
|
||||
true -> {[SessionId | Valids], Invalids};
|
||||
@@ -248,12 +279,27 @@ partition_subscribed_sessions(SessionIds, Sessions, TargetChannels, TargetUserId
|
||||
SessionIds
|
||||
).
|
||||
|
||||
-spec remove_session_member_subscription(binary(), user_id(), guild_state()) -> guild_state().
|
||||
remove_session_member_subscription(SessionId, UserId, State) ->
|
||||
MemberSubs = maps:get(member_subscriptions, State, guild_subscriptions:init_state()),
|
||||
NewMemberSubs = guild_subscriptions:unsubscribe(SessionId, UserId, MemberSubs),
|
||||
State1 = maps:put(member_subscriptions, NewMemberSubs, State),
|
||||
guild_sessions:unsubscribe_from_user_presence(UserId, State1).
|
||||
|
||||
-spec check_user_data_differs(map(), map()) -> boolean().
|
||||
check_user_data_differs(CurrentUserData, NewUserData) ->
|
||||
utils:check_user_data_differs(CurrentUserData, NewUserData).
|
||||
|
||||
-spec find_member_by_user_id(user_id(), guild_state()) -> member() | undefined.
|
||||
find_member_by_user_id(UserId, State) ->
|
||||
guild_permissions:find_member_by_user_id(UserId, State).
|
||||
|
||||
-spec store_member_presence(user_id(), map(), guild_state()) -> guild_state().
|
||||
store_member_presence(UserId, PresenceMap, State) ->
|
||||
MemberPresence = maps:get(member_presence, State, #{}),
|
||||
UpdatedPresence = maps:put(UserId, PresenceMap, MemberPresence),
|
||||
maps:put(member_presence, UpdatedPresence, State).
|
||||
|
||||
-ifdef(TEST).
|
||||
|
||||
handle_bus_presence_non_member_noop_test() ->
|
||||
@@ -279,24 +325,24 @@ handle_bus_presence_user_update_test() ->
|
||||
Payload = #{<<"user">> => UserData, <<"user_update">> => true},
|
||||
{noreply, NewState} = handle_bus_presence(1, Payload, State),
|
||||
Data = maps:get(data, NewState),
|
||||
[Member | _] = maps:get(<<"members">>, Data),
|
||||
Member = maps:get(1, maps:get(<<"members">>, Data)),
|
||||
?assertEqual(<<"Updated">>, maps:get(<<"username">>, maps:get(<<"user">>, Member))).
|
||||
|
||||
normalize_presence_status_test() ->
|
||||
?assertEqual(<<"offline">>, normalize_presence_status(<<"invisible">>)),
|
||||
?assertEqual(<<"online">>, normalize_presence_status(<<"online">>)),
|
||||
?assertEqual(<<"idle">>, normalize_presence_status(<<"idle">>)),
|
||||
?assertEqual(<<"offline">>, normalize_presence_status(undefined)).
|
||||
|
||||
presence_test_state() ->
|
||||
#{
|
||||
id => 42,
|
||||
data => #{
|
||||
<<"members">> => [
|
||||
#{<<"user">> => #{<<"id">> => <<"1">>, <<"username">> => <<"Alpha">>}}
|
||||
]
|
||||
<<"members">> => #{
|
||||
1 => #{<<"user">> => #{<<"id">> => <<"1">>, <<"username">> => <<"Alpha">>}}
|
||||
}
|
||||
},
|
||||
sessions => #{}
|
||||
}.
|
||||
|
||||
-endif.
|
||||
|
||||
check_user_data_differs(CurrentUserData, NewUserData) ->
|
||||
utils:check_user_data_differs(CurrentUserData, NewUserData).
|
||||
|
||||
find_member_by_user_id(UserId, State) ->
|
||||
guild_permissions:find_member_by_user_id(UserId, State).
|
||||
|
||||
@@ -27,6 +27,8 @@
|
||||
|
||||
-type session_state() :: map().
|
||||
-type request_data() :: map().
|
||||
-type member() :: map().
|
||||
-type presence() :: map().
|
||||
|
||||
-ifdef(TEST).
|
||||
-include_lib("eunit/include/eunit.hrl").
|
||||
@@ -34,13 +36,10 @@
|
||||
|
||||
-spec handle_request(request_data(), pid(), session_state()) -> ok | {error, atom()}.
|
||||
handle_request(Data, SocketPid, SessionState) when is_map(Data), is_pid(SocketPid) ->
|
||||
logger:debug("[guild_request_members] Handling guild members request: ~p", [Data]),
|
||||
case parse_request(Data) of
|
||||
{ok, Request} ->
|
||||
logger:debug("[guild_request_members] Request parsed successfully: ~p", [Request]),
|
||||
process_request(Request, SocketPid, SessionState);
|
||||
{error, Reason} ->
|
||||
logger:warning("[guild_request_members] Failed to parse request: ~p", [Reason]),
|
||||
{error, Reason}
|
||||
end;
|
||||
handle_request(_, _, _) ->
|
||||
@@ -55,7 +54,6 @@ parse_request(Data) ->
|
||||
Presences = maps:get(<<"presences">>, Data, false),
|
||||
Nonce = maps:get(<<"nonce">>, Data, null),
|
||||
NormalizedNonce = normalize_nonce(Nonce),
|
||||
|
||||
case validate_guild_id(GuildIdRaw) of
|
||||
{ok, GuildId} ->
|
||||
case validate_user_ids(UserIdsRaw) of
|
||||
@@ -137,25 +135,16 @@ process_request(Request, SocketPid, SessionState) ->
|
||||
#{guild_id := GuildId, query := Query, limit := Limit, user_ids := UserIds} = Request,
|
||||
UserIdBin = maps:get(user_id, SessionState),
|
||||
UserId = type_conv:to_integer(UserIdBin),
|
||||
|
||||
logger:debug(
|
||||
"[guild_request_members] Processing request for guild ~p, user ~p, user_ids: ~p",
|
||||
[GuildId, UserId, UserIds]
|
||||
),
|
||||
|
||||
case check_permission(UserId, GuildId, Query, Limit, UserIds, SessionState) of
|
||||
ok ->
|
||||
logger:debug("[guild_request_members] Permission check passed, fetching members"),
|
||||
fetch_and_send_members(Request, SocketPid, SessionState);
|
||||
{error, Reason} ->
|
||||
logger:warning(
|
||||
"[guild_request_members] Permission check failed: ~p",
|
||||
[Reason]
|
||||
),
|
||||
{error, Reason}
|
||||
end.
|
||||
|
||||
-spec check_permission(integer(), integer(), binary(), non_neg_integer(), [integer()], session_state()) ->
|
||||
-spec check_permission(
|
||||
integer(), integer(), binary(), non_neg_integer(), [integer()], session_state()
|
||||
) ->
|
||||
ok | {error, atom()}.
|
||||
check_permission(UserId, GuildId, Query, Limit, UserIds, SessionState) ->
|
||||
RequiresPermission = Query =:= <<>> andalso Limit =:= 0 andalso UserIds =:= [],
|
||||
@@ -177,7 +166,6 @@ check_management_permission(UserId, _GuildId, GuildPid) ->
|
||||
KickMembers = constants:kick_members_permission(),
|
||||
BanMembers = constants:ban_members_permission(),
|
||||
RequiredPermission = ManageRoles bor KickMembers bor BanMembers,
|
||||
|
||||
PermRequest = #{
|
||||
user_id => UserId,
|
||||
permission => RequiredPermission,
|
||||
@@ -215,65 +203,45 @@ fetch_and_send_members(Request, _SocketPid, SessionState) ->
|
||||
nonce := Nonce
|
||||
} = Request,
|
||||
SessionId = maps:get(session_id, SessionState),
|
||||
|
||||
logger:debug(
|
||||
"[guild_request_members] Looking up guild ~p for member request",
|
||||
[GuildId]
|
||||
),
|
||||
|
||||
case lookup_guild(GuildId, SessionState) of
|
||||
{ok, GuildPid} ->
|
||||
logger:debug("[guild_request_members] Guild ~p found, fetching members", [GuildId]),
|
||||
Members = fetch_members(GuildPid, Query, Limit, UserIds),
|
||||
logger:debug("[guild_request_members] Found ~p members", [length(Members)]),
|
||||
PresencesList = maybe_fetch_presences(Presences, GuildPid, Members),
|
||||
send_member_chunks(GuildPid, SessionId, Members, PresencesList, Nonce),
|
||||
logger:debug(
|
||||
"[guild_request_members] Sent ~p member chunks for guild ~p with nonce ~p",
|
||||
[max(1, (length(Members) + ?CHUNK_SIZE - 1) div ?CHUNK_SIZE), GuildId, Nonce]
|
||||
),
|
||||
ok;
|
||||
{error, Reason} ->
|
||||
logger:warning(
|
||||
"[guild_request_members] Failed to lookup guild ~p: ~p",
|
||||
[GuildId, Reason]
|
||||
),
|
||||
{error, Reason}
|
||||
end.
|
||||
|
||||
-spec fetch_members(pid(), binary(), non_neg_integer(), [integer()]) -> [map()].
|
||||
-spec fetch_members(pid(), binary(), non_neg_integer(), [integer()]) -> [member()].
|
||||
fetch_members(GuildPid, _Query, _Limit, UserIds) when UserIds =/= [] ->
|
||||
logger:debug("[guild_request_members] Fetching members by user_ids: ~p", [UserIds]),
|
||||
case gen_server:call(GuildPid, {list_guild_members, #{limit => 100000, offset => 0}}, 10000) of
|
||||
#{members := AllMembers} ->
|
||||
logger:debug("[guild_request_members] Got ~p members from guild, filtering by user_ids", [length(AllMembers)]),
|
||||
Filtered = filter_members_by_ids(AllMembers, UserIds),
|
||||
logger:debug("[guild_request_members] Filtered to ~p members", [length(Filtered)]),
|
||||
Filtered;
|
||||
Other ->
|
||||
logger:warning("[guild_request_members] Unexpected response from guild: ~p", [Other]),
|
||||
filter_members_by_ids(AllMembers, UserIds);
|
||||
_ ->
|
||||
[]
|
||||
end;
|
||||
fetch_members(GuildPid, Query, Limit, []) ->
|
||||
ActualLimit = case Limit of 0 -> 100000; L -> L end,
|
||||
logger:debug("[guild_request_members] Fetching members with query '~s', limit ~p", [Query, ActualLimit]),
|
||||
case gen_server:call(GuildPid, {list_guild_members, #{limit => ActualLimit, offset => 0}}, 10000) of
|
||||
ActualLimit =
|
||||
case Limit of
|
||||
0 -> 100000;
|
||||
L -> L
|
||||
end,
|
||||
case
|
||||
gen_server:call(GuildPid, {list_guild_members, #{limit => ActualLimit, offset => 0}}, 10000)
|
||||
of
|
||||
#{members := AllMembers} ->
|
||||
logger:debug("[guild_request_members] Got ~p members from guild", [length(AllMembers)]),
|
||||
Result = case Query of
|
||||
case Query of
|
||||
<<>> ->
|
||||
lists:sublist(AllMembers, ActualLimit);
|
||||
_ ->
|
||||
filter_members_by_query(AllMembers, Query, ActualLimit)
|
||||
end,
|
||||
logger:debug("[guild_request_members] Returning ~p members after query/filter", [length(Result)]),
|
||||
Result;
|
||||
Other ->
|
||||
logger:warning("[guild_request_members] Unexpected response from guild: ~p", [Other]),
|
||||
end;
|
||||
_ ->
|
||||
[]
|
||||
end.
|
||||
|
||||
-spec filter_members_by_ids([map()], [integer()]) -> [map()].
|
||||
-spec filter_members_by_ids([member()], [integer()]) -> [member()].
|
||||
filter_members_by_ids(Members, UserIds) ->
|
||||
UserIdSet = sets:from_list(UserIds),
|
||||
lists:filter(
|
||||
@@ -284,7 +252,7 @@ filter_members_by_ids(Members, UserIds) ->
|
||||
Members
|
||||
).
|
||||
|
||||
-spec filter_members_by_query([map()], binary(), non_neg_integer()) -> [map()].
|
||||
-spec filter_members_by_query([member()], binary(), non_neg_integer()) -> [member()].
|
||||
filter_members_by_query(Members, Query, Limit) ->
|
||||
NormalizedQuery = string:lowercase(binary_to_list(Query)),
|
||||
Matches = lists:filter(
|
||||
@@ -297,50 +265,47 @@ filter_members_by_query(Members, Query, Limit) ->
|
||||
),
|
||||
lists:sublist(Matches, Limit).
|
||||
|
||||
-spec get_display_name(map()) -> binary().
|
||||
-spec get_display_name(member()) -> binary().
|
||||
get_display_name(Member) when is_map(Member) ->
|
||||
Nick = maps:get(<<"nick">>, Member, undefined),
|
||||
case Nick of
|
||||
undefined -> nick_isundefined(Member);
|
||||
null -> nick_isundefined(Member);
|
||||
undefined -> get_fallback_name(Member);
|
||||
null -> get_fallback_name(Member);
|
||||
_ when is_binary(Nick) -> Nick;
|
||||
_ -> nick_isundefined(Member)
|
||||
_ -> get_fallback_name(Member)
|
||||
end;
|
||||
get_display_name(_) ->
|
||||
<<>>.
|
||||
|
||||
nick_isundefined(Member) ->
|
||||
-spec get_fallback_name(member()) -> binary().
|
||||
get_fallback_name(Member) ->
|
||||
User = maps:get(<<"user">>, Member, #{}),
|
||||
GlobalName = maps:get(<<"global_name">>, User, undefined),
|
||||
case GlobalName of
|
||||
undefined ->
|
||||
Username = maps:get(<<"username">>, User, <<>>),
|
||||
case Username of
|
||||
null -> <<>>;
|
||||
undefined -> <<>>;
|
||||
_ when is_binary(Username) -> Username;
|
||||
_ -> <<>>
|
||||
end;
|
||||
null ->
|
||||
Username = maps:get(<<"username">>, User, <<>>),
|
||||
case Username of
|
||||
null -> <<>>;
|
||||
undefined -> <<>>;
|
||||
_ when is_binary(Username) -> Username;
|
||||
_ -> <<>>
|
||||
end;
|
||||
undefined -> get_username(User);
|
||||
null -> get_username(User);
|
||||
_ when is_binary(GlobalName) -> GlobalName;
|
||||
_ -> get_username(User)
|
||||
end.
|
||||
|
||||
-spec get_username(map()) -> binary().
|
||||
get_username(User) ->
|
||||
Username = maps:get(<<"username">>, User, <<>>),
|
||||
case Username of
|
||||
null -> <<>>;
|
||||
undefined -> <<>>;
|
||||
_ when is_binary(Username) -> Username;
|
||||
_ -> <<>>
|
||||
end.
|
||||
|
||||
-spec extract_user_id(map()) -> integer() | undefined.
|
||||
-spec extract_user_id(member()) -> integer() | undefined.
|
||||
extract_user_id(Member) when is_map(Member) ->
|
||||
User = maps:get(<<"user">>, Member, #{}),
|
||||
map_utils:get_integer(User, <<"id">>, undefined);
|
||||
extract_user_id(_) ->
|
||||
undefined.
|
||||
|
||||
-spec maybe_fetch_presences(boolean(), pid(), [map()]) -> [map()].
|
||||
-spec maybe_fetch_presences(boolean(), pid(), [member()]) -> [presence()].
|
||||
maybe_fetch_presences(false, _GuildPid, _Members) ->
|
||||
[];
|
||||
maybe_fetch_presences(true, _GuildPid, Members) ->
|
||||
@@ -361,41 +326,30 @@ maybe_fetch_presences(true, _GuildPid, Members) ->
|
||||
[P || P <- Cached, presence_visible(P)]
|
||||
end.
|
||||
|
||||
-spec presence_visible(map()) -> boolean().
|
||||
-spec presence_visible(presence()) -> boolean().
|
||||
presence_visible(P) ->
|
||||
Status = maps:get(<<"status">>, P, <<"offline">>),
|
||||
Status =/= <<"offline">> andalso Status =/= <<"invisible">>.
|
||||
|
||||
-spec send_member_chunks(pid(), binary(), [map()], [map()], term()) -> ok.
|
||||
-spec send_member_chunks(pid(), binary(), [member()], [presence()], term()) -> ok.
|
||||
send_member_chunks(GuildPid, SessionId, Members, Presences, Nonce) ->
|
||||
TotalChunks = max(1, (length(Members) + ?CHUNK_SIZE - 1) div ?CHUNK_SIZE),
|
||||
MemberChunks = chunk_list(Members, ?CHUNK_SIZE),
|
||||
PresenceChunks = chunk_presences(Presences, MemberChunks),
|
||||
|
||||
logger:debug(
|
||||
"[guild_request_members] Sending ~p member chunks (total members: ~p, nonce: ~p)",
|
||||
[TotalChunks, length(Members), Nonce]
|
||||
),
|
||||
|
||||
lists:foldl(
|
||||
fun({MemberChunk, PresenceChunk}, ChunkIndex) ->
|
||||
ChunkData = build_chunk_data(
|
||||
MemberChunk, PresenceChunk, ChunkIndex, TotalChunks, Nonce
|
||||
),
|
||||
logger:debug(
|
||||
"[guild_request_members] Sending chunk ~p/~p with ~p members, nonce: ~p",
|
||||
[ChunkIndex + 1, TotalChunks, length(MemberChunk), Nonce]
|
||||
),
|
||||
gen_server:cast(GuildPid, {send_members_chunk, SessionId, ChunkData}),
|
||||
ChunkIndex + 1
|
||||
end,
|
||||
0,
|
||||
lists:zip(MemberChunks, PresenceChunks)
|
||||
),
|
||||
logger:debug("[guild_request_members] All chunks sent successfully"),
|
||||
ok.
|
||||
|
||||
-spec build_chunk_data([map()], [map()], non_neg_integer(), non_neg_integer(), term()) ->
|
||||
-spec build_chunk_data([member()], [presence()], non_neg_integer(), non_neg_integer(), term()) ->
|
||||
map().
|
||||
build_chunk_data(Members, Presences, ChunkIndex, TotalChunks, Nonce) ->
|
||||
Base = #{
|
||||
@@ -403,15 +357,15 @@ build_chunk_data(Members, Presences, ChunkIndex, TotalChunks, Nonce) ->
|
||||
<<"chunk_index">> => ChunkIndex,
|
||||
<<"chunk_count">> => TotalChunks
|
||||
},
|
||||
WithPresences = case Presences of
|
||||
[] -> Base;
|
||||
_ -> maps:put(<<"presences">>, Presences, Base)
|
||||
end,
|
||||
WithNonce = case Nonce of
|
||||
WithPresences =
|
||||
case Presences of
|
||||
[] -> Base;
|
||||
_ -> maps:put(<<"presences">>, Presences, Base)
|
||||
end,
|
||||
case Nonce of
|
||||
null -> WithPresences;
|
||||
_ -> maps:put(<<"nonce">>, Nonce, WithPresences)
|
||||
end,
|
||||
WithNonce.
|
||||
end.
|
||||
|
||||
-spec chunk_list([T], pos_integer()) -> [[T]] when T :: term().
|
||||
chunk_list([], _Size) ->
|
||||
@@ -419,13 +373,14 @@ chunk_list([], _Size) ->
|
||||
chunk_list(List, Size) ->
|
||||
chunk_list(List, Size, []).
|
||||
|
||||
-spec chunk_list([T], pos_integer(), [[T]]) -> [[T]] when T :: term().
|
||||
chunk_list([], _Size, Acc) ->
|
||||
lists:reverse(Acc);
|
||||
chunk_list(List, Size, Acc) ->
|
||||
{Chunk, Rest} = lists:split(min(Size, length(List)), List),
|
||||
chunk_list(Rest, Size, [Chunk | Acc]).
|
||||
|
||||
-spec chunk_presences([map()], [[map()]]) -> [[map()]].
|
||||
-spec chunk_presences([presence()], [[member()]]) -> [[presence()]].
|
||||
chunk_presences(Presences, MemberChunks) ->
|
||||
lists:map(
|
||||
fun(MemberChunk) ->
|
||||
@@ -491,15 +446,18 @@ display_name_priority_test() ->
|
||||
<<"nick">> => <<"Nick">>
|
||||
},
|
||||
?assertEqual(<<"Nick">>, get_display_name(MemberWithNick)),
|
||||
|
||||
MemberWithGlobal = #{
|
||||
<<"user">> => #{<<"username">> => <<"user">>, <<"global_name">> => <<"Global">>}
|
||||
},
|
||||
?assertEqual(<<"Global">>, get_display_name(MemberWithGlobal)),
|
||||
|
||||
MemberWithUsername = #{
|
||||
<<"user">> => #{<<"username">> => <<"user">>}
|
||||
},
|
||||
?assertEqual(<<"user">>, get_display_name(MemberWithUsername)).
|
||||
|
||||
normalize_nonce_test() ->
|
||||
?assertEqual(<<"abc">>, normalize_nonce(<<"abc">>)),
|
||||
?assertEqual(null, normalize_nonce(<<"this_nonce_is_way_too_long_to_be_valid">>)),
|
||||
?assertEqual(null, normalize_nonce(undefined)).
|
||||
|
||||
-endif.
|
||||
|
||||
@@ -20,7 +20,9 @@
|
||||
-export([
|
||||
handle_session_connect/3,
|
||||
handle_session_down/2,
|
||||
remove_session/2,
|
||||
filter_sessions_for_channel/4,
|
||||
filter_sessions_for_message/5,
|
||||
filter_sessions_for_manage_channels/4,
|
||||
filter_sessions_exclude_session/2,
|
||||
handle_user_offline/2,
|
||||
@@ -29,67 +31,79 @@
|
||||
build_initial_last_message_ids/1,
|
||||
is_session_active/2,
|
||||
subscribe_to_user_presence/2,
|
||||
unsubscribe_from_user_presence/2
|
||||
unsubscribe_from_user_presence/2,
|
||||
set_session_viewable_channels/3,
|
||||
refresh_all_viewable_channels/1
|
||||
]).
|
||||
|
||||
-ifdef(TEST).
|
||||
-include_lib("eunit/include/eunit.hrl").
|
||||
-endif.
|
||||
|
||||
-import(guild_permissions, [can_view_channel/4, can_manage_channel/3, find_member_by_user_id/2]).
|
||||
-import(guild_data, [get_guild_state/2]).
|
||||
-import(guild_availability, [is_guild_unavailable_for_user/2]).
|
||||
-type guild_state() :: map().
|
||||
-type session_id() :: binary().
|
||||
-type user_id() :: integer().
|
||||
-type guild_id() :: integer().
|
||||
-type channel_id() :: integer().
|
||||
-type session_data() :: map().
|
||||
-type sessions_map() :: #{session_id() => session_data()}.
|
||||
-type session_pair() :: {session_id(), session_data()}.
|
||||
|
||||
-spec handle_session_connect(map(), pid(), guild_state()) ->
|
||||
{reply, {ok, map()} | {ok, unavailable, map()}, guild_state()}.
|
||||
handle_session_connect(Request, Pid, State) ->
|
||||
#{session_id := SessionId, user_id := UserId} = Request,
|
||||
Sessions = maps:get(sessions, State, #{}),
|
||||
case maps:is_key(SessionId, Sessions) of
|
||||
true ->
|
||||
{reply, {ok, guild_data:get_guild_state(UserId, State)}, State};
|
||||
false ->
|
||||
register_new_session(Request, Pid, UserId, SessionId, State)
|
||||
end.
|
||||
|
||||
-spec register_new_session(map(), pid(), user_id(), session_id(), guild_state()) ->
|
||||
{reply, {ok, map()} | {ok, unavailable, map()}, guild_state()}.
|
||||
register_new_session(Request, Pid, UserId, SessionId, State) ->
|
||||
Sessions = maps:get(sessions, State, #{}),
|
||||
ActiveGuilds = maps:get(active_guilds, Request, sets:new()),
|
||||
InitialGuildId = maps:get(initial_guild_id, Request, undefined),
|
||||
UserRoles = session_passive:get_user_roles_for_guild(UserId, State),
|
||||
Bot = maps:get(bot, Request, false),
|
||||
GuildId = maps:get(id, State),
|
||||
|
||||
case maps:is_key(SessionId, Sessions) of
|
||||
Ref = monitor(process, Pid),
|
||||
GuildState = guild_data:get_guild_state(UserId, State),
|
||||
InitialLastMessageIds = build_initial_last_message_ids(GuildState),
|
||||
InitialChannelVersions = build_initial_channel_versions(GuildState),
|
||||
InitialViewableChannels = build_viewable_channel_map(
|
||||
guild_visibility:get_user_viewable_channels(UserId, State)
|
||||
),
|
||||
SessionData = #{
|
||||
session_id => SessionId,
|
||||
user_id => UserId,
|
||||
pid => Pid,
|
||||
mref => Ref,
|
||||
active_guilds => ActiveGuilds,
|
||||
user_roles => UserRoles,
|
||||
bot => Bot,
|
||||
is_staff => maps:get(is_staff, Request, false),
|
||||
previous_passive_updates => InitialLastMessageIds,
|
||||
previous_passive_channel_versions => InitialChannelVersions,
|
||||
previous_passive_voice_states => #{},
|
||||
viewable_channels => InitialViewableChannels
|
||||
},
|
||||
NewSessions = maps:put(SessionId, SessionData, Sessions),
|
||||
State1 = maps:put(sessions, NewSessions, State),
|
||||
State2 = subscribe_to_user_presence(UserId, State1),
|
||||
_ = maybe_notify_coordinator(session_connected, SessionId, UserId, State2),
|
||||
case guild_availability:is_guild_unavailable_for_user(UserId, State2) of
|
||||
true ->
|
||||
{reply, {ok, get_guild_state(UserId, State)}, State};
|
||||
false ->
|
||||
Ref = monitor(process, Pid),
|
||||
GuildState = get_guild_state(UserId, State),
|
||||
InitialLastMessageIds = build_initial_last_message_ids(GuildState),
|
||||
SessionData = #{
|
||||
session_id => SessionId,
|
||||
user_id => UserId,
|
||||
pid => Pid,
|
||||
mref => Ref,
|
||||
active_guilds => ActiveGuilds,
|
||||
user_roles => UserRoles,
|
||||
bot => Bot,
|
||||
previous_passive_updates => InitialLastMessageIds
|
||||
UnavailableResponse = #{
|
||||
<<"id">> => integer_to_binary(GuildId),
|
||||
<<"unavailable">> => true
|
||||
},
|
||||
NewSessions = maps:put(SessionId, SessionData, Sessions),
|
||||
State1 = maps:put(sessions, NewSessions, State),
|
||||
|
||||
State2 = subscribe_to_user_presence(UserId, State1),
|
||||
|
||||
case is_guild_unavailable_for_user(UserId, State2) of
|
||||
true ->
|
||||
GuildId = maps:get(id, State2),
|
||||
UnavailableResponse = #{
|
||||
<<"id">> => integer_to_binary(GuildId),
|
||||
<<"unavailable">> => true
|
||||
},
|
||||
{reply, {ok, unavailable, UnavailableResponse}, State2};
|
||||
false ->
|
||||
SyncedState = maybe_auto_sync_initial_guild(
|
||||
SessionId,
|
||||
GuildId,
|
||||
InitialGuildId,
|
||||
State2
|
||||
),
|
||||
{reply, {ok, GuildState}, SyncedState}
|
||||
end
|
||||
{reply, {ok, unavailable, UnavailableResponse}, State2};
|
||||
false ->
|
||||
SyncedState = maybe_auto_sync_initial_guild(SessionId, GuildId, InitialGuildId, State2),
|
||||
{reply, {ok, GuildState}, SyncedState}
|
||||
end.
|
||||
|
||||
-spec build_initial_last_message_ids(map()) -> #{binary() => binary()}.
|
||||
build_initial_last_message_ids(GuildState) ->
|
||||
Channels = maps:get(<<"channels">>, GuildState, []),
|
||||
lists:foldl(
|
||||
@@ -106,10 +120,69 @@ build_initial_last_message_ids(GuildState) ->
|
||||
Channels
|
||||
).
|
||||
|
||||
-spec build_initial_channel_versions(map()) -> #{binary() => integer()}.
|
||||
build_initial_channel_versions(GuildState) ->
|
||||
Channels = maps:get(<<"channels">>, GuildState, []),
|
||||
lists:foldl(
|
||||
fun(Channel, Acc) ->
|
||||
ChannelIdBin = maps:get(<<"id">>, Channel, undefined),
|
||||
case ChannelIdBin of
|
||||
undefined ->
|
||||
Acc;
|
||||
_ ->
|
||||
Version = map_utils:get_integer(Channel, <<"version">>, 0),
|
||||
maps:put(ChannelIdBin, Version, Acc)
|
||||
end
|
||||
end,
|
||||
#{},
|
||||
Channels
|
||||
).
|
||||
|
||||
-spec handle_session_down(reference(), guild_state()) ->
|
||||
{noreply, guild_state()} | {stop, normal, guild_state()}.
|
||||
handle_session_down(Ref, State) ->
|
||||
Sessions = maps:get(sessions, State, #{}),
|
||||
DisconnectingSession = find_session_by_ref(Ref, Sessions),
|
||||
State1 = cleanup_disconnecting_session(DisconnectingSession, State),
|
||||
NewSessions = maps:filter(fun(_K, S) -> maps:get(mref, S) =/= Ref end, Sessions),
|
||||
NewState = maps:put(sessions, NewSessions, State1),
|
||||
case map_size(NewSessions) of
|
||||
0 ->
|
||||
case should_auto_stop_on_empty(NewState) of
|
||||
true -> {stop, normal, NewState};
|
||||
false -> {noreply, NewState}
|
||||
end;
|
||||
_ -> {noreply, NewState}
|
||||
end.
|
||||
|
||||
DisconnectingSession = maps:fold(
|
||||
-spec remove_session(session_id(), guild_state()) -> guild_state().
|
||||
remove_session(SessionId, State) ->
|
||||
Sessions = maps:get(sessions, State, #{}),
|
||||
case maps:get(SessionId, Sessions, undefined) of
|
||||
undefined ->
|
||||
State;
|
||||
Session ->
|
||||
maybe_demonitor_session(Session),
|
||||
StateAfterCleanup = cleanup_disconnecting_session(Session, State),
|
||||
SessionsAfterCleanup = maps:get(sessions, StateAfterCleanup, #{}),
|
||||
NewSessions = maps:remove(SessionId, SessionsAfterCleanup),
|
||||
maps:put(sessions, NewSessions, StateAfterCleanup)
|
||||
end.
|
||||
|
||||
-spec maybe_demonitor_session(session_data()) -> ok.
|
||||
maybe_demonitor_session(Session) ->
|
||||
Ref = maps:get(mref, Session, undefined),
|
||||
case is_reference(Ref) of
|
||||
true ->
|
||||
demonitor(Ref, [flush]),
|
||||
ok;
|
||||
false ->
|
||||
ok
|
||||
end.
|
||||
|
||||
-spec find_session_by_ref(reference(), sessions_map()) -> session_data() | undefined.
|
||||
find_session_by_ref(Ref, Sessions) ->
|
||||
maps:fold(
|
||||
fun(_K, S, Acc) ->
|
||||
case maps:get(mref, S) =:= Ref of
|
||||
true -> S;
|
||||
@@ -118,95 +191,253 @@ handle_session_down(Ref, State) ->
|
||||
end,
|
||||
undefined,
|
||||
Sessions
|
||||
),
|
||||
|
||||
State1 =
|
||||
case DisconnectingSession of
|
||||
undefined ->
|
||||
State;
|
||||
Session ->
|
||||
UserId = maps:get(user_id, Session),
|
||||
SessionId = maps:get(session_id, Session),
|
||||
StateAfterPresence = unsubscribe_from_user_presence(UserId, State),
|
||||
StateAfterMemberList = guild_member_list:unsubscribe_session(
|
||||
SessionId, StateAfterPresence
|
||||
),
|
||||
MemberSubs = maps:get(
|
||||
member_subscriptions, StateAfterMemberList, guild_subscriptions:init_state()
|
||||
),
|
||||
NewMemberSubs = guild_subscriptions:unsubscribe_session(SessionId, MemberSubs),
|
||||
maps:put(member_subscriptions, NewMemberSubs, StateAfterMemberList)
|
||||
end,
|
||||
|
||||
NewSessions = maps:filter(fun(_K, S) -> maps:get(mref, S) =/= Ref end, Sessions),
|
||||
NewState = maps:put(sessions, NewSessions, State1),
|
||||
|
||||
case map_size(NewSessions) of
|
||||
0 ->
|
||||
{stop, normal, NewState};
|
||||
_ ->
|
||||
{noreply, NewState}
|
||||
end.
|
||||
|
||||
filter_sessions_for_channel(Sessions, ChannelId, SessionIdOpt, State) ->
|
||||
GuildId = maps:get(id, State, 0),
|
||||
lists:filter(
|
||||
fun({Sid, S}) ->
|
||||
UserId = maps:get(user_id, S),
|
||||
Member = find_member_by_user_id(UserId, State),
|
||||
|
||||
ExcludeSession =
|
||||
case SessionIdOpt of
|
||||
undefined -> false;
|
||||
SessionId -> Sid =:= SessionId
|
||||
end,
|
||||
|
||||
case {ExcludeSession, Member} of
|
||||
{true, _} ->
|
||||
false;
|
||||
{_, undefined} ->
|
||||
logger:warning(
|
||||
"[guild_sessions] Filtering out session with no member: "
|
||||
"guild_id=~p session_id=~p user_id=~p",
|
||||
[GuildId, Sid, UserId]
|
||||
),
|
||||
false;
|
||||
{false, _} ->
|
||||
can_view_channel(UserId, ChannelId, Member, State)
|
||||
end
|
||||
end,
|
||||
maps:to_list(Sessions)
|
||||
).
|
||||
|
||||
filter_sessions_for_manage_channels(Sessions, ChannelId, SessionIdOpt, State) ->
|
||||
-spec cleanup_disconnecting_session(session_data() | undefined, guild_state()) -> guild_state().
|
||||
cleanup_disconnecting_session(undefined, State) ->
|
||||
State;
|
||||
cleanup_disconnecting_session(Session, State) ->
|
||||
UserId = maps:get(user_id, Session),
|
||||
SessionId = maps:get(session_id, Session),
|
||||
_ = maybe_notify_coordinator(session_disconnected, SessionId, UserId, State),
|
||||
StateAfterPresence = unsubscribe_from_user_presence(UserId, State),
|
||||
StateAfterMemberList = guild_member_list:unsubscribe_session(SessionId, StateAfterPresence),
|
||||
MemberSubs = maps:get(
|
||||
member_subscriptions, StateAfterMemberList, guild_subscriptions:init_state()
|
||||
),
|
||||
NewMemberSubs = guild_subscriptions:unsubscribe_session(SessionId, MemberSubs),
|
||||
StateAfterSubs = maps:put(member_subscriptions, NewMemberSubs, StateAfterMemberList),
|
||||
cleanup_connect_admission_for_session(SessionId, StateAfterSubs).
|
||||
|
||||
-spec cleanup_connect_admission_for_session(session_id(), guild_state()) -> guild_state().
|
||||
cleanup_connect_admission_for_session(SessionId, State) ->
|
||||
Pending0 = maps:get(session_connect_pending, State, undefined),
|
||||
State1 =
|
||||
case Pending0 of
|
||||
Pending when is_map(Pending) ->
|
||||
maps:put(session_connect_pending, maps:remove(SessionId, Pending), State);
|
||||
_ ->
|
||||
State
|
||||
end,
|
||||
Queue0 = maps:get(session_connect_queue, State1, undefined),
|
||||
Queue = normalize_connect_queue(Queue0),
|
||||
case queue:is_queue(Queue) of
|
||||
true ->
|
||||
Filtered = queue:filter(
|
||||
fun(Item) ->
|
||||
Request = maps:get(request, Item, #{}),
|
||||
maps:get(session_id, Request, undefined) =/= SessionId
|
||||
end,
|
||||
Queue
|
||||
),
|
||||
maps:put(session_connect_queue, Filtered, State1);
|
||||
false ->
|
||||
State1
|
||||
end.
|
||||
|
||||
-spec normalize_connect_queue(term()) -> queue:queue() | undefined.
|
||||
normalize_connect_queue(Value) when is_list(Value) ->
|
||||
queue:from_list(Value);
|
||||
normalize_connect_queue(Value) ->
|
||||
case queue:is_queue(Value) of
|
||||
true -> Value;
|
||||
false -> undefined
|
||||
end.
|
||||
|
||||
-spec should_auto_stop_on_empty(guild_state()) -> boolean().
|
||||
should_auto_stop_on_empty(State) ->
|
||||
case maps:get(disable_auto_stop_on_empty, State, false) of
|
||||
true ->
|
||||
false;
|
||||
false ->
|
||||
case maps:get(very_large_guild_coordinator_pid, State, undefined) of
|
||||
Pid when is_pid(Pid) -> false;
|
||||
_ -> true
|
||||
end
|
||||
end.
|
||||
|
||||
-spec maybe_notify_coordinator(session_connected | session_disconnected, session_id(), user_id(), guild_state()) ->
|
||||
ok.
|
||||
maybe_notify_coordinator(Type, SessionId, UserId, State) ->
|
||||
case {maps:get(very_large_guild_coordinator_pid, State, undefined),
|
||||
maps:get(very_large_guild_shard_index, State, undefined)}
|
||||
of
|
||||
{CoordPid, ShardIndex} when is_pid(CoordPid), is_integer(ShardIndex) ->
|
||||
Msg =
|
||||
case Type of
|
||||
session_connected ->
|
||||
{very_large_guild_session_connected, ShardIndex, SessionId, UserId};
|
||||
session_disconnected ->
|
||||
{very_large_guild_session_disconnected, ShardIndex, SessionId, UserId}
|
||||
end,
|
||||
CoordPid ! Msg,
|
||||
ok;
|
||||
_ ->
|
||||
ok
|
||||
end.
|
||||
|
||||
-spec filter_sessions_for_channel(
|
||||
sessions_map(), channel_id(), session_id() | undefined, guild_state()
|
||||
) ->
|
||||
[session_pair()].
|
||||
filter_sessions_for_channel(Sessions, ChannelId, SessionIdOpt, State) ->
|
||||
lists:filter(
|
||||
fun({Sid, S}) ->
|
||||
UserId = maps:get(user_id, S),
|
||||
|
||||
ExcludeSession =
|
||||
case SessionIdOpt of
|
||||
undefined -> false;
|
||||
SessionId -> Sid =:= SessionId
|
||||
end,
|
||||
|
||||
case maps:get(pending_connect, S, false) of
|
||||
true ->
|
||||
false;
|
||||
false ->
|
||||
ExcludeSession = should_exclude_session(Sid, SessionIdOpt),
|
||||
case ExcludeSession of
|
||||
true ->
|
||||
false;
|
||||
false ->
|
||||
can_manage_channel(UserId, ChannelId, State)
|
||||
session_can_view_channel(S, ChannelId, State)
|
||||
end
|
||||
end
|
||||
end,
|
||||
maps:to_list(Sessions)
|
||||
).
|
||||
|
||||
filter_sessions_exclude_session(Sessions, SessionIdOpt) ->
|
||||
case SessionIdOpt of
|
||||
-spec filter_sessions_for_message(
|
||||
sessions_map(), channel_id(), binary(), session_id() | undefined, guild_state()
|
||||
) ->
|
||||
[session_pair()].
|
||||
filter_sessions_for_message(Sessions, ChannelId, MessageId, SessionIdOpt, State) ->
|
||||
lists:filter(
|
||||
fun({Sid, S}) ->
|
||||
case maps:get(pending_connect, S, false) of
|
||||
true ->
|
||||
false;
|
||||
false ->
|
||||
ExcludeSession = should_exclude_session(Sid, SessionIdOpt),
|
||||
case ExcludeSession of
|
||||
true ->
|
||||
false;
|
||||
false ->
|
||||
UserId = maps:get(user_id, S, undefined),
|
||||
case session_can_view_channel(S, ChannelId, State) of
|
||||
false ->
|
||||
false;
|
||||
true ->
|
||||
Perms = guild_permissions:get_member_permissions(UserId, ChannelId, State),
|
||||
guild_permissions:can_access_message_by_permissions(
|
||||
Perms, MessageId, State
|
||||
)
|
||||
end
|
||||
end
|
||||
end
|
||||
end,
|
||||
maps:to_list(Sessions)
|
||||
).
|
||||
|
||||
-spec filter_sessions_for_manage_channels(
|
||||
sessions_map(), channel_id(), session_id() | undefined, guild_state()
|
||||
) ->
|
||||
[session_pair()].
|
||||
filter_sessions_for_manage_channels(Sessions, ChannelId, SessionIdOpt, State) ->
|
||||
lists:filter(
|
||||
fun({Sid, S}) ->
|
||||
case maps:get(pending_connect, S, false) of
|
||||
true ->
|
||||
false;
|
||||
false ->
|
||||
UserId = maps:get(user_id, S),
|
||||
ExcludeSession = should_exclude_session(Sid, SessionIdOpt),
|
||||
case ExcludeSession of
|
||||
true -> false;
|
||||
false -> guild_permissions:can_manage_channel(UserId, ChannelId, State)
|
||||
end
|
||||
end
|
||||
end,
|
||||
maps:to_list(Sessions)
|
||||
).
|
||||
|
||||
-spec filter_sessions_exclude_session(sessions_map(), session_id() | undefined) -> [session_pair()].
|
||||
filter_sessions_exclude_session(Sessions, undefined) ->
|
||||
[{Sid, S} || {Sid, S} <- maps:to_list(Sessions), maps:get(pending_connect, S, false) =/= true];
|
||||
filter_sessions_exclude_session(Sessions, SessionId) ->
|
||||
[
|
||||
{Sid, S}
|
||||
|| {Sid, S} <- maps:to_list(Sessions),
|
||||
Sid =/= SessionId,
|
||||
maps:get(pending_connect, S, false) =/= true
|
||||
].
|
||||
|
||||
-spec should_exclude_session(session_id(), session_id() | undefined) -> boolean().
|
||||
should_exclude_session(_, undefined) -> false;
|
||||
should_exclude_session(Sid, SessionId) -> Sid =:= SessionId.
|
||||
|
||||
-spec set_session_viewable_channels(session_id(), map(), guild_state()) -> guild_state().
|
||||
set_session_viewable_channels(SessionId, ViewableChannels, State) ->
|
||||
Sessions = maps:get(sessions, State, #{}),
|
||||
case maps:get(SessionId, Sessions, undefined) of
|
||||
undefined ->
|
||||
maps:to_list(Sessions);
|
||||
SessionId ->
|
||||
[{Sid, S} || {Sid, S} <- maps:to_list(Sessions), Sid =/= SessionId]
|
||||
State;
|
||||
SessionData ->
|
||||
NewSessionData = maps:put(viewable_channels, ViewableChannels, SessionData),
|
||||
NewSessions = maps:put(SessionId, NewSessionData, Sessions),
|
||||
maps:put(sessions, NewSessions, State)
|
||||
end.
|
||||
|
||||
-spec refresh_all_viewable_channels(guild_state()) -> guild_state().
|
||||
refresh_all_viewable_channels(State) ->
|
||||
Sessions = maps:get(sessions, State, #{}),
|
||||
lists:foldl(
|
||||
fun({SessionId, SessionData}, AccState) ->
|
||||
UserId = maps:get(user_id, SessionData, undefined),
|
||||
case is_integer(UserId) of
|
||||
true ->
|
||||
ViewableChannels = build_viewable_channel_map(
|
||||
guild_visibility:get_user_viewable_channels(UserId, AccState)
|
||||
),
|
||||
set_session_viewable_channels(SessionId, ViewableChannels, AccState);
|
||||
false ->
|
||||
AccState
|
||||
end
|
||||
end,
|
||||
State,
|
||||
maps:to_list(Sessions)
|
||||
).
|
||||
|
||||
-spec session_can_view_channel(session_data(), channel_id(), guild_state()) -> boolean().
|
||||
session_can_view_channel(SessionData, ChannelId, State) ->
|
||||
UserId = maps:get(user_id, SessionData, undefined),
|
||||
case {UserId, maps:get(viewable_channels, SessionData, undefined)} of
|
||||
{Uid, ViewableChannels} when is_integer(Uid), is_map(ViewableChannels) ->
|
||||
case maps:is_key(ChannelId, ViewableChannels) of
|
||||
true ->
|
||||
true;
|
||||
false ->
|
||||
check_member_channel_access(Uid, ChannelId, State)
|
||||
end;
|
||||
{Uid, _} when is_integer(Uid) ->
|
||||
check_member_channel_access(Uid, ChannelId, State);
|
||||
_ ->
|
||||
false
|
||||
end.
|
||||
|
||||
-spec check_member_channel_access(user_id(), channel_id(), guild_state()) -> boolean().
|
||||
check_member_channel_access(UserId, ChannelId, State) ->
|
||||
Member = guild_permissions:find_member_by_user_id(UserId, State),
|
||||
case Member of
|
||||
undefined ->
|
||||
false;
|
||||
_ ->
|
||||
guild_permissions:can_view_channel(UserId, ChannelId, Member, State)
|
||||
end.
|
||||
|
||||
-spec build_viewable_channel_map([channel_id()]) -> #{channel_id() => true}.
|
||||
build_viewable_channel_map(ChannelIds) ->
|
||||
lists:foldl(
|
||||
fun(ChannelId, Acc) ->
|
||||
maps:put(ChannelId, true, Acc)
|
||||
end,
|
||||
#{},
|
||||
ChannelIds
|
||||
).
|
||||
|
||||
-spec subscribe_to_user_presence(user_id(), guild_state()) -> guild_state().
|
||||
subscribe_to_user_presence(UserId, State) ->
|
||||
PresenceSubs = maps:get(presence_subscriptions, State, #{}),
|
||||
CurrentCount = maps:get(UserId, PresenceSubs, 0),
|
||||
@@ -221,6 +452,7 @@ subscribe_to_user_presence(UserId, State) ->
|
||||
maps:put(presence_subscriptions, NewSubs, State)
|
||||
end.
|
||||
|
||||
-spec unsubscribe_from_user_presence(user_id(), guild_state()) -> guild_state().
|
||||
unsubscribe_from_user_presence(UserId, State) ->
|
||||
PresenceSubs = maps:get(presence_subscriptions, State, #{}),
|
||||
CurrentCount = maps:get(UserId, PresenceSubs, 0),
|
||||
@@ -235,30 +467,33 @@ unsubscribe_from_user_presence(UserId, State) ->
|
||||
maps:put(presence_subscriptions, NewSubs, State)
|
||||
end.
|
||||
|
||||
-spec handle_user_offline(user_id(), guild_state()) -> guild_state().
|
||||
handle_user_offline(UserId, State) ->
|
||||
PresenceSubs = maps:get(presence_subscriptions, State, #{}),
|
||||
case maps:get(UserId, PresenceSubs, undefined) of
|
||||
0 ->
|
||||
presence_bus:unsubscribe(UserId),
|
||||
NewSubs = maps:remove(UserId, PresenceSubs),
|
||||
maps:put(presence_subscriptions, NewSubs, State);
|
||||
undefined ->
|
||||
State;
|
||||
StateWithSubs = maps:put(presence_subscriptions, NewSubs, State),
|
||||
MemberPresence = maps:get(member_presence, StateWithSubs, #{}),
|
||||
UpdatedMemberPresence = maps:remove(UserId, MemberPresence),
|
||||
maps:put(member_presence, UpdatedMemberPresence, StateWithSubs);
|
||||
_ ->
|
||||
State
|
||||
end.
|
||||
|
||||
-spec maybe_send_cached_presence(user_id(), guild_state()) -> guild_state().
|
||||
maybe_send_cached_presence(UserId, State) ->
|
||||
case presence_cache:get(UserId) of
|
||||
{ok, Payload} ->
|
||||
case guild_presence:handle_bus_presence(UserId, Payload, State) of
|
||||
{noreply, UpdatedState} ->
|
||||
UpdatedState
|
||||
{noreply, UpdatedState} -> UpdatedState
|
||||
end;
|
||||
_ ->
|
||||
State
|
||||
end.
|
||||
|
||||
-spec set_session_active_guild(session_id(), guild_id(), guild_state()) -> guild_state().
|
||||
set_session_active_guild(SessionId, GuildId, State) ->
|
||||
Sessions = maps:get(sessions, State, #{}),
|
||||
case maps:get(SessionId, Sessions, undefined) of
|
||||
@@ -270,6 +505,7 @@ set_session_active_guild(SessionId, GuildId, State) ->
|
||||
maps:put(sessions, NewSessions, State)
|
||||
end.
|
||||
|
||||
-spec set_session_passive_guild(session_id(), guild_id(), guild_state()) -> guild_state().
|
||||
set_session_passive_guild(SessionId, GuildId, State) ->
|
||||
Sessions = maps:get(sessions, State, #{}),
|
||||
case maps:get(SessionId, Sessions, undefined) of
|
||||
@@ -282,39 +518,39 @@ set_session_passive_guild(SessionId, GuildId, State) ->
|
||||
maps:put(sessions, NewSessions, State)
|
||||
end.
|
||||
|
||||
-spec is_session_active(session_id(), guild_state()) -> boolean().
|
||||
is_session_active(SessionId, State) ->
|
||||
GuildId = maps:get(id, State, 0),
|
||||
Sessions = maps:get(sessions, State, #{}),
|
||||
case maps:get(SessionId, Sessions, undefined) of
|
||||
undefined ->
|
||||
false;
|
||||
SessionData ->
|
||||
not session_passive:is_passive(GuildId, SessionData)
|
||||
undefined -> false;
|
||||
SessionData -> not session_passive:is_passive(GuildId, SessionData)
|
||||
end.
|
||||
|
||||
maybe_auto_sync_initial_guild(SessionId, GuildId, InitialGuildId, State) ->
|
||||
case InitialGuildId of
|
||||
GuildId ->
|
||||
Sessions = maps:get(sessions, State, #{}),
|
||||
case maps:get(SessionId, Sessions, undefined) of
|
||||
undefined ->
|
||||
State;
|
||||
SessionData ->
|
||||
SyncedSessionData = session_passive:mark_guild_synced(GuildId, SessionData),
|
||||
NewSessions = maps:put(SessionId, SyncedSessionData, Sessions),
|
||||
maps:put(sessions, NewSessions, State)
|
||||
end;
|
||||
_ ->
|
||||
State
|
||||
end.
|
||||
-spec maybe_auto_sync_initial_guild(
|
||||
session_id(), guild_id(), guild_id() | undefined, guild_state()
|
||||
) ->
|
||||
guild_state().
|
||||
maybe_auto_sync_initial_guild(SessionId, GuildId, GuildId, State) ->
|
||||
Sessions = maps:get(sessions, State, #{}),
|
||||
case maps:get(SessionId, Sessions, undefined) of
|
||||
undefined ->
|
||||
State;
|
||||
SessionData ->
|
||||
SyncedSessionData = session_passive:mark_guild_synced(GuildId, SessionData),
|
||||
NewSessions = maps:put(SessionId, SyncedSessionData, Sessions),
|
||||
maps:put(sessions, NewSessions, State)
|
||||
end;
|
||||
maybe_auto_sync_initial_guild(_SessionId, _GuildId, _InitialGuildId, State) ->
|
||||
State.
|
||||
|
||||
-ifdef(TEST).
|
||||
-include_lib("eunit/include/eunit.hrl").
|
||||
|
||||
build_initial_last_message_ids_empty_channels_test() ->
|
||||
GuildState = #{<<"channels">> => []},
|
||||
Result = build_initial_last_message_ids(GuildState),
|
||||
?assertEqual(#{}, Result),
|
||||
ok.
|
||||
?assertEqual(#{}, Result).
|
||||
|
||||
build_initial_last_message_ids_with_channels_test() ->
|
||||
GuildState = #{
|
||||
@@ -324,8 +560,7 @@ build_initial_last_message_ids_with_channels_test() ->
|
||||
]
|
||||
},
|
||||
Result = build_initial_last_message_ids(GuildState),
|
||||
?assertEqual(#{<<"100">> => <<"500">>, <<"101">> => <<"600">>}, Result),
|
||||
ok.
|
||||
?assertEqual(#{<<"100">> => <<"500">>, <<"101">> => <<"600">>}, Result).
|
||||
|
||||
build_initial_last_message_ids_filters_null_test() ->
|
||||
GuildState = #{
|
||||
@@ -336,13 +571,237 @@ build_initial_last_message_ids_filters_null_test() ->
|
||||
]
|
||||
},
|
||||
Result = build_initial_last_message_ids(GuildState),
|
||||
?assertEqual(#{<<"100">> => <<"500">>}, Result),
|
||||
ok.
|
||||
?assertEqual(#{<<"100">> => <<"500">>}, Result).
|
||||
|
||||
build_initial_last_message_ids_no_channels_key_test() ->
|
||||
GuildState = #{},
|
||||
Result = build_initial_last_message_ids(GuildState),
|
||||
?assertEqual(#{}, Result),
|
||||
?assertEqual(#{}, Result).
|
||||
|
||||
build_initial_channel_versions_test() ->
|
||||
GuildState = #{
|
||||
<<"channels">> => [
|
||||
#{<<"id">> => <<"100">>, <<"version">> => 5},
|
||||
#{<<"id">> => <<"101">>}
|
||||
]
|
||||
},
|
||||
Result = build_initial_channel_versions(GuildState),
|
||||
?assertEqual(#{<<"100">> => 5, <<"101">> => 0}, Result).
|
||||
|
||||
filter_sessions_for_channel_uses_cached_viewable_channels_test() ->
|
||||
SessionId = <<"s1">>,
|
||||
SessionData = #{
|
||||
session_id => SessionId,
|
||||
user_id => 10,
|
||||
pid => self(),
|
||||
viewable_channels => #{200 => true}
|
||||
},
|
||||
Sessions = #{SessionId => SessionData},
|
||||
State = #{
|
||||
sessions => Sessions,
|
||||
data => #{<<"members">> => #{}}
|
||||
},
|
||||
Result = filter_sessions_for_channel(Sessions, 200, undefined, State),
|
||||
?assertEqual([{SessionId, SessionData}], Result).
|
||||
|
||||
refresh_all_viewable_channels_populates_cache_test() ->
|
||||
SessionId = <<"s1">>,
|
||||
UserId = 10,
|
||||
GuildId = 42,
|
||||
ViewPerm = constants:view_channel_permission(),
|
||||
State = #{
|
||||
id => GuildId,
|
||||
sessions => #{
|
||||
SessionId => #{
|
||||
session_id => SessionId,
|
||||
user_id => UserId,
|
||||
pid => self()
|
||||
}
|
||||
},
|
||||
data => #{
|
||||
<<"guild">> => #{<<"owner_id">> => <<"999">>},
|
||||
<<"roles">> => [
|
||||
#{<<"id">> => integer_to_binary(GuildId), <<"permissions">> => integer_to_binary(ViewPerm)}
|
||||
],
|
||||
<<"members">> => #{
|
||||
UserId => #{<<"user">> => #{<<"id">> => integer_to_binary(UserId)}, <<"roles">> => []}
|
||||
},
|
||||
<<"channels">> => [
|
||||
#{<<"id">> => <<"200">>, <<"permission_overwrites">> => []}
|
||||
]
|
||||
}
|
||||
},
|
||||
UpdatedState = refresh_all_viewable_channels(State),
|
||||
UpdatedSessions = maps:get(sessions, UpdatedState, #{}),
|
||||
UpdatedSession = maps:get(SessionId, UpdatedSessions),
|
||||
ViewableChannels = maps:get(viewable_channels, UpdatedSession, #{}),
|
||||
?assertEqual(true, maps:is_key(200, ViewableChannels)).
|
||||
|
||||
filter_sessions_exclude_session_undefined_test() ->
|
||||
Sessions = #{<<"s1">> => #{}, <<"s2">> => #{}},
|
||||
Result = filter_sessions_exclude_session(Sessions, undefined),
|
||||
?assertEqual(2, length(Result)).
|
||||
|
||||
filter_sessions_exclude_session_specific_test() ->
|
||||
Sessions = #{<<"s1">> => #{}, <<"s2">> => #{}},
|
||||
Result = filter_sessions_exclude_session(Sessions, <<"s1">>),
|
||||
?assertEqual(1, length(Result)),
|
||||
?assertEqual([{<<"s2">>, #{}}], Result).
|
||||
|
||||
should_exclude_session_test() ->
|
||||
?assertEqual(false, should_exclude_session(<<"s1">>, undefined)),
|
||||
?assertEqual(true, should_exclude_session(<<"s1">>, <<"s1">>)),
|
||||
?assertEqual(false, should_exclude_session(<<"s1">>, <<"s2">>)).
|
||||
|
||||
find_session_by_ref_found_test() ->
|
||||
Ref = make_ref(),
|
||||
Sessions = #{<<"s1">> => #{mref => Ref, user_id => 1}},
|
||||
Result = find_session_by_ref(Ref, Sessions),
|
||||
?assertEqual(#{mref => Ref, user_id => 1}, Result).
|
||||
|
||||
find_session_by_ref_not_found_test() ->
|
||||
Ref = make_ref(),
|
||||
OtherRef = make_ref(),
|
||||
Sessions = #{<<"s1">> => #{mref => OtherRef, user_id => 1}},
|
||||
Result = find_session_by_ref(Ref, Sessions),
|
||||
?assertEqual(undefined, Result).
|
||||
|
||||
remove_session_removes_entry_test() ->
|
||||
SessionId = <<"s1">>,
|
||||
SessionData = #{
|
||||
session_id => SessionId,
|
||||
user_id => 1,
|
||||
pid => self(),
|
||||
mref => make_ref(),
|
||||
active_guilds => sets:new(),
|
||||
user_roles => [],
|
||||
bot => false,
|
||||
previous_passive_updates => #{},
|
||||
previous_passive_channel_versions => #{},
|
||||
previous_passive_voice_states => #{}
|
||||
},
|
||||
State = #{
|
||||
sessions => #{SessionId => SessionData},
|
||||
presence_subscriptions => #{1 => 1},
|
||||
member_list_subscriptions => #{},
|
||||
member_subscriptions => guild_subscriptions:init_state()
|
||||
},
|
||||
NewState = remove_session(SessionId, State),
|
||||
?assertEqual(#{}, maps:get(sessions, NewState, #{})),
|
||||
?assertEqual(0, maps:get(1, maps:get(presence_subscriptions, NewState, #{}), 0)).
|
||||
|
||||
remove_session_cleans_connect_pending_test() ->
|
||||
SessionId = <<"s1">>,
|
||||
SessionData = #{
|
||||
session_id => SessionId,
|
||||
user_id => 1,
|
||||
pid => self(),
|
||||
mref => make_ref(),
|
||||
active_guilds => sets:new(),
|
||||
user_roles => [],
|
||||
bot => false,
|
||||
previous_passive_updates => #{},
|
||||
previous_passive_channel_versions => #{},
|
||||
previous_passive_voice_states => #{}
|
||||
},
|
||||
State = #{
|
||||
sessions => #{SessionId => SessionData},
|
||||
presence_subscriptions => #{1 => 1},
|
||||
member_list_subscriptions => #{},
|
||||
member_subscriptions => guild_subscriptions:init_state(),
|
||||
session_connect_pending => #{SessionId => 3, <<"s2">> => 1},
|
||||
session_connect_queue => [
|
||||
#{request => #{session_id => SessionId}, attempt => 3},
|
||||
#{request => #{session_id => <<"s2">>}, attempt => 1}
|
||||
]
|
||||
},
|
||||
NewState = remove_session(SessionId, State),
|
||||
Pending = maps:get(session_connect_pending, NewState, #{}),
|
||||
?assertEqual(false, maps:is_key(SessionId, Pending)),
|
||||
?assertEqual(true, maps:is_key(<<"s2">>, Pending)),
|
||||
Queue0 = maps:get(session_connect_queue, NewState, queue:new()),
|
||||
Queue =
|
||||
case Queue0 of
|
||||
L when is_list(L) -> L;
|
||||
_ ->
|
||||
case queue:is_queue(Queue0) of
|
||||
true -> queue:to_list(Queue0);
|
||||
false -> []
|
||||
end
|
||||
end,
|
||||
?assertEqual(
|
||||
1,
|
||||
length([
|
||||
Item
|
||||
|| Item <- Queue,
|
||||
maps:get(session_id, maps:get(request, Item, #{}), undefined) =:= <<"s2">>
|
||||
])
|
||||
),
|
||||
?assertEqual(
|
||||
0,
|
||||
length([
|
||||
Item
|
||||
|| Item <- Queue,
|
||||
maps:get(session_id, maps:get(request, Item, #{}), undefined) =:= SessionId
|
||||
])
|
||||
),
|
||||
ok.
|
||||
|
||||
cleanup_connect_admission_queue_format_test() ->
|
||||
S1 = <<"s1">>,
|
||||
S2 = <<"s2">>,
|
||||
S3 = <<"s3">>,
|
||||
Queue = queue:from_list([
|
||||
#{request => #{session_id => S1}, attempt => 0},
|
||||
#{request => #{session_id => S2}, attempt => 1},
|
||||
#{request => #{session_id => S1}, attempt => 2},
|
||||
#{request => #{session_id => S3}, attempt => 3}
|
||||
]),
|
||||
State0 = #{
|
||||
sessions => #{},
|
||||
session_connect_queue => Queue,
|
||||
presence_subscriptions => #{},
|
||||
member_list_subscriptions => #{},
|
||||
member_subscriptions => guild_subscriptions:init_state()
|
||||
},
|
||||
State1 = cleanup_connect_admission_for_session(S1, State0),
|
||||
ResultQueue0 = maps:get(session_connect_queue, State1),
|
||||
ResultQueue =
|
||||
case queue:is_queue(ResultQueue0) of
|
||||
true -> queue:to_list(ResultQueue0);
|
||||
false -> ResultQueue0
|
||||
end,
|
||||
?assertEqual(2, length(ResultQueue)),
|
||||
SessionIds = [
|
||||
maps:get(session_id, maps:get(request, Item, #{}), undefined)
|
||||
|| Item <- ResultQueue
|
||||
],
|
||||
?assertEqual(false, lists:member(S1, SessionIds)),
|
||||
?assertEqual(true, lists:member(S2, SessionIds)),
|
||||
?assertEqual(true, lists:member(S3, SessionIds)).
|
||||
|
||||
pending_connect_filtered_from_channel_sessions_test() ->
|
||||
NormalSession = #{
|
||||
session_id => <<"s1">>,
|
||||
user_id => 10,
|
||||
pid => self(),
|
||||
viewable_channels => #{200 => true}
|
||||
},
|
||||
PendingSession = #{
|
||||
session_id => <<"s2">>,
|
||||
user_id => 11,
|
||||
pid => self(),
|
||||
pending_connect => true,
|
||||
viewable_channels => #{200 => true}
|
||||
},
|
||||
Sessions = #{<<"s1">> => NormalSession, <<"s2">> => PendingSession},
|
||||
State = #{
|
||||
sessions => Sessions,
|
||||
data => #{<<"members">> => #{}}
|
||||
},
|
||||
Result = filter_sessions_for_channel(Sessions, 200, undefined, State),
|
||||
?assertEqual(1, length(Result)),
|
||||
[{ResultSid, _}] = Result,
|
||||
?assertEqual(<<"s1">>, ResultSid).
|
||||
|
||||
-endif.
|
||||
|
||||
@@ -17,24 +17,26 @@
|
||||
|
||||
-module(guild_state).
|
||||
|
||||
-export([
|
||||
update_state/3
|
||||
]).
|
||||
-export([update_state/3]).
|
||||
|
||||
-import(guild_user_data, [maybe_update_cached_user_data/3]).
|
||||
-import(guild_availability, [handle_unavailability_transition/2]).
|
||||
-import(guild_visibility, [compute_and_dispatch_visibility_changes/2]).
|
||||
-import(guild, [update_counts/1]).
|
||||
-type guild_state() :: map().
|
||||
-type guild_data() :: map().
|
||||
-type event() :: atom().
|
||||
-type event_data() :: map().
|
||||
-type user_id() :: integer().
|
||||
-type role_id() :: binary().
|
||||
|
||||
-spec update_state(event(), event_data(), guild_state()) -> guild_state().
|
||||
update_state(Event, EventData, State) ->
|
||||
StateWithUpdatedUser = maybe_update_cached_user_data(Event, EventData, State),
|
||||
Data = maps:get(data, StateWithUpdatedUser),
|
||||
|
||||
StateWithUpdatedUser0 = guild_user_data:maybe_update_cached_user_data(Event, EventData, State),
|
||||
Data0 = maps:get(data, StateWithUpdatedUser0),
|
||||
Data = guild_data_index:normalize_data(Data0),
|
||||
StateWithUpdatedUser = maps:put(data, Data, StateWithUpdatedUser0),
|
||||
UpdatedData = update_data_for_event(Event, EventData, Data, State),
|
||||
UpdatedState = maps:put(data, UpdatedData, StateWithUpdatedUser),
|
||||
handle_post_update(Event, EventData, StateWithUpdatedUser, UpdatedState).
|
||||
|
||||
handle_post_update(Event, StateWithUpdatedUser, UpdatedState).
|
||||
|
||||
-spec update_data_for_event(event(), event_data(), guild_data(), guild_state()) -> guild_data().
|
||||
update_data_for_event(guild_update, EventData, Data, _State) ->
|
||||
handle_guild_update(EventData, Data);
|
||||
update_data_for_event(guild_member_add, EventData, Data, _State) ->
|
||||
@@ -70,23 +72,184 @@ update_data_for_event(guild_stickers_update, EventData, Data, _State) ->
|
||||
update_data_for_event(_Event, _EventData, Data, _State) ->
|
||||
Data.
|
||||
|
||||
handle_post_update(guild_update, StateWithUpdatedUser, UpdatedState) ->
|
||||
handle_unavailability_transition(StateWithUpdatedUser, UpdatedState),
|
||||
-spec handle_post_update(event(), event_data(), guild_state(), guild_state()) -> guild_state().
|
||||
handle_post_update(guild_update, _EventData, StateWithUpdatedUser, UpdatedState) ->
|
||||
guild_availability:handle_unavailability_transition(StateWithUpdatedUser, UpdatedState);
|
||||
handle_post_update(guild_member_add, _EventData, _StateWithUpdatedUser, UpdatedState) ->
|
||||
guild:update_counts(UpdatedState);
|
||||
handle_post_update(guild_member_update, EventData, StateWithUpdatedUser, UpdatedState) ->
|
||||
UserId = extract_user_id(EventData),
|
||||
case is_integer(UserId) andalso UserId > 0 of
|
||||
true ->
|
||||
guild_visibility:compute_and_dispatch_visibility_changes_for_users(
|
||||
[UserId],
|
||||
StateWithUpdatedUser,
|
||||
UpdatedState
|
||||
);
|
||||
false ->
|
||||
guild_visibility:compute_and_dispatch_visibility_changes(
|
||||
StateWithUpdatedUser,
|
||||
UpdatedState
|
||||
)
|
||||
end;
|
||||
handle_post_update(guild_role_create, _EventData, _StateWithUpdatedUser, UpdatedState) ->
|
||||
UpdatedState;
|
||||
handle_post_update(guild_member_add, _StateWithUpdatedUser, UpdatedState) ->
|
||||
update_counts(UpdatedState);
|
||||
handle_post_update(guild_member_remove, _StateWithUpdatedUser, UpdatedState) ->
|
||||
handle_post_update(guild_role_update, EventData, StateWithUpdatedUser, UpdatedState) ->
|
||||
recompute_visibility_for_roles(
|
||||
extract_role_ids_from_role_update(EventData),
|
||||
StateWithUpdatedUser,
|
||||
UpdatedState
|
||||
);
|
||||
handle_post_update(guild_role_update_bulk, EventData, StateWithUpdatedUser, UpdatedState) ->
|
||||
recompute_visibility_for_roles(
|
||||
extract_role_ids_from_role_update_bulk(EventData),
|
||||
StateWithUpdatedUser,
|
||||
UpdatedState
|
||||
);
|
||||
handle_post_update(guild_role_delete, EventData, StateWithUpdatedUser, UpdatedState) ->
|
||||
recompute_visibility_for_roles(
|
||||
extract_role_ids_from_role_delete(EventData),
|
||||
StateWithUpdatedUser,
|
||||
UpdatedState
|
||||
);
|
||||
handle_post_update(channel_update, EventData, StateWithUpdatedUser, UpdatedState) ->
|
||||
ChannelIds = extract_channel_ids_from_channel_update(EventData),
|
||||
guild_visibility:compute_and_dispatch_visibility_changes_for_channels(
|
||||
ChannelIds,
|
||||
StateWithUpdatedUser,
|
||||
UpdatedState
|
||||
);
|
||||
handle_post_update(channel_update_bulk, EventData, StateWithUpdatedUser, UpdatedState) ->
|
||||
ChannelIds = extract_channel_ids_from_channel_update_bulk(EventData),
|
||||
guild_visibility:compute_and_dispatch_visibility_changes_for_channels(
|
||||
ChannelIds,
|
||||
StateWithUpdatedUser,
|
||||
UpdatedState
|
||||
);
|
||||
handle_post_update(guild_member_remove, EventData, _StateWithUpdatedUser, UpdatedState) ->
|
||||
UserId = extract_user_id(EventData),
|
||||
State1 = cleanup_removed_member_sessions(UpdatedState),
|
||||
update_counts(State1);
|
||||
handle_post_update(Event, StateWithUpdatedUser, UpdatedState) ->
|
||||
State2 = maybe_disconnect_removed_member(UserId, State1),
|
||||
guild:update_counts(State2);
|
||||
handle_post_update(Event, _EventData, StateWithUpdatedUser, UpdatedState) ->
|
||||
case needs_visibility_check(Event) of
|
||||
true ->
|
||||
compute_and_dispatch_visibility_changes(StateWithUpdatedUser, UpdatedState),
|
||||
UpdatedState;
|
||||
guild_visibility:compute_and_dispatch_visibility_changes(
|
||||
StateWithUpdatedUser, UpdatedState
|
||||
);
|
||||
false ->
|
||||
UpdatedState
|
||||
end.
|
||||
|
||||
-spec maybe_disconnect_removed_member(user_id(), guild_state()) -> guild_state().
|
||||
maybe_disconnect_removed_member(UserId, State) when is_integer(UserId), UserId > 0 ->
|
||||
{reply, _Result, NewState} =
|
||||
guild_voice_disconnect:disconnect_voice_user(
|
||||
#{user_id => UserId, connection_id => null},
|
||||
State
|
||||
),
|
||||
NewState;
|
||||
maybe_disconnect_removed_member(_, State) ->
|
||||
State.
|
||||
|
||||
-spec recompute_visibility_for_roles([integer()], guild_state(), guild_state()) -> guild_state().
|
||||
recompute_visibility_for_roles(RoleIds, StateWithUpdatedUser, UpdatedState) ->
|
||||
GuildId = maps:get(id, UpdatedState, 0),
|
||||
case lists:any(fun(RoleId) -> RoleId =:= GuildId end, RoleIds) of
|
||||
true ->
|
||||
guild_visibility:compute_and_dispatch_visibility_changes(
|
||||
StateWithUpdatedUser,
|
||||
UpdatedState
|
||||
);
|
||||
false ->
|
||||
UserIds = affected_user_ids_for_roles(RoleIds, StateWithUpdatedUser, UpdatedState),
|
||||
case UserIds of
|
||||
[] ->
|
||||
UpdatedState;
|
||||
_ ->
|
||||
guild_visibility:compute_and_dispatch_visibility_changes_for_users(
|
||||
UserIds,
|
||||
StateWithUpdatedUser,
|
||||
UpdatedState
|
||||
)
|
||||
end
|
||||
end.
|
||||
|
||||
-spec affected_user_ids_for_roles([integer()], guild_state(), guild_state()) -> [user_id()].
|
||||
affected_user_ids_for_roles(RoleIds, StateWithUpdatedUser, UpdatedState) ->
|
||||
OldData = maps:get(data, StateWithUpdatedUser, #{}),
|
||||
NewData = maps:get(data, UpdatedState, #{}),
|
||||
OldMemberRoleIndex = guild_data_index:member_role_index(OldData),
|
||||
NewMemberRoleIndex = guild_data_index:member_role_index(NewData),
|
||||
UserIdSet = lists:foldl(
|
||||
fun(RoleId, AccSet) ->
|
||||
OldUsers = maps:keys(maps:get(RoleId, OldMemberRoleIndex, #{})),
|
||||
NewUsers = maps:keys(maps:get(RoleId, NewMemberRoleIndex, #{})),
|
||||
RoleUsers = OldUsers ++ NewUsers,
|
||||
lists:foldl(
|
||||
fun(UserId, InnerSet) -> sets:add_element(UserId, InnerSet) end,
|
||||
AccSet,
|
||||
RoleUsers
|
||||
)
|
||||
end,
|
||||
sets:new(),
|
||||
lists:usort(RoleIds)
|
||||
),
|
||||
sets:to_list(UserIdSet).
|
||||
|
||||
-spec extract_role_ids_from_role_update(event_data()) -> [integer()].
|
||||
extract_role_ids_from_role_update(EventData) ->
|
||||
RoleData = maps:get(<<"role">>, EventData, #{}),
|
||||
case type_conv:to_integer(maps:get(<<"id">>, RoleData, undefined)) of
|
||||
undefined -> [];
|
||||
RoleId -> [RoleId]
|
||||
end.
|
||||
|
||||
-spec extract_role_ids_from_role_update_bulk(event_data()) -> [integer()].
|
||||
extract_role_ids_from_role_update_bulk(EventData) ->
|
||||
Roles = maps:get(<<"roles">>, EventData, []),
|
||||
lists:filtermap(
|
||||
fun(RoleData) ->
|
||||
case type_conv:to_integer(maps:get(<<"id">>, RoleData, undefined)) of
|
||||
undefined ->
|
||||
false;
|
||||
RoleId ->
|
||||
{true, RoleId}
|
||||
end
|
||||
end,
|
||||
Roles
|
||||
).
|
||||
|
||||
-spec extract_role_ids_from_role_delete(event_data()) -> [integer()].
|
||||
extract_role_ids_from_role_delete(EventData) ->
|
||||
case type_conv:to_integer(maps:get(<<"role_id">>, EventData, undefined)) of
|
||||
undefined -> [];
|
||||
RoleId -> [RoleId]
|
||||
end.
|
||||
|
||||
-spec extract_channel_ids_from_channel_update(event_data()) -> [integer()].
|
||||
extract_channel_ids_from_channel_update(EventData) ->
|
||||
case type_conv:to_integer(maps:get(<<"id">>, EventData, undefined)) of
|
||||
undefined -> [];
|
||||
ChannelId -> [ChannelId]
|
||||
end.
|
||||
|
||||
-spec extract_channel_ids_from_channel_update_bulk(event_data()) -> [integer()].
|
||||
extract_channel_ids_from_channel_update_bulk(EventData) ->
|
||||
Channels = maps:get(<<"channels">>, EventData, []),
|
||||
lists:filtermap(
|
||||
fun(ChannelData) ->
|
||||
case type_conv:to_integer(maps:get(<<"id">>, ChannelData, undefined)) of
|
||||
undefined ->
|
||||
false;
|
||||
ChannelId ->
|
||||
{true, ChannelId}
|
||||
end
|
||||
end,
|
||||
Channels
|
||||
).
|
||||
|
||||
-spec needs_visibility_check(event()) -> boolean().
|
||||
needs_visibility_check(guild_role_create) -> true;
|
||||
needs_visibility_check(guild_role_update) -> true;
|
||||
needs_visibility_check(guild_role_update_bulk) -> true;
|
||||
@@ -96,33 +259,36 @@ needs_visibility_check(channel_update) -> true;
|
||||
needs_visibility_check(channel_update_bulk) -> true;
|
||||
needs_visibility_check(_) -> false.
|
||||
|
||||
-spec handle_guild_update(event_data(), guild_data()) -> guild_data().
|
||||
handle_guild_update(EventData, Data) ->
|
||||
Guild = maps:get(<<"guild">>, Data),
|
||||
UpdatedGuild = maps:merge(Guild, EventData),
|
||||
maps:put(<<"guild">>, UpdatedGuild, Data).
|
||||
|
||||
-spec handle_member_add(event_data(), guild_data()) -> guild_data().
|
||||
handle_member_add(EventData, Data) ->
|
||||
Members = maps:get(<<"members">>, Data, []),
|
||||
UpdatedData = maps:put(<<"members">>, Members ++ [EventData], Data),
|
||||
UpdatedData.
|
||||
guild_data_index:put_member(EventData, Data).
|
||||
|
||||
-spec handle_member_update(event_data(), guild_data()) -> guild_data().
|
||||
handle_member_update(EventData, Data) ->
|
||||
Members = maps:get(<<"members">>, Data, []),
|
||||
UserId = extract_user_id(EventData),
|
||||
UpdatedMembers = replace_member_by_id(Members, UserId, EventData),
|
||||
maps:put(<<"members">>, UpdatedMembers, Data).
|
||||
Members = guild_data_index:member_map(Data),
|
||||
case maps:is_key(UserId, Members) of
|
||||
true ->
|
||||
guild_data_index:put_member(EventData, Data);
|
||||
false ->
|
||||
Data
|
||||
end.
|
||||
|
||||
-spec handle_member_remove(event_data(), guild_data(), guild_state()) -> guild_data().
|
||||
handle_member_remove(EventData, Data, _State) ->
|
||||
Members = maps:get(<<"members">>, Data, []),
|
||||
UserId = extract_user_id(EventData),
|
||||
FilteredMembers = remove_member_by_id(Members, UserId),
|
||||
maps:put(<<"members">>, FilteredMembers, Data).
|
||||
guild_data_index:remove_member(UserId, Data).
|
||||
|
||||
-spec cleanup_removed_member_sessions(guild_state()) -> guild_state().
|
||||
cleanup_removed_member_sessions(State) ->
|
||||
Data = maps:get(data, State),
|
||||
Members = maps:get(<<"members">>, Data, []),
|
||||
MemberUserIds = sets:from_list([extract_user_id_from_member(M) || M <- Members]),
|
||||
|
||||
MemberUserIds = sets:from_list(guild_data_index:member_ids(Data)),
|
||||
Sessions = maps:get(sessions, State, #{}),
|
||||
FilteredSessions = maps:filter(
|
||||
fun(_K, S) ->
|
||||
@@ -131,158 +297,141 @@ cleanup_removed_member_sessions(State) ->
|
||||
end,
|
||||
Sessions
|
||||
),
|
||||
maps:put(sessions, FilteredSessions, State).
|
||||
|
||||
Presences = maps:get(presences, State, #{}),
|
||||
FilteredPresences = maps:filter(
|
||||
fun(UserId, _V) ->
|
||||
sets:is_element(UserId, MemberUserIds)
|
||||
end,
|
||||
Presences
|
||||
),
|
||||
|
||||
State1 = maps:put(sessions, FilteredSessions, State),
|
||||
maps:put(presences, FilteredPresences, State1).
|
||||
|
||||
extract_user_id_from_member(Member) when is_map(Member) ->
|
||||
MUser = maps:get(<<"user">>, Member, #{}),
|
||||
utils:binary_to_integer_safe(maps:get(<<"id">>, MUser, <<"0">>));
|
||||
extract_user_id_from_member(_) ->
|
||||
0.
|
||||
|
||||
-spec extract_user_id(event_data()) -> user_id().
|
||||
extract_user_id(EventData) ->
|
||||
MUser = maps:get(<<"user">>, EventData, #{}),
|
||||
utils:binary_to_integer_safe(maps:get(<<"id">>, MUser, <<"0">>)).
|
||||
|
||||
replace_member_by_id(Members, UserId, NewMember) ->
|
||||
lists:map(
|
||||
fun(M) when is_map(M) ->
|
||||
MMUser = maps:get(<<"user">>, M, #{}),
|
||||
MUserId = utils:binary_to_integer_safe(maps:get(<<"id">>, MMUser, <<"0">>)),
|
||||
case MUserId =:= UserId of
|
||||
true -> NewMember;
|
||||
false -> M
|
||||
end
|
||||
end,
|
||||
Members
|
||||
).
|
||||
|
||||
remove_member_by_id(Members, UserId) ->
|
||||
lists:filter(
|
||||
fun(M) when is_map(M) ->
|
||||
MMUser = maps:get(<<"user">>, M, #{}),
|
||||
MUserId = utils:binary_to_integer_safe(maps:get(<<"id">>, MMUser, <<"0">>)),
|
||||
MUserId =/= UserId
|
||||
end,
|
||||
Members
|
||||
).
|
||||
|
||||
-spec handle_role_create(event_data(), guild_data()) -> guild_data().
|
||||
handle_role_create(EventData, Data) ->
|
||||
Roles = maps:get(<<"roles">>, Data, []),
|
||||
Roles = guild_data_index:role_list(Data),
|
||||
RoleData = maps:get(<<"role">>, EventData),
|
||||
maps:put(<<"roles">>, Roles ++ [RoleData], Data).
|
||||
guild_data_index:put_roles(Roles ++ [RoleData], Data).
|
||||
|
||||
-spec handle_role_update(event_data(), guild_data()) -> guild_data().
|
||||
handle_role_update(EventData, Data) ->
|
||||
Roles = maps:get(<<"roles">>, Data, []),
|
||||
Roles = guild_data_index:role_list(Data),
|
||||
RoleData = maps:get(<<"role">>, EventData),
|
||||
RoleId = maps:get(<<"id">>, RoleData),
|
||||
UpdatedRoles = replace_item_by_id(Roles, RoleId, RoleData),
|
||||
maps:put(<<"roles">>, UpdatedRoles, Data).
|
||||
guild_data_index:put_roles(UpdatedRoles, Data).
|
||||
|
||||
-spec handle_role_update_bulk(event_data(), guild_data()) -> guild_data().
|
||||
handle_role_update_bulk(EventData, Data) ->
|
||||
Roles = maps:get(<<"roles">>, Data, []),
|
||||
Roles = guild_data_index:role_list(Data),
|
||||
BulkRoles = maps:get(<<"roles">>, EventData, []),
|
||||
UpdatedRoles = bulk_update_items(Roles, BulkRoles),
|
||||
maps:put(<<"roles">>, UpdatedRoles, Data).
|
||||
guild_data_index:put_roles(UpdatedRoles, Data).
|
||||
|
||||
-spec handle_role_delete(event_data(), guild_data()) -> guild_data().
|
||||
handle_role_delete(EventData, Data) ->
|
||||
Roles = maps:get(<<"roles">>, Data, []),
|
||||
Roles = guild_data_index:role_list(Data),
|
||||
RoleId = maps:get(<<"role_id">>, EventData),
|
||||
FilteredRoles = remove_item_by_id(Roles, RoleId),
|
||||
Data1 = maps:put(<<"roles">>, FilteredRoles, Data),
|
||||
Data1 = guild_data_index:put_roles(FilteredRoles, Data),
|
||||
Data2 = strip_role_from_members(RoleId, Data1),
|
||||
strip_role_from_channel_overwrites(RoleId, Data2).
|
||||
|
||||
-spec strip_role_from_members(role_id(), guild_data()) -> guild_data().
|
||||
strip_role_from_members(RoleId, Data) ->
|
||||
Members = maps:get(<<"members">>, Data, []),
|
||||
UpdatedMembers = lists:map(
|
||||
fun(Member) when is_map(Member) ->
|
||||
MemberRoles = maps:get(<<"roles">>, Member, []),
|
||||
FilteredRoles = lists:filter(
|
||||
fun(R) ->
|
||||
RoleIdInt = utils:binary_to_integer_safe(RoleId),
|
||||
RInt = utils:binary_to_integer_safe(R),
|
||||
RInt =/= RoleIdInt
|
||||
end,
|
||||
MemberRoles
|
||||
),
|
||||
maps:put(<<"roles">>, FilteredRoles, Member);
|
||||
(Member) ->
|
||||
Member
|
||||
RoleIdInt = utils:binary_to_integer_safe(RoleId),
|
||||
MemberRoleIndex = guild_data_index:member_role_index(Data),
|
||||
AffectedUsers = maps:keys(maps:get(RoleIdInt, MemberRoleIndex, #{})),
|
||||
lists:foldl(
|
||||
fun(UserId, AccData) ->
|
||||
case guild_data_index:get_member(UserId, AccData) of
|
||||
undefined ->
|
||||
AccData;
|
||||
Member ->
|
||||
MemberRoles = maps:get(<<"roles">>, Member, []),
|
||||
FilteredRoles = lists:filter(
|
||||
fun(R) ->
|
||||
RInt = utils:binary_to_integer_safe(R),
|
||||
RInt =/= RoleIdInt
|
||||
end,
|
||||
MemberRoles
|
||||
),
|
||||
guild_data_index:put_member(maps:put(<<"roles">>, FilteredRoles, Member), AccData)
|
||||
end
|
||||
end,
|
||||
Members
|
||||
),
|
||||
maps:put(<<"members">>, UpdatedMembers, Data).
|
||||
Data,
|
||||
AffectedUsers
|
||||
).
|
||||
|
||||
-spec strip_role_from_channel_overwrites(role_id(), guild_data()) -> guild_data().
|
||||
strip_role_from_channel_overwrites(RoleId, Data) ->
|
||||
Channels = maps:get(<<"channels">>, Data, []),
|
||||
Channels = guild_data_index:channel_list(Data),
|
||||
RoleIdInt = utils:binary_to_integer_safe(RoleId),
|
||||
UpdatedChannels = lists:map(
|
||||
fun(Channel) when is_map(Channel) ->
|
||||
Overwrites = maps:get(<<"permission_overwrites">>, Channel, []),
|
||||
FilteredOverwrites = lists:filter(
|
||||
fun(Overwrite) when is_map(Overwrite) ->
|
||||
OverwriteType = maps:get(<<"type">>, Overwrite, 0),
|
||||
OverwriteId = utils:binary_to_integer_safe(maps:get(<<"id">>, Overwrite, <<"0">>)),
|
||||
not (OverwriteType =:= 0 andalso OverwriteId =:= RoleIdInt);
|
||||
(_) ->
|
||||
true
|
||||
end,
|
||||
Overwrites
|
||||
),
|
||||
maps:put(<<"permission_overwrites">>, FilteredOverwrites, Channel);
|
||||
(Channel) ->
|
||||
Channel
|
||||
fun
|
||||
(Channel) when is_map(Channel) ->
|
||||
Overwrites = maps:get(<<"permission_overwrites">>, Channel, []),
|
||||
FilteredOverwrites = lists:filter(
|
||||
fun
|
||||
(Overwrite) when is_map(Overwrite) ->
|
||||
OverwriteType = maps:get(<<"type">>, Overwrite, 0),
|
||||
OverwriteId = utils:binary_to_integer_safe(
|
||||
maps:get(<<"id">>, Overwrite, <<"0">>)
|
||||
),
|
||||
not (OverwriteType =:= 0 andalso OverwriteId =:= RoleIdInt);
|
||||
(_) ->
|
||||
true
|
||||
end,
|
||||
Overwrites
|
||||
),
|
||||
maps:put(<<"permission_overwrites">>, FilteredOverwrites, Channel);
|
||||
(Channel) ->
|
||||
Channel
|
||||
end,
|
||||
Channels
|
||||
),
|
||||
maps:put(<<"channels">>, UpdatedChannels, Data).
|
||||
guild_data_index:put_channels(UpdatedChannels, Data).
|
||||
|
||||
-spec handle_channel_create(event_data(), guild_data()) -> guild_data().
|
||||
handle_channel_create(EventData, Data) ->
|
||||
Channels = maps:get(<<"channels">>, Data, []),
|
||||
maps:put(<<"channels">>, Channels ++ [EventData], Data).
|
||||
Channels = guild_data_index:channel_list(Data),
|
||||
guild_data_index:put_channels(Channels ++ [EventData], Data).
|
||||
|
||||
-spec handle_channel_update(event_data(), guild_data()) -> guild_data().
|
||||
handle_channel_update(EventData, Data) ->
|
||||
Channels = maps:get(<<"channels">>, Data, []),
|
||||
Channels = guild_data_index:channel_list(Data),
|
||||
ChannelId = maps:get(<<"id">>, EventData),
|
||||
UpdatedChannels = replace_item_by_id(Channels, ChannelId, EventData),
|
||||
maps:put(<<"channels">>, UpdatedChannels, Data).
|
||||
guild_data_index:put_channels(UpdatedChannels, Data).
|
||||
|
||||
-spec handle_channel_update_bulk(event_data(), guild_data()) -> guild_data().
|
||||
handle_channel_update_bulk(EventData, Data) ->
|
||||
Channels = maps:get(<<"channels">>, Data, []),
|
||||
Channels = guild_data_index:channel_list(Data),
|
||||
BulkChannels = maps:get(<<"channels">>, EventData, []),
|
||||
UpdatedChannels = bulk_update_items(Channels, BulkChannels),
|
||||
maps:put(<<"channels">>, UpdatedChannels, Data).
|
||||
guild_data_index:put_channels(UpdatedChannels, Data).
|
||||
|
||||
-spec handle_channel_delete(event_data(), guild_data()) -> guild_data().
|
||||
handle_channel_delete(EventData, Data) ->
|
||||
Channels = maps:get(<<"channels">>, Data, []),
|
||||
Channels = guild_data_index:channel_list(Data),
|
||||
ChannelId = maps:get(<<"id">>, EventData),
|
||||
FilteredChannels = remove_item_by_id(Channels, ChannelId),
|
||||
maps:put(<<"channels">>, FilteredChannels, Data).
|
||||
guild_data_index:put_channels(FilteredChannels, Data).
|
||||
|
||||
-spec handle_message_create(event_data(), guild_data()) -> guild_data().
|
||||
handle_message_create(EventData, Data) ->
|
||||
Channels = maps:get(<<"channels">>, Data, []),
|
||||
Channels = guild_data_index:channel_list(Data),
|
||||
ChannelId = maps:get(<<"channel_id">>, EventData),
|
||||
MessageId = maps:get(<<"id">>, EventData),
|
||||
UpdatedChannels = update_channel_field(Channels, ChannelId, <<"last_message_id">>, MessageId),
|
||||
maps:put(<<"channels">>, UpdatedChannels, Data).
|
||||
guild_data_index:put_channels(UpdatedChannels, Data).
|
||||
|
||||
-spec handle_channel_pins_update(event_data(), guild_data()) -> guild_data().
|
||||
handle_channel_pins_update(EventData, Data) ->
|
||||
Channels = maps:get(<<"channels">>, Data, []),
|
||||
Channels = guild_data_index:channel_list(Data),
|
||||
ChannelId = maps:get(<<"channel_id">>, EventData),
|
||||
LastPin = maps:get(<<"last_pin_timestamp">>, EventData),
|
||||
UpdatedChannels = update_channel_field(Channels, ChannelId, <<"last_pin_timestamp">>, LastPin),
|
||||
maps:put(<<"channels">>, UpdatedChannels, Data).
|
||||
guild_data_index:put_channels(UpdatedChannels, Data).
|
||||
|
||||
-spec update_channel_field([map()], binary(), binary(), term()) -> [map()].
|
||||
update_channel_field(Channels, ChannelId, Field, Value) ->
|
||||
lists:map(
|
||||
fun(C) when is_map(C) ->
|
||||
@@ -294,12 +443,15 @@ update_channel_field(Channels, ChannelId, Field, Value) ->
|
||||
Channels
|
||||
).
|
||||
|
||||
-spec handle_emojis_update(event_data(), guild_data()) -> guild_data().
|
||||
handle_emojis_update(EventData, Data) ->
|
||||
maps:put(<<"emojis">>, maps:get(<<"emojis">>, EventData, []), Data).
|
||||
|
||||
-spec handle_stickers_update(event_data(), guild_data()) -> guild_data().
|
||||
handle_stickers_update(EventData, Data) ->
|
||||
maps:put(<<"stickers">>, maps:get(<<"stickers">>, EventData, []), Data).
|
||||
|
||||
-spec replace_item_by_id([map()], binary(), map()) -> [map()].
|
||||
replace_item_by_id(Items, Id, NewItem) ->
|
||||
lists:map(
|
||||
fun(Item) when is_map(Item) ->
|
||||
@@ -311,6 +463,7 @@ replace_item_by_id(Items, Id, NewItem) ->
|
||||
Items
|
||||
).
|
||||
|
||||
-spec remove_item_by_id([map()], binary()) -> [map()].
|
||||
remove_item_by_id(Items, Id) ->
|
||||
lists:filter(
|
||||
fun(Item) when is_map(Item) ->
|
||||
@@ -319,6 +472,7 @@ remove_item_by_id(Items, Id) ->
|
||||
Items
|
||||
).
|
||||
|
||||
-spec bulk_update_items([map()], [map()]) -> [map()].
|
||||
bulk_update_items(Items, BulkItems) ->
|
||||
BulkMap = lists:foldl(
|
||||
fun
|
||||
@@ -333,7 +487,6 @@ bulk_update_items(Items, BulkItems) ->
|
||||
#{},
|
||||
BulkItems
|
||||
),
|
||||
|
||||
lists:map(
|
||||
fun
|
||||
(Item) when is_map(Item) ->
|
||||
@@ -358,26 +511,19 @@ handle_role_delete_strips_from_members_test() ->
|
||||
#{<<"id">> => <<"100">>, <<"name">> => <<"Admin">>},
|
||||
#{<<"id">> => <<"200">>, <<"name">> => <<"Moderator">>}
|
||||
],
|
||||
<<"members">> => [
|
||||
#{
|
||||
<<"user">> => #{<<"id">> => <<"1">>},
|
||||
<<"roles">> => [<<"100">>, <<"200">>]
|
||||
},
|
||||
#{
|
||||
<<"user">> => #{<<"id">> => <<"2">>},
|
||||
<<"roles">> => [<<"200">>]
|
||||
},
|
||||
#{
|
||||
<<"user">> => #{<<"id">> => <<"3">>},
|
||||
<<"roles">> => [<<"100">>]
|
||||
}
|
||||
],
|
||||
<<"members">> => #{
|
||||
1 => #{<<"user">> => #{<<"id">> => <<"1">>}, <<"roles">> => [<<"100">>, <<"200">>]},
|
||||
2 => #{<<"user">> => #{<<"id">> => <<"2">>}, <<"roles">> => [<<"200">>]},
|
||||
3 => #{<<"user">> => #{<<"id">> => <<"3">>}, <<"roles">> => [<<"100">>]}
|
||||
},
|
||||
<<"channels">> => []
|
||||
},
|
||||
EventData = #{<<"role_id">> => RoleIdToDelete},
|
||||
Result = handle_role_delete(EventData, Data),
|
||||
Members = maps:get(<<"members">>, Result),
|
||||
[M1, M2, M3] = Members,
|
||||
M1 = maps:get(1, Members),
|
||||
M2 = maps:get(2, Members),
|
||||
M3 = maps:get(3, Members),
|
||||
?assertEqual([<<"100">>], maps:get(<<"roles">>, M1)),
|
||||
?assertEqual([], maps:get(<<"roles">>, M2)),
|
||||
?assertEqual([<<"100">>], maps:get(<<"roles">>, M3)).
|
||||
@@ -394,45 +540,24 @@ handle_role_delete_strips_from_channel_overwrites_test() ->
|
||||
#{
|
||||
<<"id">> => <<"500">>,
|
||||
<<"permission_overwrites">> => [
|
||||
#{<<"id">> => <<"100">>, <<"type">> => 0, <<"allow">> => <<"0">>, <<"deny">> => <<"1024">>},
|
||||
#{<<"id">> => <<"200">>, <<"type">> => 0, <<"allow">> => <<"1024">>, <<"deny">> => <<"0">>},
|
||||
#{<<"id">> => <<"1">>, <<"type">> => 1, <<"allow">> => <<"2048">>, <<"deny">> => <<"0">>}
|
||||
]
|
||||
},
|
||||
#{
|
||||
<<"id">> => <<"501">>,
|
||||
<<"permission_overwrites">> => [
|
||||
#{<<"id">> => <<"200">>, <<"type">> => 0, <<"allow">> => <<"1024">>, <<"deny">> => <<"0">>}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
EventData = #{<<"role_id">> => RoleIdToDelete},
|
||||
Result = handle_role_delete(EventData, Data),
|
||||
Channels = maps:get(<<"channels">>, Result),
|
||||
[Ch1, Ch2] = Channels,
|
||||
Ch1Overwrites = maps:get(<<"permission_overwrites">>, Ch1),
|
||||
Ch2Overwrites = maps:get(<<"permission_overwrites">>, Ch2),
|
||||
?assertEqual(2, length(Ch1Overwrites)),
|
||||
?assertEqual(0, length(Ch2Overwrites)),
|
||||
OverwriteIds = [maps:get(<<"id">>, O) || O <- Ch1Overwrites],
|
||||
?assert(lists:member(<<"100">>, OverwriteIds)),
|
||||
?assert(lists:member(<<"1">>, OverwriteIds)),
|
||||
?assertNot(lists:member(<<"200">>, OverwriteIds)).
|
||||
|
||||
handle_role_delete_preserves_user_overwrites_test() ->
|
||||
RoleIdToDelete = <<"200">>,
|
||||
Data = #{
|
||||
<<"roles">> => [
|
||||
#{<<"id">> => <<"200">>, <<"name">> => <<"Moderator">>}
|
||||
],
|
||||
<<"members">> => [],
|
||||
<<"channels">> => [
|
||||
#{
|
||||
<<"id">> => <<"500">>,
|
||||
<<"permission_overwrites">> => [
|
||||
#{<<"id">> => <<"200">>, <<"type">> => 0, <<"allow">> => <<"1024">>, <<"deny">> => <<"0">>},
|
||||
#{<<"id">> => <<"200">>, <<"type">> => 1, <<"allow">> => <<"2048">>, <<"deny">> => <<"0">>}
|
||||
#{
|
||||
<<"id">> => <<"100">>,
|
||||
<<"type">> => 0,
|
||||
<<"allow">> => <<"0">>,
|
||||
<<"deny">> => <<"1024">>
|
||||
},
|
||||
#{
|
||||
<<"id">> => <<"200">>,
|
||||
<<"type">> => 0,
|
||||
<<"allow">> => <<"1024">>,
|
||||
<<"deny">> => <<"0">>
|
||||
},
|
||||
#{
|
||||
<<"id">> => <<"1">>,
|
||||
<<"type">> => 1,
|
||||
<<"allow">> => <<"2048">>,
|
||||
<<"deny">> => <<"0">>
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
@@ -441,26 +566,141 @@ handle_role_delete_preserves_user_overwrites_test() ->
|
||||
Result = handle_role_delete(EventData, Data),
|
||||
Channels = maps:get(<<"channels">>, Result),
|
||||
[Ch1] = Channels,
|
||||
Overwrites = maps:get(<<"permission_overwrites">>, Ch1),
|
||||
?assertEqual(1, length(Overwrites)),
|
||||
[RemainingOverwrite] = Overwrites,
|
||||
?assertEqual(1, maps:get(<<"type">>, RemainingOverwrite)).
|
||||
Ch1Overwrites = maps:get(<<"permission_overwrites">>, Ch1),
|
||||
?assertEqual(2, length(Ch1Overwrites)).
|
||||
|
||||
handle_role_delete_removes_role_from_roles_list_test() ->
|
||||
RoleIdToDelete = <<"200">>,
|
||||
handle_member_add_test() ->
|
||||
Data = #{<<"members">> => #{1 => #{<<"user">> => #{<<"id">> => <<"1">>}}}},
|
||||
EventData = #{<<"user">> => #{<<"id">> => <<"2">>}},
|
||||
Result = handle_member_add(EventData, Data),
|
||||
Members = maps:get(<<"members">>, Result),
|
||||
?assertEqual(2, map_size(Members)).
|
||||
|
||||
handle_member_update_test() ->
|
||||
Data = #{
|
||||
<<"roles">> => [
|
||||
#{<<"id">> => <<"100">>, <<"name">> => <<"Admin">>},
|
||||
#{<<"id">> => <<"200">>, <<"name">> => <<"Moderator">>}
|
||||
],
|
||||
<<"members">> => [],
|
||||
<<"channels">> => []
|
||||
<<"members">> => #{
|
||||
1 => #{<<"user">> => #{<<"id">> => <<"1">>}, <<"nick">> => <<"OldNick">>}
|
||||
}
|
||||
},
|
||||
EventData = #{<<"role_id">> => RoleIdToDelete},
|
||||
Result = handle_role_delete(EventData, Data),
|
||||
Roles = maps:get(<<"roles">>, Result),
|
||||
?assertEqual(1, length(Roles)),
|
||||
[RemainingRole] = Roles,
|
||||
?assertEqual(<<"100">>, maps:get(<<"id">>, RemainingRole)).
|
||||
EventData = #{<<"user">> => #{<<"id">> => <<"1">>}, <<"nick">> => <<"NewNick">>},
|
||||
Result = handle_member_update(EventData, Data),
|
||||
Members = maps:get(<<"members">>, Result),
|
||||
Member = maps:get(1, Members),
|
||||
?assertEqual(<<"NewNick">>, maps:get(<<"nick">>, Member)).
|
||||
|
||||
handle_channel_create_test() ->
|
||||
Data = #{<<"channels">> => []},
|
||||
EventData = #{<<"id">> => <<"100">>, <<"name">> => <<"general">>},
|
||||
Result = handle_channel_create(EventData, Data),
|
||||
Channels = maps:get(<<"channels">>, Result),
|
||||
?assertEqual(1, length(Channels)).
|
||||
|
||||
bulk_update_items_test() ->
|
||||
Items = [
|
||||
#{<<"id">> => <<"1">>, <<"value">> => <<"old1">>},
|
||||
#{<<"id">> => <<"2">>, <<"value">> => <<"old2">>}
|
||||
],
|
||||
BulkItems = [
|
||||
#{<<"id">> => <<"1">>, <<"value">> => <<"new1">>}
|
||||
],
|
||||
Result = bulk_update_items(Items, BulkItems),
|
||||
[Item1, Item2] = Result,
|
||||
?assertEqual(<<"new1">>, maps:get(<<"value">>, Item1)),
|
||||
?assertEqual(<<"old2">>, maps:get(<<"value">>, Item2)).
|
||||
|
||||
needs_visibility_check_test() ->
|
||||
?assertEqual(true, needs_visibility_check(guild_role_update)),
|
||||
?assertEqual(true, needs_visibility_check(channel_update)),
|
||||
?assertEqual(false, needs_visibility_check(message_create)),
|
||||
?assertEqual(false, needs_visibility_check(unknown_event)).
|
||||
|
||||
extract_channel_ids_from_channel_update_test() ->
|
||||
?assertEqual([42], extract_channel_ids_from_channel_update(#{<<"id">> => <<"42">>})),
|
||||
?assertEqual([], extract_channel_ids_from_channel_update(#{})).
|
||||
|
||||
extract_channel_ids_from_channel_update_bulk_test() ->
|
||||
EventData = #{
|
||||
<<"channels">> => [
|
||||
#{<<"id">> => <<"10">>},
|
||||
#{<<"id">> => <<"11">>},
|
||||
#{<<"name">> => <<"missing_id">>}
|
||||
]
|
||||
},
|
||||
?assertEqual([10, 11], extract_channel_ids_from_channel_update_bulk(EventData)).
|
||||
|
||||
guild_member_remove_disconnects_voice_test() ->
|
||||
Self = self(),
|
||||
TestFun = fun(GuildId, ChannelId, UserId, ConnectionId) ->
|
||||
Self ! {force_disconnect, GuildId, ChannelId, UserId, ConnectionId},
|
||||
{ok, #{success => true}}
|
||||
end,
|
||||
GuildId = 42,
|
||||
UserId = 5,
|
||||
ChannelId = 20,
|
||||
Data = #{
|
||||
<<"guild">> => #{<<"owner_id">> => <<"999">>},
|
||||
<<"roles">> => [
|
||||
#{<<"id">> => integer_to_binary(GuildId), <<"permissions">> => <<"0">>}
|
||||
],
|
||||
<<"members">> => #{
|
||||
UserId => #{<<"user">> => #{<<"id">> => integer_to_binary(UserId)}, <<"roles">> => []}
|
||||
},
|
||||
<<"channels">> => [#{<<"id">> => integer_to_binary(ChannelId)}]
|
||||
},
|
||||
VoiceStates = #{
|
||||
<<"conn">> => #{
|
||||
<<"user_id">> => integer_to_binary(UserId),
|
||||
<<"guild_id">> => integer_to_binary(GuildId),
|
||||
<<"channel_id">> => integer_to_binary(ChannelId),
|
||||
<<"connection_id">> => <<"conn">>
|
||||
}
|
||||
},
|
||||
Sessions = #{<<"s1">> => #{user_id => UserId, pid => self()}},
|
||||
State = #{
|
||||
id => GuildId,
|
||||
data => Data,
|
||||
voice_states => VoiceStates,
|
||||
sessions => Sessions,
|
||||
test_force_disconnect_fun => TestFun
|
||||
},
|
||||
EventData = #{<<"user">> => #{<<"id">> => integer_to_binary(UserId)}},
|
||||
UpdatedState = update_state(guild_member_remove, EventData, State),
|
||||
?assertEqual(#{}, maps:get(voice_states, UpdatedState)),
|
||||
?assertEqual(#{}, maps:get(sessions, UpdatedState, #{})),
|
||||
receive
|
||||
{force_disconnect, GuildId, ChannelId, UserId, <<"conn">>} -> ok
|
||||
after 200 ->
|
||||
?assert(false)
|
||||
end.
|
||||
|
||||
guild_update_syncs_unavailability_cache_test() ->
|
||||
GuildId = 420042,
|
||||
CleanupState = #{
|
||||
id => GuildId,
|
||||
data => #{
|
||||
<<"guild">> => #{<<"features">> => []}
|
||||
}
|
||||
},
|
||||
_ = guild_availability:update_unavailability_cache_for_state(CleanupState),
|
||||
try
|
||||
State0 = #{
|
||||
id => GuildId,
|
||||
data => #{
|
||||
<<"guild">> => #{<<"features">> => []},
|
||||
<<"members">> => []
|
||||
},
|
||||
sessions => #{}
|
||||
},
|
||||
State1 = update_state(
|
||||
guild_update,
|
||||
#{<<"features">> => [<<"UNAVAILABLE_FOR_EVERYONE">>]},
|
||||
State0
|
||||
),
|
||||
?assertEqual(unavailable_for_everyone, guild_availability:get_cached_unavailability_mode(GuildId)),
|
||||
_State2 = update_state(guild_update, #{<<"features">> => []}, State1),
|
||||
?assertEqual(available, guild_availability:get_cached_unavailability_mode(GuildId))
|
||||
after
|
||||
_ = guild_availability:update_unavailability_cache_for_state(CleanupState)
|
||||
end.
|
||||
|
||||
-endif.
|
||||
|
||||
@@ -77,10 +77,8 @@ unsubscribe_session(SessionId, State) ->
|
||||
update_subscriptions(SessionId, NewMemberIds, State) ->
|
||||
CurrentlySubscribed = get_user_ids_for_session(SessionId, State),
|
||||
NewMemberIdSet = sets:from_list(NewMemberIds),
|
||||
|
||||
ToRemove = sets:subtract(CurrentlySubscribed, NewMemberIdSet),
|
||||
ToAdd = sets:subtract(NewMemberIdSet, CurrentlySubscribed),
|
||||
|
||||
State1 = sets:fold(
|
||||
fun(UserId, AccState) ->
|
||||
unsubscribe(SessionId, UserId, AccState)
|
||||
@@ -88,7 +86,6 @@ update_subscriptions(SessionId, NewMemberIds, State) ->
|
||||
State,
|
||||
ToRemove
|
||||
),
|
||||
|
||||
sets:fold(
|
||||
fun(UserId, AccState) ->
|
||||
subscribe(SessionId, UserId, AccState)
|
||||
@@ -182,4 +179,11 @@ update_subscriptions_test() ->
|
||||
?assert(is_subscribed(<<"session1">>, 200, State3)),
|
||||
?assert(is_subscribed(<<"session1">>, 300, State3)).
|
||||
|
||||
get_user_ids_for_session_test() ->
|
||||
State0 = init_state(),
|
||||
State1 = subscribe(<<"session1">>, 100, State0),
|
||||
State2 = subscribe(<<"session1">>, 200, State1),
|
||||
UserIds = get_user_ids_for_session(<<"session1">>, State2),
|
||||
?assertEqual([100, 200], lists:sort(sets:to_list(UserIds))).
|
||||
|
||||
-endif.
|
||||
|
||||
@@ -19,5 +19,6 @@
|
||||
|
||||
-export([send_guild_sync/2]).
|
||||
|
||||
-spec send_guild_sync(pid(), binary()) -> ok.
|
||||
send_guild_sync(GuildPid, SessionId) ->
|
||||
gen_server:cast(GuildPid, {send_guild_sync, SessionId}).
|
||||
|
||||
@@ -23,95 +23,88 @@
|
||||
-include_lib("eunit/include/eunit.hrl").
|
||||
-endif.
|
||||
|
||||
-spec handle_subscriptions(map(), pid(), map()) -> ok.
|
||||
-type session_state() :: map().
|
||||
-type session_id() :: binary() | undefined.
|
||||
|
||||
-spec handle_subscriptions(map(), pid(), session_state()) -> ok.
|
||||
handle_subscriptions(Data, SocketPid, SessionState) ->
|
||||
Subscriptions = maps:get(<<"subscriptions">>, Data, #{}),
|
||||
Guilds = maps:get(guilds, SessionState, #{}),
|
||||
SessionId = maps:get(id, SessionState, undefined),
|
||||
|
||||
logger:debug("[guild_unified_subscriptions] Processing ~p guild subscriptions for session ~p", [
|
||||
map_size(Subscriptions), SessionId
|
||||
]),
|
||||
|
||||
maps:foreach(
|
||||
fun(GuildIdBin, GuildSubData) ->
|
||||
process_guild_subscription(GuildIdBin, GuildSubData, Guilds, SessionId, SocketPid, SessionState)
|
||||
process_guild_subscription(
|
||||
GuildIdBin, GuildSubData, Guilds, SessionId, SocketPid, SessionState
|
||||
)
|
||||
end,
|
||||
Subscriptions
|
||||
),
|
||||
ok.
|
||||
|
||||
-spec process_guild_subscription(binary(), map(), map(), binary() | undefined, pid(), map()) -> ok.
|
||||
-spec process_guild_subscription(binary(), map(), map(), session_id(), pid(), session_state()) ->
|
||||
ok.
|
||||
process_guild_subscription(GuildIdBin, GuildSubData, Guilds, SessionId, SocketPid, SessionState) ->
|
||||
case validation:validate_snowflake(<<"guild_id">>, GuildIdBin) of
|
||||
{ok, GuildId} ->
|
||||
case maps:get(GuildId, Guilds, undefined) of
|
||||
{GuildPid, _Ref} when is_pid(GuildPid) ->
|
||||
process_guild_sub_options(GuildId, GuildPid, GuildSubData, SessionId, SocketPid, SessionState);
|
||||
process_guild_sub_options(
|
||||
GuildId, GuildPid, GuildSubData, SessionId, SocketPid, SessionState
|
||||
);
|
||||
undefined ->
|
||||
logger:warning("[guild_unified_subscriptions] Guild ~p not found in session state", [GuildId]),
|
||||
ok;
|
||||
_ ->
|
||||
ok
|
||||
end;
|
||||
{error, _, Reason} ->
|
||||
logger:warning("[guild_unified_subscriptions] Invalid guild_id ~p: ~p", [GuildIdBin, Reason]),
|
||||
{error, _, _} ->
|
||||
ok
|
||||
end.
|
||||
|
||||
-spec process_guild_sub_options(integer(), pid(), map(), binary() | undefined, pid(), map()) -> ok.
|
||||
-spec process_guild_sub_options(integer(), pid(), map(), session_id(), pid(), session_state()) ->
|
||||
ok.
|
||||
process_guild_sub_options(GuildId, GuildPid, GuildSubData, SessionId, SocketPid, SessionState) ->
|
||||
WasActive = not session_passive:is_passive(GuildId, SessionState),
|
||||
ActiveChanged = process_active_flag(GuildSubData, GuildPid, SessionId, WasActive),
|
||||
|
||||
process_sync_flag(GuildSubData, GuildId, GuildPid, SessionId, ActiveChanged),
|
||||
|
||||
process_member_list_channels(GuildSubData, GuildId, GuildPid, SessionId, SocketPid),
|
||||
|
||||
process_member_subscriptions(GuildSubData, GuildPid, SessionId),
|
||||
|
||||
process_typing_flag(GuildSubData, GuildPid, SessionId),
|
||||
|
||||
ok.
|
||||
|
||||
-spec process_active_flag(map(), pid(), binary() | undefined, boolean()) -> boolean().
|
||||
-spec process_active_flag(map(), pid(), session_id(), boolean()) -> boolean().
|
||||
process_active_flag(GuildSubData, GuildPid, SessionId, WasActive) ->
|
||||
case maps:get(<<"active">>, GuildSubData, undefined) of
|
||||
undefined ->
|
||||
false;
|
||||
true ->
|
||||
gen_server:cast(GuildPid, {set_session_active, SessionId}),
|
||||
logger:debug("[guild_unified_subscriptions] Set session ~p active", [SessionId]),
|
||||
not WasActive;
|
||||
false ->
|
||||
gen_server:cast(GuildPid, {set_session_passive, SessionId}),
|
||||
logger:debug("[guild_unified_subscriptions] Set session ~p passive", [SessionId]),
|
||||
WasActive
|
||||
end.
|
||||
|
||||
-spec process_sync_flag(map(), integer(), pid(), binary() | undefined, boolean()) -> ok.
|
||||
process_sync_flag(GuildSubData, GuildId, GuildPid, SessionId, ActiveChanged) ->
|
||||
-spec process_sync_flag(map(), integer(), pid(), session_id(), boolean()) -> ok.
|
||||
process_sync_flag(GuildSubData, _GuildId, GuildPid, SessionId, ActiveChanged) ->
|
||||
ShouldSync = maps:get(<<"sync">>, GuildSubData, false) =:= true orelse ActiveChanged,
|
||||
case ShouldSync of
|
||||
true ->
|
||||
guild_sync:send_guild_sync(GuildPid, SessionId),
|
||||
logger:debug("[guild_unified_subscriptions] Sent guild sync for guild ~p", [GuildId]);
|
||||
guild_sync:send_guild_sync(GuildPid, SessionId);
|
||||
false ->
|
||||
ok
|
||||
end.
|
||||
|
||||
-spec process_member_list_channels(map(), integer(), pid(), binary() | undefined, pid()) -> ok.
|
||||
-spec process_member_list_channels(map(), integer(), pid(), session_id(), pid()) -> ok.
|
||||
process_member_list_channels(GuildSubData, GuildId, GuildPid, SessionId, SocketPid) ->
|
||||
case maps:get(<<"member_list_channels">>, GuildSubData, undefined) of
|
||||
undefined ->
|
||||
ok;
|
||||
MemberListChannels when is_map(MemberListChannels) ->
|
||||
logger:debug("[guild_unified_subscriptions] Processing ~p member list channels for guild ~p", [
|
||||
map_size(MemberListChannels), GuildId
|
||||
]),
|
||||
maps:foreach(
|
||||
fun(ChannelIdBin, Ranges) ->
|
||||
process_channel_lazy_subscribe(ChannelIdBin, Ranges, GuildId, GuildPid, SessionId, SocketPid)
|
||||
process_channel_lazy_subscribe(
|
||||
ChannelIdBin, Ranges, GuildId, GuildPid, SessionId, SocketPid
|
||||
)
|
||||
end,
|
||||
MemberListChannels
|
||||
);
|
||||
@@ -119,28 +112,29 @@ process_member_list_channels(GuildSubData, GuildId, GuildPid, SessionId, SocketP
|
||||
ok
|
||||
end.
|
||||
|
||||
-spec process_channel_lazy_subscribe(binary(), list(), integer(), pid(), binary() | undefined, pid()) -> ok.
|
||||
-spec process_channel_lazy_subscribe(binary(), list(), integer(), pid(), session_id(), pid()) -> ok.
|
||||
process_channel_lazy_subscribe(ChannelIdBin, Ranges, _GuildId, GuildPid, SessionId, _SocketPid) ->
|
||||
case validation:validate_snowflake(<<"channel_id">>, ChannelIdBin) of
|
||||
{ok, ChannelId} ->
|
||||
ParsedRanges = parse_ranges(Ranges),
|
||||
logger:debug("[guild_unified_subscriptions] Lazy subscribe channel ~p with ranges ~p", [
|
||||
ChannelId, ParsedRanges
|
||||
]),
|
||||
case gen_server:call(GuildPid, {lazy_subscribe, #{
|
||||
session_id => SessionId,
|
||||
channel_id => ChannelId,
|
||||
ranges => ParsedRanges
|
||||
}}, 10000) of
|
||||
case
|
||||
gen_server:call(
|
||||
GuildPid,
|
||||
{lazy_subscribe, #{
|
||||
session_id => SessionId,
|
||||
channel_id => ChannelId,
|
||||
ranges => ParsedRanges
|
||||
}},
|
||||
10000
|
||||
)
|
||||
of
|
||||
ok ->
|
||||
ok;
|
||||
Error ->
|
||||
logger:error("[guild_unified_subscriptions] lazy_subscribe failed for channel ~p: ~p", [
|
||||
ChannelId, Error
|
||||
])
|
||||
_Error ->
|
||||
ok
|
||||
end;
|
||||
{error, _, Reason} ->
|
||||
logger:warning("[guild_unified_subscriptions] Invalid channel_id ~p: ~p", [ChannelIdBin, Reason])
|
||||
{error, _, _} ->
|
||||
ok
|
||||
end,
|
||||
ok.
|
||||
|
||||
@@ -160,16 +154,13 @@ parse_ranges(Ranges) when is_list(Ranges) ->
|
||||
parse_ranges(_) ->
|
||||
[].
|
||||
|
||||
-spec process_member_subscriptions(map(), pid(), binary() | undefined) -> ok.
|
||||
-spec process_member_subscriptions(map(), pid(), session_id()) -> ok.
|
||||
process_member_subscriptions(GuildSubData, GuildPid, SessionId) ->
|
||||
case maps:get(<<"members">>, GuildSubData, undefined) of
|
||||
undefined ->
|
||||
ok;
|
||||
Members when is_list(Members) ->
|
||||
MemberIds = parse_member_ids(Members),
|
||||
logger:debug("[guild_unified_subscriptions] Updating member subscriptions with ~p members", [
|
||||
length(MemberIds)
|
||||
]),
|
||||
gen_server:cast(GuildPid, {update_member_subscriptions, SessionId, MemberIds});
|
||||
_ ->
|
||||
ok
|
||||
@@ -189,16 +180,13 @@ parse_member_ids(Members) when is_list(Members) ->
|
||||
parse_member_ids(_) ->
|
||||
[].
|
||||
|
||||
-spec process_typing_flag(map(), pid(), binary() | undefined) -> ok.
|
||||
-spec process_typing_flag(map(), pid(), session_id()) -> ok.
|
||||
process_typing_flag(GuildSubData, GuildPid, SessionId) ->
|
||||
case maps:get(<<"typing">>, GuildSubData, undefined) of
|
||||
undefined ->
|
||||
ok;
|
||||
TypingFlag when is_boolean(TypingFlag) ->
|
||||
gen_server:cast(GuildPid, {set_session_typing_override, SessionId, TypingFlag}),
|
||||
logger:debug("[guild_unified_subscriptions] Set typing override to ~p for session ~p", [
|
||||
TypingFlag, SessionId
|
||||
]);
|
||||
gen_server:cast(GuildPid, {set_session_typing_override, SessionId, TypingFlag});
|
||||
_ ->
|
||||
ok
|
||||
end.
|
||||
|
||||
@@ -24,129 +24,155 @@
|
||||
check_user_data_differs/2
|
||||
]).
|
||||
|
||||
-import(guild_permissions, [find_member_by_user_id/2]).
|
||||
-type guild_state() :: map().
|
||||
-type user_id() :: integer().
|
||||
-type member() :: map().
|
||||
|
||||
-ifdef(TEST).
|
||||
-include_lib("eunit/include/eunit.hrl").
|
||||
-endif.
|
||||
|
||||
-spec update_user_data(map(), guild_state()) -> {noreply, guild_state()}.
|
||||
update_user_data(EventData, State) ->
|
||||
UserId = utils:binary_to_integer_safe(maps:get(<<"id">>, EventData)),
|
||||
Data = maps:get(data, State),
|
||||
Members = maps:get(<<"members">>, Data, []),
|
||||
|
||||
UpdatedMembers = lists:map(
|
||||
fun(Member) when is_map(Member) ->
|
||||
MUser = maps:get(<<"user">>, Member, #{}),
|
||||
MemberId =
|
||||
case is_map(MUser) of
|
||||
true ->
|
||||
utils:binary_to_integer_safe(maps:get(<<"id">>, MUser, <<"0">>));
|
||||
false ->
|
||||
undefined
|
||||
end,
|
||||
if
|
||||
MemberId =:= UserId ->
|
||||
maps:put(<<"user">>, EventData, Member);
|
||||
true ->
|
||||
Member
|
||||
end
|
||||
Members = guild_data_index:member_map(Data),
|
||||
UpdatedMembers = maps:map(
|
||||
fun(_MemberUserId, Member) ->
|
||||
maybe_update_member_user(Member, UserId, EventData)
|
||||
end,
|
||||
Members
|
||||
),
|
||||
|
||||
UpdatedData = maps:put(<<"members">>, UpdatedMembers, Data),
|
||||
UpdatedData = guild_data_index:put_member_map(UpdatedMembers, Data),
|
||||
UpdatedState = maps:put(data, UpdatedData, State),
|
||||
|
||||
UpdatedMember = find_member_by_user_id(UserId, UpdatedState),
|
||||
case UpdatedMember of
|
||||
undefined -> ok;
|
||||
M -> gen_server:cast(self(), {dispatch, #{event => guild_member_update, data => M}})
|
||||
end,
|
||||
|
||||
dispatch_member_update_if_found(UserId, UpdatedState),
|
||||
{noreply, UpdatedState}.
|
||||
|
||||
handle_user_data_update(UserId, UserData, State) ->
|
||||
Data = maps:get(data, State),
|
||||
Members = maps:get(<<"members">>, Data, []),
|
||||
-spec maybe_update_member_user(member(), user_id(), map()) -> member().
|
||||
maybe_update_member_user(Member, UserId, EventData) ->
|
||||
MUser = maps:get(<<"user">>, Member, #{}),
|
||||
MemberId =
|
||||
case is_map(MUser) of
|
||||
true -> utils:binary_to_integer_safe(maps:get(<<"id">>, MUser, <<"0">>));
|
||||
false -> undefined
|
||||
end,
|
||||
case MemberId =:= UserId of
|
||||
true -> maps:put(<<"user">>, EventData, Member);
|
||||
false -> Member
|
||||
end.
|
||||
|
||||
CurrentMember = find_member_by_user_id(UserId, State),
|
||||
case CurrentMember of
|
||||
-spec dispatch_member_update_if_found(user_id(), guild_state()) -> ok.
|
||||
dispatch_member_update_if_found(UserId, State) ->
|
||||
case guild_permissions:find_member_by_user_id(UserId, State) of
|
||||
undefined -> ok;
|
||||
M -> gen_server:cast(self(), {dispatch, #{event => guild_member_update, data => M}})
|
||||
end.
|
||||
|
||||
-spec handle_user_data_update(user_id(), map(), guild_state()) -> guild_state().
|
||||
handle_user_data_update(UserId, UserData, State) ->
|
||||
case guild_permissions:find_member_by_user_id(UserId, State) of
|
||||
undefined ->
|
||||
State;
|
||||
Member ->
|
||||
CurrentUserData = maps:get(<<"user">>, Member, #{}),
|
||||
IsDifferent = check_user_data_differs(CurrentUserData, UserData),
|
||||
if
|
||||
IsDifferent ->
|
||||
UpdatedMembers = lists:map(
|
||||
fun(M) when is_map(M) ->
|
||||
MUser = maps:get(<<"user">>, M, #{}),
|
||||
MemberId =
|
||||
case is_map(MUser) of
|
||||
true ->
|
||||
utils:binary_to_integer_safe(
|
||||
maps:get(<<"id">>, MUser, <<"0">>)
|
||||
);
|
||||
false ->
|
||||
undefined
|
||||
end,
|
||||
if
|
||||
MemberId =:= UserId ->
|
||||
maps:put(<<"user">>, UserData, M);
|
||||
true ->
|
||||
M
|
||||
end
|
||||
end,
|
||||
Members
|
||||
),
|
||||
|
||||
UpdatedData = maps:put(<<"members">>, UpdatedMembers, Data),
|
||||
UpdatedState = maps:put(data, UpdatedData, State),
|
||||
|
||||
UpdatedMember = find_member_by_user_id(UserId, UpdatedState),
|
||||
case UpdatedMember of
|
||||
undefined ->
|
||||
ok;
|
||||
M ->
|
||||
GuildId = maps:get(id, UpdatedState),
|
||||
MemberUpdateData = maps:put(
|
||||
<<"guild_id">>, integer_to_binary(GuildId), M
|
||||
),
|
||||
gen_server:cast(
|
||||
self(),
|
||||
{dispatch, #{
|
||||
event => guild_member_update, data => MemberUpdateData
|
||||
}}
|
||||
)
|
||||
end,
|
||||
|
||||
UpdatedState;
|
||||
case check_user_data_differs(CurrentUserData, UserData) of
|
||||
false ->
|
||||
State;
|
||||
true ->
|
||||
State
|
||||
apply_user_data_update(UserId, UserData, State)
|
||||
end
|
||||
end.
|
||||
|
||||
-spec apply_user_data_update(user_id(), map(), guild_state()) -> guild_state().
|
||||
apply_user_data_update(UserId, UserData, State) ->
|
||||
Data = maps:get(data, State),
|
||||
Members = guild_data_index:member_map(Data),
|
||||
UpdatedMembers = maps:map(
|
||||
fun(_MemberUserId, Member) ->
|
||||
maybe_update_member_user(Member, UserId, UserData)
|
||||
end,
|
||||
Members
|
||||
),
|
||||
UpdatedData = guild_data_index:put_member_map(UpdatedMembers, Data),
|
||||
UpdatedState = maps:put(data, UpdatedData, State),
|
||||
dispatch_guild_member_update(UserId, UpdatedState),
|
||||
UpdatedState.
|
||||
|
||||
-spec dispatch_guild_member_update(user_id(), guild_state()) -> ok.
|
||||
dispatch_guild_member_update(UserId, State) ->
|
||||
case guild_permissions:find_member_by_user_id(UserId, State) of
|
||||
undefined ->
|
||||
ok;
|
||||
M ->
|
||||
GuildId = maps:get(id, State),
|
||||
MemberUpdateData = maps:put(<<"guild_id">>, integer_to_binary(GuildId), M),
|
||||
gen_server:cast(
|
||||
self(),
|
||||
{dispatch, #{event => guild_member_update, data => MemberUpdateData}}
|
||||
)
|
||||
end.
|
||||
|
||||
-spec check_user_data_differs(map(), map()) -> boolean().
|
||||
check_user_data_differs(CurrentUserData, NewUserData) ->
|
||||
utils:check_user_data_differs(CurrentUserData, NewUserData).
|
||||
|
||||
maybe_update_cached_user_data(Event, EventData, State) ->
|
||||
case Event of
|
||||
E when E =:= message_create; E =:= message_update ->
|
||||
case maps:get(<<"author">>, EventData, undefined) of
|
||||
-spec maybe_update_cached_user_data(atom(), map(), guild_state()) -> guild_state().
|
||||
maybe_update_cached_user_data(Event, EventData, State) when
|
||||
Event =:= message_create; Event =:= message_update
|
||||
->
|
||||
case maps:get(<<"author">>, EventData, undefined) of
|
||||
undefined ->
|
||||
State;
|
||||
AuthorData ->
|
||||
UserId = utils:binary_to_integer_safe(maps:get(<<"id">>, AuthorData, <<"0">>)),
|
||||
case guild_permissions:find_member_by_user_id(UserId, State) of
|
||||
undefined ->
|
||||
State;
|
||||
AuthorData ->
|
||||
UserId = utils:binary_to_integer_safe(maps:get(<<"id">>, AuthorData, <<"0">>)),
|
||||
case find_member_by_user_id(UserId, State) of
|
||||
undefined ->
|
||||
State;
|
||||
Member ->
|
||||
CurrentUserData = maps:get(<<"user">>, Member, #{}),
|
||||
case check_user_data_differs(CurrentUserData, AuthorData) of
|
||||
true ->
|
||||
handle_user_data_update(UserId, AuthorData, State);
|
||||
false ->
|
||||
State
|
||||
end
|
||||
Member ->
|
||||
CurrentUserData = maps:get(<<"user">>, Member, #{}),
|
||||
case check_user_data_differs(CurrentUserData, AuthorData) of
|
||||
true ->
|
||||
handle_user_data_update(UserId, AuthorData, State);
|
||||
false ->
|
||||
State
|
||||
end
|
||||
end;
|
||||
_ ->
|
||||
State
|
||||
end.
|
||||
end
|
||||
end;
|
||||
maybe_update_cached_user_data(_, _, State) ->
|
||||
State.
|
||||
|
||||
-ifdef(TEST).
|
||||
|
||||
update_user_data_updates_member_test() ->
|
||||
State = test_state(),
|
||||
EventData = #{<<"id">> => <<"100">>, <<"username">> => <<"updated">>},
|
||||
{noreply, UpdatedState} = update_user_data(EventData, State),
|
||||
Data = maps:get(data, UpdatedState),
|
||||
Member = maps:get(100, maps:get(<<"members">>, Data)),
|
||||
User = maps:get(<<"user">>, Member),
|
||||
?assertEqual(<<"updated">>, maps:get(<<"username">>, User)).
|
||||
|
||||
handle_user_data_update_no_change_test() ->
|
||||
State = test_state(),
|
||||
UserData = #{<<"id">> => <<"100">>, <<"username">> => <<"alice">>},
|
||||
NewState = handle_user_data_update(100, UserData, State),
|
||||
?assertEqual(State, NewState).
|
||||
|
||||
check_user_data_differs_test() ->
|
||||
Current = #{<<"username">> => <<"alice">>},
|
||||
Same = #{<<"username">> => <<"alice">>},
|
||||
Different = #{<<"username">> => <<"bob">>},
|
||||
?assertEqual(false, check_user_data_differs(Current, Same)),
|
||||
?assertEqual(true, check_user_data_differs(Current, Different)).
|
||||
|
||||
test_state() ->
|
||||
#{
|
||||
id => 42,
|
||||
data => #{
|
||||
<<"members">> => #{
|
||||
100 => #{<<"user">> => #{<<"id">> => <<"100">>, <<"username">> => <<"alice">>}}
|
||||
}
|
||||
}
|
||||
}.
|
||||
|
||||
-endif.
|
||||
|
||||
@@ -23,18 +23,37 @@
|
||||
has_virtual_access/3,
|
||||
get_virtual_channels_for_user/2,
|
||||
get_users_with_virtual_access/2,
|
||||
dispatch_channel_visibility_change/4
|
||||
dispatch_channel_visibility_change/4,
|
||||
mark_pending_join/3,
|
||||
clear_pending_join/3,
|
||||
is_pending_join/3,
|
||||
mark_preserve/3,
|
||||
clear_preserve/3,
|
||||
has_preserve/3,
|
||||
mark_move_pending/3,
|
||||
clear_move_pending/3,
|
||||
is_move_pending/3
|
||||
]).
|
||||
|
||||
-import(guild_permissions, [find_channel_by_id/2]).
|
||||
-type guild_state() :: map().
|
||||
-type user_id() :: integer().
|
||||
-type channel_id() :: integer().
|
||||
|
||||
-ifdef(TEST).
|
||||
-include_lib("eunit/include/eunit.hrl").
|
||||
-endif.
|
||||
|
||||
-spec add_virtual_access(user_id(), channel_id(), guild_state()) -> guild_state().
|
||||
add_virtual_access(UserId, ChannelId, State) ->
|
||||
VirtualAccess = maps:get(virtual_channel_access, State, #{}),
|
||||
UserChannels = maps:get(UserId, VirtualAccess, sets:new()),
|
||||
UpdatedUserChannels = sets:add_element(ChannelId, UserChannels),
|
||||
UpdatedVirtualAccess = maps:put(UserId, UpdatedUserChannels, VirtualAccess),
|
||||
maps:put(virtual_channel_access, UpdatedVirtualAccess, State).
|
||||
State1 = maps:put(virtual_channel_access, UpdatedVirtualAccess, State),
|
||||
State2 = update_user_session_view_cache(UserId, ChannelId, add, State1),
|
||||
mark_pending_join(UserId, ChannelId, State2).
|
||||
|
||||
-spec remove_virtual_access(user_id(), channel_id(), guild_state()) -> guild_state().
|
||||
remove_virtual_access(UserId, ChannelId, State) ->
|
||||
VirtualAccess = maps:get(virtual_channel_access, State, #{}),
|
||||
case maps:get(UserId, VirtualAccess, undefined) of
|
||||
@@ -44,32 +63,149 @@ remove_virtual_access(UserId, ChannelId, State) ->
|
||||
UpdatedUserChannels = sets:del_element(ChannelId, UserChannels),
|
||||
case sets:size(UpdatedUserChannels) of
|
||||
0 ->
|
||||
UpdatedVirtualAccess = maps:remove(UserId, VirtualAccess),
|
||||
maps:put(virtual_channel_access, UpdatedVirtualAccess, State);
|
||||
remove_all_user_virtual_access(UserId, State);
|
||||
_ ->
|
||||
UpdatedVirtualAccess = maps:put(UserId, UpdatedUserChannels, VirtualAccess),
|
||||
maps:put(virtual_channel_access, UpdatedVirtualAccess, State)
|
||||
update_user_virtual_access(UserId, ChannelId, UpdatedUserChannels, State)
|
||||
end
|
||||
end.
|
||||
|
||||
-spec remove_all_user_virtual_access(user_id(), guild_state()) -> guild_state().
|
||||
remove_all_user_virtual_access(UserId, State) ->
|
||||
VirtualAccess = maps:get(virtual_channel_access, State, #{}),
|
||||
PendingMap = maps:get(virtual_channel_access_pending, State, #{}),
|
||||
PreserveMap = maps:get(virtual_channel_access_preserve, State, #{}),
|
||||
MoveMap = maps:get(virtual_channel_access_move_pending, State, #{}),
|
||||
UpdatedVirtualAccess = maps:remove(UserId, VirtualAccess),
|
||||
UpdatedPending = maps:remove(UserId, PendingMap),
|
||||
UpdatedPreserve = maps:remove(UserId, PreserveMap),
|
||||
UpdatedMove = maps:remove(UserId, MoveMap),
|
||||
State1 = maps:put(virtual_channel_access, UpdatedVirtualAccess, State),
|
||||
State2 = maps:put(virtual_channel_access_pending, UpdatedPending, State1),
|
||||
State3 = maps:put(virtual_channel_access_preserve, UpdatedPreserve, State2),
|
||||
State4 = maps:put(virtual_channel_access_move_pending, UpdatedMove, State3),
|
||||
clear_user_session_view_cache(UserId, State4).
|
||||
|
||||
-spec update_user_virtual_access(user_id(), channel_id(), sets:set(), guild_state()) ->
|
||||
guild_state().
|
||||
update_user_virtual_access(UserId, ChannelId, UpdatedUserChannels, State) ->
|
||||
VirtualAccess = maps:get(virtual_channel_access, State, #{}),
|
||||
PendingMap = maps:get(virtual_channel_access_pending, State, #{}),
|
||||
PreserveMap = maps:get(virtual_channel_access_preserve, State, #{}),
|
||||
MoveMap = maps:get(virtual_channel_access_move_pending, State, #{}),
|
||||
UpdatedVirtualAccess = maps:put(UserId, UpdatedUserChannels, VirtualAccess),
|
||||
UpdatedUserPending = sets:del_element(ChannelId, maps:get(UserId, PendingMap, sets:new())),
|
||||
UpdatedPending = maps:put(UserId, UpdatedUserPending, PendingMap),
|
||||
UpdatedUserPreserve = sets:del_element(ChannelId, maps:get(UserId, PreserveMap, sets:new())),
|
||||
UpdatedPreserve = maps:put(UserId, UpdatedUserPreserve, PreserveMap),
|
||||
UpdatedUserMove = sets:del_element(ChannelId, maps:get(UserId, MoveMap, sets:new())),
|
||||
UpdatedMove = maps:put(UserId, UpdatedUserMove, MoveMap),
|
||||
State1 = maps:put(virtual_channel_access, UpdatedVirtualAccess, State),
|
||||
State2 = maps:put(virtual_channel_access_pending, UpdatedPending, State1),
|
||||
State3 = maps:put(virtual_channel_access_preserve, UpdatedPreserve, State2),
|
||||
State4 = maps:put(virtual_channel_access_move_pending, UpdatedMove, State3),
|
||||
update_user_session_view_cache(UserId, ChannelId, remove, State4).
|
||||
|
||||
-spec has_virtual_access(user_id(), channel_id(), guild_state()) -> boolean().
|
||||
has_virtual_access(UserId, ChannelId, State) ->
|
||||
VirtualAccess = maps:get(virtual_channel_access, State, #{}),
|
||||
case maps:get(UserId, VirtualAccess, undefined) of
|
||||
undefined ->
|
||||
false;
|
||||
UserChannels ->
|
||||
sets:is_element(ChannelId, UserChannels)
|
||||
undefined -> false;
|
||||
UserChannels -> sets:is_element(ChannelId, UserChannels)
|
||||
end.
|
||||
|
||||
-spec get_virtual_channels_for_user(user_id(), guild_state()) -> [channel_id()].
|
||||
get_virtual_channels_for_user(UserId, State) ->
|
||||
VirtualAccess = maps:get(virtual_channel_access, State, #{}),
|
||||
case maps:get(UserId, VirtualAccess, undefined) of
|
||||
undefined ->
|
||||
[];
|
||||
UserChannels ->
|
||||
sets:to_list(UserChannels)
|
||||
undefined -> [];
|
||||
UserChannels -> sets:to_list(UserChannels)
|
||||
end.
|
||||
|
||||
-spec mark_pending_join(user_id(), channel_id(), guild_state()) -> guild_state().
|
||||
mark_pending_join(UserId, ChannelId, State) ->
|
||||
PendingMap = maps:get(virtual_channel_access_pending, State, #{}),
|
||||
UserPending = maps:get(UserId, PendingMap, sets:new()),
|
||||
UpdatedUserPending = sets:add_element(ChannelId, UserPending),
|
||||
UpdatedPending = maps:put(UserId, UpdatedUserPending, PendingMap),
|
||||
maps:put(virtual_channel_access_pending, UpdatedPending, State).
|
||||
|
||||
-spec clear_pending_join(user_id(), channel_id(), guild_state()) -> guild_state().
|
||||
clear_pending_join(UserId, ChannelId, State) ->
|
||||
PendingMap = maps:get(virtual_channel_access_pending, State, #{}),
|
||||
UserPending = maps:get(UserId, PendingMap, sets:new()),
|
||||
UpdatedUserPending = sets:del_element(ChannelId, UserPending),
|
||||
UpdatedPending =
|
||||
case sets:size(UpdatedUserPending) of
|
||||
0 -> maps:remove(UserId, PendingMap);
|
||||
_ -> maps:put(UserId, UpdatedUserPending, PendingMap)
|
||||
end,
|
||||
maps:put(virtual_channel_access_pending, UpdatedPending, State).
|
||||
|
||||
-spec is_pending_join(user_id(), channel_id(), guild_state()) -> boolean().
|
||||
is_pending_join(UserId, ChannelId, State) ->
|
||||
PendingMap = maps:get(virtual_channel_access_pending, State, #{}),
|
||||
case maps:get(UserId, PendingMap, undefined) of
|
||||
undefined -> false;
|
||||
UserPending -> sets:is_element(ChannelId, UserPending)
|
||||
end.
|
||||
|
||||
-spec mark_preserve(user_id(), channel_id(), guild_state()) -> guild_state().
|
||||
mark_preserve(UserId, ChannelId, State) ->
|
||||
PreserveMap = maps:get(virtual_channel_access_preserve, State, #{}),
|
||||
UserPreserve = maps:get(UserId, PreserveMap, sets:new()),
|
||||
UpdatedUserPreserve = sets:add_element(ChannelId, UserPreserve),
|
||||
UpdatedPreserve = maps:put(UserId, UpdatedUserPreserve, PreserveMap),
|
||||
maps:put(virtual_channel_access_preserve, UpdatedPreserve, State).
|
||||
|
||||
-spec clear_preserve(user_id(), channel_id(), guild_state()) -> guild_state().
|
||||
clear_preserve(UserId, ChannelId, State) ->
|
||||
PreserveMap = maps:get(virtual_channel_access_preserve, State, #{}),
|
||||
UserPreserve = maps:get(UserId, PreserveMap, sets:new()),
|
||||
UpdatedUserPreserve = sets:del_element(ChannelId, UserPreserve),
|
||||
UpdatedPreserve =
|
||||
case sets:size(UpdatedUserPreserve) of
|
||||
0 -> maps:remove(UserId, PreserveMap);
|
||||
_ -> maps:put(UserId, UpdatedUserPreserve, PreserveMap)
|
||||
end,
|
||||
maps:put(virtual_channel_access_preserve, UpdatedPreserve, State).
|
||||
|
||||
-spec has_preserve(user_id(), channel_id(), guild_state()) -> boolean().
|
||||
has_preserve(UserId, ChannelId, State) ->
|
||||
PreserveMap = maps:get(virtual_channel_access_preserve, State, #{}),
|
||||
case maps:get(UserId, PreserveMap, undefined) of
|
||||
undefined -> false;
|
||||
UserPreserve -> sets:is_element(ChannelId, UserPreserve)
|
||||
end.
|
||||
|
||||
-spec mark_move_pending(user_id(), channel_id(), guild_state()) -> guild_state().
|
||||
mark_move_pending(UserId, ChannelId, State) ->
|
||||
MoveMap = maps:get(virtual_channel_access_move_pending, State, #{}),
|
||||
UserMoves = maps:get(UserId, MoveMap, sets:new()),
|
||||
UpdatedUserMoves = sets:add_element(ChannelId, UserMoves),
|
||||
UpdatedMoveMap = maps:put(UserId, UpdatedUserMoves, MoveMap),
|
||||
maps:put(virtual_channel_access_move_pending, UpdatedMoveMap, State).
|
||||
|
||||
-spec clear_move_pending(user_id(), channel_id(), guild_state()) -> guild_state().
|
||||
clear_move_pending(UserId, ChannelId, State) ->
|
||||
MoveMap = maps:get(virtual_channel_access_move_pending, State, #{}),
|
||||
UserMoves = maps:get(UserId, MoveMap, sets:new()),
|
||||
UpdatedUserMoves = sets:del_element(ChannelId, UserMoves),
|
||||
UpdatedMoveMap =
|
||||
case sets:size(UpdatedUserMoves) of
|
||||
0 -> maps:remove(UserId, MoveMap);
|
||||
_ -> maps:put(UserId, UpdatedUserMoves, MoveMap)
|
||||
end,
|
||||
maps:put(virtual_channel_access_move_pending, UpdatedMoveMap, State).
|
||||
|
||||
-spec is_move_pending(user_id(), channel_id(), guild_state()) -> boolean().
|
||||
is_move_pending(UserId, ChannelId, State) ->
|
||||
MoveMap = maps:get(virtual_channel_access_move_pending, State, #{}),
|
||||
case maps:get(UserId, MoveMap, undefined) of
|
||||
undefined -> false;
|
||||
UserMoves -> sets:is_element(ChannelId, UserMoves)
|
||||
end.
|
||||
|
||||
-spec get_users_with_virtual_access(channel_id(), guild_state()) -> [user_id()].
|
||||
get_users_with_virtual_access(ChannelId, State) ->
|
||||
VirtualAccess = maps:get(virtual_channel_access, State, #{}),
|
||||
maps:fold(
|
||||
@@ -83,45 +219,160 @@ get_users_with_virtual_access(ChannelId, State) ->
|
||||
VirtualAccess
|
||||
).
|
||||
|
||||
-spec dispatch_channel_visibility_change(user_id(), channel_id(), add | remove, guild_state()) ->
|
||||
ok.
|
||||
dispatch_channel_visibility_change(UserId, ChannelId, Action, State) ->
|
||||
Channel = find_channel_by_id(ChannelId, State),
|
||||
Channel = guild_permissions:find_channel_by_id(ChannelId, State),
|
||||
case Channel of
|
||||
undefined ->
|
||||
ok;
|
||||
_ ->
|
||||
Sessions = maps:get(sessions, State, #{}),
|
||||
GuildId = maps:get(id, State),
|
||||
|
||||
UserSessions = maps:filter(
|
||||
fun(_Sid, SessionData) ->
|
||||
maps:get(user_id, SessionData) =:= UserId
|
||||
end,
|
||||
Sessions
|
||||
),
|
||||
|
||||
case Action of
|
||||
add ->
|
||||
ChannelWithGuild = maps:put(
|
||||
<<"guild_id">>, integer_to_binary(GuildId), Channel
|
||||
),
|
||||
maps:foreach(
|
||||
fun(_Sid, SessionData) ->
|
||||
Pid = maps:get(pid, SessionData),
|
||||
gen_server:cast(Pid, {dispatch, channel_create, ChannelWithGuild})
|
||||
end,
|
||||
UserSessions
|
||||
);
|
||||
remove ->
|
||||
ChannelDelete = #{
|
||||
<<"id">> => integer_to_binary(ChannelId),
|
||||
<<"guild_id">> => integer_to_binary(GuildId)
|
||||
},
|
||||
maps:foreach(
|
||||
fun(_Sid, SessionData) ->
|
||||
Pid = maps:get(pid, SessionData),
|
||||
gen_server:cast(Pid, {dispatch, channel_delete, ChannelDelete})
|
||||
end,
|
||||
UserSessions
|
||||
)
|
||||
end
|
||||
dispatch_to_user_sessions(Action, Channel, ChannelId, GuildId, UserSessions)
|
||||
end.
|
||||
|
||||
-spec dispatch_to_user_sessions(add | remove, map(), channel_id(), integer(), map()) -> ok.
|
||||
dispatch_to_user_sessions(add, Channel, _ChannelId, GuildId, UserSessions) ->
|
||||
ChannelWithGuild = maps:put(<<"guild_id">>, integer_to_binary(GuildId), Channel),
|
||||
maps:foreach(
|
||||
fun(_Sid, SessionData) ->
|
||||
Pid = maps:get(pid, SessionData),
|
||||
gen_server:cast(Pid, {dispatch, channel_create, ChannelWithGuild})
|
||||
end,
|
||||
UserSessions
|
||||
);
|
||||
dispatch_to_user_sessions(remove, _Channel, ChannelId, GuildId, UserSessions) ->
|
||||
ChannelDelete = #{
|
||||
<<"id">> => integer_to_binary(ChannelId),
|
||||
<<"guild_id">> => integer_to_binary(GuildId)
|
||||
},
|
||||
maps:foreach(
|
||||
fun(_Sid, SessionData) ->
|
||||
Pid = maps:get(pid, SessionData),
|
||||
gen_server:cast(Pid, {dispatch, channel_delete, ChannelDelete})
|
||||
end,
|
||||
UserSessions
|
||||
).
|
||||
|
||||
-spec update_user_session_view_cache(user_id(), channel_id(), add | remove, guild_state()) ->
|
||||
guild_state().
|
||||
update_user_session_view_cache(UserId, ChannelId, Action, State) ->
|
||||
Sessions = maps:get(sessions, State, #{}),
|
||||
UpdatedSessions = maps:map(
|
||||
fun(_SessionId, SessionData) ->
|
||||
case maps:get(user_id, SessionData, undefined) of
|
||||
UserId ->
|
||||
update_session_view_cache(SessionData, ChannelId, Action);
|
||||
_ ->
|
||||
SessionData
|
||||
end
|
||||
end,
|
||||
Sessions
|
||||
),
|
||||
maps:put(sessions, UpdatedSessions, State).
|
||||
|
||||
-spec clear_user_session_view_cache(user_id(), guild_state()) -> guild_state().
|
||||
clear_user_session_view_cache(UserId, State) ->
|
||||
Sessions = maps:get(sessions, State, #{}),
|
||||
UpdatedSessions = maps:map(
|
||||
fun(_SessionId, SessionData) ->
|
||||
case maps:get(user_id, SessionData, undefined) of
|
||||
UserId ->
|
||||
maps:put(viewable_channels, #{}, SessionData);
|
||||
_ ->
|
||||
SessionData
|
||||
end
|
||||
end,
|
||||
Sessions
|
||||
),
|
||||
maps:put(sessions, UpdatedSessions, State).
|
||||
|
||||
-spec update_session_view_cache(map(), channel_id(), add | remove) -> map().
|
||||
update_session_view_cache(SessionData, ChannelId, add) ->
|
||||
ViewableChannels = ensure_viewable_channel_map(maps:get(viewable_channels, SessionData, #{})),
|
||||
UpdatedViewableChannels = maps:put(ChannelId, true, ViewableChannels),
|
||||
maps:put(viewable_channels, UpdatedViewableChannels, SessionData);
|
||||
update_session_view_cache(SessionData, ChannelId, remove) ->
|
||||
ViewableChannels = ensure_viewable_channel_map(maps:get(viewable_channels, SessionData, #{})),
|
||||
UpdatedViewableChannels = maps:remove(ChannelId, ViewableChannels),
|
||||
maps:put(viewable_channels, UpdatedViewableChannels, SessionData).
|
||||
|
||||
-spec ensure_viewable_channel_map(term()) -> map().
|
||||
ensure_viewable_channel_map(ViewableChannels) when is_map(ViewableChannels) ->
|
||||
ViewableChannels;
|
||||
ensure_viewable_channel_map(_) ->
|
||||
#{}.
|
||||
|
||||
-ifdef(TEST).
|
||||
|
||||
add_virtual_access_test() ->
|
||||
State = #{},
|
||||
State1 = add_virtual_access(100, 500, State),
|
||||
?assertEqual(true, has_virtual_access(100, 500, State1)),
|
||||
?assertEqual(true, is_pending_join(100, 500, State1)).
|
||||
|
||||
add_virtual_access_updates_session_cache_test() ->
|
||||
State = #{
|
||||
sessions => #{
|
||||
<<"s1">> => #{user_id => 100, viewable_channels => #{}}
|
||||
}
|
||||
},
|
||||
State1 = add_virtual_access(100, 500, State),
|
||||
Session = maps:get(<<"s1">>, maps:get(sessions, State1)),
|
||||
ViewableChannels = maps:get(viewable_channels, Session, #{}),
|
||||
?assertEqual(true, maps:is_key(500, ViewableChannels)).
|
||||
|
||||
remove_virtual_access_test() ->
|
||||
State = add_virtual_access(100, 500, #{}),
|
||||
State1 = remove_virtual_access(100, 500, State),
|
||||
?assertEqual(false, has_virtual_access(100, 500, State1)).
|
||||
|
||||
remove_virtual_access_updates_session_cache_test() ->
|
||||
State = #{
|
||||
sessions => #{
|
||||
<<"s1">> => #{user_id => 100, viewable_channels => #{500 => true}}
|
||||
},
|
||||
virtual_channel_access => #{100 => sets:from_list([500])},
|
||||
virtual_channel_access_pending => #{100 => sets:from_list([500])},
|
||||
virtual_channel_access_preserve => #{100 => sets:new()},
|
||||
virtual_channel_access_move_pending => #{100 => sets:new()}
|
||||
},
|
||||
State1 = remove_virtual_access(100, 500, State),
|
||||
Session = maps:get(<<"s1">>, maps:get(sessions, State1)),
|
||||
ViewableChannels = maps:get(viewable_channels, Session, #{}),
|
||||
?assertEqual(false, maps:is_key(500, ViewableChannels)).
|
||||
|
||||
get_virtual_channels_for_user_test() ->
|
||||
State = add_virtual_access(100, 500, #{}),
|
||||
State1 = add_virtual_access(100, 501, State),
|
||||
Channels = lists:sort(get_virtual_channels_for_user(100, State1)),
|
||||
?assertEqual([500, 501], Channels).
|
||||
|
||||
get_users_with_virtual_access_test() ->
|
||||
State = add_virtual_access(100, 500, #{}),
|
||||
State1 = add_virtual_access(101, 500, State),
|
||||
Users = lists:sort(get_users_with_virtual_access(500, State1)),
|
||||
?assertEqual([100, 101], Users).
|
||||
|
||||
mark_and_clear_preserve_test() ->
|
||||
State = add_virtual_access(100, 500, #{}),
|
||||
State1 = mark_preserve(100, 500, State),
|
||||
?assertEqual(true, has_preserve(100, 500, State1)),
|
||||
State2 = clear_preserve(100, 500, State1),
|
||||
?assertEqual(false, has_preserve(100, 500, State2)).
|
||||
|
||||
mark_and_clear_move_pending_test() ->
|
||||
State = add_virtual_access(100, 500, #{}),
|
||||
State1 = mark_move_pending(100, 500, State),
|
||||
?assertEqual(true, is_move_pending(100, 500, State1)),
|
||||
State2 = clear_move_pending(100, 500, State1),
|
||||
?assertEqual(false, is_move_pending(100, 500, State2)).
|
||||
|
||||
-endif.
|
||||
|
||||
@@ -20,19 +20,25 @@
|
||||
-export([
|
||||
get_user_viewable_channels/2,
|
||||
compute_and_dispatch_visibility_changes/2,
|
||||
compute_and_dispatch_visibility_changes_for_users/3,
|
||||
compute_and_dispatch_visibility_changes_for_channels/3,
|
||||
viewable_channel_set/2,
|
||||
have_shared_viewable_channel/3
|
||||
]).
|
||||
|
||||
-import(guild_member_list, [calculate_list_id/2, build_sync_response/4]).
|
||||
-import(guild_permissions, [can_view_channel/4, find_member_by_user_id/2, find_channel_by_id/2]).
|
||||
-type guild_state() :: map().
|
||||
-type user_id() :: integer().
|
||||
-type channel_id() :: integer().
|
||||
|
||||
-spec get_user_viewable_channels(integer(), map()) -> [integer()].
|
||||
-ifdef(TEST).
|
||||
-include_lib("eunit/include/eunit.hrl").
|
||||
-endif.
|
||||
|
||||
-spec get_user_viewable_channels(user_id(), guild_state()) -> [channel_id()].
|
||||
get_user_viewable_channels(UserId, State) ->
|
||||
Data = map_utils:ensure_map(map_utils:get_safe(State, data, #{})),
|
||||
Channels = map_utils:ensure_list(maps:get(<<"channels">>, Data, [])),
|
||||
Member = find_member_by_user_id(UserId, State),
|
||||
|
||||
Member = guild_permissions:find_member_by_user_id(UserId, State),
|
||||
case Member of
|
||||
undefined ->
|
||||
[];
|
||||
@@ -44,7 +50,9 @@ get_user_viewable_channels(UserId, State) ->
|
||||
undefined ->
|
||||
false;
|
||||
_ ->
|
||||
case can_view_channel(UserId, ChannelId, Member, State) of
|
||||
case
|
||||
guild_permissions:can_view_channel(UserId, ChannelId, Member, State)
|
||||
of
|
||||
true -> {true, ChannelId};
|
||||
false -> false
|
||||
end
|
||||
@@ -54,62 +62,346 @@ get_user_viewable_channels(UserId, State) ->
|
||||
)
|
||||
end.
|
||||
|
||||
-spec viewable_channel_set(integer(), map()) -> sets:set().
|
||||
-spec viewable_channel_set(user_id(), guild_state()) -> sets:set(channel_id()).
|
||||
viewable_channel_set(UserId, State) when is_integer(UserId) ->
|
||||
sets:from_list(get_user_viewable_channels(UserId, State));
|
||||
case get_cached_viewable_channel_map(UserId, State) of
|
||||
undefined ->
|
||||
sets:from_list(get_user_viewable_channels(UserId, State));
|
||||
ViewableChannelMap ->
|
||||
sets:from_list(maps:keys(ViewableChannelMap))
|
||||
end;
|
||||
viewable_channel_set(_, _) ->
|
||||
sets:new().
|
||||
|
||||
-spec have_shared_viewable_channel(integer(), integer(), map()) -> boolean().
|
||||
have_shared_viewable_channel(UserId, OtherUserId, State) when is_integer(UserId), is_integer(OtherUserId), UserId =/= OtherUserId ->
|
||||
-spec get_cached_viewable_channel_map(user_id(), guild_state()) -> map() | undefined.
|
||||
get_cached_viewable_channel_map(UserId, State) ->
|
||||
Sessions = maps:get(sessions, State, #{}),
|
||||
maps:fold(
|
||||
fun(_SessionId, SessionData, Acc) ->
|
||||
case Acc of
|
||||
undefined ->
|
||||
case maps:get(user_id, SessionData, undefined) of
|
||||
UserId ->
|
||||
case maps:get(viewable_channels, SessionData, undefined) of
|
||||
ViewableChannels when is_map(ViewableChannels) ->
|
||||
ViewableChannels;
|
||||
_ ->
|
||||
undefined
|
||||
end;
|
||||
_ ->
|
||||
undefined
|
||||
end;
|
||||
_ ->
|
||||
Acc
|
||||
end
|
||||
end,
|
||||
undefined,
|
||||
Sessions
|
||||
).
|
||||
|
||||
-spec have_shared_viewable_channel(user_id(), user_id(), guild_state()) -> boolean().
|
||||
have_shared_viewable_channel(UserId, OtherUserId, State) when
|
||||
is_integer(UserId), is_integer(OtherUserId), UserId =/= OtherUserId
|
||||
->
|
||||
SetA = viewable_channel_set(UserId, State),
|
||||
SetB = viewable_channel_set(OtherUserId, State),
|
||||
not sets:is_empty(sets:intersection(SetA, SetB));
|
||||
have_shared_viewable_channel(_, _, _) ->
|
||||
false.
|
||||
|
||||
-spec compute_and_dispatch_visibility_changes(map(), map()) -> ok.
|
||||
compute_and_dispatch_visibility_changes(OldState, NewState) ->
|
||||
Sessions = maps:get(sessions, NewState, #{}),
|
||||
GuildId = maps:get(id, NewState, 0),
|
||||
-spec filter_connected_session_entries(map()) -> [{binary(), map()}].
|
||||
filter_connected_session_entries(Sessions) ->
|
||||
[{Sid, S} || {Sid, S} <- maps:to_list(Sessions), maps:get(pending_connect, S, false) =/= true].
|
||||
|
||||
lists:foreach(
|
||||
fun({SessionId, SessionData}) ->
|
||||
-spec compute_and_dispatch_visibility_changes(guild_state(), guild_state()) -> guild_state().
|
||||
compute_and_dispatch_visibility_changes(OldState, NewState) ->
|
||||
compute_and_dispatch_visibility_changes_for_sessions(
|
||||
filter_connected_session_entries(maps:get(sessions, NewState, #{})),
|
||||
OldState,
|
||||
NewState
|
||||
).
|
||||
|
||||
-spec compute_and_dispatch_visibility_changes_for_channels(
|
||||
[channel_id()], guild_state(), guild_state()
|
||||
) ->
|
||||
guild_state().
|
||||
compute_and_dispatch_visibility_changes_for_channels(ChannelIds, OldState, NewState) ->
|
||||
ValidChannelIds = lists:usort([Id || Id <- ChannelIds, is_integer(Id), Id > 0]),
|
||||
case ValidChannelIds of
|
||||
[] ->
|
||||
compute_and_dispatch_visibility_changes(OldState, NewState);
|
||||
_ ->
|
||||
compute_and_dispatch_channel_visibility_changes_for_sessions(
|
||||
filter_connected_session_entries(maps:get(sessions, NewState, #{})),
|
||||
ValidChannelIds,
|
||||
OldState,
|
||||
NewState
|
||||
)
|
||||
end.
|
||||
|
||||
-spec compute_and_dispatch_visibility_changes_for_users([user_id()], guild_state(), guild_state()) ->
|
||||
guild_state().
|
||||
compute_and_dispatch_visibility_changes_for_users(UserIds, OldState, NewState) ->
|
||||
Sessions = maps:get(sessions, NewState, #{}),
|
||||
UserIdSet = sets:from_list(UserIds),
|
||||
TargetSessions = lists:filter(
|
||||
fun({_SessionId, SessionData}) ->
|
||||
maps:get(pending_connect, SessionData, false) =/= true andalso
|
||||
begin
|
||||
SessionUserId = maps:get(user_id, SessionData, undefined),
|
||||
is_integer(SessionUserId) andalso sets:is_element(SessionUserId, UserIdSet)
|
||||
end
|
||||
end,
|
||||
maps:to_list(Sessions)
|
||||
),
|
||||
compute_and_dispatch_visibility_changes_for_sessions(TargetSessions, OldState, NewState).
|
||||
|
||||
-spec compute_and_dispatch_channel_visibility_changes_for_sessions(
|
||||
[{binary(), map()}], [channel_id()], guild_state(), guild_state()
|
||||
) ->
|
||||
guild_state().
|
||||
compute_and_dispatch_channel_visibility_changes_for_sessions(
|
||||
SessionEntries, ChannelIds, OldState, NewState
|
||||
) ->
|
||||
GuildId = maps:get(id, NewState, 0),
|
||||
lists:foldl(
|
||||
fun({SessionId, SessionData}, AccState) ->
|
||||
compute_channel_visibility_changes_for_session(
|
||||
SessionId,
|
||||
SessionData,
|
||||
ChannelIds,
|
||||
OldState,
|
||||
AccState,
|
||||
GuildId
|
||||
)
|
||||
end,
|
||||
NewState,
|
||||
SessionEntries
|
||||
).
|
||||
|
||||
-spec compute_channel_visibility_changes_for_session(
|
||||
binary(), map(), [channel_id()], guild_state(), guild_state(), integer()
|
||||
) ->
|
||||
guild_state().
|
||||
compute_channel_visibility_changes_for_session(
|
||||
SessionId, SessionData, ChannelIds, OldState, NewState, GuildId
|
||||
) ->
|
||||
UserId = maps:get(user_id, SessionData, undefined),
|
||||
case is_integer(UserId) of
|
||||
false ->
|
||||
NewState;
|
||||
true ->
|
||||
Pid = maps:get(pid, SessionData, undefined),
|
||||
OldMember = guild_permissions:find_member_by_user_id(UserId, OldState),
|
||||
ConnectedSet = connected_voice_channel_set(UserId, NewState),
|
||||
InitialViewableMap = ensure_viewable_channel_map(SessionData, UserId, OldState),
|
||||
{FinalViewableMap, StateAfterChannels} = lists:foldl(
|
||||
fun(ChannelId, {ViewableMapAcc, StateAcc}) ->
|
||||
OldVisible = channel_is_visible(UserId, ChannelId, OldMember, OldState),
|
||||
{StateAfterPreserve, NewVisible} = ensure_new_channel_visibility(
|
||||
UserId,
|
||||
ChannelId,
|
||||
ConnectedSet,
|
||||
StateAcc
|
||||
),
|
||||
UpdatedViewableMap = update_viewable_map_for_channel(
|
||||
ViewableMapAcc, ChannelId, NewVisible
|
||||
),
|
||||
case {OldVisible, NewVisible} of
|
||||
{true, false} ->
|
||||
dispatch_channel_delete(ChannelId, Pid, OldState, GuildId),
|
||||
{UpdatedViewableMap, StateAfterPreserve};
|
||||
{false, true} ->
|
||||
dispatch_channel_create(
|
||||
ChannelId, Pid, StateAfterPreserve, GuildId
|
||||
),
|
||||
send_member_list_sync(
|
||||
SessionId,
|
||||
SessionData,
|
||||
ChannelId,
|
||||
GuildId,
|
||||
StateAfterPreserve
|
||||
),
|
||||
{UpdatedViewableMap, StateAfterPreserve};
|
||||
_ ->
|
||||
{UpdatedViewableMap, StateAfterPreserve}
|
||||
end
|
||||
end,
|
||||
{InitialViewableMap, NewState},
|
||||
ChannelIds
|
||||
),
|
||||
guild_sessions:set_session_viewable_channels(
|
||||
SessionId,
|
||||
FinalViewableMap,
|
||||
StateAfterChannels
|
||||
)
|
||||
end.
|
||||
|
||||
-spec ensure_viewable_channel_map(map(), user_id(), guild_state()) -> #{channel_id() => true}.
|
||||
ensure_viewable_channel_map(SessionData, UserId, State) ->
|
||||
case maps:get(viewable_channels, SessionData, undefined) of
|
||||
ViewableChannels when is_map(ViewableChannels) ->
|
||||
ViewableChannels;
|
||||
_ ->
|
||||
viewable_channel_map(sets:from_list(get_user_viewable_channels(UserId, State)))
|
||||
end.
|
||||
|
||||
-spec ensure_new_channel_visibility(
|
||||
user_id(), channel_id(), sets:set(channel_id()), guild_state()
|
||||
) ->
|
||||
{guild_state(), boolean()}.
|
||||
ensure_new_channel_visibility(UserId, ChannelId, ConnectedSet, State) ->
|
||||
NewMember = guild_permissions:find_member_by_user_id(UserId, State),
|
||||
NewVisible0 = channel_is_visible(UserId, ChannelId, NewMember, State),
|
||||
case {NewVisible0, sets:is_element(ChannelId, ConnectedSet)} of
|
||||
{true, _} ->
|
||||
{State, true};
|
||||
{false, false} ->
|
||||
{State, false};
|
||||
{false, true} ->
|
||||
case guild_virtual_channel_access:has_virtual_access(UserId, ChannelId, State) of
|
||||
true ->
|
||||
{State, true};
|
||||
false ->
|
||||
State1 = guild_virtual_channel_access:add_virtual_access(
|
||||
UserId,
|
||||
ChannelId,
|
||||
State
|
||||
),
|
||||
State2 = guild_virtual_channel_access:clear_pending_join(
|
||||
UserId,
|
||||
ChannelId,
|
||||
State1
|
||||
),
|
||||
{State2, true}
|
||||
end
|
||||
end.
|
||||
|
||||
-spec channel_is_visible(user_id(), channel_id(), map() | undefined, guild_state()) -> boolean().
|
||||
channel_is_visible(UserId, ChannelId, Member, State) ->
|
||||
guild_permissions:can_view_channel(UserId, ChannelId, Member, State).
|
||||
|
||||
-spec update_viewable_map_for_channel(map(), channel_id(), boolean()) -> map().
|
||||
update_viewable_map_for_channel(ViewableMap, ChannelId, true) ->
|
||||
maps:put(ChannelId, true, ViewableMap);
|
||||
update_viewable_map_for_channel(ViewableMap, ChannelId, false) ->
|
||||
maps:remove(ChannelId, ViewableMap).
|
||||
|
||||
-spec compute_and_dispatch_visibility_changes_for_sessions(
|
||||
[{binary(), map()}], guild_state(), guild_state()
|
||||
) -> guild_state().
|
||||
compute_and_dispatch_visibility_changes_for_sessions(SessionEntries, OldState, NewState) ->
|
||||
GuildId = maps:get(id, NewState, 0),
|
||||
lists:foldl(
|
||||
fun({SessionId, SessionData}, AccState) ->
|
||||
UserId = maps:get(user_id, SessionData),
|
||||
Pid = maps:get(pid, SessionData),
|
||||
|
||||
OldViewable = get_user_viewable_channels(UserId, OldState),
|
||||
NewViewable = get_user_viewable_channels(UserId, NewState),
|
||||
|
||||
OldSet = sets:from_list(OldViewable),
|
||||
OldSet = cached_viewable_channel_set(SessionData, UserId, OldState),
|
||||
NewViewable = get_user_viewable_channels(UserId, AccState),
|
||||
NewSet = sets:from_list(NewViewable),
|
||||
|
||||
Removed = sets:subtract(OldSet, NewSet),
|
||||
Added = sets:subtract(NewSet, OldSet),
|
||||
|
||||
ConnectedSet = connected_voice_channel_set(UserId, AccState),
|
||||
Removed0 = sets:subtract(OldSet, NewSet),
|
||||
{StateWithAccess, PreservedSet} = preserve_connected_channels(
|
||||
UserId,
|
||||
Removed0,
|
||||
ConnectedSet,
|
||||
AccState
|
||||
),
|
||||
NewSet2 = sets:union(NewSet, PreservedSet),
|
||||
StateWithCachedVisibility = guild_sessions:set_session_viewable_channels(
|
||||
SessionId,
|
||||
viewable_channel_map(NewSet2),
|
||||
StateWithAccess
|
||||
),
|
||||
Removed = sets:subtract(OldSet, NewSet2),
|
||||
Added = sets:subtract(NewSet2, OldSet),
|
||||
lists:foreach(
|
||||
fun(ChannelId) ->
|
||||
dispatch_channel_delete(ChannelId, Pid, OldState, GuildId)
|
||||
end,
|
||||
sets:to_list(Removed)
|
||||
),
|
||||
|
||||
lists:foreach(
|
||||
fun(ChannelId) ->
|
||||
dispatch_channel_create(ChannelId, Pid, NewState, GuildId),
|
||||
send_member_list_sync(SessionId, SessionData, ChannelId, GuildId, NewState)
|
||||
dispatch_channel_create(ChannelId, Pid, StateWithCachedVisibility, GuildId),
|
||||
send_member_list_sync(
|
||||
SessionId, SessionData, ChannelId, GuildId, StateWithCachedVisibility
|
||||
)
|
||||
end,
|
||||
sets:to_list(Added)
|
||||
)
|
||||
),
|
||||
StateWithCachedVisibility
|
||||
end,
|
||||
maps:to_list(Sessions)
|
||||
),
|
||||
ok.
|
||||
NewState,
|
||||
SessionEntries
|
||||
).
|
||||
|
||||
-spec cached_viewable_channel_set(map(), user_id(), guild_state()) -> sets:set(channel_id()).
|
||||
cached_viewable_channel_set(SessionData, UserId, State) ->
|
||||
case maps:get(viewable_channels, SessionData, undefined) of
|
||||
ViewableMap when is_map(ViewableMap) ->
|
||||
sets:from_list(maps:keys(ViewableMap));
|
||||
_ ->
|
||||
sets:from_list(get_user_viewable_channels(UserId, State))
|
||||
end.
|
||||
|
||||
-spec viewable_channel_map(sets:set(channel_id())) -> #{channel_id() => true}.
|
||||
viewable_channel_map(ChannelSet) ->
|
||||
sets:fold(
|
||||
fun(ChannelId, Acc) ->
|
||||
maps:put(ChannelId, true, Acc)
|
||||
end,
|
||||
#{},
|
||||
ChannelSet
|
||||
).
|
||||
|
||||
-spec connected_voice_channel_set(user_id(), guild_state()) -> sets:set(channel_id()).
|
||||
connected_voice_channel_set(UserId, State) ->
|
||||
VoiceStates = voice_state_utils:voice_states(State),
|
||||
maps:fold(
|
||||
fun(_ConnId, VoiceState, Acc) ->
|
||||
case voice_state_utils:voice_state_user_id(VoiceState) of
|
||||
UserId ->
|
||||
case voice_state_utils:voice_state_channel_id(VoiceState) of
|
||||
undefined -> Acc;
|
||||
ChannelId -> sets:add_element(ChannelId, Acc)
|
||||
end;
|
||||
_ ->
|
||||
Acc
|
||||
end
|
||||
end,
|
||||
sets:new(),
|
||||
VoiceStates
|
||||
).
|
||||
|
||||
-spec preserve_connected_channels(
|
||||
user_id(), sets:set(channel_id()), sets:set(channel_id()), guild_state()
|
||||
) ->
|
||||
{guild_state(), sets:set(channel_id())}.
|
||||
preserve_connected_channels(UserId, RemovedSet, ConnectedSet, State) ->
|
||||
ToPreserve = sets:intersection(RemovedSet, ConnectedSet),
|
||||
UpdatedState = sets:fold(
|
||||
fun(ChannelId, AccState) ->
|
||||
case guild_virtual_channel_access:has_virtual_access(UserId, ChannelId, AccState) of
|
||||
true ->
|
||||
AccState;
|
||||
false ->
|
||||
State1 = guild_virtual_channel_access:add_virtual_access(
|
||||
UserId, ChannelId, AccState
|
||||
),
|
||||
guild_virtual_channel_access:clear_pending_join(UserId, ChannelId, State1)
|
||||
end
|
||||
end,
|
||||
State,
|
||||
ToPreserve
|
||||
),
|
||||
{UpdatedState, ToPreserve}.
|
||||
|
||||
-spec dispatch_channel_delete(channel_id(), pid(), guild_state(), integer()) -> ok.
|
||||
dispatch_channel_delete(ChannelId, SessionPid, OldState, GuildId) ->
|
||||
case is_pid(SessionPid) of
|
||||
true ->
|
||||
case find_channel_by_id(ChannelId, OldState) of
|
||||
case guild_permissions:find_channel_by_id(ChannelId, OldState) of
|
||||
undefined ->
|
||||
ok;
|
||||
_Channel ->
|
||||
@@ -123,10 +415,11 @@ dispatch_channel_delete(ChannelId, SessionPid, OldState, GuildId) ->
|
||||
ok
|
||||
end.
|
||||
|
||||
-spec dispatch_channel_create(channel_id(), pid(), guild_state(), integer()) -> ok.
|
||||
dispatch_channel_create(ChannelId, SessionPid, NewState, GuildId) ->
|
||||
case is_pid(SessionPid) of
|
||||
true ->
|
||||
case find_channel_by_id(ChannelId, NewState) of
|
||||
case guild_permissions:find_channel_by_id(ChannelId, NewState) of
|
||||
undefined ->
|
||||
ok;
|
||||
Channel ->
|
||||
@@ -139,13 +432,14 @@ dispatch_channel_create(ChannelId, SessionPid, NewState, GuildId) ->
|
||||
ok
|
||||
end.
|
||||
|
||||
-spec send_member_list_sync(binary(), map(), channel_id(), integer(), guild_state()) -> ok.
|
||||
send_member_list_sync(SessionId, SessionData, ChannelId, GuildId, State) ->
|
||||
SessionPid = maps:get(pid, SessionData),
|
||||
case is_pid(SessionPid) of
|
||||
false ->
|
||||
ok;
|
||||
true ->
|
||||
ListId = calculate_list_id(ChannelId, State),
|
||||
ListId = guild_member_list:calculate_list_id(ChannelId, State),
|
||||
MemberListSubs = maps:get(member_list_subscriptions, State, #{}),
|
||||
ListSubs = maps:get(ListId, MemberListSubs, #{}),
|
||||
Ranges = maps:get(SessionId, ListSubs, []),
|
||||
@@ -156,15 +450,254 @@ send_member_list_sync(SessionId, SessionData, ChannelId, GuildId, State) ->
|
||||
SessionUserId = maps:get(user_id, SessionData),
|
||||
case can_send_member_list(SessionUserId, ChannelId, State) of
|
||||
true ->
|
||||
SyncResponse = build_sync_response(GuildId, ListId, Ranges, State),
|
||||
SyncResponseWithChannel = maps:put(<<"channel_id">>, integer_to_binary(ChannelId), SyncResponse),
|
||||
gen_server:cast(SessionPid, {dispatch, guild_member_list_update, SyncResponseWithChannel});
|
||||
SyncResponse = guild_member_list:build_sync_response(
|
||||
GuildId, ListId, Ranges, State
|
||||
),
|
||||
SyncResponseWithChannel = maps:put(
|
||||
<<"channel_id">>, integer_to_binary(ChannelId), SyncResponse
|
||||
),
|
||||
gen_server:cast(
|
||||
SessionPid,
|
||||
{dispatch, guild_member_list_update, SyncResponseWithChannel}
|
||||
);
|
||||
false ->
|
||||
ok
|
||||
end
|
||||
end
|
||||
end.
|
||||
|
||||
-spec can_send_member_list(user_id() | undefined, channel_id(), guild_state()) -> boolean().
|
||||
can_send_member_list(UserId, ChannelId, State) ->
|
||||
is_integer(UserId) andalso
|
||||
guild_permissions:can_view_channel(UserId, ChannelId, undefined, State).
|
||||
|
||||
-ifdef(TEST).
|
||||
|
||||
get_user_viewable_channels_returns_empty_for_non_member_test() ->
|
||||
State = #{
|
||||
data => #{
|
||||
<<"channels">> => [#{<<"id">> => <<"100">>, <<"type">> => 0}],
|
||||
<<"members">> => []
|
||||
}
|
||||
},
|
||||
?assertEqual([], get_user_viewable_channels(999, State)).
|
||||
|
||||
viewable_channel_set_returns_empty_for_invalid_user_test() ->
|
||||
State = #{data => #{}},
|
||||
?assertEqual(sets:new(), viewable_channel_set(undefined, State)).
|
||||
|
||||
have_shared_viewable_channel_same_user_test() ->
|
||||
State = #{data => #{}},
|
||||
?assertEqual(false, have_shared_viewable_channel(100, 100, State)).
|
||||
|
||||
preserves_connected_channel_visibility_on_permission_loss_test() ->
|
||||
UserId = 10,
|
||||
GuildId = 1,
|
||||
ChannelId = 5,
|
||||
ViewPerm = constants:view_channel_permission(),
|
||||
OldState = visibility_state(GuildId, UserId, ChannelId, ViewPerm, true),
|
||||
NewState = visibility_state(GuildId, UserId, ChannelId, 0, true),
|
||||
UpdatedState = compute_and_dispatch_visibility_changes(OldState, NewState),
|
||||
?assert(guild_virtual_channel_access:has_virtual_access(UserId, ChannelId, UpdatedState)),
|
||||
?assertEqual(
|
||||
false,
|
||||
guild_virtual_channel_access:is_pending_join(UserId, ChannelId, UpdatedState)
|
||||
),
|
||||
?assert(guild_permissions:can_view_channel(UserId, ChannelId, undefined, UpdatedState)).
|
||||
|
||||
does_not_add_virtual_access_when_not_connected_test() ->
|
||||
UserId = 20,
|
||||
GuildId = 2,
|
||||
ChannelId = 6,
|
||||
ViewPerm = constants:view_channel_permission(),
|
||||
OldState = visibility_state(GuildId, UserId, ChannelId, ViewPerm, false),
|
||||
NewState = visibility_state(GuildId, UserId, ChannelId, 0, false),
|
||||
UpdatedState = compute_and_dispatch_visibility_changes(OldState, NewState),
|
||||
?assertNot(guild_virtual_channel_access:has_virtual_access(UserId, ChannelId, UpdatedState)).
|
||||
|
||||
does_not_add_virtual_access_when_permission_remains_test() ->
|
||||
UserId = 30,
|
||||
GuildId = 3,
|
||||
ChannelId = 7,
|
||||
ViewPerm = constants:view_channel_permission(),
|
||||
OldState = visibility_state(GuildId, UserId, ChannelId, ViewPerm, true),
|
||||
NewState = visibility_state(GuildId, UserId, ChannelId, ViewPerm, true),
|
||||
UpdatedState = compute_and_dispatch_visibility_changes(OldState, NewState),
|
||||
?assertNot(guild_virtual_channel_access:has_virtual_access(UserId, ChannelId, UpdatedState)).
|
||||
|
||||
compute_and_dispatch_visibility_changes_for_users_targets_selected_users_test() ->
|
||||
GuildId = 33,
|
||||
ChannelId = 77,
|
||||
RoleId = 101,
|
||||
ViewPerm = constants:view_channel_permission(),
|
||||
Session10 = #{
|
||||
session_id => <<"s10">>,
|
||||
user_id => 10,
|
||||
pid => self(),
|
||||
viewable_channels => #{ChannelId => true}
|
||||
},
|
||||
Session20 = #{
|
||||
session_id => <<"s20">>,
|
||||
user_id => 20,
|
||||
pid => self(),
|
||||
viewable_channels => #{ChannelId => true}
|
||||
},
|
||||
OldState = #{
|
||||
id => GuildId,
|
||||
sessions => #{<<"s10">> => Session10, <<"s20">> => Session20},
|
||||
data => #{
|
||||
<<"guild">> => #{<<"owner_id">> => <<"999">>},
|
||||
<<"roles">> => [
|
||||
#{<<"id">> => integer_to_binary(GuildId), <<"permissions">> => <<"0">>},
|
||||
#{<<"id">> => integer_to_binary(RoleId), <<"permissions">> => integer_to_binary(ViewPerm)}
|
||||
],
|
||||
<<"members">> => #{
|
||||
10 => #{<<"user">> => #{<<"id">> => <<"10">>}, <<"roles">> => [integer_to_binary(RoleId)]},
|
||||
20 => #{<<"user">> => #{<<"id">> => <<"20">>}, <<"roles">> => [integer_to_binary(RoleId)]}
|
||||
},
|
||||
<<"channels">> => [#{<<"id">> => integer_to_binary(ChannelId), <<"permission_overwrites">> => []}]
|
||||
}
|
||||
},
|
||||
NewState = #{
|
||||
id => GuildId,
|
||||
sessions => #{<<"s10">> => Session10, <<"s20">> => Session20},
|
||||
data => #{
|
||||
<<"guild">> => #{<<"owner_id">> => <<"999">>},
|
||||
<<"roles">> => [
|
||||
#{<<"id">> => integer_to_binary(GuildId), <<"permissions">> => <<"0">>},
|
||||
#{<<"id">> => integer_to_binary(RoleId), <<"permissions">> => integer_to_binary(ViewPerm)}
|
||||
],
|
||||
<<"members">> => #{
|
||||
10 => #{<<"user">> => #{<<"id">> => <<"10">>}, <<"roles">> => []},
|
||||
20 => #{<<"user">> => #{<<"id">> => <<"20">>}, <<"roles">> => [integer_to_binary(RoleId)]}
|
||||
},
|
||||
<<"channels">> => [#{<<"id">> => integer_to_binary(ChannelId), <<"permission_overwrites">> => []}]
|
||||
}
|
||||
},
|
||||
UpdatedState = compute_and_dispatch_visibility_changes_for_users([10], OldState, NewState),
|
||||
UpdatedSessions = maps:get(sessions, UpdatedState),
|
||||
UpdatedSession10 = maps:get(<<"s10">>, UpdatedSessions),
|
||||
UpdatedSession20 = maps:get(<<"s20">>, UpdatedSessions),
|
||||
?assertEqual(false, maps:is_key(ChannelId, maps:get(viewable_channels, UpdatedSession10, #{}))),
|
||||
?assertEqual(true, maps:is_key(ChannelId, maps:get(viewable_channels, UpdatedSession20, #{}))).
|
||||
|
||||
compute_and_dispatch_visibility_changes_for_channels_limits_to_changed_channels_test() ->
|
||||
GuildId = 44,
|
||||
UserId = 10,
|
||||
ChannelA = 100,
|
||||
ChannelB = 101,
|
||||
ViewPerm = constants:view_channel_permission(),
|
||||
BaseRole = #{
|
||||
<<"id">> => integer_to_binary(GuildId),
|
||||
<<"permissions">> => integer_to_binary(ViewPerm)
|
||||
},
|
||||
Session = #{
|
||||
session_id => <<"s10">>,
|
||||
user_id => UserId,
|
||||
pid => self(),
|
||||
viewable_channels => #{ChannelA => true, ChannelB => true}
|
||||
},
|
||||
OldState = #{
|
||||
id => GuildId,
|
||||
sessions => #{<<"s10">> => Session},
|
||||
data => #{
|
||||
<<"guild">> => #{<<"owner_id">> => <<"999">>},
|
||||
<<"roles">> => [BaseRole],
|
||||
<<"members">> => #{
|
||||
UserId => #{
|
||||
<<"user">> => #{<<"id">> => integer_to_binary(UserId)},
|
||||
<<"roles">> => []
|
||||
}
|
||||
},
|
||||
<<"channels">> => [
|
||||
#{<<"id">> => integer_to_binary(ChannelA), <<"permission_overwrites">> => []},
|
||||
#{<<"id">> => integer_to_binary(ChannelB), <<"permission_overwrites">> => []}
|
||||
]
|
||||
}
|
||||
},
|
||||
NewState = #{
|
||||
id => GuildId,
|
||||
sessions => #{<<"s10">> => Session},
|
||||
data => #{
|
||||
<<"guild">> => #{<<"owner_id">> => <<"999">>},
|
||||
<<"roles">> => [BaseRole],
|
||||
<<"members">> => #{
|
||||
UserId => #{
|
||||
<<"user">> => #{<<"id">> => integer_to_binary(UserId)},
|
||||
<<"roles">> => []
|
||||
}
|
||||
},
|
||||
<<"channels">> => [
|
||||
#{
|
||||
<<"id">> => integer_to_binary(ChannelA),
|
||||
<<"permission_overwrites">> => [
|
||||
#{
|
||||
<<"id">> => integer_to_binary(GuildId),
|
||||
<<"type">> => 0,
|
||||
<<"allow">> => <<"0">>,
|
||||
<<"deny">> => integer_to_binary(ViewPerm)
|
||||
}
|
||||
]
|
||||
},
|
||||
#{<<"id">> => integer_to_binary(ChannelB), <<"permission_overwrites">> => []}
|
||||
]
|
||||
}
|
||||
},
|
||||
UpdatedState = compute_and_dispatch_visibility_changes_for_channels(
|
||||
[ChannelA],
|
||||
OldState,
|
||||
NewState
|
||||
),
|
||||
UpdatedSession = maps:get(<<"s10">>, maps:get(sessions, UpdatedState)),
|
||||
UpdatedViewable = maps:get(viewable_channels, UpdatedSession, #{}),
|
||||
?assertEqual(false, maps:is_key(ChannelA, UpdatedViewable)),
|
||||
?assertEqual(true, maps:is_key(ChannelB, UpdatedViewable)).
|
||||
|
||||
visibility_state(GuildId, UserId, ChannelId, Perms, Connected) ->
|
||||
VoiceStates =
|
||||
case Connected of
|
||||
true ->
|
||||
#{
|
||||
<<"conn">> => #{
|
||||
<<"user_id">> => integer_to_binary(UserId),
|
||||
<<"guild_id">> => integer_to_binary(GuildId),
|
||||
<<"channel_id">> => integer_to_binary(ChannelId),
|
||||
<<"connection_id">> => <<"conn">>
|
||||
}
|
||||
};
|
||||
false ->
|
||||
#{}
|
||||
end,
|
||||
Sessions = #{<<"s1">> => #{user_id => UserId, pid => self()}},
|
||||
Data = #{
|
||||
<<"guild">> => #{<<"owner_id">> => <<"999">>},
|
||||
<<"roles">> => [
|
||||
#{
|
||||
<<"id">> => integer_to_binary(GuildId),
|
||||
<<"permissions">> => integer_to_binary(Perms)
|
||||
}
|
||||
],
|
||||
<<"members">> => [
|
||||
#{<<"user">> => #{<<"id">> => integer_to_binary(UserId)}, <<"roles">> => []}
|
||||
],
|
||||
<<"channels">> => [
|
||||
#{<<"id">> => integer_to_binary(ChannelId), <<"permission_overwrites">> => []}
|
||||
]
|
||||
},
|
||||
#{
|
||||
id => GuildId,
|
||||
data => Data,
|
||||
sessions => Sessions,
|
||||
voice_states => VoiceStates
|
||||
}.
|
||||
|
||||
filter_connected_session_entries_excludes_pending_test() ->
|
||||
Normal = #{session_id => <<"s1">>, user_id => 1, pending_connect => false},
|
||||
Pending = #{session_id => <<"s2">>, user_id => 2, pending_connect => true},
|
||||
NoPending = #{session_id => <<"s3">>, user_id => 3},
|
||||
Sessions = #{<<"s1">> => Normal, <<"s2">> => Pending, <<"s3">> => NoPending},
|
||||
Result = filter_connected_session_entries(Sessions),
|
||||
ResultIds = lists:sort([Sid || {Sid, _} <- Result]),
|
||||
?assertEqual([<<"s1">>, <<"s3">>], ResultIds).
|
||||
|
||||
-endif.
|
||||
|
||||
1552
fluxer_gateway/src/guild/very_large_guild.erl
Normal file
1552
fluxer_gateway/src/guild/very_large_guild.erl
Normal file
File diff suppressed because it is too large
Load Diff
1489
fluxer_gateway/src/guild/very_large_guild_member_list.erl
Normal file
1489
fluxer_gateway/src/guild/very_large_guild_member_list.erl
Normal file
File diff suppressed because it is too large
Load Diff
@@ -17,6 +17,8 @@
|
||||
|
||||
-module(dm_voice).
|
||||
|
||||
-import(guild_voice_unclaimed_account_utils, [parse_unclaimed_error/1]).
|
||||
|
||||
-export([voice_state_update/2]).
|
||||
-export([get_voice_state/2]).
|
||||
-export([get_voice_token/6]).
|
||||
@@ -24,58 +26,48 @@
|
||||
-export([broadcast_voice_state_update/3]).
|
||||
-export([join_or_create_call/5, join_or_create_call/6]).
|
||||
|
||||
-ifdef(TEST).
|
||||
-include_lib("eunit/include/eunit.hrl").
|
||||
-endif.
|
||||
|
||||
-type dm_state() :: map().
|
||||
-type voice_state() :: map().
|
||||
-type voice_state_map() :: #{binary() => voice_state()}.
|
||||
|
||||
-spec voice_state_update(map(), dm_state()) ->
|
||||
{reply, map(), dm_state()} | {reply, {error, atom(), atom()}, dm_state()}.
|
||||
voice_state_update(Request, State) ->
|
||||
#{
|
||||
user_id := UserId,
|
||||
channel_id := ChannelId
|
||||
} = Request,
|
||||
|
||||
ConnectionId = maps:get(connection_id, Request, undefined),
|
||||
VoiceStates = maps:get(dm_voice_states, State, #{}),
|
||||
|
||||
case ChannelId of
|
||||
null ->
|
||||
handle_dm_disconnect(ConnectionId, UserId, VoiceStates, State);
|
||||
ChannelIdValue ->
|
||||
Channels = maps:get(channels, State, #{}),
|
||||
UserId = maps:get(user_id, State),
|
||||
logger:info(
|
||||
"[dm_voice] Looking up channel ~p for user ~p, channels map has ~p entries",
|
||||
[ChannelIdValue, UserId, maps:size(Channels)]
|
||||
),
|
||||
case maps:get(ChannelIdValue, Channels, undefined) of
|
||||
undefined ->
|
||||
logger:info(
|
||||
"[dm_voice] Channel ~p not found locally for user ~p, trying RPC fallback",
|
||||
[ChannelIdValue, UserId]
|
||||
),
|
||||
case fetch_dm_channel_via_rpc(ChannelIdValue, UserId) of
|
||||
{ok, Channel} ->
|
||||
NewChannels = maps:put(ChannelIdValue, Channel, Channels),
|
||||
NewState = maps:put(channels, NewChannels, State),
|
||||
logger:info(
|
||||
"[dm_voice] RPC fallback found channel ~p for user ~p, added to local map",
|
||||
[ChannelIdValue, UserId]
|
||||
),
|
||||
handle_dm_voice_with_channel(
|
||||
Channel, ChannelIdValue, UserId, Request, NewState
|
||||
);
|
||||
{error, Reason} ->
|
||||
logger:warning(
|
||||
"[dm_voice] Channel ~p not found for user ~p via RPC: ~p",
|
||||
[ChannelIdValue, UserId, Reason]
|
||||
),
|
||||
{error, _Reason} ->
|
||||
{reply, gateway_errors:error(dm_channel_not_found), State}
|
||||
end;
|
||||
Channel ->
|
||||
logger:info(
|
||||
"[dm_voice] Found channel ~p for user ~p, type: ~p",
|
||||
[ChannelIdValue, UserId, maps:get(<<"type">>, Channel, 0)]
|
||||
),
|
||||
handle_dm_voice_with_channel(Channel, ChannelIdValue, UserId, Request, State)
|
||||
end
|
||||
end.
|
||||
|
||||
-spec handle_dm_voice_with_channel(map(), integer(), integer(), map(), dm_state()) ->
|
||||
{reply, map(), dm_state()} | {reply, {error, atom(), atom()}, dm_state()}.
|
||||
handle_dm_voice_with_channel(Channel, ChannelIdValue, UserId, Request, State) ->
|
||||
#{
|
||||
session_id := SessionId,
|
||||
@@ -86,11 +78,10 @@ handle_dm_voice_with_channel(Channel, ChannelIdValue, UserId, Request, State) ->
|
||||
SelfStream = maps:get(self_stream, Request, false),
|
||||
ConnectionId = maps:get(connection_id, Request, undefined),
|
||||
IsMobile = maps:get(is_mobile, Request, false),
|
||||
ViewerStreamKey = maps:get(viewer_stream_key, Request, undefined),
|
||||
ViewerStreamKeys = maps:get(viewer_stream_keys, Request, undefined),
|
||||
Latitude = maps:get(latitude, Request, null),
|
||||
Longitude = maps:get(longitude, Request, null),
|
||||
VoiceStates = maps:get(dm_voice_states, State, #{}),
|
||||
|
||||
ChannelType = maps:get(<<"type">>, Channel, 0),
|
||||
case is_dm_channel_type(ChannelType) of
|
||||
false ->
|
||||
@@ -109,7 +100,7 @@ handle_dm_voice_with_channel(Channel, ChannelIdValue, UserId, Request, State) ->
|
||||
SelfDeaf,
|
||||
SelfVideo,
|
||||
SelfStream,
|
||||
ViewerStreamKey,
|
||||
ViewerStreamKeys,
|
||||
IsMobile,
|
||||
Latitude,
|
||||
Longitude,
|
||||
@@ -119,6 +110,8 @@ handle_dm_voice_with_channel(Channel, ChannelIdValue, UserId, Request, State) ->
|
||||
end
|
||||
end.
|
||||
|
||||
-spec handle_dm_disconnect(binary() | undefined, integer(), voice_state_map(), dm_state()) ->
|
||||
{reply, map(), dm_state()} | {reply, {error, atom(), atom()}, dm_state()}.
|
||||
handle_dm_disconnect(undefined, _UserId, _VoiceStates, State) ->
|
||||
{reply, gateway_errors:error(voice_missing_connection_id), State};
|
||||
handle_dm_disconnect(ConnectionId, _UserId, VoiceStates, State) ->
|
||||
@@ -128,13 +121,11 @@ handle_dm_disconnect(ConnectionId, _UserId, VoiceStates, State) ->
|
||||
OldVoiceState ->
|
||||
NewVoiceStates = maps:remove(ConnectionId, VoiceStates),
|
||||
NewState = maps:put(dm_voice_states, NewVoiceStates, State),
|
||||
|
||||
OldChannelId = maps:get(<<"channel_id">>, OldVoiceState, null),
|
||||
DisconnectVoiceState = maps:put(
|
||||
<<"channel_id">>, null, maps:put(<<"connection_id">>, ConnectionId, OldVoiceState)
|
||||
),
|
||||
SessionId = maps:get(id, State),
|
||||
|
||||
case OldChannelId of
|
||||
null ->
|
||||
ok;
|
||||
@@ -154,7 +145,6 @@ handle_dm_disconnect(ConnectionId, _UserId, VoiceStates, State) ->
|
||||
end
|
||||
end)
|
||||
end,
|
||||
|
||||
case OldChannelId of
|
||||
null ->
|
||||
ok;
|
||||
@@ -164,15 +154,29 @@ handle_dm_disconnect(ConnectionId, _UserId, VoiceStates, State) ->
|
||||
broadcast_voice_state_update(
|
||||
OldChannelIdInt, DisconnectVoiceState, NewState
|
||||
);
|
||||
{error, _, Reason} ->
|
||||
logger:warning("[dm_voice] Invalid channel_id: ~p", [Reason]),
|
||||
{error, _, _Reason} ->
|
||||
ok
|
||||
end
|
||||
end,
|
||||
|
||||
{reply, #{success => true}, NewState}
|
||||
end.
|
||||
|
||||
-spec handle_dm_connect_or_update(
|
||||
binary() | undefined | null,
|
||||
integer(),
|
||||
integer(),
|
||||
binary(),
|
||||
boolean(),
|
||||
boolean(),
|
||||
boolean(),
|
||||
boolean(),
|
||||
term(),
|
||||
boolean(),
|
||||
term(),
|
||||
term(),
|
||||
voice_state_map(),
|
||||
dm_state()
|
||||
) -> {reply, map(), dm_state()} | {reply, {error, atom(), atom()}, dm_state()}.
|
||||
handle_dm_connect_or_update(
|
||||
ConnectionId,
|
||||
ChannelIdValue,
|
||||
@@ -182,7 +186,7 @@ handle_dm_connect_or_update(
|
||||
SelfDeaf,
|
||||
SelfVideo,
|
||||
SelfStream,
|
||||
ViewerStreamKey,
|
||||
ViewerStreamKeys,
|
||||
IsMobile,
|
||||
Latitude,
|
||||
Longitude,
|
||||
@@ -190,7 +194,7 @@ handle_dm_connect_or_update(
|
||||
State
|
||||
) when ConnectionId =:= undefined; ConnectionId =:= null ->
|
||||
VoiceStates = maps:get(dm_voice_states, State, #{}),
|
||||
case validate_dm_viewer_stream_key(ViewerStreamKey, ChannelIdValue, VoiceStates) of
|
||||
case validate_dm_viewer_stream_keys(ViewerStreamKeys, ChannelIdValue, VoiceStates) of
|
||||
{error, ErrorAtom} ->
|
||||
{reply, gateway_errors:error(ErrorAtom), State};
|
||||
{ok, ParsedViewerKey} ->
|
||||
@@ -218,7 +222,7 @@ handle_dm_connect_or_update(
|
||||
SelfDeaf,
|
||||
SelfVideo,
|
||||
SelfStream,
|
||||
ViewerStreamKey,
|
||||
ViewerStreamKeys,
|
||||
IsMobile,
|
||||
_Latitude,
|
||||
_Longitude,
|
||||
@@ -229,43 +233,49 @@ handle_dm_connect_or_update(
|
||||
undefined ->
|
||||
{reply, gateway_errors:error(voice_connection_not_found), State};
|
||||
ExistingVoiceState ->
|
||||
ExistingSessionId = maps:get(<<"session_id">>, ExistingVoiceState, undefined),
|
||||
EffectiveSessionId = resolve_effective_session_id(ExistingSessionId, SessionId),
|
||||
ValidViewerKey = validate_dm_viewer_stream_key(
|
||||
ViewerStreamKey, ChannelIdValue, VoiceStates
|
||||
),
|
||||
case ValidViewerKey of
|
||||
{error, ErrorAtom} ->
|
||||
{reply, gateway_errors:error(ErrorAtom), State};
|
||||
{ok, ParsedViewerKey} ->
|
||||
UpdatedVoiceState = ExistingVoiceState#{
|
||||
<<"channel_id">> => integer_to_binary(ChannelIdValue),
|
||||
<<"session_id">> => EffectiveSessionId,
|
||||
<<"self_mute">> => SelfMute,
|
||||
<<"self_deaf">> => SelfDeaf,
|
||||
<<"self_video">> => SelfVideo,
|
||||
<<"self_stream">> => SelfStream,
|
||||
<<"is_mobile">> => IsMobile,
|
||||
<<"viewer_stream_key">> => ParsedViewerKey
|
||||
},
|
||||
|
||||
NewVoiceStates = maps:put(ConnectionId, UpdatedVoiceState, VoiceStates),
|
||||
NewState = maps:put(dm_voice_states, NewVoiceStates, State),
|
||||
|
||||
broadcast_voice_state_update(ChannelIdValue, UpdatedVoiceState, NewState),
|
||||
|
||||
OldChannelId = maps:get(<<"channel_id">>, ExistingVoiceState, null),
|
||||
NewChannelIdBin = integer_to_binary(ChannelIdValue),
|
||||
NeedsToken = OldChannelId =/= NewChannelIdBin,
|
||||
|
||||
maybe_spawn_join_call(
|
||||
NeedsToken, ChannelIdValue, UserId, UpdatedVoiceState, SessionId
|
||||
case guild_voice_state:user_matches_voice_state(ExistingVoiceState, UserId) of
|
||||
false ->
|
||||
{reply, gateway_errors:error(voice_user_mismatch), State};
|
||||
true ->
|
||||
ExistingSessionId = maps:get(<<"session_id">>, ExistingVoiceState, undefined),
|
||||
EffectiveSessionId = resolve_effective_session_id(ExistingSessionId, SessionId),
|
||||
ValidViewerKey = validate_dm_viewer_stream_keys(
|
||||
ViewerStreamKeys, ChannelIdValue, VoiceStates
|
||||
),
|
||||
|
||||
{reply, #{success => true, needs_token => NeedsToken}, NewState}
|
||||
case ValidViewerKey of
|
||||
{error, ErrorAtom} ->
|
||||
{reply, gateway_errors:error(ErrorAtom), State};
|
||||
{ok, ParsedViewerKey} ->
|
||||
UpdatedVoiceState = ExistingVoiceState#{
|
||||
<<"channel_id">> => integer_to_binary(ChannelIdValue),
|
||||
<<"session_id">> => EffectiveSessionId,
|
||||
<<"self_mute">> => SelfMute,
|
||||
<<"self_deaf">> => SelfDeaf,
|
||||
<<"self_video">> => SelfVideo,
|
||||
<<"self_stream">> => SelfStream,
|
||||
<<"is_mobile">> => IsMobile,
|
||||
<<"viewer_stream_keys">> => ParsedViewerKey
|
||||
},
|
||||
NewVoiceStates = maps:put(ConnectionId, UpdatedVoiceState, VoiceStates),
|
||||
NewState = maps:put(dm_voice_states, NewVoiceStates, State),
|
||||
broadcast_voice_state_update(ChannelIdValue, UpdatedVoiceState, NewState),
|
||||
OldChannelId = maps:get(<<"channel_id">>, ExistingVoiceState, null),
|
||||
NewChannelIdBin = integer_to_binary(ChannelIdValue),
|
||||
NeedsToken = OldChannelId =/= NewChannelIdBin,
|
||||
maybe_spawn_join_call(
|
||||
NeedsToken,
|
||||
ChannelIdValue,
|
||||
UserId,
|
||||
UpdatedVoiceState,
|
||||
EffectiveSessionId,
|
||||
State
|
||||
),
|
||||
{reply, #{success => true, needs_token => NeedsToken}, NewState}
|
||||
end
|
||||
end
|
||||
end.
|
||||
|
||||
-spec normalize_session_id(term()) -> binary() | undefined.
|
||||
normalize_session_id(undefined) ->
|
||||
undefined;
|
||||
normalize_session_id(SessionId) when is_binary(SessionId) ->
|
||||
@@ -281,32 +291,50 @@ normalize_session_id(SessionId) ->
|
||||
_:_ -> SessionId
|
||||
end.
|
||||
|
||||
validate_dm_viewer_stream_key(RawKey, ChannelIdValue, VoiceStates) ->
|
||||
case RawKey of
|
||||
undefined ->
|
||||
{ok, null};
|
||||
null ->
|
||||
{ok, null};
|
||||
_ when not is_binary(RawKey) -> {error, voice_invalid_state};
|
||||
_ ->
|
||||
case voice_state_utils:parse_stream_key(RawKey) of
|
||||
{ok, #{scope := dm, channel_id := ParsedChannelId, connection_id := ConnId}} when
|
||||
ParsedChannelId =:= ChannelIdValue
|
||||
->
|
||||
case maps:get(ConnId, VoiceStates, undefined) of
|
||||
undefined ->
|
||||
{error, voice_connection_not_found};
|
||||
StreamVS ->
|
||||
case map_utils:get_integer(StreamVS, <<"channel_id">>, undefined) of
|
||||
ChannelIdValue -> {ok, RawKey};
|
||||
_ -> {error, voice_invalid_state}
|
||||
end
|
||||
end;
|
||||
_ ->
|
||||
{error, voice_invalid_state}
|
||||
end
|
||||
-spec validate_dm_viewer_stream_keys(term(), integer(), voice_state_map()) ->
|
||||
{ok, list()} | {error, atom()}.
|
||||
validate_dm_viewer_stream_keys(RawKeys, _ChannelIdValue, _VoiceStates) when RawKeys =:= undefined; RawKeys =:= null ->
|
||||
{ok, []};
|
||||
validate_dm_viewer_stream_keys(RawKeys, _ChannelIdValue, _VoiceStates) when not is_list(RawKeys) ->
|
||||
{error, voice_invalid_state};
|
||||
validate_dm_viewer_stream_keys(Keys, ChannelIdValue, VoiceStates) ->
|
||||
validate_dm_viewer_stream_keys_list(Keys, ChannelIdValue, VoiceStates, []).
|
||||
|
||||
-spec validate_dm_viewer_stream_keys_list(list(), integer(), voice_state_map(), list()) ->
|
||||
{ok, list()} | {error, atom()}.
|
||||
validate_dm_viewer_stream_keys_list([], _ChannelIdValue, _VoiceStates, Acc) ->
|
||||
{ok, lists:reverse(Acc)};
|
||||
validate_dm_viewer_stream_keys_list([Key | Rest], ChannelIdValue, VoiceStates, Acc) ->
|
||||
case validate_single_dm_viewer_stream_key(Key, ChannelIdValue, VoiceStates) of
|
||||
{ok, ValidKey} ->
|
||||
validate_dm_viewer_stream_keys_list(Rest, ChannelIdValue, VoiceStates, [ValidKey | Acc]);
|
||||
{error, _} = Error ->
|
||||
Error
|
||||
end.
|
||||
|
||||
-spec validate_single_dm_viewer_stream_key(term(), integer(), voice_state_map()) ->
|
||||
{ok, binary()} | {error, atom()}.
|
||||
validate_single_dm_viewer_stream_key(RawKey, _ChannelIdValue, _VoiceStates) when not is_binary(RawKey) ->
|
||||
{error, voice_invalid_state};
|
||||
validate_single_dm_viewer_stream_key(RawKey, ChannelIdValue, VoiceStates) ->
|
||||
case voice_state_utils:parse_stream_key(RawKey) of
|
||||
{ok, #{scope := dm, channel_id := ParsedChannelId, connection_id := ConnId}} when
|
||||
ParsedChannelId =:= ChannelIdValue
|
||||
->
|
||||
case maps:get(ConnId, VoiceStates, undefined) of
|
||||
undefined ->
|
||||
{error, voice_connection_not_found};
|
||||
StreamVS ->
|
||||
case map_utils:get_integer(StreamVS, <<"channel_id">>, undefined) of
|
||||
ChannelIdValue -> {ok, RawKey};
|
||||
_ -> {error, voice_invalid_state}
|
||||
end
|
||||
end;
|
||||
_ ->
|
||||
{error, voice_invalid_state}
|
||||
end.
|
||||
|
||||
-spec resolve_effective_session_id(term(), term()) -> binary() | undefined.
|
||||
resolve_effective_session_id(ExistingSessionId, RequestSessionId) ->
|
||||
ExistingNormalized = normalize_session_id(ExistingSessionId),
|
||||
RequestNormalized = normalize_session_id(RequestSessionId),
|
||||
@@ -316,28 +344,43 @@ resolve_effective_session_id(ExistingSessionId, RequestSessionId) ->
|
||||
_ -> ExistingNormalized
|
||||
end.
|
||||
|
||||
maybe_spawn_join_call(false, _ChannelId, _UserId, _VoiceState, _SessionId) ->
|
||||
-spec maybe_spawn_join_call(
|
||||
boolean(), integer(), integer(), voice_state(), binary() | undefined, dm_state()
|
||||
) -> ok.
|
||||
maybe_spawn_join_call(false, _ChannelId, _UserId, _VoiceState, _SessionId, _State) ->
|
||||
ok;
|
||||
maybe_spawn_join_call(true, ChannelId, UserId, VoiceState, SessionId) ->
|
||||
spawn(fun() ->
|
||||
try
|
||||
join_or_create_call(ChannelId, UserId, VoiceState, SessionId, self())
|
||||
catch
|
||||
_:_ -> ok
|
||||
end
|
||||
end).
|
||||
maybe_spawn_join_call(true, ChannelId, UserId, VoiceState, SessionId, State) when
|
||||
is_binary(SessionId)
|
||||
->
|
||||
SessionPid = maps:get(session_pid, State, undefined),
|
||||
case SessionPid of
|
||||
Pid when is_pid(Pid) ->
|
||||
spawn(fun() ->
|
||||
try
|
||||
join_or_create_call(ChannelId, UserId, VoiceState, SessionId, Pid)
|
||||
catch
|
||||
_:_ -> ok
|
||||
end
|
||||
end),
|
||||
ok;
|
||||
_ ->
|
||||
ok
|
||||
end;
|
||||
maybe_spawn_join_call(true, _ChannelId, _UserId, _VoiceState, _SessionId, _State) ->
|
||||
ok.
|
||||
|
||||
-spec get_voice_token(integer(), integer(), binary(), pid(), term(), term()) -> ok | error.
|
||||
get_voice_token(ChannelId, UserId, _SessionId, SessionPid, Latitude, Longitude) ->
|
||||
Req = voice_utils:build_voice_token_rpc_request(
|
||||
null, ChannelId, UserId, null, Latitude, Longitude
|
||||
),
|
||||
|
||||
case rpc_client:call(Req) of
|
||||
Region = resolve_call_region(ChannelId),
|
||||
ReqWithRegion = voice_utils:add_rtc_region_to_request(Req, Region),
|
||||
case rpc_client:call(ReqWithRegion) of
|
||||
{ok, Data} ->
|
||||
Token = maps:get(<<"token">>, Data),
|
||||
Endpoint = maps:get(<<"endpoint">>, Data),
|
||||
ConnectionId = maps:get(<<"connectionId">>, Data),
|
||||
|
||||
SessionPid !
|
||||
{voice_server_update, #{
|
||||
channel_id => integer_to_binary(ChannelId),
|
||||
@@ -357,6 +400,20 @@ get_voice_token(ChannelId, UserId, _SessionId, SessionPid, Latitude, Longitude)
|
||||
error
|
||||
end.
|
||||
|
||||
-spec get_dm_voice_token_and_create_state(
|
||||
integer(),
|
||||
integer(),
|
||||
binary(),
|
||||
boolean(),
|
||||
boolean(),
|
||||
boolean(),
|
||||
boolean(),
|
||||
term(),
|
||||
boolean(),
|
||||
term(),
|
||||
term(),
|
||||
dm_state()
|
||||
) -> {reply, map(), dm_state()} | {reply, {error, atom(), atom()}, dm_state()}.
|
||||
get_dm_voice_token_and_create_state(
|
||||
UserId,
|
||||
ChannelId,
|
||||
@@ -365,7 +422,7 @@ get_dm_voice_token_and_create_state(
|
||||
SelfDeaf,
|
||||
SelfVideo,
|
||||
SelfStream,
|
||||
ViewerStreamKey,
|
||||
ViewerStreamKeys,
|
||||
IsMobile,
|
||||
Latitude,
|
||||
Longitude,
|
||||
@@ -374,8 +431,9 @@ get_dm_voice_token_and_create_state(
|
||||
Req = voice_utils:build_voice_token_rpc_request(
|
||||
null, ChannelId, UserId, null, Latitude, Longitude
|
||||
),
|
||||
|
||||
case rpc_client:call(Req) of
|
||||
Region = resolve_call_region(ChannelId, State),
|
||||
ReqWithRegion = voice_utils:add_rtc_region_to_request(Req, Region),
|
||||
case rpc_client:call(ReqWithRegion) of
|
||||
{ok, Data} ->
|
||||
handle_dm_token_success(
|
||||
Data,
|
||||
@@ -386,7 +444,7 @@ get_dm_voice_token_and_create_state(
|
||||
SelfDeaf,
|
||||
SelfVideo,
|
||||
SelfStream,
|
||||
ViewerStreamKey,
|
||||
ViewerStreamKeys,
|
||||
IsMobile,
|
||||
State
|
||||
);
|
||||
@@ -399,6 +457,48 @@ get_dm_voice_token_and_create_state(
|
||||
{reply, gateway_errors:error(voice_token_failed), State}
|
||||
end.
|
||||
|
||||
-spec resolve_call_region(integer()) -> binary() | null.
|
||||
resolve_call_region(ChannelId) ->
|
||||
case gen_server:call(call_manager, {lookup, ChannelId}, 5000) of
|
||||
{ok, CallPid} ->
|
||||
case gen_server:call(CallPid, {get_state}, 5000) of
|
||||
{ok, CallData} -> maps:get(region, CallData, null);
|
||||
_ -> null
|
||||
end;
|
||||
_ ->
|
||||
null
|
||||
end.
|
||||
|
||||
-spec resolve_call_region(integer(), dm_state()) -> binary() | null.
|
||||
resolve_call_region(ChannelId, State) ->
|
||||
Calls = maps:get(calls, State, #{}),
|
||||
case maps:get(ChannelId, Calls, undefined) of
|
||||
{CallPid, _Ref} when is_pid(CallPid) ->
|
||||
try
|
||||
case gen_server:call(CallPid, {get_state}, 250) of
|
||||
{ok, CallData} -> maps:get(region, CallData, null);
|
||||
_ -> null
|
||||
end
|
||||
catch
|
||||
_:_ -> null
|
||||
end;
|
||||
_ ->
|
||||
null
|
||||
end.
|
||||
|
||||
-spec handle_dm_token_success(
|
||||
map(),
|
||||
integer(),
|
||||
integer(),
|
||||
binary(),
|
||||
boolean(),
|
||||
boolean(),
|
||||
boolean(),
|
||||
boolean(),
|
||||
term(),
|
||||
boolean(),
|
||||
dm_state()
|
||||
) -> {reply, map(), dm_state()}.
|
||||
handle_dm_token_success(
|
||||
Data,
|
||||
UserId,
|
||||
@@ -408,14 +508,13 @@ handle_dm_token_success(
|
||||
SelfDeaf,
|
||||
SelfVideo,
|
||||
SelfStream,
|
||||
ViewerStreamKey,
|
||||
ViewerStreamKeys,
|
||||
IsMobile,
|
||||
State
|
||||
) ->
|
||||
Token = maps:get(<<"token">>, Data),
|
||||
Endpoint = maps:get(<<"endpoint">>, Data),
|
||||
ConnectionId = maps:get(<<"connectionId">>, Data),
|
||||
|
||||
VoiceState = #{
|
||||
<<"user_id">> => integer_to_binary(UserId),
|
||||
<<"channel_id">> => integer_to_binary(ChannelId),
|
||||
@@ -426,15 +525,12 @@ handle_dm_token_success(
|
||||
<<"self_deaf">> => SelfDeaf,
|
||||
<<"self_video">> => SelfVideo,
|
||||
<<"self_stream">> => SelfStream,
|
||||
<<"viewer_stream_key">> => ViewerStreamKey
|
||||
<<"viewer_stream_keys">> => ViewerStreamKeys
|
||||
},
|
||||
|
||||
VoiceStates = maps:get(dm_voice_states, State, #{}),
|
||||
NewVoiceStates = maps:put(ConnectionId, VoiceState, VoiceStates),
|
||||
NewState = maps:put(dm_voice_states, NewVoiceStates, State),
|
||||
|
||||
broadcast_voice_state_update(ChannelId, VoiceState, NewState),
|
||||
|
||||
SessionPid = maps:get(session_pid, State),
|
||||
VoiceServerUpdate = #{
|
||||
<<"token">> => Token,
|
||||
@@ -443,7 +539,6 @@ handle_dm_token_success(
|
||||
<<"connection_id">> => ConnectionId
|
||||
},
|
||||
gen_server:cast(SessionPid, {dispatch, voice_server_update, VoiceServerUpdate}),
|
||||
|
||||
GatewaySessionId = maps:get(id, State),
|
||||
spawn(fun() ->
|
||||
try
|
||||
@@ -452,23 +547,22 @@ handle_dm_token_success(
|
||||
_:_ -> ok
|
||||
end
|
||||
end),
|
||||
|
||||
{reply, #{success => true, needs_token => false, connection_id => ConnectionId}, NewState}.
|
||||
|
||||
-spec get_voice_state(binary(), dm_state()) -> voice_state() | undefined.
|
||||
get_voice_state(ConnectionId, State) ->
|
||||
VoiceStates = maps:get(dm_voice_states, State, #{}),
|
||||
maps:get(ConnectionId, VoiceStates, undefined).
|
||||
|
||||
-spec disconnect_voice_user(integer(), dm_state()) -> {reply, map(), dm_state()}.
|
||||
disconnect_voice_user(UserId, State) ->
|
||||
VoiceStates = maps:get(dm_voice_states, State, #{}),
|
||||
|
||||
UserVoiceStates = maps:filter(
|
||||
fun(_ConnectionId, VoiceState) ->
|
||||
maps:get(<<"user_id">>, VoiceState) =:= integer_to_binary(UserId)
|
||||
end,
|
||||
VoiceStates
|
||||
),
|
||||
|
||||
case maps:size(UserVoiceStates) of
|
||||
0 ->
|
||||
{reply, #{success => true}, State};
|
||||
@@ -481,71 +575,58 @@ disconnect_voice_user(UserId, State) ->
|
||||
UserVoiceStates
|
||||
),
|
||||
NewState = maps:put(dm_voice_states, NewVoiceStates, State),
|
||||
|
||||
maps:foreach(
|
||||
fun(_ConnectionId, VoiceState) ->
|
||||
ChannelId = maps:get(<<"channel_id">>, VoiceState, null),
|
||||
DisconnectVoiceState = maps:put(
|
||||
<<"channel_id">>,
|
||||
null,
|
||||
maps:put(<<"connection_id">>, _ConnectionId, VoiceState)
|
||||
),
|
||||
case ChannelId of
|
||||
null ->
|
||||
ok;
|
||||
_ ->
|
||||
case validation:validate_snowflake(<<"channel_id">>, ChannelId) of
|
||||
{ok, ChannelIdInt} ->
|
||||
broadcast_voice_state_update(
|
||||
ChannelIdInt, DisconnectVoiceState, NewState
|
||||
);
|
||||
{error, _, Reason} ->
|
||||
logger:warning(
|
||||
"[dm_voice] Invalid channel_id in voice state: ~p", [Reason]
|
||||
),
|
||||
ok
|
||||
end
|
||||
end
|
||||
end,
|
||||
UserVoiceStates
|
||||
),
|
||||
|
||||
spawn(fun() ->
|
||||
maps:foreach(
|
||||
fun(ConnId, VoiceState) ->
|
||||
ChannelId = maps:get(<<"channel_id">>, VoiceState, null),
|
||||
DisconnectVoiceState = maps:put(
|
||||
<<"channel_id">>,
|
||||
null,
|
||||
maps:put(<<"connection_id">>, ConnId, VoiceState)
|
||||
),
|
||||
case ChannelId of
|
||||
null ->
|
||||
ok;
|
||||
_ ->
|
||||
case validation:validate_snowflake(<<"channel_id">>, ChannelId) of
|
||||
{ok, ChannelIdInt} ->
|
||||
broadcast_voice_state_update(
|
||||
ChannelIdInt, DisconnectVoiceState, NewState
|
||||
);
|
||||
{error, _, _Reason} ->
|
||||
ok
|
||||
end
|
||||
end
|
||||
end,
|
||||
UserVoiceStates
|
||||
)
|
||||
end),
|
||||
{reply, #{success => true}, NewState}
|
||||
end.
|
||||
|
||||
parse_unclaimed_error(Body) when is_binary(Body) ->
|
||||
try jsx:decode(Body, [return_maps]) of
|
||||
#{<<"code">> := <<"UNCLAIMED_ACCOUNT_RESTRICTED">>} -> true;
|
||||
#{<<"error">> := #{<<"code">> := <<"UNCLAIMED_ACCOUNT_RESTRICTED">>}} -> true;
|
||||
_ -> false
|
||||
catch
|
||||
_:_ -> false
|
||||
end;
|
||||
parse_unclaimed_error(_) ->
|
||||
false.
|
||||
|
||||
-spec broadcast_voice_state_update(integer(), voice_state(), dm_state()) -> ok.
|
||||
broadcast_voice_state_update(ChannelId, VoiceState, State) ->
|
||||
Channels = maps:get(channels, State, #{}),
|
||||
|
||||
case maps:get(ChannelId, Channels, undefined) of
|
||||
undefined ->
|
||||
ok;
|
||||
Channel ->
|
||||
Recipients = maps:get(<<"recipient_ids">>, Channel, []),
|
||||
UserId = maps:get(user_id, State),
|
||||
|
||||
AllRecipients = lists:usort([UserId | Recipients]),
|
||||
|
||||
Event = voice_state_update,
|
||||
|
||||
lists:foreach(
|
||||
fun(RecipientId) ->
|
||||
presence_manager:dispatch_to_user(RecipientId, Event, VoiceState)
|
||||
end,
|
||||
AllRecipients
|
||||
)
|
||||
spawn(fun() ->
|
||||
lists:foreach(
|
||||
fun(RecipientId) ->
|
||||
presence_manager:dispatch_to_user(RecipientId, Event, VoiceState)
|
||||
end,
|
||||
AllRecipients
|
||||
)
|
||||
end),
|
||||
ok
|
||||
end.
|
||||
|
||||
-spec check_recipient(integer(), integer(), dm_state()) -> boolean().
|
||||
check_recipient(UserId, ChannelId, State) ->
|
||||
Channels = maps:get(channels, State, #{}),
|
||||
case maps:get(ChannelId, Channels, undefined) of
|
||||
@@ -556,20 +637,24 @@ check_recipient(UserId, ChannelId, State) ->
|
||||
is_dm_channel_type(ChannelType) andalso is_channel_recipient(UserId, Channel, State)
|
||||
end.
|
||||
|
||||
-spec is_dm_channel_type(integer()) -> boolean().
|
||||
is_dm_channel_type(1) -> true;
|
||||
is_dm_channel_type(3) -> true;
|
||||
is_dm_channel_type(_) -> false.
|
||||
|
||||
-spec is_channel_recipient(integer(), map(), dm_state()) -> boolean().
|
||||
is_channel_recipient(UserId, Channel, State) ->
|
||||
Recipients = maps:get(<<"recipient_ids">>, Channel, []),
|
||||
CurrentUserId = maps:get(user_id, State),
|
||||
lists:member(UserId, [CurrentUserId | Recipients]).
|
||||
|
||||
-spec join_or_create_call(integer(), integer(), voice_state(), binary(), pid()) -> ok.
|
||||
join_or_create_call(ChannelId, UserId, VoiceState, SessionId, SessionPid) ->
|
||||
join_or_create_call(ChannelId, UserId, VoiceState, SessionId, SessionPid, 10).
|
||||
|
||||
join_or_create_call(_ChannelId, UserId, _VoiceState, _SessionId, _SessionPid, 0) ->
|
||||
logger:warning("[dm_voice] Failed to join call after retries, user ~p could not join", [UserId]),
|
||||
-spec join_or_create_call(integer(), integer(), voice_state(), binary(), pid(), non_neg_integer()) ->
|
||||
ok.
|
||||
join_or_create_call(_ChannelId, _UserId, _VoiceState, _SessionId, _SessionPid, 0) ->
|
||||
ok;
|
||||
join_or_create_call(ChannelId, UserId, VoiceState, SessionId, SessionPid, Retries) ->
|
||||
ConnectionId = maps:get(<<"connection_id">>, VoiceState, undefined),
|
||||
@@ -597,6 +682,7 @@ join_or_create_call(ChannelId, UserId, VoiceState, SessionId, SessionPid, Retrie
|
||||
join_or_create_call(ChannelId, UserId, VoiceState, SessionId, SessionPid, Retries - 1)
|
||||
end.
|
||||
|
||||
-spec fetch_dm_channel_via_rpc(integer(), integer()) -> {ok, map()} | {error, term()}.
|
||||
fetch_dm_channel_via_rpc(ChannelId, UserId) ->
|
||||
Req = #{
|
||||
<<"type">> => <<"get_dm_channel">>,
|
||||
@@ -614,6 +700,7 @@ fetch_dm_channel_via_rpc(ChannelId, UserId) ->
|
||||
{error, Reason}
|
||||
end.
|
||||
|
||||
-spec convert_api_channel_to_gateway_format(map(), integer()) -> map().
|
||||
convert_api_channel_to_gateway_format(Channel, CurrentUserId) ->
|
||||
ChannelType = maps:get(<<"type">>, Channel, 0),
|
||||
Recipients = maps:get(<<"recipients">>, Channel, []),
|
||||
@@ -627,6 +714,7 @@ convert_api_channel_to_gateway_format(Channel, CurrentUserId) ->
|
||||
<<"recipient_ids">> => RecipientIds
|
||||
}.
|
||||
|
||||
-spec extract_recipient_id(term(), integer()) -> {true, integer()} | false.
|
||||
extract_recipient_id(Recipient, CurrentUserId) when is_map(Recipient) ->
|
||||
case maps:get(<<"id">>, Recipient, undefined) of
|
||||
undefined -> false;
|
||||
@@ -635,6 +723,7 @@ extract_recipient_id(Recipient, CurrentUserId) when is_map(Recipient) ->
|
||||
extract_recipient_id(Id, CurrentUserId) ->
|
||||
filter_recipient_id(parse_id(Id), CurrentUserId).
|
||||
|
||||
-spec parse_id(term()) -> integer() | null.
|
||||
parse_id(Id) when is_integer(Id) -> Id;
|
||||
parse_id(Id) when is_binary(Id) ->
|
||||
case validation:validate_snowflake(<<"id">>, Id) of
|
||||
@@ -644,6 +733,111 @@ parse_id(Id) when is_binary(Id) ->
|
||||
parse_id(_) ->
|
||||
null.
|
||||
|
||||
-spec filter_recipient_id(integer() | null, integer()) -> {true, integer()} | false.
|
||||
filter_recipient_id(null, _CurrentUserId) -> false;
|
||||
filter_recipient_id(Id, Id) -> false;
|
||||
filter_recipient_id(Id, _CurrentUserId) -> {true, Id}.
|
||||
|
||||
-ifdef(TEST).
|
||||
|
||||
is_dm_channel_type_test() ->
|
||||
?assert(is_dm_channel_type(1)),
|
||||
?assert(is_dm_channel_type(3)),
|
||||
?assertNot(is_dm_channel_type(0)),
|
||||
?assertNot(is_dm_channel_type(2)),
|
||||
?assertNot(is_dm_channel_type(4)).
|
||||
|
||||
normalize_session_id_test() ->
|
||||
?assertEqual(undefined, normalize_session_id(undefined)),
|
||||
?assertEqual(<<"abc">>, normalize_session_id(<<"abc">>)),
|
||||
?assertEqual(<<"123">>, normalize_session_id(123)),
|
||||
?assertEqual(<<"test">>, normalize_session_id("test")).
|
||||
|
||||
validate_dm_viewer_stream_keys_null_test() ->
|
||||
?assertEqual({ok, []}, validate_dm_viewer_stream_keys(undefined, 123, #{})),
|
||||
?assertEqual({ok, []}, validate_dm_viewer_stream_keys(null, 123, #{})).
|
||||
|
||||
validate_dm_viewer_stream_keys_invalid_type_test() ->
|
||||
?assertEqual({error, voice_invalid_state}, validate_dm_viewer_stream_keys(123, 456, #{})).
|
||||
|
||||
resolve_effective_session_id_test() ->
|
||||
?assertEqual(<<"req">>, resolve_effective_session_id(undefined, <<"req">>)),
|
||||
?assertEqual(<<"existing">>, resolve_effective_session_id(<<"existing">>, <<"req">>)),
|
||||
?assertEqual(<<"same">>, resolve_effective_session_id(<<"same">>, <<"same">>)).
|
||||
|
||||
filter_recipient_id_test() ->
|
||||
?assertEqual(false, filter_recipient_id(null, 1)),
|
||||
?assertEqual(false, filter_recipient_id(1, 1)),
|
||||
?assertEqual({true, 2}, filter_recipient_id(2, 1)).
|
||||
|
||||
parse_id_test() ->
|
||||
?assertEqual(123, parse_id(123)),
|
||||
?assertEqual(null, parse_id(invalid)).
|
||||
|
||||
handle_dm_connect_or_update_user_mismatch_test() ->
|
||||
VoiceStates = #{
|
||||
<<"conn-1">> => #{
|
||||
<<"channel_id">> => <<"100">>,
|
||||
<<"user_id">> => <<"20">>,
|
||||
<<"session_id">> => <<"sess">>
|
||||
}
|
||||
},
|
||||
State = #{dm_voice_states => VoiceStates},
|
||||
{reply, {error, validation_error, voice_user_mismatch}, _} =
|
||||
handle_dm_connect_or_update(
|
||||
<<"conn-1">>,
|
||||
100,
|
||||
10,
|
||||
<<"sess">>,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
undefined,
|
||||
false,
|
||||
null,
|
||||
null,
|
||||
VoiceStates,
|
||||
State
|
||||
).
|
||||
|
||||
handle_dm_connect_or_update_owner_match_proceeds_test() ->
|
||||
VoiceStates = #{
|
||||
<<"conn-1">> => #{
|
||||
<<"channel_id">> => <<"100">>,
|
||||
<<"user_id">> => <<"10">>,
|
||||
<<"session_id">> => <<"sess">>
|
||||
}
|
||||
},
|
||||
State = #{
|
||||
dm_voice_states => VoiceStates,
|
||||
channels => #{100 => #{<<"type">> => 1, <<"recipient_ids">> => [10]}},
|
||||
user_id => 10,
|
||||
id => <<"sess">>,
|
||||
session_pid => self()
|
||||
},
|
||||
case
|
||||
handle_dm_connect_or_update(
|
||||
<<"conn-1">>,
|
||||
100,
|
||||
10,
|
||||
<<"sess">>,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
undefined,
|
||||
false,
|
||||
null,
|
||||
null,
|
||||
VoiceStates,
|
||||
State
|
||||
)
|
||||
of
|
||||
{reply, {error, validation_error, voice_user_mismatch}, _} ->
|
||||
error(should_not_get_user_mismatch);
|
||||
{reply, _, _} ->
|
||||
ok
|
||||
end.
|
||||
|
||||
-endif.
|
||||
|
||||
@@ -26,7 +26,7 @@
|
||||
-export([confirm_voice_connection_from_livekit/2]).
|
||||
-export([move_member/2]).
|
||||
-export([broadcast_voice_state_update/3]).
|
||||
-export([broadcast_voice_server_update_to_session/6]).
|
||||
-export([broadcast_voice_server_update_to_session/7]).
|
||||
-export([send_voice_server_update_for_move/5]).
|
||||
-export([send_voice_server_updates_for_move/4]).
|
||||
-export([switch_voice_region_handler/2]).
|
||||
@@ -35,67 +35,88 @@
|
||||
-export([handle_virtual_channel_access_for_move/4]).
|
||||
-export([cleanup_virtual_access_on_disconnect/2]).
|
||||
|
||||
voice_state_update(Request, State) ->
|
||||
case guild_voice_connection:voice_state_update(Request, State) of
|
||||
{reply, Response, NewState} ->
|
||||
{reply, Response, NewState};
|
||||
{error, Category, Message} ->
|
||||
{reply, {error, Category, Message}, State}
|
||||
end.
|
||||
-type guild_state() :: map().
|
||||
-type voice_state() :: map().
|
||||
|
||||
-spec voice_state_update(map(), guild_state()) ->
|
||||
{reply, map(), guild_state()} | {reply, {error, atom(), atom()}, guild_state()}.
|
||||
voice_state_update(Request, State) ->
|
||||
guild_voice_connection:voice_state_update(Request, State).
|
||||
|
||||
-spec get_voice_state(map(), guild_state()) -> {reply, map(), guild_state()}.
|
||||
get_voice_state(Request, State) ->
|
||||
guild_voice_state:get_voice_state(Request, State).
|
||||
|
||||
-spec get_voice_states_list(guild_state()) -> [voice_state()].
|
||||
get_voice_states_list(State) ->
|
||||
guild_voice_state:get_voice_states_list(State).
|
||||
|
||||
-spec update_member_voice(map(), guild_state()) -> {reply, map(), guild_state()}.
|
||||
update_member_voice(Request, State) ->
|
||||
guild_voice_member:update_member_voice(Request, State).
|
||||
|
||||
-spec disconnect_voice_user(map(), guild_state()) -> {reply, map(), guild_state()}.
|
||||
disconnect_voice_user(Request, State) ->
|
||||
guild_voice_disconnect:disconnect_voice_user(Request, State).
|
||||
|
||||
-spec disconnect_voice_user_if_in_channel(map(), guild_state()) -> {reply, map(), guild_state()}.
|
||||
disconnect_voice_user_if_in_channel(Request, State) ->
|
||||
guild_voice_disconnect:disconnect_voice_user_if_in_channel(Request, State).
|
||||
|
||||
-spec disconnect_all_voice_users_in_channel(map(), guild_state()) -> {reply, map(), guild_state()}.
|
||||
disconnect_all_voice_users_in_channel(Request, State) ->
|
||||
guild_voice_disconnect:disconnect_all_voice_users_in_channel(Request, State).
|
||||
|
||||
-spec confirm_voice_connection_from_livekit(map(), guild_state()) ->
|
||||
{reply, map(), guild_state()} | {error, atom(), atom()}.
|
||||
confirm_voice_connection_from_livekit(Request, State) ->
|
||||
case guild_voice_connection:confirm_voice_connection_from_livekit(Request, State) of
|
||||
{reply, Response, NewState} ->
|
||||
{reply, Response, NewState};
|
||||
{error, Category, Message} ->
|
||||
{reply, {error, Category, Message}, State}
|
||||
end.
|
||||
guild_voice_connection:confirm_voice_connection_from_livekit(Request, State).
|
||||
|
||||
-spec move_member(map(), guild_state()) -> {reply, map(), guild_state()}.
|
||||
move_member(Request, State) ->
|
||||
guild_voice_move:move_member(Request, State).
|
||||
|
||||
-spec send_voice_server_update_for_move(integer(), integer(), integer(), binary(), pid()) -> ok.
|
||||
send_voice_server_update_for_move(GuildId, ChannelId, UserId, SessionId, GuildPid) ->
|
||||
guild_voice_move:send_voice_server_update_for_move(
|
||||
GuildId, ChannelId, UserId, SessionId, GuildPid
|
||||
).
|
||||
|
||||
-spec send_voice_server_updates_for_move(integer(), integer(), [map()], pid()) -> ok.
|
||||
send_voice_server_updates_for_move(GuildId, ChannelId, SessionDataList, GuildPid) ->
|
||||
guild_voice_move:send_voice_server_updates_for_move(
|
||||
GuildId, ChannelId, SessionDataList, GuildPid
|
||||
).
|
||||
|
||||
-spec broadcast_voice_state_update(voice_state(), guild_state(), binary() | null) -> ok.
|
||||
broadcast_voice_state_update(VoiceState, State, OldChannelIdBin) ->
|
||||
guild_voice_broadcast:broadcast_voice_state_update(VoiceState, State, OldChannelIdBin).
|
||||
|
||||
broadcast_voice_server_update_to_session(GuildId, SessionId, Token, Endpoint, ConnectionId, State) ->
|
||||
-spec broadcast_voice_server_update_to_session(
|
||||
integer(), integer(), binary(), binary(), binary(), binary(), guild_state()
|
||||
) -> ok.
|
||||
broadcast_voice_server_update_to_session(
|
||||
GuildId,
|
||||
ChannelId,
|
||||
SessionId,
|
||||
Token,
|
||||
Endpoint,
|
||||
ConnectionId,
|
||||
State
|
||||
) ->
|
||||
guild_voice_broadcast:broadcast_voice_server_update_to_session(
|
||||
GuildId, SessionId, Token, Endpoint, ConnectionId, State
|
||||
GuildId, ChannelId, SessionId, Token, Endpoint, ConnectionId, State
|
||||
).
|
||||
|
||||
-spec switch_voice_region_handler(map(), guild_state()) -> {reply, map(), guild_state()}.
|
||||
switch_voice_region_handler(Request, State) ->
|
||||
guild_voice_region:switch_voice_region_handler(Request, State).
|
||||
|
||||
-spec switch_voice_region(integer(), integer(), pid()) -> ok | {error, term()}.
|
||||
switch_voice_region(GuildId, ChannelId, GuildPid) ->
|
||||
guild_voice_region:switch_voice_region(GuildId, ChannelId, GuildPid).
|
||||
|
||||
-spec handle_virtual_channel_access_for_move(integer(), integer(), map(), pid()) -> ok.
|
||||
handle_virtual_channel_access_for_move(UserId, ChannelId, _ConnectionsToMove, GuildPid) ->
|
||||
case gen_server:call(GuildPid, {get_sessions}, 10000) of
|
||||
State when is_map(State) ->
|
||||
@@ -104,16 +125,21 @@ handle_virtual_channel_access_for_move(UserId, ChannelId, _ConnectionsToMove, Gu
|
||||
undefined ->
|
||||
ok;
|
||||
_ ->
|
||||
HasViewPermission = guild_permissions:can_view_channel_by_permissions(
|
||||
UserId, ChannelId, Member, State
|
||||
Permissions = guild_permissions:get_member_permissions(
|
||||
UserId, ChannelId, State
|
||||
),
|
||||
case HasViewPermission of
|
||||
ViewPerm = constants:view_channel_permission(),
|
||||
ConnectPerm = constants:connect_permission(),
|
||||
HasView = (Permissions band ViewPerm) =:= ViewPerm,
|
||||
HasConnect = (Permissions band ConnectPerm) =:= ConnectPerm,
|
||||
case HasView andalso HasConnect of
|
||||
true ->
|
||||
ok;
|
||||
false ->
|
||||
gen_server:cast(
|
||||
gen_server:call(
|
||||
GuildPid,
|
||||
{add_virtual_channel_access, UserId, ChannelId}
|
||||
{add_virtual_channel_access, UserId, ChannelId},
|
||||
10000
|
||||
)
|
||||
end
|
||||
end;
|
||||
@@ -121,5 +147,6 @@ handle_virtual_channel_access_for_move(UserId, ChannelId, _ConnectionsToMove, Gu
|
||||
ok
|
||||
end.
|
||||
|
||||
-spec cleanup_virtual_access_on_disconnect(integer(), pid()) -> ok.
|
||||
cleanup_virtual_access_on_disconnect(UserId, GuildPid) ->
|
||||
gen_server:cast(GuildPid, {cleanup_virtual_access_for_user, UserId}).
|
||||
|
||||
@@ -18,28 +18,23 @@
|
||||
-module(guild_voice_broadcast).
|
||||
|
||||
-export([broadcast_voice_state_update/3]).
|
||||
-export([broadcast_voice_server_update_to_session/6]).
|
||||
-export([broadcast_voice_server_update_to_session/7]).
|
||||
|
||||
-ifdef(TEST).
|
||||
-define(WARN_MISSING_CONN(_VoiceState), ok).
|
||||
-else.
|
||||
-define(WARN_MISSING_CONN(VoiceState),
|
||||
logger:warning(
|
||||
"[guild_voice_broadcast] Skipping VOICE_STATE_UPDATE broadcast - missing connection_id: ~p",
|
||||
[VoiceState]
|
||||
)
|
||||
).
|
||||
-include_lib("eunit/include/eunit.hrl").
|
||||
-endif.
|
||||
|
||||
-type guild_state() :: map().
|
||||
-type voice_state() :: map().
|
||||
|
||||
-spec broadcast_voice_state_update(voice_state(), guild_state(), binary() | null) -> ok.
|
||||
broadcast_voice_state_update(VoiceState, State, OldChannelIdBin) ->
|
||||
case maps:get(<<"connection_id">>, VoiceState, undefined) of
|
||||
undefined ->
|
||||
?WARN_MISSING_CONN(VoiceState),
|
||||
ok;
|
||||
ConnectionId ->
|
||||
_ConnectionId ->
|
||||
Sessions = maps:get(sessions, State, #{}),
|
||||
ChannelIdBin = maps:get(<<"channel_id">>, VoiceState, null),
|
||||
|
||||
FilterChannelIdBin =
|
||||
case ChannelIdBin of
|
||||
null ->
|
||||
@@ -47,71 +42,100 @@ broadcast_voice_state_update(VoiceState, State, OldChannelIdBin) ->
|
||||
_ ->
|
||||
ChannelIdBin
|
||||
end,
|
||||
|
||||
FilterChannelId = utils:binary_to_integer_safe(FilterChannelIdBin),
|
||||
|
||||
UserId = maps:get(<<"user_id">>, VoiceState, <<"unknown">>),
|
||||
GuildId = maps:get(id, State, 0),
|
||||
AllSessionDetails = [{Sid, maps:get(user_id, S)} || {Sid, S} <- maps:to_list(Sessions)],
|
||||
logger:info(
|
||||
"[guild_voice_broadcast] Broadcasting voice state update: "
|
||||
"guild_id=~p user_id=~p channel_id=~p connection_id=~p "
|
||||
"total_sessions=~p all_sessions=~p filter_channel_id=~p",
|
||||
[
|
||||
GuildId,
|
||||
UserId,
|
||||
ChannelIdBin,
|
||||
ConnectionId,
|
||||
maps:size(Sessions),
|
||||
AllSessionDetails,
|
||||
FilterChannelId
|
||||
]
|
||||
),
|
||||
|
||||
FilteredSessions = guild_sessions:filter_sessions_for_channel(
|
||||
Sessions, FilterChannelId, undefined, State
|
||||
),
|
||||
|
||||
SessionDetails = [{Sid, maps:get(user_id, S)} || {Sid, S} <- FilteredSessions],
|
||||
Pids = [maps:get(pid, S) || {_Sid, S} <- FilteredSessions],
|
||||
|
||||
logger:info(
|
||||
"[guild_voice_broadcast] Filtered sessions: "
|
||||
"guild_id=~p user_id=~p filtered_count=~p session_details=~p pids=~p",
|
||||
[GuildId, UserId, length(FilteredSessions), SessionDetails, Pids]
|
||||
),
|
||||
|
||||
lists:foreach(
|
||||
fun(Pid) when is_pid(Pid) ->
|
||||
logger:info(
|
||||
"[guild_voice_broadcast] Sending voice_state_update to session pid ~p",
|
||||
[Pid]
|
||||
),
|
||||
gen_server:cast(Pid, {dispatch, voice_state_update, VoiceState})
|
||||
end,
|
||||
Pids
|
||||
)
|
||||
),
|
||||
maybe_relay_voice_state_update(VoiceState, OldChannelIdBin, State),
|
||||
ok
|
||||
end.
|
||||
|
||||
broadcast_voice_server_update_to_session(GuildId, SessionId, Token, Endpoint, ConnectionId, State) ->
|
||||
-spec broadcast_voice_server_update_to_session(
|
||||
integer(), integer(), binary(), binary(), binary(), binary(), guild_state()
|
||||
) -> ok.
|
||||
broadcast_voice_server_update_to_session(
|
||||
GuildId,
|
||||
ChannelId,
|
||||
SessionId,
|
||||
Token,
|
||||
Endpoint,
|
||||
ConnectionId,
|
||||
State
|
||||
) ->
|
||||
VoiceServerUpdate = #{
|
||||
<<"token">> => Token,
|
||||
<<"endpoint">> => Endpoint,
|
||||
<<"guild_id">> => integer_to_binary(GuildId),
|
||||
<<"channel_id">> => integer_to_binary(ChannelId),
|
||||
<<"connection_id">> => ConnectionId
|
||||
},
|
||||
|
||||
Sessions = maps:get(sessions, State, #{}),
|
||||
|
||||
case maps:get(SessionId, Sessions, undefined) of
|
||||
undefined ->
|
||||
maybe_relay_voice_server_update(
|
||||
GuildId, ChannelId, SessionId, Token, Endpoint, ConnectionId, State
|
||||
),
|
||||
ok;
|
||||
SessionData ->
|
||||
SessionPid = maps:get(pid, SessionData, null),
|
||||
case SessionPid of
|
||||
Pid when is_pid(Pid) ->
|
||||
gen_server:cast(Pid, {dispatch, voice_server_update, VoiceServerUpdate});
|
||||
gen_server:cast(Pid, {dispatch, voice_server_update, VoiceServerUpdate}),
|
||||
ok;
|
||||
_ ->
|
||||
maybe_relay_voice_server_update(
|
||||
GuildId, ChannelId, SessionId, Token, Endpoint, ConnectionId, State
|
||||
),
|
||||
ok
|
||||
end
|
||||
end.
|
||||
|
||||
-spec maybe_relay_voice_state_update(map(), binary() | null, guild_state()) -> ok.
|
||||
maybe_relay_voice_state_update(VoiceState, OldChannelIdBin, State) ->
|
||||
case {maps:get(very_large_guild_coordinator_pid, State, undefined),
|
||||
maps:get(very_large_guild_shard_index, State, undefined)}
|
||||
of
|
||||
{CoordPid, ShardIndex} when is_pid(CoordPid), is_integer(ShardIndex) ->
|
||||
CoordPid ! {very_large_guild_voice_state_update, ShardIndex, VoiceState, OldChannelIdBin},
|
||||
ok;
|
||||
_ ->
|
||||
ok
|
||||
end.
|
||||
|
||||
-spec maybe_relay_voice_server_update(
|
||||
integer(),
|
||||
integer(),
|
||||
binary(),
|
||||
binary(),
|
||||
binary(),
|
||||
binary(),
|
||||
guild_state()
|
||||
) -> ok.
|
||||
maybe_relay_voice_server_update(GuildId, ChannelId, SessionId, Token, Endpoint, ConnectionId, State) ->
|
||||
case {maps:get(very_large_guild_coordinator_pid, State, undefined),
|
||||
maps:get(very_large_guild_shard_index, State, undefined)}
|
||||
of
|
||||
{CoordPid, ShardIndex} when is_pid(CoordPid), is_integer(ShardIndex) ->
|
||||
CoordPid !
|
||||
{very_large_guild_voice_server_update, ShardIndex, GuildId, ChannelId, SessionId,
|
||||
Token, Endpoint, ConnectionId},
|
||||
ok;
|
||||
_ ->
|
||||
ok
|
||||
end.
|
||||
|
||||
-ifdef(TEST).
|
||||
|
||||
broadcast_voice_state_update_missing_connection_id_test() ->
|
||||
VoiceState = #{<<"user_id">> => <<"1">>},
|
||||
State = #{sessions => #{}},
|
||||
?assertEqual(ok, broadcast_voice_state_update(VoiceState, State, null)).
|
||||
|
||||
-endif.
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -23,15 +23,18 @@
|
||||
-export([disconnect_voice_user_if_in_channel/2]).
|
||||
-export([disconnect_all_voice_users_in_channel/2]).
|
||||
-export([cleanup_virtual_channel_access_for_user/2]).
|
||||
|
||||
-type guild_state() :: map().
|
||||
-type voice_state() :: map().
|
||||
-type voice_state_map() :: #{binary() => voice_state()}.
|
||||
-export([recently_disconnected_voice_states/1]).
|
||||
-export([clear_recently_disconnected/2]).
|
||||
-export([clear_recently_disconnected_for_channel/2]).
|
||||
|
||||
-ifdef(TEST).
|
||||
-include_lib("eunit/include/eunit.hrl").
|
||||
-endif.
|
||||
|
||||
-type guild_state() :: map().
|
||||
-type voice_state() :: map().
|
||||
-type voice_state_map() :: #{binary() => voice_state()}.
|
||||
|
||||
-spec handle_voice_disconnect(
|
||||
binary() | undefined,
|
||||
term(),
|
||||
@@ -45,7 +48,10 @@ handle_voice_disconnect(ConnectionId, _SessionId, UserId, VoiceStates0, State) -
|
||||
VoiceStates = voice_state_utils:ensure_voice_states(VoiceStates0),
|
||||
case maps:get(ConnectionId, VoiceStates, undefined) of
|
||||
undefined ->
|
||||
{reply, #{success => true}, State};
|
||||
%% Voice state not in voice_states - check if it's still pending
|
||||
%% (user disconnected before LiveKit confirmation)
|
||||
State1 = clear_pending_voice_connection(ConnectionId, State),
|
||||
{reply, #{success => true}, State1};
|
||||
OldVoiceState ->
|
||||
case guild_voice_state:user_matches_voice_state(OldVoiceState, UserId) of
|
||||
false ->
|
||||
@@ -64,16 +70,43 @@ handle_voice_disconnect(ConnectionId, _SessionId, UserId, VoiceStates0, State) -
|
||||
{GuildId, ChannelId} ->
|
||||
maybe_force_disconnect(GuildId, ChannelId, UserId, ConnectionId, State),
|
||||
NewVoiceStates = maps:remove(ConnectionId, VoiceStates),
|
||||
NewState = maps:put(voice_states, NewVoiceStates, State),
|
||||
NewState0 = maps:put(voice_states, NewVoiceStates, State),
|
||||
NewState = clear_recently_disconnected(ConnectionId, NewState0),
|
||||
voice_state_utils:broadcast_disconnects(
|
||||
#{ConnectionId => OldVoiceState}, NewState
|
||||
),
|
||||
FinalState = cleanup_virtual_channel_access_for_user(UserId, NewState),
|
||||
FinalState = maybe_cleanup_after_disconnect(
|
||||
UserId, ChannelId, NewState
|
||||
),
|
||||
{reply, #{success => true}, FinalState}
|
||||
end
|
||||
end
|
||||
end.
|
||||
|
||||
-spec clear_pending_voice_connection(binary(), guild_state()) -> guild_state().
|
||||
clear_pending_voice_connection(ConnectionId, State) ->
|
||||
PendingConnections = maps:get(pending_voice_connections, State, #{}),
|
||||
case maps:is_key(ConnectionId, PendingConnections) of
|
||||
false ->
|
||||
State;
|
||||
true ->
|
||||
NewPendingConnections = maps:remove(ConnectionId, PendingConnections),
|
||||
maps:put(pending_voice_connections, NewPendingConnections, State)
|
||||
end.
|
||||
|
||||
-spec maybe_cleanup_after_disconnect(integer(), integer(), guild_state()) -> guild_state().
|
||||
maybe_cleanup_after_disconnect(UserId, ChannelId, State) ->
|
||||
case
|
||||
guild_virtual_channel_access:is_pending_join(UserId, ChannelId, State) orelse
|
||||
guild_virtual_channel_access:has_preserve(UserId, ChannelId, State) orelse
|
||||
guild_virtual_channel_access:is_move_pending(UserId, ChannelId, State)
|
||||
of
|
||||
true ->
|
||||
State;
|
||||
false ->
|
||||
cleanup_virtual_channel_access_for_user(UserId, State)
|
||||
end.
|
||||
|
||||
-spec disconnect_voice_user(map(), guild_state()) -> {reply, map(), guild_state()}.
|
||||
disconnect_voice_user(#{user_id := UserId} = Request, State) ->
|
||||
ConnectionId = maps:get(connection_id, Request, null),
|
||||
@@ -85,12 +118,22 @@ disconnect_voice_user(#{user_id := UserId} = Request, State) ->
|
||||
end),
|
||||
case maps:size(UserVoiceStates) of
|
||||
0 ->
|
||||
{reply, #{success => true}, State};
|
||||
%% No active voice states - also clean up any pending connections
|
||||
State1 = clear_pending_voice_connections_for_user(UserId, State),
|
||||
{reply, #{success => true}, State1};
|
||||
_ ->
|
||||
maybe_force_disconnect_voice_states(UserVoiceStates, State),
|
||||
NewVoiceStates = voice_state_utils:drop_voice_states(
|
||||
UserVoiceStates, VoiceStates
|
||||
),
|
||||
NewState = maps:put(voice_states, NewVoiceStates, State),
|
||||
NewState0 = maps:put(voice_states, NewVoiceStates, State),
|
||||
NewState = maps:fold(
|
||||
fun(ConnId, _, AccState) ->
|
||||
clear_recently_disconnected(ConnId, AccState)
|
||||
end,
|
||||
NewState0,
|
||||
UserVoiceStates
|
||||
),
|
||||
voice_state_utils:broadcast_disconnects(UserVoiceStates, NewState),
|
||||
FinalState = cleanup_virtual_channel_access_for_user(UserId, NewState),
|
||||
{reply, #{success => true}, FinalState}
|
||||
@@ -98,14 +141,18 @@ disconnect_voice_user(#{user_id := UserId} = Request, State) ->
|
||||
SpecificConnection ->
|
||||
case maps:get(SpecificConnection, VoiceStates, undefined) of
|
||||
undefined ->
|
||||
{reply, #{success => true}, State};
|
||||
%% Not found in voice_states - also clean up pending connection
|
||||
State1 = clear_pending_voice_connection(SpecificConnection, State),
|
||||
{reply, #{success => true}, State1};
|
||||
VoiceState ->
|
||||
case voice_state_utils:voice_state_user_id(VoiceState) of
|
||||
undefined ->
|
||||
{reply, gateway_errors:error(voice_invalid_state), State};
|
||||
VoiceStateUserId when VoiceStateUserId =:= UserId ->
|
||||
maybe_force_disconnect_voice_state(SpecificConnection, VoiceState, State),
|
||||
NewVoiceStates = maps:remove(SpecificConnection, VoiceStates),
|
||||
NewState = maps:put(voice_states, NewVoiceStates, State),
|
||||
NewState0 = maps:put(voice_states, NewVoiceStates, State),
|
||||
NewState = clear_recently_disconnected(SpecificConnection, NewState0),
|
||||
voice_state_utils:broadcast_disconnects(
|
||||
#{SpecificConnection => VoiceState}, NewState
|
||||
),
|
||||
@@ -117,6 +164,19 @@ disconnect_voice_user(#{user_id := UserId} = Request, State) ->
|
||||
end
|
||||
end.
|
||||
|
||||
-spec clear_pending_voice_connections_for_user(integer(), guild_state()) -> guild_state().
|
||||
clear_pending_voice_connections_for_user(UserId, State) ->
|
||||
PendingConnections = maps:get(pending_voice_connections, State, #{}),
|
||||
FilteredPending = maps:filter(
|
||||
fun(_ConnId, PendingData) ->
|
||||
PendingUserId = maps:get(user_id, PendingData, undefined),
|
||||
PendingUserId =/= UserId
|
||||
end,
|
||||
PendingConnections
|
||||
),
|
||||
maps:put(pending_voice_connections, FilteredPending, State).
|
||||
|
||||
-spec disconnect_voice_user_if_in_channel(map(), guild_state()) -> {reply, map(), guild_state()}.
|
||||
disconnect_voice_user_if_in_channel(
|
||||
#{user_id := UserId, expected_channel_id := ExpectedChannelId} = Request,
|
||||
State
|
||||
@@ -131,27 +191,36 @@ disconnect_voice_user_if_in_channel(
|
||||
end),
|
||||
case maps:size(UserVoiceStates) of
|
||||
0 ->
|
||||
%% Not found in voice_states - also clean up any pending connections
|
||||
%% for this user/channel (user disconnected before LiveKit confirmation)
|
||||
State1 = clear_pending_voice_connections_for_user_channel(
|
||||
UserId, ExpectedChannelId, State
|
||||
),
|
||||
{reply,
|
||||
#{
|
||||
success => true,
|
||||
ignored => true,
|
||||
reason => <<"not_in_expected_channel">>
|
||||
},
|
||||
State};
|
||||
State1};
|
||||
_ ->
|
||||
NewVoiceStates = voice_state_utils:drop_voice_states(
|
||||
UserVoiceStates, VoiceStates
|
||||
),
|
||||
NewState = maps:put(voice_states, NewVoiceStates, State),
|
||||
NewState0 = maps:put(voice_states, NewVoiceStates, State),
|
||||
NewState = cache_recently_disconnected(UserVoiceStates, NewState0),
|
||||
voice_state_utils:broadcast_disconnects(UserVoiceStates, NewState),
|
||||
{reply, #{success => true}, NewState}
|
||||
end;
|
||||
ConnId ->
|
||||
case maps:get(ConnId, VoiceStates, undefined) of
|
||||
undefined ->
|
||||
%% Not found in voice_states - also clean up pending connection
|
||||
%% (user disconnected before LiveKit confirmation)
|
||||
State1 = clear_pending_voice_connection(ConnId, State),
|
||||
{reply,
|
||||
#{success => true, ignored => true, reason => <<"connection_not_found">>},
|
||||
State};
|
||||
State1};
|
||||
VoiceState ->
|
||||
case
|
||||
{
|
||||
@@ -161,7 +230,10 @@ disconnect_voice_user_if_in_channel(
|
||||
of
|
||||
{UserId, ExpectedChannelId} ->
|
||||
NewVoiceStates = maps:remove(ConnId, VoiceStates),
|
||||
NewState = maps:put(voice_states, NewVoiceStates, State),
|
||||
NewState0 = maps:put(voice_states, NewVoiceStates, State),
|
||||
NewState = cache_recently_disconnected(
|
||||
#{ConnId => VoiceState}, NewState0
|
||||
),
|
||||
voice_state_utils:broadcast_disconnects(
|
||||
#{ConnId => VoiceState}, NewState
|
||||
),
|
||||
@@ -178,105 +250,157 @@ disconnect_voice_user_if_in_channel(
|
||||
end
|
||||
end.
|
||||
|
||||
-spec clear_pending_voice_connections_for_user_channel(integer(), integer(), guild_state()) ->
|
||||
guild_state().
|
||||
clear_pending_voice_connections_for_user_channel(UserId, ChannelId, State) ->
|
||||
PendingConnections = maps:get(pending_voice_connections, State, #{}),
|
||||
FilteredPending = maps:filter(
|
||||
fun(_ConnId, PendingData) ->
|
||||
PendingUserId = maps:get(user_id, PendingData, undefined),
|
||||
PendingChannelId = maps:get(channel_id, PendingData, undefined),
|
||||
not (PendingUserId =:= UserId andalso PendingChannelId =:= ChannelId)
|
||||
end,
|
||||
PendingConnections
|
||||
),
|
||||
maps:put(pending_voice_connections, FilteredPending, State).
|
||||
|
||||
-spec clear_pending_voice_connections_for_channel(integer(), guild_state()) -> guild_state().
|
||||
clear_pending_voice_connections_for_channel(ChannelId, State) ->
|
||||
PendingConnections = maps:get(pending_voice_connections, State, #{}),
|
||||
FilteredPending = maps:filter(
|
||||
fun(_ConnId, PendingData) ->
|
||||
PendingChannelId = maps:get(channel_id, PendingData, undefined),
|
||||
PendingChannelId =/= ChannelId
|
||||
end,
|
||||
PendingConnections
|
||||
),
|
||||
maps:put(pending_voice_connections, FilteredPending, State).
|
||||
|
||||
-spec disconnect_all_voice_users_in_channel(map(), guild_state()) -> {reply, map(), guild_state()}.
|
||||
disconnect_all_voice_users_in_channel(#{channel_id := ChannelId}, State) ->
|
||||
VoiceStates = voice_state_utils:voice_states(State),
|
||||
ChannelVoiceStates = voice_state_utils:filter_voice_states(VoiceStates, fun(_, V) ->
|
||||
voice_state_utils:voice_state_channel_id(V) =:= ChannelId
|
||||
end),
|
||||
%% Also clean up any pending connections for this channel
|
||||
%% (users that requested tokens but haven't confirmed via LiveKit yet)
|
||||
State1 = clear_pending_voice_connections_for_channel(ChannelId, State),
|
||||
case maps:size(ChannelVoiceStates) of
|
||||
0 ->
|
||||
{reply, #{success => true, disconnected_count => 0}, State};
|
||||
{reply, #{success => true, disconnected_count => 0}, State1};
|
||||
Count ->
|
||||
maybe_force_disconnect_voice_states(ChannelVoiceStates, State1),
|
||||
NewVoiceStates = voice_state_utils:drop_voice_states(ChannelVoiceStates, VoiceStates),
|
||||
NewState = maps:put(voice_states, NewVoiceStates, State),
|
||||
NewState0 = maps:put(voice_states, NewVoiceStates, State1),
|
||||
NewState = clear_recently_disconnected_for_channel(ChannelId, NewState0),
|
||||
voice_state_utils:broadcast_disconnects(ChannelVoiceStates, NewState),
|
||||
{reply, #{success => true, disconnected_count => Count}, NewState}
|
||||
end.
|
||||
|
||||
-spec maybe_force_disconnect_voice_states(voice_state_map(), guild_state()) -> ok.
|
||||
maybe_force_disconnect_voice_states(VoiceStates, State) ->
|
||||
maps:foreach(
|
||||
fun(ConnId, VoiceState) ->
|
||||
maybe_force_disconnect_voice_state(ConnId, VoiceState, State)
|
||||
end,
|
||||
VoiceStates
|
||||
),
|
||||
ok.
|
||||
|
||||
-spec maybe_force_disconnect_voice_state(binary(), voice_state(), guild_state()) -> ok.
|
||||
maybe_force_disconnect_voice_state(ConnectionId, VoiceState, State) ->
|
||||
UserId = voice_state_utils:voice_state_user_id(VoiceState),
|
||||
ChannelId = voice_state_utils:voice_state_channel_id(VoiceState),
|
||||
GuildId = resolve_guild_id(VoiceState, State),
|
||||
case {GuildId, ChannelId, UserId} of
|
||||
{GId, CId, UId} when is_integer(GId), is_integer(CId), is_integer(UId) ->
|
||||
_ = maybe_force_disconnect(GId, CId, UId, ConnectionId, State),
|
||||
ok;
|
||||
_ ->
|
||||
ok
|
||||
end.
|
||||
|
||||
-spec resolve_guild_id(voice_state(), guild_state()) -> integer() | undefined.
|
||||
resolve_guild_id(VoiceState, State) ->
|
||||
case voice_state_utils:voice_state_guild_id(VoiceState) of
|
||||
undefined ->
|
||||
map_utils:get_integer(State, id, undefined);
|
||||
GuildId ->
|
||||
GuildId
|
||||
end.
|
||||
|
||||
-spec force_disconnect_participant(integer(), integer(), integer(), binary()) ->
|
||||
{ok, map()} | {error, term()}.
|
||||
force_disconnect_participant(GuildId, ChannelId, UserId, ConnectionId) ->
|
||||
Req = voice_utils:build_force_disconnect_rpc_request(GuildId, ChannelId, UserId, ConnectionId),
|
||||
case rpc_client:call(Req) of
|
||||
{ok, _Data} ->
|
||||
logger:debug(
|
||||
"[guild_voice_disconnect] Force disconnected participant via RPC ~p",
|
||||
[
|
||||
[
|
||||
{guildId, GuildId},
|
||||
{channelId, ChannelId},
|
||||
{userId, UserId},
|
||||
{connectionId, ConnectionId}
|
||||
]
|
||||
]
|
||||
),
|
||||
{ok, #{success => true}};
|
||||
{error, Reason} ->
|
||||
logger:error(
|
||||
"[guild_voice_disconnect] Failed to force disconnect participant via RPC ~p",
|
||||
[
|
||||
[
|
||||
{guildId, GuildId},
|
||||
{channelId, ChannelId},
|
||||
{userId, UserId},
|
||||
{connectionId, ConnectionId},
|
||||
{error, Reason}
|
||||
]
|
||||
]
|
||||
),
|
||||
{error, Reason}
|
||||
end.
|
||||
|
||||
-spec cleanup_virtual_channel_access_for_user(integer(), guild_state()) -> guild_state().
|
||||
cleanup_virtual_channel_access_for_user(UserId, State) ->
|
||||
VoiceStates = voice_state_utils:voice_states(State),
|
||||
HasVoiceConnection = maps:fold(
|
||||
VirtualChannels = guild_virtual_channel_access:get_virtual_channels_for_user(UserId, State),
|
||||
lists:foldl(
|
||||
fun(ChannelId, AccState) ->
|
||||
case user_has_voice_connection_in_channel(UserId, ChannelId, VoiceStates) of
|
||||
true ->
|
||||
AccState;
|
||||
false ->
|
||||
case
|
||||
guild_virtual_channel_access:is_pending_join(UserId, ChannelId, AccState) orelse
|
||||
guild_virtual_channel_access:has_preserve(UserId, ChannelId, AccState) orelse
|
||||
guild_virtual_channel_access:is_move_pending(
|
||||
UserId, ChannelId, AccState
|
||||
)
|
||||
of
|
||||
true ->
|
||||
AccState;
|
||||
false ->
|
||||
ok = maybe_dispatch_visibility_remove(UserId, ChannelId, AccState),
|
||||
guild_virtual_channel_access:remove_virtual_access(
|
||||
UserId, ChannelId, AccState
|
||||
)
|
||||
end
|
||||
end
|
||||
end,
|
||||
State,
|
||||
VirtualChannels
|
||||
).
|
||||
|
||||
-spec user_has_voice_connection_in_channel(integer(), integer(), voice_state_map()) -> boolean().
|
||||
user_has_voice_connection_in_channel(UserId, ChannelId, VoiceStates) ->
|
||||
ChannelIdBin = integer_to_binary(ChannelId),
|
||||
maps:fold(
|
||||
fun(_ConnId, VoiceState, Acc) ->
|
||||
case Acc of
|
||||
true -> true;
|
||||
false -> voice_state_utils:voice_state_user_id(VoiceState) =:= UserId
|
||||
true ->
|
||||
true;
|
||||
false ->
|
||||
voice_state_utils:voice_state_user_id(VoiceState) =:= UserId andalso
|
||||
maps:get(<<"channel_id">>, VoiceState, null) =:= ChannelIdBin
|
||||
end
|
||||
end,
|
||||
false,
|
||||
VoiceStates
|
||||
),
|
||||
case HasVoiceConnection of
|
||||
).
|
||||
|
||||
-spec maybe_dispatch_visibility_remove(integer(), integer(), guild_state()) -> ok.
|
||||
maybe_dispatch_visibility_remove(UserId, ChannelId, State) ->
|
||||
case guild_virtual_channel_access:has_virtual_access(UserId, ChannelId, State) of
|
||||
true ->
|
||||
State;
|
||||
guild_virtual_channel_access:dispatch_channel_visibility_change(
|
||||
UserId, ChannelId, remove, State
|
||||
);
|
||||
false ->
|
||||
VirtualChannels = guild_virtual_channel_access:get_virtual_channels_for_user(
|
||||
UserId, State
|
||||
),
|
||||
lists:foldl(
|
||||
fun(ChannelId, AccState) ->
|
||||
Member = guild_permissions:find_member_by_user_id(UserId, AccState),
|
||||
case Member of
|
||||
undefined ->
|
||||
AccState;
|
||||
_ ->
|
||||
HasViewPermission = guild_permissions:can_view_channel_by_permissions(
|
||||
UserId, ChannelId, Member, AccState
|
||||
),
|
||||
case HasViewPermission of
|
||||
true ->
|
||||
guild_virtual_channel_access:remove_virtual_access(
|
||||
UserId, ChannelId, AccState
|
||||
);
|
||||
false ->
|
||||
guild_virtual_channel_access:dispatch_channel_visibility_change(
|
||||
UserId, ChannelId, remove, AccState
|
||||
),
|
||||
guild_virtual_channel_access:remove_virtual_access(
|
||||
UserId, ChannelId, AccState
|
||||
)
|
||||
end
|
||||
end
|
||||
end,
|
||||
State,
|
||||
VirtualChannels
|
||||
)
|
||||
ok
|
||||
end.
|
||||
|
||||
-spec maybe_force_disconnect(integer(), integer(), integer(), binary(), guild_state()) ->
|
||||
{ok, map()} | {error, term()}.
|
||||
maybe_force_disconnect(GuildId, ChannelId, UserId, ConnectionId, State) ->
|
||||
case maps:get(test_force_disconnect_fun, State, undefined) of
|
||||
Fun when is_function(Fun, 4) ->
|
||||
@@ -285,6 +409,59 @@ maybe_force_disconnect(GuildId, ChannelId, UserId, ConnectionId, State) ->
|
||||
force_disconnect_participant(GuildId, ChannelId, UserId, ConnectionId)
|
||||
end.
|
||||
|
||||
-define(RECENTLY_DISCONNECTED_TTL_MS, 60000).
|
||||
|
||||
-spec recently_disconnected_voice_states(guild_state()) -> map().
|
||||
recently_disconnected_voice_states(State) ->
|
||||
case maps:get(recently_disconnected_voice_states, State, undefined) of
|
||||
Map when is_map(Map) -> Map;
|
||||
_ -> #{}
|
||||
end.
|
||||
|
||||
-spec cache_recently_disconnected(voice_state_map(), guild_state()) -> guild_state().
|
||||
cache_recently_disconnected(VoiceStatesToCache, State) ->
|
||||
Now = erlang:system_time(millisecond),
|
||||
Existing = recently_disconnected_voice_states(State),
|
||||
Swept = sweep_expired_recently_disconnected(Existing, Now),
|
||||
NewEntries = maps:fold(
|
||||
fun(ConnId, VoiceState, Acc) ->
|
||||
maps:put(ConnId, #{voice_state => VoiceState, disconnected_at => Now}, Acc)
|
||||
end,
|
||||
Swept,
|
||||
VoiceStatesToCache
|
||||
),
|
||||
maps:put(recently_disconnected_voice_states, NewEntries, State).
|
||||
|
||||
-spec sweep_expired_recently_disconnected(map(), integer()) -> map().
|
||||
sweep_expired_recently_disconnected(Cache, Now) ->
|
||||
maps:filter(
|
||||
fun(_ConnId, #{disconnected_at := DisconnectedAt}) ->
|
||||
(Now - DisconnectedAt) < ?RECENTLY_DISCONNECTED_TTL_MS;
|
||||
(_ConnId, _) ->
|
||||
false
|
||||
end,
|
||||
Cache
|
||||
).
|
||||
|
||||
-spec clear_recently_disconnected(binary(), guild_state()) -> guild_state().
|
||||
clear_recently_disconnected(ConnectionId, State) ->
|
||||
Cache = recently_disconnected_voice_states(State),
|
||||
NewCache = maps:remove(ConnectionId, Cache),
|
||||
maps:put(recently_disconnected_voice_states, NewCache, State).
|
||||
|
||||
-spec clear_recently_disconnected_for_channel(integer(), guild_state()) -> guild_state().
|
||||
clear_recently_disconnected_for_channel(ChannelId, State) ->
|
||||
Cache = recently_disconnected_voice_states(State),
|
||||
NewCache = maps:filter(
|
||||
fun(_ConnId, #{voice_state := VS}) ->
|
||||
voice_state_utils:voice_state_channel_id(VS) =/= ChannelId;
|
||||
(_ConnId, _) ->
|
||||
false
|
||||
end,
|
||||
Cache
|
||||
),
|
||||
maps:put(recently_disconnected_voice_states, NewCache, State).
|
||||
|
||||
-ifdef(TEST).
|
||||
|
||||
disconnect_voice_user_removes_all_connections_test() ->
|
||||
@@ -292,7 +469,10 @@ disconnect_voice_user_removes_all_connections_test() ->
|
||||
<<"a">> => voice_state_fixture(5, 10, 20),
|
||||
<<"b">> => voice_state_fixture(5, 10, 21)
|
||||
},
|
||||
State = #{voice_states => VoiceStates},
|
||||
State = #{
|
||||
voice_states => VoiceStates,
|
||||
test_force_disconnect_fun => fun(_, _, _, _) -> {ok, #{success => true}} end
|
||||
},
|
||||
{reply, #{success := true}, #{voice_states := #{}}} =
|
||||
disconnect_voice_user(#{user_id => 5, connection_id => null}, State).
|
||||
|
||||
@@ -309,6 +489,353 @@ disconnect_voice_user_if_in_channel_ignored_test() ->
|
||||
{reply, #{ignored := true}, _} =
|
||||
disconnect_voice_user_if_in_channel(#{user_id => 5, expected_channel_id => 99}, State).
|
||||
|
||||
user_has_voice_connection_in_channel_test() ->
|
||||
VoiceStates = #{
|
||||
<<"conn">> => #{<<"user_id">> => <<"10">>, <<"channel_id">> => <<"100">>}
|
||||
},
|
||||
?assert(user_has_voice_connection_in_channel(10, 100, VoiceStates)),
|
||||
?assertNot(user_has_voice_connection_in_channel(10, 200, VoiceStates)),
|
||||
?assertNot(user_has_voice_connection_in_channel(20, 100, VoiceStates)).
|
||||
|
||||
recently_disconnected_voice_states_default_test() ->
|
||||
?assertEqual(#{}, recently_disconnected_voice_states(#{})).
|
||||
|
||||
recently_disconnected_voice_states_returns_map_test() ->
|
||||
Cache = #{<<"conn">> => #{voice_state => #{}, disconnected_at => 1000}},
|
||||
State = #{recently_disconnected_voice_states => Cache},
|
||||
?assertEqual(Cache, recently_disconnected_voice_states(State)).
|
||||
|
||||
cache_recently_disconnected_test() ->
|
||||
VS = voice_state_fixture(5, 10, 20),
|
||||
State = #{},
|
||||
NewState = cache_recently_disconnected(#{<<"conn">> => VS}, State),
|
||||
Cache = recently_disconnected_voice_states(NewState),
|
||||
?assert(maps:is_key(<<"conn">>, Cache)),
|
||||
#{<<"conn">> := #{voice_state := CachedVS}} = Cache,
|
||||
?assertEqual(VS, CachedVS).
|
||||
|
||||
clear_recently_disconnected_test() ->
|
||||
VS = voice_state_fixture(5, 10, 20),
|
||||
State0 = cache_recently_disconnected(#{<<"conn">> => VS}, #{}),
|
||||
State1 = clear_recently_disconnected(<<"conn">>, State0),
|
||||
?assertEqual(#{}, recently_disconnected_voice_states(State1)).
|
||||
|
||||
clear_recently_disconnected_for_channel_test() ->
|
||||
VS1 = voice_state_fixture(5, 10, 20),
|
||||
VS2 = voice_state_fixture(6, 10, 30),
|
||||
State0 = cache_recently_disconnected(#{<<"a">> => VS1, <<"b">> => VS2}, #{}),
|
||||
State1 = clear_recently_disconnected_for_channel(20, State0),
|
||||
Cache = recently_disconnected_voice_states(State1),
|
||||
?assertNot(maps:is_key(<<"a">>, Cache)),
|
||||
?assert(maps:is_key(<<"b">>, Cache)).
|
||||
|
||||
sweep_expired_recently_disconnected_test() ->
|
||||
Now = erlang:system_time(millisecond),
|
||||
Cache = #{
|
||||
<<"fresh">> => #{voice_state => #{}, disconnected_at => Now - 1000},
|
||||
<<"expired">> => #{voice_state => #{}, disconnected_at => Now - 70000}
|
||||
},
|
||||
Swept = sweep_expired_recently_disconnected(Cache, Now),
|
||||
?assert(maps:is_key(<<"fresh">>, Swept)),
|
||||
?assertNot(maps:is_key(<<"expired">>, Swept)).
|
||||
|
||||
disconnect_voice_user_if_in_channel_caches_by_connection_test() ->
|
||||
VS = voice_state_fixture(5, 10, 20),
|
||||
VoiceStates = #{<<"conn">> => VS},
|
||||
State = #{voice_states => VoiceStates},
|
||||
{reply, #{success := true}, NewState} =
|
||||
disconnect_voice_user_if_in_channel(
|
||||
#{user_id => 5, expected_channel_id => 20, connection_id => <<"conn">>},
|
||||
State
|
||||
),
|
||||
Cache = recently_disconnected_voice_states(NewState),
|
||||
?assert(maps:is_key(<<"conn">>, Cache)).
|
||||
|
||||
disconnect_voice_user_calls_force_disconnect_for_all_connections_test() ->
|
||||
Self = self(),
|
||||
TestFun = fun(GuildId, ChannelId, UserId, ConnectionId) ->
|
||||
Self ! {force_disconnect, GuildId, ChannelId, UserId, ConnectionId},
|
||||
{ok, #{success => true}}
|
||||
end,
|
||||
VoiceStates = #{
|
||||
<<"a">> => voice_state_fixture(5, 10, 20),
|
||||
<<"b">> => voice_state_fixture(5, 10, 21)
|
||||
},
|
||||
State = #{
|
||||
id => 10,
|
||||
voice_states => VoiceStates,
|
||||
test_force_disconnect_fun => TestFun
|
||||
},
|
||||
{reply, #{success := true}, #{voice_states := #{}}} =
|
||||
disconnect_voice_user(#{user_id => 5, connection_id => null}, State),
|
||||
Msgs = collect_force_disconnect_messages(2),
|
||||
?assertEqual(2, length(Msgs)),
|
||||
?assert(lists:member({force_disconnect, 10, 20, 5, <<"a">>}, Msgs)),
|
||||
?assert(lists:member({force_disconnect, 10, 21, 5, <<"b">>}, Msgs)).
|
||||
|
||||
disconnect_voice_user_calls_force_disconnect_for_specific_connection_test() ->
|
||||
Self = self(),
|
||||
TestFun = fun(GuildId, ChannelId, UserId, ConnectionId) ->
|
||||
Self ! {force_disconnect, GuildId, ChannelId, UserId, ConnectionId},
|
||||
{ok, #{success => true}}
|
||||
end,
|
||||
VoiceStates = #{
|
||||
<<"a">> => voice_state_fixture(5, 10, 20),
|
||||
<<"b">> => voice_state_fixture(5, 10, 21)
|
||||
},
|
||||
State = #{
|
||||
id => 10,
|
||||
voice_states => VoiceStates,
|
||||
test_force_disconnect_fun => TestFun
|
||||
},
|
||||
{reply, #{success := true}, NewState} =
|
||||
disconnect_voice_user(#{user_id => 5, connection_id => <<"a">>}, State),
|
||||
Msgs = collect_force_disconnect_messages(1),
|
||||
?assertEqual(1, length(Msgs)),
|
||||
?assert(lists:member({force_disconnect, 10, 20, 5, <<"a">>}, Msgs)),
|
||||
Remaining = maps:get(voice_states, NewState),
|
||||
?assert(maps:is_key(<<"b">>, Remaining)),
|
||||
?assertNot(maps:is_key(<<"a">>, Remaining)).
|
||||
|
||||
disconnect_all_voice_users_in_channel_calls_force_disconnect_test() ->
|
||||
Self = self(),
|
||||
TestFun = fun(GuildId, ChannelId, UserId, ConnectionId) ->
|
||||
Self ! {force_disconnect, GuildId, ChannelId, UserId, ConnectionId},
|
||||
{ok, #{success => true}}
|
||||
end,
|
||||
VoiceStates = #{
|
||||
<<"a">> => voice_state_fixture(5, 10, 20),
|
||||
<<"b">> => voice_state_fixture(6, 10, 20),
|
||||
<<"c">> => voice_state_fixture(7, 10, 30)
|
||||
},
|
||||
State = #{
|
||||
id => 10,
|
||||
voice_states => VoiceStates,
|
||||
test_force_disconnect_fun => TestFun
|
||||
},
|
||||
{reply, #{success := true, disconnected_count := 2}, NewState} =
|
||||
disconnect_all_voice_users_in_channel(#{channel_id => 20}, State),
|
||||
Msgs = collect_force_disconnect_messages(2),
|
||||
?assertEqual(2, length(Msgs)),
|
||||
?assert(lists:member({force_disconnect, 10, 20, 5, <<"a">>}, Msgs)),
|
||||
?assert(lists:member({force_disconnect, 10, 20, 6, <<"b">>}, Msgs)),
|
||||
Remaining = maps:get(voice_states, NewState),
|
||||
?assert(maps:is_key(<<"c">>, Remaining)),
|
||||
?assertNot(maps:is_key(<<"a">>, Remaining)),
|
||||
?assertNot(maps:is_key(<<"b">>, Remaining)).
|
||||
|
||||
disconnect_voice_user_if_in_channel_skips_force_disconnect_test() ->
|
||||
Self = self(),
|
||||
TestFun = fun(_, _, _, _) ->
|
||||
Self ! force_disconnect_called,
|
||||
{ok, #{success => true}}
|
||||
end,
|
||||
VoiceStates = #{<<"a">> => voice_state_fixture(5, 10, 20)},
|
||||
State = #{
|
||||
voice_states => VoiceStates,
|
||||
test_force_disconnect_fun => TestFun
|
||||
},
|
||||
{reply, #{success := true}, _} =
|
||||
disconnect_voice_user_if_in_channel(
|
||||
#{user_id => 5, expected_channel_id => 20, connection_id => <<"a">>},
|
||||
State
|
||||
),
|
||||
receive
|
||||
force_disconnect_called -> ?assert(false)
|
||||
after 0 ->
|
||||
ok
|
||||
end.
|
||||
|
||||
%% Tests for clear_pending_voice_connection/2
|
||||
|
||||
clear_pending_voice_connection_removes_connection_test() ->
|
||||
PendingConnections = #{
|
||||
<<"conn1">> => #{user_id => 1, channel_id => 100},
|
||||
<<"conn2">> => #{user_id => 2, channel_id => 200}
|
||||
},
|
||||
State = #{pending_voice_connections => PendingConnections},
|
||||
NewState = clear_pending_voice_connection(<<"conn1">>, State),
|
||||
NewPending = maps:get(pending_voice_connections, NewState),
|
||||
?assertNot(maps:is_key(<<"conn1">>, NewPending)),
|
||||
?assert(maps:is_key(<<"conn2">>, NewPending)).
|
||||
|
||||
clear_pending_voice_connection_ignores_missing_test() ->
|
||||
PendingConnections = #{<<"conn1">> => #{user_id => 1, channel_id => 100}},
|
||||
State = #{pending_voice_connections => PendingConnections},
|
||||
NewState = clear_pending_voice_connection(<<"missing">>, State),
|
||||
?assertEqual(State, NewState).
|
||||
|
||||
clear_pending_voice_connection_handles_empty_pending_test() ->
|
||||
State = #{voice_states => #{}},
|
||||
NewState = clear_pending_voice_connection(<<"conn">>, State),
|
||||
?assertEqual(#{}, maps:get(pending_voice_connections, NewState, #{})).
|
||||
|
||||
%% Tests for clear_pending_voice_connections_for_user/2
|
||||
|
||||
clear_pending_voice_connections_for_user_removes_all_user_connections_test() ->
|
||||
PendingConnections = #{
|
||||
<<"conn1">> => #{user_id => 5, channel_id => 100},
|
||||
<<"conn2">> => #{user_id => 5, channel_id => 200},
|
||||
<<"conn3">> => #{user_id => 6, channel_id => 100}
|
||||
},
|
||||
State = #{pending_voice_connections => PendingConnections},
|
||||
NewState = clear_pending_voice_connections_for_user(5, State),
|
||||
NewPending = maps:get(pending_voice_connections, NewState),
|
||||
?assertNot(maps:is_key(<<"conn1">>, NewPending)),
|
||||
?assertNot(maps:is_key(<<"conn2">>, NewPending)),
|
||||
?assert(maps:is_key(<<"conn3">>, NewPending)).
|
||||
|
||||
%% Tests for clear_pending_voice_connections_for_user_channel/3
|
||||
|
||||
clear_pending_voice_connections_for_user_channel_removes_matching_test() ->
|
||||
PendingConnections = #{
|
||||
<<"conn1">> => #{user_id => 5, channel_id => 100},
|
||||
<<"conn2">> => #{user_id => 5, channel_id => 200},
|
||||
<<"conn3">> => #{user_id => 6, channel_id => 100}
|
||||
},
|
||||
State = #{pending_voice_connections => PendingConnections},
|
||||
NewState = clear_pending_voice_connections_for_user_channel(5, 100, State),
|
||||
NewPending = maps:get(pending_voice_connections, NewState),
|
||||
?assertNot(maps:is_key(<<"conn1">>, NewPending)),
|
||||
?assert(maps:is_key(<<"conn2">>, NewPending)),
|
||||
?assert(maps:is_key(<<"conn3">>, NewPending)).
|
||||
|
||||
%% Tests for clear_pending_voice_connections_for_channel/2
|
||||
|
||||
clear_pending_voice_connections_for_channel_removes_all_channel_connections_test() ->
|
||||
PendingConnections = #{
|
||||
<<"conn1">> => #{user_id => 5, channel_id => 100},
|
||||
<<"conn2">> => #{user_id => 6, channel_id => 100},
|
||||
<<"conn3">> => #{user_id => 7, channel_id => 200}
|
||||
},
|
||||
State = #{pending_voice_connections => PendingConnections},
|
||||
NewState = clear_pending_voice_connections_for_channel(100, State),
|
||||
NewPending = maps:get(pending_voice_connections, NewState),
|
||||
?assertNot(maps:is_key(<<"conn1">>, NewPending)),
|
||||
?assertNot(maps:is_key(<<"conn2">>, NewPending)),
|
||||
?assert(maps:is_key(<<"conn3">>, NewPending)).
|
||||
|
||||
%% Tests for disconnect handlers cleaning up pending connections
|
||||
|
||||
handle_voice_disconnect_cleans_pending_when_not_in_voice_states_test() ->
|
||||
PendingConnections = #{<<"conn1">> => #{user_id => 5, channel_id => 100}},
|
||||
State = #{
|
||||
voice_states => #{},
|
||||
pending_voice_connections => PendingConnections
|
||||
},
|
||||
{reply, #{success := true}, NewState} =
|
||||
handle_voice_disconnect(<<"conn1">>, undefined, 5, #{}, State),
|
||||
NewPending = maps:get(pending_voice_connections, NewState),
|
||||
?assertNot(maps:is_key(<<"conn1">>, NewPending)).
|
||||
|
||||
disconnect_voice_user_cleans_pending_when_no_active_states_test() ->
|
||||
PendingConnections = #{
|
||||
<<"conn1">> => #{user_id => 5, channel_id => 100},
|
||||
<<"conn2">> => #{user_id => 6, channel_id => 100}
|
||||
},
|
||||
State = #{
|
||||
id => 10,
|
||||
voice_states => #{},
|
||||
pending_voice_connections => PendingConnections
|
||||
},
|
||||
{reply, #{success := true}, NewState} =
|
||||
disconnect_voice_user(#{user_id => 5}, State),
|
||||
NewPending = maps:get(pending_voice_connections, NewState),
|
||||
?assertNot(maps:is_key(<<"conn1">>, NewPending)),
|
||||
?assert(maps:is_key(<<"conn2">>, NewPending)).
|
||||
|
||||
disconnect_voice_user_cleans_pending_for_specific_connection_test() ->
|
||||
PendingConnections = #{
|
||||
<<"conn1">> => #{user_id => 5, channel_id => 100},
|
||||
<<"conn2">> => #{user_id => 5, channel_id => 200}
|
||||
},
|
||||
State = #{
|
||||
id => 10,
|
||||
voice_states => #{},
|
||||
pending_voice_connections => PendingConnections
|
||||
},
|
||||
{reply, #{success := true}, NewState} =
|
||||
disconnect_voice_user(#{user_id => 5, connection_id => <<"conn1">>}, State),
|
||||
NewPending = maps:get(pending_voice_connections, NewState),
|
||||
?assertNot(maps:is_key(<<"conn1">>, NewPending)),
|
||||
?assert(maps:is_key(<<"conn2">>, NewPending)).
|
||||
|
||||
disconnect_voice_user_if_in_channel_cleans_pending_when_not_found_test() ->
|
||||
PendingConnections = #{
|
||||
<<"conn1">> => #{user_id => 5, channel_id => 100},
|
||||
<<"conn2">> => #{user_id => 6, channel_id => 100}
|
||||
},
|
||||
State = #{
|
||||
id => 10,
|
||||
voice_states => #{},
|
||||
pending_voice_connections => PendingConnections
|
||||
},
|
||||
{reply, #{success := true, ignored := true}, NewState} =
|
||||
disconnect_voice_user_if_in_channel(
|
||||
#{user_id => 5, expected_channel_id => 100},
|
||||
State
|
||||
),
|
||||
NewPending = maps:get(pending_voice_connections, NewState),
|
||||
?assertNot(maps:is_key(<<"conn1">>, NewPending)),
|
||||
?assert(maps:is_key(<<"conn2">>, NewPending)).
|
||||
|
||||
disconnect_all_voice_users_in_channel_cleans_pending_test() ->
|
||||
Self = self(),
|
||||
TestFun = fun(GuildId, ChannelId, UserId, ConnectionId) ->
|
||||
Self ! {force_disconnect, GuildId, ChannelId, UserId, ConnectionId},
|
||||
{ok, #{success => true}}
|
||||
end,
|
||||
VoiceStates = #{
|
||||
<<"a">> => voice_state_fixture(5, 10, 20)
|
||||
},
|
||||
PendingConnections = #{
|
||||
<<"pending1">> => #{user_id => 6, channel_id => 20},
|
||||
<<"pending2">> => #{user_id => 7, channel_id => 30}
|
||||
},
|
||||
State = #{
|
||||
id => 10,
|
||||
voice_states => VoiceStates,
|
||||
pending_voice_connections => PendingConnections,
|
||||
test_force_disconnect_fun => TestFun
|
||||
},
|
||||
{reply, #{success := true, disconnected_count := 1}, NewState} =
|
||||
disconnect_all_voice_users_in_channel(#{channel_id => 20}, State),
|
||||
_ = collect_force_disconnect_messages(1),
|
||||
NewPending = maps:get(pending_voice_connections, NewState),
|
||||
?assertNot(maps:is_key(<<"pending1">>, NewPending)),
|
||||
?assert(maps:is_key(<<"pending2">>, NewPending)).
|
||||
|
||||
disconnect_all_voice_users_in_channel_cleans_pending_when_no_active_states_test() ->
|
||||
PendingConnections = #{
|
||||
<<"pending1">> => #{user_id => 5, channel_id => 20},
|
||||
<<"pending2">> => #{user_id => 6, channel_id => 30}
|
||||
},
|
||||
State = #{
|
||||
id => 10,
|
||||
voice_states => #{},
|
||||
pending_voice_connections => PendingConnections
|
||||
},
|
||||
{reply, #{success := true, disconnected_count := 0}, NewState} =
|
||||
disconnect_all_voice_users_in_channel(#{channel_id => 20}, State),
|
||||
NewPending = maps:get(pending_voice_connections, NewState),
|
||||
?assertNot(maps:is_key(<<"pending1">>, NewPending)),
|
||||
?assert(maps:is_key(<<"pending2">>, NewPending)).
|
||||
|
||||
collect_force_disconnect_messages(Count) ->
|
||||
collect_force_disconnect_messages(Count, []).
|
||||
|
||||
collect_force_disconnect_messages(0, Acc) ->
|
||||
lists:reverse(Acc);
|
||||
collect_force_disconnect_messages(Count, Acc) when Count > 0 ->
|
||||
receive
|
||||
{force_disconnect, _, _, _, _} = Msg ->
|
||||
collect_force_disconnect_messages(Count - 1, [Msg | Acc]);
|
||||
_Other ->
|
||||
collect_force_disconnect_messages(Count, Acc)
|
||||
after 200 ->
|
||||
lists:reverse(Acc)
|
||||
end.
|
||||
|
||||
voice_state_fixture(UserId, GuildId, ChannelId) ->
|
||||
#{
|
||||
<<"user_id">> => integer_to_binary(UserId),
|
||||
|
||||
@@ -40,7 +40,6 @@ update_member_voice(Request, State) ->
|
||||
#{user_id := UserId, mute := Mute, deaf := Deaf} = Request,
|
||||
VoiceStates = voice_state_utils:voice_states(State),
|
||||
GuildId = map_utils:get_integer(State, id, 0),
|
||||
|
||||
case find_member_by_user_id(UserId, State) of
|
||||
undefined ->
|
||||
{reply, gateway_errors:error(voice_member_not_found), State};
|
||||
@@ -48,7 +47,6 @@ update_member_voice(Request, State) ->
|
||||
UpdatedMember = set_member_voice_flags(Member, Mute, Deaf),
|
||||
StateWithUpdatedMember = store_member(UpdatedMember, State),
|
||||
UserVoiceStates = user_voice_states(UserId, VoiceStates),
|
||||
|
||||
case maps:size(UserVoiceStates) of
|
||||
0 ->
|
||||
{reply, #{success => true}, StateWithUpdatedMember};
|
||||
@@ -64,43 +62,22 @@ update_member_voice(Request, State) ->
|
||||
end
|
||||
end.
|
||||
|
||||
-spec find_member_by_user_id(integer(), guild_state()) -> member() | undefined.
|
||||
find_member_by_user_id(UserId, State) ->
|
||||
guild_permissions:find_member_by_user_id(UserId, State).
|
||||
|
||||
-spec find_channel_by_id(integer(), guild_state()) -> map() | undefined.
|
||||
find_channel_by_id(ChannelId, State) ->
|
||||
guild_permissions:find_channel_by_id(ChannelId, State).
|
||||
|
||||
-spec enforce_participant_state_in_livekit(integer(), integer(), integer(), boolean(), boolean()) ->
|
||||
ok.
|
||||
enforce_participant_state_in_livekit(GuildId, ChannelId, UserId, Mute, Deaf) ->
|
||||
Req = voice_utils:build_update_participant_rpc_request(GuildId, ChannelId, UserId, Mute, Deaf),
|
||||
case rpc_client:call(Req) of
|
||||
{ok, _Data} ->
|
||||
logger:debug(
|
||||
"[guild_voice_member] Enforced participant state in LiveKit ~p",
|
||||
[
|
||||
[
|
||||
{guildId, GuildId},
|
||||
{channelId, ChannelId},
|
||||
{userId, UserId},
|
||||
{mute, Mute},
|
||||
{deaf, Deaf}
|
||||
]
|
||||
]
|
||||
),
|
||||
ok;
|
||||
{error, Reason} ->
|
||||
logger:warning(
|
||||
"[guild_voice_member] Failed to enforce participant state in LiveKit ~p",
|
||||
[
|
||||
[
|
||||
{guildId, GuildId},
|
||||
{channelId, ChannelId},
|
||||
{userId, UserId},
|
||||
{mute, Mute},
|
||||
{deaf, Deaf},
|
||||
{error, Reason}
|
||||
]
|
||||
]
|
||||
),
|
||||
{error, _Reason} ->
|
||||
ok
|
||||
end.
|
||||
|
||||
@@ -108,16 +85,10 @@ enforce_participant_state_in_livekit(GuildId, ChannelId, UserId, Mute, Deaf) ->
|
||||
guild_data(State) ->
|
||||
map_utils:ensure_map(map_utils:get_safe(State, data, #{})).
|
||||
|
||||
-spec guild_members(guild_state()) -> [member()].
|
||||
guild_members(State) ->
|
||||
map_utils:ensure_list(maps:get(<<"members">>, guild_data(State), [])).
|
||||
|
||||
-spec member_user_id(member()) -> integer() | undefined.
|
||||
member_user_id(Member) when is_map(Member) ->
|
||||
User = map_utils:ensure_map(maps:get(<<"user">>, Member, #{})),
|
||||
map_utils:get_integer(User, <<"id">>, undefined);
|
||||
member_user_id(_) ->
|
||||
undefined.
|
||||
map_utils:get_integer(User, <<"id">>, undefined).
|
||||
|
||||
-spec set_member_voice_flags(member(), boolean(), boolean()) -> member().
|
||||
set_member_voice_flags(Member, Mute, Deaf) ->
|
||||
@@ -128,19 +99,9 @@ store_member(Member, State) ->
|
||||
case member_user_id(Member) of
|
||||
undefined ->
|
||||
State;
|
||||
TargetId ->
|
||||
_TargetId ->
|
||||
Data = guild_data(State),
|
||||
Members = guild_members(State),
|
||||
UpdatedMembers = lists:map(
|
||||
fun(Current) ->
|
||||
case member_user_id(Current) of
|
||||
TargetId -> Member;
|
||||
_ -> Current
|
||||
end
|
||||
end,
|
||||
Members
|
||||
),
|
||||
UpdatedData = maps:put(<<"members">>, UpdatedMembers, Data),
|
||||
UpdatedData = guild_data_index:put_member(Member, Data),
|
||||
maps:put(data, UpdatedData, State)
|
||||
end.
|
||||
|
||||
@@ -246,9 +207,17 @@ update_member_voice_updates_voice_states_test() ->
|
||||
?assert(false)
|
||||
end.
|
||||
|
||||
update_member_voice_member_not_found_test() ->
|
||||
State = voice_member_test_state(#{}),
|
||||
Request = #{user_id => 999, mute => true, deaf => false},
|
||||
{reply, Error, _} = update_member_voice(Request, State),
|
||||
?assertEqual({error, not_found, voice_member_not_found}, Error).
|
||||
|
||||
voice_member_test_state(Overrides) ->
|
||||
BaseData = #{
|
||||
<<"members">> => [member_fixture(10)]
|
||||
<<"members">> => #{
|
||||
10 => member_fixture(10)
|
||||
}
|
||||
},
|
||||
BaseState = #{
|
||||
id => 42,
|
||||
|
||||
@@ -19,9 +19,16 @@
|
||||
|
||||
-export([move_member/2]).
|
||||
-export([send_voice_server_update_for_move/5]).
|
||||
-export([send_voice_server_update_for_move/6]).
|
||||
-export([send_voice_server_updates_for_move/4]).
|
||||
|
||||
-ifdef(TEST).
|
||||
-include_lib("eunit/include/eunit.hrl").
|
||||
-endif.
|
||||
|
||||
-type guild_state() :: map().
|
||||
-type voice_state() :: map().
|
||||
-type voice_state_map() :: #{binary() => voice_state()}.
|
||||
-type move_request() :: #{
|
||||
user_id := integer(),
|
||||
moderator_id := integer(),
|
||||
@@ -31,10 +38,6 @@
|
||||
deaf := boolean()
|
||||
}.
|
||||
|
||||
-ifdef(TEST).
|
||||
-include_lib("eunit/include/eunit.hrl").
|
||||
-endif.
|
||||
|
||||
-spec move_member(move_request(), guild_state()) -> {reply, map(), guild_state()}.
|
||||
move_member(Request, State) ->
|
||||
#{
|
||||
@@ -44,10 +47,17 @@ move_member(Request, State) ->
|
||||
} = Request,
|
||||
ConnectionId = maps:get(connection_id, Request, null),
|
||||
ChannelId = normalize_channel_id(ChannelIdRaw),
|
||||
logger:debug(
|
||||
"Handling voice move_member request",
|
||||
#{
|
||||
user_id => UserId,
|
||||
moderator_id => ModeratorId,
|
||||
channel_id => ChannelId,
|
||||
connection_id => ConnectionId
|
||||
}
|
||||
),
|
||||
VoiceStates = voice_state_utils:voice_states(State),
|
||||
|
||||
UserVoiceStates = find_user_voice_states(UserId, VoiceStates),
|
||||
|
||||
case maps:size(UserVoiceStates) of
|
||||
0 ->
|
||||
{reply, gateway_errors:error(voice_user_not_in_voice), State};
|
||||
@@ -55,11 +65,20 @@ move_member(Request, State) ->
|
||||
ConnectionsToMove = select_connections_to_move(
|
||||
ConnectionId, UserId, VoiceStates, UserVoiceStates
|
||||
),
|
||||
logger:debug(
|
||||
"Selected voice connections to move",
|
||||
#{
|
||||
user_id => UserId,
|
||||
connection_id => ConnectionId,
|
||||
connections_to_move_count => maps:size(ConnectionsToMove)
|
||||
}
|
||||
),
|
||||
handle_move(
|
||||
ConnectionsToMove, ChannelId, UserId, ModeratorId, ConnectionId, VoiceStates, State
|
||||
)
|
||||
end.
|
||||
|
||||
-spec find_user_voice_states(integer(), voice_state_map()) -> voice_state_map().
|
||||
find_user_voice_states(UserId, VoiceStates) ->
|
||||
maps:filter(
|
||||
fun(_ConnId, VoiceState) ->
|
||||
@@ -68,6 +87,8 @@ find_user_voice_states(UserId, VoiceStates) ->
|
||||
VoiceStates
|
||||
).
|
||||
|
||||
-spec select_connections_to_move(binary() | null, integer(), voice_state_map(), voice_state_map()) ->
|
||||
voice_state_map().
|
||||
select_connections_to_move(null, _UserId, _VoiceStates, UserVoiceStates) ->
|
||||
UserVoiceStates;
|
||||
select_connections_to_move(ConnectionId, UserId, VoiceStates, _UserVoiceStates) ->
|
||||
@@ -83,11 +104,16 @@ select_connections_to_move(ConnectionId, UserId, VoiceStates, _UserVoiceStates)
|
||||
end
|
||||
end.
|
||||
|
||||
-spec handle_move(
|
||||
voice_state_map(),
|
||||
integer() | null,
|
||||
integer(),
|
||||
integer(),
|
||||
binary() | null,
|
||||
voice_state_map(),
|
||||
guild_state()
|
||||
) -> {reply, map(), guild_state()}.
|
||||
handle_move(ConnectionsToMove, ChannelId, UserId, ModeratorId, ConnectionId, VoiceStates, State) ->
|
||||
logger:info(
|
||||
"[guild_voice_move] handle_move user_id=~p moderator_id=~p channel_id=~p connection_id=~p connections=~p",
|
||||
[UserId, ModeratorId, ChannelId, ConnectionId, maps:keys(ConnectionsToMove)]
|
||||
),
|
||||
case maps:size(ConnectionsToMove) of
|
||||
0 ->
|
||||
Error =
|
||||
@@ -99,14 +125,24 @@ handle_move(ConnectionsToMove, ChannelId, UserId, ModeratorId, ConnectionId, Voi
|
||||
_ ->
|
||||
case ChannelId of
|
||||
null ->
|
||||
logger:debug(
|
||||
"Disconnect move requested",
|
||||
#{user_id => UserId, connection_id => ConnectionId}
|
||||
),
|
||||
handle_disconnect_move(ConnectionsToMove, UserId, VoiceStates, State);
|
||||
ChannelIdValue ->
|
||||
logger:debug(
|
||||
"Channel move requested",
|
||||
#{user_id => UserId, channel_id => ChannelIdValue, connection_id => ConnectionId}
|
||||
),
|
||||
handle_channel_move(
|
||||
ConnectionsToMove, ChannelIdValue, UserId, ModeratorId, VoiceStates, State
|
||||
)
|
||||
end
|
||||
end.
|
||||
|
||||
-spec handle_disconnect_move(voice_state_map(), integer(), voice_state_map(), guild_state()) ->
|
||||
{reply, map(), guild_state()}.
|
||||
handle_disconnect_move(ConnectionsToMove, UserId, VoiceStates, State) ->
|
||||
NewVoiceStates = maps:fold(
|
||||
fun(ConnId, _VoiceState, Acc) -> maps:remove(ConnId, Acc) end,
|
||||
@@ -114,79 +150,107 @@ handle_disconnect_move(ConnectionsToMove, UserId, VoiceStates, State) ->
|
||||
ConnectionsToMove
|
||||
),
|
||||
NewState = maps:put(voice_states, NewVoiceStates, State),
|
||||
|
||||
maps:foreach(
|
||||
fun(_ConnId, VoiceState) ->
|
||||
OldChannelIdBin = maps:get(<<"channel_id">>, VoiceState, null),
|
||||
DisconnectVoiceState = maps:put(<<"channel_id">>, null, VoiceState),
|
||||
guild_voice_broadcast:broadcast_voice_state_update(
|
||||
DisconnectVoiceState, NewState, OldChannelIdBin
|
||||
)
|
||||
end,
|
||||
ConnectionsToMove
|
||||
),
|
||||
|
||||
spawn(fun() ->
|
||||
maps:foreach(
|
||||
fun(_ConnId, VoiceState) ->
|
||||
OldChannelIdBin = maps:get(<<"channel_id">>, VoiceState, null),
|
||||
DisconnectVoiceState = maps:put(<<"channel_id">>, null, VoiceState),
|
||||
guild_voice_broadcast:broadcast_voice_state_update(
|
||||
DisconnectVoiceState, NewState, OldChannelIdBin
|
||||
)
|
||||
end,
|
||||
ConnectionsToMove
|
||||
)
|
||||
end),
|
||||
{reply, #{success => true, user_id => UserId, connections_moved => ConnectionsToMove},
|
||||
NewState}.
|
||||
|
||||
-spec handle_channel_move(
|
||||
voice_state_map(), integer(), integer(), integer(), voice_state_map(), guild_state()
|
||||
) -> {reply, map(), guild_state()}.
|
||||
handle_channel_move(ConnectionsToMove, ChannelIdValue, UserId, ModeratorId, VoiceStates, State) ->
|
||||
logger:info(
|
||||
"[guild_voice_move] handle_channel_move user_id=~p moderator_id=~p target_channel_id=~p connections=~p",
|
||||
[UserId, ModeratorId, ChannelIdValue, maps:keys(ConnectionsToMove)]
|
||||
),
|
||||
Channel = guild_voice_member:find_channel_by_id(ChannelIdValue, State),
|
||||
case Channel of
|
||||
undefined ->
|
||||
{reply, gateway_errors:error(voice_channel_not_found), State};
|
||||
_ ->
|
||||
StateWithPending0 = guild_virtual_channel_access:mark_pending_join(
|
||||
UserId, ChannelIdValue, State
|
||||
),
|
||||
StateWithPending1 = guild_virtual_channel_access:mark_preserve(
|
||||
UserId, ChannelIdValue, StateWithPending0
|
||||
),
|
||||
StateWithPending2 = guild_virtual_channel_access:mark_move_pending(
|
||||
UserId, ChannelIdValue, StateWithPending1
|
||||
),
|
||||
ChannelType = maps:get(<<"type">>, Channel, 0),
|
||||
case ChannelType of
|
||||
2 ->
|
||||
check_move_permissions_and_execute(
|
||||
ConnectionsToMove, ChannelIdValue, UserId, ModeratorId, VoiceStates, State
|
||||
ConnectionsToMove,
|
||||
ChannelIdValue,
|
||||
UserId,
|
||||
ModeratorId,
|
||||
VoiceStates,
|
||||
StateWithPending2
|
||||
);
|
||||
_ ->
|
||||
{reply, gateway_errors:error(voice_channel_not_voice), State}
|
||||
end
|
||||
end.
|
||||
|
||||
-spec check_move_permissions_and_execute(
|
||||
voice_state_map(), integer(), integer(), integer(), voice_state_map(), guild_state()
|
||||
) -> {reply, map(), guild_state()}.
|
||||
check_move_permissions_and_execute(
|
||||
ConnectionsToMove, ChannelIdValue, _UserId, ModeratorId, VoiceStates, State
|
||||
ConnectionsToMove, ChannelIdValue, UserId, ModeratorId, VoiceStates, State
|
||||
) ->
|
||||
ViewPerm = constants:view_channel_permission(),
|
||||
ConnectPerm = constants:connect_permission(),
|
||||
ModPerms = guild_permissions:get_member_permissions(ModeratorId, ChannelIdValue, State),
|
||||
ModHasConnect = (ModPerms band ConnectPerm) =:= ConnectPerm,
|
||||
ModHasView = (ModPerms band ViewPerm) =:= ViewPerm,
|
||||
|
||||
case ModHasConnect andalso ModHasView of
|
||||
false ->
|
||||
{reply, gateway_errors:error(voice_moderator_missing_connect), State};
|
||||
true ->
|
||||
execute_move(ConnectionsToMove, VoiceStates, State)
|
||||
execute_move(ConnectionsToMove, ChannelIdValue, UserId, VoiceStates, State)
|
||||
end.
|
||||
|
||||
execute_move(ConnectionsToMove, VoiceStates, State) ->
|
||||
-spec execute_move(voice_state_map(), integer(), integer(), voice_state_map(), guild_state()) ->
|
||||
{reply, map(), guild_state()}.
|
||||
execute_move(ConnectionsToMove, ChannelIdValue, UserId, VoiceStates, State) ->
|
||||
StatePending = guild_virtual_channel_access:mark_pending_join(UserId, ChannelIdValue, State),
|
||||
StatePending2 = guild_virtual_channel_access:mark_preserve(
|
||||
UserId, ChannelIdValue, StatePending
|
||||
),
|
||||
StatePending3 = guild_virtual_channel_access:mark_move_pending(
|
||||
UserId, ChannelIdValue, StatePending2
|
||||
),
|
||||
logger:debug(
|
||||
"Executing voice channel move",
|
||||
#{user_id => UserId, channel_id => ChannelIdValue}
|
||||
),
|
||||
NewVoiceStates = maps:fold(
|
||||
fun(ConnId, _VoiceState, Acc) -> maps:remove(ConnId, Acc) end,
|
||||
VoiceStates,
|
||||
ConnectionsToMove
|
||||
),
|
||||
StateAfterDisconnect = maps:put(voice_states, NewVoiceStates, State),
|
||||
|
||||
maps:foreach(
|
||||
fun(_ConnId, VoiceState) ->
|
||||
OldChannelIdBin = maps:get(<<"channel_id">>, VoiceState, null),
|
||||
DisconnectVoiceState = maps:put(<<"channel_id">>, null, VoiceState),
|
||||
guild_voice_broadcast:broadcast_voice_state_update(
|
||||
DisconnectVoiceState, StateAfterDisconnect, OldChannelIdBin
|
||||
)
|
||||
end,
|
||||
ConnectionsToMove
|
||||
),
|
||||
|
||||
StateAfterDisconnect = maps:put(voice_states, NewVoiceStates, StatePending3),
|
||||
StateWithVirtualAccess = maybe_add_virtual_access(UserId, ChannelIdValue, StateAfterDisconnect),
|
||||
spawn(fun() ->
|
||||
maps:foreach(
|
||||
fun(_ConnId, VoiceState) ->
|
||||
OldChannelIdBin = maps:get(<<"channel_id">>, VoiceState, null),
|
||||
DisconnectVoiceState = maps:put(<<"channel_id">>, null, VoiceState),
|
||||
guild_voice_broadcast:broadcast_voice_state_update(
|
||||
DisconnectVoiceState, StateWithVirtualAccess, OldChannelIdBin
|
||||
)
|
||||
end,
|
||||
ConnectionsToMove
|
||||
)
|
||||
end),
|
||||
SessionData = extract_session_data(ConnectionsToMove),
|
||||
|
||||
{reply,
|
||||
#{
|
||||
success => true,
|
||||
@@ -194,8 +258,9 @@ execute_move(ConnectionsToMove, VoiceStates, State) ->
|
||||
session_data => SessionData,
|
||||
connections_to_move => ConnectionsToMove
|
||||
},
|
||||
StateAfterDisconnect}.
|
||||
StateWithVirtualAccess}.
|
||||
|
||||
-spec extract_session_data(voice_state_map()) -> [map()].
|
||||
extract_session_data(ConnectionsToMove) ->
|
||||
{_ConnectionIds, SessionData} = maps:fold(
|
||||
fun(ConnId, VoiceState, {AccConnIds, AccSessionData}) ->
|
||||
@@ -223,6 +288,159 @@ member_user_id(Member) ->
|
||||
User = map_utils:ensure_map(maps:get(<<"user">>, map_utils:ensure_map(Member), #{})),
|
||||
map_utils:get_integer(User, <<"id">>, undefined).
|
||||
|
||||
-spec send_voice_server_update_for_move(
|
||||
integer(), integer(), integer(), binary() | undefined, pid()
|
||||
) -> ok.
|
||||
send_voice_server_update_for_move(GuildId, ChannelId, UserId, SessionId, GuildPid) ->
|
||||
send_voice_server_update_for_move(GuildId, ChannelId, UserId, SessionId, null, GuildPid).
|
||||
|
||||
-spec send_voice_server_update_for_move(
|
||||
integer(), integer(), integer(), binary() | undefined, binary() | null, pid()
|
||||
) -> ok.
|
||||
send_voice_server_update_for_move(GuildId, ChannelId, UserId, SessionId, OldConnectionId, GuildPid) ->
|
||||
case SessionId of
|
||||
undefined ->
|
||||
ok;
|
||||
_ ->
|
||||
spawn(fun() ->
|
||||
case gen_server:call(GuildPid, {get_sessions}, 10000) of
|
||||
State when is_map(State) ->
|
||||
VoicePermissions = voice_utils:compute_voice_permissions(
|
||||
UserId, ChannelId, State
|
||||
),
|
||||
case
|
||||
guild_voice_connection:request_voice_token(
|
||||
GuildId, ChannelId, UserId, OldConnectionId, VoicePermissions
|
||||
)
|
||||
of
|
||||
{ok, TokenData} ->
|
||||
Token = maps:get(token, TokenData),
|
||||
Endpoint = maps:get(endpoint, TokenData),
|
||||
ConnectionId = maps:get(connection_id, TokenData),
|
||||
guild_voice_broadcast:broadcast_voice_server_update_to_session(
|
||||
GuildId,
|
||||
ChannelId,
|
||||
SessionId,
|
||||
Token,
|
||||
Endpoint,
|
||||
ConnectionId,
|
||||
State
|
||||
);
|
||||
{error, _Reason} ->
|
||||
ok
|
||||
end;
|
||||
_ ->
|
||||
ok
|
||||
end
|
||||
end),
|
||||
ok
|
||||
end.
|
||||
|
||||
-spec maybe_add_virtual_access(integer(), integer(), guild_state()) -> guild_state().
|
||||
maybe_add_virtual_access(UserId, ChannelId, State) ->
|
||||
Member = guild_permissions:find_member_by_user_id(UserId, State),
|
||||
case Member of
|
||||
undefined ->
|
||||
State;
|
||||
_ ->
|
||||
Permissions = guild_permissions:get_member_permissions(UserId, ChannelId, State),
|
||||
ViewPerm = constants:view_channel_permission(),
|
||||
ConnectPerm = constants:connect_permission(),
|
||||
HasView = (Permissions band ViewPerm) =:= ViewPerm,
|
||||
HasConnect = (Permissions band ConnectPerm) =:= ConnectPerm,
|
||||
case HasView andalso HasConnect of
|
||||
true ->
|
||||
State;
|
||||
false ->
|
||||
NewState = guild_virtual_channel_access:add_virtual_access(
|
||||
UserId, ChannelId, State
|
||||
),
|
||||
guild_virtual_channel_access:dispatch_channel_visibility_change(
|
||||
UserId, ChannelId, add, NewState
|
||||
),
|
||||
NewState
|
||||
end
|
||||
end.
|
||||
|
||||
-spec send_voice_server_updates_for_move(integer(), integer(), [map()], pid()) -> ok.
|
||||
send_voice_server_updates_for_move(GuildId, ChannelId, SessionDataList, GuildPid) ->
|
||||
spawn(fun() ->
|
||||
lists:foreach(
|
||||
fun(SessionInfo) ->
|
||||
send_single_voice_server_update(GuildId, ChannelId, SessionInfo, GuildPid)
|
||||
end,
|
||||
SessionDataList
|
||||
)
|
||||
end),
|
||||
ok.
|
||||
|
||||
-spec send_single_voice_server_update(integer(), integer(), map(), pid()) -> ok.
|
||||
send_single_voice_server_update(GuildId, ChannelId, SessionInfo, GuildPid) ->
|
||||
SessionId = maps:get(session_id, SessionInfo),
|
||||
SelfMute = maps:get(self_mute, SessionInfo),
|
||||
SelfDeaf = maps:get(self_deaf, SessionInfo),
|
||||
SelfVideo = maps:get(self_video, SessionInfo),
|
||||
SelfStream = maps:get(self_stream, SessionInfo),
|
||||
IsMobile = maps:get(is_mobile, SessionInfo),
|
||||
OldConnectionId = maps:get(connection_id, SessionInfo, null),
|
||||
Member = maps:get(member, SessionInfo),
|
||||
ServerMute = maps:get(<<"mute">>, Member, false),
|
||||
ServerDeaf = maps:get(<<"deaf">>, Member, false),
|
||||
case member_user_id(Member) of
|
||||
undefined ->
|
||||
ok;
|
||||
UserId ->
|
||||
case gen_server:call(GuildPid, {get_sessions}, 10000) of
|
||||
StateData when is_map(StateData) ->
|
||||
VoicePermissions = voice_utils:compute_voice_permissions(
|
||||
UserId, ChannelId, StateData
|
||||
),
|
||||
case
|
||||
guild_voice_connection:request_voice_token(
|
||||
GuildId, ChannelId, UserId, OldConnectionId, VoicePermissions
|
||||
)
|
||||
of
|
||||
{ok, TokenData} ->
|
||||
Token = maps:get(token, TokenData),
|
||||
Endpoint = maps:get(endpoint, TokenData),
|
||||
NewConnectionId = maps:get(connection_id, TokenData),
|
||||
PendingMetadata = #{
|
||||
<<"user_id">> => UserId,
|
||||
<<"guild_id">> => GuildId,
|
||||
<<"channel_id">> => ChannelId,
|
||||
<<"connection_id">> => NewConnectionId,
|
||||
<<"session_id">> => SessionId,
|
||||
<<"self_mute">> => SelfMute,
|
||||
<<"self_deaf">> => SelfDeaf,
|
||||
<<"self_video">> => SelfVideo,
|
||||
<<"self_stream">> => SelfStream,
|
||||
<<"is_mobile">> => IsMobile,
|
||||
<<"server_mute">> => ServerMute,
|
||||
<<"server_deaf">> => ServerDeaf,
|
||||
<<"member">> => Member
|
||||
},
|
||||
_ = gen_server:call(
|
||||
GuildPid,
|
||||
{store_pending_connection, NewConnectionId, PendingMetadata},
|
||||
10000
|
||||
),
|
||||
guild_voice_broadcast:broadcast_voice_server_update_to_session(
|
||||
GuildId,
|
||||
ChannelId,
|
||||
SessionId,
|
||||
Token,
|
||||
Endpoint,
|
||||
NewConnectionId,
|
||||
StateData
|
||||
);
|
||||
{error, _Reason} ->
|
||||
ok
|
||||
end;
|
||||
_ ->
|
||||
ok
|
||||
end
|
||||
end.
|
||||
|
||||
-ifdef(TEST).
|
||||
|
||||
move_member_user_not_in_voice_test() ->
|
||||
@@ -253,6 +471,12 @@ select_connections_to_move_specific_connection_test() ->
|
||||
?assertEqual(#{<<"conn-b">> => maps:get(<<"conn-b">>, VoiceStates)}, Selected),
|
||||
?assertEqual(#{}, select_connections_to_move(<<"conn-b">>, 10, VoiceStates, #{})).
|
||||
|
||||
normalize_channel_id_test() ->
|
||||
?assertEqual(null, normalize_channel_id(null)),
|
||||
?assertEqual(123, normalize_channel_id(123)),
|
||||
?assertEqual(456, normalize_channel_id(<<"456">>)),
|
||||
?assertEqual(null, normalize_channel_id(undefined)).
|
||||
|
||||
test_state(VoiceStates) ->
|
||||
#{
|
||||
id => 1,
|
||||
@@ -274,105 +498,3 @@ voice_state_fixture(UserId, ChannelId, ConnId) ->
|
||||
}.
|
||||
|
||||
-endif.
|
||||
|
||||
send_voice_server_update_for_move(GuildId, ChannelId, UserId, SessionId, GuildPid) ->
|
||||
case SessionId of
|
||||
undefined ->
|
||||
ok;
|
||||
_ ->
|
||||
case gen_server:call(GuildPid, {get_sessions}, 10000) of
|
||||
State when is_map(State) ->
|
||||
VoicePermissions = voice_utils:compute_voice_permissions(
|
||||
UserId, ChannelId, State
|
||||
),
|
||||
case
|
||||
guild_voice_connection:request_voice_token(
|
||||
GuildId, ChannelId, UserId, VoicePermissions
|
||||
)
|
||||
of
|
||||
{ok, TokenData} ->
|
||||
Token = maps:get(token, TokenData),
|
||||
Endpoint = maps:get(endpoint, TokenData),
|
||||
ConnectionId = maps:get(connection_id, TokenData),
|
||||
guild_voice_broadcast:broadcast_voice_server_update_to_session(
|
||||
GuildId, SessionId, Token, Endpoint, ConnectionId, State
|
||||
);
|
||||
{error, _Reason} ->
|
||||
ok
|
||||
end;
|
||||
_ ->
|
||||
ok
|
||||
end
|
||||
end.
|
||||
|
||||
send_voice_server_updates_for_move(GuildId, ChannelId, SessionDataList, GuildPid) ->
|
||||
lists:foreach(
|
||||
fun(SessionInfo) ->
|
||||
send_single_voice_server_update(GuildId, ChannelId, SessionInfo, GuildPid)
|
||||
end,
|
||||
SessionDataList
|
||||
).
|
||||
|
||||
send_single_voice_server_update(GuildId, ChannelId, SessionInfo, GuildPid) ->
|
||||
SessionId = maps:get(session_id, SessionInfo),
|
||||
SelfMute = maps:get(self_mute, SessionInfo),
|
||||
SelfDeaf = maps:get(self_deaf, SessionInfo),
|
||||
SelfVideo = maps:get(self_video, SessionInfo),
|
||||
SelfStream = maps:get(self_stream, SessionInfo),
|
||||
IsMobile = maps:get(is_mobile, SessionInfo),
|
||||
Member = maps:get(member, SessionInfo),
|
||||
ServerMute = maps:get(<<"mute">>, Member, false),
|
||||
ServerDeaf = maps:get(<<"deaf">>, Member, false),
|
||||
case member_user_id(Member) of
|
||||
undefined ->
|
||||
logger:warning(
|
||||
"[guild_voice_move] Missing user_id in member while sending voice server update: ~p",
|
||||
[SessionInfo]
|
||||
),
|
||||
ok;
|
||||
UserId ->
|
||||
case gen_server:call(GuildPid, {get_sessions}, 10000) of
|
||||
StateData when is_map(StateData) ->
|
||||
VoicePermissions = voice_utils:compute_voice_permissions(
|
||||
UserId, ChannelId, StateData
|
||||
),
|
||||
case
|
||||
guild_voice_connection:request_voice_token(
|
||||
GuildId, ChannelId, UserId, VoicePermissions
|
||||
)
|
||||
of
|
||||
{ok, TokenData} ->
|
||||
Token = maps:get(token, TokenData),
|
||||
Endpoint = maps:get(endpoint, TokenData),
|
||||
NewConnectionId = maps:get(connection_id, TokenData),
|
||||
|
||||
PendingMetadata = #{
|
||||
<<"user_id">> => UserId,
|
||||
<<"guild_id">> => GuildId,
|
||||
<<"channel_id">> => ChannelId,
|
||||
<<"connection_id">> => NewConnectionId,
|
||||
<<"session_id">> => SessionId,
|
||||
<<"self_mute">> => SelfMute,
|
||||
<<"self_deaf">> => SelfDeaf,
|
||||
<<"self_video">> => SelfVideo,
|
||||
<<"self_stream">> => SelfStream,
|
||||
<<"is_mobile">> => IsMobile,
|
||||
<<"server_mute">> => ServerMute,
|
||||
<<"server_deaf">> => ServerDeaf,
|
||||
<<"member">> => Member
|
||||
},
|
||||
gen_server:cast(
|
||||
GuildPid,
|
||||
{store_pending_connection, NewConnectionId, PendingMetadata}
|
||||
),
|
||||
|
||||
guild_voice_broadcast:broadcast_voice_server_update_to_session(
|
||||
GuildId, SessionId, Token, Endpoint, NewConnectionId, StateData
|
||||
);
|
||||
{error, _Reason} ->
|
||||
ok
|
||||
end;
|
||||
_ ->
|
||||
ok
|
||||
end
|
||||
end.
|
||||
|
||||
@@ -26,23 +26,23 @@
|
||||
|
||||
-type guild_state() :: map().
|
||||
-type voice_state() :: map().
|
||||
-type user_id() :: integer().
|
||||
-type channel_id() :: integer().
|
||||
|
||||
-ifdef(TEST).
|
||||
-include_lib("eunit/include/eunit.hrl").
|
||||
-endif.
|
||||
|
||||
-spec sync_user_voice_permissions(integer(), guild_state()) -> ok.
|
||||
-spec sync_user_voice_permissions(user_id(), guild_state()) -> ok.
|
||||
sync_user_voice_permissions(UserId, State) ->
|
||||
VoiceStates = voice_state_utils:voice_states(State),
|
||||
GuildId = map_utils:get_integer(State, id, 0),
|
||||
|
||||
UserVoiceStates = maps:filter(
|
||||
fun(_ConnId, VoiceState) ->
|
||||
voice_state_utils:voice_state_user_id(VoiceState) =:= UserId
|
||||
end,
|
||||
VoiceStates
|
||||
),
|
||||
|
||||
maps:foreach(
|
||||
fun(_ConnId, VoiceState) ->
|
||||
sync_voice_state_permissions(GuildId, UserId, VoiceState, State)
|
||||
@@ -51,18 +51,16 @@ sync_user_voice_permissions(UserId, State) ->
|
||||
),
|
||||
ok.
|
||||
|
||||
-spec sync_all_voice_permissions_for_channel(integer(), guild_state()) -> ok.
|
||||
-spec sync_all_voice_permissions_for_channel(channel_id(), guild_state()) -> ok.
|
||||
sync_all_voice_permissions_for_channel(ChannelId, State) ->
|
||||
VoiceStates = voice_state_utils:voice_states(State),
|
||||
GuildId = map_utils:get_integer(State, id, 0),
|
||||
|
||||
ChannelVoiceStates = maps:filter(
|
||||
fun(_ConnId, VoiceState) ->
|
||||
voice_state_utils:voice_state_channel_id(VoiceState) =:= ChannelId
|
||||
end,
|
||||
VoiceStates
|
||||
),
|
||||
|
||||
maps:foreach(
|
||||
fun(_ConnId, VoiceState) ->
|
||||
UserId = voice_state_utils:voice_state_user_id(VoiceState),
|
||||
@@ -84,15 +82,12 @@ maybe_sync_permissions_on_role_update(RoleUpdate, State) ->
|
||||
_ ->
|
||||
OldPermissions = maps:get(<<"old_permissions">>, RoleUpdate, 0),
|
||||
NewPermissions = maps:get(<<"permissions">>, RoleUpdate, 0),
|
||||
|
||||
AdminPerm = constants:administrator_permission(),
|
||||
SpeakPerm = constants:speak_permission(),
|
||||
StreamPerm = constants:stream_permission(),
|
||||
VoicePerms = AdminPerm bor SpeakPerm bor StreamPerm,
|
||||
|
||||
OldVoicePerms = OldPermissions band VoicePerms,
|
||||
NewVoicePerms = NewPermissions band VoicePerms,
|
||||
|
||||
case OldVoicePerms =/= NewVoicePerms of
|
||||
true ->
|
||||
sync_users_with_role(RoleId, State);
|
||||
@@ -110,7 +105,6 @@ maybe_sync_permissions_on_member_update(MemberUpdate, State) ->
|
||||
_ ->
|
||||
OldRoles = maps:get(<<"old_roles">>, MemberUpdate, []),
|
||||
NewRoles = maps:get(<<"roles">>, MemberUpdate, []),
|
||||
|
||||
case OldRoles =/= NewRoles of
|
||||
true ->
|
||||
sync_user_voice_permissions(UserId, State);
|
||||
@@ -119,11 +113,10 @@ maybe_sync_permissions_on_member_update(MemberUpdate, State) ->
|
||||
end
|
||||
end.
|
||||
|
||||
-spec sync_voice_state_permissions(integer(), integer(), voice_state(), guild_state()) -> ok.
|
||||
-spec sync_voice_state_permissions(integer(), user_id(), voice_state(), guild_state()) -> ok.
|
||||
sync_voice_state_permissions(GuildId, UserId, VoiceState, State) ->
|
||||
ChannelId = voice_state_utils:voice_state_channel_id(VoiceState),
|
||||
ConnectionId = maps:get(<<"connection_id">>, VoiceState, undefined),
|
||||
|
||||
case {ChannelId, ConnectionId} of
|
||||
{undefined, _} ->
|
||||
ok;
|
||||
@@ -134,7 +127,9 @@ sync_voice_state_permissions(GuildId, UserId, VoiceState, State) ->
|
||||
dispatch_permission_update(GuildId, ChId, UserId, ConnId, VoicePermissions, State)
|
||||
end.
|
||||
|
||||
-spec dispatch_permission_update(integer(), integer(), integer(), binary(), map(), guild_state()) ->
|
||||
-spec dispatch_permission_update(
|
||||
integer(), channel_id(), user_id(), binary(), map(), guild_state()
|
||||
) ->
|
||||
ok.
|
||||
dispatch_permission_update(GuildId, ChannelId, UserId, ConnectionId, VoicePermissions, State) ->
|
||||
case maps:get(test_permission_sync_fun, State, undefined) of
|
||||
@@ -149,7 +144,7 @@ dispatch_permission_update(GuildId, ChannelId, UserId, ConnectionId, VoicePermis
|
||||
end.
|
||||
|
||||
-spec enforce_voice_permissions_in_livekit(
|
||||
integer(), integer(), integer(), binary(), map()
|
||||
integer(), channel_id(), user_id(), binary(), map()
|
||||
) -> ok.
|
||||
enforce_voice_permissions_in_livekit(GuildId, ChannelId, UserId, ConnectionId, VoicePermissions) ->
|
||||
Req = voice_utils:build_update_participant_permissions_rpc_request(
|
||||
@@ -157,33 +152,8 @@ enforce_voice_permissions_in_livekit(GuildId, ChannelId, UserId, ConnectionId, V
|
||||
),
|
||||
case rpc_client:call(Req) of
|
||||
{ok, _Data} ->
|
||||
logger:debug(
|
||||
"[guild_voice_permission_sync] Synced voice permissions ~p",
|
||||
[
|
||||
[
|
||||
{guildId, GuildId},
|
||||
{channelId, ChannelId},
|
||||
{userId, UserId},
|
||||
{connectionId, ConnectionId},
|
||||
{permissions, VoicePermissions}
|
||||
]
|
||||
]
|
||||
),
|
||||
ok;
|
||||
{error, Reason} ->
|
||||
logger:warning(
|
||||
"[guild_voice_permission_sync] Failed to sync voice permissions ~p",
|
||||
[
|
||||
[
|
||||
{guildId, GuildId},
|
||||
{channelId, ChannelId},
|
||||
{userId, UserId},
|
||||
{connectionId, ConnectionId},
|
||||
{permissions, VoicePermissions},
|
||||
{error, Reason}
|
||||
]
|
||||
]
|
||||
),
|
||||
{error, _Reason} ->
|
||||
ok
|
||||
end.
|
||||
|
||||
@@ -192,7 +162,6 @@ sync_users_with_role(RoleId, State) ->
|
||||
RoleIdBin = ensure_binary(RoleId),
|
||||
VoiceStates = voice_state_utils:voice_states(State),
|
||||
GuildId = map_utils:get_integer(State, id, 0),
|
||||
|
||||
maps:foreach(
|
||||
fun(_ConnId, VoiceState) ->
|
||||
UserId = voice_state_utils:voice_state_user_id(VoiceState),
|
||||
@@ -212,7 +181,7 @@ sync_users_with_role(RoleId, State) ->
|
||||
),
|
||||
ok.
|
||||
|
||||
-spec user_has_role(integer(), binary(), guild_state()) -> boolean().
|
||||
-spec user_has_role(user_id(), binary(), guild_state()) -> boolean().
|
||||
user_has_role(UserId, RoleIdBin, State) ->
|
||||
case guild_voice_member:find_member_by_user_id(UserId, State) of
|
||||
undefined ->
|
||||
@@ -222,7 +191,7 @@ user_has_role(UserId, RoleIdBin, State) ->
|
||||
lists:member(RoleIdBin, Roles)
|
||||
end.
|
||||
|
||||
-spec get_member_user_id(map()) -> integer() | undefined.
|
||||
-spec get_member_user_id(map()) -> user_id() | undefined.
|
||||
get_member_user_id(MemberUpdate) ->
|
||||
User = maps:get(<<"user">>, MemberUpdate, #{}),
|
||||
map_utils:get_integer(User, <<"id">>, undefined).
|
||||
@@ -242,19 +211,16 @@ sync_user_voice_permissions_syncs_connected_user_test() ->
|
||||
ChannelId = 500,
|
||||
GuildId = 42,
|
||||
RoleId = 999,
|
||||
|
||||
VoiceState = #{
|
||||
<<"user_id">> => integer_to_binary(UserId),
|
||||
<<"channel_id">> => integer_to_binary(ChannelId),
|
||||
<<"connection_id">> => <<"test-conn">>
|
||||
},
|
||||
|
||||
Permissions =
|
||||
constants:view_channel_permission() bor
|
||||
constants:connect_permission() bor
|
||||
constants:speak_permission() bor
|
||||
constants:stream_permission(),
|
||||
|
||||
State = #{
|
||||
id => GuildId,
|
||||
voice_states => #{<<"conn">> => VoiceState},
|
||||
@@ -301,4 +267,17 @@ sync_user_voice_permissions_no_voice_state_test() ->
|
||||
},
|
||||
ok = sync_user_voice_permissions(10, State).
|
||||
|
||||
maybe_sync_permissions_on_member_update_no_role_change_test() ->
|
||||
State = #{id => 42, voice_states => #{}},
|
||||
MemberUpdate = #{
|
||||
<<"user">> => #{<<"id">> => <<"10">>},
|
||||
<<"roles">> => [<<"1">>],
|
||||
<<"old_roles">> => [<<"1">>]
|
||||
},
|
||||
?assertEqual(ok, maybe_sync_permissions_on_member_update(MemberUpdate, State)).
|
||||
|
||||
ensure_binary_test() ->
|
||||
?assertEqual(<<"123">>, ensure_binary(123)),
|
||||
?assertEqual(<<"abc">>, ensure_binary(<<"abc">>)).
|
||||
|
||||
-endif.
|
||||
|
||||
@@ -19,15 +19,15 @@
|
||||
|
||||
-export([check_voice_permissions_and_limits/6]).
|
||||
|
||||
-type guild_state() :: map().
|
||||
-type voice_state_map() :: #{binary() => map()}.
|
||||
-type channel() :: map().
|
||||
-import(utils, [parse_iso8601_to_unix_ms/1]).
|
||||
|
||||
-ifdef(TEST).
|
||||
-include_lib("eunit/include/eunit.hrl").
|
||||
-endif.
|
||||
|
||||
-import(utils, [parse_iso8601_to_unix_ms/1]).
|
||||
-type guild_state() :: map().
|
||||
-type voice_state_map() :: #{binary() => map()}.
|
||||
-type channel() :: map().
|
||||
|
||||
-spec check_voice_permissions_and_limits(
|
||||
integer(), integer(), channel(), voice_state_map(), guild_state(), boolean()
|
||||
@@ -45,8 +45,10 @@ check_voice_permissions_and_limits(UserId, ChannelIdValue, Channel, VoiceStates,
|
||||
case
|
||||
channel_has_capacity(UserId, ChannelIdValue, Channel, VoiceStates, IsUpdate)
|
||||
of
|
||||
true -> {ok, allowed};
|
||||
false -> gateway_errors:error(voice_channel_full)
|
||||
true ->
|
||||
{ok, allowed};
|
||||
false ->
|
||||
gateway_errors:error(voice_channel_full)
|
||||
end
|
||||
end
|
||||
end.
|
||||
@@ -57,18 +59,26 @@ has_view_and_connect_perms(UserId, ChannelIdValue, State) ->
|
||||
true ->
|
||||
true;
|
||||
false ->
|
||||
Permissions = resolve_permissions(UserId, ChannelIdValue, State),
|
||||
ViewPerm = constants:view_channel_permission(),
|
||||
ConnectPerm = constants:connect_permission(),
|
||||
(Permissions band ViewPerm) =:= ViewPerm andalso
|
||||
(Permissions band ConnectPerm) =:= ConnectPerm
|
||||
case guild_virtual_channel_access:is_move_pending(UserId, ChannelIdValue, State) of
|
||||
true ->
|
||||
true;
|
||||
false ->
|
||||
Permissions = resolve_permissions(UserId, ChannelIdValue, State),
|
||||
ViewPerm = constants:view_channel_permission(),
|
||||
ConnectPerm = constants:connect_permission(),
|
||||
HasView = (Permissions band ViewPerm) =:= ViewPerm,
|
||||
HasConnect = (Permissions band ConnectPerm) =:= ConnectPerm,
|
||||
HasView andalso HasConnect
|
||||
end
|
||||
end.
|
||||
|
||||
-spec channel_has_capacity(integer(), integer(), channel(), voice_state_map(), boolean()) ->
|
||||
boolean().
|
||||
channel_has_capacity(UserId, ChannelIdValue, Channel, VoiceStates, IsUpdate) ->
|
||||
UserLimit = maps:get(<<"user_limit">>, Channel, 0),
|
||||
case UserLimit of
|
||||
AnyCameraActive = any_camera_active_in_channel(ChannelIdValue, VoiceStates),
|
||||
EffectiveLimit = effective_user_limit(UserLimit, AnyCameraActive),
|
||||
case EffectiveLimit of
|
||||
0 ->
|
||||
true;
|
||||
Limit when Limit > 0 ->
|
||||
@@ -85,6 +95,26 @@ channel_has_capacity(UserId, ChannelIdValue, Channel, VoiceStates, IsUpdate) ->
|
||||
true
|
||||
end.
|
||||
|
||||
-spec any_camera_active_in_channel(integer(), voice_state_map()) -> boolean().
|
||||
any_camera_active_in_channel(ChannelIdValue, VoiceStates) ->
|
||||
lists:any(
|
||||
fun({_ConnId, VS}) ->
|
||||
case map_utils:get_integer(VS, <<"channel_id">>, undefined) of
|
||||
ChannelIdValue ->
|
||||
maps:get(<<"self_video">>, VS, false) =:= true;
|
||||
_ ->
|
||||
false
|
||||
end
|
||||
end,
|
||||
maps:to_list(VoiceStates)
|
||||
).
|
||||
|
||||
-spec effective_user_limit(integer(), boolean()) -> integer().
|
||||
effective_user_limit(0, false) -> 0;
|
||||
effective_user_limit(0, true) -> 25;
|
||||
effective_user_limit(Limit, false) -> Limit;
|
||||
effective_user_limit(Limit, true) -> min(Limit, 25).
|
||||
|
||||
-spec is_member_timed_out(integer(), guild_state()) -> boolean().
|
||||
is_member_timed_out(UserId, State) ->
|
||||
case guild_permissions:find_member_by_user_id(UserId, State) of
|
||||
@@ -98,13 +128,11 @@ is_member_timed_out(UserId, State) ->
|
||||
undefined ->
|
||||
false;
|
||||
Value when is_integer(Value) ->
|
||||
Value > erlang:system_time(millisecond);
|
||||
_ ->
|
||||
false
|
||||
Value > erlang:system_time(millisecond)
|
||||
end
|
||||
end.
|
||||
|
||||
-spec users_in_channel(integer(), voice_state_map()) -> sets:set().
|
||||
-spec users_in_channel(integer(), voice_state_map()) -> sets:set(integer()).
|
||||
users_in_channel(ChannelIdValue, VoiceStates0) ->
|
||||
VoiceStates = voice_state_utils:ensure_voice_states(VoiceStates0),
|
||||
maps:fold(
|
||||
@@ -161,6 +189,18 @@ voice_permissions_existing_user_update_test() ->
|
||||
),
|
||||
?assertEqual({ok, allowed}, Result).
|
||||
|
||||
users_in_channel_test() ->
|
||||
VoiceStates = #{
|
||||
<<"conn1">> => #{<<"channel_id">> => <<"10">>, <<"user_id">> => <<"1">>},
|
||||
<<"conn2">> => #{<<"channel_id">> => <<"10">>, <<"user_id">> => <<"2">>},
|
||||
<<"conn3">> => #{<<"channel_id">> => <<"20">>, <<"user_id">> => <<"3">>}
|
||||
},
|
||||
Result = users_in_channel(10, VoiceStates),
|
||||
?assertEqual(2, sets:size(Result)),
|
||||
?assert(sets:is_element(1, Result)),
|
||||
?assert(sets:is_element(2, Result)),
|
||||
?assertNot(sets:is_element(3, Result)).
|
||||
|
||||
required_voice_perms() ->
|
||||
constants:view_channel_permission() bor constants:connect_permission().
|
||||
|
||||
|
||||
@@ -20,9 +20,17 @@
|
||||
-export([switch_voice_region_handler/2]).
|
||||
-export([switch_voice_region/3]).
|
||||
|
||||
-type guild_state() :: map().
|
||||
-type guild_reply(T) :: {reply, T, guild_state()}.
|
||||
-type voice_state() :: map().
|
||||
|
||||
-ifdef(TEST).
|
||||
-include_lib("eunit/include/eunit.hrl").
|
||||
-endif.
|
||||
|
||||
-spec switch_voice_region_handler(map(), guild_state()) -> guild_reply(map()).
|
||||
switch_voice_region_handler(Request, State) ->
|
||||
#{channel_id := ChannelId} = Request,
|
||||
|
||||
Channel = guild_voice_member:find_channel_by_id(ChannelId, State),
|
||||
case Channel of
|
||||
undefined ->
|
||||
@@ -37,42 +45,21 @@ switch_voice_region_handler(Request, State) ->
|
||||
end
|
||||
end.
|
||||
|
||||
-spec switch_voice_region(integer(), integer(), pid()) -> ok.
|
||||
switch_voice_region(GuildId, ChannelId, GuildPid) ->
|
||||
case gen_server:call(GuildPid, {get_sessions}, 10000) of
|
||||
State when is_map(State) ->
|
||||
VoiceStates = voice_state_utils:voice_states(State),
|
||||
|
||||
UsersInChannel = maps:fold(
|
||||
fun(ConnectionId, VoiceState, Acc) ->
|
||||
case voice_state_utils:voice_state_channel_id(VoiceState) of
|
||||
ChannelId ->
|
||||
case voice_state_utils:voice_state_user_id(VoiceState) of
|
||||
undefined ->
|
||||
logger:warning(
|
||||
"[guild_voice_region] Missing user_id for connection ~p",
|
||||
[ConnectionId]
|
||||
),
|
||||
Acc;
|
||||
UserId ->
|
||||
SessionId = maps:get(<<"session_id">>, VoiceState, undefined),
|
||||
[{UserId, SessionId, VoiceState} | Acc]
|
||||
end;
|
||||
_ ->
|
||||
Acc
|
||||
end
|
||||
end,
|
||||
[],
|
||||
VoiceStates
|
||||
),
|
||||
|
||||
UsersInChannel = collect_users_in_channel(VoiceStates, ChannelId),
|
||||
lists:foreach(
|
||||
fun({UserId, SessionId, VoiceState}) ->
|
||||
fun({UserId, SessionId, ExistingConnectionId, VoiceState}) ->
|
||||
case SessionId of
|
||||
undefined ->
|
||||
ok;
|
||||
_ ->
|
||||
send_voice_server_update_for_region_switch(
|
||||
GuildId, ChannelId, UserId, SessionId, VoiceState, GuildPid
|
||||
GuildId, ChannelId, UserId, SessionId, ExistingConnectionId,
|
||||
VoiceState, GuildPid
|
||||
)
|
||||
end
|
||||
end,
|
||||
@@ -82,42 +69,61 @@ switch_voice_region(GuildId, ChannelId, GuildPid) ->
|
||||
ok
|
||||
end.
|
||||
|
||||
-spec collect_users_in_channel(map(), integer()) ->
|
||||
[{integer(), binary() | undefined, binary(), voice_state()}].
|
||||
collect_users_in_channel(VoiceStates, ChannelId) ->
|
||||
maps:fold(
|
||||
fun(ConnectionId, VoiceState, Acc) ->
|
||||
case voice_state_utils:voice_state_channel_id(VoiceState) of
|
||||
ChannelId ->
|
||||
case voice_state_utils:voice_state_user_id(VoiceState) of
|
||||
undefined ->
|
||||
Acc;
|
||||
UserId ->
|
||||
SessionId = maps:get(<<"session_id">>, VoiceState, undefined),
|
||||
[{UserId, SessionId, ConnectionId, VoiceState} | Acc]
|
||||
end;
|
||||
_ ->
|
||||
Acc
|
||||
end
|
||||
end,
|
||||
[],
|
||||
VoiceStates
|
||||
).
|
||||
|
||||
-spec send_voice_server_update_for_region_switch(
|
||||
integer(), integer(), integer(), binary(), binary(), voice_state(), pid()
|
||||
) -> ok.
|
||||
send_voice_server_update_for_region_switch(
|
||||
GuildId, ChannelId, UserId, SessionId, ExistingVoiceState, GuildPid
|
||||
GuildId, ChannelId, UserId, SessionId, ExistingConnectionId, ExistingVoiceState, GuildPid
|
||||
) ->
|
||||
case gen_server:call(GuildPid, {get_sessions}, 10000) of
|
||||
State when is_map(State) ->
|
||||
VoicePermissions = voice_utils:compute_voice_permissions(UserId, ChannelId, State),
|
||||
TokenNonce = voice_utils:generate_token_nonce(),
|
||||
case
|
||||
guild_voice_connection:request_voice_token(
|
||||
GuildId, ChannelId, UserId, VoicePermissions
|
||||
GuildId, ChannelId, UserId, ExistingConnectionId, VoicePermissions, TokenNonce
|
||||
)
|
||||
of
|
||||
{ok, TokenData} ->
|
||||
Token = maps:get(token, TokenData),
|
||||
Endpoint = maps:get(endpoint, TokenData),
|
||||
ConnectionId = maps:get(connection_id, TokenData),
|
||||
|
||||
PendingMetadata = #{
|
||||
user_id => UserId,
|
||||
guild_id => GuildId,
|
||||
channel_id => ChannelId,
|
||||
session_id => SessionId,
|
||||
self_mute => maps:get(<<"self_mute">>, ExistingVoiceState, false),
|
||||
self_deaf => maps:get(<<"self_deaf">>, ExistingVoiceState, false),
|
||||
self_video => maps:get(<<"self_video">>, ExistingVoiceState, false),
|
||||
self_stream => maps:get(<<"self_stream">>, ExistingVoiceState, false),
|
||||
is_mobile => maps:get(<<"is_mobile">>, ExistingVoiceState, false),
|
||||
server_mute => maps:get(<<"mute">>, ExistingVoiceState, false),
|
||||
server_deaf => maps:get(<<"deaf">>, ExistingVoiceState, false),
|
||||
member => maps:get(<<"member">>, ExistingVoiceState, #{})
|
||||
},
|
||||
gen_server:cast(
|
||||
GuildPid, {store_pending_connection, ConnectionId, PendingMetadata}
|
||||
PendingMetadata = build_pending_metadata(
|
||||
UserId, GuildId, ChannelId, SessionId, ExistingVoiceState, TokenNonce
|
||||
),
|
||||
_ = gen_server:call(
|
||||
GuildPid, {store_pending_connection, ConnectionId, PendingMetadata}, 10000
|
||||
),
|
||||
|
||||
guild_voice_broadcast:broadcast_voice_server_update_to_session(
|
||||
GuildId, SessionId, Token, Endpoint, ConnectionId, State
|
||||
GuildId,
|
||||
ChannelId,
|
||||
SessionId,
|
||||
Token,
|
||||
Endpoint,
|
||||
ConnectionId,
|
||||
State
|
||||
);
|
||||
{error, _Reason} ->
|
||||
ok
|
||||
@@ -125,3 +131,73 @@ send_voice_server_update_for_region_switch(
|
||||
_ ->
|
||||
ok
|
||||
end.
|
||||
|
||||
-spec build_pending_metadata(integer(), integer(), integer(), binary(), voice_state(), binary()) -> map().
|
||||
build_pending_metadata(UserId, GuildId, ChannelId, SessionId, ExistingVoiceState, TokenNonce) ->
|
||||
Now = erlang:system_time(millisecond),
|
||||
#{
|
||||
user_id => UserId,
|
||||
guild_id => GuildId,
|
||||
channel_id => ChannelId,
|
||||
session_id => SessionId,
|
||||
self_mute => maps:get(<<"self_mute">>, ExistingVoiceState, false),
|
||||
self_deaf => maps:get(<<"self_deaf">>, ExistingVoiceState, false),
|
||||
self_video => maps:get(<<"self_video">>, ExistingVoiceState, false),
|
||||
self_stream => maps:get(<<"self_stream">>, ExistingVoiceState, false),
|
||||
is_mobile => maps:get(<<"is_mobile">>, ExistingVoiceState, false),
|
||||
server_mute => maps:get(<<"mute">>, ExistingVoiceState, false),
|
||||
server_deaf => maps:get(<<"deaf">>, ExistingVoiceState, false),
|
||||
member => maps:get(<<"member">>, ExistingVoiceState, #{}),
|
||||
viewer_stream_keys => [],
|
||||
token_nonce => TokenNonce,
|
||||
created_at => Now,
|
||||
expires_at => Now + 30000
|
||||
}.
|
||||
|
||||
-ifdef(TEST).
|
||||
|
||||
switch_voice_region_handler_not_found_test() ->
|
||||
State = #{data => #{<<"channels">> => []}},
|
||||
Request = #{channel_id => 999},
|
||||
{reply, Error, _} = switch_voice_region_handler(Request, State),
|
||||
?assertEqual({error, not_found, voice_channel_not_found}, Error).
|
||||
|
||||
switch_voice_region_handler_not_voice_test() ->
|
||||
State = #{
|
||||
data => #{
|
||||
<<"channels">> => [
|
||||
#{<<"id">> => <<"100">>, <<"type">> => 0}
|
||||
]
|
||||
}
|
||||
},
|
||||
Request = #{channel_id => 100},
|
||||
{reply, Error, _} = switch_voice_region_handler(Request, State),
|
||||
?assertEqual({error, validation_error, voice_channel_not_voice}, Error).
|
||||
|
||||
switch_voice_region_handler_success_test() ->
|
||||
State = #{
|
||||
data => #{
|
||||
<<"channels">> => [
|
||||
#{<<"id">> => <<"100">>, <<"type">> => 2}
|
||||
]
|
||||
}
|
||||
},
|
||||
Request = #{channel_id => 100},
|
||||
{reply, Reply, _} = switch_voice_region_handler(Request, State),
|
||||
?assertEqual(true, maps:get(success, Reply)).
|
||||
|
||||
collect_users_in_channel_test() ->
|
||||
VoiceState = #{
|
||||
<<"channel_id">> => <<"100">>,
|
||||
<<"user_id">> => <<"10">>,
|
||||
<<"session_id">> => <<"sess1">>
|
||||
},
|
||||
VoiceStates = #{<<"conn1">> => VoiceState},
|
||||
Result = collect_users_in_channel(VoiceStates, 100),
|
||||
?assertEqual(1, length(Result)),
|
||||
[{UserId, SessionId, ConnectionId, _}] = Result,
|
||||
?assertEqual(10, UserId),
|
||||
?assertEqual(<<"sess1">>, SessionId),
|
||||
?assertEqual(<<"conn1">>, ConnectionId).
|
||||
|
||||
-endif.
|
||||
|
||||
@@ -26,14 +26,14 @@
|
||||
-export([create_voice_state/8]).
|
||||
-export([extract_session_info_from_voice_state/2]).
|
||||
|
||||
-type guild_state() :: map().
|
||||
-type voice_state() :: map().
|
||||
-type voice_state_map() :: #{binary() => voice_state()}.
|
||||
|
||||
-ifdef(TEST).
|
||||
-include_lib("eunit/include/eunit.hrl").
|
||||
-endif.
|
||||
|
||||
-type guild_state() :: map().
|
||||
-type voice_state() :: map().
|
||||
-type voice_state_map() :: #{binary() => voice_state()}.
|
||||
|
||||
-spec get_voice_state(map(), guild_state()) -> {reply, map(), guild_state()}.
|
||||
get_voice_state(Request, State) ->
|
||||
case maps:get(connection_id, Request, null) of
|
||||
@@ -58,7 +58,7 @@ get_voice_states_list(State) ->
|
||||
voice_state_map(),
|
||||
guild_state(),
|
||||
boolean(),
|
||||
term()
|
||||
list()
|
||||
) -> {reply, map(), guild_state()}.
|
||||
update_voice_state_data(
|
||||
ConnectionId,
|
||||
@@ -69,39 +69,85 @@ update_voice_state_data(
|
||||
VoiceStates,
|
||||
State,
|
||||
NeedsToken,
|
||||
ViewerStreamKey
|
||||
ViewerStreamKeys
|
||||
) ->
|
||||
#voice_flags{
|
||||
self_mute = SelfMute,
|
||||
self_deaf = SelfDeaf,
|
||||
self_video = SelfVideo,
|
||||
self_stream = SelfStream,
|
||||
is_mobile = IsMobile
|
||||
#{
|
||||
self_mute := SelfMute,
|
||||
self_deaf := SelfDeaf,
|
||||
self_video := SelfVideo,
|
||||
self_stream := SelfStream,
|
||||
is_mobile := IsMobile
|
||||
} = Flags,
|
||||
ServerMute = maps:get(<<"mute">>, Member, false),
|
||||
ServerDeaf = maps:get(<<"deaf">>, Member, false),
|
||||
OldVersion = maps:get(<<"version">>, ExistingVoiceState, 0),
|
||||
UpdatedVoiceState = ExistingVoiceState#{
|
||||
<<"channel_id">> => ChannelIdBin,
|
||||
<<"mute">> => ServerMute,
|
||||
<<"deaf">> => ServerDeaf,
|
||||
<<"self_mute">> => SelfMute,
|
||||
<<"self_deaf">> => SelfDeaf,
|
||||
<<"self_video">> => SelfVideo,
|
||||
<<"self_stream">> => SelfStream,
|
||||
<<"is_mobile">> => IsMobile,
|
||||
<<"viewer_stream_key">> => ViewerStreamKey,
|
||||
<<"version">> => OldVersion + 1
|
||||
},
|
||||
NewVoiceStates = maps:put(ConnectionId, UpdatedVoiceState, VoiceStates),
|
||||
NewState = maps:put(voice_states, NewVoiceStates, State),
|
||||
guild_voice_broadcast:broadcast_voice_state_update(UpdatedVoiceState, NewState, ChannelIdBin),
|
||||
Reply =
|
||||
case NeedsToken of
|
||||
true -> #{success => true, voice_state => UpdatedVoiceState, needs_token => true};
|
||||
false -> #{success => true, voice_state => UpdatedVoiceState}
|
||||
end,
|
||||
{reply, Reply, NewState}.
|
||||
OldChannelIdBin = maps:get(<<"channel_id">>, ExistingVoiceState, null),
|
||||
IsChannelChange = OldChannelIdBin =/= ChannelIdBin,
|
||||
HasStateChange = has_voice_state_change(
|
||||
ExistingVoiceState, ChannelIdBin, ServerMute, ServerDeaf,
|
||||
SelfMute, SelfDeaf, SelfVideo, SelfStream, IsMobile, ViewerStreamKeys
|
||||
),
|
||||
case HasStateChange of
|
||||
false ->
|
||||
Reply = #{success => true, voice_state => ExistingVoiceState},
|
||||
{reply, Reply, State};
|
||||
true ->
|
||||
OldVersion = maps:get(<<"version">>, ExistingVoiceState, 0),
|
||||
UpdatedVoiceState = ExistingVoiceState#{
|
||||
<<"channel_id">> => ChannelIdBin,
|
||||
<<"mute">> => ServerMute,
|
||||
<<"deaf">> => ServerDeaf,
|
||||
<<"self_mute">> => SelfMute,
|
||||
<<"self_deaf">> => SelfDeaf,
|
||||
<<"self_video">> => SelfVideo,
|
||||
<<"self_stream">> => SelfStream,
|
||||
<<"is_mobile">> => IsMobile,
|
||||
<<"viewer_stream_keys">> => ViewerStreamKeys,
|
||||
<<"version">> => OldVersion + 1
|
||||
},
|
||||
NewVoiceStates = maps:put(ConnectionId, UpdatedVoiceState, VoiceStates),
|
||||
NewState = maps:put(voice_states, NewVoiceStates, State),
|
||||
case IsChannelChange of
|
||||
true ->
|
||||
DisconnectState = ExistingVoiceState#{
|
||||
<<"channel_id">> => null,
|
||||
<<"connection_id">> => ConnectionId
|
||||
},
|
||||
guild_voice_broadcast:broadcast_voice_state_update(
|
||||
DisconnectState, NewState, OldChannelIdBin
|
||||
),
|
||||
guild_voice_broadcast:broadcast_voice_state_update(
|
||||
UpdatedVoiceState, NewState, ChannelIdBin
|
||||
);
|
||||
false ->
|
||||
guild_voice_broadcast:broadcast_voice_state_update(
|
||||
UpdatedVoiceState, NewState, ChannelIdBin
|
||||
)
|
||||
end,
|
||||
Reply =
|
||||
case NeedsToken of
|
||||
true -> #{success => true, voice_state => UpdatedVoiceState, needs_token => true};
|
||||
false -> #{success => true, voice_state => UpdatedVoiceState}
|
||||
end,
|
||||
{reply, Reply, NewState}
|
||||
end.
|
||||
|
||||
-spec has_voice_state_change(
|
||||
voice_state(), binary(), boolean(), boolean(),
|
||||
boolean(), boolean(), boolean(), boolean(), boolean(), term()
|
||||
) -> boolean().
|
||||
has_voice_state_change(
|
||||
ExistingVoiceState, ChannelIdBin, ServerMute, ServerDeaf,
|
||||
SelfMute, SelfDeaf, SelfVideo, SelfStream, IsMobile, ViewerStreamKeys
|
||||
) ->
|
||||
maps:get(<<"channel_id">>, ExistingVoiceState, null) =/= ChannelIdBin orelse
|
||||
maps:get(<<"mute">>, ExistingVoiceState, false) =/= ServerMute orelse
|
||||
maps:get(<<"deaf">>, ExistingVoiceState, false) =/= ServerDeaf orelse
|
||||
maps:get(<<"self_mute">>, ExistingVoiceState, false) =/= SelfMute orelse
|
||||
maps:get(<<"self_deaf">>, ExistingVoiceState, false) =/= SelfDeaf orelse
|
||||
maps:get(<<"self_video">>, ExistingVoiceState, false) =/= SelfVideo orelse
|
||||
maps:get(<<"self_stream">>, ExistingVoiceState, false) =/= SelfStream orelse
|
||||
maps:get(<<"is_mobile">>, ExistingVoiceState, false) =/= IsMobile orelse
|
||||
maps:get(<<"viewer_stream_keys">>, ExistingVoiceState, []) =/= ViewerStreamKeys.
|
||||
|
||||
-spec user_matches_voice_state(voice_state(), integer() | binary()) -> boolean().
|
||||
user_matches_voice_state(VoiceState, UserId) when is_integer(UserId) ->
|
||||
@@ -122,7 +168,7 @@ user_matches_voice_state(_VoiceState, _UserId) ->
|
||||
boolean(),
|
||||
boolean(),
|
||||
voice_flags(),
|
||||
term()
|
||||
list()
|
||||
) -> voice_state().
|
||||
create_voice_state(
|
||||
GuildIdBin,
|
||||
@@ -132,14 +178,14 @@ create_voice_state(
|
||||
ServerMute,
|
||||
ServerDeaf,
|
||||
Flags,
|
||||
ViewerStreamKey
|
||||
ViewerStreamKeys
|
||||
) ->
|
||||
#voice_flags{
|
||||
self_mute = SelfMute,
|
||||
self_deaf = SelfDeaf,
|
||||
self_video = SelfVideo,
|
||||
self_stream = SelfStream,
|
||||
is_mobile = IsMobile
|
||||
#{
|
||||
self_mute := SelfMute,
|
||||
self_deaf := SelfDeaf,
|
||||
self_video := SelfVideo,
|
||||
self_stream := SelfStream,
|
||||
is_mobile := IsMobile
|
||||
} = Flags,
|
||||
#{
|
||||
<<"guild_id">> => GuildIdBin,
|
||||
@@ -153,7 +199,7 @@ create_voice_state(
|
||||
<<"self_video">> => SelfVideo,
|
||||
<<"self_stream">> => SelfStream,
|
||||
<<"is_mobile">> => IsMobile,
|
||||
<<"viewer_stream_key">> => ViewerStreamKey,
|
||||
<<"viewer_stream_keys">> => ViewerStreamKeys,
|
||||
<<"version">> => 0
|
||||
}.
|
||||
|
||||
@@ -177,29 +223,138 @@ user_matches_voice_state_integer_test() ->
|
||||
?assert(user_matches_voice_state(VoiceState, 10)),
|
||||
?assertNot(user_matches_voice_state(VoiceState, 11)).
|
||||
|
||||
update_voice_state_data_updates_version_test() ->
|
||||
VoiceState = #{<<"version">> => 1, <<"channel_id">> => <<"1">>},
|
||||
Member = #{<<"mute">> => true, <<"deaf">> => false},
|
||||
Flags = #voice_flags{
|
||||
self_mute = true,
|
||||
self_deaf = false,
|
||||
self_video = false,
|
||||
self_stream = false,
|
||||
is_mobile = false
|
||||
user_matches_voice_state_binary_test() ->
|
||||
VoiceState = #{<<"user_id">> => <<"123">>},
|
||||
?assert(user_matches_voice_state(VoiceState, <<"123">>)),
|
||||
?assertNot(user_matches_voice_state(VoiceState, <<"456">>)).
|
||||
|
||||
user_matches_voice_state_undefined_test() ->
|
||||
VoiceState = #{},
|
||||
?assertNot(user_matches_voice_state(VoiceState, 10)).
|
||||
|
||||
create_voice_state_test() ->
|
||||
Flags = #{
|
||||
self_mute => true,
|
||||
self_deaf => false,
|
||||
self_video => true,
|
||||
self_stream => false,
|
||||
is_mobile => true
|
||||
},
|
||||
{reply, #{voice_state := Updated}, _} =
|
||||
update_voice_state_data(
|
||||
<<"conn">>,
|
||||
<<"2">>,
|
||||
Flags,
|
||||
Member,
|
||||
VoiceState,
|
||||
#{<<"conn">> => VoiceState},
|
||||
#{voice_states => #{}},
|
||||
false,
|
||||
null
|
||||
),
|
||||
?assertEqual(2, maps:get(<<"version">>, Updated)),
|
||||
?assertEqual(<<"2">>, maps:get(<<"channel_id">>, Updated)).
|
||||
VS = create_voice_state(
|
||||
<<"1">>, <<"2">>, <<"3">>, <<"conn">>, false, false, Flags, []
|
||||
),
|
||||
?assertEqual(<<"1">>, maps:get(<<"guild_id">>, VS)),
|
||||
?assertEqual(<<"2">>, maps:get(<<"channel_id">>, VS)),
|
||||
?assertEqual(<<"3">>, maps:get(<<"user_id">>, VS)),
|
||||
?assertEqual(<<"conn">>, maps:get(<<"connection_id">>, VS)),
|
||||
?assertEqual(true, maps:get(<<"self_mute">>, VS)),
|
||||
?assertEqual(false, maps:get(<<"self_deaf">>, VS)),
|
||||
?assertEqual(true, maps:get(<<"self_video">>, VS)),
|
||||
?assertEqual(false, maps:get(<<"self_stream">>, VS)),
|
||||
?assertEqual(true, maps:get(<<"is_mobile">>, VS)),
|
||||
?assertEqual(0, maps:get(<<"version">>, VS)).
|
||||
|
||||
extract_session_info_from_voice_state_test() ->
|
||||
VoiceState = #{
|
||||
<<"session_id">> => <<"sess">>,
|
||||
<<"self_mute">> => true,
|
||||
<<"self_deaf">> => false,
|
||||
<<"self_video">> => true,
|
||||
<<"self_stream">> => false,
|
||||
<<"is_mobile">> => true,
|
||||
<<"member">> => #{<<"id">> => <<"m">>}
|
||||
},
|
||||
Info = extract_session_info_from_voice_state(<<"conn">>, VoiceState),
|
||||
?assertEqual(<<"conn">>, maps:get(connection_id, Info)),
|
||||
?assertEqual(<<"sess">>, maps:get(session_id, Info)),
|
||||
?assertEqual(true, maps:get(self_mute, Info)),
|
||||
?assertEqual(#{<<"id">> => <<"m">>}, maps:get(member, Info)).
|
||||
|
||||
has_voice_state_change_no_change_test() ->
|
||||
ExistingVoiceState = #{
|
||||
<<"channel_id">> => <<"100">>,
|
||||
<<"mute">> => false,
|
||||
<<"deaf">> => false,
|
||||
<<"self_mute">> => true,
|
||||
<<"self_deaf">> => false,
|
||||
<<"self_video">> => false,
|
||||
<<"self_stream">> => false,
|
||||
<<"is_mobile">> => false,
|
||||
<<"viewer_stream_keys">> => []
|
||||
},
|
||||
?assertNot(has_voice_state_change(
|
||||
ExistingVoiceState, <<"100">>, false, false, true, false, false, false, false, []
|
||||
)).
|
||||
|
||||
has_voice_state_change_channel_change_test() ->
|
||||
ExistingVoiceState = #{
|
||||
<<"channel_id">> => <<"100">>,
|
||||
<<"mute">> => false,
|
||||
<<"deaf">> => false,
|
||||
<<"self_mute">> => false,
|
||||
<<"self_deaf">> => false,
|
||||
<<"self_video">> => false,
|
||||
<<"self_stream">> => false,
|
||||
<<"is_mobile">> => false,
|
||||
<<"viewer_stream_keys">> => []
|
||||
},
|
||||
?assert(has_voice_state_change(
|
||||
ExistingVoiceState, <<"200">>, false, false, false, false, false, false, false, []
|
||||
)).
|
||||
|
||||
has_voice_state_change_self_mute_change_test() ->
|
||||
ExistingVoiceState = #{
|
||||
<<"channel_id">> => <<"100">>,
|
||||
<<"mute">> => false,
|
||||
<<"deaf">> => false,
|
||||
<<"self_mute">> => false,
|
||||
<<"self_deaf">> => false,
|
||||
<<"self_video">> => false,
|
||||
<<"self_stream">> => false,
|
||||
<<"is_mobile">> => false,
|
||||
<<"viewer_stream_keys">> => []
|
||||
},
|
||||
?assert(has_voice_state_change(
|
||||
ExistingVoiceState, <<"100">>, false, false, true, false, false, false, false, []
|
||||
)).
|
||||
|
||||
has_voice_state_change_server_mute_change_test() ->
|
||||
ExistingVoiceState = #{
|
||||
<<"channel_id">> => <<"100">>,
|
||||
<<"mute">> => false,
|
||||
<<"deaf">> => false,
|
||||
<<"self_mute">> => false,
|
||||
<<"self_deaf">> => false,
|
||||
<<"self_video">> => false,
|
||||
<<"self_stream">> => false,
|
||||
<<"is_mobile">> => false,
|
||||
<<"viewer_stream_keys">> => []
|
||||
},
|
||||
?assert(has_voice_state_change(
|
||||
ExistingVoiceState, <<"100">>, true, false, false, false, false, false, false, []
|
||||
)).
|
||||
|
||||
has_voice_state_change_viewer_stream_keys_change_test() ->
|
||||
ExistingVoiceState = #{
|
||||
<<"channel_id">> => <<"100">>,
|
||||
<<"mute">> => false,
|
||||
<<"deaf">> => false,
|
||||
<<"self_mute">> => false,
|
||||
<<"self_deaf">> => false,
|
||||
<<"self_video">> => false,
|
||||
<<"self_stream">> => false,
|
||||
<<"is_mobile">> => false,
|
||||
<<"viewer_stream_keys">> => []
|
||||
},
|
||||
?assert(has_voice_state_change(
|
||||
ExistingVoiceState, <<"100">>, false, false, false, false, false, false, false,
|
||||
[<<"999:100:conn">>]
|
||||
)).
|
||||
|
||||
has_voice_state_change_defaults_test() ->
|
||||
ExistingVoiceState = #{},
|
||||
?assertNot(has_voice_state_change(
|
||||
ExistingVoiceState, null, false, false, false, false, false, false, false, []
|
||||
)).
|
||||
|
||||
-endif.
|
||||
|
||||
@@ -0,0 +1,99 @@
|
||||
%% Copyright (C) 2026 Fluxer Contributors
|
||||
%%
|
||||
%% This file is part of Fluxer.
|
||||
%%
|
||||
%% Fluxer is free software: you can redistribute it and/or modify
|
||||
%% it under the terms of the GNU Affero General Public License as published by
|
||||
%% the Free Software Foundation, either version 3 of the License, or
|
||||
%% (at your option) any later version.
|
||||
%%
|
||||
%% Fluxer is distributed in the hope that it will be useful,
|
||||
%% but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
%% GNU Affero General Public License for more details.
|
||||
%%
|
||||
%% You should have received a copy of the GNU Affero General Public License
|
||||
%% along with Fluxer. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
-module(guild_voice_unclaimed_account_utils).
|
||||
|
||||
-export([parse_unclaimed_error/1]).
|
||||
|
||||
-ifdef(TEST).
|
||||
-include_lib("eunit/include/eunit.hrl").
|
||||
-endif.
|
||||
|
||||
-spec parse_unclaimed_error(iodata() | term()) -> boolean().
|
||||
parse_unclaimed_error(Body) when is_binary(Body); is_list(Body) ->
|
||||
try json:decode(iolist_to_binary(Body)) of
|
||||
Map when is_map(Map) ->
|
||||
case get_unclaimed_error_code(Map) of
|
||||
Code when is_binary(Code) -> is_voice_unclaimed_error_code(Code);
|
||||
_ -> false
|
||||
end;
|
||||
_ ->
|
||||
false
|
||||
catch
|
||||
_:_ -> false
|
||||
end;
|
||||
parse_unclaimed_error(_) ->
|
||||
false.
|
||||
|
||||
-spec get_unclaimed_error_code(map()) -> binary() | undefined.
|
||||
get_unclaimed_error_code(Map) when is_map(Map) ->
|
||||
case maps:get(<<"code">>, Map, undefined) of
|
||||
Code when is_binary(Code) ->
|
||||
Code;
|
||||
_ ->
|
||||
case maps:get(<<"error">>, Map, undefined) of
|
||||
Error when is_map(Error) ->
|
||||
maps:get(<<"code">>, Error, undefined);
|
||||
_ ->
|
||||
undefined
|
||||
end
|
||||
end.
|
||||
|
||||
-spec is_voice_unclaimed_error_code(binary()) -> boolean().
|
||||
is_voice_unclaimed_error_code(Code) when is_binary(Code) ->
|
||||
lists:member(
|
||||
Code,
|
||||
[
|
||||
<<"UNCLAIMED_ACCOUNT_CANNOT_JOIN_ONE_ON_ONE_VOICE_CALLS">>,
|
||||
<<"UNCLAIMED_ACCOUNT_CANNOT_JOIN_VOICE_CHANNELS">>
|
||||
]
|
||||
).
|
||||
|
||||
-ifdef(TEST).
|
||||
|
||||
parse_unclaimed_error_with_direct_code_test() ->
|
||||
Body = json:encode(#{<<"code">> => <<"UNCLAIMED_ACCOUNT_CANNOT_JOIN_VOICE_CHANNELS">>}),
|
||||
?assertEqual(true, parse_unclaimed_error(Body)).
|
||||
|
||||
parse_unclaimed_error_with_nested_code_test() ->
|
||||
Body = json:encode(#{
|
||||
<<"error">> => #{<<"code">> => <<"UNCLAIMED_ACCOUNT_CANNOT_JOIN_ONE_ON_ONE_VOICE_CALLS">>}
|
||||
}),
|
||||
?assertEqual(true, parse_unclaimed_error(Body)).
|
||||
|
||||
parse_unclaimed_error_with_unknown_code_test() ->
|
||||
Body = json:encode(#{<<"code">> => <<"SOME_OTHER_ERROR">>}),
|
||||
?assertEqual(false, parse_unclaimed_error(Body)).
|
||||
|
||||
parse_unclaimed_error_with_invalid_json_test() ->
|
||||
?assertEqual(false, parse_unclaimed_error(<<"not json">>)).
|
||||
|
||||
parse_unclaimed_error_with_non_binary_test() ->
|
||||
?assertEqual(false, parse_unclaimed_error(undefined)),
|
||||
?assertEqual(false, parse_unclaimed_error(123)).
|
||||
|
||||
is_voice_unclaimed_error_code_test() ->
|
||||
?assertEqual(
|
||||
true, is_voice_unclaimed_error_code(<<"UNCLAIMED_ACCOUNT_CANNOT_JOIN_VOICE_CHANNELS">>)
|
||||
),
|
||||
?assertEqual(
|
||||
true,
|
||||
is_voice_unclaimed_error_code(<<"UNCLAIMED_ACCOUNT_CANNOT_JOIN_ONE_ON_ONE_VOICE_CALLS">>)
|
||||
),
|
||||
?assertEqual(false, is_voice_unclaimed_error_code(<<"OTHER_ERROR">>)).
|
||||
|
||||
-endif.
|
||||
@@ -24,6 +24,10 @@
|
||||
channel_has_capacity/3
|
||||
]).
|
||||
|
||||
-ifdef(TEST).
|
||||
-include_lib("eunit/include/eunit.hrl").
|
||||
-endif.
|
||||
|
||||
-type user_id() :: integer().
|
||||
-type session_id() :: binary().
|
||||
-type session_pid() :: pid().
|
||||
@@ -103,3 +107,56 @@ channel_has_capacity(ChannelId, UserLimit, VoiceStates) ->
|
||||
-spec ensure_binary(binary() | integer()) -> binary().
|
||||
ensure_binary(Value) when is_binary(Value) -> Value;
|
||||
ensure_binary(Value) when is_integer(Value) -> integer_to_binary(Value).
|
||||
|
||||
-ifdef(TEST).
|
||||
|
||||
find_session_by_user_id_test() ->
|
||||
Pid = self(),
|
||||
Ref = make_ref(),
|
||||
Sessions = #{
|
||||
<<"session1">> => {100, Pid, Ref},
|
||||
<<"session2">> => {200, Pid, make_ref()}
|
||||
},
|
||||
?assertMatch({ok, <<"session1">>, _, _}, find_session_by_user_id(100, Sessions)),
|
||||
?assertEqual(not_found, find_session_by_user_id(999, Sessions)).
|
||||
|
||||
disconnect_user_not_found_test() ->
|
||||
VoiceStates = #{},
|
||||
Sessions = #{},
|
||||
CleanupFun = fun(_, _) -> ok end,
|
||||
?assertMatch({not_found, _, _}, disconnect_user(100, VoiceStates, Sessions, CleanupFun)).
|
||||
|
||||
disconnect_user_if_in_channel_mismatch_test() ->
|
||||
VoiceStates = #{100 => #{<<"channel_id">> => <<"999">>}},
|
||||
Sessions = #{},
|
||||
CleanupFun = fun(_, _) -> ok end,
|
||||
Result = disconnect_user_if_in_channel(100, 123, VoiceStates, Sessions, CleanupFun),
|
||||
?assertMatch({channel_mismatch, _, _}, Result).
|
||||
|
||||
disconnect_user_if_in_channel_not_found_test() ->
|
||||
VoiceStates = #{},
|
||||
Sessions = #{},
|
||||
CleanupFun = fun(_, _) -> ok end,
|
||||
Result = disconnect_user_if_in_channel(100, 123, VoiceStates, Sessions, CleanupFun),
|
||||
?assertMatch({not_found, _, _}, Result).
|
||||
|
||||
channel_has_capacity_unlimited_test() ->
|
||||
VoiceStates = #{
|
||||
1 => #{<<"channel_id">> => <<"100">>},
|
||||
2 => #{<<"channel_id">> => <<"100">>}
|
||||
},
|
||||
?assert(channel_has_capacity(100, 0, VoiceStates)).
|
||||
|
||||
channel_has_capacity_limited_test() ->
|
||||
VoiceStates = #{
|
||||
1 => #{<<"channel_id">> => <<"100">>},
|
||||
2 => #{<<"channel_id">> => <<"100">>}
|
||||
},
|
||||
?assertNot(channel_has_capacity(100, 2, VoiceStates)),
|
||||
?assert(channel_has_capacity(100, 3, VoiceStates)).
|
||||
|
||||
ensure_binary_test() ->
|
||||
?assertEqual(<<"123">>, ensure_binary(123)),
|
||||
?assertEqual(<<"abc">>, ensure_binary(<<"abc">>)).
|
||||
|
||||
-endif.
|
||||
|
||||
@@ -24,6 +24,10 @@
|
||||
confirm_pending_connection/2
|
||||
]).
|
||||
|
||||
-ifdef(TEST).
|
||||
-include_lib("eunit/include/eunit.hrl").
|
||||
-endif.
|
||||
|
||||
-type connection_id() :: binary().
|
||||
-type pending_metadata() :: map().
|
||||
-type pending_map() :: #{connection_id() => pending_metadata()}.
|
||||
@@ -56,3 +60,35 @@ confirm_pending_connection(ConnectionId, PendingMap) ->
|
||||
_Metadata ->
|
||||
{confirmed, maps:remove(ConnectionId, PendingMap)}
|
||||
end.
|
||||
|
||||
-ifdef(TEST).
|
||||
|
||||
add_pending_connection_test() ->
|
||||
PendingMap = #{},
|
||||
Metadata = #{user_id => 1, channel_id => 2},
|
||||
Result = add_pending_connection(<<"conn">>, Metadata, PendingMap),
|
||||
?assert(maps:is_key(<<"conn">>, Result)),
|
||||
StoredMetadata = maps:get(<<"conn">>, Result),
|
||||
?assertEqual(1, maps:get(user_id, StoredMetadata)),
|
||||
?assertEqual(2, maps:get(channel_id, StoredMetadata)),
|
||||
?assert(maps:is_key(joined_at, StoredMetadata)).
|
||||
|
||||
remove_pending_connection_test() ->
|
||||
PendingMap = #{<<"conn">> => #{user_id => 1}},
|
||||
?assertEqual(#{}, remove_pending_connection(<<"conn">>, PendingMap)),
|
||||
?assertEqual(PendingMap, remove_pending_connection(undefined, PendingMap)),
|
||||
?assertEqual(PendingMap, remove_pending_connection(<<"other">>, PendingMap)).
|
||||
|
||||
get_pending_connection_test() ->
|
||||
PendingMap = #{<<"conn">> => #{user_id => 1}},
|
||||
?assertEqual(#{user_id => 1}, get_pending_connection(<<"conn">>, PendingMap)),
|
||||
?assertEqual(undefined, get_pending_connection(<<"other">>, PendingMap)),
|
||||
?assertEqual(undefined, get_pending_connection(undefined, PendingMap)).
|
||||
|
||||
confirm_pending_connection_test() ->
|
||||
PendingMap = #{<<"conn">> => #{user_id => 1}},
|
||||
?assertMatch({confirmed, #{}}, confirm_pending_connection(<<"conn">>, PendingMap)),
|
||||
?assertMatch({not_found, _}, confirm_pending_connection(<<"other">>, PendingMap)),
|
||||
?assertMatch({not_found, _}, confirm_pending_connection(undefined, PendingMap)).
|
||||
|
||||
-endif.
|
||||
|
||||
@@ -33,66 +33,86 @@
|
||||
build_stream_key/3
|
||||
]).
|
||||
|
||||
-ifdef(TEST).
|
||||
-include_lib("eunit/include/eunit.hrl").
|
||||
-endif.
|
||||
|
||||
-type voice_state() :: map().
|
||||
-type voice_state_map() :: #{binary() => voice_state()}.
|
||||
-type guild_state() :: map().
|
||||
-type stream_key_result() :: #{
|
||||
scope := guild | dm,
|
||||
guild_id := integer() | undefined,
|
||||
channel_id := integer(),
|
||||
connection_id := binary()
|
||||
}.
|
||||
|
||||
-spec voice_states(guild_state()) -> voice_state_map().
|
||||
voice_states(State) when is_map(State) ->
|
||||
case maps:get(voice_states, State, undefined) of
|
||||
Map when is_map(Map) -> Map;
|
||||
_ -> #{}
|
||||
end.
|
||||
|
||||
-spec ensure_voice_states(term()) -> voice_state_map().
|
||||
ensure_voice_states(Map) when is_map(Map) ->
|
||||
Map;
|
||||
ensure_voice_states(_) ->
|
||||
#{}.
|
||||
|
||||
-spec voice_state_user_id(voice_state()) -> integer() | undefined.
|
||||
voice_state_user_id(VoiceState) ->
|
||||
map_utils:get_integer(VoiceState, <<"user_id">>, undefined).
|
||||
|
||||
-spec voice_state_channel_id(voice_state()) -> integer() | undefined.
|
||||
voice_state_channel_id(VoiceState) ->
|
||||
map_utils:get_integer(VoiceState, <<"channel_id">>, undefined).
|
||||
|
||||
-spec voice_state_guild_id(voice_state()) -> integer() | undefined.
|
||||
voice_state_guild_id(VoiceState) ->
|
||||
map_utils:get_integer(VoiceState, <<"guild_id">>, undefined).
|
||||
|
||||
-spec filter_voice_states(voice_state_map(), fun((binary(), voice_state()) -> boolean())) ->
|
||||
voice_state_map().
|
||||
filter_voice_states(VoiceStates, Predicate) when is_map(VoiceStates) ->
|
||||
maps:filter(Predicate, VoiceStates);
|
||||
filter_voice_states(_, _) ->
|
||||
#{}.
|
||||
|
||||
-spec drop_voice_states(voice_state_map(), voice_state_map()) -> voice_state_map().
|
||||
drop_voice_states(ToDrop, VoiceStates) ->
|
||||
maps:fold(fun(ConnId, _VoiceState, Acc) -> maps:remove(ConnId, Acc) end, VoiceStates, ToDrop).
|
||||
|
||||
-spec broadcast_disconnects(voice_state_map(), guild_state()) -> ok.
|
||||
broadcast_disconnects(VoiceStates, State) ->
|
||||
maps:foreach(
|
||||
fun(ConnId, VoiceState) ->
|
||||
OldChannelIdBin = maps:get(<<"channel_id">>, VoiceState, null),
|
||||
DisconnectVoiceState = VoiceState#{
|
||||
<<"channel_id">> => null,
|
||||
<<"connection_id">> => ConnId
|
||||
},
|
||||
guild_voice_broadcast:broadcast_voice_state_update(
|
||||
DisconnectVoiceState, State, OldChannelIdBin
|
||||
)
|
||||
end,
|
||||
VoiceStates
|
||||
).
|
||||
spawn(fun() ->
|
||||
maps:foreach(
|
||||
fun(ConnId, VoiceState) ->
|
||||
OldChannelIdBin = maps:get(<<"channel_id">>, VoiceState, null),
|
||||
DisconnectVoiceState = VoiceState#{
|
||||
<<"channel_id">> => null,
|
||||
<<"connection_id">> => ConnId
|
||||
},
|
||||
guild_voice_broadcast:broadcast_voice_state_update(
|
||||
DisconnectVoiceState, State, OldChannelIdBin
|
||||
)
|
||||
end,
|
||||
VoiceStates
|
||||
)
|
||||
end),
|
||||
ok.
|
||||
|
||||
-spec voice_flags_from_context(map()) -> voice_flags().
|
||||
voice_flags_from_context(Context) ->
|
||||
#voice_flags{
|
||||
self_mute = maps:get(self_mute, Context, false),
|
||||
self_deaf = maps:get(self_deaf, Context, false),
|
||||
self_video = maps:get(self_video, Context, false),
|
||||
self_stream = maps:get(self_stream, Context, false),
|
||||
is_mobile = maps:get(is_mobile, Context, false)
|
||||
#{
|
||||
self_mute => maps:get(self_mute, Context, false),
|
||||
self_deaf => maps:get(self_deaf, Context, false),
|
||||
self_video => maps:get(self_video, Context, false),
|
||||
self_stream => maps:get(self_stream, Context, false),
|
||||
is_mobile => maps:get(is_mobile, Context, false)
|
||||
}.
|
||||
|
||||
-spec parse_stream_key(term()) ->
|
||||
{ok, #{
|
||||
scope := guild | dm,
|
||||
guild_id := integer() | undefined,
|
||||
channel_id := integer(),
|
||||
connection_id := binary()
|
||||
}}
|
||||
| {error, invalid_stream_key}.
|
||||
-spec parse_stream_key(term()) -> {ok, stream_key_result()} | {error, invalid_stream_key}.
|
||||
parse_stream_key(StreamKey) when is_binary(StreamKey) ->
|
||||
Parts = binary:split(StreamKey, <<":">>, [global]),
|
||||
case Parts of
|
||||
@@ -126,12 +146,7 @@ parse_channel_bin(ChannelBin) ->
|
||||
Chan.
|
||||
|
||||
-spec build_stream_key_result({dm, undefined} | {guild, integer()}, integer(), binary()) ->
|
||||
{ok, #{
|
||||
scope := guild | dm,
|
||||
guild_id := integer() | undefined,
|
||||
channel_id := integer(),
|
||||
connection_id := binary()
|
||||
}}.
|
||||
{ok, stream_key_result()}.
|
||||
build_stream_key_result({dm, _}, ChannelId, ConnId) ->
|
||||
{ok, #{
|
||||
scope => dm,
|
||||
@@ -162,3 +177,84 @@ build_stream_key(GuildId, ChannelId, ConnectionId) when
|
||||
":",
|
||||
ConnectionId/binary
|
||||
>>.
|
||||
|
||||
-ifdef(TEST).
|
||||
|
||||
voice_states_returns_map_test() ->
|
||||
State = #{voice_states => #{<<"a">> => #{}}},
|
||||
?assertEqual(#{<<"a">> => #{}}, voice_states(State)),
|
||||
?assertEqual(#{}, voice_states(#{})),
|
||||
?assertEqual(#{}, voice_states(#{voice_states => not_a_map})).
|
||||
|
||||
ensure_voice_states_test() ->
|
||||
?assertEqual(#{<<"a">> => 1}, ensure_voice_states(#{<<"a">> => 1})),
|
||||
?assertEqual(#{}, ensure_voice_states(not_a_map)).
|
||||
|
||||
voice_state_user_id_test() ->
|
||||
?assertEqual(123, voice_state_user_id(#{<<"user_id">> => <<"123">>})),
|
||||
?assertEqual(undefined, voice_state_user_id(#{})).
|
||||
|
||||
voice_state_channel_id_test() ->
|
||||
?assertEqual(456, voice_state_channel_id(#{<<"channel_id">> => <<"456">>})),
|
||||
?assertEqual(undefined, voice_state_channel_id(#{})).
|
||||
|
||||
voice_state_guild_id_test() ->
|
||||
?assertEqual(789, voice_state_guild_id(#{<<"guild_id">> => <<"789">>})),
|
||||
?assertEqual(undefined, voice_state_guild_id(#{})).
|
||||
|
||||
filter_voice_states_test() ->
|
||||
VoiceStates = #{
|
||||
<<"a">> => #{<<"user_id">> => <<"1">>},
|
||||
<<"b">> => #{<<"user_id">> => <<"2">>}
|
||||
},
|
||||
Filtered = filter_voice_states(VoiceStates, fun(_, V) ->
|
||||
maps:get(<<"user_id">>, V) =:= <<"1">>
|
||||
end),
|
||||
?assertEqual(#{<<"a">> => #{<<"user_id">> => <<"1">>}}, Filtered).
|
||||
|
||||
drop_voice_states_test() ->
|
||||
VoiceStates = #{<<"a">> => #{}, <<"b">> => #{}, <<"c">> => #{}},
|
||||
ToDrop = #{<<"a">> => #{}, <<"c">> => #{}},
|
||||
Result = drop_voice_states(ToDrop, VoiceStates),
|
||||
?assertEqual(#{<<"b">> => #{}}, Result).
|
||||
|
||||
voice_flags_from_context_test() ->
|
||||
Context = #{
|
||||
self_mute => true,
|
||||
self_deaf => false,
|
||||
self_video => true,
|
||||
self_stream => false,
|
||||
is_mobile => true
|
||||
},
|
||||
Flags = voice_flags_from_context(Context),
|
||||
?assertEqual(true, maps:get(self_mute, Flags)),
|
||||
?assertEqual(false, maps:get(self_deaf, Flags)),
|
||||
?assertEqual(true, maps:get(self_video, Flags)),
|
||||
?assertEqual(false, maps:get(self_stream, Flags)),
|
||||
?assertEqual(true, maps:get(is_mobile, Flags)).
|
||||
|
||||
parse_stream_key_dm_test() ->
|
||||
Result = parse_stream_key(<<"dm:123:conn-id">>),
|
||||
?assertMatch({ok, #{scope := dm, channel_id := 123, connection_id := <<"conn-id">>}}, Result).
|
||||
|
||||
parse_stream_key_guild_test() ->
|
||||
Result = parse_stream_key(<<"999:123:conn-id">>),
|
||||
?assertMatch(
|
||||
{ok, #{scope := guild, guild_id := 999, channel_id := 123, connection_id := <<"conn-id">>}},
|
||||
Result
|
||||
).
|
||||
|
||||
parse_stream_key_invalid_test() ->
|
||||
?assertEqual({error, invalid_stream_key}, parse_stream_key(<<"invalid">>)),
|
||||
?assertEqual({error, invalid_stream_key}, parse_stream_key(<<"a:b">>)),
|
||||
?assertEqual({error, invalid_stream_key}, parse_stream_key(123)).
|
||||
|
||||
build_stream_key_dm_test() ->
|
||||
Result = build_stream_key(undefined, 123, <<"conn">>),
|
||||
?assertEqual(<<"dm:123:conn">>, Result).
|
||||
|
||||
build_stream_key_guild_test() ->
|
||||
Result = build_stream_key(999, 123, <<"conn">>),
|
||||
?assertEqual(<<"999:123:conn">>, Result).
|
||||
|
||||
-endif.
|
||||
|
||||
@@ -20,45 +20,56 @@
|
||||
-export([
|
||||
build_voice_token_rpc_request/6,
|
||||
build_voice_token_rpc_request/7,
|
||||
build_voice_token_rpc_request/8,
|
||||
build_force_disconnect_rpc_request/4,
|
||||
build_update_participant_rpc_request/5,
|
||||
build_update_participant_permissions_rpc_request/5,
|
||||
add_geolocation_to_request/3,
|
||||
compute_voice_permissions/3
|
||||
add_rtc_region_to_request/2,
|
||||
compute_voice_permissions/3,
|
||||
generate_token_nonce/0
|
||||
]).
|
||||
|
||||
-ifdef(TEST).
|
||||
-include_lib("eunit/include/eunit.hrl").
|
||||
-endif.
|
||||
|
||||
-type guild_state() :: map().
|
||||
-type voice_permissions() :: #{
|
||||
can_speak := boolean(),
|
||||
can_stream := boolean(),
|
||||
can_video := boolean()
|
||||
}.
|
||||
|
||||
-spec build_voice_token_rpc_request(
|
||||
integer() | null,
|
||||
integer(),
|
||||
integer(),
|
||||
binary() | integer() | null,
|
||||
binary() | number() | list() | null,
|
||||
binary() | number() | list() | null
|
||||
) -> map().
|
||||
build_voice_token_rpc_request(GuildId, ChannelId, UserId, ConnectionId, Latitude, Longitude) ->
|
||||
BaseReq0 = #{
|
||||
<<"type">> => <<"voice_get_token">>,
|
||||
<<"channel_id">> => integer_to_binary(ChannelId),
|
||||
<<"user_id">> => integer_to_binary(UserId)
|
||||
},
|
||||
BaseReq =
|
||||
case GuildId of
|
||||
null ->
|
||||
#{
|
||||
<<"type">> => <<"voice_get_token">>,
|
||||
<<"channel_id">> => integer_to_binary(ChannelId),
|
||||
<<"user_id">> => integer_to_binary(UserId)
|
||||
};
|
||||
_ ->
|
||||
BaseMap = #{
|
||||
<<"type">> => <<"voice_get_token">>,
|
||||
<<"guild_id">> => integer_to_binary(GuildId),
|
||||
<<"channel_id">> => integer_to_binary(ChannelId),
|
||||
<<"user_id">> => integer_to_binary(UserId)
|
||||
},
|
||||
case ConnectionId of
|
||||
null ->
|
||||
BaseMap;
|
||||
ConnectionId when is_binary(ConnectionId) ->
|
||||
maps:put(<<"connection_id">>, ConnectionId, BaseMap);
|
||||
ConnectionId when is_integer(ConnectionId) ->
|
||||
maps:put(<<"connection_id">>, integer_to_binary(ConnectionId), BaseMap);
|
||||
_ ->
|
||||
BaseMap
|
||||
end
|
||||
null -> BaseReq0;
|
||||
_ -> maps:put(<<"guild_id">>, integer_to_binary(GuildId), BaseReq0)
|
||||
end,
|
||||
WithConnection = add_connection_id_to_request(BaseReq, ConnectionId),
|
||||
add_geolocation_to_request(WithConnection, Latitude, Longitude).
|
||||
|
||||
add_geolocation_to_request(BaseReq, Latitude, Longitude).
|
||||
|
||||
-spec add_geolocation_to_request(
|
||||
map(),
|
||||
binary() | number() | list() | null,
|
||||
binary() | number() | list() | null
|
||||
) -> map().
|
||||
add_geolocation_to_request(RequestMap, Latitude, Longitude) ->
|
||||
case {Latitude, Longitude} of
|
||||
case {normalise_coordinate(Latitude), normalise_coordinate(Longitude)} of
|
||||
{Lat, Long} when is_binary(Lat) andalso is_binary(Long) ->
|
||||
maps:merge(RequestMap, #{
|
||||
<<"latitude">> => Lat,
|
||||
@@ -68,6 +79,47 @@ add_geolocation_to_request(RequestMap, Latitude, Longitude) ->
|
||||
RequestMap
|
||||
end.
|
||||
|
||||
-spec normalise_coordinate(binary() | number() | list() | null) -> binary() | undefined.
|
||||
normalise_coordinate(null) ->
|
||||
undefined;
|
||||
normalise_coordinate(Value) when is_binary(Value) ->
|
||||
Value;
|
||||
normalise_coordinate(Value) when is_integer(Value) ->
|
||||
integer_to_binary(Value);
|
||||
normalise_coordinate(Value) when is_float(Value) ->
|
||||
float_to_binary(Value, [short]);
|
||||
normalise_coordinate(Value) when is_list(Value) ->
|
||||
try
|
||||
list_to_binary(Value)
|
||||
catch
|
||||
error:badarg -> undefined
|
||||
end;
|
||||
normalise_coordinate(_Value) ->
|
||||
undefined.
|
||||
|
||||
-spec add_rtc_region_to_request(map(), binary() | null) -> map().
|
||||
add_rtc_region_to_request(RequestMap, Region) ->
|
||||
case Region of
|
||||
RegionBin when is_binary(RegionBin) ->
|
||||
maps:put(<<"rtc_region">>, RegionBin, RequestMap);
|
||||
_ ->
|
||||
RequestMap
|
||||
end.
|
||||
|
||||
-spec add_connection_id_to_request(map(), binary() | integer() | null) -> map().
|
||||
add_connection_id_to_request(RequestMap, ConnectionId) ->
|
||||
case ConnectionId of
|
||||
null ->
|
||||
RequestMap;
|
||||
ConnectionIdBin when is_binary(ConnectionIdBin) ->
|
||||
maps:put(<<"connection_id">>, ConnectionIdBin, RequestMap);
|
||||
ConnectionIdInt when is_integer(ConnectionIdInt) ->
|
||||
maps:put(<<"connection_id">>, integer_to_binary(ConnectionIdInt), RequestMap);
|
||||
_ ->
|
||||
RequestMap
|
||||
end.
|
||||
|
||||
-spec build_force_disconnect_rpc_request(integer() | null, integer(), integer(), binary()) -> map().
|
||||
build_force_disconnect_rpc_request(GuildId, ChannelId, UserId, ConnectionId) ->
|
||||
BaseReq = #{
|
||||
<<"type">> => <<"voice_force_disconnect_participant">>,
|
||||
@@ -82,6 +134,9 @@ build_force_disconnect_rpc_request(GuildId, ChannelId, UserId, ConnectionId) ->
|
||||
maps:put(<<"guild_id">>, integer_to_binary(GuildId), BaseReq)
|
||||
end.
|
||||
|
||||
-spec build_update_participant_rpc_request(
|
||||
integer() | null, integer(), integer(), boolean(), boolean()
|
||||
) -> map().
|
||||
build_update_participant_rpc_request(GuildId, ChannelId, UserId, Mute, Deaf) ->
|
||||
BaseReq = #{
|
||||
<<"type">> => <<"voice_update_participant">>,
|
||||
@@ -97,6 +152,9 @@ build_update_participant_rpc_request(GuildId, ChannelId, UserId, Mute, Deaf) ->
|
||||
maps:put(<<"guild_id">>, integer_to_binary(GuildId), BaseReq)
|
||||
end.
|
||||
|
||||
-spec build_update_participant_permissions_rpc_request(
|
||||
integer() | null, integer(), integer(), binary(), voice_permissions()
|
||||
) -> map().
|
||||
build_update_participant_permissions_rpc_request(
|
||||
GuildId, ChannelId, UserId, ConnectionId, VoicePermissions
|
||||
) ->
|
||||
@@ -116,35 +174,181 @@ build_update_participant_permissions_rpc_request(
|
||||
maps:put(<<"guild_id">>, integer_to_binary(GuildId), BaseReq)
|
||||
end.
|
||||
|
||||
-spec compute_voice_permissions(integer(), integer(), map()) -> map().
|
||||
-spec compute_voice_permissions(integer(), integer(), guild_state()) -> voice_permissions().
|
||||
compute_voice_permissions(UserId, ChannelId, State) ->
|
||||
Permissions = guild_permissions:get_member_permissions(UserId, ChannelId, State),
|
||||
SpeakPerm = constants:speak_permission(),
|
||||
StreamPerm = constants:stream_permission(),
|
||||
AdminPerm = constants:administrator_permission(),
|
||||
|
||||
IsAdmin = (Permissions band AdminPerm) =:= AdminPerm,
|
||||
CanSpeak = IsAdmin orelse ((Permissions band SpeakPerm) =:= SpeakPerm),
|
||||
CanStream = IsAdmin orelse ((Permissions band StreamPerm) =:= StreamPerm),
|
||||
|
||||
HasVirtualAccess = guild_virtual_channel_access:has_virtual_access(UserId, ChannelId, State),
|
||||
FinalCanSpeak = CanSpeak orelse HasVirtualAccess,
|
||||
FinalCanStream = CanStream orelse HasVirtualAccess,
|
||||
|
||||
#{
|
||||
can_speak => FinalCanSpeak,
|
||||
can_stream => FinalCanStream,
|
||||
can_video => FinalCanStream
|
||||
}.
|
||||
|
||||
-spec build_voice_token_rpc_request(
|
||||
integer() | null,
|
||||
integer(),
|
||||
integer(),
|
||||
binary() | integer() | null,
|
||||
binary() | null,
|
||||
binary() | null,
|
||||
voice_permissions()
|
||||
) -> map().
|
||||
build_voice_token_rpc_request(
|
||||
GuildId, ChannelId, UserId, ConnectionId, Latitude, Longitude, VoicePermissions
|
||||
) ->
|
||||
build_voice_token_rpc_request(
|
||||
GuildId, ChannelId, UserId, ConnectionId, Latitude, Longitude, VoicePermissions, null
|
||||
).
|
||||
|
||||
-spec build_voice_token_rpc_request(
|
||||
integer() | null,
|
||||
integer(),
|
||||
integer(),
|
||||
binary() | integer() | null,
|
||||
binary() | null,
|
||||
binary() | null,
|
||||
voice_permissions(),
|
||||
binary() | null
|
||||
) -> map().
|
||||
build_voice_token_rpc_request(
|
||||
GuildId, ChannelId, UserId, ConnectionId, Latitude, Longitude, VoicePermissions, TokenNonce
|
||||
) ->
|
||||
BaseReq = build_voice_token_rpc_request(
|
||||
GuildId, ChannelId, UserId, ConnectionId, Latitude, Longitude
|
||||
),
|
||||
maps:merge(BaseReq, #{
|
||||
Req0 = maps:merge(BaseReq, #{
|
||||
<<"can_speak">> => maps:get(can_speak, VoicePermissions, true),
|
||||
<<"can_stream">> => maps:get(can_stream, VoicePermissions, true),
|
||||
<<"can_video">> => maps:get(can_video, VoicePermissions, true)
|
||||
}).
|
||||
}),
|
||||
case TokenNonce of
|
||||
null -> Req0;
|
||||
undefined -> Req0;
|
||||
_ when is_binary(TokenNonce) -> maps:put(<<"token_nonce">>, TokenNonce, Req0);
|
||||
_ -> Req0
|
||||
end.
|
||||
|
||||
-spec generate_token_nonce() -> binary().
|
||||
generate_token_nonce() ->
|
||||
Bytes = crypto:strong_rand_bytes(16),
|
||||
binary:encode_hex(Bytes, lowercase).
|
||||
|
||||
-ifdef(TEST).
|
||||
|
||||
build_voice_token_rpc_request_guild_test() ->
|
||||
Req = build_voice_token_rpc_request(123, 456, 789, null, null, null),
|
||||
?assertEqual(<<"voice_get_token">>, maps:get(<<"type">>, Req)),
|
||||
?assertEqual(<<"123">>, maps:get(<<"guild_id">>, Req)),
|
||||
?assertEqual(<<"456">>, maps:get(<<"channel_id">>, Req)),
|
||||
?assertEqual(<<"789">>, maps:get(<<"user_id">>, Req)),
|
||||
?assertNot(maps:is_key(<<"connection_id">>, Req)).
|
||||
|
||||
build_voice_token_rpc_request_dm_test() ->
|
||||
Req = build_voice_token_rpc_request(null, 456, 789, null, null, null),
|
||||
?assertEqual(<<"voice_get_token">>, maps:get(<<"type">>, Req)),
|
||||
?assertNot(maps:is_key(<<"guild_id">>, Req)),
|
||||
?assertEqual(<<"456">>, maps:get(<<"channel_id">>, Req)).
|
||||
|
||||
build_voice_token_rpc_request_with_connection_test() ->
|
||||
Req = build_voice_token_rpc_request(123, 456, 789, <<"conn-id">>, null, null),
|
||||
?assertEqual(<<"conn-id">>, maps:get(<<"connection_id">>, Req)).
|
||||
|
||||
build_voice_token_rpc_request_dm_with_connection_test() ->
|
||||
Req = build_voice_token_rpc_request(null, 456, 789, <<"conn-id">>, null, null),
|
||||
?assertEqual(<<"conn-id">>, maps:get(<<"connection_id">>, Req)).
|
||||
|
||||
add_geolocation_to_request_test() ->
|
||||
BaseReq = #{<<"type">> => <<"test">>},
|
||||
WithGeo = add_geolocation_to_request(BaseReq, <<"1.0">>, <<"2.0">>),
|
||||
?assertEqual(<<"1.0">>, maps:get(<<"latitude">>, WithGeo)),
|
||||
?assertEqual(<<"2.0">>, maps:get(<<"longitude">>, WithGeo)),
|
||||
WithoutGeo = add_geolocation_to_request(BaseReq, null, null),
|
||||
?assertNot(maps:is_key(<<"latitude">>, WithoutGeo)).
|
||||
|
||||
add_geolocation_to_request_number_test() ->
|
||||
BaseReq = #{<<"type">> => <<"test">>},
|
||||
WithGeo = add_geolocation_to_request(BaseReq, 1.5, 2.25),
|
||||
?assertEqual(<<"1.5">>, maps:get(<<"latitude">>, WithGeo)),
|
||||
?assertEqual(<<"2.25">>, maps:get(<<"longitude">>, WithGeo)).
|
||||
|
||||
add_rtc_region_to_request_test() ->
|
||||
BaseReq = #{<<"type">> => <<"test">>},
|
||||
WithRegion = add_rtc_region_to_request(BaseReq, <<"us-east">>),
|
||||
?assertEqual(<<"us-east">>, maps:get(<<"rtc_region">>, WithRegion)),
|
||||
WithoutRegion = add_rtc_region_to_request(BaseReq, null),
|
||||
?assertNot(maps:is_key(<<"rtc_region">>, WithoutRegion)).
|
||||
|
||||
build_force_disconnect_rpc_request_test() ->
|
||||
Req = build_force_disconnect_rpc_request(123, 456, 789, <<"conn">>),
|
||||
?assertEqual(<<"voice_force_disconnect_participant">>, maps:get(<<"type">>, Req)),
|
||||
?assertEqual(<<"123">>, maps:get(<<"guild_id">>, Req)),
|
||||
?assertEqual(<<"conn">>, maps:get(<<"connection_id">>, Req)).
|
||||
|
||||
build_update_participant_rpc_request_test() ->
|
||||
Req = build_update_participant_rpc_request(123, 456, 789, true, false),
|
||||
?assertEqual(<<"voice_update_participant">>, maps:get(<<"type">>, Req)),
|
||||
?assertEqual(true, maps:get(<<"mute">>, Req)),
|
||||
?assertEqual(false, maps:get(<<"deaf">>, Req)).
|
||||
|
||||
generate_token_nonce_format_test() ->
|
||||
Nonce = generate_token_nonce(),
|
||||
?assert(is_binary(Nonce)),
|
||||
?assertEqual(32, byte_size(Nonce)),
|
||||
?assert(lists:all(fun(C) ->
|
||||
(C >= $0 andalso C =< $9) orelse (C >= $a andalso C =< $f)
|
||||
end, binary_to_list(Nonce))).
|
||||
|
||||
generate_token_nonce_unique_test() ->
|
||||
Nonce1 = generate_token_nonce(),
|
||||
Nonce2 = generate_token_nonce(),
|
||||
Nonce3 = generate_token_nonce(),
|
||||
?assertNot(Nonce1 =:= Nonce2),
|
||||
?assertNot(Nonce2 =:= Nonce3),
|
||||
?assertNot(Nonce1 =:= Nonce3).
|
||||
|
||||
build_voice_token_rpc_request_with_nonce_test() ->
|
||||
VoicePerms = #{
|
||||
can_speak => true,
|
||||
can_stream => false,
|
||||
can_video => false
|
||||
},
|
||||
Req = build_voice_token_rpc_request(
|
||||
123, 456, 789, null, null, null, VoicePerms, <<"test-nonce-123">>
|
||||
),
|
||||
?assertEqual(<<"test-nonce-123">>, maps:get(<<"token_nonce">>, Req)),
|
||||
?assertEqual(true, maps:get(<<"can_speak">>, Req)),
|
||||
?assertEqual(false, maps:get(<<"can_stream">>, Req)).
|
||||
|
||||
build_voice_token_rpc_request_without_nonce_test() ->
|
||||
VoicePerms = #{
|
||||
can_speak => true,
|
||||
can_stream => true,
|
||||
can_video => true
|
||||
},
|
||||
Req = build_voice_token_rpc_request(
|
||||
123, 456, 789, null, null, null, VoicePerms, null
|
||||
),
|
||||
?assertNot(maps:is_key(<<"token_nonce">>, Req)),
|
||||
?assertEqual(true, maps:get(<<"can_speak">>, Req)).
|
||||
|
||||
build_voice_token_rpc_request_undefined_nonce_test() ->
|
||||
VoicePerms = #{
|
||||
can_speak => false,
|
||||
can_stream => true,
|
||||
can_video => true
|
||||
},
|
||||
Req = build_voice_token_rpc_request(
|
||||
123, 456, 789, null, null, null, VoicePerms, undefined
|
||||
),
|
||||
?assertNot(maps:is_key(<<"token_nonce">>, Req)),
|
||||
?assertEqual(false, maps:get(<<"can_speak">>, Req)).
|
||||
|
||||
-endif.
|
||||
|
||||
Reference in New Issue
Block a user