refactor progress

This commit is contained in:
Hampus Kraft
2026-02-17 12:22:36 +00:00
parent cb31608523
commit d5abd1a7e4
8257 changed files with 1190207 additions and 761040 deletions

View File

@@ -1,5 +1,7 @@
FROM erlang:28-slim AS build
ARG LOGGER_LEVEL=info
WORKDIR /usr/src/app
RUN apt-get update && apt-get install -y --no-install-recommends \
@@ -9,6 +11,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
gcc \
g++ \
libc6-dev \
gettext-base \
ca-certificates \
&& rm -rf /var/lib/apt/lists/*
@@ -19,8 +22,7 @@ COPY rebar.config rebar.lock* ./
RUN rebar3 compile --deps_only
COPY . .
RUN cp config/vm.args.template config/vm.args && \
cp config/sys.config.template config/sys.config && \
RUN LOGGER_LEVEL=${LOGGER_LEVEL} envsubst < config/sys.config.template > config/sys.config && \
rebar3 as prod release
FROM erlang:28-slim
@@ -33,14 +35,14 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
&& rm -rf /var/lib/apt/lists/*
COPY --from=build /usr/src/app/_build/prod/rel/fluxer_gateway .
COPY scripts/docker_entrypoint.sh /opt/fluxer_gateway/bin/docker_entrypoint.sh
RUN useradd -r -s /sbin/nologin -d /opt/fluxer_gateway fluxer && \
RUN chmod +x /opt/fluxer_gateway/bin/docker_entrypoint.sh && \
useradd -r -s /sbin/nologin -d /opt/fluxer_gateway fluxer && \
chown -R fluxer:fluxer /opt/fluxer_gateway
USER fluxer
EXPOSE 8080 8081
ENV RELX_REPLACE_OS_VARS=true
ENTRYPOINT ["/opt/fluxer_gateway/bin/fluxer_gateway", "foreground"]
ENTRYPOINT ["/opt/fluxer_gateway/bin/docker_entrypoint.sh"]

View File

@@ -1,27 +1,9 @@
[
{fluxer_gateway, [
{ws_port, 8080},
{rpc_port, 8081},
{api_host, "${FLUXER_API_HOST}"},
{rpc_secret_key, <<"${GATEWAY_RPC_SECRET}">>},
{max_payload_size, 4096},
{heartbeat_interval, 41250},
{heartbeat_timeout, 45000},
{resume_timeout, 10000},
{identify_rate_limit_enabled, false},
{push_enabled, true},
{push_user_guild_settings_cache_mb, 8192},
{push_subscriptions_cache_mb, 8192},
{push_blocked_ids_cache_mb, 8192},
{push_badge_counts_cache_mb, 256},
{push_badge_counts_cache_ttl_seconds, 60},
{media_proxy_endpoint, "${MEDIA_PROXY_ENDPOINT}"}
]},
{kernel, [
{logger_level, debug},
{logger_level, ${LOGGER_LEVEL}},
{logger, [
{handler, default, logger_std_h, #{
level => debug,
level => ${LOGGER_LEVEL},
config => #{
type => standard_io
}

View File

@@ -0,0 +1,11 @@
${FLUXER_GATEWAY_NODE_FLAG} ${FLUXER_GATEWAY_NODE_NAME}
+K true
+A30
-env ERL_MAX_PORTS 4096
-env ERL_PROCESSES 262144
+P 262144
+Q 65536
+S 4:4
+zdbbl 32768
-kernel logger_level ${LOGGER_LEVEL}
-kernel inet_backend socket

View File

@@ -1,4 +1,4 @@
-name ${RELEASE_NODE}
${FLUXER_GATEWAY_NODE_FLAG} ${FLUXER_GATEWAY_NODE_NAME}
+K true
+A30
-env ERL_MAX_PORTS 4096

View File

@@ -0,0 +1 @@
cowboy_req.erl.*Unknown type ranch

View File

@@ -1,44 +0,0 @@
#!/bin/bash
# Copyright (C) 2026 Fluxer Contributors
#
# This file is part of Fluxer.
#
# Fluxer is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Fluxer is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with Fluxer. If not, see <https://www.gnu.org/licenses/>.
set -e
# Set default values for environment variables if not set
export LOGGER_LEVEL=${LOGGER_LEVEL:-debug}
export RELEASE_NODE=${RELEASE_NODE:-fluxer_gateway@gateway}
# Substitute environment variables in config files
envsubst < /workspace/config/vm.args.template > /workspace/config/vm.args
envsubst < /workspace/config/sys.config.template > /workspace/config/sys.config
# Start inotify watcher in the background for auto-recompilation
(while true; do
inotifywait -r -e modify,create,delete,move src/ config/ 2>/dev/null && \
sleep 0.5 && \
rebar3 compile && \
envsubst < /workspace/config/vm.args.template > /workspace/config/vm.args && \
envsubst < /workspace/config/sys.config.template > /workspace/config/sys.config
done) &
# Start the Erlang application
exec erl -pa _build/default/lib/*/ebin \
-config config/sys.config \
-args_file config/vm.args \
-eval 'application:ensure_all_started(fluxer_gateway)' \
-noshell

View File

@@ -15,12 +15,10 @@
%% You should have received a copy of the GNU Affero General Public License
%% along with Fluxer. If not, see <https://www.gnu.org/licenses/>.
-record(voice_flags, {
self_mute = false :: boolean(),
self_deaf = false :: boolean(),
self_video = false :: boolean(),
self_stream = false :: boolean(),
is_mobile = false :: boolean()
}).
-type voice_flags() :: #voice_flags{}.
-type voice_flags() :: #{
self_mute := boolean(),
self_deaf := boolean(),
self_video := boolean(),
self_stream := boolean(),
is_mobile := boolean()
}.

View File

@@ -1,11 +1,9 @@
{erl_opts, [debug_info, nowarn_deprecated_function]}.
{deps, [
{jsx, "3.1.0"},
{cowboy, "2.14.2"},
{hackney, "1.25.0"},
{base64url, "1.0.1"},
{ezstd, "1.1.0"},
{jose, "1.11.10"}
{jose, "1.11.10"},
{ezstd, "1.1.0"}
]}.
{overrides, [
@@ -21,15 +19,27 @@
{include_erts, false},
{extended_start_script, true},
{sys_config, "./config/sys.config"},
{vm_args, "./config/vm.args"}
{vm_args_src, "./config/vm.args.src"}
]}.
{profiles, [
{prod, [
{deps, [
{opentelemetry_api, "1.5.0"},
{opentelemetry, "1.7.0"},
{opentelemetry_exporter, "1.10.0"},
{opentelemetry_experimental, "0.5.1"},
{opentelemetry_api_experimental, "0.5.1"},
{opentelemetry_semantic_conventions, "1.27.0"}
]},
{relx, [
{dev_mode, false},
{include_erts, true}
]}
]},
{erl_opts, [{d, 'HAS_OPENTELEMETRY'}]}
]},
{dev, [
{erl_opts, [debug_info, nowarn_deprecated_function, {d, 'DEV_MODE'}]}
]}
]}.
@@ -43,5 +53,8 @@
{erlfmt, [write]}.
{dialyzer, [
{plt_extra_apps, [ezstd, jose]}
{plt_extra_apps, [
jose, ranch
]},
{warnings_file, "dialyzer.ignore-warnings"}
]}.

View File

@@ -1,50 +1,23 @@
{"1.2.0",
[{<<"base64url">>,{pkg,<<"base64url">>,<<"1.0.1">>},0},
{<<"certifi">>,{pkg,<<"certifi">>,<<"2.15.0">>},1},
{<<"cowboy">>,{pkg,<<"cowboy">>,<<"2.14.2">>},0},
{<<"cowlib">>,{pkg,<<"cowlib">>,<<"2.16.0">>},1},
{<<"ezstd">>,{pkg,<<"ezstd">>,<<"1.1.0">>},0},
{<<"hackney">>,{pkg,<<"hackney">>,<<"1.25.0">>},0},
{<<"idna">>,{pkg,<<"idna">>,<<"6.1.1">>},1},
{<<"jose">>,{pkg,<<"jose">>,<<"1.11.10">>},0},
{<<"jsx">>,{pkg,<<"jsx">>,<<"3.1.0">>},0},
{<<"metrics">>,{pkg,<<"metrics">>,<<"1.0.1">>},1},
{<<"mimerl">>,{pkg,<<"mimerl">>,<<"1.4.0">>},1},
{<<"parse_trans">>,{pkg,<<"parse_trans">>,<<"3.4.1">>},1},
{<<"ranch">>,{pkg,<<"ranch">>,<<"2.2.0">>},1},
{<<"ssl_verify_fun">>,{pkg,<<"ssl_verify_fun">>,<<"1.1.7">>},1},
{<<"unicode_util_compat">>,{pkg,<<"unicode_util_compat">>,<<"0.7.1">>},1}]}.
{<<"ranch">>,{pkg,<<"ranch">>,<<"2.2.0">>},1}]}.
[
{pkg_hash,[
{<<"base64url">>, <<"F8C7F2DA04CA9A5D0F5F50258F055E1D699F0E8BF4CFDB30B750865368403CF6">>},
{<<"certifi">>, <<"0E6E882FCDAAA0A5A9F2B3DB55B1394DBA07E8D6D9BCAD08318FB604C6839712">>},
{<<"cowboy">>, <<"4008BE1DF6ADE45E4F2A4E9E2D22B36D0B5ABA4E20B0A0D7049E28D124E34847">>},
{<<"cowlib">>, <<"54592074EBBBB92EE4746C8A8846E5605052F29309D3A873468D76CDF932076F">>},
{<<"ezstd">>, <<"D3B483D6ACFADFB65DBA4015371E6D54526DBF3D9EF0941B5ADD8BF5890731F4">>},
{<<"hackney">>, <<"390E9B83F31E5B325B9F43B76E1A785CBDB69B5B6CD4E079AA67835DED046867">>},
{<<"idna">>, <<"8A63070E9F7D0C62EB9D9FCB360A7DE382448200FBBD1B106CC96D3D8099DF8D">>},
{<<"jose">>, <<"A903F5227417BD2A08C8A00A0CBCC458118BE84480955E8D251297A425723F83">>},
{<<"jsx">>, <<"D12516BAA0BB23A59BB35DCCAF02A1BD08243FCBB9EFE24F2D9D056CCFF71268">>},
{<<"metrics">>, <<"25F094DEA2CDA98213CECC3AEFF09E940299D950904393B2A29D191C346A8486">>},
{<<"mimerl">>, <<"3882A5CA67FBBE7117BA8947F27643557ADEC38FA2307490C4C4207624CB213B">>},
{<<"parse_trans">>, <<"6E6AA8167CB44CC8F39441D05193BE6E6F4E7C2946CB2759F015F8C56B76E5FF">>},
{<<"ranch">>, <<"25528F82BC8D7C6152C57666CA99EC716510FE0925CB188172F41CE93117B1B0">>},
{<<"ssl_verify_fun">>, <<"354C321CF377240C7B8716899E182CE4890C5938111A1296ADD3EC74CF1715DF">>},
{<<"unicode_util_compat">>, <<"A48703A25C170EEDADCA83B11E88985AF08D35F37C6F664D6DCFB106A97782FC">>}]},
{<<"ranch">>, <<"25528F82BC8D7C6152C57666CA99EC716510FE0925CB188172F41CE93117B1B0">>}]},
{pkg_hash_ext,[
{<<"base64url">>, <<"F9B3ADD4731A02A9B0410398B475B33E7566A695365237A6BDEE1BB447719F5C">>},
{<<"certifi">>, <<"B147ED22CE71D72EAFDAD94F055165C1C182F61A2FF49DF28BCC71D1D5B94A60">>},
{<<"cowboy">>, <<"569081DA046E7B41B5DF36AA359BE71A0C8874E5B9CFF6F747073FC57BAF1AB9">>},
{<<"cowlib">>, <<"7F478D80D66B747344F0EA7708C187645CFCC08B11AA424632F78E25BF05DB51">>},
{<<"ezstd">>, <<"28CFA0ED6CC3922095AD5BA0F23392A1664273358B17184BAA909868361184E7">>},
{<<"hackney">>, <<"7209BFD75FD1F42467211FF8F59EA74D6F2A9E81CBCEE95A56711EE79FD6B1D4">>},
{<<"idna">>, <<"92376EB7894412ED19AC475E4A86F7B413C1B9FBB5BD16DCCD57934157944CEA">>},
{<<"jose">>, <<"0D6CD36FF8BA174DB29148FC112B5842186B68A90CE9FC2B3EC3AFE76593E614">>},
{<<"jsx">>, <<"0C5CC8FDC11B53CC25CF65AC6705AD39E54ECC56D1C22E4ADB8F5A53FB9427F3">>},
{<<"metrics">>, <<"69B09ADDDC4F74A40716AE54D140F93BEB0FB8978D8636EADED0C31B6F099F16">>},
{<<"mimerl">>, <<"13AF15F9F68C65884ECCA3A3891D50A7B57D82152792F3E19D88650AA126B144">>},
{<<"parse_trans">>, <<"620A406CE75DADA827B82E453C19CF06776BE266F5A67CFF34E1EF2CBB60E49A">>},
{<<"ranch">>, <<"FA0B99A1780C80218A4197A59EA8D3BDAE32FBFF7E88527D7D8A4787EFF4F8E7">>},
{<<"ssl_verify_fun">>, <<"FE4C190E8F37401D30167C8C405EDA19469F34577987C76DDE613E838BBC67F8">>},
{<<"unicode_util_compat">>, <<"B3A917854CE3AE233619744AD1E0102E05673136776FB2FA76234F3E03B23642">>}]}
{<<"ranch">>, <<"FA0B99A1780C80218A4197A59EA8D3BDAE32FBFF7E88527D7D8A4787EFF4F8E7">>}]}
].

View File

@@ -0,0 +1,99 @@
#!/usr/bin/env sh
set -eu
NODE_BASE_NAME="${FLUXER_GATEWAY_NODE_BASENAME:-fluxer_gateway}"
if [ -z "${FLUXER_GATEWAY_NODE_FLAG:-}" ] && [ -n "${FLUXER_GATEWAY_NODE_NAME:-}" ]; then
case "${FLUXER_GATEWAY_NODE_NAME}" in
*@*)
FLUXER_GATEWAY_NODE_FLAG="-name"
;;
*)
FLUXER_GATEWAY_NODE_FLAG="-sname"
;;
esac
export FLUXER_GATEWAY_NODE_FLAG
fi
if [ -z "${FLUXER_GATEWAY_NODE_HOST:-}" ]; then
if [ -n "${HOSTNAME:-}" ]; then
FLUXER_GATEWAY_NODE_HOST="$HOSTNAME"
else
FLUXER_GATEWAY_NODE_HOST="$(hostname)"
fi
export FLUXER_GATEWAY_NODE_HOST
fi
if [ -n "${FLUXER_GATEWAY_NODE_FLAG:-}" ]; then
case "$FLUXER_GATEWAY_NODE_FLAG" in
-name | -sname)
;;
*)
echo "Invalid FLUXER_GATEWAY_NODE_FLAG: $FLUXER_GATEWAY_NODE_FLAG" >&2
exit 64
;;
esac
fi
if [ -z "${FLUXER_GATEWAY_NODE_FLAG:-}" ]; then
NODE_MODE=""
if [ -n "${FLUXER_GATEWAY_NODE_MODE:-}" ]; then
NODE_MODE="$FLUXER_GATEWAY_NODE_MODE"
fi
if [ -z "$NODE_MODE" ]; then
FQDN_HOST=""
if command -v hostname >/dev/null 2>&1; then
FQDN_HOST="$(hostname -f 2>/dev/null || true)"
fi
if [ -n "$FQDN_HOST" ] && printf '%s' "$FQDN_HOST" | grep -q '\.'; then
NODE_MODE="long"
FLUXER_GATEWAY_NODE_HOST="$FQDN_HOST"
else
if printf '%s' "$FLUXER_GATEWAY_NODE_HOST" | grep -q '\.'; then
NODE_MODE="long"
else
NODE_MODE="short"
fi
fi
fi
case "$NODE_MODE" in
long)
FLUXER_GATEWAY_NODE_FLAG="-name"
;;
short)
FLUXER_GATEWAY_NODE_FLAG="-sname"
;;
*)
echo "Invalid FLUXER_GATEWAY_NODE_MODE: $NODE_MODE" >&2
exit 64
;;
esac
export FLUXER_GATEWAY_NODE_FLAG
export FLUXER_GATEWAY_NODE_HOST
fi
if [ -z "${FLUXER_GATEWAY_NODE_NAME:-}" ]; then
if [ "$FLUXER_GATEWAY_NODE_FLAG" = "-name" ]; then
FLUXER_GATEWAY_NODE_NAME="${NODE_BASE_NAME}@${FLUXER_GATEWAY_NODE_HOST}"
else
SAFE_HOST="$(printf '%s' "$FLUXER_GATEWAY_NODE_HOST" | tr -c 'A-Za-z0-9' '_' | tr 'A-Z' 'a-z')"
FLUXER_GATEWAY_NODE_NAME="${NODE_BASE_NAME}_${SAFE_HOST}"
fi
export FLUXER_GATEWAY_NODE_NAME
fi
if [ "$FLUXER_GATEWAY_NODE_FLAG" = "-sname" ]; then
case "$FLUXER_GATEWAY_NODE_NAME" in
*@*)
echo "FLUXER_GATEWAY_NODE_NAME must not include '@' when using -sname." >&2
exit 64
;;
esac
fi
exec /opt/fluxer_gateway/bin/fluxer_gateway foreground

View File

@@ -0,0 +1,61 @@
#!/usr/bin/env bash
# Copyright (C) 2026 Fluxer Contributors
#
# This file is part of Fluxer.
#
# Fluxer is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Fluxer is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with Fluxer. If not, see <https://www.gnu.org/licenses/>.
set -euo pipefail
if [ "$#" -eq 0 ]; then
echo "Usage: $0 <rebar3 args...>" >&2
exit 64
fi
ASDF_SHIMS_PATH="${ASDF_DATA_DIR:-$HOME/.asdf}/shims"
if [ -d "$ASDF_SHIMS_PATH" ]; then
export PATH="$ASDF_SHIMS_PATH:$PATH"
fi
SCRIPT_DIR=$(cd "$(dirname "$0")" && pwd)
if [ -f "$SCRIPT_DIR/../rebar.config" ]; then
GATEWAY_DIR=$(cd "$SCRIPT_DIR/.." && pwd)
REPO_ROOT=$(cd "$GATEWAY_DIR/.." && pwd)
else
REPO_ROOT=$(cd "$SCRIPT_DIR/../.." && pwd)
GATEWAY_DIR="$REPO_ROOT/fluxer_gateway"
fi
if [ -z "${FLUXER_CONFIG:-}" ] && [ -f "$REPO_ROOT/config/config.test.json" ]; then
export FLUXER_CONFIG="$REPO_ROOT/config/config.test.json"
fi
should_skip_plugins=true
for arg in "$@"; do
case "$arg" in
fmt | plugins)
should_skip_plugins=false
break
;;
esac
done
if [ "$should_skip_plugins" = true ]; then
export REBAR_SKIP_PROJECT_PLUGINS=1
fi
cd "$GATEWAY_DIR"
exec rebar3 "$@"

View File

@@ -0,0 +1,23 @@
#!/usr/bin/env bash
# Copyright (C) 2026 Fluxer Contributors
#
# This file is part of Fluxer.
#
# Fluxer is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Fluxer is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with Fluxer. If not, see <https://www.gnu.org/licenses/>.
set -euo pipefail
SCRIPT_DIR=$(cd "$(dirname "$0")" && pwd)
"$SCRIPT_DIR/rebar3_wrapper.sh" compile

View File

@@ -0,0 +1,23 @@
#!/usr/bin/env bash
# Copyright (C) 2026 Fluxer Contributors
#
# This file is part of Fluxer.
#
# Fluxer is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Fluxer is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with Fluxer. If not, see <https://www.gnu.org/licenses/>.
set -euo pipefail
SCRIPT_DIR=$(cd "$(dirname "$0")" && pwd)
"$SCRIPT_DIR/rebar3_wrapper.sh" dialyzer

View File

@@ -0,0 +1,23 @@
#!/usr/bin/env bash
# Copyright (C) 2026 Fluxer Contributors
#
# This file is part of Fluxer.
#
# Fluxer is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Fluxer is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with Fluxer. If not, see <https://www.gnu.org/licenses/>.
set -euo pipefail
SCRIPT_DIR=$(cd "$(dirname "$0")" && pwd)
"$SCRIPT_DIR/rebar3_wrapper.sh" eunit

View File

@@ -0,0 +1,23 @@
#!/usr/bin/env bash
# Copyright (C) 2026 Fluxer Contributors
#
# This file is part of Fluxer.
#
# Fluxer is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Fluxer is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with Fluxer. If not, see <https://www.gnu.org/licenses/>.
set -euo pipefail
SCRIPT_DIR=$(cd "$(dirname "$0")" && pwd)
"$SCRIPT_DIR/rebar3_wrapper.sh" fmt

View File

@@ -0,0 +1,25 @@
#!/usr/bin/env bash
# Copyright (C) 2026 Fluxer Contributors
#
# This file is part of Fluxer.
#
# Fluxer is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Fluxer is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with Fluxer. If not, see <https://www.gnu.org/licenses/>.
set -euo pipefail
SCRIPT_DIR=$(cd "$(dirname "$0")" && pwd)
"$SCRIPT_DIR/rebar3_wrapper.sh" as prod clean -a
"$SCRIPT_DIR/rebar3_wrapper.sh" as prod compile

File diff suppressed because it is too large Load Diff

View File

@@ -23,6 +23,10 @@
-export([start_link/0]).
-export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]).
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-endif.
-type channel_id() :: integer().
-type call_ref() :: {pid(), reference()}.
-type call_data() :: map().
@@ -35,27 +39,10 @@ start_link() ->
-spec init([]) -> {ok, state()}.
init([]) ->
process_flag(trap_exit, true),
ets:new(call_pid_cache, [named_table, public, set]),
{ok, #{calls => #{}}}.
-spec handle_call(Request, From, State) -> Result when
Request ::
{create, channel_id(), call_data()}
| {lookup, channel_id()}
| {get_or_create, channel_id(), call_data()}
| {terminate_call, channel_id()}
| get_local_count
| get_global_count
| term(),
From :: gen_server:from(),
State :: state(),
Result :: {reply, Reply, state()},
Reply ::
{ok, pid()}
| {error, already_exists}
| {error, not_found}
| {error, term()}
| ok
| {ok, non_neg_integer()}.
-spec handle_call(term(), gen_server:from(), state()) -> {reply, term(), state()}.
handle_call({create, ChannelId, CallData}, _From, State) ->
do_create_call(ChannelId, CallData, State);
handle_call({lookup, ChannelId}, _From, State) ->
@@ -75,25 +62,18 @@ handle_call(_Request, _From, State) ->
handle_cast(_Msg, State) ->
{noreply, State}.
-spec handle_info(Info, State) -> {noreply, state()} when
Info :: {'DOWN', reference(), process, pid(), term()} | term(),
State :: state().
-spec handle_info(term(), state()) -> {noreply, state()}.
handle_info({'DOWN', _Ref, process, Pid, _Reason}, #{calls := Calls} = State) ->
NewCalls = process_registry:cleanup_on_down(Pid, Calls),
{noreply, State#{calls := NewCalls}};
handle_info(_Info, State) ->
{noreply, State}.
-spec terminate(Reason, State) -> ok when
Reason :: term(),
State :: state().
-spec terminate(term(), state()) -> ok.
terminate(_Reason, #{calls := _Calls}) ->
ok.
-spec code_change(OldVsn, State, Extra) -> {ok, state()} when
OldVsn :: term(),
State :: state() | {state, map()},
Extra :: term().
-spec code_change(term(), state() | {state, map()}, term()) -> {ok, state()}.
code_change(_OldVsn, {state, Calls}, _Extra) ->
{ok, #{calls => Calls}};
code_change(_OldVsn, State, _Extra) ->
@@ -118,6 +98,7 @@ do_create_call(ChannelId, CallData, #{calls := Calls} = State) ->
ChannelId, {RegisteredPid, Ref}, CleanCalls
),
NewState = State#{calls := NewCalls},
ets:insert(call_pid_cache, {ChannelId, RegisteredPid}),
{reply, {ok, RegisteredPid}, NewState};
{error, Reason} ->
{reply, {error, Reason}, State}
@@ -132,13 +113,31 @@ do_create_call(ChannelId, CallData, #{calls := Calls} = State) ->
-spec do_lookup_call(channel_id(), state()) -> {reply, {ok, pid()} | {error, not_found}, state()}.
do_lookup_call(ChannelId, #{calls := Calls} = State) ->
case ets:lookup(call_pid_cache, ChannelId) of
[{ChannelId, Pid}] when is_pid(Pid) ->
case is_process_alive(Pid) of
true ->
{reply, {ok, Pid}, State};
false ->
ets:delete(call_pid_cache, ChannelId),
do_lookup_call_fallback(ChannelId, Calls, State)
end;
_ ->
do_lookup_call_fallback(ChannelId, Calls, State)
end.
-spec do_lookup_call_fallback(channel_id(), map(), state()) ->
{reply, {ok, pid()} | {error, not_found}, state()}.
do_lookup_call_fallback(ChannelId, Calls, State) ->
case maps:get(ChannelId, Calls, undefined) of
{Pid, _Ref} when is_pid(Pid) ->
ets:insert(call_pid_cache, {ChannelId, Pid}),
{reply, {ok, Pid}, State};
undefined ->
CallName = process_registry:build_process_name(call, ChannelId),
case process_registry:lookup_or_monitor(CallName, ChannelId, Calls) of
{ok, Pid, _Ref, NewCalls} ->
ets:insert(call_pid_cache, {ChannelId, Pid}),
{reply, {ok, Pid}, State#{calls := NewCalls}};
{error, not_found} ->
{reply, {error, not_found}, State}
@@ -163,8 +162,17 @@ do_terminate_call(ChannelId, #{calls := Calls} = State) ->
gen_server:stop(Pid, normal, ?SHUTDOWN_TIMEOUT),
CallName = process_registry:build_process_name(call, ChannelId),
process_registry:safe_unregister(CallName),
ets:delete(call_pid_cache, ChannelId),
NewCalls = maps:remove(ChannelId, Calls),
{reply, ok, State#{calls := NewCalls}};
undefined ->
{reply, {error, not_found}, State}
end.
-ifdef(TEST).
state_operations_test() ->
State = #{calls => #{}},
?assertEqual(#{}, maps:get(calls, State)).
-endif.

View File

@@ -10,12 +10,10 @@
public_key,
ssl,
inets,
jsx,
jose,
cowboy,
hackney,
base64url,
ezstd
ezstd,
base64url
]},
{env, []},
{modules, []},

View File

@@ -19,36 +19,25 @@
-behaviour(application).
-export([start/2, stop/1]).
-spec start(application:start_type(), term()) -> {ok, pid()} | {error, term()}.
start(_StartType, _StartArgs) ->
fluxer_gateway_env:load(),
WsPort = fluxer_gateway_env:get(ws_port),
RpcPort = fluxer_gateway_env:get(rpc_port),
otel_metrics:init(),
Port = fluxer_gateway_env:get(port),
Dispatch = cowboy_router:compile([
{'_', [
{<<"/_health">>, health_handler, []},
{<<"/_rpc">>, gateway_rpc_http_handler, []},
{<<"/_admin/reload">>, hot_reload_handler, []},
{<<"/">>, gateway_handler, []}
]}
]),
{ok, _} = cowboy:start_clear(http, [{port, WsPort}], #{
{ok, _} = cowboy:start_clear(http, [{port, Port}], #{
env => #{dispatch => Dispatch},
max_frame_size => 4096
}),
RpcDispatch = cowboy_router:compile([
{'_', [
{<<"/_rpc">>, gateway_rpc_http_handler, []},
{<<"/_admin/reload">>, hot_reload_handler, []}
]}
]),
{ok, _} = cowboy:start_clear(rpc_http, [{port, RpcPort}], #{
env => #{dispatch => RpcDispatch}
}),
fluxer_gateway_sup:start_link().
-spec stop(term()) -> ok.
stop(_State) ->
ok.

View File

@@ -0,0 +1,311 @@
%% Copyright (C) 2026 Fluxer Contributors
%%
%% This file is part of Fluxer.
%%
%% Fluxer is free software: you can redistribute it and/or modify
%% it under the terms of the GNU Affero General Public License as published by
%% the Free Software Foundation, either version 3 of the License, or
%% (at your option) any later version.
%%
%% Fluxer is distributed in the hope that it will be useful,
%% but WITHOUT ANY WARRANTY; without even the implied warranty of
%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
%% GNU Affero General Public License for more details.
%%
%% You should have received a copy of the GNU Affero General Public License
%% along with Fluxer. If not, see <https://www.gnu.org/licenses/>.
-module(fluxer_gateway_config).
-export([load/0, load_from/1]).
-type config() :: map().
-type log_level() :: debug | info | notice | warning | error | critical | alert | emergency.
-spec load() -> config().
load() ->
case os:getenv("FLUXER_CONFIG") of
false -> erlang:error({missing_env, "FLUXER_CONFIG"});
"" -> erlang:error({missing_env, "FLUXER_CONFIG"});
Path -> load_from(Path)
end.
-spec load_from(string()) -> config().
load_from(Path) when is_list(Path) ->
case file:read_file(Path) of
{ok, Content} ->
Json = json:decode(Content),
build_config(Json);
{error, Reason} ->
erlang:error({json_read_failed, Path, Reason})
end.
-spec build_config(map()) -> config().
build_config(Json) ->
Service = get_map(Json, [<<"services">>, <<"gateway">>]),
Gateway = get_map(Json, [<<"gateway">>]),
Telemetry = get_map(Json, [<<"telemetry">>]),
Sentry = get_map(Json, [<<"sentry">>]),
Vapid = get_map(Json, [<<"auth">>, <<"vapid">>]),
#{
port => get_int(Service, <<"port">>, 8080),
rpc_tcp_port => get_int(Service, <<"rpc_tcp_port">>, 8772),
api_host => get_env_or_string("FLUXER_GATEWAY_API_HOST", Service, <<"api_host">>, "api"),
api_canary_host => get_optional_string(Service, <<"api_canary_host">>),
admin_reload_secret => get_optional_binary(Service, <<"admin_reload_secret">>),
rpc_secret_key => get_binary(Gateway, <<"rpc_secret">>, undefined),
identify_rate_limit_enabled => get_bool(Service, <<"identify_rate_limit_enabled">>, false),
push_enabled => get_bool(Service, <<"push_enabled">>, true),
push_user_guild_settings_cache_mb => get_int(
Service,
<<"push_user_guild_settings_cache_mb">>,
1024
),
push_subscriptions_cache_mb => get_int(Service, <<"push_subscriptions_cache_mb">>, 1024),
push_blocked_ids_cache_mb => get_int(Service, <<"push_blocked_ids_cache_mb">>, 1024),
presence_cache_shards => get_optional_int(Service, <<"presence_cache_shards">>),
presence_bus_shards => get_optional_int(Service, <<"presence_bus_shards">>),
presence_shards => get_optional_int(Service, <<"presence_shards">>),
guild_shards => get_optional_int(Service, <<"guild_shards">>),
session_shards => get_optional_int(Service, <<"session_shards">>),
push_badge_counts_cache_mb => get_int(Service, <<"push_badge_counts_cache_mb">>, 256),
push_badge_counts_cache_ttl_seconds =>
get_int(Service, <<"push_badge_counts_cache_ttl_seconds">>, 60),
push_dispatcher_max_inflight => get_int(Service, <<"push_dispatcher_max_inflight">>, 16),
push_dispatcher_max_queue => get_int(Service, <<"push_dispatcher_max_queue">>, 2048),
gateway_http_rpc_connect_timeout_ms =>
get_int(Service, <<"gateway_http_rpc_connect_timeout_ms">>, 5000),
gateway_http_rpc_recv_timeout_ms =>
get_int(Service, <<"gateway_http_rpc_recv_timeout_ms">>, 30000),
gateway_http_push_connect_timeout_ms =>
get_int(Service, <<"gateway_http_push_connect_timeout_ms">>, 3000),
gateway_http_push_recv_timeout_ms =>
get_int(Service, <<"gateway_http_push_recv_timeout_ms">>, 5000),
gateway_http_rpc_max_concurrency =>
get_int(Service, <<"gateway_http_rpc_max_concurrency">>, 512),
gateway_rpc_tcp_max_input_buffer_bytes =>
get_int(Service, <<"gateway_rpc_tcp_max_input_buffer_bytes">>, 2097152),
gateway_http_push_max_concurrency =>
get_int(Service, <<"gateway_http_push_max_concurrency">>, 256),
gateway_http_failure_threshold =>
get_int(Service, <<"gateway_http_failure_threshold">>, 6),
gateway_http_recovery_timeout_ms =>
get_int(Service, <<"gateway_http_recovery_timeout_ms">>, 15000),
gateway_http_cleanup_interval_ms =>
get_int(Service, <<"gateway_http_cleanup_interval_ms">>, 30000),
gateway_http_cleanup_max_age_ms =>
get_int(Service, <<"gateway_http_cleanup_max_age_ms">>, 300000),
media_proxy_endpoint => get_optional_binary(Service, <<"media_proxy_endpoint">>),
vapid_email => get_binary(Vapid, <<"email">>, <<>>),
vapid_public_key => get_optional_binary(Vapid, <<"public_key">>),
vapid_private_key => get_optional_binary(Vapid, <<"private_key">>),
gateway_metrics_enabled => get_optional_bool(Service, <<"gateway_metrics_enabled">>),
gateway_metrics_report_interval_ms =>
get_optional_int(Service, <<"gateway_metrics_report_interval_ms">>),
release_node => get_string(Service, <<"release_node">>, "fluxer_gateway@127.0.0.1"),
logger_level => get_log_level(Service, <<"logger_level">>, info),
telemetry => #{
enabled => get_bool(Telemetry, <<"enabled">>, true),
otlp_endpoint => get_string(Telemetry, <<"otlp_endpoint">>, ""),
api_key => get_string(Telemetry, <<"api_key">>, ""),
service_name => get_string(Telemetry, <<"service_name">>, "fluxer-gateway"),
environment => get_string(Telemetry, <<"environment">>, "development"),
trace_sampling_ratio => get_float(Telemetry, <<"trace_sampling_ratio">>, 1.0)
},
sentry => #{
build_sha => get_string(Sentry, <<"build_sha">>, ""),
release_channel => get_string(Sentry, <<"release_channel">>, "")
}
}.
-spec get_map(map(), [binary()]) -> map().
get_map(Map, Keys) ->
case get_in(Map, Keys) of
Value when is_map(Value) -> Value;
_ -> #{}
end.
-spec get_int(map(), binary(), integer()) -> integer().
get_int(Map, Key, Default) when is_integer(Default) ->
to_int(get_value(Map, Key), Default).
-spec get_optional_int(map(), binary()) -> integer() | undefined.
get_optional_int(Map, Key) ->
to_optional_int(get_value(Map, Key)).
-spec get_bool(map(), binary(), boolean()) -> boolean().
get_bool(Map, Key, Default) when is_boolean(Default) ->
to_bool(get_value(Map, Key), Default).
-spec get_optional_bool(map(), binary()) -> boolean() | undefined.
get_optional_bool(Map, Key) ->
case get_value(Map, Key) of
undefined -> undefined;
Value -> to_bool(Value, undefined)
end.
-spec get_string(map(), binary(), string()) -> string().
get_string(Map, Key, Default) when is_list(Default) ->
to_string(get_value(Map, Key), Default).
-spec get_env_or_string(string(), map(), binary(), string()) -> string().
get_env_or_string(EnvVar, Map, Key, Default) when is_list(EnvVar), is_list(Default) ->
case os:getenv(EnvVar) of
false -> get_string(Map, Key, Default);
"" -> get_string(Map, Key, Default);
Value -> Value
end.
-spec get_optional_string(map(), binary()) -> string() | undefined.
get_optional_string(Map, Key) ->
case get_value(Map, Key) of
undefined ->
undefined;
Value ->
Clean = string:trim(to_string(Value, "")),
case Clean of
"" -> undefined;
_ -> Clean
end
end.
-spec get_binary(map(), binary(), binary() | undefined) -> binary() | undefined.
get_binary(Map, Key, Default) ->
to_binary(get_value(Map, Key), Default).
-spec get_optional_binary(map(), binary()) -> binary() | undefined.
get_optional_binary(Map, Key) ->
case get_value(Map, Key) of
undefined -> undefined;
Value -> to_binary(Value, undefined)
end.
-spec get_log_level(map(), binary(), log_level()) -> log_level().
get_log_level(Map, Key, Default) when is_atom(Default) ->
Value = get_value(Map, Key),
case normalize_log_level(Value) of
undefined -> Default;
Level -> Level
end.
-spec get_float(map(), binary(), number()) -> float().
get_float(Map, Key, Default) when is_number(Default) ->
to_float(get_value(Map, Key), Default).
-spec get_in(term(), [binary()]) -> term().
get_in(Map, [Key | Rest]) when is_map(Map) ->
case get_value(Map, Key) of
undefined -> undefined;
Value when Rest =:= [] -> Value;
Value -> get_in(Value, Rest)
end;
get_in(_, _) ->
undefined.
-spec get_value(term(), binary()) -> term().
get_value(Map, Key) when is_map(Map) ->
case maps:get(Key, Map, undefined) of
undefined when is_binary(Key) ->
maps:get(binary_to_list(Key), Map, undefined);
Value ->
Value
end.
-spec to_int(term(), integer() | undefined) -> integer() | undefined.
to_int(Value, _Default) when is_integer(Value) ->
Value;
to_int(Value, _Default) when is_float(Value) ->
trunc(Value);
to_int(Value, Default) ->
case to_string(Value, "") of
"" ->
Default;
Str ->
case string:to_integer(Str) of
{Int, _} when is_integer(Int) -> Int;
{error, _} -> Default
end
end.
-spec to_optional_int(term()) -> integer() | undefined.
to_optional_int(Value) ->
case to_int(Value, undefined) of
undefined -> undefined;
Int -> Int
end.
-spec to_bool(term(), boolean() | undefined) -> boolean() | undefined.
to_bool(Value, _Default) when is_boolean(Value) ->
Value;
to_bool(Value, Default) when is_atom(Value) ->
case Value of
true -> true;
false -> false;
_ -> Default
end;
to_bool(Value, Default) ->
case string:lowercase(to_string(Value, "")) of
"true" -> true;
"1" -> true;
"false" -> false;
"0" -> false;
_ -> Default
end.
-spec to_string(term(), string()) -> string().
to_string(Value, Default) when is_list(Default) ->
case Value of
undefined -> Default;
Bin when is_binary(Bin) -> binary_to_list(Bin);
Str when is_list(Str) -> Str;
Atom when is_atom(Atom) -> atom_to_list(Atom);
_ -> Default
end.
-spec to_binary(term(), binary() | undefined) -> binary() | undefined.
to_binary(Value, Default) ->
case Value of
undefined -> Default;
Bin when is_binary(Bin) -> Bin;
Str when is_list(Str) -> list_to_binary(Str);
Atom when is_atom(Atom) -> list_to_binary(atom_to_list(Atom));
_ -> Default
end.
-spec to_float(term(), float()) -> float().
to_float(Value, _Default) when is_float(Value) ->
Value;
to_float(Value, _Default) when is_integer(Value) ->
float(Value);
to_float(Value, Default) ->
case to_string(Value, "") of
"" ->
Default;
Str ->
case string:to_float(Str) of
{Float, _} when is_float(Float) -> Float;
{error, _} -> Default
end
end.
-spec normalize_log_level(term()) -> log_level() | undefined.
normalize_log_level(undefined) ->
undefined;
normalize_log_level(Level) when is_atom(Level) ->
normalize_log_level(atom_to_list(Level));
normalize_log_level(Level) when is_binary(Level) ->
normalize_log_level(binary_to_list(Level));
normalize_log_level(Level) when is_list(Level) ->
case string:lowercase(string:trim(Level)) of
"debug" -> debug;
"info" -> info;
"notice" -> notice;
"warning" -> warning;
"error" -> error;
"critical" -> critical;
"alert" -> alert;
"emergency" -> emergency;
_ -> undefined
end;
normalize_log_level(_) ->
undefined.

View File

@@ -0,0 +1,302 @@
%% Copyright (C) 2026 Fluxer Contributors
%%
%% This file is part of Fluxer.
%%
%% Fluxer is free software: you can redistribute it and/or modify
%% it under the terms of the GNU Affero General Public License as published by
%% the Free Software Foundation, either version 3 of the License, or
%% (at your option) any later version.
%%
%% Fluxer is distributed in the hope that it will be useful,
%% but WITHOUT ANY WARRANTY; without even the implied warranty of
%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
%% GNU Affero General Public License for more details.
%%
%% You should have received a copy of the GNU Affero General Public License
%% along with Fluxer. If not, see <https://www.gnu.org/licenses/>.
-module(fluxer_gateway_crypto).
-export([
init/0,
decrypt/2,
encrypt/2,
derive_shared_secret/2,
generate_keypair/0,
get_public_key/0,
new_crypto_state/1,
is_encrypted_frame/1,
unwrap_encrypted_frame/1,
wrap_encrypted_frame/1
]).
-define(KEYPAIR_KEY, {?MODULE, instance_keypair}).
-define(ENCRYPTED_FRAME_PREFIX, 16#FE).
-define(NONCE_SIZE, 12).
-define(TAG_SIZE, 16).
-define(KEY_SIZE, 32).
-type keypair() :: #{public := binary(), private := binary()}.
-type crypto_state() :: #{
shared_secret := binary(),
send_counter := non_neg_integer(),
recv_counter := non_neg_integer()
}.
-export_type([keypair/0, crypto_state/0]).
-spec init() -> ok.
init() ->
case persistent_term:get(?KEYPAIR_KEY, undefined) of
undefined ->
Keypair = generate_keypair(),
persistent_term:put(?KEYPAIR_KEY, Keypair),
ok;
_ ->
ok
end.
-spec generate_keypair() -> keypair().
generate_keypair() ->
{Public, Private} = crypto:generate_key(ecdh, x25519),
#{public => Public, private => Private}.
-spec get_public_key() -> binary() | undefined.
get_public_key() ->
case persistent_term:get(?KEYPAIR_KEY, undefined) of
undefined -> undefined;
#{public := Public} -> Public
end.
-spec derive_shared_secret(binary(), keypair()) -> {ok, binary()} | {error, term()}.
derive_shared_secret(PeerPublic, #{private := Private}) when
byte_size(PeerPublic) =:= ?KEY_SIZE
->
try
SharedSecret = crypto:compute_key(ecdh, PeerPublic, Private, x25519),
{ok, SharedSecret}
catch
error:Reason ->
{error, {key_exchange_failed, Reason}}
end;
derive_shared_secret(PeerPublic, _Keypair) ->
{error, {invalid_peer_key_size, byte_size(PeerPublic)}}.
-spec new_crypto_state(binary()) -> crypto_state().
new_crypto_state(SharedSecret) when byte_size(SharedSecret) =:= ?KEY_SIZE ->
#{
shared_secret => SharedSecret,
send_counter => 0,
recv_counter => 0
}.
-spec encrypt(binary(), crypto_state()) ->
{ok, binary(), crypto_state()} | {error, term()}.
encrypt(Plaintext, State = #{shared_secret := Key, send_counter := Counter}) ->
try
Nonce = counter_to_nonce(Counter),
AAD = <<>>,
{Ciphertext, Tag} = crypto:crypto_one_time_aead(
aes_256_gcm,
Key,
Nonce,
Plaintext,
AAD,
?TAG_SIZE,
true
),
Encrypted = <<Nonce/binary, Tag/binary, Ciphertext/binary>>,
NewState = State#{send_counter => Counter + 1},
{ok, Encrypted, NewState}
catch
error:Reason ->
{error, {encrypt_failed, Reason}}
end.
-spec decrypt(binary(), crypto_state()) ->
{ok, binary(), crypto_state()} | {error, term()}.
decrypt(Data, State = #{shared_secret := Key, recv_counter := Counter}) ->
MinSize = ?NONCE_SIZE + ?TAG_SIZE,
case byte_size(Data) > MinSize of
false ->
{error, {invalid_encrypted_data, too_short}};
true ->
<<Nonce:?NONCE_SIZE/binary, Tag:?TAG_SIZE/binary, Ciphertext/binary>> = Data,
ExpectedNonce = counter_to_nonce(Counter),
case validate_nonce(Nonce, ExpectedNonce, Counter) of
{ok, ActualCounter} ->
do_decrypt(Ciphertext, Key, Nonce, Tag, State, ActualCounter);
{error, Reason} ->
{error, Reason}
end
end.
-spec do_decrypt(binary(), binary(), binary(), binary(), crypto_state(), non_neg_integer()) ->
{ok, binary(), crypto_state()} | {error, term()}.
do_decrypt(Ciphertext, Key, Nonce, Tag, State, ActualCounter) ->
AAD = <<>>,
try
case crypto:crypto_one_time_aead(
aes_256_gcm,
Key,
Nonce,
Ciphertext,
AAD,
Tag,
false
) of
Plaintext when is_binary(Plaintext) ->
NewState = State#{recv_counter => ActualCounter + 1},
{ok, Plaintext, NewState};
error ->
{error, authentication_failed}
end
catch
error:Reason ->
{error, {decrypt_failed, Reason}}
end.
-spec counter_to_nonce(non_neg_integer()) -> binary().
counter_to_nonce(Counter) ->
<<0:32, Counter:64/big-unsigned-integer>>.
-spec validate_nonce(binary(), binary(), non_neg_integer()) ->
{ok, non_neg_integer()} | {error, term()}.
validate_nonce(Nonce, ExpectedNonce, Counter) when Nonce =:= ExpectedNonce ->
{ok, Counter};
validate_nonce(Nonce, _ExpectedNonce, Counter) ->
<<_Prefix:4/binary, ReceivedCounter:64/big-unsigned-integer>> = Nonce,
MaxWindow = 32,
case ReceivedCounter > Counter andalso ReceivedCounter =< Counter + MaxWindow of
true ->
{ok, ReceivedCounter};
false ->
{error, {nonce_mismatch, Counter, ReceivedCounter}}
end.
-spec is_encrypted_frame(binary()) -> boolean().
is_encrypted_frame(<<?ENCRYPTED_FRAME_PREFIX, _Rest/binary>>) ->
true;
is_encrypted_frame(_) ->
false.
-spec unwrap_encrypted_frame(binary()) -> {ok, binary()} | {error, not_encrypted}.
unwrap_encrypted_frame(<<?ENCRYPTED_FRAME_PREFIX, Data/binary>>) ->
{ok, Data};
unwrap_encrypted_frame(_) ->
{error, not_encrypted}.
-spec wrap_encrypted_frame(binary()) -> binary().
wrap_encrypted_frame(Data) ->
<<?ENCRYPTED_FRAME_PREFIX, Data/binary>>.
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
generate_keypair_test() ->
Keypair = generate_keypair(),
?assert(is_map(Keypair)),
?assertEqual(?KEY_SIZE, byte_size(maps:get(public, Keypair))),
?assertEqual(?KEY_SIZE, byte_size(maps:get(private, Keypair))).
derive_shared_secret_test() ->
Keypair1 = generate_keypair(),
Keypair2 = generate_keypair(),
{ok, Secret1} = derive_shared_secret(maps:get(public, Keypair2), Keypair1),
{ok, Secret2} = derive_shared_secret(maps:get(public, Keypair1), Keypair2),
?assertEqual(Secret1, Secret2),
?assertEqual(?KEY_SIZE, byte_size(Secret1)).
derive_shared_secret_invalid_key_test() ->
Keypair = generate_keypair(),
Result = derive_shared_secret(<<"short">>, Keypair),
?assertMatch({error, {invalid_peer_key_size, _}}, Result).
new_crypto_state_test() ->
Secret = crypto:strong_rand_bytes(?KEY_SIZE),
State = new_crypto_state(Secret),
?assertEqual(Secret, maps:get(shared_secret, State)),
?assertEqual(0, maps:get(send_counter, State)),
?assertEqual(0, maps:get(recv_counter, State)).
encrypt_decrypt_roundtrip_test() ->
Secret = crypto:strong_rand_bytes(?KEY_SIZE),
State = new_crypto_state(Secret),
Plaintext = <<"hello world">>,
{ok, Ciphertext, State2} = encrypt(Plaintext, State),
?assert(byte_size(Ciphertext) > byte_size(Plaintext)),
?assertEqual(1, maps:get(send_counter, State2)),
{ok, Decrypted, State3} = decrypt(Ciphertext, State),
?assertEqual(Plaintext, Decrypted),
?assertEqual(1, maps:get(recv_counter, State3)).
encrypt_multiple_messages_test() ->
Secret = crypto:strong_rand_bytes(?KEY_SIZE),
SendState = new_crypto_state(Secret),
RecvState = new_crypto_state(Secret),
Messages = [<<"msg1">>, <<"msg2">>, <<"msg3">>],
{FinalSendState, FinalRecvState, DecryptedMsgs} = lists:foldl(
fun(Msg, {SS, RS, Acc}) ->
{ok, Cipher, SS2} = encrypt(Msg, SS),
{ok, Plain, RS2} = decrypt(Cipher, RS),
{SS2, RS2, [Plain | Acc]}
end,
{SendState, RecvState, []},
Messages
),
?assertEqual(3, maps:get(send_counter, FinalSendState)),
?assertEqual(3, maps:get(recv_counter, FinalRecvState)),
?assertEqual(Messages, lists:reverse(DecryptedMsgs)).
decrypt_tampered_test() ->
Secret = crypto:strong_rand_bytes(?KEY_SIZE),
State = new_crypto_state(Secret),
{ok, Ciphertext, _} = encrypt(<<"hello">>, State),
Tampered = <<(binary:first(Ciphertext) bxor 1), (binary:part(Ciphertext, 1, byte_size(Ciphertext) - 1))/binary>>,
Result = decrypt(Tampered, State),
?assertMatch({error, _}, Result).
decrypt_too_short_test() ->
Secret = crypto:strong_rand_bytes(?KEY_SIZE),
State = new_crypto_state(Secret),
Result = decrypt(<<"short">>, State),
?assertMatch({error, {invalid_encrypted_data, too_short}}, Result).
is_encrypted_frame_test() ->
?assertEqual(true, is_encrypted_frame(<<16#FE, "data">>)),
?assertEqual(false, is_encrypted_frame(<<"data">>)),
?assertEqual(false, is_encrypted_frame(<<16#FF, "data">>)),
?assertEqual(false, is_encrypted_frame(<<>>)).
unwrap_encrypted_frame_test() ->
?assertEqual({ok, <<"data">>}, unwrap_encrypted_frame(<<16#FE, "data">>)),
?assertEqual({error, not_encrypted}, unwrap_encrypted_frame(<<"data">>)).
wrap_encrypted_frame_test() ->
?assertEqual(<<16#FE, "data">>, wrap_encrypted_frame(<<"data">>)).
counter_to_nonce_test() ->
Nonce0 = counter_to_nonce(0),
?assertEqual(?NONCE_SIZE, byte_size(Nonce0)),
?assertEqual(<<0:32, 0:64>>, Nonce0),
Nonce1 = counter_to_nonce(1),
?assertEqual(<<0:32, 1:64>>, Nonce1).
init_creates_keypair_test() ->
persistent_term:erase(?KEYPAIR_KEY),
ok = init(),
Public = get_public_key(),
?assert(is_binary(Public)),
?assertEqual(?KEY_SIZE, byte_size(Public)),
persistent_term:erase(?KEYPAIR_KEY).
init_idempotent_test() ->
persistent_term:erase(?KEYPAIR_KEY),
ok = init(),
Public1 = get_public_key(),
ok = init(),
Public2 = get_public_key(),
?assertEqual(Public1, Public2),
persistent_term:erase(?KEYPAIR_KEY).
-endif.

View File

@@ -19,14 +19,15 @@
-export([load/0, get/1, get_optional/1, get_map/0, patch/1, update/1]).
-define(APP, fluxer_gateway).
-define(CONFIG_TERM_KEY, {fluxer_gateway, runtime_config}).
-type config() :: map().
-spec load() -> config().
load() ->
set_config(build_config()).
Config = build_config(),
apply_system_config(Config),
set_config(Config).
-spec get(atom()) -> term().
get(Key) when is_atom(Key) ->
@@ -67,202 +68,145 @@ ensure_loaded() ->
-spec build_config() -> config().
build_config() ->
#{
ws_port => env_int("FLUXER_GATEWAY_WS_PORT", ws_port, 8080),
rpc_port => env_int("FLUXER_GATEWAY_RPC_PORT", rpc_port, 8081),
api_host => env_string("API_HOST", api_host, "api"),
api_canary_host => env_optional_string("API_CANARY_HOST", api_canary_host),
rpc_secret_key => env_binary("GATEWAY_RPC_SECRET", rpc_secret_key, undefined),
identify_rate_limit_enabled => env_bool("FLUXER_GATEWAY_IDENTIFY_RATE_LIMIT_ENABLED", identify_rate_limit_enabled, false),
push_enabled => env_bool("FLUXER_GATEWAY_PUSH_ENABLED", push_enabled, true),
push_user_guild_settings_cache_mb => env_int("FLUXER_GATEWAY_PUSH_USER_GUILD_SETTINGS_CACHE_MB",
push_user_guild_settings_cache_mb, 1024),
push_subscriptions_cache_mb => env_int("FLUXER_GATEWAY_PUSH_SUBSCRIPTIONS_CACHE_MB",
push_subscriptions_cache_mb, 1024),
push_blocked_ids_cache_mb => env_int("FLUXER_GATEWAY_PUSH_BLOCKED_IDS_CACHE_MB",
push_blocked_ids_cache_mb, 1024),
presence_cache_shards => env_optional_int("FLUXER_GATEWAY_PRESENCE_CACHE_SHARDS", presence_cache_shards),
presence_bus_shards => env_optional_int("FLUXER_GATEWAY_PRESENCE_BUS_SHARDS", presence_bus_shards),
presence_shards => env_optional_int("FLUXER_GATEWAY_PRESENCE_SHARDS", presence_shards),
guild_shards => env_optional_int("FLUXER_GATEWAY_GUILD_SHARDS", guild_shards),
metrics_host => env_optional_string("FLUXER_METRICS_HOST", metrics_host),
push_badge_counts_cache_mb => app_env_int(push_badge_counts_cache_mb, 256),
push_badge_counts_cache_ttl_seconds => app_env_int(push_badge_counts_cache_ttl_seconds, 60),
media_proxy_endpoint => env_optional_binary("MEDIA_PROXY_ENDPOINT", media_proxy_endpoint),
vapid_email => env_binary("VAPID_EMAIL", vapid_email, <<"support@fluxer.app">>),
vapid_public_key => env_binary("VAPID_PUBLIC_KEY", vapid_public_key, undefined),
vapid_private_key => env_binary("VAPID_PRIVATE_KEY", vapid_private_key, undefined),
gateway_metrics_enabled => app_env_optional_bool(gateway_metrics_enabled),
gateway_metrics_report_interval_ms => app_env_optional_int(gateway_metrics_report_interval_ms)
}.
fluxer_gateway_config:load().
-spec env_int(string(), atom(), integer()) -> integer().
env_int(EnvVar, AppKey, Default) when is_atom(AppKey), is_integer(Default) ->
case os:getenv(EnvVar) of
false ->
app_env_int(AppKey, Default);
Value ->
parse_int(Value, Default)
-spec apply_system_config(config()) -> ok.
apply_system_config(Config) ->
apply_logger_config(Config),
apply_telemetry_config(Config).
-spec apply_logger_config(config()) -> ok.
apply_logger_config(Config) ->
LoggerLevel = resolve_logger_level(Config),
logger:set_primary_config(level, LoggerLevel),
logger:set_handler_config(default, level, LoggerLevel).
-spec apply_telemetry_config(config()) -> ok.
apply_telemetry_config(Config) ->
Telemetry = maps:get(telemetry, Config, #{}),
apply_telemetry_config(Telemetry, Config).
-spec resolve_logger_level(config()) -> atom().
resolve_logger_level(Config) ->
Default = maps:get(logger_level, Config, info),
case os:getenv("LOGGER_LEVEL") of
false -> Default;
"" -> Default;
Value -> parse_logger_level(Value, Default)
end.
-spec env_optional_int(string(), atom()) -> integer() | undefined.
env_optional_int(EnvVar, AppKey) when is_atom(AppKey) ->
case os:getenv(EnvVar) of
false ->
app_env_optional_int(AppKey);
Value ->
parse_int(Value, undefined)
end.
-spec env_bool(string(), atom(), boolean()) -> boolean().
env_bool(EnvVar, AppKey, Default) when is_atom(AppKey), is_boolean(Default) ->
case os:getenv(EnvVar) of
false ->
app_env_bool(AppKey, Default);
Value ->
parse_bool(Value, Default)
end.
-spec env_string(string(), atom(), string()) -> string().
env_string(EnvVar, AppKey, Default) when is_atom(AppKey) ->
case os:getenv(EnvVar) of
false ->
app_env_string(AppKey, Default);
Value ->
Value
end.
-spec env_optional_string(string(), atom()) -> string() | undefined.
env_optional_string(EnvVar, AppKey) when is_atom(AppKey) ->
case os:getenv(EnvVar) of
false ->
app_env_optional_string(AppKey);
Value ->
Value
end.
-spec env_binary(string(), atom(), binary() | undefined) -> binary() | undefined.
env_binary(EnvVar, AppKey, Default) when is_atom(AppKey) ->
case os:getenv(EnvVar) of
false ->
app_env_binary(AppKey, Default);
Value ->
to_binary(Value, Default)
end.
-spec env_optional_binary(string(), atom()) -> binary() | undefined.
env_optional_binary(EnvVar, AppKey) when is_atom(AppKey) ->
case os:getenv(EnvVar) of
false ->
app_env_optional_binary(AppKey);
Value ->
to_binary(Value, undefined)
end.
-spec parse_int(string(), integer() | undefined) -> integer() | undefined.
parse_int(Value, Default) ->
Str = string:trim(Value),
try
list_to_integer(Str)
catch
_:_ -> Default
end.
-spec parse_bool(string(), boolean()) -> boolean().
parse_bool(Value, Default) ->
Str = string:lowercase(string:trim(Value)),
case Str of
"true" -> true;
"1" -> true;
"false" -> false;
"0" -> false;
-spec parse_logger_level(string(), atom()) -> atom().
parse_logger_level(Value, Default) ->
case string:lowercase(string:trim(Value)) of
"debug" -> debug;
"info" -> info;
"notice" -> notice;
"warning" -> warning;
"error" -> error;
"critical" -> critical;
"alert" -> alert;
"emergency" -> emergency;
_ -> Default
end.
-spec to_binary(string(), binary() | undefined) -> binary() | undefined.
to_binary(Value, Default) ->
try
list_to_binary(Value)
catch
_:_ -> Default
-ifdef(HAS_OPENTELEMETRY).
-spec apply_telemetry_config(map(), config()) -> ok.
apply_telemetry_config(Telemetry, Config) ->
Sentry = maps:get(sentry, Config, #{}),
ShouldEnable = otel_metrics:configure_enabled(Telemetry),
case ShouldEnable of
true ->
set_opentelemetry_env(Telemetry, Sentry, Config);
false ->
application:set_env(opentelemetry_experimental, readers, []),
application:set_env(opentelemetry, processors, []),
application:set_env(opentelemetry, traces_exporter, none)
end.
-spec app_env_int(atom(), integer()) -> integer().
app_env_int(Key, Default) ->
case application:get_env(?APP, Key) of
{ok, Value} when is_integer(Value) ->
Value;
_ ->
Default
-spec set_opentelemetry_env(map(), map(), config()) -> ok.
-ifdef(DEV_MODE).
set_opentelemetry_env(_Telemetry, _Sentry, _Config) ->
ok.
-else.
set_opentelemetry_env(Telemetry, Sentry, Config) ->
Endpoint = maps:get(otlp_endpoint, Telemetry, ""),
ApiKey = maps:get(api_key, Telemetry, ""),
Headers = otlp_headers(ApiKey),
ServiceName = maps:get(service_name, Telemetry, "fluxer-gateway"),
Environment = maps:get(environment, Telemetry, "development"),
Version = maps:get(build_sha, Sentry, ""),
InstanceId = maps:get(release_node, Config, ""),
Resource = [
{service_name, ServiceName},
{service_version, Version},
{service_namespace, "fluxer"},
{deployment_environment, Environment},
{service_instance_id, InstanceId}
],
application:set_env(
opentelemetry_experimental,
readers,
[
{otel_periodic_reader, #{
exporter =>
{otel_otlp_metrics, #{
protocol => http_protobuf,
endpoint => Endpoint,
headers => Headers
}},
interval => 30000
}}
]
),
application:set_env(opentelemetry_experimental, resource, Resource),
application:set_env(
opentelemetry,
processors,
[
{otel_batch_processor, #{
exporter => {opentelemetry_exporter, #{}},
scheduled_delay_ms => 1000,
max_queue_size => 2048,
export_timeout_ms => 30000
}}
]
),
application:set_env(opentelemetry, traces_exporter, {opentelemetry_exporter, #{}}),
application:set_env(
opentelemetry,
logger,
[
{handler, default, otel_log_handler, #{
level => info,
max_queue_size => 2048,
scheduled_delay_ms => 1000,
exporting_timeout_ms => 30000,
exporter =>
{otel_otlp_logs, #{
protocol => http_protobuf,
endpoint => Endpoint,
headers => Headers
}}
}}
]
),
application:set_env(opentelemetry_exporter, otlp_protocol, http_protobuf),
application:set_env(opentelemetry_exporter, otlp_endpoint, Endpoint),
application:set_env(opentelemetry_exporter, otlp_headers, Headers).
-spec otlp_headers(string()) -> [{string(), string()}].
otlp_headers(ApiKey) ->
ApiKeyStr = string:trim(ApiKey),
case ApiKeyStr of
"" -> [];
_ -> [{"Authorization", "Bearer " ++ ApiKeyStr}]
end.
-spec app_env_optional_int(atom()) -> integer() | undefined.
app_env_optional_int(Key) ->
case application:get_env(?APP, Key) of
{ok, Value} when is_integer(Value) ->
Value;
_ ->
undefined
end.
-spec app_env_bool(atom(), boolean()) -> boolean().
app_env_bool(Key, Default) ->
case application:get_env(?APP, Key) of
{ok, Value} when is_boolean(Value) ->
Value;
_ ->
Default
end.
-spec app_env_optional_bool(atom()) -> boolean() | undefined.
app_env_optional_bool(Key) ->
case application:get_env(?APP, Key) of
{ok, Value} when is_boolean(Value) ->
Value;
_ ->
undefined
end.
-spec app_env_string(atom(), string()) -> string().
app_env_string(Key, Default) ->
case application:get_env(?APP, Key) of
{ok, Value} when is_list(Value) ->
Value;
{ok, Value} when is_binary(Value) ->
binary_to_list(Value);
_ ->
Default
end.
-spec app_env_optional_string(atom()) -> string() | undefined.
app_env_optional_string(Key) ->
case application:get_env(?APP, Key) of
{ok, Value} when is_list(Value) ->
Value;
{ok, Value} when is_binary(Value) ->
binary_to_list(Value);
_ ->
undefined
end.
-spec app_env_binary(atom(), binary() | undefined) -> binary() | undefined.
app_env_binary(Key, Default) ->
case application:get_env(?APP, Key) of
{ok, Value} when is_binary(Value) ->
Value;
{ok, Value} when is_list(Value) ->
list_to_binary(Value);
_ ->
Default
end.
-spec app_env_optional_binary(atom()) -> binary() | undefined.
app_env_optional_binary(Key) ->
case application:get_env(?APP, Key) of
{ok, Value} when is_binary(Value) ->
Value;
{ok, Value} when is_list(Value) ->
list_to_binary(Value);
_ ->
undefined
end.
-endif.
-else.
-spec apply_telemetry_config(map(), config()) -> ok.
apply_telemetry_config(_Telemetry, _Config) ->
application:set_env(opentelemetry_experimental, readers, []),
application:set_env(opentelemetry, processors, []),
application:set_env(opentelemetry, traces_exporter, none).
-endif.

View File

@@ -19,74 +19,39 @@
-behaviour(supervisor).
-export([start_link/0, init/1]).
-spec start_link() -> {ok, pid()} | {error, term()}.
start_link() ->
supervisor:start_link({local, ?MODULE}, ?MODULE, []).
-spec init([]) -> {ok, {supervisor:sup_flags(), [supervisor:child_spec()]}}.
init([]) ->
SessionManager = #{
id => session_manager,
start => {session_manager, start_link, []},
SupFlags = #{
strategy => one_for_one,
intensity => 5,
period => 10
},
Children = [
child_spec(gateway_http_client, gateway_http_client),
child_spec(gateway_rpc_tcp_server, gateway_rpc_tcp_server),
child_spec(session_manager, session_manager),
child_spec(presence_cache, presence_cache),
child_spec(presence_bus, presence_bus),
child_spec(presence_manager, presence_manager),
child_spec(guild_crash_logger, guild_crash_logger),
child_spec(guild_manager, guild_manager),
child_spec(call_manager, call_manager),
child_spec(push_dispatcher, push_dispatcher),
child_spec(push, push),
child_spec(gateway_metrics_collector, gateway_metrics_collector)
],
{ok, {SupFlags, Children}}.
-spec child_spec(atom(), module()) -> supervisor:child_spec().
child_spec(Id, Module) ->
#{
id => Id,
start => {Module, start_link, []},
restart => permanent,
shutdown => 5000,
type => worker
},
PresenceManager = #{
id => presence_manager,
start => {presence_manager, start_link, []},
restart => permanent,
shutdown => 5000,
type => worker
},
GuildManager = #{
id => guild_manager,
start => {guild_manager, start_link, []},
restart => permanent,
shutdown => 5000,
type => worker
},
Push = #{
id => push,
start => {push, start_link, []},
restart => permanent,
shutdown => 5000,
type => worker
},
CallManager = #{
id => call_manager,
start => {call_manager, start_link, []},
restart => permanent,
shutdown => 5000,
type => worker
},
PresenceBus = #{
id => presence_bus,
start => {presence_bus, start_link, []},
restart => permanent,
shutdown => 5000,
type => worker
},
PresenceCache = #{
id => presence_cache,
start => {presence_cache, start_link, []},
restart => permanent,
shutdown => 5000,
type => worker
},
GatewayMetricsCollector = #{
id => gateway_metrics_collector,
start => {gateway_metrics_collector, start_link, []},
restart => permanent,
shutdown => 5000,
type => worker
},
{ok,
{{one_for_one, 5, 10}, [
SessionManager,
PresenceCache,
PresenceBus,
PresenceManager,
GuildManager,
CallManager,
Push,
GatewayMetricsCollector
]}}.
}.

View File

@@ -24,15 +24,18 @@
]).
-type encoding() :: json.
-type frame_type() :: text | binary.
-export_type([encoding/0]).
-spec parse_encoding(binary() | undefined) -> encoding().
parse_encoding(_) -> json.
parse_encoding(_) ->
json.
-spec encode(map(), encoding()) -> {ok, iodata(), text | binary} | {error, term()}.
-spec encode(map(), encoding()) -> {ok, iodata(), frame_type()} | {error, term()}.
encode(Message, json) ->
try
Encoded = jsx:encode(Message),
Encoded = iolist_to_binary(json:encode(Message)),
{ok, Encoded, text}
catch
_:Reason ->
@@ -42,7 +45,7 @@ encode(Message, json) ->
-spec decode(binary(), encoding()) -> {ok, map()} | {error, term()}.
decode(Data, json) ->
try
Decoded = jsx:decode(Data, [{return_maps, true}]),
Decoded = json:decode(Data),
{ok, Decoded}
catch
_:Reason ->
@@ -52,26 +55,66 @@ decode(Data, json) ->
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
parse_encoding_test() ->
?assertEqual(json, parse_encoding(<<"json">>)),
?assertEqual(json, parse_encoding(<<"etf">>)),
?assertEqual(json, parse_encoding(undefined)),
?assertEqual(json, parse_encoding(<<"invalid">>)).
parse_encoding_test_() ->
[
?_assertEqual(json, parse_encoding(<<"json">>)),
?_assertEqual(json, parse_encoding(<<"etf">>)),
?_assertEqual(json, parse_encoding(undefined)),
?_assertEqual(json, parse_encoding(<<"invalid">>)),
?_assertEqual(json, parse_encoding(<<>>))
].
encode_json_test() ->
encode_json_test_() ->
Message = #{<<"op">> => 0, <<"d">> => #{<<"test">> => true}},
[
?_assertMatch({ok, _, text}, encode(Message, json)),
?_test(begin
{ok, Encoded, text} = encode(Message, json),
?assert(is_binary(Encoded))
end)
].
encode_empty_map_test() ->
{ok, Encoded, text} = encode(#{}, json),
?assertEqual(<<"{}">>, Encoded).
encode_nested_test() ->
Message = #{<<"a">> => #{<<"b">> => #{<<"c">> => 1}}},
{ok, Encoded, text} = encode(Message, json),
?assert(is_binary(Encoded)).
?assert(is_binary(Encoded)),
{ok, Decoded} = decode(Encoded, json),
?assertEqual(Message, Decoded).
decode_json_test() ->
decode_json_test_() ->
Data = <<"{\"op\":0,\"d\":{\"test\":true}}">>,
{ok, Decoded} = decode(Data, json),
?assertEqual(0, maps:get(<<"op">>, Decoded)).
[
?_assertMatch({ok, _}, decode(Data, json)),
?_test(begin
{ok, Decoded} = decode(Data, json),
?assertEqual(0, maps:get(<<"op">>, Decoded))
end)
].
roundtrip_json_test() ->
Original = #{<<"op">> => 10, <<"d">> => #{<<"heartbeat_interval">> => 41250}},
{ok, Encoded, _} = encode(Original, json),
{ok, Decoded} = decode(iolist_to_binary(Encoded), json),
?assertEqual(Original, Decoded).
decode_invalid_json_test() ->
?assertMatch({error, {decode_failed, _}}, decode(<<"not json">>, json)).
decode_empty_object_test() ->
{ok, Decoded} = decode(<<"{}">>, json),
?assertEqual(#{}, Decoded).
roundtrip_json_test_() ->
Messages = [
#{<<"op">> => 10, <<"d">> => #{<<"heartbeat_interval">> => 41250}},
#{<<"op">> => 0, <<"s">> => 1, <<"t">> => <<"READY">>, <<"d">> => #{}},
#{<<"list">> => [1, 2, 3], <<"bool">> => true, <<"null">> => null}
],
[
?_test(begin
{ok, Encoded, _} = encode(Msg, json),
{ok, Decoded} = decode(iolist_to_binary(Encoded), json),
?assertEqual(Msg, Decoded)
end)
|| Msg <- Messages
].
-endif.

View File

@@ -27,80 +27,166 @@
]).
-type compression() :: none | zstd_stream.
-export_type([compression/0]).
-record(compress_ctx, {type :: compression()}).
-type compress_ctx() :: #compress_ctx{}.
-export_type([compress_ctx/0]).
-opaque compress_ctx() :: #{type := compression()}.
-export_type([compression/0, compress_ctx/0]).
-spec parse_compression(binary() | undefined) -> compression().
parse_compression(<<"none">>) -> none;
parse_compression(<<"zstd-stream">>) -> zstd_stream;
parse_compression(_) -> none.
parse_compression(<<"none">>) ->
none;
%% TODO: temporarily disabled re-enable zstd-stream once compression issues are resolved
parse_compression(<<"zstd-stream">>) ->
none;
parse_compression(_) ->
none.
-spec new_context(compression()) -> compress_ctx().
new_context(none) ->
#compress_ctx{type = none};
#{type => none};
new_context(zstd_stream) ->
#compress_ctx{type = zstd_stream}.
#{type => zstd_stream}.
-spec close_context(compress_ctx()) -> ok.
close_context(_Ctx) ->
close_context(#{}) ->
ok.
-spec get_type(compress_ctx()) -> compression().
get_type(#compress_ctx{type = Type}) ->
get_type(#{type := Type}) ->
Type.
-spec compress(iodata(), compress_ctx()) -> {ok, binary(), compress_ctx()} | {error, term()}.
compress(Data, Ctx = #compress_ctx{type = none}) ->
compress(Data, Ctx = #{type := none}) ->
{ok, iolist_to_binary(Data), Ctx};
compress(Data, Ctx = #compress_ctx{type = zstd_stream}) ->
try
Binary = iolist_to_binary(Data),
case ezstd:compress(Binary, 3) of
Compressed when is_binary(Compressed) ->
{ok, Compressed, Ctx};
{error, Reason} ->
{error, {compress_failed, Reason}}
end
catch
_:Exception ->
{error, {compress_failed, Exception}}
end.
compress(Data, Ctx = #{type := zstd_stream}) ->
zstd_compress(Data, Ctx).
-spec decompress(binary(), compress_ctx()) -> {ok, binary(), compress_ctx()} | {error, term()}.
decompress(Data, Ctx = #compress_ctx{type = none}) ->
decompress(Data, Ctx = #{type := none}) ->
{ok, Data, Ctx};
decompress(Data, Ctx = #compress_ctx{type = zstd_stream}) ->
try
case ezstd:decompress(Data) of
Decompressed when is_binary(Decompressed) ->
{ok, Decompressed, Ctx};
{error, Reason} ->
{error, {decompress_failed, Reason}}
end
catch
_:Exception ->
{error, {decompress_failed, Exception}}
decompress(Data, Ctx = #{type := zstd_stream}) ->
zstd_decompress(Data, Ctx).
zstd_compress(Data, Ctx) ->
case ezstd_available() of
true ->
try
Binary = iolist_to_binary(Data),
case erlang:apply(ezstd, compress, [Binary, 3]) of
Compressed when is_binary(Compressed) ->
{ok, Compressed, Ctx};
{error, Reason} ->
{error, {compress_failed, Reason}}
end
catch
_:Exception ->
{error, {compress_failed, Exception}}
end;
false ->
{error, {compress_failed, zstd_not_available}}
end.
zstd_decompress(Data, Ctx) ->
case ezstd_available() of
true ->
try
case erlang:apply(ezstd, decompress, [Data]) of
Decompressed when is_binary(Decompressed) ->
{ok, Decompressed, Ctx};
{error, Reason} ->
{error, {decompress_failed, Reason}}
end
catch
_:Exception ->
{error, {decompress_failed, Exception}}
end;
false ->
{error, {decompress_failed, zstd_not_available}}
end.
-spec ezstd_available() -> boolean().
ezstd_available() ->
case code:ensure_loaded(ezstd) of
{module, ezstd} ->
erlang:function_exported(ezstd, compress, 2) andalso
erlang:function_exported(ezstd, decompress, 1);
_ ->
false
end.
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
parse_compression_test() ->
?assertEqual(none, parse_compression(undefined)),
?assertEqual(none, parse_compression(<<>>)),
?assertEqual(zstd_stream, parse_compression(<<"zstd-stream">>)),
?assertEqual(none, parse_compression(<<"none">>)).
parse_compression_test_() ->
[
?_assertEqual(none, parse_compression(undefined)),
?_assertEqual(none, parse_compression(<<>>)),
?_assertEqual(none, parse_compression(<<"none">>)),
?_assertEqual(none, parse_compression(<<"invalid">>)),
%% zstd-stream temporarily disabled always returns none
?_assertEqual(none, parse_compression(<<"zstd-stream">>))
].
zstd_roundtrip_test() ->
Ctx = new_context(zstd_stream),
Data = <<"hello world, this is a test message for zstd compression">>,
new_context_test_() ->
[
?_assertEqual(none, get_type(new_context(none))),
?_assertEqual(zstd_stream, get_type(new_context(zstd_stream)))
].
close_context_test() ->
Ctx = new_context(none),
?assertEqual(ok, close_context(Ctx)).
compress_none_test() ->
Ctx = new_context(none),
Data = <<"hello world">>,
{ok, Compressed, Ctx2} = compress(Data, Ctx),
?assert(is_binary(Compressed)),
{ok, Decompressed, _} = decompress(Compressed, Ctx2),
?assertEqual(Data, Decompressed),
ok = close_context(Ctx2).
?assertEqual(Data, Compressed),
?assertEqual(none, get_type(Ctx2)).
compress_none_iolist_test() ->
Ctx = new_context(none),
Data = [<<"hello">>, <<" ">>, <<"world">>],
{ok, Compressed, _} = compress(Data, Ctx),
?assertEqual(<<"hello world">>, Compressed).
decompress_none_test() ->
Ctx = new_context(none),
Data = <<"hello world">>,
{ok, Decompressed, _} = decompress(Data, Ctx),
?assertEqual(Data, Decompressed).
-ifdef(DEV_MODE).
zstd_roundtrip_test() ->
?assertEqual(skip, skip).
zstd_compression_ratio_test() ->
?assertEqual(skip, skip).
-else.
zstd_roundtrip_test() ->
case ezstd_available() of
true ->
Ctx = new_context(zstd_stream),
Data = <<"hello world, this is a test message for zstd compression">>,
{ok, Compressed, Ctx2} = compress(Data, Ctx),
?assert(is_binary(Compressed)),
{ok, Decompressed, _} = decompress(Compressed, Ctx2),
?assertEqual(Data, Decompressed),
ok = close_context(Ctx2);
false ->
?assertEqual(skip, skip)
end.
zstd_compression_ratio_test() ->
case ezstd_available() of
true ->
Ctx = new_context(zstd_stream),
Data = binary:copy(<<"aaaaaaaaaa">>, 100),
{ok, Compressed, _} = compress(Data, Ctx),
?assert(byte_size(Compressed) < byte_size(Data));
false ->
?assertEqual(skip, skip)
end.
-endif.
-endif.

View File

@@ -17,6 +17,8 @@
-module(gateway_errors).
-compile({no_auto_import, [error/1]}).
-export([
error/1,
error_code/1,
@@ -25,11 +27,64 @@
is_recoverable/1
]).
-spec error(atom()) -> {error, atom(), atom()}.
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-endif.
-type error_atom() ::
voice_connection_not_found
| voice_channel_not_found
| voice_channel_not_voice
| voice_member_not_found
| voice_user_not_in_voice
| voice_guild_not_found
| voice_permission_denied
| voice_member_timed_out
| voice_channel_full
| voice_missing_connection_id
| voice_invalid_user_id
| voice_invalid_channel_id
| voice_invalid_state
| voice_user_mismatch
| voice_token_failed
| voice_guild_id_missing
| voice_invalid_guild_id
| voice_moderator_missing_connect
| voice_unclaimed_account
| voice_update_rate_limited
| voice_nonce_mismatch
| voice_pending_expired
| voice_camera_user_limit
| dm_channel_not_found
| dm_not_recipient
| dm_invalid_channel_type
| validation_invalid_snowflake
| validation_null_snowflake
| validation_invalid_snowflake_list
| validation_expected_list
| validation_expected_map
| validation_missing_field
| validation_invalid_params
| internal_error
| timeout
| unknown_error
| atom().
-type error_category() ::
not_found
| validation_error
| permission_denied
| voice_error
| rate_limited
| timeout
| unknown
| auth_failed.
-spec error(error_atom()) -> {error, error_category(), error_atom()}.
error(ErrorAtom) ->
{error, error_category(ErrorAtom), ErrorAtom}.
-spec error_code(atom()) -> binary().
-spec error_code(error_atom()) -> binary().
error_code(voice_connection_not_found) -> <<"VOICE_CONNECTION_NOT_FOUND">>;
error_code(voice_channel_not_found) -> <<"VOICE_CHANNEL_NOT_FOUND">>;
error_code(voice_channel_not_voice) -> <<"VOICE_INVALID_CHANNEL_TYPE">>;
@@ -49,6 +104,10 @@ error_code(voice_guild_id_missing) -> <<"VOICE_GUILD_ID_MISSING">>;
error_code(voice_invalid_guild_id) -> <<"VOICE_INVALID_GUILD_ID">>;
error_code(voice_moderator_missing_connect) -> <<"VOICE_PERMISSION_DENIED">>;
error_code(voice_unclaimed_account) -> <<"VOICE_UNCLAIMED_ACCOUNT">>;
error_code(voice_update_rate_limited) -> <<"VOICE_UPDATE_RATE_LIMITED">>;
error_code(voice_nonce_mismatch) -> <<"VOICE_NONCE_MISMATCH">>;
error_code(voice_pending_expired) -> <<"VOICE_PENDING_EXPIRED">>;
error_code(voice_camera_user_limit) -> <<"VOICE_CAMERA_USER_LIMIT">>;
error_code(dm_channel_not_found) -> <<"DM_CHANNEL_NOT_FOUND">>;
error_code(dm_not_recipient) -> <<"DM_NOT_RECIPIENT">>;
error_code(dm_invalid_channel_type) -> <<"DM_INVALID_CHANNEL_TYPE">>;
@@ -64,7 +123,7 @@ error_code(timeout) -> <<"TIMEOUT">>;
error_code(unknown_error) -> <<"UNKNOWN_ERROR">>;
error_code(_) -> <<"UNKNOWN_ERROR">>.
-spec error_message(atom()) -> binary().
-spec error_message(error_atom()) -> binary().
error_message(voice_connection_not_found) -> <<"Voice connection not found">>;
error_message(voice_channel_not_found) -> <<"Voice channel not found">>;
error_message(voice_channel_not_voice) -> <<"Channel is not a voice channel">>;
@@ -84,6 +143,10 @@ error_message(voice_guild_id_missing) -> <<"Guild ID is required">>;
error_message(voice_invalid_guild_id) -> <<"Invalid guild ID">>;
error_message(voice_moderator_missing_connect) -> <<"Moderator missing connect permission">>;
error_message(voice_unclaimed_account) -> <<"Claim your account to join voice">>;
error_message(voice_update_rate_limited) -> <<"Voice updates are rate limited">>;
error_message(voice_nonce_mismatch) -> <<"Voice token nonce mismatch">>;
error_message(voice_pending_expired) -> <<"Voice pending connection expired">>;
error_message(voice_camera_user_limit) -> <<"Too many users in channel to enable camera">>;
error_message(dm_channel_not_found) -> <<"DM channel not found">>;
error_message(dm_not_recipient) -> <<"Not a recipient of this channel">>;
error_message(dm_invalid_channel_type) -> <<"Not a DM or Group DM channel">>;
@@ -99,7 +162,7 @@ error_message(timeout) -> <<"Request timed out">>;
error_message(unknown_error) -> <<"An unknown error occurred">>;
error_message(_) -> <<"An unknown error occurred">>.
-spec error_category(atom()) -> atom().
-spec error_category(error_atom()) -> error_category().
error_category(voice_connection_not_found) -> not_found;
error_category(voice_channel_not_found) -> not_found;
error_category(voice_channel_not_voice) -> validation_error;
@@ -119,6 +182,10 @@ error_category(voice_guild_id_missing) -> validation_error;
error_category(voice_invalid_guild_id) -> validation_error;
error_category(voice_moderator_missing_connect) -> permission_denied;
error_category(voice_unclaimed_account) -> permission_denied;
error_category(voice_update_rate_limited) -> rate_limited;
error_category(voice_nonce_mismatch) -> validation_error;
error_category(voice_pending_expired) -> validation_error;
error_category(voice_camera_user_limit) -> permission_denied;
error_category(dm_channel_not_found) -> not_found;
error_category(dm_not_recipient) -> permission_denied;
error_category(dm_invalid_channel_type) -> validation_error;
@@ -134,7 +201,7 @@ error_category(timeout) -> timeout;
error_category(unknown_error) -> unknown;
error_category(_) -> unknown.
-spec is_recoverable(atom()) -> boolean().
-spec is_recoverable(error_category()) -> boolean().
is_recoverable(not_found) -> true;
is_recoverable(permission_denied) -> true;
is_recoverable(voice_error) -> true;
@@ -144,3 +211,122 @@ is_recoverable(unknown) -> true;
is_recoverable(rate_limited) -> false;
is_recoverable(auth_failed) -> false;
is_recoverable(_) -> true.
-ifdef(TEST).
error_test() ->
?assertEqual({error, not_found, voice_connection_not_found}, error(voice_connection_not_found)),
?assertEqual(
{error, validation_error, voice_channel_not_voice}, error(voice_channel_not_voice)
),
?assertEqual(
{error, permission_denied, voice_permission_denied}, error(voice_permission_denied)
).
error_code_test() ->
?assertEqual(<<"VOICE_CONNECTION_NOT_FOUND">>, error_code(voice_connection_not_found)),
?assertEqual(<<"VOICE_CHANNEL_NOT_FOUND">>, error_code(voice_channel_not_found)),
?assertEqual(<<"VOICE_PERMISSION_DENIED">>, error_code(voice_permission_denied)),
?assertEqual(<<"VOICE_PERMISSION_DENIED">>, error_code(voice_moderator_missing_connect)),
?assertEqual(<<"UNKNOWN_ERROR">>, error_code(some_random_error)),
?assertEqual(<<"TIMEOUT">>, error_code(timeout)),
?assertEqual(<<"INTERNAL_ERROR">>, error_code(internal_error)).
error_message_test() ->
?assertEqual(<<"Voice connection not found">>, error_message(voice_connection_not_found)),
?assertEqual(<<"Voice channel not found">>, error_message(voice_channel_not_found)),
?assertEqual(<<"Missing voice permissions">>, error_message(voice_permission_denied)),
?assertEqual(<<"Voice channel is full">>, error_message(voice_channel_full)),
?assertEqual(<<"An unknown error occurred">>, error_message(some_random_error)),
?assertEqual(<<"Request timed out">>, error_message(timeout)).
error_category_test() ->
?assertEqual(not_found, error_category(voice_connection_not_found)),
?assertEqual(not_found, error_category(voice_channel_not_found)),
?assertEqual(not_found, error_category(dm_channel_not_found)),
?assertEqual(validation_error, error_category(voice_channel_not_voice)),
?assertEqual(validation_error, error_category(validation_invalid_snowflake)),
?assertEqual(permission_denied, error_category(voice_permission_denied)),
?assertEqual(permission_denied, error_category(voice_channel_full)),
?assertEqual(voice_error, error_category(voice_token_failed)),
?assertEqual(rate_limited, error_category(voice_update_rate_limited)),
?assertEqual(timeout, error_category(timeout)),
?assertEqual(unknown, error_category(unknown_error)),
?assertEqual(unknown, error_category(some_random_error)).
is_recoverable_test() ->
?assert(is_recoverable(not_found)),
?assert(is_recoverable(permission_denied)),
?assert(is_recoverable(voice_error)),
?assert(is_recoverable(validation_error)),
?assert(is_recoverable(timeout)),
?assert(is_recoverable(unknown)),
?assertNot(is_recoverable(rate_limited)),
?assertNot(is_recoverable(auth_failed)).
all_voice_errors_have_codes_test() ->
VoiceErrors = [
voice_connection_not_found,
voice_channel_not_found,
voice_channel_not_voice,
voice_member_not_found,
voice_user_not_in_voice,
voice_guild_not_found,
voice_permission_denied,
voice_member_timed_out,
voice_channel_full,
voice_missing_connection_id,
voice_invalid_user_id,
voice_invalid_channel_id,
voice_invalid_state,
voice_user_mismatch,
voice_token_failed,
voice_guild_id_missing,
voice_invalid_guild_id,
voice_moderator_missing_connect,
voice_unclaimed_account,
voice_update_rate_limited,
voice_nonce_mismatch,
voice_pending_expired,
voice_camera_user_limit
],
lists:foreach(
fun(Error) ->
Code = error_code(Error),
?assert(is_binary(Code)),
?assertNotEqual(<<"UNKNOWN_ERROR">>, Code)
end,
VoiceErrors
).
all_dm_errors_have_codes_test() ->
DmErrors = [dm_channel_not_found, dm_not_recipient, dm_invalid_channel_type],
lists:foreach(
fun(Error) ->
Code = error_code(Error),
?assert(is_binary(Code)),
?assertNotEqual(<<"UNKNOWN_ERROR">>, Code)
end,
DmErrors
).
all_validation_errors_have_codes_test() ->
ValidationErrors = [
validation_invalid_snowflake,
validation_null_snowflake,
validation_invalid_snowflake_list,
validation_expected_list,
validation_expected_map,
validation_missing_field,
validation_invalid_params
],
lists:foreach(
fun(Error) ->
Code = error_code(Error),
?assert(is_binary(Code)),
?assertNotEqual(<<"UNKNOWN_ERROR">>, Code)
end,
ValidationErrors
).
-endif.

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,540 @@
%% Copyright (C) 2026 Fluxer Contributors
%%
%% This file is part of Fluxer.
%%
%% Fluxer is free software: you can redistribute it and/or modify
%% it under the terms of the GNU Affero General Public License as published by
%% the Free Software Foundation, either version 3 of the License, or
%% (at your option) any later version.
%%
%% Fluxer is distributed in the hope that it will be useful,
%% but WITHOUT ANY WARRANTY; without even the implied warranty of
%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
%% GNU Affero General Public License for more details.
%%
%% You should have received a copy of the GNU Affero General Public License
%% along with Fluxer. If not, see <https://www.gnu.org/licenses/>.
-module(gateway_http_client).
-behaviour(gen_server).
-export([start_link/0, request/5, request/6]).
-export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]).
-define(SERVER, ?MODULE).
-define(CIRCUIT_TABLE, gateway_http_circuit_breaker).
-define(INFLIGHT_TABLE, gateway_http_inflight).
-define(DEFAULT_RPC_CONNECT_TIMEOUT_MS, 5000).
-define(DEFAULT_RPC_RECV_TIMEOUT_MS, 30000).
-define(DEFAULT_PUSH_CONNECT_TIMEOUT_MS, 3000).
-define(DEFAULT_PUSH_RECV_TIMEOUT_MS, 5000).
-define(DEFAULT_RPC_MAX_CONCURRENCY, 512).
-define(DEFAULT_PUSH_MAX_CONCURRENCY, 256).
-define(DEFAULT_FAILURE_THRESHOLD, 6).
-define(DEFAULT_RECOVERY_TIMEOUT_MS, 15000).
-define(DEFAULT_CLEANUP_INTERVAL_MS, 30000).
-define(DEFAULT_CLEANUP_MAX_AGE_MS, 300000).
-type workload() :: rpc | push.
-type method() :: get | post | put | patch | delete | head | options.
-type request_headers() :: [{binary() | string(), binary() | string()}].
-type request_options() :: #{
connect_timeout => timeout(),
recv_timeout => timeout(),
max_concurrency => pos_integer(),
failure_threshold => pos_integer(),
recovery_timeout_ms => pos_integer(),
content_type => binary() | string()
}.
-type response() :: {ok, non_neg_integer(), [{binary(), binary()}], binary()} | {error, term()}.
-type state() :: #{}.
-spec start_link() -> {ok, pid()} | {error, term()}.
start_link() ->
case whereis(?SERVER) of
undefined ->
case gen_server:start_link({local, ?SERVER}, ?MODULE, [], []) of
{error, {already_started, Pid}} when is_pid(Pid) ->
{ok, Pid};
Other ->
Other
end;
Pid when is_pid(Pid) ->
{ok, Pid}
end.
-spec request(workload(), method(), iodata(), request_headers(), iodata() | undefined) -> response().
request(Workload, Method, Url, Headers, Body) ->
request(Workload, Method, Url, Headers, Body, #{}).
-spec request(workload(), method(), iodata(), request_headers(), iodata() | undefined, request_options()) ->
response().
request(Workload, Method, Url, Headers, Body, Opts) when is_map(Opts) ->
ensure_runtime(Workload),
WorkloadOpts = merged_workload_options(Workload, Opts),
MaxConcurrency = maps:get(max_concurrency, WorkloadOpts),
FailureThreshold = maps:get(failure_threshold, WorkloadOpts),
RecoveryTimeoutMs = maps:get(recovery_timeout_ms, WorkloadOpts),
Host = extract_host_key(Url),
CircuitKey = {Workload, Host},
case allow_circuit_request(CircuitKey, RecoveryTimeoutMs) of
ok ->
case acquire_inflight_slot(Workload, MaxConcurrency) of
ok ->
Result = safe_do_request(Workload, Method, Url, Headers, Body, WorkloadOpts),
release_inflight_slot(Workload),
update_circuit_state(CircuitKey, Result, FailureThreshold),
Result;
{error, overloaded} ->
{error, overloaded}
end;
{error, circuit_open} ->
{error, circuit_open}
end.
-spec safe_do_request(
workload(), method(), iodata(), request_headers(), iodata() | undefined, request_options()
) ->
response().
safe_do_request(Workload, Method, Url, Headers, Body, Opts) ->
try do_request(Workload, Method, Url, Headers, Body, Opts) of
Result ->
Result
catch
Class:Reason:Stacktrace ->
{error,
{request_exception, #{
class => Class,
reason => Reason,
frame => first_stack_frame(Stacktrace),
workload => Workload,
method => Method,
url => ensure_binary(Url)
}}}
end.
-spec init([]) -> {ok, state()}.
init([]) ->
process_flag(trap_exit, true),
ensure_table(
?CIRCUIT_TABLE,
[named_table, public, set, {read_concurrency, true}, {write_concurrency, true}]
),
ensure_table(
?INFLIGHT_TABLE,
[named_table, public, set, {read_concurrency, true}, {write_concurrency, true}]
),
ok = ensure_httpc_profile(profile_for(rpc), rpc),
ok = ensure_httpc_profile(profile_for(push), push),
schedule_cleanup(),
{ok, #{}}.
-spec handle_call(term(), gen_server:from(), state()) -> {reply, ok, state()}.
handle_call(_Request, _From, State) ->
{reply, ok, State}.
-spec handle_cast(term(), state()) -> {noreply, state()}.
handle_cast(_Msg, State) ->
{noreply, State}.
-spec handle_info(term(), state()) -> {noreply, state()}.
handle_info(cleanup_circuits, State) ->
prune_circuit_table(),
schedule_cleanup(),
{noreply, State};
handle_info(_Info, State) ->
{noreply, State}.
-spec terminate(term(), state()) -> ok.
terminate(_Reason, _State) ->
ok.
-spec code_change(term(), state(), term()) -> {ok, state()}.
code_change(_OldVsn, State, _Extra) ->
{ok, State}.
-spec ensure_table(atom(), [term()]) -> ok.
ensure_table(Name, Options) ->
case ets:whereis(Name) of
undefined ->
try
_ = ets:new(Name, Options),
ok
catch
error:badarg -> ok
end;
_ ->
ok
end.
-spec ensure_runtime(workload()) -> ok.
ensure_runtime(Workload) ->
ok = ensure_started(),
_ = Workload,
ok.
-spec ensure_started() -> ok.
ensure_started() ->
case start_link() of
{ok, _Pid} ->
ok;
_ ->
ok
end.
-spec schedule_cleanup() -> reference().
schedule_cleanup() ->
erlang:send_after(cleanup_interval_ms(), self(), cleanup_circuits).
-spec prune_circuit_table() -> ok.
prune_circuit_table() ->
Now = erlang:system_time(millisecond),
MaxAgeMs = cleanup_max_age_ms(),
_ =
ets:foldl(
fun({Key, CircuitState}, Acc) ->
case is_stale_circuit(CircuitState, Now, MaxAgeMs) of
true ->
ets:delete(?CIRCUIT_TABLE, Key),
Acc;
false ->
Acc
end
end,
ok,
?CIRCUIT_TABLE
),
ok.
-spec is_stale_circuit(map(), integer(), integer()) -> boolean().
is_stale_circuit(#{state := open, opened_at := OpenedAt}, Now, MaxAgeMs) ->
Now - OpenedAt > MaxAgeMs;
is_stale_circuit(#{state := closed, failures := 0, updated_at := UpdatedAt}, Now, MaxAgeMs) ->
Now - UpdatedAt > MaxAgeMs;
is_stale_circuit(_, _, _) ->
false.
-spec allow_circuit_request({workload(), binary()}, pos_integer()) -> ok | {error, circuit_open}.
allow_circuit_request(CircuitKey, RecoveryTimeoutMs) ->
Now = erlang:system_time(millisecond),
case safe_lookup_circuit(CircuitKey) of
[] ->
ok;
[{_, #{state := open, opened_at := OpenedAt}} = Entry] ->
case Now - OpenedAt >= RecoveryTimeoutMs of
true ->
{_, State0} = Entry,
NewState = State0#{state => half_open, updated_at => Now},
ets:insert(?CIRCUIT_TABLE, {CircuitKey, NewState}),
ok;
false ->
{error, circuit_open}
end;
_ ->
ok
end.
-spec safe_lookup_circuit({workload(), binary()}) -> list().
safe_lookup_circuit(Key) ->
try ets:lookup(?CIRCUIT_TABLE, Key) of
Result -> Result
catch
error:badarg -> []
end.
-spec acquire_inflight_slot(workload(), pos_integer()) -> ok | {error, overloaded}.
acquire_inflight_slot(Workload, MaxConcurrency) ->
case safe_update_counter(?INFLIGHT_TABLE, Workload, {2, 1}) of
{ok, Count} when Count =< MaxConcurrency ->
ok;
{ok, _Count} ->
_ = safe_update_counter(?INFLIGHT_TABLE, Workload, {2, -1}),
{error, overloaded};
{error, _Reason} ->
{error, overloaded}
end.
-spec release_inflight_slot(workload()) -> ok.
release_inflight_slot(Workload) ->
_ = safe_update_counter(?INFLIGHT_TABLE, Workload, {2, -1}),
ok.
-spec safe_update_counter(atom(), term(), {pos_integer(), integer()}) ->
{ok, integer()} | {error, term()}.
safe_update_counter(Table, Key, Op) ->
try
{ok, ets:update_counter(Table, Key, Op, {Key, 0})}
catch
error:badarg ->
ok = ensure_started(),
try
{ok, ets:update_counter(Table, Key, Op, {Key, 0})}
catch
error:badarg ->
{error, badarg}
end
end.
-spec update_circuit_state({workload(), binary()}, response(), pos_integer()) -> ok.
update_circuit_state(CircuitKey, Result, FailureThreshold) ->
Now = erlang:system_time(millisecond),
case should_count_failure(Result) of
true ->
record_failure(CircuitKey, FailureThreshold, Now);
false ->
record_success(CircuitKey, Now)
end.
-spec should_count_failure(response()) -> boolean().
should_count_failure({error, _Reason}) ->
true;
should_count_failure({ok, StatusCode, _Headers, _Body}) when StatusCode >= 500 ->
true;
should_count_failure(_) ->
false.
-spec record_failure({workload(), binary()}, pos_integer(), integer()) -> ok.
record_failure(CircuitKey, Threshold, Now) ->
case safe_lookup_circuit(CircuitKey) of
[] ->
ets:insert(?CIRCUIT_TABLE, {CircuitKey, #{
state => closed,
failures => 1,
opened_at => undefined,
updated_at => Now
}}),
ok;
[{_, #{failures := Failures} = Existing}] ->
NewFailures = Failures + 1,
NewState =
case NewFailures >= Threshold of
true -> open;
false -> maps:get(state, Existing, closed)
end,
OpenedAt =
case NewState of
open -> Now;
_ -> maps:get(opened_at, Existing, undefined)
end,
ets:insert(?CIRCUIT_TABLE, {CircuitKey, Existing#{
state => NewState,
failures => NewFailures,
opened_at => OpenedAt,
updated_at => Now
}}),
ok
end.
-spec record_success({workload(), binary()}, integer()) -> ok.
record_success(CircuitKey, Now) ->
case safe_lookup_circuit(CircuitKey) of
[] ->
ok;
[{_, Existing}] ->
NewState = Existing#{
state => closed,
failures => 0,
updated_at => Now
},
ets:insert(?CIRCUIT_TABLE, {CircuitKey, maps:remove(opened_at, NewState)}),
ok
end.
-spec do_request(workload(), method(), iodata(), request_headers(), iodata() | undefined, request_options()) ->
response().
do_request(Workload, Method, Url, Headers, Body, Opts) ->
HttpMethod = normalize_method(Method),
UrlString = ensure_list(Url),
RequestHeaders = normalize_request_headers(Headers),
RequestTuple = build_request_tuple(UrlString, RequestHeaders, Body, Opts),
ConnectTimeout = maps:get(connect_timeout, Opts),
RecvTimeout = maps:get(recv_timeout, Opts),
HttpOptions = [
{connect_timeout, ConnectTimeout},
{timeout, RecvTimeout},
{autoredirect, false}
],
RequestOptions = [{body_format, binary}],
case httpc:request(HttpMethod, RequestTuple, HttpOptions, RequestOptions, profile_for(Workload)) of
{ok, {{_HttpVersion, StatusCode, _ReasonPhrase}, RespHeaders, RespBody}} ->
{ok, StatusCode, normalize_response_headers(RespHeaders), ensure_binary(RespBody)};
{error, Reason} ->
{error, Reason}
end.
-spec normalize_method(method() | atom()) -> method().
normalize_method(post) -> post;
normalize_method(get) -> get;
normalize_method(put) -> put;
normalize_method(patch) -> patch;
normalize_method(delete) -> delete;
normalize_method(head) -> head;
normalize_method(options) -> options;
normalize_method(_) -> post.
-spec build_request_tuple(string(), [{string(), string()}], iodata() | undefined, request_options()) ->
{string(), [{string(), string()}]}
| {string(), [{string(), string()}], string(), iodata()}.
build_request_tuple(Url, Headers, undefined, _Opts) ->
{Url, Headers};
build_request_tuple(Url, Headers, Body, Opts) ->
ContentType = resolve_content_type(Headers, Opts),
{Url, Headers, ContentType, Body}.
-spec resolve_content_type([{string(), string()}], request_options()) -> string().
resolve_content_type(Headers, Opts) ->
case maps:get(content_type, Opts, undefined) of
undefined ->
case find_content_type_header(Headers) of
undefined -> "application/json";
Value -> Value
end;
Value ->
ensure_list(Value)
end.
-spec find_content_type_header([{string(), string()}]) -> string() | undefined.
find_content_type_header([]) ->
undefined;
find_content_type_header([{Name, Value} | Rest]) ->
case string:lowercase(Name) of
"content-type" -> Value;
_ -> find_content_type_header(Rest)
end.
-spec normalize_request_headers(request_headers()) -> [{string(), string()}].
normalize_request_headers(Headers) ->
[
{ensure_list(Name), ensure_list(Value)}
|| {Name, Value} <- Headers
].
-spec normalize_response_headers([{string(), string()}]) -> [{binary(), binary()}].
normalize_response_headers(Headers) ->
[
{list_to_binary(Name), list_to_binary(Value)}
|| {Name, Value} <- Headers
].
-spec extract_host_key(iodata()) -> binary().
extract_host_key(Url) ->
UrlString = ensure_list(Url),
try
Parsed = uri_string:parse(UrlString),
case maps:get(host, Parsed, undefined) of
undefined -> <<"unknown">>;
Host when is_binary(Host) -> normalize_host(Host);
Host when is_list(Host) -> normalize_host(list_to_binary(Host));
_ -> <<"unknown">>
end
catch
_:_ -> <<"unknown">>
end.
-spec normalize_host(binary()) -> binary().
normalize_host(Host) ->
list_to_binary(string:lowercase(binary_to_list(Host))).
-spec ensure_binary(iodata()) -> binary().
ensure_binary(Value) when is_binary(Value) ->
Value;
ensure_binary(Value) ->
iolist_to_binary(Value).
-spec ensure_list(iodata()) -> string().
ensure_list(Value) when is_binary(Value) ->
binary_to_list(Value);
ensure_list(Value) when is_list(Value) ->
Value;
ensure_list(Value) when is_atom(Value) ->
atom_to_list(Value);
ensure_list(Value) when is_integer(Value) ->
integer_to_list(Value);
ensure_list(_Value) ->
"".
-spec ensure_httpc_profile(atom(), workload()) -> ok.
ensure_httpc_profile(Profile, Workload) ->
_ =
case inets:start(httpc, [{profile, Profile}]) of
{ok, _Pid} -> ok;
{error, {already_started, _Pid}} -> ok;
{error, {already_started, _Pid, _}} -> ok;
{error, _Reason} -> ok
end,
Options = workload_httpc_options(Workload),
_ = httpc:set_options(Options, Profile),
ok.
-spec workload_httpc_options(workload()) -> list().
workload_httpc_options(rpc) ->
[
{max_sessions, 1024},
{max_keep_alive_length, 256}
];
workload_httpc_options(push) ->
[
{max_sessions, 2048},
{max_keep_alive_length, 512}
].
-spec merged_workload_options(workload(), request_options()) -> request_options().
merged_workload_options(Workload, Opts) ->
maps:merge(default_options(Workload), Opts).
-spec default_options(workload()) -> request_options().
default_options(rpc) ->
#{
connect_timeout => get_int_or_default(gateway_http_rpc_connect_timeout_ms, ?DEFAULT_RPC_CONNECT_TIMEOUT_MS),
recv_timeout => get_int_or_default(gateway_http_rpc_recv_timeout_ms, ?DEFAULT_RPC_RECV_TIMEOUT_MS),
max_concurrency =>
get_int_or_default(gateway_http_rpc_max_concurrency, ?DEFAULT_RPC_MAX_CONCURRENCY),
failure_threshold =>
get_int_or_default(gateway_http_failure_threshold, ?DEFAULT_FAILURE_THRESHOLD),
recovery_timeout_ms =>
get_int_or_default(gateway_http_recovery_timeout_ms, ?DEFAULT_RECOVERY_TIMEOUT_MS),
content_type => <<"application/json">>
};
default_options(push) ->
#{
connect_timeout => get_int_or_default(gateway_http_push_connect_timeout_ms, ?DEFAULT_PUSH_CONNECT_TIMEOUT_MS),
recv_timeout => get_int_or_default(gateway_http_push_recv_timeout_ms, ?DEFAULT_PUSH_RECV_TIMEOUT_MS),
max_concurrency =>
get_int_or_default(gateway_http_push_max_concurrency, ?DEFAULT_PUSH_MAX_CONCURRENCY),
failure_threshold =>
get_int_or_default(gateway_http_failure_threshold, ?DEFAULT_FAILURE_THRESHOLD),
recovery_timeout_ms =>
get_int_or_default(gateway_http_recovery_timeout_ms, ?DEFAULT_RECOVERY_TIMEOUT_MS),
content_type => <<"application/octet-stream">>
}.
-spec cleanup_interval_ms() -> pos_integer().
cleanup_interval_ms() ->
get_int_or_default(gateway_http_cleanup_interval_ms, ?DEFAULT_CLEANUP_INTERVAL_MS).
-spec cleanup_max_age_ms() -> pos_integer().
cleanup_max_age_ms() ->
get_int_or_default(gateway_http_cleanup_max_age_ms, ?DEFAULT_CLEANUP_MAX_AGE_MS).
-spec get_int_or_default(atom(), integer()) -> integer().
get_int_or_default(Key, Default) ->
case fluxer_gateway_env:get_optional(Key) of
Value when is_integer(Value), Value > 0 -> Value;
_ -> Default
end.
-spec profile_for(workload()) -> atom().
profile_for(rpc) ->
gateway_http_rpc_profile;
profile_for(push) ->
gateway_http_push_profile.
-spec first_stack_frame(list()) -> term().
first_stack_frame([Frame | _]) ->
Frame;
first_stack_frame([]) ->
undefined.

View File

@@ -19,21 +19,38 @@
-export([execute_method/2]).
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-endif.
-define(CALL_LOOKUP_TIMEOUT, 2000).
-define(CALL_CREATE_TIMEOUT, 10000).
-spec execute_method(binary(), map()) -> term().
execute_method(<<"call.get">>, #{<<"channel_id">> := ChannelIdBin}) ->
ChannelId = validation:snowflake_or_throw(<<"channel_id">>, ChannelIdBin),
case gen_server:call(call_manager, {lookup, ChannelId}, 5000) of
case lookup_call(ChannelId) of
{ok, Pid} ->
case gen_server:call(Pid, {get_state}, 5000) of
{ok, CallData} ->
CallData;
_ ->
throw({error, <<"Failed to get call state">>})
case gen_server:call(Pid, {get_state}, ?CALL_LOOKUP_TIMEOUT) of
{ok, CallData} -> CallData;
_ -> throw({error, <<"call_state_error">>})
end;
{error, not_found} ->
null;
not_found ->
null
end;
execute_method(<<"call.get_pending_joins">>, #{<<"channel_id">> := ChannelIdBin}) ->
ChannelId = validation:snowflake_or_throw(<<"channel_id">>, ChannelIdBin),
case lookup_call(ChannelId) of
{ok, Pid} ->
case gen_server:call(Pid, {get_pending_connections}, ?CALL_LOOKUP_TIMEOUT) of
#{pending_joins := PendingJoins} ->
#{<<"pending_joins">> => PendingJoins};
_ ->
throw({error, <<"call_pending_joins_error">>})
end;
not_found ->
#{<<"pending_joins">> => []}
end;
execute_method(<<"call.create">>, Params) ->
#{
<<"channel_id">> := ChannelIdBin,
@@ -42,12 +59,10 @@ execute_method(<<"call.create">>, Params) ->
<<"ringing">> := RingingBins,
<<"recipients">> := RecipientsBins
} = Params,
ChannelId = validation:snowflake_or_throw(<<"channel_id">>, ChannelIdBin),
MessageId = validation:snowflake_or_throw(<<"message_id">>, MessageIdBin),
Ringing = validation:snowflake_list_or_throw(<<"ringing">>, RingingBins),
Recipients = validation:snowflake_list_or_throw(<<"recipients">>, RecipientsBins),
CallData = #{
channel_id => ChannelId,
message_id => MessageId,
@@ -55,161 +70,123 @@ execute_method(<<"call.create">>, Params) ->
ringing => Ringing,
recipients => Recipients
},
case gen_server:call(call_manager, {create, ChannelId, CallData}, 10000) of
case gen_server:call(call_manager, {create, ChannelId, CallData}, ?CALL_CREATE_TIMEOUT) of
{ok, Pid} ->
case gen_server:call(Pid, {get_state}, 5000) of
{ok, CallState} ->
CallState;
_ ->
throw({error, <<"Failed to get call state after creation">>})
case gen_server:call(Pid, {get_state}, ?CALL_LOOKUP_TIMEOUT) of
{ok, CallState} -> CallState;
_ -> throw({error, <<"call_state_error">>})
end;
{error, already_exists} ->
throw({error, <<"Call already exists">>});
throw({error, <<"call_already_exists">>});
{error, Reason} ->
throw({error, iolist_to_binary(io_lib:format("Failed to create call: ~p", [Reason]))})
throw({error, iolist_to_binary(io_lib:format("create_call_error: ~p", [Reason]))})
end;
execute_method(<<"call.update_region">>, #{
<<"channel_id">> := ChannelIdBin, <<"region">> := Region
execute_method(<<"call.update_region">>, #{<<"channel_id">> := ChannelIdBin, <<"region">> := Region}) ->
ChannelId = validation:snowflake_or_throw(<<"channel_id">>, ChannelIdBin),
with_call(ChannelId, fun(Pid) ->
case gen_server:call(Pid, {update_region, Region}, ?CALL_LOOKUP_TIMEOUT) of
ok -> true;
_ -> throw({error, <<"update_region_error">>})
end
end);
execute_method(<<"call.ring">>, #{
<<"channel_id">> := ChannelIdBin, <<"recipients">> := RecipientsBin
}) ->
ChannelId = validation:snowflake_or_throw(<<"channel_id">>, ChannelIdBin),
case gen_server:call(call_manager, {lookup, ChannelId}, 5000) of
{ok, Pid} ->
case gen_server:call(Pid, {update_region, Region}, 5000) of
ok ->
true;
_ ->
throw({error, <<"Failed to update region">>})
end;
not_found ->
throw({error, <<"Call not found">>})
end;
execute_method(<<"call.ring">>, Params) ->
#{<<"channel_id">> := ChannelIdBin, <<"recipients">> := RecipientsBin} = Params,
Recipients = validation:snowflake_list_or_throw(<<"recipients">>, RecipientsBin),
with_call(ChannelId, fun(Pid) ->
case gen_server:call(Pid, {ring_recipients, Recipients}, ?CALL_LOOKUP_TIMEOUT) of
ok -> true;
_ -> throw({error, <<"ring_recipients_error">>})
end
end);
execute_method(<<"call.stop_ringing">>, #{
<<"channel_id">> := ChannelIdBin, <<"recipients">> := RecipientsBin
}) ->
ChannelId = validation:snowflake_or_throw(<<"channel_id">>, ChannelIdBin),
Recipients = validation:snowflake_list_or_throw(<<"recipients">>, RecipientsBin),
case gen_server:call(call_manager, {lookup, ChannelId}, 5000) of
{ok, Pid} ->
case gen_server:call(Pid, {ring_recipients, Recipients}, 5000) of
ok ->
true;
_ ->
throw({error, <<"Failed to ring recipients">>})
end;
not_found ->
throw({error, <<"Call not found">>})
end;
execute_method(<<"call.stop_ringing">>, Params) ->
#{<<"channel_id">> := ChannelIdBin, <<"recipients">> := RecipientsBin} = Params,
ChannelId = validation:snowflake_or_throw(<<"channel_id">>, ChannelIdBin),
Recipients = validation:snowflake_list_or_throw(<<"recipients">>, RecipientsBin),
case gen_server:call(call_manager, {lookup, ChannelId}, 5000) of
{ok, Pid} ->
case gen_server:call(Pid, {stop_ringing, Recipients}, 5000) of
ok ->
true;
_ ->
throw({error, <<"Failed to stop ringing">>})
end;
not_found ->
throw({error, <<"Call not found">>})
end;
execute_method(<<"call.join">>, Params) ->
#{
<<"channel_id">> := ChannelIdBin,
<<"user_id">> := UserIdBin,
<<"session_id">> := SessionIdBin,
<<"voice_state">> := VoiceState
} = Params,
with_call(ChannelId, fun(Pid) ->
case gen_server:call(Pid, {stop_ringing, Recipients}, ?CALL_LOOKUP_TIMEOUT) of
ok -> true;
_ -> throw({error, <<"stop_ringing_error">>})
end
end);
execute_method(<<"call.join">>, #{
<<"channel_id">> := ChannelIdBin,
<<"user_id">> := UserIdBin,
<<"session_id">> := SessionIdBin,
<<"voice_state">> := VoiceState
}) ->
ChannelId = validation:snowflake_or_throw(<<"channel_id">>, ChannelIdBin),
UserId = validation:snowflake_or_throw(<<"user_id">>, UserIdBin),
SessionId = SessionIdBin,
case gen_server:call(session_manager, {lookup, SessionId}, 5000) of
case session_manager:lookup(SessionId) of
{ok, SessionPid} ->
case gen_server:call(call_manager, {lookup, ChannelId}, 5000) of
{ok, CallPid} ->
case
gen_server:call(
CallPid, {join, UserId, VoiceState, SessionId, SessionPid}, 5000
)
of
ok ->
true;
_ ->
throw({error, <<"Failed to join call">>})
end;
not_found ->
throw({error, <<"Call not found">>})
end;
not_found ->
throw({error, <<"Session not found">>})
with_call(ChannelId, fun(CallPid) ->
gen_server:cast(CallPid, {join_async, UserId, VoiceState, SessionId, SessionPid}),
true
end);
{error, not_found} ->
throw({error, <<"session_not_found">>})
end;
execute_method(<<"call.leave">>, #{<<"channel_id">> := ChannelIdBin, <<"session_id">> := SessionId}) ->
ChannelId = validation:snowflake_or_throw(<<"channel_id">>, ChannelIdBin),
case gen_server:call(call_manager, {lookup, ChannelId}, 5000) of
{ok, Pid} ->
case gen_server:call(Pid, {leave, SessionId}, 5000) of
ok ->
true;
_ ->
throw({error, <<"Failed to leave call">>})
end;
not_found ->
throw({error, <<"Call not found">>})
end;
with_call(ChannelId, fun(Pid) ->
case gen_server:call(Pid, {leave, SessionId}, ?CALL_LOOKUP_TIMEOUT) of
ok -> true;
_ -> throw({error, <<"leave_call_error">>})
end
end);
execute_method(<<"call.delete">>, #{<<"channel_id">> := ChannelIdBin}) ->
ChannelId = validation:snowflake_or_throw(<<"channel_id">>, ChannelIdBin),
case gen_server:call(call_manager, {terminate_call, ChannelId}, 5000) of
ok ->
true;
{error, not_found} ->
throw({error, <<"Call not found">>});
_ ->
throw({error, <<"Failed to delete call">>})
case gen_server:call(call_manager, {terminate_call, ChannelId}, ?CALL_LOOKUP_TIMEOUT) of
ok -> true;
{error, not_found} -> throw({error, <<"call_not_found">>});
_ -> throw({error, <<"delete_call_error">>})
end;
execute_method(<<"call.confirm_connection">>, Params) ->
#{<<"channel_id">> := ChannelIdBin, <<"connection_id">> := ConnectionId} = Params,
execute_method(<<"call.confirm_connection">>, #{
<<"channel_id">> := ChannelIdBin, <<"connection_id">> := ConnectionId
}) ->
ChannelId = validation:snowflake_or_throw(<<"channel_id">>, ChannelIdBin),
logger:debug(
"[gateway_rpc_call] call.confirm_connection channel_id=~p connection_id=~p",
[ChannelId, ConnectionId]
),
case gen_server:call(call_manager, {lookup, ChannelId}, 5000) of
case lookup_call(ChannelId) of
{ok, Pid} ->
gen_server:call(Pid, {confirm_connection, ConnectionId}, 5000);
{error, not_found} ->
logger:debug(
"[gateway_rpc_call] call.confirm_connection call not found for channel_id=~p", [
ChannelId
]
),
#{success => true, call_not_found => true};
gen_server:call(Pid, {confirm_connection, ConnectionId}, ?CALL_LOOKUP_TIMEOUT);
not_found ->
logger:debug(
"[gateway_rpc_call] call.confirm_connection call manager returned not_found for channel_id=~p",
[ChannelId]
),
#{success => true, call_not_found => true}
end;
execute_method(<<"call.disconnect_user_if_in_channel">>, Params) ->
#{<<"channel_id">> := ChannelIdBin, <<"user_id">> := UserIdBin} = Params,
execute_method(
<<"call.disconnect_user_if_in_channel">>,
#{<<"channel_id">> := ChannelIdBin, <<"user_id">> := UserIdBin} = Params
) ->
ChannelId = validation:snowflake_or_throw(<<"channel_id">>, ChannelIdBin),
UserId = validation:snowflake_or_throw(<<"user_id">>, UserIdBin),
ConnectionId = maps:get(<<"connection_id">>, Params, undefined),
case gen_server:call(call_manager, {lookup, ChannelId}, 5000) of
case lookup_call(ChannelId) of
{ok, Pid} ->
gen_server:call(
Pid, {disconnect_user_if_in_channel, UserId, ChannelId, ConnectionId}, 5000
Pid,
{disconnect_user_if_in_channel, UserId, ChannelId, ConnectionId},
?CALL_LOOKUP_TIMEOUT
);
{error, not_found} ->
#{success => true, call_not_found => true};
not_found ->
#{success => true, call_not_found => true}
end.
-spec lookup_call(integer()) -> {ok, pid()} | not_found.
lookup_call(ChannelId) ->
case gen_server:call(call_manager, {lookup, ChannelId}, ?CALL_LOOKUP_TIMEOUT) of
{ok, Pid} -> {ok, Pid};
{error, not_found} -> not_found;
not_found -> not_found
end.
-spec with_call(integer(), fun((pid()) -> T)) -> T when T :: term().
with_call(ChannelId, Fun) ->
case lookup_call(ChannelId) of
{ok, Pid} -> Fun(Pid);
not_found -> throw({error, <<"call_not_found">>})
end.
-ifdef(TEST).
-endif.

File diff suppressed because it is too large Load Diff

View File

@@ -21,6 +21,7 @@
-define(JSON_HEADERS, #{<<"content-type">> => <<"application/json">>}).
-spec init(cowboy_req:req(), term()) -> {ok, cowboy_req:req(), term()}.
init(Req0, State) ->
case cowboy_req:method(Req0) of
<<"POST">> ->
@@ -30,27 +31,13 @@ init(Req0, State) ->
{ok, Req, State}
end.
-spec handle_post(cowboy_req:req(), term()) -> {ok, cowboy_req:req(), term()}.
handle_post(Req0, State) ->
case authorize(Req0) of
ok ->
case read_body(Req0) of
{ok, Decoded, Req1} ->
case maps:get(<<"method">>, Decoded, undefined) of
undefined ->
respond(400, #{<<"error">> => <<"Missing method">>}, Req1, State);
Method when is_binary(Method) ->
ParamsValue = maps:get(<<"params">>, Decoded, #{}),
case is_map(ParamsValue) of
true ->
execute_method(Method, ParamsValue, Req1, State);
false ->
respond(
400, #{<<"error">> => <<"Invalid params">>}, Req1, State
)
end;
_ ->
respond(400, #{<<"error">> => <<"Invalid method">>}, Req1, State)
end;
handle_decoded_body(Decoded, Req1, State);
{error, ErrorBody, Req1} ->
respond(400, ErrorBody, Req1, State)
end;
@@ -58,57 +45,98 @@ handle_post(Req0, State) ->
{ok, Req1, State}
end.
-spec handle_decoded_body(map(), cowboy_req:req(), term()) -> {ok, cowboy_req:req(), term()}.
handle_decoded_body(Decoded, Req0, State) ->
case maps:get(<<"method">>, Decoded, undefined) of
undefined ->
respond(400, #{<<"error">> => <<"Missing method">>}, Req0, State);
Method when is_binary(Method) ->
ParamsValue = maps:get(<<"params">>, Decoded, #{}),
case is_map(ParamsValue) of
true ->
execute_method(Method, ParamsValue, Req0, State);
false ->
respond(400, #{<<"error">> => <<"Invalid params">>}, Req0, State)
end;
_ ->
respond(400, #{<<"error">> => <<"Invalid method">>}, Req0, State)
end.
-spec authorize(cowboy_req:req()) -> ok | {error, cowboy_req:req()}.
authorize(Req0) ->
case cowboy_req:header(<<"authorization">>, Req0) of
undefined ->
Req = cowboy_req:reply(
401,
?JSON_HEADERS,
jsx:encode(#{<<"error">> => <<"Unauthorized">>}),
json:encode(#{<<"error">> => <<"Unauthorized">>}),
Req0
),
{error, Req};
AuthHeader ->
case fluxer_gateway_env:get(rpc_secret_key) of
undefined ->
Req = cowboy_req:reply(
500,
?JSON_HEADERS,
jsx:encode(#{<<"error">> => <<"RPC secret not configured">>}),
Req0
),
{error, Req};
Secret when is_binary(Secret) ->
Expected = <<"Bearer ", Secret/binary>>,
case AuthHeader of
Expected ->
ok;
_ ->
Req = cowboy_req:reply(
401,
?JSON_HEADERS,
jsx:encode(#{<<"error">> => <<"Unauthorized">>}),
Req0
),
{error, Req}
end
end
authorize_with_secret(AuthHeader, Req0)
end.
read_body(Req0) ->
read_body(Req0, <<>>).
-spec authorize_with_secret(binary(), cowboy_req:req()) -> ok | {error, cowboy_req:req()}.
authorize_with_secret(AuthHeader, Req0) ->
case fluxer_gateway_env:get(rpc_secret_key) of
undefined ->
Req = cowboy_req:reply(
500,
?JSON_HEADERS,
json:encode(#{<<"error">> => <<"RPC secret not configured">>}),
Req0
),
{error, Req};
Secret when is_binary(Secret) ->
Expected = <<"Bearer ", Secret/binary>>,
check_auth_header(AuthHeader, Expected, Req0)
end.
read_body(Req0, Acc) ->
-spec check_auth_header(binary(), binary(), cowboy_req:req()) -> ok | {error, cowboy_req:req()}.
check_auth_header(AuthHeader, Expected, Req0) ->
case secure_compare(AuthHeader, Expected) of
true ->
ok;
false ->
Req = cowboy_req:reply(
401,
?JSON_HEADERS,
json:encode(#{<<"error">> => <<"Unauthorized">>}),
Req0
),
{error, Req}
end.
-spec secure_compare(binary(), binary()) -> boolean().
secure_compare(Left, Right) when is_binary(Left), is_binary(Right) ->
case byte_size(Left) =:= byte_size(Right) of
true ->
crypto:hash_equals(Left, Right);
false ->
false
end.
-spec read_body(cowboy_req:req()) ->
{ok, map(), cowboy_req:req()} | {error, map(), cowboy_req:req()}.
read_body(Req0) ->
read_body_chunks(Req0, <<>>).
-spec read_body_chunks(cowboy_req:req(), binary()) ->
{ok, map(), cowboy_req:req()} | {error, map(), cowboy_req:req()}.
read_body_chunks(Req0, Acc) ->
case cowboy_req:read_body(Req0) of
{ok, Body, Req1} ->
FullBody = <<Acc/binary, Body/binary>>,
decode_body(FullBody, Req1);
{more, Body, Req1} ->
read_body(Req1, <<Acc/binary, Body/binary>>)
read_body_chunks(Req1, <<Acc/binary, Body/binary>>)
end.
-spec decode_body(binary(), cowboy_req:req()) ->
{ok, map(), cowboy_req:req()} | {error, map(), cowboy_req:req()}.
decode_body(Body, Req0) ->
case catch jsx:decode(Body, [return_maps]) of
case catch json:decode(Body) of
{'EXIT', _Reason} ->
{error, #{<<"error">> => <<"Invalid JSON payload">>}, Req0};
Decoded when is_map(Decoded) ->
@@ -117,6 +145,7 @@ decode_body(Body, Req0) ->
{error, #{<<"error">> => <<"Invalid request body">>}, Req0}
end.
-spec execute_method(binary(), map(), cowboy_req:req(), term()) -> {ok, cowboy_req:req(), term()}.
execute_method(Method, Params, Req0, State) ->
try
Result = gateway_rpc_router:execute(Method, Params),
@@ -124,10 +153,15 @@ execute_method(Method, Params, Req0, State) ->
catch
throw:{error, Message} ->
respond(400, #{<<"error">> => Message}, Req0, State);
exit:timeout ->
respond(504, #{<<"error">> => <<"timeout">>}, Req0, State);
exit:{timeout, _} ->
respond(504, #{<<"error">> => <<"timeout">>}, Req0, State);
_:_ ->
respond(500, #{<<"error">> => <<"Internal error">>}, Req0, State)
end.
-spec respond(pos_integer(), map(), cowboy_req:req(), term()) -> {ok, cowboy_req:req(), term()}.
respond(Status, Body, Req0, State) ->
Req = cowboy_req:reply(Status, ?JSON_HEADERS, jsx:encode(Body), Req0),
Req = cowboy_req:reply(Status, ?JSON_HEADERS, json:encode(Body), Req0),
{ok, Req, State}.

View File

@@ -19,6 +19,7 @@
-export([execute_method/2, get_local_node_stats/0]).
-spec execute_method(binary(), map()) -> map().
execute_method(<<"process.memory_stats">>, Params) ->
Limit =
case maps:get(<<"limit">>, Params, undefined) of
@@ -27,42 +28,22 @@ execute_method(<<"process.memory_stats">>, Params) ->
LimitValue ->
validation:snowflake_or_throw(<<"limit">>, LimitValue)
end,
Guilds = process_memory_stats:get_guild_memory_stats(Limit),
#{<<"guilds">> => Guilds};
GuildsWithStringMemory = [G#{memory := integer_to_binary(maps:get(memory, G))} || G <- Guilds],
#{<<"guilds">> => GuildsWithStringMemory};
execute_method(<<"process.node_stats">>, _Params) ->
get_local_node_stats().
-spec get_local_node_stats() -> map().
get_local_node_stats() ->
SessionCount =
case gen_server:call(session_manager, get_global_count, 1000) of
{ok, SC} -> SC;
_ -> 0
end,
GuildCount =
case gen_server:call(guild_manager, get_global_count, 1000) of
{ok, GC} -> GC;
_ -> 0
end,
PresenceCount =
case gen_server:call(presence_manager, get_global_count, 1000) of
{ok, PC} -> PC;
_ -> 0
end,
CallCount =
case gen_server:call(call_manager, get_global_count, 1000) of
{ok, CC} -> CC;
_ -> 0
end,
SessionCount = get_manager_count(session_manager),
GuildCount = get_manager_count(guild_manager),
PresenceCount = get_manager_count(presence_manager),
CallCount = get_manager_count(call_manager),
MemoryInfo = erlang:memory(),
TotalMemory = proplists:get_value(total, MemoryInfo, 0),
ProcessMemory = proplists:get_value(processes, MemoryInfo, 0),
SystemMemory = proplists:get_value(system, MemoryInfo, 0),
#{
<<"status">> => <<"healthy">>,
<<"sessions">> => SessionCount,
@@ -70,11 +51,18 @@ get_local_node_stats() ->
<<"presences">> => PresenceCount,
<<"calls">> => CallCount,
<<"memory">> => #{
<<"total">> => TotalMemory,
<<"processes">> => ProcessMemory,
<<"system">> => SystemMemory
<<"total">> => integer_to_binary(TotalMemory),
<<"processes">> => integer_to_binary(ProcessMemory),
<<"system">> => integer_to_binary(SystemMemory)
},
<<"process_count">> => erlang:system_info(process_count),
<<"process_limit">> => erlang:system_info(process_limit),
<<"uptime_seconds">> => element(1, erlang:statistics(wall_clock)) div 1000
}.
-spec get_manager_count(atom()) -> non_neg_integer().
get_manager_count(Manager) ->
case gen_server:call(Manager, get_global_count, 1000) of
{ok, Count} -> Count;
_ -> 0
end.

View File

@@ -19,6 +19,13 @@
-export([execute_method/2]).
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-endif.
-define(PRESENCE_LOOKUP_TIMEOUT, 2000).
-spec execute_method(binary(), map()) -> term().
execute_method(<<"presence.dispatch">>, #{
<<"user_id">> := UserIdBin, <<"event">> := Event, <<"data">> := Data
}) ->
@@ -35,132 +42,73 @@ execute_method(<<"presence.join_guild">>, #{
}) ->
UserId = validation:snowflake_or_throw(<<"user_id">>, UserIdBin),
GuildId = validation:snowflake_or_throw(<<"guild_id">>, GuildIdBin),
case gen_server:call(presence_manager, {lookup, UserId}, 10000) of
{ok, Pid} ->
case gen_server:call(Pid, {join_guild, GuildId}, 10000) of
ok -> true;
_ -> throw({error, <<"Join guild failed">>})
end;
not_found ->
true;
{error, _} ->
true;
_ ->
true
end;
presence_manager:lookup_async(UserId, {join_guild, GuildId}),
true;
execute_method(<<"presence.leave_guild">>, #{
<<"user_id">> := UserIdBin, <<"guild_id">> := GuildIdBin
}) ->
UserId = validation:snowflake_or_throw(<<"user_id">>, UserIdBin),
GuildId = validation:snowflake_or_throw(<<"guild_id">>, GuildIdBin),
case gen_server:call(presence_manager, {lookup, UserId}, 10000) of
{ok, Pid} ->
case gen_server:call(Pid, {leave_guild, GuildId}, 10000) of
ok -> true;
_ -> throw({error, <<"Leave guild failed">>})
end;
not_found ->
true;
{error, _} ->
true;
_ ->
true
end;
presence_manager:lookup_async(UserId, {leave_guild, GuildId}),
true;
execute_method(<<"presence.terminate_sessions">>, #{
<<"user_id">> := UserIdBin, <<"session_id_hashes">> := SessionIdHashes
}) ->
UserId = validation:snowflake_or_throw(<<"user_id">>, UserIdBin),
case gen_server:call(presence_manager, {lookup, UserId}, 10000) of
{ok, Pid} ->
case gen_server:call(Pid, {terminate_session, SessionIdHashes}, 10000) of
ok -> true;
_ -> throw({error, <<"Terminate session failed">>})
end;
not_found ->
true;
{error, _} ->
true;
_ ->
true
end;
execute_method(<<"presence.terminate_all_sessions">>, #{
<<"user_id">> := UserIdBin
}) ->
presence_manager:lookup_async(UserId, {terminate_session, SessionIdHashes}),
true;
execute_method(<<"presence.terminate_all_sessions">>, #{<<"user_id">> := UserIdBin}) ->
UserId = validation:snowflake_or_throw(<<"user_id">>, UserIdBin),
case presence_manager:terminate_all_sessions(UserId) of
ok -> true;
_ -> throw({error, <<"Terminate all sessions failed">>})
_ -> throw({error, <<"terminate_sessions_error">>})
end;
execute_method(<<"presence.has_active">>, #{<<"user_id">> := UserIdBin}) ->
UserId = validation:snowflake_or_throw(<<"user_id">>, UserIdBin),
case gen_server:call(presence_manager, {lookup, UserId}, 10000) of
{ok, _Pid} ->
#{<<"has_active">> => true};
_ ->
#{<<"has_active">> => false}
case gen_server:call(presence_manager, {lookup, UserId}, ?PRESENCE_LOOKUP_TIMEOUT) of
{ok, _Pid} -> #{<<"has_active">> => true};
_ -> #{<<"has_active">> => false}
end;
execute_method(<<"presence.add_temporary_guild">>, #{
<<"user_id">> := UserIdBin, <<"guild_id">> := GuildIdBin
}) ->
UserId = validation:snowflake_or_throw(<<"user_id">>, UserIdBin),
GuildId = validation:snowflake_or_throw(<<"guild_id">>, GuildIdBin),
case gen_server:call(presence_manager, {lookup, UserId}, 10000) of
{ok, Pid} ->
case gen_server:call(Pid, {add_temporary_guild, GuildId}, 10000) of
ok -> true;
_ -> throw({error, <<"Add temporary guild failed">>})
end;
not_found ->
true;
{error, _} ->
true;
_ ->
true
end;
presence_manager:lookup_async(UserId, {add_temporary_guild, GuildId}),
true;
execute_method(<<"presence.remove_temporary_guild">>, #{
<<"user_id">> := UserIdBin, <<"guild_id">> := GuildIdBin
}) ->
UserId = validation:snowflake_or_throw(<<"user_id">>, UserIdBin),
GuildId = validation:snowflake_or_throw(<<"guild_id">>, GuildIdBin),
case gen_server:call(presence_manager, {lookup, UserId}, 10000) of
{ok, Pid} ->
case gen_server:call(Pid, {remove_temporary_guild, GuildId}, 10000) of
ok -> true;
_ -> throw({error, <<"Remove temporary guild failed">>})
end;
not_found ->
true;
{error, _} ->
true;
_ ->
true
end;
presence_manager:lookup_async(UserId, {remove_temporary_guild, GuildId}),
true;
execute_method(<<"presence.sync_group_dm_recipients">>, #{
<<"user_id">> := UserIdBin, <<"recipients_by_channel">> := RecipientsByChannel
}) ->
UserId = validation:snowflake_or_throw(<<"user_id">>, UserIdBin),
NormalizedRecipients =
maps:from_list([
{
validation:snowflake_or_throw(<<"channel_id">>, ChannelIdBin),
[validation:snowflake_or_throw(<<"recipient_id">>, RBin) || RBin <- Recipients]
}
|| {ChannelIdBin, Recipients} <- maps:to_list(RecipientsByChannel)
]),
case gen_server:call(presence_manager, {lookup, UserId}, 10000) of
NormalizedRecipients = normalize_recipients(RecipientsByChannel),
case gen_server:call(presence_manager, {lookup, UserId}, ?PRESENCE_LOOKUP_TIMEOUT) of
{ok, Pid} ->
gen_server:cast(Pid, {sync_group_dm_recipients, NormalizedRecipients}),
true;
not_found ->
true;
{error, _} ->
true;
_ ->
true
end.
-spec normalize_recipients(map()) -> map().
normalize_recipients(RecipientsByChannel) ->
maps:from_list([
{
validation:snowflake_or_throw(<<"channel_id">>, ChannelIdBin),
[validation:snowflake_or_throw(<<"recipient_id">>, RBin) || RBin <- Recipients]
}
|| {ChannelIdBin, Recipients} <- maps:to_list(RecipientsByChannel)
]).
-spec handle_offline_dispatch(atom(), integer(), map()) -> true.
handle_offline_dispatch(message_create, UserId, Data) ->
AuthorIdBin = maps:get(<<"id">>, maps:get(<<"author">>, Data, #{}), <<"0">>),
AuthorIdBin = maps:get(<<"id">>, maps:get(<<"author">>, Data, #{}), undefined),
AuthorId = validation:snowflake_or_throw(<<"author_id">>, AuthorIdBin),
push:handle_message_create(#{
message_data => Data,
@@ -178,5 +126,16 @@ handle_offline_dispatch(relationship_remove, UserId, _Data) ->
handle_offline_dispatch(_Event, _UserId, _Data) ->
true.
-spec sync_blocked_ids_for_user(integer()) -> ok.
sync_blocked_ids_for_user(_UserId) ->
ok.
-ifdef(TEST).
normalize_recipients_test() ->
Input = #{<<"123">> => [<<"456">>, <<"789">>]},
Result = normalize_recipients(Input),
?assert(is_map(Result)),
?assertEqual(1, maps:size(Result)).
-endif.

View File

@@ -19,6 +19,7 @@
-export([execute_method/2]).
-spec execute_method(binary(), map()) -> true.
execute_method(<<"push.sync_user_guild_settings">>, #{
<<"user_id">> := UserIdBin,
<<"guild_id">> := GuildIdBin,

View File

@@ -19,18 +19,33 @@
-export([execute/2]).
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-endif.
-spec execute(binary(), map()) -> term().
execute(Method, Params) ->
case Method of
<<"guild.", _/binary>> ->
gateway_rpc_guild:execute_method(Method, Params);
<<"presence.", _/binary>> ->
gateway_rpc_presence:execute_method(Method, Params);
<<"push.", _/binary>> ->
gateway_rpc_push:execute_method(Method, Params);
<<"call.", _/binary>> ->
gateway_rpc_call:execute_method(Method, Params);
<<"process.", _/binary>> ->
gateway_rpc_misc:execute_method(Method, Params);
_ ->
throw({error, <<"Unknown method: ", Method/binary>>})
end.
route_method(Method, Params).
-spec route_method(binary(), map()) -> term().
route_method(<<"guild.", _/binary>> = Method, Params) ->
gateway_rpc_guild:execute_method(Method, Params);
route_method(<<"presence.", _/binary>> = Method, Params) ->
gateway_rpc_presence:execute_method(Method, Params);
route_method(<<"push.", _/binary>> = Method, Params) ->
gateway_rpc_push:execute_method(Method, Params);
route_method(<<"call.", _/binary>> = Method, Params) ->
gateway_rpc_call:execute_method(Method, Params);
route_method(<<"voice.", _/binary>> = Method, Params) ->
gateway_rpc_voice:execute_method(Method, Params);
route_method(<<"process.", _/binary>> = Method, Params) ->
gateway_rpc_misc:execute_method(Method, Params);
route_method(Method, _Params) ->
throw({error, <<"Unknown method: ", Method/binary>>}).
-ifdef(TEST).
route_method_guild_test() ->
?assertThrow({error, _}, route_method(<<"unknown.method">>, #{})).
-endif.

View File

@@ -0,0 +1,446 @@
%% Copyright (C) 2026 Fluxer Contributors
%%
%% This file is part of Fluxer.
%%
%% Fluxer is free software: you can redistribute it and/or modify
%% it under the terms of the GNU Affero General Public License as published by
%% the Free Software Foundation, either version 3 of the License, or
%% (at your option) any later version.
%%
%% Fluxer is distributed in the hope that it will be useful,
%% but WITHOUT ANY WARRANTY; without even the implied warranty of
%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
%% GNU Affero General Public License for more details.
%%
%% You should have received a copy of the GNU Affero General Public License
%% along with Fluxer. If not, see <https://www.gnu.org/licenses/>.
-module(gateway_rpc_tcp_connection).
-export([serve/1]).
-define(DEFAULT_MAX_INFLIGHT, 1024).
-define(DEFAULT_MAX_INPUT_BUFFER_BYTES, 2097152).
-define(DEFAULT_DISPATCH_RESERVE_DIVISOR, 8).
-define(MAX_FRAME_BYTES, 1048576).
-define(PROTOCOL_VERSION, <<"fluxer.rpc.tcp.v1">>).
-type state() :: #{
socket := inet:socket(),
buffer := binary(),
authenticated := boolean(),
inflight := non_neg_integer(),
max_inflight := pos_integer(),
max_input_buffer_bytes := pos_integer()
}.
-type rpc_result() :: {ok, term()} | {error, binary()}.
-spec serve(inet:socket()) -> ok.
serve(Socket) ->
ok = inet:setopts(Socket, [{active, once}, {nodelay, true}, {keepalive, true}]),
State = #{
socket => Socket,
buffer => <<>>,
authenticated => false,
inflight => 0,
max_inflight => max_inflight(),
max_input_buffer_bytes => max_input_buffer_bytes()
},
loop(State).
-spec loop(state()) -> ok.
loop(#{socket := Socket} = State) ->
receive
{tcp, Socket, Data} ->
case handle_tcp_data(Data, State) of
{ok, NewState} ->
ok = inet:setopts(Socket, [{active, once}]),
loop(NewState);
{stop, Reason, _NewState} ->
logger:debug("Gateway TCP RPC connection closed: ~p", [Reason]),
close_socket(Socket),
ok
end;
{tcp_closed, Socket} ->
ok;
{tcp_error, Socket, Reason} ->
logger:warning("Gateway TCP RPC socket error: ~p", [Reason]),
close_socket(Socket),
ok;
{rpc_response, RequestId, Result} ->
NewState = handle_rpc_response(RequestId, Result, State),
loop(NewState);
_Other ->
loop(State)
end.
-spec handle_tcp_data(binary(), state()) -> {ok, state()} | {stop, term(), state()}.
handle_tcp_data(Data, #{buffer := Buffer, max_input_buffer_bytes := MaxInputBufferBytes} = State) ->
case byte_size(Buffer) + byte_size(Data) =< MaxInputBufferBytes of
false ->
_ = send_error_frame(State, protocol_error_binary(input_buffer_limit_exceeded)),
{stop, input_buffer_limit_exceeded, State};
true ->
Combined = <<Buffer/binary, Data/binary>>,
decode_tcp_frames(Combined, State)
end.
-spec decode_tcp_frames(binary(), state()) -> {ok, state()} | {stop, term(), state()}.
decode_tcp_frames(Combined, State) ->
case decode_frames(Combined, []) of
{ok, Frames, Rest} ->
process_frames(Frames, State#{buffer => Rest});
{error, Reason} ->
_ = send_error_frame(State, protocol_error_binary(Reason)),
{stop, Reason, State}
end.
-spec process_frames([map()], state()) -> {ok, state()} | {stop, term(), state()}.
process_frames([], State) ->
{ok, State};
process_frames([Frame | Rest], State) ->
case process_frame(Frame, State) of
{ok, NewState} ->
process_frames(Rest, NewState);
{stop, Reason, NewState} ->
{stop, Reason, NewState}
end.
-spec process_frame(map(), state()) -> {ok, state()} | {stop, term(), state()}.
process_frame(#{<<"type">> := <<"hello">>} = Frame, #{authenticated := false} = State) ->
handle_hello_frame(Frame, State);
process_frame(#{<<"type">> := <<"hello">>}, State) ->
_ = send_error_frame(State, <<"duplicate_hello">>),
{stop, duplicate_hello, State};
process_frame(#{<<"type">> := <<"request">>} = Frame, #{authenticated := true} = State) ->
handle_request_frame(Frame, State);
process_frame(#{<<"type">> := <<"request">>}, State) ->
_ = send_error_frame(State, <<"unauthorized">>),
{stop, unauthorized, State};
process_frame(#{<<"type">> := <<"ping">>}, State) ->
_ = send_frame(State, #{<<"type">> => <<"pong">>}),
{ok, State};
process_frame(#{<<"type">> := <<"pong">>}, State) ->
{ok, State};
process_frame(#{<<"type">> := <<"close">>}, State) ->
{stop, client_close, State};
process_frame(_Frame, State) ->
_ = send_error_frame(State, <<"unknown_frame_type">>),
{stop, unknown_frame_type, State}.
-spec handle_hello_frame(map(), state()) -> {ok, state()} | {stop, term(), state()}.
handle_hello_frame(Frame, State) ->
case {maps:get(<<"protocol">>, Frame, undefined), maps:get(<<"authorization">>, Frame, undefined)} of
{?PROTOCOL_VERSION, AuthHeader} when is_binary(AuthHeader) ->
authorize_hello(AuthHeader, State);
_ ->
_ = send_error_frame(State, <<"invalid_hello">>),
{stop, invalid_hello, State}
end.
-spec authorize_hello(binary(), state()) -> {ok, state()} | {stop, term(), state()}.
authorize_hello(AuthHeader, State) ->
case fluxer_gateway_env:get(rpc_secret_key) of
Secret when is_binary(Secret) ->
Expected = <<"Bearer ", Secret/binary>>,
case secure_compare(AuthHeader, Expected) of
true ->
HelloAck = #{
<<"type">> => <<"hello_ack">>,
<<"protocol">> => ?PROTOCOL_VERSION,
<<"max_in_flight">> => maps:get(max_inflight, State),
<<"ping_interval_ms">> => 15000
},
_ = send_frame(State, HelloAck),
{ok, State#{authenticated => true}};
false ->
_ = send_error_frame(State, <<"unauthorized">>),
{stop, unauthorized, State}
end;
_ ->
_ = send_error_frame(State, <<"rpc_secret_not_configured">>),
{stop, rpc_secret_not_configured, State}
end.
-spec handle_request_frame(map(), state()) -> {ok, state()}.
handle_request_frame(Frame, State) ->
RequestId = request_id_from_frame(Frame),
Method = maps:get(<<"method">>, Frame, undefined),
case should_reject_request(Method, State) of
true ->
_ =
send_response_frame(
State,
RequestId,
false,
undefined,
<<"overloaded">>
),
{ok, State};
false ->
case {Method, maps:get(<<"params">>, Frame, undefined)} of
{MethodName, Params} when is_binary(RequestId), is_binary(MethodName), is_map(Params) ->
Parent = self(),
_ = spawn(fun() ->
Parent ! {rpc_response, RequestId, execute_method(MethodName, Params)}
end),
{ok, increment_inflight(State)};
_ ->
_ =
send_response_frame(
State,
RequestId,
false,
undefined,
<<"invalid_request">>
),
{ok, State}
end
end.
-spec should_reject_request(term(), state()) -> boolean().
should_reject_request(Method, #{inflight := Inflight, max_inflight := MaxInflight}) ->
case is_dispatch_method(Method) of
true ->
Inflight >= MaxInflight;
false ->
Inflight >= non_dispatch_inflight_limit(MaxInflight)
end.
-spec non_dispatch_inflight_limit(pos_integer()) -> pos_integer().
non_dispatch_inflight_limit(MaxInflight) ->
Reserve = dispatch_reserve_slots(MaxInflight),
max(1, MaxInflight - Reserve).
-spec dispatch_reserve_slots(pos_integer()) -> pos_integer().
dispatch_reserve_slots(MaxInflight) ->
max(1, MaxInflight div ?DEFAULT_DISPATCH_RESERVE_DIVISOR).
-spec is_dispatch_method(term()) -> boolean().
is_dispatch_method(Method) when is_binary(Method) ->
Suffix = <<".dispatch">>,
MethodSize = byte_size(Method),
SuffixSize = byte_size(Suffix),
MethodSize >= SuffixSize andalso
binary:part(Method, MethodSize - SuffixSize, SuffixSize) =:= Suffix;
is_dispatch_method(_) ->
false.
-spec execute_method(binary(), map()) -> rpc_result().
execute_method(Method, Params) ->
try
Result = gateway_rpc_router:execute(Method, Params),
{ok, Result}
catch
throw:{error, Message} ->
{error, error_binary(Message)};
exit:timeout ->
{error, <<"timeout">>};
exit:{timeout, _} ->
{error, <<"timeout">>};
Class:Reason ->
logger:error(
"Gateway TCP RPC method execution failed. method=~ts class=~p reason=~p",
[Method, Class, Reason]
),
{error, <<"internal_error">>}
end.
-spec handle_rpc_response(binary(), rpc_result(), state()) -> state().
handle_rpc_response(RequestId, {ok, Result}, State) ->
_ = send_response_frame(State, RequestId, true, Result, undefined),
decrement_inflight(State);
handle_rpc_response(RequestId, {error, Error}, State) ->
_ = send_response_frame(State, RequestId, false, undefined, Error),
decrement_inflight(State).
-spec send_response_frame(state(), binary(), boolean(), term(), binary() | undefined) -> ok | {error, term()}.
send_response_frame(State, RequestId, true, Result, _Error) ->
send_frame(State, #{
<<"type">> => <<"response">>,
<<"id">> => RequestId,
<<"ok">> => true,
<<"result">> => Result
});
send_response_frame(State, RequestId, false, _Result, Error) ->
send_frame(State, #{
<<"type">> => <<"response">>,
<<"id">> => RequestId,
<<"ok">> => false,
<<"error">> => Error
}).
-spec send_error_frame(state(), binary()) -> ok | {error, term()}.
send_error_frame(State, Error) ->
send_frame(State, #{
<<"type">> => <<"error">>,
<<"error">> => Error
}).
-spec send_frame(state(), map()) -> ok | {error, term()}.
send_frame(#{socket := Socket}, Frame) ->
gen_tcp:send(Socket, encode_frame(Frame)).
-spec encode_frame(map()) -> binary().
encode_frame(Frame) ->
Payload = iolist_to_binary(json:encode(Frame)),
Length = integer_to_binary(byte_size(Payload)),
<<Length/binary, "\n", Payload/binary>>.
-spec decode_frames(binary(), [map()]) -> {ok, [map()], binary()} | {error, term()}.
decode_frames(Buffer, Acc) ->
case binary:match(Buffer, <<"\n">>) of
nomatch ->
{ok, lists:reverse(Acc), Buffer};
{Pos, 1} ->
LengthBin = binary:part(Buffer, 0, Pos),
case parse_length(LengthBin) of
{ok, Length} ->
HeaderSize = Pos + 1,
RequiredSize = HeaderSize + Length,
case byte_size(Buffer) >= RequiredSize of
false ->
{ok, lists:reverse(Acc), Buffer};
true ->
Payload = binary:part(Buffer, HeaderSize, Length),
RestSize = byte_size(Buffer) - RequiredSize,
Rest = binary:part(Buffer, RequiredSize, RestSize),
case decode_payload(Payload) of
{ok, Frame} ->
decode_frames(Rest, [Frame | Acc]);
{error, Reason} ->
{error, Reason}
end
end;
{error, Reason} ->
{error, Reason}
end
end.
-spec decode_payload(binary()) -> {ok, map()} | {error, term()}.
decode_payload(Payload) ->
case catch json:decode(Payload) of
{'EXIT', _} ->
{error, invalid_json};
Frame when is_map(Frame) ->
{ok, Frame};
_ ->
{error, invalid_json}
end.
-spec parse_length(binary()) -> {ok, non_neg_integer()} | {error, term()}.
parse_length(<<>>) ->
{error, invalid_frame_length};
parse_length(LengthBin) ->
try
Length = binary_to_integer(LengthBin),
case Length >= 0 andalso Length =< ?MAX_FRAME_BYTES of
true -> {ok, Length};
false -> {error, invalid_frame_length}
end
catch
_:_ ->
{error, invalid_frame_length}
end.
-spec secure_compare(binary(), binary()) -> boolean().
secure_compare(Left, Right) when is_binary(Left), is_binary(Right) ->
case byte_size(Left) =:= byte_size(Right) of
true ->
crypto:hash_equals(Left, Right);
false ->
false
end.
-spec request_id_from_frame(map()) -> binary().
request_id_from_frame(Frame) ->
case maps:get(<<"id">>, Frame, <<>>) of
Id when is_binary(Id) ->
Id;
Id when is_integer(Id) ->
integer_to_binary(Id);
_ ->
<<>>
end.
-spec increment_inflight(state()) -> state().
increment_inflight(#{inflight := Inflight} = State) ->
State#{inflight => Inflight + 1}.
-spec decrement_inflight(state()) -> state().
decrement_inflight(#{inflight := Inflight} = State) when Inflight > 0 ->
State#{inflight => Inflight - 1};
decrement_inflight(State) ->
State.
-spec error_binary(term()) -> binary().
error_binary(Value) when is_binary(Value) ->
Value;
error_binary(Value) when is_list(Value) ->
unicode:characters_to_binary(Value);
error_binary(Value) when is_atom(Value) ->
atom_to_binary(Value, utf8);
error_binary(Value) ->
unicode:characters_to_binary(io_lib:format("~p", [Value])).
-spec protocol_error_binary(term()) -> binary().
protocol_error_binary(invalid_json) ->
<<"invalid_json">>;
protocol_error_binary(invalid_frame_length) ->
<<"invalid_frame_length">>;
protocol_error_binary(input_buffer_limit_exceeded) ->
<<"input_buffer_limit_exceeded">>.
-spec close_socket(inet:socket()) -> ok.
close_socket(Socket) ->
catch gen_tcp:close(Socket),
ok.
-spec max_inflight() -> pos_integer().
max_inflight() ->
case fluxer_gateway_env:get(gateway_http_rpc_max_concurrency) of
Value when is_integer(Value), Value > 0 ->
Value;
_ ->
?DEFAULT_MAX_INFLIGHT
end.
-spec max_input_buffer_bytes() -> pos_integer().
max_input_buffer_bytes() ->
case fluxer_gateway_env:get(gateway_rpc_tcp_max_input_buffer_bytes) of
Value when is_integer(Value), Value > 0 ->
Value;
_ ->
?DEFAULT_MAX_INPUT_BUFFER_BYTES
end.
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
decode_single_frame_test() ->
Frame = #{<<"type">> => <<"ping">>},
Encoded = encode_frame(Frame),
?assertEqual({ok, [Frame], <<>>}, decode_frames(Encoded, [])).
decode_multiple_frames_test() ->
FrameA = #{<<"type">> => <<"ping">>},
FrameB = #{<<"type">> => <<"pong">>},
Encoded = <<(encode_frame(FrameA))/binary, (encode_frame(FrameB))/binary>>,
?assertEqual({ok, [FrameA, FrameB], <<>>}, decode_frames(Encoded, [])).
decode_partial_frame_test() ->
Frame = #{<<"type">> => <<"ping">>},
Encoded = encode_frame(Frame),
Prefix = binary:part(Encoded, 0, 3),
?assertEqual({ok, [], Prefix}, decode_frames(Prefix, [])).
invalid_length_test() ->
?assertEqual({error, invalid_frame_length}, decode_frames(<<"x\n{}">>, [])).
secure_compare_test() ->
?assert(secure_compare(<<"abc">>, <<"abc">>)),
?assertNot(secure_compare(<<"abc">>, <<"abd">>)),
?assertNot(secure_compare(<<"abc">>, <<"abcd">>)).
-endif.

View File

@@ -0,0 +1,108 @@
%% Copyright (C) 2026 Fluxer Contributors
%%
%% This file is part of Fluxer.
%%
%% Fluxer is free software: you can redistribute it and/or modify
%% it under the terms of the GNU Affero General Public License as published by
%% the Free Software Foundation, either version 3 of the License, or
%% (at your option) any later version.
%%
%% Fluxer is distributed in the hope that it will be useful,
%% but WITHOUT ANY WARRANTY; without even the implied warranty of
%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
%% GNU Affero General Public License for more details.
%%
%% You should have received a copy of the GNU Affero General Public License
%% along with Fluxer. If not, see <https://www.gnu.org/licenses/>.
-module(gateway_rpc_tcp_server).
-behaviour(gen_server).
-export([start_link/0, accept_loop/1]).
-export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]).
-type state() :: #{
listen_socket := inet:socket(),
acceptor_pid := pid(),
port := inet:port_number()
}.
-spec start_link() -> {ok, pid()} | {error, term()}.
start_link() ->
gen_server:start_link({local, ?MODULE}, ?MODULE, [], []).
-spec init([]) -> {ok, state()} | {stop, term()}.
init([]) ->
process_flag(trap_exit, true),
Port = fluxer_gateway_env:get(rpc_tcp_port),
case gen_tcp:listen(Port, listen_options()) of
{ok, ListenSocket} ->
AcceptorPid = spawn_link(?MODULE, accept_loop, [ListenSocket]),
logger:info("Gateway TCP RPC listener started on port ~p", [Port]),
{ok, #{
listen_socket => ListenSocket,
acceptor_pid => AcceptorPid,
port => Port
}};
{error, Reason} ->
{stop, {rpc_tcp_listen_failed, Port, Reason}}
end.
-spec handle_call(term(), gen_server:from(), state()) -> {reply, ok, state()}.
handle_call(_Request, _From, State) ->
{reply, ok, State}.
-spec handle_cast(term(), state()) -> {noreply, state()}.
handle_cast(_Msg, State) ->
{noreply, State}.
-spec handle_info(term(), state()) -> {noreply, state()}.
handle_info({'EXIT', Pid, Reason}, #{acceptor_pid := Pid, listen_socket := ListenSocket} = State) ->
case Reason of
normal ->
{noreply, State};
shutdown ->
{noreply, State};
_ ->
logger:error("Gateway TCP RPC acceptor crashed: ~p", [Reason]),
NewAcceptorPid = spawn_link(?MODULE, accept_loop, [ListenSocket]),
{noreply, State#{acceptor_pid => NewAcceptorPid}}
end;
handle_info(_Info, State) ->
{noreply, State}.
-spec terminate(term(), state()) -> ok.
terminate(_Reason, #{listen_socket := ListenSocket, port := Port}) ->
catch gen_tcp:close(ListenSocket),
logger:info("Gateway TCP RPC listener stopped on port ~p", [Port]),
ok.
-spec code_change(term(), state(), term()) -> {ok, state()}.
code_change(_OldVsn, State, _Extra) ->
{ok, State}.
-spec accept_loop(inet:socket()) -> ok.
accept_loop(ListenSocket) ->
case gen_tcp:accept(ListenSocket) of
{ok, Socket} ->
_ = spawn_link(?MODULE, accept_loop, [ListenSocket]),
gateway_rpc_tcp_connection:serve(Socket);
{error, closed} ->
ok;
{error, Reason} ->
logger:error("Gateway TCP RPC accept failed: ~p", [Reason]),
timer:sleep(200),
accept_loop(ListenSocket)
end.
-spec listen_options() -> [gen_tcp:listen_option()].
listen_options() ->
[
binary,
{packet, raw},
{active, false},
{reuseaddr, true},
{nodelay, true},
{backlog, 4096},
{keepalive, true}
].

View File

@@ -0,0 +1,232 @@
%% Copyright (C) 2026 Fluxer Contributors
%%
%% This file is part of Fluxer.
%%
%% Fluxer is free software: you can redistribute it and/or modify
%% it under the terms of the GNU Affero General Public License as published by
%% the Free Software Foundation, either version 3 of the License, or
%% (at your option) any later version.
%%
%% Fluxer is distributed in the hope that it will be useful,
%% but WITHOUT ANY WARRANTY; without even the implied warranty of
%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
%% GNU Affero General Public License for more details.
%%
%% You should have received a copy of the GNU Affero General Public License
%% along with Fluxer. If not, see <https://www.gnu.org/licenses/>.
-module(gateway_rpc_voice).
-export([execute_method/2]).
-spec execute_method(binary(), map()) -> term().
execute_method(<<"voice.confirm_connection">>, Params) ->
ChannelIdBin = maps:get(<<"channel_id">>, Params),
ConnectionId = maps:get(<<"connection_id">>, Params),
case parse_optional_guild_id(Params) of
undefined ->
gateway_rpc_call:execute_method(
<<"call.confirm_connection">>,
#{
<<"channel_id">> => ChannelIdBin,
<<"connection_id">> => ConnectionId
}
);
GuildId ->
TokenNonce = maps:get(<<"token_nonce">>, Params, undefined),
gateway_rpc_guild:execute_method(
<<"guild.confirm_voice_connection_from_livekit">>,
#{
<<"guild_id">> => integer_to_binary(GuildId),
<<"connection_id">> => ConnectionId,
<<"token_nonce">> => TokenNonce
}
)
end;
execute_method(<<"voice.disconnect_user_if_in_channel">>, Params) ->
ChannelIdBin = maps:get(<<"channel_id">>, Params),
UserIdBin = maps:get(<<"user_id">>, Params),
ConnectionId = maps:get(<<"connection_id">>, Params, undefined),
case parse_optional_guild_id(Params) of
undefined ->
CallParams = #{
<<"channel_id">> => ChannelIdBin,
<<"user_id">> => UserIdBin
},
gateway_rpc_call:execute_method(
<<"call.disconnect_user_if_in_channel">>,
maybe_put_connection_id(ConnectionId, CallParams)
);
GuildId ->
GuildParams = #{
<<"guild_id">> => integer_to_binary(GuildId),
<<"user_id">> => UserIdBin,
<<"expected_channel_id">> => ChannelIdBin
},
gateway_rpc_guild:execute_method(
<<"guild.disconnect_voice_user_if_in_channel">>,
maybe_put_connection_id(ConnectionId, GuildParams)
)
end;
execute_method(<<"voice.get_voice_states_for_channel">>, Params) ->
ChannelIdBin = maps:get(<<"channel_id">>, Params),
case parse_optional_guild_id(Params) of
undefined ->
build_dm_voice_states_response(ChannelIdBin);
GuildId ->
gateway_rpc_guild:execute_method(
<<"guild.get_voice_states_for_channel">>,
#{
<<"guild_id">> => integer_to_binary(GuildId),
<<"channel_id">> => ChannelIdBin
}
)
end;
execute_method(<<"voice.get_pending_joins_for_channel">>, Params) ->
ChannelIdBin = maps:get(<<"channel_id">>, Params),
case parse_optional_guild_id(Params) of
undefined ->
normalize_pending_joins_response(
gateway_rpc_call:execute_method(
<<"call.get_pending_joins">>,
#{<<"channel_id">> => ChannelIdBin}
)
);
GuildId ->
gateway_rpc_guild:execute_method(
<<"guild.get_pending_joins_for_channel">>,
#{
<<"guild_id">> => integer_to_binary(GuildId),
<<"channel_id">> => ChannelIdBin
}
)
end;
execute_method(Method, _Params) ->
throw({error, <<"Unknown method: ", Method/binary>>}).
-spec parse_optional_guild_id(map()) -> integer() | undefined.
parse_optional_guild_id(Params) ->
case maps:get(<<"guild_id">>, Params, undefined) of
undefined ->
undefined;
null ->
undefined;
GuildIdBin ->
validation:snowflake_or_throw(<<"guild_id">>, GuildIdBin)
end.
-spec maybe_put_connection_id(binary() | undefined, map()) -> map().
maybe_put_connection_id(undefined, Params) ->
Params;
maybe_put_connection_id(ConnectionId, Params) ->
Params#{<<"connection_id">> => ConnectionId}.
-spec build_dm_voice_states_response(binary()) -> map().
build_dm_voice_states_response(ChannelIdBin) ->
case gateway_rpc_call:execute_method(<<"call.get">>, #{<<"channel_id">> => ChannelIdBin}) of
null ->
#{<<"voice_states">> => []};
CallData when is_map(CallData) ->
VoiceStates = get_map_value(CallData, [<<"voice_states">>, voice_states]),
#{<<"voice_states">> => normalize_voice_states(VoiceStates)}
end.
-spec normalize_voice_states(term()) -> [map()].
normalize_voice_states(VoiceStates) when is_list(VoiceStates) ->
lists:reverse(
lists:foldl(fun normalize_voice_state_entry/2, [], VoiceStates)
);
normalize_voice_states(_) ->
[].
-spec normalize_voice_state_entry(map(), [map()]) -> [map()].
normalize_voice_state_entry(VoiceState, Acc) ->
ConnectionId = normalize_id(get_map_value(VoiceState, [<<"connection_id">>, connection_id])),
UserId = normalize_id(get_map_value(VoiceState, [<<"user_id">>, user_id])),
ChannelId = normalize_id(get_map_value(VoiceState, [<<"channel_id">>, channel_id])),
case {ConnectionId, UserId, ChannelId} of
{undefined, _, _} ->
Acc;
{_, undefined, _} ->
Acc;
{_, _, undefined} ->
Acc;
_ ->
[#{
<<"connection_id">> => ConnectionId,
<<"user_id">> => UserId,
<<"channel_id">> => ChannelId
} | Acc]
end.
-spec normalize_pending_joins_response(term()) -> map().
normalize_pending_joins_response(Response) when is_map(Response) ->
PendingJoins = get_map_value(Response, [<<"pending_joins">>, pending_joins]),
#{<<"pending_joins">> => normalize_pending_joins(PendingJoins)};
normalize_pending_joins_response(_) ->
#{<<"pending_joins">> => []}.
-spec normalize_pending_joins(term()) -> [map()].
normalize_pending_joins(PendingJoins) when is_list(PendingJoins) ->
lists:reverse(
lists:foldl(fun normalize_pending_join_entry/2, [], PendingJoins)
);
normalize_pending_joins(_) ->
[].
-spec normalize_pending_join_entry(map(), [map()]) -> [map()].
normalize_pending_join_entry(PendingJoin, Acc) ->
ConnectionId = normalize_id(get_map_value(PendingJoin, [<<"connection_id">>, connection_id])),
UserId = normalize_id(get_map_value(PendingJoin, [<<"user_id">>, user_id])),
TokenNonce = normalize_token_nonce(get_map_value(PendingJoin, [<<"token_nonce">>, token_nonce])),
ExpiresAt = normalize_expiry(get_map_value(PendingJoin, [<<"expires_at">>, expires_at])),
case {ConnectionId, UserId} of
{undefined, _} ->
Acc;
{_, undefined} ->
Acc;
_ ->
[#{
<<"connection_id">> => ConnectionId,
<<"user_id">> => UserId,
<<"token_nonce">> => TokenNonce,
<<"expires_at">> => ExpiresAt
} | Acc]
end.
-spec normalize_id(term()) -> binary() | undefined.
normalize_id(undefined) ->
undefined;
normalize_id(Value) when is_binary(Value) ->
Value;
normalize_id(Value) when is_integer(Value) ->
integer_to_binary(Value);
normalize_id(_) ->
undefined.
-spec normalize_token_nonce(term()) -> binary().
normalize_token_nonce(undefined) ->
<<>>;
normalize_token_nonce(Value) when is_binary(Value) ->
Value;
normalize_token_nonce(Value) when is_integer(Value) ->
integer_to_binary(Value);
normalize_token_nonce(_) ->
<<>>.
-spec normalize_expiry(term()) -> integer().
normalize_expiry(Value) when is_integer(Value) ->
Value;
normalize_expiry(_) ->
0.
-spec get_map_value(map(), [term()]) -> term().
get_map_value(_Map, []) ->
undefined;
get_map_value(Map, [Key | Rest]) ->
case maps:find(Key, Map) of
{ok, Value} ->
Value;
error ->
get_map_value(Map, Rest)
end.

View File

@@ -19,6 +19,7 @@
-export([init/2]).
-spec init(cowboy_req:req(), term()) -> {ok, cowboy_req:req(), term()}.
init(Req0, State) ->
Req = cowboy_req:reply(
200,

View File

@@ -60,6 +60,7 @@
-type purge_mode() :: none | soft | hard.
-type reload_opts() :: #{purge => purge_mode()}.
-type reload_result() :: map().
-spec reload_module(atom()) -> {ok, map()} | {error, term()}.
reload_module(Module) when is_atom(Module) ->
@@ -79,11 +80,11 @@ reload_module(Module) when is_atom(Module) ->
end
end.
-spec reload_modules([atom()]) -> {ok, [map()]}.
-spec reload_modules([atom()]) -> {ok, [reload_result()]}.
reload_modules(Modules) when is_list(Modules) ->
reload_modules(Modules, #{purge => soft}).
-spec reload_modules([atom()], reload_opts()) -> {ok, [map()]}.
-spec reload_modules([atom()], reload_opts()) -> {ok, [reload_result()]}.
reload_modules(Modules, Opts) when is_list(Modules), is_map(Opts) ->
Purge = maps:get(purge, Opts, soft),
Results = lists:map(
@@ -94,7 +95,7 @@ reload_modules(Modules, Opts) when is_list(Modules), is_map(Opts) ->
),
{ok, Results}.
-spec reload_beams([{atom(), binary()}], reload_opts()) -> {ok, [map()]}.
-spec reload_beams([{atom(), binary()}], reload_opts()) -> {ok, [reload_result()]}.
reload_beams(Pairs, Opts) when is_list(Pairs), is_map(Opts) ->
Purge = maps:get(purge, Opts, soft),
Results =
@@ -106,11 +107,11 @@ reload_beams(Pairs, Opts) when is_list(Pairs), is_map(Opts) ->
),
{ok, Results}.
-spec reload_all_changed() -> {ok, [map()]}.
-spec reload_all_changed() -> {ok, [reload_result()]}.
reload_all_changed() ->
reload_all_changed(soft).
-spec reload_all_changed(purge_mode()) -> {ok, [map()]}.
-spec reload_all_changed(purge_mode()) -> {ok, [reload_result()]}.
reload_all_changed(Purge) ->
ChangedModules = get_changed_modules(),
reload_modules(ChangedModules, #{purge => Purge}).
@@ -141,6 +142,7 @@ get_module_info(Module) when is_atom(Module) ->
}}
end.
-spec reload_one(atom(), purge_mode()) -> reload_result().
reload_one(Module, Purge) ->
case is_critical_module(Module) of
true ->
@@ -149,6 +151,7 @@ reload_one(Module, Purge) ->
do_reload_one(Module, Purge)
end.
-spec reload_one_beam(atom(), binary(), purge_mode()) -> reload_result().
reload_one_beam(Module, BeamBin, Purge) ->
case is_critical_module(Module) of
true ->
@@ -157,19 +160,21 @@ reload_one_beam(Module, BeamBin, Purge) ->
do_reload_one_beam(Module, BeamBin, Purge)
end.
-spec do_reload_one(atom(), purge_mode()) -> reload_result().
do_reload_one(Module, Purge) ->
OldLoadedMd5 = loaded_md5(Module),
OldBeamPath = code:which(Module),
OldDiskMd5 = disk_md5(OldBeamPath),
ok = maybe_purge_before_load(Module, Purge),
case code:load_file(Module) of
{module, Module} ->
NewLoadedMd5 = loaded_md5(Module),
NewBeamPath = code:which(Module),
NewDiskMd5 = disk_md5(NewBeamPath),
Verified = (NewLoadedMd5 =/= undefined) andalso (NewDiskMd5 =/= undefined) andalso (NewLoadedMd5 =:= NewDiskMd5),
Verified =
(NewLoadedMd5 =/= undefined) andalso
(NewDiskMd5 =/= undefined) andalso
(NewLoadedMd5 =:= NewDiskMd5),
{PurgedOld, LingeringCount} = maybe_purge_old_after_load(Module, Purge),
#{
module => Module,
@@ -195,9 +200,9 @@ do_reload_one(Module, Purge) ->
}
end.
-spec do_reload_one_beam(atom(), binary(), purge_mode()) -> reload_result().
do_reload_one_beam(Module, BeamBin, Purge) ->
OldLoadedMd5 = loaded_md5(Module),
ExpectedMd5 =
case beam_lib:md5(BeamBin) of
{ok, {Module, Md5}} ->
@@ -207,9 +212,7 @@ do_reload_one_beam(Module, BeamBin, Purge) ->
_ ->
erlang:error(invalid_beam)
end,
ok = maybe_purge_before_load(Module, Purge),
Filename = atom_to_list(Module) ++ ".beam(hot)",
case code:load_binary(Module, Filename, BeamBin) of
{module, Module} ->
@@ -239,6 +242,7 @@ do_reload_one_beam(Module, BeamBin, Purge) ->
}
end.
-spec maybe_purge_before_load(atom(), purge_mode()) -> ok.
maybe_purge_before_load(_Module, none) ->
ok;
maybe_purge_before_load(_Module, soft) ->
@@ -247,16 +251,28 @@ maybe_purge_before_load(Module, hard) ->
_ = code:purge(Module),
ok.
-spec maybe_purge_old_after_load(atom(), purge_mode()) -> {boolean(), non_neg_integer()}.
maybe_purge_old_after_load(_Module, none) ->
{false, 0};
maybe_purge_old_after_load(Module, hard) ->
_ = code:soft_purge(Module),
Purged = code:purge(Module),
{Purged, case Purged of true -> 0; false -> count_lingering(Module) end};
LingeringCount =
case Purged of
true -> 0;
false -> count_lingering(Module)
end,
{Purged, LingeringCount};
maybe_purge_old_after_load(Module, soft) ->
Purged = wait_soft_purge(Module, 40, 50),
{Purged, case Purged of true -> 0; false -> count_lingering(Module) end}.
LingeringCount =
case Purged of
true -> 0;
false -> count_lingering(Module)
end,
{Purged, LingeringCount}.
-spec wait_soft_purge(atom(), non_neg_integer(), pos_integer()) -> boolean().
wait_soft_purge(_Module, 0, _SleepMs) ->
false;
wait_soft_purge(Module, N, SleepMs) ->
@@ -264,10 +280,13 @@ wait_soft_purge(Module, N, SleepMs) ->
true ->
true;
false ->
receive after SleepMs -> ok end,
receive
after SleepMs -> ok
end,
wait_soft_purge(Module, N - 1, SleepMs)
end.
-spec count_lingering(atom()) -> non_neg_integer().
count_lingering(Module) ->
lists:foldl(
fun(Pid, Acc) ->
@@ -280,21 +299,26 @@ count_lingering(Module) ->
processes()
).
-spec get_changed_modules() -> [atom()].
get_changed_modules() ->
Modified = code:modified_modules(),
[M || M <- Modified, is_fluxer_module(M), not is_critical_module(M)].
-spec is_critical_module(atom()) -> boolean().
is_critical_module(Module) ->
lists:member(Module, ?CRITICAL_MODULES).
-spec is_fluxer_module(atom()) -> boolean().
is_fluxer_module(Module) ->
ModuleStr = atom_to_list(Module),
lists:prefix("fluxer_", ModuleStr) orelse
lists:prefix("gateway", ModuleStr) orelse
lists:prefix("gateway_http_", ModuleStr) orelse
lists:prefix("session", ModuleStr) orelse
lists:prefix("guild", ModuleStr) orelse
lists:prefix("presence", ModuleStr) orelse
lists:prefix("push", ModuleStr) orelse
lists:prefix("push_dispatcher", ModuleStr) orelse
lists:prefix("call", ModuleStr) orelse
lists:prefix("health", ModuleStr) orelse
lists:prefix("hot_reload", ModuleStr) orelse
@@ -311,9 +335,13 @@ is_fluxer_module(Module) ->
lists:prefix("map_utils", ModuleStr) orelse
lists:prefix("type_conv", ModuleStr) orelse
lists:prefix("utils", ModuleStr) orelse
lists:prefix("snowflake_", ModuleStr) orelse
lists:prefix("user_utils", ModuleStr) orelse
lists:prefix("custom_status", ModuleStr).
lists:prefix("custom_status", ModuleStr) orelse
lists:prefix("otel_", ModuleStr) orelse
lists:prefix("event_", ModuleStr).
-spec loaded_md5(atom()) -> binary() | undefined.
loaded_md5(Module) ->
try
Module:module_info(md5)
@@ -321,6 +349,7 @@ loaded_md5(Module) ->
_:_ -> undefined
end.
-spec disk_md5(string() | atom()) -> binary() | undefined.
disk_md5(Path) when is_list(Path) ->
case beam_lib:md5(Path) of
{ok, {_M, Md5}} -> Md5;
@@ -329,11 +358,13 @@ disk_md5(Path) when is_list(Path) ->
disk_md5(_) ->
undefined.
-spec hex_or_null(binary() | undefined) -> binary() | null.
hex_or_null(undefined) ->
null;
hex_or_null(Bin) when is_binary(Bin) ->
binary:encode_hex(Bin, lowercase).
-spec get_loaded_time(atom()) -> term().
get_loaded_time(Module) ->
try
case Module:module_info(compile) of
@@ -346,6 +377,7 @@ get_loaded_time(Module) ->
_:_ -> undefined
end.
-spec get_disk_time(string() | atom()) -> calendar:datetime() | undefined.
get_disk_time(BeamPath) when is_list(BeamPath) ->
case file:read_file_info(BeamPath) of
{ok, FileInfo} ->

View File

@@ -23,6 +23,9 @@
-define(MAX_MODULES, 600).
-define(MAX_BODY_BYTES, 26214400).
-type purge_mode() :: none | soft | hard.
-spec init(cowboy_req:req(), term()) -> {ok, cowboy_req:req(), term()}.
init(Req0, State) ->
case cowboy_req:method(Req0) of
<<"POST">> ->
@@ -32,6 +35,7 @@ init(Req0, State) ->
{ok, Req, State}
end.
-spec handle_post(cowboy_req:req(), term()) -> {ok, cowboy_req:req(), term()}.
handle_post(Req0, State) ->
case authorize(Req0) of
ok ->
@@ -45,52 +49,75 @@ handle_post(Req0, State) ->
{ok, Req1, State}
end.
-spec authorize(cowboy_req:req()) -> ok | {error, cowboy_req:req()}.
authorize(Req0) ->
case cowboy_req:header(<<"authorization">>, Req0) of
undefined ->
Req = cowboy_req:reply(
401,
?JSON_HEADERS,
jsx:encode(#{<<"error">> => <<"Unauthorized">>}),
json:encode(#{<<"error">> => <<"Unauthorized">>}),
Req0
),
{error, Req};
AuthHeader ->
case os:getenv("GATEWAY_ADMIN_SECRET") of
false ->
Req = cowboy_req:reply(
500,
?JSON_HEADERS,
jsx:encode(#{<<"error">> => <<"GATEWAY_ADMIN_SECRET not configured">>}),
Req0
),
{error, Req};
Secret ->
Expected = <<"Bearer ", (list_to_binary(Secret))/binary>>,
case AuthHeader of
Expected ->
ok;
_ ->
Req = cowboy_req:reply(
401,
?JSON_HEADERS,
jsx:encode(#{<<"error">> => <<"Unauthorized">>}),
Req0
),
{error, Req}
end
end
authorize_with_secret(AuthHeader, Req0)
end.
-spec authorize_with_secret(binary(), cowboy_req:req()) -> ok | {error, cowboy_req:req()}.
authorize_with_secret(AuthHeader, Req0) ->
case fluxer_gateway_env:get(admin_reload_secret) of
undefined ->
Req = cowboy_req:reply(
500,
?JSON_HEADERS,
json:encode(#{<<"error">> => <<"admin reload secret not configured">>}),
Req0
),
{error, Req};
Secret when is_binary(Secret) ->
check_auth_header(AuthHeader, <<"Bearer ", Secret/binary>>, Req0);
Secret when is_list(Secret) ->
check_auth_header(AuthHeader, <<"Bearer ", (list_to_binary(Secret))/binary>>, Req0)
end.
-spec check_auth_header(binary(), binary(), cowboy_req:req()) -> ok | {error, cowboy_req:req()}.
check_auth_header(AuthHeader, Expected, Req0) ->
case secure_compare(AuthHeader, Expected) of
true ->
ok;
false ->
Req = cowboy_req:reply(
401,
?JSON_HEADERS,
json:encode(#{<<"error">> => <<"Unauthorized">>}),
Req0
),
{error, Req}
end.
-spec secure_compare(binary(), binary()) -> boolean().
secure_compare(Left, Right) when is_binary(Left), is_binary(Right) ->
case byte_size(Left) =:= byte_size(Right) of
true ->
crypto:hash_equals(Left, Right);
false ->
false
end.
-spec read_body(cowboy_req:req()) ->
{ok, map(), cowboy_req:req()} | {error, pos_integer(), map(), cowboy_req:req()}.
read_body(Req0) ->
case cowboy_req:body_length(Req0) of
Length when is_integer(Length), Length > ?MAX_BODY_BYTES ->
{error, 413, #{<<"error">> => <<"Request body too large">>}, Req0};
_ ->
read_body(Req0, <<>>)
read_body_chunks(Req0, <<>>)
end.
read_body(Req0, Acc) ->
-spec read_body_chunks(cowboy_req:req(), binary()) ->
{ok, map(), cowboy_req:req()} | {error, pos_integer(), map(), cowboy_req:req()}.
read_body_chunks(Req0, Acc) ->
case cowboy_req:read_body(Req0, #{length => 1048576}) of
{ok, Body, Req1} ->
FullBody = <<Acc/binary, Body/binary>>,
@@ -101,14 +128,16 @@ read_body(Req0, Acc) ->
true ->
{error, 413, #{<<"error">> => <<"Request body too large">>}, Req1};
false ->
read_body(Req1, NewAcc)
read_body_chunks(Req1, NewAcc)
end
end.
-spec decode_body(binary(), cowboy_req:req()) ->
{ok, map(), cowboy_req:req()} | {error, pos_integer(), map(), cowboy_req:req()}.
decode_body(<<>>, Req0) ->
{ok, #{}, Req0};
decode_body(Body, Req0) ->
case catch jsx:decode(Body, [return_maps]) of
case catch json:decode(Body) of
{'EXIT', _Reason} ->
{error, 400, #{<<"error">> => <<"Invalid JSON payload">>}, Req0};
Decoded when is_map(Decoded) ->
@@ -117,6 +146,7 @@ decode_body(Body, Req0) ->
{error, 400, #{<<"error">> => <<"Invalid request body">>}, Req0}
end.
-spec handle_reload(map(), cowboy_req:req(), term()) -> {ok, cowboy_req:req(), term()}.
handle_reload(Params, Req0, State) ->
try
Purge = parse_purge(maps:get(<<"purge">>, Params, <<"soft">>)),
@@ -124,14 +154,7 @@ handle_reload(Params, Req0, State) ->
undefined ->
handle_modules_reload(Params, Purge, Req0, State);
Beams when is_list(Beams) ->
case length(Beams) =< ?MAX_MODULES of
true ->
Pairs = decode_beams(Beams),
{ok, Results} = hot_reload:reload_beams(Pairs, #{purge => Purge}),
respond(200, #{<<"results">> => Results}, Req0, State);
false ->
respond(400, #{<<"error">> => <<"Too many modules">>}, Req0, State)
end;
handle_beams_reload(Beams, Purge, Req0, State);
_ ->
respond(400, #{<<"error">> => <<"beams must be an array">>}, Req0, State)
end
@@ -142,11 +165,24 @@ handle_reload(Params, Req0, State) ->
respond(400, #{<<"error">> => <<"Invalid module name or beam payload">>}, Req0, State);
error:{beam_module_mismatch, _, _} ->
respond(400, #{<<"error">> => <<"Invalid module name or beam payload">>}, Req0, State);
_:Reason ->
logger:error("hot_reload_handler: Error during reload: ~p", [Reason]),
_:_Reason ->
respond(500, #{<<"error">> => <<"Internal error">>}, Req0, State)
end.
-spec handle_beams_reload([map()], purge_mode(), cowboy_req:req(), term()) ->
{ok, cowboy_req:req(), term()}.
handle_beams_reload(Beams, Purge, Req0, State) ->
case length(Beams) =< ?MAX_MODULES of
true ->
Pairs = decode_beams(Beams),
{ok, Results} = hot_reload:reload_beams(Pairs, #{purge => Purge}),
respond(200, #{<<"results">> => Results}, Req0, State);
false ->
respond(400, #{<<"error">> => <<"Too many modules">>}, Req0, State)
end.
-spec handle_modules_reload(map(), purge_mode(), cowboy_req:req(), term()) ->
{ok, cowboy_req:req(), term()}.
handle_modules_reload(Params, Purge, Req0, State) ->
case maps:get(<<"modules">>, Params, []) of
[] ->
@@ -165,6 +201,7 @@ handle_modules_reload(Params, Purge, Req0, State) ->
respond(400, #{<<"error">> => <<"modules must be an array">>}, Req0, State)
end.
-spec decode_beams([map()]) -> [{atom(), binary()}].
decode_beams(Beams) ->
lists:map(
fun(Elem) ->
@@ -187,6 +224,7 @@ decode_beams(Beams) ->
Beams
).
-spec to_binary(binary() | list()) -> binary().
to_binary(B) when is_binary(B) ->
B;
to_binary(L) when is_list(L) ->
@@ -194,6 +232,7 @@ to_binary(L) when is_list(L) ->
to_binary(_) ->
erlang:error(badarg).
-spec parse_purge(binary() | atom()) -> purge_mode().
parse_purge(<<"none">>) -> none;
parse_purge(<<"soft">>) -> soft;
parse_purge(<<"hard">>) -> hard;
@@ -202,6 +241,7 @@ parse_purge(soft) -> soft;
parse_purge(hard) -> hard;
parse_purge(_) -> soft.
-spec to_module_atom(binary() | list()) -> atom().
to_module_atom(B) when is_binary(B) ->
case is_allowed_module_name(B) of
true -> erlang:binary_to_atom(B, utf8);
@@ -212,10 +252,12 @@ to_module_atom(L) when is_list(L) ->
to_module_atom(_) ->
erlang:error(badarg).
-spec is_allowed_module_name(binary()) -> boolean().
is_allowed_module_name(Bin) when is_binary(Bin) ->
byte_size(Bin) > 0 andalso byte_size(Bin) < 128 andalso
is_safe_chars(Bin) andalso has_allowed_prefix(Bin).
-spec is_safe_chars(binary()) -> boolean().
is_safe_chars(Bin) ->
lists:all(
fun(C) ->
@@ -226,14 +268,17 @@ is_safe_chars(Bin) ->
binary_to_list(Bin)
).
-spec has_allowed_prefix(binary()) -> boolean().
has_allowed_prefix(Bin) ->
Prefixes = [
<<"fluxer_">>,
<<"gateway">>,
<<"gateway_http_">>,
<<"session">>,
<<"guild">>,
<<"presence">>,
<<"push">>,
<<"push_dispatcher">>,
<<"call">>,
<<"health">>,
<<"hot_reload">>,
@@ -251,7 +296,10 @@ has_allowed_prefix(Bin) ->
<<"type_conv">>,
<<"utils">>,
<<"user_utils">>,
<<"custom_status">>
<<"snowflake_">>,
<<"custom_status">>,
<<"otel_">>,
<<"event_">>
],
lists:any(
fun(P) ->
@@ -261,6 +309,7 @@ has_allowed_prefix(Bin) ->
Prefixes
).
-spec respond(pos_integer(), map(), cowboy_req:req(), term()) -> {ok, cowboy_req:req(), term()}.
respond(Status, Body, Req0, State) ->
Req = cowboy_req:reply(Status, ?JSON_HEADERS, jsx:encode(Body), Req0),
Req = cowboy_req:reply(Status, ?JSON_HEADERS, json:encode(Body), Req0),
{ok, Req, State}.

View File

@@ -19,10 +19,6 @@
-export([select/2, group_keys/2]).
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-endif.
-define(HASH_LIMIT, 16#FFFFFFFF).
-spec select(term(), pos_integer()) -> non_neg_integer().
@@ -51,22 +47,21 @@ select(_Key, _ShardCount) ->
-spec group_keys([term()], pos_integer()) -> [{non_neg_integer(), [term()]}].
group_keys(Keys, ShardCount) when is_list(Keys), ShardCount > 0 ->
Sorted =
maps:to_list(
lists:foldl(
fun(Key, Acc) ->
Index = select(Key, ShardCount),
Existing = maps:get(Index, Acc, []),
maps:put(Index, [Key | Existing], Acc)
end,
#{},
Keys
)
Grouped =
lists:foldl(
fun(Key, Acc) ->
Index = select(Key, ShardCount),
Existing = maps:get(Index, Acc, []),
maps:put(Index, [Key | Existing], Acc)
end,
#{},
Keys
),
lists:sort(
Sorted = lists:sort(
fun({IdxA, _}, {IdxB, _}) -> IdxA =< IdxB end,
[{Index, lists:usort(Group)} || {Index, Group} <- Sorted]
);
[{Index, lists:usort(Group)} || {Index, Group} <- maps:to_list(Grouped)]
),
Sorted;
group_keys(_Keys, _ShardCount) ->
[].
@@ -75,20 +70,55 @@ weight(Key, Index) ->
erlang:phash2({Key, Index}, ?HASH_LIMIT).
-ifdef(TEST).
select_returns_valid_index_test() ->
-include_lib("eunit/include/eunit.hrl").
select_single_shard_test() ->
?assertEqual(0, select(test_key, 1)),
Index = select(test_key, 5),
?assert(Index >= 0),
?assert(Index < 5).
?assertEqual(0, select(any_key, 1)),
?assertEqual(0, select(12345, 1)).
select_is_stable_for_same_inputs_test() ->
?assertEqual(select(<<"abc">>, 8), select(<<"abc">>, 8)),
?assertEqual(select(12345, 3), select(12345, 3)).
select_valid_index_test_() ->
[
?_test(begin
Index = select(test_key, N),
?assert(Index >= 0),
?assert(Index < N)
end)
|| N <- [2, 5, 10, 100]
].
group_keys_sorts_and_deduplicates_test() ->
select_stability_test_() ->
[
?_assertEqual(select(<<"abc">>, 8), select(<<"abc">>, 8)),
?_assertEqual(select(12345, 3), select(12345, 3)),
?_assertEqual(select({user, 1}, 10), select({user, 1}, 10))
].
select_distribution_test() ->
Keys = lists:seq(1, 1000),
ShardCount = 10,
Distribution = lists:foldl(
fun(Key, Acc) ->
Index = select(Key, ShardCount),
maps:update_with(Index, fun(V) -> V + 1 end, 1, Acc)
end,
#{},
Keys
),
Counts = maps:values(Distribution),
?assertEqual(ShardCount, maps:size(Distribution)),
lists:foreach(fun(Count) -> ?assert(Count > 0) end, Counts).
group_keys_empty_test() ->
?assertEqual([], group_keys([], 4)).
group_keys_single_test() ->
Groups = group_keys([key1], 4),
?assertEqual(1, length(Groups)).
group_keys_deduplicates_test() ->
Keys = [1, 2, 3, 1, 2],
Groups = group_keys(Keys, 2),
?assertMatch([{_, _}, {_, _}], Groups),
lists:foreach(
fun({_Index, GroupKeys}) ->
?assertEqual(GroupKeys, lists:usort(GroupKeys))
@@ -96,6 +126,16 @@ group_keys_sorts_and_deduplicates_test() ->
Groups
).
group_keys_handles_empty_test() ->
?assertEqual([], group_keys([], 4)).
group_keys_sorted_indices_test() ->
Keys = lists:seq(1, 100),
Groups = group_keys(Keys, 5),
Indices = [I || {I, _} <- Groups],
?assertEqual(Indices, lists:sort(Indices)).
group_keys_all_keys_present_test() ->
Keys = [a, b, c, d, e],
Groups = group_keys(Keys, 3),
AllGroupedKeys = lists:flatten([K || {_, K} <- Groups]),
?assertEqual(lists:sort(Keys), lists:sort(AllGroupedKeys)).
-endif.

View File

@@ -27,59 +27,86 @@
-type rpc_request() :: map().
-type rpc_response() :: {ok, map()} | {error, term()}.
-type rpc_options() :: map().
-spec call(rpc_request()) -> rpc_response().
call(Request) ->
call(Request, #{}).
-spec call(rpc_request(), map()) -> rpc_response().
-spec call(rpc_request(), rpc_options()) -> rpc_response().
call(Request, _Options) ->
Url = get_rpc_url(),
Headers = get_rpc_headers(),
Body = jsx:encode(Request),
case
hackney:request(post, Url, Headers, Body, [{recv_timeout, 30000}, {connect_timeout, 5000}])
of
{ok, 200, _RespHeaders, ClientRef} ->
case hackney:body(ClientRef) of
{ok, RespBody} ->
Response = jsx:decode(RespBody, [return_maps]),
Data = maps:get(<<"data">>, Response, #{}),
{ok, Data};
{error, Reason} ->
logger:error("[rpc_client] Failed to read response body: ~p", [Reason]),
{error, {body_read_failed, Reason}}
end;
{ok, StatusCode, _RespHeaders, ClientRef} ->
case hackney:body(ClientRef) of
{ok, RespBody} ->
hackney:close(ClientRef),
logger:error("[rpc_client] RPC request failed with status ~p: ~s", [
StatusCode, RespBody
]),
{error, {http_error, StatusCode, RespBody}};
{error, Reason} ->
hackney:close(ClientRef),
logger:error(
"[rpc_client] Failed to read error response body (status ~p): ~p", [
StatusCode, Reason
]
),
{error, {http_error, StatusCode, body_read_failed}}
end;
Body = json:encode(Request),
case gateway_http_client:request(rpc, post, Url, Headers, Body) of
{ok, 200, _RespHeaders, RespBody} ->
handle_success_response(RespBody);
{ok, StatusCode, _RespHeaders, RespBody} ->
handle_error_response(StatusCode, RespBody);
{error, Reason} ->
logger:error("[rpc_client] RPC request failed: ~p", [Reason]),
{error, Reason}
end.
-spec handle_success_response(binary()) -> rpc_response().
handle_success_response(RespBody) ->
Response = json:decode(RespBody),
Data = maps:get(<<"data">>, Response, #{}),
{ok, Data}.
-spec handle_error_response(pos_integer(), binary()) -> {error, term()}.
handle_error_response(StatusCode, RespBody) ->
{error, {http_error, StatusCode, RespBody}}.
-spec get_rpc_url() -> string().
get_rpc_url() ->
ApiHost = fluxer_gateway_env:get(api_host),
get_rpc_url(ApiHost).
-spec get_rpc_url(string() | binary()) -> string().
get_rpc_url(ApiHost) ->
"http://" ++ ApiHost ++ "/_rpc".
BaseUrl = api_host_base_url(ApiHost),
BaseUrl ++ "/_rpc".
-spec api_host_base_url(string() | binary()) -> string().
api_host_base_url(ApiHost) ->
HostString = ensure_string(ApiHost),
Normalized = normalize_api_host(HostString),
strip_trailing_slash(Normalized).
-spec ensure_string(binary() | string()) -> string().
ensure_string(Value) when is_binary(Value) ->
binary_to_list(Value);
ensure_string(Value) when is_list(Value) ->
Value.
-spec normalize_api_host(string()) -> string().
normalize_api_host(Host) ->
Lower = string:lowercase(Host),
case {has_protocol_prefix(Lower, "http://"), has_protocol_prefix(Lower, "https://")} of
{true, _} -> Host;
{_, true} -> Host;
_ -> "http://" ++ Host
end.
-spec has_protocol_prefix(string(), string()) -> boolean().
has_protocol_prefix(Str, Prefix) ->
case string:prefix(Str, Prefix) of
nomatch -> false;
_ -> true
end.
-spec strip_trailing_slash(string()) -> string().
strip_trailing_slash([]) ->
"";
strip_trailing_slash(Url) ->
case lists:last(Url) of
$/ -> strip_trailing_slash(lists:droplast(Url));
_ -> Url
end.
-spec get_rpc_headers() -> [{binary() | string(), binary() | string()}].
get_rpc_headers() ->
RpcSecretKey = fluxer_gateway_env:get(rpc_secret_key),
[{<<"Authorization">>, <<"Bearer ", RpcSecretKey/binary>>}].
AuthHeader = {<<"Authorization">>, <<"Bearer ", RpcSecretKey/binary>>},
InitialHeaders = [AuthHeader],
gateway_tracing:inject_rpc_headers(InitialHeaders).

File diff suppressed because it is too large Load Diff

View File

@@ -21,123 +21,533 @@
is_guild_unavailable_for_user/2,
is_user_staff/2,
check_unavailability_transition/2,
handle_unavailability_transition/2
handle_unavailability_transition/2,
get_cached_unavailability_mode/1,
is_guild_unavailable_for_user_from_cache/2,
update_unavailability_cache_for_state/1
]).
-import(guild_permissions, [find_member_by_user_id/2]).
-import(guild_data, [get_guild_state/2]).
-type guild_state() :: map().
-type user_id() :: integer().
-type guild_id() :: integer().
-type unavailability_mode() ::
available
| unavailable_for_everyone
| unavailable_for_everyone_but_staff.
-type transition_result() :: {unavailable_enabled, boolean()} | unavailable_disabled | no_change.
-define(GUILD_UNAVAILABILITY_CACHE, guild_unavailability_cache).
-define(STAFF_USER_FLAG, 16#1).
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-endif.
-spec is_guild_unavailable_for_user(user_id(), guild_state()) -> boolean().
is_guild_unavailable_for_user(UserId, State) ->
Data = maps:get(data, State),
Guild = maps:get(<<"guild">>, Data),
Features = maps:get(<<"features">>, Guild, []),
HasUnavailableForEveryone = lists:member(<<"UNAVAILABLE_FOR_EVERYONE">>, Features),
HasUnavailableForEveryoneButStaff =
lists:member(<<"UNAVAILABLE_FOR_EVERYONE_BUT_STAFF">>, Features),
case {HasUnavailableForEveryone, HasUnavailableForEveryoneButStaff} of
{true, _} ->
case get_unavailability_mode_from_state(State) of
unavailable_for_everyone ->
true;
{false, true} ->
unavailable_for_everyone_but_staff ->
not is_user_staff(UserId, State);
{false, false} ->
available ->
false
end.
-spec is_user_staff(user_id(), guild_state()) -> boolean().
is_user_staff(UserId, State) ->
case find_member_by_user_id(UserId, State) of
undefined ->
case is_user_staff_from_sessions(UserId, State) of
true ->
true;
false ->
false;
Member ->
User = maps:get(<<"user">>, Member, #{}),
Flags = utils:binary_to_integer_safe(maps:get(<<"flags">>, User, <<"0">>)),
(Flags band 16#1) =:= 16#1
end.
check_unavailability_transition(OldState, NewState) ->
OldData = maps:get(data, OldState),
OldGuild = maps:get(<<"guild">>, OldData),
OldFeatures = maps:get(<<"features">>, OldGuild, []),
NewData = maps:get(data, NewState),
NewGuild = maps:get(<<"guild">>, NewData),
NewFeatures = maps:get(<<"features">>, NewGuild, []),
OldUnavailableForEveryone = lists:member(<<"UNAVAILABLE_FOR_EVERYONE">>, OldFeatures),
NewUnavailableForEveryone = lists:member(<<"UNAVAILABLE_FOR_EVERYONE">>, NewFeatures),
OldUnavailableForEveryoneButStaff =
lists:member(<<"UNAVAILABLE_FOR_EVERYONE_BUT_STAFF">>, OldFeatures),
NewUnavailableForEveryoneButStaff =
lists:member(<<"UNAVAILABLE_FOR_EVERYONE_BUT_STAFF">>, NewFeatures),
OldIsUnavailable = OldUnavailableForEveryone orelse OldUnavailableForEveryoneButStaff,
NewIsUnavailable = NewUnavailableForEveryone orelse NewUnavailableForEveryoneButStaff,
case {OldIsUnavailable, NewIsUnavailable} of
{false, true} ->
{unavailable_enabled, NewUnavailableForEveryoneButStaff};
{true, false} ->
unavailable_disabled;
_ ->
case
{OldUnavailableForEveryoneButStaff, NewUnavailableForEveryoneButStaff,
OldUnavailableForEveryone, NewUnavailableForEveryone}
of
{true, false, false, true} ->
{unavailable_enabled, false};
{false, true, true, false} ->
{unavailable_enabled, true};
_ ->
no_change
undefined ->
case guild_permissions:find_member_by_user_id(UserId, State) of
undefined ->
false;
Member ->
User = maps:get(<<"user">>, Member, #{}),
is_user_staff_from_user_data(User)
end
end.
handle_unavailability_transition(OldState, NewState) ->
GuildId = maps:get(id, NewState),
UnavailablePayload = #{
<<"id">> => integer_to_binary(GuildId),
<<"unavailable">> => true
},
-spec is_user_staff_from_sessions(user_id(), guild_state()) -> boolean() | undefined.
is_user_staff_from_sessions(UserId, State) ->
Sessions = maps:get(sessions, State, #{}),
maps:fold(
fun(_SessionId, SessionData, Acc) ->
case Acc of
undefined ->
SessionUserId = maps:get(user_id, SessionData, undefined),
SessionIsStaff = maps:get(is_staff, SessionData, undefined),
case {SessionUserId =:= UserId, SessionIsStaff} of
{true, true} ->
true;
{true, false} ->
false;
_ ->
undefined
end;
_ ->
Acc
end
end,
undefined,
Sessions
).
-spec get_cached_unavailability_mode(guild_id()) -> unavailability_mode().
get_cached_unavailability_mode(GuildId) ->
ensure_unavailability_cache_table(),
case ets:lookup(?GUILD_UNAVAILABILITY_CACHE, GuildId) of
[{GuildId, Mode}] ->
normalize_unavailability_mode(Mode);
[] ->
available
end.
-spec is_guild_unavailable_for_user_from_cache(guild_id(), map()) -> boolean().
is_guild_unavailable_for_user_from_cache(GuildId, UserData) ->
case get_cached_unavailability_mode(GuildId) of
unavailable_for_everyone ->
true;
unavailable_for_everyone_but_staff ->
not is_user_staff_from_user_data(UserData);
available ->
false
end.
-spec update_unavailability_cache_for_state(guild_state()) -> unavailability_mode().
update_unavailability_cache_for_state(State) ->
GuildId = maps:get(id, State),
Mode = get_unavailability_mode_from_state(State),
set_cached_unavailability_mode(GuildId, Mode),
Mode.
-spec check_unavailability_transition(guild_state(), guild_state()) -> transition_result().
check_unavailability_transition(OldState, NewState) ->
OldMode = get_unavailability_mode_from_state(OldState),
NewMode = get_unavailability_mode_from_state(NewState),
case {OldMode, NewMode} of
{available, unavailable_for_everyone} ->
{unavailable_enabled, false};
{available, unavailable_for_everyone_but_staff} ->
{unavailable_enabled, true};
{unavailable_for_everyone, available} ->
unavailable_disabled;
{unavailable_for_everyone_but_staff, available} ->
unavailable_disabled;
{unavailable_for_everyone_but_staff, unavailable_for_everyone} ->
{unavailable_enabled, false};
{unavailable_for_everyone, unavailable_for_everyone_but_staff} ->
{unavailable_enabled, true};
_ ->
no_change
end.
-spec handle_unavailability_transition(guild_state(), guild_state()) -> guild_state().
handle_unavailability_transition(OldState, NewState) ->
_ = update_unavailability_cache_for_state(NewState),
GuildId = maps:get(id, NewState),
case check_unavailability_transition(OldState, NewState) of
{unavailable_enabled, StaffOnly} ->
Sessions = maps:get(sessions, NewState, #{}),
lists:foreach(
fun({_SessionId, SessionData}) ->
UserId = maps:get(user_id, SessionData),
Pid = maps:get(pid, SessionData),
ShouldBeUnavailable =
case StaffOnly of
true -> not is_user_staff(UserId, NewState);
false -> true
end,
case ShouldBeUnavailable of
true ->
gen_server:cast(Pid, {dispatch, guild_delete, UnavailablePayload});
false ->
ok
end
end,
maps:to_list(Sessions)
);
disconnect_ineligible_sessions(StaffOnly, NewState, GuildId);
unavailable_disabled ->
Sessions = maps:get(sessions, NewState, #{}),
GuildId = maps:get(id, NewState),
BulkPresences = presence_utils:collect_guild_member_presences(NewState),
lists:foreach(
fun({_SessionId, SessionData}) ->
UserId = maps:get(user_id, SessionData),
Pid = maps:get(pid, SessionData),
GuildState = get_guild_state(UserId, NewState),
gen_server:cast(Pid, {dispatch, guild_create, GuildState}),
presence_utils:send_presence_bulk(Pid, GuildId, UserId, BulkPresences)
case maps:get(pending_connect, SessionData, false) of
true ->
ok;
false ->
UserId = maps:get(user_id, SessionData),
Pid = maps:get(pid, SessionData),
GuildState = guild_data:get_guild_state(UserId, NewState),
gen_server:cast(Pid, {dispatch, guild_create, GuildState}),
presence_utils:send_presence_bulk(Pid, GuildId, UserId, BulkPresences)
end
end,
maps:to_list(Sessions)
);
),
NewState;
no_change ->
NewState
end.
-spec disconnect_ineligible_sessions(boolean(), guild_state(), guild_id()) -> guild_state().
disconnect_ineligible_sessions(StaffOnly, State, GuildId) ->
Sessions = maps:get(sessions, State, #{}),
{FinalState, _DisconnectedUsers} = lists:foldl(
fun({SessionId, SessionData}, {AccState, ProcessedUsers}) ->
UserId = maps:get(user_id, SessionData),
case should_disconnect_user(UserId, StaffOnly, AccState) of
true ->
Pid = maps:get(pid, SessionData, undefined),
maybe_send_guild_leave(Pid, GuildId),
{VoiceState, UpdatedUsers} =
maybe_disconnect_voice_for_user(UserId, ProcessedUsers, AccState),
NewState = guild_sessions:remove_session(SessionId, VoiceState),
{NewState, UpdatedUsers};
false ->
{AccState, ProcessedUsers}
end
end,
{State, sets:new()},
maps:to_list(Sessions)
),
FinalState.
-spec should_disconnect_user(user_id(), boolean(), guild_state()) -> boolean().
should_disconnect_user(UserId, true, State) ->
not is_user_staff(UserId, State);
should_disconnect_user(_UserId, false, _State) ->
true.
-spec maybe_send_guild_leave(pid() | undefined, guild_id()) -> ok.
maybe_send_guild_leave(Pid, GuildId) when is_pid(Pid) ->
gen_server:cast(Pid, {guild_leave, GuildId, forced_unavailable}),
ok;
maybe_send_guild_leave(_Pid, _GuildId) ->
ok.
-spec maybe_disconnect_voice_for_user(user_id(), sets:set(user_id()), guild_state()) ->
{guild_state(), sets:set(user_id())}.
maybe_disconnect_voice_for_user(UserId, ProcessedUsers, State) ->
case sets:is_element(UserId, ProcessedUsers) of
true ->
{State, ProcessedUsers};
false ->
{reply, _Result, VoiceState} = guild_voice_disconnect:disconnect_voice_user(
#{user_id => UserId, connection_id => null},
State
),
{VoiceState, sets:add_element(UserId, ProcessedUsers)}
end.
-spec ensure_unavailability_cache_table() -> ok.
ensure_unavailability_cache_table() ->
case ets:whereis(?GUILD_UNAVAILABILITY_CACHE) of
undefined ->
try ets:new(?GUILD_UNAVAILABILITY_CACHE, [named_table, public, set, {read_concurrency, true}]) of
_ -> ok
catch
error:badarg -> ok
end;
_ ->
ok
end.
-spec set_cached_unavailability_mode(guild_id(), unavailability_mode()) -> ok.
set_cached_unavailability_mode(GuildId, available) ->
ensure_unavailability_cache_table(),
ets:delete(?GUILD_UNAVAILABILITY_CACHE, GuildId),
ok;
set_cached_unavailability_mode(GuildId, Mode) ->
ensure_unavailability_cache_table(),
ets:insert(?GUILD_UNAVAILABILITY_CACHE, {GuildId, Mode}),
ok.
-spec normalize_unavailability_mode(term()) -> unavailability_mode().
normalize_unavailability_mode(unavailable_for_everyone) ->
unavailable_for_everyone;
normalize_unavailability_mode(unavailable_for_everyone_but_staff) ->
unavailable_for_everyone_but_staff;
normalize_unavailability_mode(_) ->
available.
-spec get_unavailability_mode_from_state(guild_state()) -> unavailability_mode().
get_unavailability_mode_from_state(State) ->
Data = maps:get(data, State, #{}),
Guild = maps:get(<<"guild">>, Data, #{}),
Features = maps:get(<<"features">>, Guild, []),
get_unavailability_mode_from_features(Features).
-spec get_unavailability_mode_from_features(term()) -> unavailability_mode().
get_unavailability_mode_from_features(Features) when is_list(Features) ->
HasUnavailableForEveryone = lists:member(<<"UNAVAILABLE_FOR_EVERYONE">>, Features),
HasUnavailableForEveryoneButStaff =
lists:member(<<"UNAVAILABLE_FOR_EVERYONE_BUT_STAFF">>, Features),
case {HasUnavailableForEveryone, HasUnavailableForEveryoneButStaff} of
{true, _} ->
unavailable_for_everyone;
{false, true} ->
unavailable_for_everyone_but_staff;
{false, false} ->
available
end;
get_unavailability_mode_from_features(_) ->
available.
-spec is_user_staff_from_user_data(map()) -> boolean().
is_user_staff_from_user_data(UserData) when is_map(UserData) ->
case parse_is_staff_value(maps:get(<<"is_staff">>, UserData, undefined)) of
undefined ->
is_user_staff_from_flags(UserData);
IsStaff ->
IsStaff
end;
is_user_staff_from_user_data(_) ->
false.
-spec parse_is_staff_value(term()) -> boolean() | undefined.
parse_is_staff_value(true) ->
true;
parse_is_staff_value(false) ->
false;
parse_is_staff_value(<<"true">>) ->
true;
parse_is_staff_value(<<"false">>) ->
false;
parse_is_staff_value(_) ->
undefined.
-spec is_user_staff_from_flags(map()) -> boolean().
is_user_staff_from_flags(UserData) ->
FlagsValue = maps:get(<<"flags">>, UserData, 0),
Flags = type_conv:to_integer(FlagsValue),
case Flags of
undefined ->
false;
Value when is_integer(Value) ->
(Value band ?STAFF_USER_FLAG) =:= ?STAFF_USER_FLAG
end.
-ifdef(TEST).
-spec cleanup_unavailability_cache(guild_id()) -> ok.
cleanup_unavailability_cache(GuildId) ->
set_cached_unavailability_mode(GuildId, available).
disconnect_ineligible_sessions_staff_only_test() ->
Parent = self(),
GuildId = 99001,
NonStaffPid = start_session_capture(non_staff, Parent),
StaffPid = start_session_capture(staff, Parent),
try
BaseState = state_for_unavailability_transition_test(GuildId, NonStaffPid, StaffPid),
OldState = BaseState,
NewState = BaseState#{
data => #{
<<"guild">> => #{
<<"features">> => [<<"UNAVAILABLE_FOR_EVERYONE_BUT_STAFF">>]
},
<<"members">> => [
#{<<"user">> => #{<<"id">> => <<"1001">>, <<"flags">> => <<"0">>}},
#{<<"user">> => #{<<"id">> => <<"1002">>, <<"flags">> => <<"1">>}}
]
}
},
UpdatedState = handle_unavailability_transition(OldState, NewState),
Sessions = maps:get(sessions, UpdatedState, #{}),
?assertEqual(1, map_size(Sessions)),
?assert(maps:is_key(<<"staff">>, Sessions)),
?assertEqual(unavailable_for_everyone_but_staff, get_cached_unavailability_mode(GuildId)),
receive
{non_staff, {'$gen_cast', {guild_leave, GuildId, forced_unavailable}}} -> ok
after 1000 ->
?assert(false)
end,
receive
{staff, {'$gen_cast', {guild_leave, GuildId, forced_unavailable}}} ->
?assert(false)
after 200 ->
ok
end
after
cleanup_unavailability_cache(GuildId),
NonStaffPid ! stop,
StaffPid ! stop
end.
disconnect_ineligible_sessions_everyone_test() ->
Parent = self(),
GuildId = 99002,
UserOnePid = start_session_capture(user_one, Parent),
UserTwoPid = start_session_capture(user_two, Parent),
try
BaseState = state_for_unavailability_transition_test(GuildId, UserOnePid, UserTwoPid),
OldState = BaseState,
NewState = BaseState#{
data => #{
<<"guild">> => #{
<<"features">> => [<<"UNAVAILABLE_FOR_EVERYONE">>]
},
<<"members">> => [
#{<<"user">> => #{<<"id">> => <<"1001">>, <<"flags">> => <<"0">>}},
#{<<"user">> => #{<<"id">> => <<"1002">>, <<"flags">> => <<"1">>}}
]
}
},
UpdatedState = handle_unavailability_transition(OldState, NewState),
?assertEqual(#{}, maps:get(sessions, UpdatedState, #{})),
?assertEqual(unavailable_for_everyone, get_cached_unavailability_mode(GuildId)),
receive
{user_one, {'$gen_cast', {guild_leave, GuildId, forced_unavailable}}} -> ok
after 1000 ->
?assert(false)
end,
receive
{user_two, {'$gen_cast', {guild_leave, GuildId, forced_unavailable}}} -> ok
after 1000 ->
?assert(false)
end
after
cleanup_unavailability_cache(GuildId),
UserOnePid ! stop,
UserTwoPid ! stop
end.
is_guild_unavailable_for_user_from_cache_is_staff_test() ->
GuildId = 99003,
try
set_cached_unavailability_mode(GuildId, unavailable_for_everyone_but_staff),
?assertEqual(
true,
is_guild_unavailable_for_user_from_cache(
GuildId,
#{<<"is_staff">> => false}
)
),
?assertEqual(
false,
is_guild_unavailable_for_user_from_cache(
GuildId,
#{<<"is_staff">> => true}
)
)
after
cleanup_unavailability_cache(GuildId)
end.
-spec start_session_capture(atom(), pid()) -> pid().
start_session_capture(Tag, Parent) ->
spawn(fun() -> session_capture_loop(Tag, Parent) end).
-spec session_capture_loop(atom(), pid()) -> ok.
session_capture_loop(Tag, Parent) ->
receive
stop ->
ok;
{'$gen_cast', Msg} ->
Parent ! {Tag, {'$gen_cast', Msg}},
session_capture_loop(Tag, Parent);
_Other ->
session_capture_loop(Tag, Parent)
end.
-spec state_for_unavailability_transition_test(guild_id(), pid(), pid()) -> guild_state().
state_for_unavailability_transition_test(GuildId, NonStaffPid, StaffPid) ->
#{
id => GuildId,
sessions => #{
<<"non_staff">> => #{
session_id => <<"non_staff">>,
user_id => 1001,
pid => NonStaffPid,
mref => make_ref(),
active_guilds => sets:new(),
user_roles => [],
bot => false,
is_staff => false,
previous_passive_updates => #{},
previous_passive_channel_versions => #{},
previous_passive_voice_states => #{}
},
<<"staff">> => #{
session_id => <<"staff">>,
user_id => 1002,
pid => StaffPid,
mref => make_ref(),
active_guilds => sets:new(),
user_roles => [],
bot => false,
is_staff => true,
previous_passive_updates => #{},
previous_passive_channel_versions => #{},
previous_passive_voice_states => #{}
}
},
presence_subscriptions => #{},
member_list_subscriptions => #{},
member_subscriptions => guild_subscriptions:init_state(),
data => #{
<<"guild">> => #{
<<"features">> => []
},
<<"members">> => [
#{<<"user">> => #{<<"id">> => <<"1001">>, <<"flags">> => <<"0">>}},
#{<<"user">> => #{<<"id">> => <<"1002">>, <<"flags">> => <<"1">>}}
]
},
voice_states => #{},
pending_voice_connections => #{}
}.
is_guild_unavailable_for_user_unavailable_for_everyone_test() ->
State = #{
data => #{
<<"guild">> => #{
<<"features">> => [<<"UNAVAILABLE_FOR_EVERYONE">>]
},
<<"members">> => []
}
},
?assertEqual(true, is_guild_unavailable_for_user(123, State)).
is_guild_unavailable_for_user_available_test() ->
State = #{
data => #{
<<"guild">> => #{
<<"features">> => []
},
<<"members">> => []
}
},
?assertEqual(false, is_guild_unavailable_for_user(123, State)).
check_unavailability_transition_no_change_test() ->
State = #{
data => #{
<<"guild">> => #{
<<"features">> => []
}
}
},
?assertEqual(no_change, check_unavailability_transition(State, State)).
check_unavailability_transition_enabled_test() ->
OldState = #{
data => #{
<<"guild">> => #{
<<"features">> => []
}
}
},
NewState = #{
data => #{
<<"guild">> => #{
<<"features">> => [<<"UNAVAILABLE_FOR_EVERYONE">>]
}
}
},
?assertEqual({unavailable_enabled, false}, check_unavailability_transition(OldState, NewState)).
check_unavailability_transition_disabled_test() ->
OldState = #{
data => #{
<<"guild">> => #{
<<"features">> => [<<"UNAVAILABLE_FOR_EVERYONE">>]
}
}
},
NewState = #{
data => #{
<<"guild">> => #{
<<"features">> => []
}
}
},
?assertEqual(unavailable_disabled, check_unavailability_transition(OldState, NewState)).
-endif.

View File

@@ -17,9 +17,7 @@
-module(guild_client).
-export([
voice_state_update/3
]).
-export([voice_state_update/3]).
-export_type([
voice_state_update_success/0,
@@ -27,7 +25,10 @@
voice_state_update_result/0
]).
-define(DEFAULT_TIMEOUT, 12000).
-define(CIRCUIT_BREAKER_TABLE, guild_circuit_breaker).
-define(FAILURE_THRESHOLD, 5).
-define(RECOVERY_TIMEOUT_MS, 30000).
-define(MAX_CONCURRENT, 50).
-type voice_state_update_success() :: #{
success := true,
@@ -44,10 +45,39 @@
{ok, voice_state_update_success()}
| {error, timeout}
| {error, noproc}
| {error, circuit_breaker_open}
| {error, too_many_requests}
| {error, atom(), atom()}.
-type circuit_state() :: closed | open | half_open.
-spec voice_state_update(pid(), map(), timeout()) -> voice_state_update_result().
voice_state_update(GuildPid, Request, Timeout) ->
ensure_table(),
case acquire_slot(GuildPid) of
ok ->
try
execute_with_circuit_breaker(GuildPid, Request, Timeout)
after
release_slot(GuildPid)
end;
{error, Reason} ->
{error, Reason}
end.
-spec execute_with_circuit_breaker(pid(), map(), timeout()) -> voice_state_update_result().
execute_with_circuit_breaker(GuildPid, Request, Timeout) ->
case get_circuit_state(GuildPid) of
open ->
{error, circuit_breaker_open};
State when State =:= closed; State =:= half_open ->
Result = do_call(GuildPid, Request, Timeout),
update_circuit_state(GuildPid, Result, State),
Result
end.
-spec do_call(pid(), map(), timeout()) -> voice_state_update_result().
do_call(GuildPid, Request, Timeout) ->
try gen_server:call(GuildPid, {voice_state_update, Request}, Timeout) of
Response when is_map(Response) ->
case maps:get(success, Response, false) of
@@ -57,12 +87,138 @@ voice_state_update(GuildPid, Request, Timeout) ->
{error, Category, ErrorAtom} when is_atom(Category), is_atom(ErrorAtom) ->
{error, Category, ErrorAtom}
catch
exit:{timeout, _} ->
{error, timeout};
exit:{noproc, _} ->
{error, noproc};
exit:{normal, _} ->
{error, noproc}
exit:{timeout, _} -> {error, timeout};
exit:{noproc, _} -> {error, noproc};
exit:{normal, _} -> {error, noproc}
end.
-spec get_circuit_state(pid()) -> circuit_state().
get_circuit_state(GuildPid) ->
case safe_lookup(GuildPid) of
[] ->
closed;
[{_, #{state := open, opened_at := OpenedAt}}] ->
Now = erlang:system_time(millisecond),
case Now - OpenedAt > ?RECOVERY_TIMEOUT_MS of
true -> half_open;
false -> open
end;
[{_, #{state := State}}] ->
State
end.
-spec update_circuit_state(pid(), voice_state_update_result(), circuit_state()) -> ok.
update_circuit_state(GuildPid, Result, PrevState) ->
IsSuccess = is_success_result(Result),
case {IsSuccess, PrevState} of
{true, half_open} ->
ets:delete(?CIRCUIT_BREAKER_TABLE, GuildPid),
ok;
{true, closed} ->
reset_failures(GuildPid);
{false, _} ->
record_failure(GuildPid)
end.
-spec is_success_result(voice_state_update_result()) -> boolean().
is_success_result({ok, _}) -> true;
is_success_result(_) -> false.
-spec reset_failures(pid()) -> ok.
reset_failures(GuildPid) ->
case safe_lookup(GuildPid) of
[{_, State}] ->
ets:insert(?CIRCUIT_BREAKER_TABLE, {GuildPid, State#{failures => 0}}),
ok;
[] ->
ok
end.
-spec record_failure(pid()) -> ok.
record_failure(GuildPid) ->
Now = erlang:system_time(millisecond),
case safe_lookup(GuildPid) of
[] ->
ets:insert(
?CIRCUIT_BREAKER_TABLE,
{GuildPid, #{
state => closed,
failures => 1,
concurrent => 0
}}
),
ok;
[{_, #{failures := F} = State}] when F + 1 >= ?FAILURE_THRESHOLD ->
ets:insert(
?CIRCUIT_BREAKER_TABLE,
{GuildPid, State#{
state => open,
failures => F + 1,
opened_at => Now
}}
),
ok;
[{_, #{failures := F} = State}] ->
ets:insert(?CIRCUIT_BREAKER_TABLE, {GuildPid, State#{failures => F + 1}}),
ok
end.
-spec acquire_slot(pid()) -> ok | {error, too_many_requests}.
acquire_slot(GuildPid) ->
case safe_lookup(GuildPid) of
[] ->
ets:insert(
?CIRCUIT_BREAKER_TABLE,
{GuildPid, #{
state => closed,
failures => 0,
concurrent => 1
}}
),
ok;
[{_, #{concurrent := C}}] when C >= ?MAX_CONCURRENT ->
{error, too_many_requests};
[{_, #{concurrent := C} = State}] ->
ets:insert(?CIRCUIT_BREAKER_TABLE, {GuildPid, State#{concurrent => C + 1}}),
ok
end.
-spec release_slot(pid()) -> ok.
release_slot(GuildPid) ->
case safe_lookup(GuildPid) of
[{_, #{concurrent := C} = State}] when C > 0 ->
ets:insert(?CIRCUIT_BREAKER_TABLE, {GuildPid, State#{concurrent => C - 1}}),
ok;
_ ->
ok
end.
-spec safe_lookup(pid()) -> list().
safe_lookup(GuildPid) ->
try ets:lookup(?CIRCUIT_BREAKER_TABLE, GuildPid) of
Result -> Result
catch
error:badarg -> []
end.
-spec ensure_table() -> ok.
ensure_table() ->
case ets:whereis(?CIRCUIT_BREAKER_TABLE) of
undefined ->
try
ets:new(?CIRCUIT_BREAKER_TABLE, [
named_table,
public,
set,
{read_concurrency, true},
{write_concurrency, true}
]),
ok
catch
error:badarg -> ok
end;
_ ->
ok
end.
-ifdef(TEST).
@@ -72,4 +228,136 @@ module_exports_test() ->
Exports = guild_client:module_info(exports),
?assert(lists:member({voice_state_update, 3}, Exports)).
ensure_table_creates_table_test() ->
catch ets:delete(?CIRCUIT_BREAKER_TABLE),
?assertEqual(undefined, ets:whereis(?CIRCUIT_BREAKER_TABLE)),
ensure_table(),
?assertNotEqual(undefined, ets:whereis(?CIRCUIT_BREAKER_TABLE)).
ensure_table_idempotent_test() ->
ensure_table(),
ensure_table(),
?assertNotEqual(undefined, ets:whereis(?CIRCUIT_BREAKER_TABLE)).
acquire_slot_creates_entry_test() ->
ensure_table(),
Pid = spawn(fun() ->
receive
done -> ok
end
end),
ets:delete_all_objects(?CIRCUIT_BREAKER_TABLE),
?assertEqual(ok, acquire_slot(Pid)),
[{Pid, State}] = ets:lookup(?CIRCUIT_BREAKER_TABLE, Pid),
?assertEqual(1, maps:get(concurrent, State)),
Pid ! done.
acquire_slot_increments_test() ->
ensure_table(),
Pid = spawn(fun() ->
receive
done -> ok
end
end),
ets:delete_all_objects(?CIRCUIT_BREAKER_TABLE),
acquire_slot(Pid),
acquire_slot(Pid),
[{Pid, State}] = ets:lookup(?CIRCUIT_BREAKER_TABLE, Pid),
?assertEqual(2, maps:get(concurrent, State)),
Pid ! done.
release_slot_decrements_test() ->
ensure_table(),
Pid = spawn(fun() ->
receive
done -> ok
end
end),
ets:delete_all_objects(?CIRCUIT_BREAKER_TABLE),
acquire_slot(Pid),
acquire_slot(Pid),
release_slot(Pid),
[{Pid, State}] = ets:lookup(?CIRCUIT_BREAKER_TABLE, Pid),
?assertEqual(1, maps:get(concurrent, State)),
Pid ! done.
get_circuit_state_closed_test() ->
ensure_table(),
Pid = spawn(fun() ->
receive
done -> ok
end
end),
ets:delete_all_objects(?CIRCUIT_BREAKER_TABLE),
?assertEqual(closed, get_circuit_state(Pid)),
Pid ! done.
get_circuit_state_open_test() ->
ensure_table(),
Pid = spawn(fun() ->
receive
done -> ok
end
end),
ets:delete_all_objects(?CIRCUIT_BREAKER_TABLE),
Now = erlang:system_time(millisecond),
ets:insert(
?CIRCUIT_BREAKER_TABLE,
{Pid, #{
state => open,
failures => 5,
concurrent => 0,
opened_at => Now
}}
),
?assertEqual(open, get_circuit_state(Pid)),
Pid ! done.
get_circuit_state_half_open_test() ->
ensure_table(),
Pid = spawn(fun() ->
receive
done -> ok
end
end),
ets:delete_all_objects(?CIRCUIT_BREAKER_TABLE),
OldTime = erlang:system_time(millisecond) - ?RECOVERY_TIMEOUT_MS - 1000,
ets:insert(
?CIRCUIT_BREAKER_TABLE,
{Pid, #{
state => open,
failures => 5,
concurrent => 0,
opened_at => OldTime
}}
),
?assertEqual(half_open, get_circuit_state(Pid)),
Pid ! done.
record_failure_opens_circuit_test() ->
ensure_table(),
Pid = spawn(fun() ->
receive
done -> ok
end
end),
ets:delete_all_objects(?CIRCUIT_BREAKER_TABLE),
ets:insert(
?CIRCUIT_BREAKER_TABLE,
{Pid, #{
state => closed,
failures => ?FAILURE_THRESHOLD - 1,
concurrent => 0
}}
),
record_failure(Pid),
[{Pid, State}] = ets:lookup(?CIRCUIT_BREAKER_TABLE, Pid),
?assertEqual(open, maps:get(state, State)),
Pid ! done.
is_success_result_test() ->
?assertEqual(true, is_success_result({ok, #{}})),
?assertEqual(false, is_success_result({error, timeout})),
?assertEqual(false, is_success_result({error, noproc})).
-endif.

View File

@@ -31,6 +31,7 @@
-type guild_data_map() :: map().
-type guild_member() :: map().
-type channel_list() :: [map()].
-type user_id() :: integer().
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
@@ -45,11 +46,10 @@ get_guild_data(#{user_id := UserId}, State) ->
Reply = #{guild_data => GuildData},
{reply, Reply, State};
_ ->
Members = map_utils:ensure_list(maps:get(<<"members">>, Data, [])),
case member_in_list(UserId, Members) of
false ->
case guild_data_index:get_member(UserId, Data) of
undefined ->
{reply, #{guild_data => null, error_reason => <<"forbidden">>}, State};
true ->
_ ->
GuildData = build_complete_guild_data(Data, State),
{reply, #{guild_data => GuildData}, State}
end
@@ -76,7 +76,7 @@ has_member(#{user_id := UserId}, State) ->
-spec list_guild_members(map(), guild_state()) -> guild_reply(map()).
list_guild_members(#{limit := Limit, offset := Offset}, State) ->
Data = guild_data_map(State),
AllMembers = map_utils:ensure_list(maps:get(<<"members">>, Data, [])),
AllMembers = guild_data_index:member_list(Data),
TotalCount = length(AllMembers),
PaginatedMembers = paginate_members(AllMembers, Limit, Offset),
{reply, #{members => PaginatedMembers, total => TotalCount}, State}.
@@ -93,19 +93,24 @@ get_first_viewable_text_channel(State) ->
EveryoneChannelId = find_everyone_viewable_text_channel(Channels, State),
{reply, #{channel_id => EveryoneChannelId}, State}.
-spec get_guild_state(integer(), guild_state()) -> map().
-spec get_guild_state(user_id(), guild_state()) -> map().
get_guild_state(UserId, State) ->
Data = guild_data_map(State),
GuildId = map_utils:get_integer(State, id, 0),
AllChannels = channels_from_data(Data),
AllMembers = map_utils:ensure_list(maps:get(<<"members">>, Data, [])),
AllMembers = guild_data_index:member_values(Data),
Member = find_member_by_user_id(UserId, State),
{ViewableChannels, JoinedAt} = derive_member_view(UserId, Member, State, AllChannels),
OnlineCount = guild_member_list:get_online_count(State),
OwnMemberList = case Member of
undefined -> [];
M -> [M]
end,
OwnMemberList =
case Member of
undefined -> [];
M -> [M]
end,
VoiceStates = guild_voice:get_voice_states_list(State),
VoiceMembers = voice_members_from_states(VoiceStates, AllMembers),
Members = merge_members(OwnMemberList, VoiceMembers),
MemberCount = maps:get(member_count, State, length(AllMembers)),
#{
<<"id">> => integer_to_binary(GuildId),
<<"properties">> => maps:get(<<"guild">>, Data, #{}),
@@ -113,11 +118,11 @@ get_guild_state(UserId, State) ->
<<"channels">> => ViewableChannels,
<<"emojis">> => maps:get(<<"emojis">>, Data, []),
<<"stickers">> => maps:get(<<"stickers">>, Data, []),
<<"members">> => OwnMemberList,
<<"member_count">> => length(AllMembers),
<<"members">> => Members,
<<"member_count">> => MemberCount,
<<"online_count">> => OnlineCount,
<<"presences">> => [],
<<"voice_states">> => guild_voice:get_voice_states_list(State),
<<"voice_states">> => VoiceStates,
<<"joined_at">> => JoinedAt
}.
@@ -140,6 +145,7 @@ find_everyone_viewable_text_channel(Channels, State) ->
map_utils:ensure_list(Channels)
).
-spec find_member_by_user_id(user_id(), guild_state()) -> guild_member() | undefined.
find_member_by_user_id(UserId, State) ->
guild_permissions:find_member_by_user_id(UserId, State).
@@ -164,19 +170,7 @@ channels_from_state(State) ->
-spec channels_from_data(guild_data_map()) -> channel_list().
channels_from_data(Data) ->
map_utils:ensure_list(maps:get(<<"channels">>, Data, [])).
-spec member_in_list(integer(), [guild_member()]) -> boolean().
member_in_list(UserId, Members) ->
lists:any(fun(Member) -> member_matches(UserId, Member) end, Members).
-spec member_matches(integer(), guild_member()) -> boolean().
member_matches(UserId, Member) ->
MemberUser = map_utils:ensure_map(maps:get(<<"user">>, Member, #{})),
case map_utils:get_integer(MemberUser, <<"id">>, undefined) of
undefined -> false;
Id -> Id =:= UserId
end.
guild_data_index:channel_list(Data).
-spec paginate_members([guild_member()], non_neg_integer(), non_neg_integer()) -> [guild_member()].
paginate_members(Members, Limit, Offset) ->
@@ -191,7 +185,7 @@ paginate_members(Members, Limit, Offset) ->
end
end.
-spec derive_member_view(integer(), guild_member() | undefined, guild_state(), channel_list()) ->
-spec derive_member_view(user_id(), guild_member() | undefined, guild_state(), channel_list()) ->
{channel_list(), term()}.
derive_member_view(_UserId, undefined, _State, _Channels) ->
{[], null};
@@ -210,7 +204,63 @@ derive_member_view(UserId, Member, State, Channels) ->
JoinedAt = maps:get(<<"joined_at">>, Member, null),
{Filtered, JoinedAt}.
-spec role_permissions_for_id(list(), integer()) -> integer().
-spec voice_members_from_states([map()], [guild_member()]) -> [guild_member()].
voice_members_from_states(VoiceStates, Members) ->
MemberIndex = build_member_index(Members),
lists:filtermap(
fun(VoiceState) ->
case voice_state_utils:voice_state_user_id(VoiceState) of
undefined ->
false;
UserId ->
case maps:get(UserId, MemberIndex, undefined) of
undefined -> false;
Member -> {true, Member}
end
end
end,
VoiceStates
).
-spec build_member_index([guild_member()]) -> #{integer() => guild_member()}.
build_member_index(Members) ->
lists:foldl(
fun(Member, Acc) ->
case member_user_id(Member) of
undefined -> Acc;
UserId -> maps:put(UserId, Member, Acc)
end
end,
#{},
Members
).
-spec merge_members([guild_member()], [guild_member()]) -> [guild_member()].
merge_members(Primary, Secondary) ->
{Merged, _} =
lists:foldl(
fun(Member, {Acc, Seen}) ->
case member_user_id(Member) of
undefined ->
{Acc, Seen};
UserId ->
case sets:is_element(UserId, Seen) of
true -> {Acc, Seen};
false -> {[Member | Acc], sets:add_element(UserId, Seen)}
end
end
end,
{[], sets:new()},
Primary ++ Secondary
),
lists:reverse(Merged).
-spec member_user_id(guild_member()) -> integer() | undefined.
member_user_id(Member) ->
MemberUser = map_utils:ensure_map(maps:get(<<"user">>, Member, #{})),
map_utils:get_integer(MemberUser, <<"id">>, undefined).
-spec role_permissions_for_id([map()], integer()) -> integer().
role_permissions_for_id(Roles, GuildId) ->
lists:foldl(
fun(Role, Acc) ->
@@ -229,6 +279,10 @@ select_first_viewable(Channel, GuildId, BasePerms) ->
ChannelId = map_utils:get_integer(Channel, <<"id">>, undefined),
select_first_viewable(ChannelType, ChannelId, Channel, GuildId, BasePerms).
-spec select_first_viewable(
integer() | undefined, integer() | undefined, map(), integer(), integer()
) ->
integer() | null.
select_first_viewable(0, ChannelId, Channel, GuildId, BasePerms) when is_integer(ChannelId) ->
case (BasePerms band constants:administrator_permission()) =/= 0 of
true ->
@@ -273,6 +327,13 @@ find_everyone_viewable_text_channel_test() ->
ChannelId = find_everyone_viewable_text_channel(Channels, State),
?assertEqual(500, ChannelId).
paginate_members_test() ->
Members = [#{<<"id">> => 1}, #{<<"id">> => 2}, #{<<"id">> => 3}],
?assertEqual([#{<<"id">> => 1}, #{<<"id">> => 2}], paginate_members(Members, 2, 0)),
?assertEqual([#{<<"id">> => 2}, #{<<"id">> => 3}], paginate_members(Members, 2, 1)),
?assertEqual([#{<<"id">> => 3}], paginate_members(Members, 2, 2)),
?assertEqual([], paginate_members(Members, 2, 5)).
test_state() ->
GuildId = 100,
ViewPerm = constants:view_channel_permission(),

View File

@@ -0,0 +1,480 @@
%% Copyright (C) 2026 Fluxer Contributors
%%
%% This file is part of Fluxer.
%%
%% Fluxer is free software: you can redistribute it and/or modify
%% it under the terms of the GNU Affero General Public License as published by
%% the Free Software Foundation, either version 3 of the License, or
%% (at your option) any later version.
%%
%% Fluxer is distributed in the hope that it will be useful,
%% but WITHOUT ANY WARRANTY; without even the implied warranty of
%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
%% GNU Affero General Public License for more details.
%%
%% You should have received a copy of the GNU Affero General Public License
%% along with Fluxer. If not, see <https://www.gnu.org/licenses/>.
-module(guild_data_index).
-export([
normalize_data/1,
member_map/1,
member_values/1,
member_list/1,
member_count/1,
member_ids/1,
member_role_index/1,
get_member/2,
put_member/2,
put_member_map/2,
put_member_list/2,
remove_member/2,
role_list/1,
role_index/1,
put_roles/2,
channel_list/1,
channel_index/1,
put_channels/2
]).
-type guild_data() :: map().
-type member() :: map().
-type role() :: map().
-type channel() :: map().
-type user_id() :: integer().
-type snowflake_id() :: integer().
-type role_member_index() :: #{snowflake_id() => #{user_id() => true}}.
-spec normalize_data(guild_data()) -> guild_data().
normalize_data(Data) when is_map(Data) ->
MemberMap = member_map(Data),
Roles = role_list(Data),
Channels = channel_list(Data),
Data1 = maps:put(<<"members">>, MemberMap, Data),
Data2 = maps:put(<<"roles">>, Roles, Data1),
Data3 = maps:put(<<"channels">>, Channels, Data2),
Data4 = maps:put(<<"role_index">>, build_id_index(Roles), Data3),
Data5 = maps:put(<<"channel_index">>, build_id_index(Channels), Data4),
maps:put(<<"member_role_index">>, build_member_role_index(MemberMap), Data5);
normalize_data(Data) ->
Data.
-spec member_map(guild_data()) -> #{user_id() => member()}.
member_map(Data) when is_map(Data) ->
case maps:get(<<"members">>, Data, #{}) of
Members when is_map(Members) ->
normalize_member_map(Members);
Members when is_list(Members) ->
build_member_map(Members);
_ ->
#{}
end;
member_map(_) ->
#{}.
-spec member_list(guild_data()) -> [member()].
member_list(Data) ->
MemberPairs = maps:to_list(member_map(Data)),
SortedPairs = lists:sort(fun({A, _}, {B, _}) -> A =< B end, MemberPairs),
[Member || {_UserId, Member} <- SortedPairs].
-spec member_values(guild_data()) -> [member()].
member_values(Data) ->
maps:values(member_map(Data)).
-spec member_count(guild_data()) -> non_neg_integer().
member_count(Data) ->
map_size(member_map(Data)).
-spec member_ids(guild_data()) -> [user_id()].
member_ids(Data) ->
maps:keys(member_map(Data)).
-spec member_role_index(guild_data()) -> role_member_index().
member_role_index(Data) when is_map(Data) ->
case maps:get(<<"member_role_index">>, Data, undefined) of
Index when is_map(Index) -> normalize_member_role_index(Index);
_ -> build_member_role_index(member_map(Data))
end;
member_role_index(_) ->
#{}.
-spec get_member(user_id(), guild_data()) -> member() | undefined.
get_member(UserId, Data) when is_integer(UserId) ->
maps:get(UserId, member_map(Data), undefined);
get_member(_, _) ->
undefined.
-spec put_member(member(), guild_data()) -> guild_data().
put_member(Member, Data) when is_map(Member), is_map(Data) ->
case member_user_id(Member) of
undefined ->
Data;
UserId ->
MemberMap = member_map(Data),
ExistingMember = maps:get(UserId, MemberMap, undefined),
ExistingRoles = member_role_ids(ExistingMember),
UpdatedRoles = member_role_ids(Member),
RoleIndex = member_role_index(Data),
RoleIndex1 = remove_user_from_member_role_index(UserId, ExistingRoles, RoleIndex),
RoleIndex2 = add_user_to_member_role_index(UserId, UpdatedRoles, RoleIndex1),
Data1 = maps:put(<<"members">>, maps:put(UserId, Member, MemberMap), Data),
maps:put(<<"member_role_index">>, RoleIndex2, Data1)
end;
put_member(_, Data) ->
Data.
-spec put_member_map(#{user_id() => member()}, guild_data()) -> guild_data().
put_member_map(MemberMap, Data) when is_map(MemberMap), is_map(Data) ->
NormalizedMemberMap = normalize_member_map(MemberMap),
Data1 = maps:put(<<"members">>, NormalizedMemberMap, Data),
maps:put(<<"member_role_index">>, build_member_role_index(NormalizedMemberMap), Data1);
put_member_map(_, Data) ->
Data.
-spec put_member_list([member()], guild_data()) -> guild_data().
put_member_list(Members, Data) when is_list(Members), is_map(Data) ->
put_member_map(build_member_map(Members), Data);
put_member_list(_, Data) ->
Data.
-spec remove_member(user_id(), guild_data()) -> guild_data().
remove_member(UserId, Data) when is_integer(UserId), is_map(Data) ->
MemberMap = member_map(Data),
Member = maps:get(UserId, MemberMap, undefined),
MemberRoles = member_role_ids(Member),
RoleIndex = member_role_index(Data),
RoleIndex1 = remove_user_from_member_role_index(UserId, MemberRoles, RoleIndex),
Data1 = maps:put(<<"members">>, maps:remove(UserId, MemberMap), Data),
maps:put(<<"member_role_index">>, RoleIndex1, Data1);
remove_member(_, Data) ->
Data.
-spec role_list(guild_data()) -> [role()].
role_list(Data) when is_map(Data) ->
ensure_list(maps:get(<<"roles">>, Data, []));
role_list(_) ->
[].
-spec role_index(guild_data()) -> #{snowflake_id() => role()}.
role_index(Data) when is_map(Data) ->
case maps:get(<<"role_index">>, Data, undefined) of
Index when is_map(Index) -> normalize_id_index(Index);
_ -> build_id_index(role_list(Data))
end;
role_index(_) ->
#{}.
-spec put_roles([role()], guild_data()) -> guild_data().
put_roles(Roles, Data) when is_map(Data) ->
RoleList = ensure_list(Roles),
Data1 = maps:put(<<"roles">>, RoleList, Data),
maps:put(<<"role_index">>, build_id_index(RoleList), Data1);
put_roles(_, Data) ->
Data.
-spec channel_list(guild_data()) -> [channel()].
channel_list(Data) when is_map(Data) ->
ensure_list(maps:get(<<"channels">>, Data, []));
channel_list(_) ->
[].
-spec channel_index(guild_data()) -> #{snowflake_id() => channel()}.
channel_index(Data) when is_map(Data) ->
case maps:get(<<"channel_index">>, Data, undefined) of
Index when is_map(Index) -> normalize_id_index(Index);
_ -> build_id_index(channel_list(Data))
end;
channel_index(_) ->
#{}.
-spec put_channels([channel()], guild_data()) -> guild_data().
put_channels(Channels, Data) when is_map(Data) ->
ChannelList = ensure_list(Channels),
Data1 = maps:put(<<"channels">>, ChannelList, Data),
maps:put(<<"channel_index">>, build_id_index(ChannelList), Data1);
put_channels(_, Data) ->
Data.
-spec build_member_map([member()]) -> #{user_id() => member()}.
build_member_map(Members) ->
lists:foldl(
fun(Member, Acc) ->
case member_user_id(Member) of
undefined ->
Acc;
UserId ->
maps:put(UserId, Member, Acc)
end
end,
#{},
Members
).
-spec normalize_member_map(map()) -> #{user_id() => member()}.
normalize_member_map(MemberMap) ->
maps:fold(
fun(Key, Member, Acc) ->
case normalize_member_key(Key, Member) of
undefined ->
Acc;
UserId ->
maps:put(UserId, Member, Acc)
end
end,
#{},
MemberMap
).
-spec normalize_member_key(term(), member()) -> user_id() | undefined.
normalize_member_key(Key, Member) ->
case type_conv:to_integer(Key) of
undefined -> member_user_id(Member);
UserId -> UserId
end.
-spec member_user_id(member()) -> user_id() | undefined.
member_user_id(Member) when is_map(Member) ->
User = maps:get(<<"user">>, Member, #{}),
map_utils:get_integer(User, <<"id">>, undefined);
member_user_id(_) ->
undefined.
-spec member_role_ids(term()) -> [snowflake_id()].
member_role_ids(Member) when is_map(Member) ->
extract_integer_list(maps:get(<<"roles">>, Member, []));
member_role_ids(_) ->
[].
-spec build_member_role_index(#{user_id() => member()}) -> role_member_index().
build_member_role_index(MemberMap) ->
maps:fold(
fun(UserId, Member, Acc) ->
add_user_to_member_role_index(UserId, member_role_ids(Member), Acc)
end,
#{},
MemberMap
).
-spec normalize_member_role_index(map()) -> role_member_index().
normalize_member_role_index(Index) ->
maps:fold(
fun(RoleKey, Members, Acc) ->
case type_conv:to_integer(RoleKey) of
undefined ->
Acc;
RoleId ->
NormalizedMembers = normalize_member_role_members(Members),
case map_size(NormalizedMembers) of
0 ->
Acc;
_ ->
maps:put(RoleId, NormalizedMembers, Acc)
end
end
end,
#{},
Index
).
-spec normalize_member_role_members(term()) -> #{user_id() => true}.
normalize_member_role_members(Members) when is_map(Members) ->
maps:fold(
fun(UserKey, _Flag, Acc) ->
case type_conv:to_integer(UserKey) of
undefined ->
Acc;
UserId ->
maps:put(UserId, true, Acc)
end
end,
#{},
Members
);
normalize_member_role_members(_) ->
#{}.
-spec add_user_to_member_role_index(user_id(), [snowflake_id()], role_member_index()) ->
role_member_index().
add_user_to_member_role_index(UserId, RoleIds, RoleIndex) when is_integer(UserId) ->
lists:foldl(
fun(RoleId, Acc) ->
RoleMembers = maps:get(RoleId, Acc, #{}),
maps:put(RoleId, maps:put(UserId, true, RoleMembers), Acc)
end,
RoleIndex,
RoleIds
).
-spec remove_user_from_member_role_index(user_id(), [snowflake_id()], role_member_index()) ->
role_member_index().
remove_user_from_member_role_index(UserId, RoleIds, RoleIndex) when is_integer(UserId) ->
lists:foldl(
fun(RoleId, Acc) ->
RoleMembers = maps:get(RoleId, Acc, #{}),
UpdatedRoleMembers = maps:remove(UserId, RoleMembers),
case map_size(UpdatedRoleMembers) of
0 ->
maps:remove(RoleId, Acc);
_ ->
maps:put(RoleId, UpdatedRoleMembers, Acc)
end
end,
RoleIndex,
RoleIds
).
-spec build_id_index([map()]) -> #{snowflake_id() => map()}.
build_id_index(Items) ->
lists:foldl(
fun(Item, Acc) ->
case map_utils:get_integer(Item, <<"id">>, undefined) of
undefined ->
Acc;
Id ->
maps:put(Id, Item, Acc)
end
end,
#{},
Items
).
-spec normalize_id_index(map()) -> #{snowflake_id() => map()}.
normalize_id_index(Index) ->
maps:fold(
fun(Key, Item, Acc) ->
case type_conv:to_integer(Key) of
undefined ->
case map_utils:get_integer(Item, <<"id">>, undefined) of
undefined -> Acc;
Id -> maps:put(Id, Item, Acc)
end;
Id ->
maps:put(Id, Item, Acc)
end
end,
#{},
Index
).
-spec ensure_list(term()) -> list().
ensure_list(List) when is_list(List) ->
List;
ensure_list(_) ->
[].
-spec extract_integer_list(term()) -> [integer()].
extract_integer_list(List) when is_list(List) ->
lists:reverse(
lists:foldl(
fun(Value, Acc) ->
case type_conv:to_integer(Value) of
undefined -> Acc;
Int -> [Int | Acc]
end
end,
[],
List
)
);
extract_integer_list(_) ->
[].
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
member_map_from_list_test() ->
Data = #{
<<"members">> => [
#{<<"user">> => #{<<"id">> => <<"2">>}, <<"nick">> => <<"beta">>},
#{<<"user">> => #{<<"id">> => <<"1">>}, <<"nick">> => <<"alpha">>}
]
},
MemberMap = member_map(Data),
?assertMatch(#{1 := _, 2 := _}, MemberMap),
?assertEqual(2, member_count(Data)).
member_list_is_sorted_test() ->
Data = #{
<<"members">> => #{
5 => #{<<"user">> => #{<<"id">> => <<"5">>}},
2 => #{<<"user">> => #{<<"id">> => <<"2">>}}
}
},
[First, Second] = member_list(Data),
?assertEqual(2, member_user_id(First)),
?assertEqual(5, member_user_id(Second)).
member_values_returns_members_without_sorting_test() ->
Data = #{
<<"members">> => #{
9 => #{<<"user">> => #{<<"id">> => <<"9">>}},
2 => #{<<"user">> => #{<<"id">> => <<"2">>}}
}
},
Values = member_values(Data),
?assertEqual(2, length(Values)).
put_member_updates_entry_test() ->
Data = #{
<<"members">> => #{
10 => #{<<"user">> => #{<<"id">> => <<"10">>}, <<"nick">> => <<"old">>}
}
},
UpdatedData = put_member(
#{<<"user">> => #{<<"id">> => <<"10">>}, <<"nick">> => <<"new">>},
Data
),
UpdatedMember = get_member(10, UpdatedData),
?assertEqual(<<"new">>, maps:get(<<"nick">>, UpdatedMember)).
remove_member_removes_entry_test() ->
Data = #{
<<"members">> => #{
10 => #{<<"user">> => #{<<"id">> => <<"10">>}}
}
},
UpdatedData = remove_member(10, Data),
?assertEqual(undefined, get_member(10, UpdatedData)).
normalize_data_builds_indexes_test() ->
Data = #{
<<"members">> => [#{<<"user">> => #{<<"id">> => <<"1">>}}],
<<"roles">> => [#{<<"id">> => <<"100">>}],
<<"channels">> => [#{<<"id">> => <<"200">>}]
},
Normalized = normalize_data(Data),
?assert(is_map(maps:get(<<"members">>, Normalized))),
?assertMatch(#{100 := _}, role_index(Normalized)),
?assertMatch(#{200 := _}, channel_index(Normalized)).
member_role_index_builds_role_to_user_lookup_test() ->
Data = #{
<<"members">> => #{
1 => #{<<"user">> => #{<<"id">> => <<"1">>}, <<"roles">> => [<<"10">>, <<"11">>]},
2 => #{<<"user">> => #{<<"id">> => <<"2">>}, <<"roles">> => [<<"11">>]}
}
},
Index = member_role_index(Data),
?assertEqual(#{1 => true}, maps:get(10, Index)),
?assertEqual(#{1 => true, 2 => true}, maps:get(11, Index)).
put_member_and_remove_member_keep_member_role_index_in_sync_test() ->
Data0 = #{
<<"members">> => #{
3 => #{<<"user">> => #{<<"id">> => <<"3">>}, <<"roles">> => [<<"20">>]}
}
},
Data1 = put_member(
#{<<"user">> => #{<<"id">> => <<"3">>}, <<"roles">> => [<<"30">>]},
Data0
),
Index1 = member_role_index(Data1),
?assertEqual(undefined, maps:get(20, Index1, undefined)),
?assertEqual(#{3 => true}, maps:get(30, Index1)),
Data2 = remove_member(3, Data1),
Index2 = member_role_index(Data2),
?assertEqual(undefined, maps:get(30, Index2, undefined)).
-endif.

File diff suppressed because it is too large Load Diff

View File

@@ -23,23 +23,15 @@
-export([start_link/0]).
-export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]).
-define(GUILD_PID_CACHE, guild_pid_cache).
-type guild_id() :: integer().
-type shard_map() :: #{pid => pid(), ref => reference()}.
-type shard_map() :: #{pid := pid(), ref := reference()}.
-type state() :: #{
shards => #{non_neg_integer() => shard_map()},
shard_count => pos_integer()
shards := #{non_neg_integer() => shard_map()},
shard_count := pos_integer()
}.
-record(shard, {
pid :: pid(),
ref :: reference()
}).
-record(state, {
shards = #{} :: #{non_neg_integer() => #shard{}},
shard_count = 1 :: pos_integer()
}).
-spec start_link() -> {ok, pid()} | {error, term()}.
start_link() ->
gen_server:start_link({local, ?MODULE}, ?MODULE, [], []).
@@ -47,23 +39,23 @@ start_link() ->
-spec init(list()) -> {ok, state()}.
init([]) ->
process_flag(trap_exit, true),
{ShardCount, Source} = determine_shard_count(),
ShardMap = start_shards(ShardCount, #{}),
maybe_log_shard_source(guild_manager, ShardCount, Source),
ets:new(?GUILD_PID_CACHE, [named_table, public, set, {read_concurrency, true}]),
{ShardCount, _Source} = determine_shard_count(),
ShardMap = start_shards(ShardCount),
{ok, #{shards => ShardMap, shard_count => ShardCount}}.
-spec handle_call(term(), gen_server:from(), state()) -> {reply, term(), state()}.
handle_call({start_or_lookup, GuildId} = Request, _From, State) ->
{Reply, NewState} = forward_call(GuildId, Request, State),
handle_call({start_or_lookup, GuildId}, _From, State) ->
{Reply, NewState} = forward_call(GuildId, {start_or_lookup, GuildId}, State),
{reply, Reply, NewState};
handle_call({stop_guild, GuildId} = Request, _From, State) ->
{Reply, NewState} = forward_call(GuildId, Request, State),
handle_call({stop_guild, GuildId}, _From, State) ->
{Reply, NewState} = forward_call(GuildId, {stop_guild, GuildId}, State),
{reply, Reply, NewState};
handle_call({reload_guild, GuildId} = Request, _From, State) ->
{Reply, NewState} = forward_call(GuildId, Request, State),
handle_call({reload_guild, GuildId}, _From, State) ->
{Reply, NewState} = forward_call(GuildId, {reload_guild, GuildId}, State),
{reply, Reply, NewState};
handle_call({shutdown_guild, GuildId} = Request, _From, State) ->
{Reply, NewState} = forward_call(GuildId, Request, State),
handle_call({shutdown_guild, GuildId}, _From, State) ->
{Reply, NewState} = forward_call(GuildId, {shutdown_guild, GuildId}, State),
{reply, Reply, NewState};
handle_call({reload_all_guilds, GuildIds}, _From, State) ->
{Reply, NewState} = handle_reload_all(GuildIds, State),
@@ -74,8 +66,7 @@ handle_call(get_local_count, _From, State) ->
handle_call(get_global_count, _From, State) ->
{Count, NewState} = aggregate_counts(get_global_count, State),
{reply, {ok, Count}, NewState};
handle_call(Request, _From, State) ->
logger:warning("[guild_manager] unknown request ~p", [Request]),
handle_call(_Request, _From, State) ->
{reply, ok, State}.
-spec handle_cast(term(), state()) -> {noreply, state()}.
@@ -83,21 +74,20 @@ handle_cast(_Msg, State) ->
{noreply, State}.
-spec handle_info(term(), state()) -> {noreply, state()}.
handle_info({'DOWN', Ref, process, _Pid, Reason}, State) ->
handle_info({'DOWN', Ref, process, Pid, _Reason}, State) ->
Shards = maps:get(shards, State),
case find_shard_by_ref(Ref, Shards) of
{ok, Index} ->
logger:warning("[guild_manager] shard ~p crashed: ~p", [Index, Reason]),
{_Shard, NewState} = restart_shard(Index, State),
{noreply, NewState};
not_found ->
cleanup_guild_from_cache(Pid),
{noreply, State}
end;
handle_info({'EXIT', Pid, Reason}, State) ->
handle_info({'EXIT', Pid, _Reason}, State) ->
Shards = maps:get(shards, State),
case find_shard_by_pid(Pid, Shards) of
{ok, Index} ->
logger:warning("[guild_manager] shard ~p exited: ~p", [Index, Reason]),
{_Shard, NewState} = restart_shard(Index, State),
{noreply, NewState};
not_found ->
@@ -116,17 +106,10 @@ terminate(_Reason, State) ->
end,
maps:values(Shards)
),
catch ets:delete(?GUILD_PID_CACHE),
ok.
-spec code_change(term(), term(), term()) -> {ok, state()}.
code_change(_OldVsn, #state{shards = OldShards, shard_count = ShardCount}, _Extra) ->
NewShards = maps:map(
fun(_Index, #shard{pid = Pid, ref = Ref}) ->
#{pid => Pid, ref => Ref}
end,
OldShards
),
{ok, #{shards => NewShards, shard_count => ShardCount}};
code_change(_OldVsn, State, _Extra) when is_map(State) ->
{ok, State}.
@@ -139,19 +122,26 @@ determine_shard_count() ->
{default_shard_count(), auto}
end.
-spec start_shards(pos_integer(), #{}) -> #{non_neg_integer() => shard_map()}.
start_shards(Count, Acc) ->
-spec default_shard_count() -> pos_integer().
default_shard_count() ->
Candidates = [
erlang:system_info(logical_processors_available),
erlang:system_info(schedulers_online)
],
max(1, lists:max([C || C <- Candidates, is_integer(C), C > 0] ++ [1])).
-spec start_shards(pos_integer()) -> #{non_neg_integer() => shard_map()}.
start_shards(Count) ->
lists:foldl(
fun(Index, MapAcc) ->
case start_shard(Index) of
{ok, Shard} ->
maps:put(Index, Shard, MapAcc);
{error, Reason} ->
logger:warning("[guild_manager] failed to start shard ~p: ~p", [Index, Reason]),
{error, _Reason} ->
MapAcc
end
end,
Acc,
#{},
lists:seq(0, Count - 1)
).
@@ -172,14 +162,31 @@ restart_shard(Index, State) ->
{ok, Shard} ->
Updated = State#{shards => maps:put(Index, Shard, Shards)},
{Shard, Updated};
{error, Reason} ->
logger:error("[guild_manager] failed to restart shard ~p: ~p", [Index, Reason]),
Dummy = #{pid => spawn(fun() -> exit(normal) end), ref => make_ref()},
{error, _Reason} ->
DummyPid = spawn(fun() -> ok end),
Dummy = #{pid => DummyPid, ref => make_ref()},
{Dummy, State}
end.
-spec forward_call(guild_id(), term(), state()) -> {term(), state()}.
forward_call(GuildId, {start_or_lookup, _} = Request, State) ->
case ets:lookup(?GUILD_PID_CACHE, GuildId) of
[{GuildId, GuildPid}] when is_pid(GuildPid) ->
case erlang:is_process_alive(GuildPid) of
true ->
{{ok, GuildPid}, State};
false ->
ets:delete(?GUILD_PID_CACHE, GuildId),
forward_call_to_shard(GuildId, Request, State)
end;
[] ->
forward_call_to_shard(GuildId, Request, State)
end;
forward_call(GuildId, Request, State) ->
forward_call_to_shard(GuildId, Request, State).
-spec forward_call_to_shard(guild_id(), term(), state()) -> {term(), state()}.
forward_call_to_shard(GuildId, Request, State) ->
{Index, State1} = ensure_shard(GuildId, State),
Shards = maps:get(shards, State1),
ShardMap = maps:get(Index, Shards),
@@ -187,7 +194,11 @@ forward_call(GuildId, Request, State) ->
case catch gen_server:call(Pid, Request, ?DEFAULT_GEN_SERVER_TIMEOUT) of
{'EXIT', _} ->
{_Shard, State2} = restart_shard(Index, State1),
forward_call(GuildId, Request, State2);
forward_call_to_shard(GuildId, Request, State2);
{ok, GuildPid} = Reply ->
ets:insert(?GUILD_PID_CACHE, {GuildId, GuildPid}),
erlang:monitor(process, GuildPid),
{Reply, State1};
Reply ->
{Reply, State1}
end.
@@ -223,146 +234,137 @@ select_shard(GuildId, Count) when Count > 0 ->
-spec aggregate_counts(term(), state()) -> {non_neg_integer(), state()}.
aggregate_counts(Request, State) ->
Shards = maps:get(shards, State),
Counts =
[
begin
Pid = maps:get(pid, ShardMap),
case catch gen_server:call(Pid, Request, ?DEFAULT_GEN_SERVER_TIMEOUT) of
{ok, Count} -> Count;
_ -> 0
end
Counts = lists:map(
fun(ShardMap) ->
Pid = maps:get(pid, ShardMap),
case catch gen_server:call(Pid, Request, ?DEFAULT_GEN_SERVER_TIMEOUT) of
{ok, Count} -> Count;
_ -> 0
end
|| ShardMap <- maps:values(Shards)
],
end,
maps:values(Shards)
),
{lists:sum(Counts), State}.
-spec handle_reload_all([guild_id()], state()) -> {#{count => non_neg_integer()}, state()}.
-spec handle_reload_all([guild_id()], state()) -> {#{count := non_neg_integer()}, state()}.
handle_reload_all([], State) ->
Shards = maps:get(shards, State),
{Replies, FinalState} =
lists:foldl(
fun({_Index, ShardMap}, {AccReplies, AccState}) ->
Pid = maps:get(pid, ShardMap),
case catch gen_server:call(Pid, {reload_all_guilds, []}, 60000) of
Reply ->
{AccReplies ++ [Reply], AccState}
end
end,
{[], State},
maps:to_list(Shards)
),
Count = lists:sum([maps:get(count, Reply, 0) || Reply <- Replies]),
{Replies, FinalState} = lists:foldl(
fun({_Index, ShardMap}, {AccReplies, AccState}) ->
Pid = maps:get(pid, ShardMap),
Reply = catch gen_server:call(Pid, {reload_all_guilds, []}, 15000),
{[Reply | AccReplies], AccState}
end,
{[], State},
maps:to_list(Shards)
),
Count = lists:sum([maps:get(count, Reply, 0) || Reply <- Replies, is_map(Reply)]),
{#{count => Count}, FinalState};
handle_reload_all(GuildIds, State) ->
Count = maps:get(shard_count, State),
Groups = group_ids_by_shard(GuildIds, Count),
{TotalCount, FinalState} =
lists:foldl(
fun({Index, Ids}, {AccCount, AccState}) ->
{ShardIdx, State1} = ensure_shard_for_index(Index, AccState),
Shards = maps:get(shards, State1),
ShardMap = maps:get(ShardIdx, Shards),
Pid = maps:get(pid, ShardMap),
case catch gen_server:call(Pid, {reload_all_guilds, Ids}, 60000) of
#{count := CountReply} ->
{AccCount + CountReply, State1};
_ ->
{AccCount, State1}
end
end,
{0, State},
Groups
),
{TotalCount, FinalState} = lists:foldl(
fun({Index, Ids}, {AccCount, AccState}) ->
{ShardIdx, State1} = ensure_shard_for_index(Index, AccState),
Shards = maps:get(shards, State1),
ShardMap = maps:get(ShardIdx, Shards),
Pid = maps:get(pid, ShardMap),
case catch gen_server:call(Pid, {reload_all_guilds, Ids}, 15000) of
#{count := CountReply} ->
{AccCount + CountReply, State1};
_ ->
{AccCount, State1}
end
end,
{0, State},
Groups
),
{#{count => TotalCount}, FinalState}.
-spec group_ids_by_shard([guild_id()], pos_integer()) -> [{non_neg_integer(), [guild_id()]}].
group_ids_by_shard(GuildIds, ShardCount) ->
lists:foldl(
fun(GuildId, Acc) ->
Index = select_shard(GuildId, ShardCount),
case lists:keytake(Index, 1, Acc) of
{value, {Index, Ids}, Rest} ->
[{Index, [GuildId | Ids]} | Rest];
false ->
[{Index, [GuildId]} | Acc]
end
end,
[],
GuildIds
).
rendezvous_router:group_keys(GuildIds, ShardCount).
-spec find_shard_by_ref(reference(), #{non_neg_integer() => shard_map()}) ->
{ok, non_neg_integer()} | not_found.
find_shard_by_ref(Ref, Shards) ->
maps:fold(
fun
(Index, ShardMap, _) when is_map(ShardMap) ->
case maps:get(ref, ShardMap) of
R when R =:= Ref -> {ok, Index};
_ -> not_found
end;
(_, _, Acc) ->
Acc
end,
not_found,
Shards
).
find_shard_by(fun(#{ref := R}) -> R =:= Ref end, Shards).
-spec find_shard_by_pid(pid(), #{non_neg_integer() => shard_map()}) ->
{ok, non_neg_integer()} | not_found.
find_shard_by_pid(Pid, Shards) ->
find_shard_by(fun(#{pid := P}) -> P =:= Pid end, Shards).
-spec find_shard_by(fun((shard_map()) -> boolean()), #{non_neg_integer() => shard_map()}) ->
{ok, non_neg_integer()} | not_found.
find_shard_by(Pred, Shards) ->
maps:fold(
fun
(Index, ShardMap, _) when is_map(ShardMap) ->
case maps:get(pid, ShardMap) of
P when P =:= Pid -> {ok, Index};
_ -> not_found
end;
(_, _, Acc) ->
Acc
(_, _, {ok, _} = Found) ->
Found;
(Index, ShardMap, not_found) ->
case Pred(ShardMap) of
true -> {ok, Index};
false -> not_found
end
end,
not_found,
Shards
).
-spec default_shard_count() -> pos_integer().
default_shard_count() ->
Candidates = [
erlang:system_info(logical_processors_available), erlang:system_info(schedulers_online)
],
lists:max([C || C <- Candidates, is_integer(C), C > 0] ++ [1]).
-spec maybe_log_shard_source(atom(), pos_integer(), configured | auto) -> ok.
maybe_log_shard_source(Name, Count, configured) ->
logger:info("[~p] starting with ~p shards (configured)", [Name, Count]),
ok;
maybe_log_shard_source(Name, Count, auto) ->
logger:info("[~p] starting with ~p shards (auto)", [Name, Count]),
-spec cleanup_guild_from_cache(pid()) -> ok.
cleanup_guild_from_cache(Pid) ->
case ets:match_object(?GUILD_PID_CACHE, {'$1', Pid}) of
[{GuildId, _Pid}] ->
ets:delete(?GUILD_PID_CACHE, GuildId);
[] ->
ok
end,
ok.
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
determine_shard_count_configured_test() ->
with_runtime_config(guild_shards, 4, fun() ->
?assertMatch({4, configured}, determine_shard_count())
end).
default_shard_count_positive_test() ->
Count = default_shard_count(),
?assert(Count >= 1).
determine_shard_count_auto_test() ->
with_runtime_config(guild_shards, undefined, fun() ->
{Count, auto} = determine_shard_count(),
?assert(Count > 0)
end).
select_shard_deterministic_test() ->
GuildId = 12345,
ShardCount = 8,
Shard1 = select_shard(GuildId, ShardCount),
Shard2 = select_shard(GuildId, ShardCount),
?assertEqual(Shard1, Shard2).
select_shard_in_range_test() ->
ShardCount = 8,
lists:foreach(
fun(GuildId) ->
Shard = select_shard(GuildId, ShardCount),
?assert(Shard >= 0 andalso Shard < ShardCount)
end,
lists:seq(1, 100)
).
group_ids_by_shard_test() ->
GuildIds = [1, 2, 3, 4, 5],
ShardCount = 2,
Groups = group_ids_by_shard(GuildIds, ShardCount),
AllIds = lists:flatten([Ids || {_, Ids} <- Groups]),
?assertEqual(lists:sort(GuildIds), lists:sort(AllIds)).
find_shard_by_ref_found_test() ->
Ref = make_ref(),
Shards = #{0 => #{pid => self(), ref => Ref}},
?assertMatch({ok, 0}, find_shard_by_ref(Ref, Shards)).
find_shard_by_ref_not_found_test() ->
Shards = #{0 => #{pid => self(), ref => make_ref()}},
?assertEqual(not_found, find_shard_by_ref(make_ref(), Shards)).
find_shard_by_pid_found_test() ->
Pid = self(),
Shards = #{0 => #{pid => Pid, ref => make_ref()}},
?assertMatch({ok, 0}, find_shard_by_pid(Pid, Shards)).
with_runtime_config(Key, Value, Fun) ->
Original = fluxer_gateway_env:get(Key),
fluxer_gateway_env:patch(#{Key => Value}),
Result = Fun(),
fluxer_gateway_env:update(fun(Map) ->
case Original of
undefined -> maps:remove(Key, Map);
Val -> maps:put(Key, Val, Map)
end
end),
Result.
-endif.

View File

@@ -21,6 +21,8 @@
-include_lib("fluxer_gateway/include/timeout_config.hrl").
-define(GUILD_API_CANARY_PERCENTAGE, 5).
-define(BATCH_SIZE, 10).
-define(BATCH_DELAY_MS, 100).
-export([start_link/1]).
-export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]).
@@ -30,56 +32,34 @@
-type guild_data() :: #{binary() => term()}.
-type fetch_result() :: {ok, guild_data()} | {error, term()}.
-type state() :: #{
guilds => #{guild_id() => guild_ref() | loading},
api_host => string(),
api_canary_host => undefined | string(),
pending_requests => #{guild_id() => [gen_server:from()]}
guilds := #{guild_id() => guild_ref() | loading},
api_host := string(),
api_canary_host := undefined | string(),
pending_requests := #{guild_id() => [gen_server:from()]},
shard_index := non_neg_integer()
}.
-record(state, {
guilds = #{} :: #{guild_id() => guild_ref() | loading},
api_host :: string(),
api_canary_host :: undefined | string(),
pending_requests = #{} :: #{guild_id() => [gen_server:from()]}
}).
-spec start_link(non_neg_integer()) -> {ok, pid()} | {error, term()}.
start_link(ShardIndex) ->
gen_server:start_link(?MODULE, #{shard_index => ShardIndex}, []).
-spec init(map()) -> {ok, state()}.
init(_Args) ->
init(Args) ->
process_flag(trap_exit, true),
fluxer_gateway_env:load(),
ApiHost = fluxer_gateway_env:get(api_host),
ApiCanaryHost = fluxer_gateway_env:get(api_canary_host),
ShardIndex = maps:get(shard_index, Args, 0),
{ok, #{
guilds => #{},
api_host => ApiHost,
api_canary_host => ApiCanaryHost,
pending_requests => #{}
pending_requests => #{},
shard_index => ShardIndex
}}.
-spec handle_call(Request, From, State) -> Result when
Request ::
{start_or_lookup, guild_id()}
| {stop_guild, guild_id()}
| {reload_guild, guild_id()}
| {reload_all_guilds, [guild_id()]}
| {shutdown_guild, guild_id()}
| get_local_count
| get_global_count
| term(),
From :: gen_server:from(),
State :: state(),
Result ::
{reply, Reply, state()}
| {noreply, state()},
Reply ::
{ok, pid()}
| {error, term()}
| ok
| {ok, non_neg_integer()}.
-spec handle_call(term(), gen_server:from(), state()) ->
{reply, term(), state()} | {noreply, state()}.
handle_call({start_or_lookup, GuildId}, From, State) ->
do_start_or_lookup(GuildId, From, State);
handle_call({stop_guild, GuildId}, _From, State) ->
@@ -87,37 +67,7 @@ handle_call({stop_guild, GuildId}, _From, State) ->
handle_call({reload_guild, GuildId}, From, State) ->
do_reload_guild(GuildId, From, State);
handle_call({reload_all_guilds, GuildIds}, From, State) ->
Guilds = maps:get(guilds, State),
GuildsToReload =
case GuildIds of
[] ->
[{GuildId, Pid} || {GuildId, {Pid, _Ref}} <- maps:to_list(Guilds)];
Ids ->
lists:filtermap(
fun(GuildId) ->
case maps:get(GuildId, Guilds, undefined) of
{Pid, _Ref} -> {true, {GuildId, Pid}};
_ -> false
end
end,
Ids
)
end,
Manager = self(),
spawn(fun() ->
try
reload_guilds_in_batches(GuildsToReload, Manager, State, 10, 100),
gen_server:cast(Manager, {all_guilds_reloaded, From, length(GuildsToReload)})
catch
Class:Error:Stacktrace ->
logger:error(
"[guild_manager] Spawned process failed: ~p:~p~n~p",
[Class, Error, Stacktrace]
),
gen_server:cast(Manager, {all_guilds_reloaded, From, 0})
end
end),
{noreply, State};
do_reload_all_guilds(GuildIds, From, State);
handle_call({shutdown_guild, GuildId}, _From, State) ->
do_shutdown_guild(GuildId, State);
handle_call(get_local_count, _From, State) ->
@@ -131,60 +81,18 @@ handle_call(get_global_count, _From, State) ->
handle_call(_Unknown, _From, State) ->
{reply, ok, State}.
-spec handle_cast(Request, State) -> {noreply, state()} when
Request ::
{guild_data_fetched, guild_id(), fetch_result()}
| {guild_data_reloaded, guild_id(), pid(), gen_server:from(), fetch_result()}
| {all_guilds_reloaded, gen_server:from(), non_neg_integer()}
| term(),
State :: state().
-spec handle_cast(term(), state()) -> {noreply, state()}.
handle_cast({guild_data_fetched, GuildId, Result}, State) ->
Pending = maps:get(pending_requests, State),
Requests = maps:get(GuildId, Pending, []),
Guilds = maps:get(guilds, State),
case Result of
{ok, Data} ->
case start_guild(GuildId, Data, State) of
{ok, Pid, NewState} ->
lists:foreach(fun(From) -> gen_server:reply(From, {ok, Pid}) end, Requests),
NewPending = maps:remove(GuildId, Pending),
NewGuilds = maps:get(guilds, NewState),
CleanGuilds = maps:remove(GuildId, NewGuilds),
{noreply, NewState#{pending_requests => NewPending, guilds => CleanGuilds}};
{error, Reason} ->
logger:error("[guild_manager] Failed to start guild ~p: ~p", [GuildId, Reason]),
lists:foreach(
fun(From) -> gen_server:reply(From, {error, Reason}) end, Requests
),
NewGuilds = maps:remove(GuildId, Guilds),
NewPending = maps:remove(GuildId, Pending),
{noreply, State#{guilds => NewGuilds, pending_requests => NewPending}}
end;
_ ->
lists:foreach(fun(From) -> gen_server:reply(From, {error, not_found}) end, Requests),
NewGuilds = maps:remove(GuildId, Guilds),
NewPending = maps:remove(GuildId, Pending),
{noreply, State#{guilds => NewGuilds, pending_requests => NewPending}}
end;
handle_cast({guild_data_reloaded, _GuildId, Pid, From, Result}, State) ->
case Result of
{ok, Data} ->
gen_server:call(Pid, {reload, Data}, ?GUILD_CALL_TIMEOUT),
gen_server:reply(From, ok),
{noreply, State};
_ ->
gen_server:reply(From, {error, fetch_failed}),
{noreply, State}
end;
handle_guild_data_fetched(GuildId, Result, State);
handle_cast({guild_data_reloaded, GuildId, Pid, From, Result}, State) ->
handle_guild_data_reloaded(GuildId, Pid, From, Result, State);
handle_cast({all_guilds_reloaded, From, Count}, State) ->
gen_server:reply(From, #{count => Count}),
{noreply, State};
handle_cast(_Unknown, State) ->
{noreply, State}.
-spec handle_info(Info, State) -> {noreply, state()} when
Info :: {'DOWN', reference(), process, pid(), term()} | term(),
State :: state().
-spec handle_info(term(), state()) -> {noreply, state()}.
handle_info({'DOWN', _Ref, process, Pid, _Reason}, State) ->
Guilds = maps:get(guilds, State),
NewGuilds = process_registry:cleanup_on_down(Pid, Guilds),
@@ -192,23 +100,323 @@ handle_info({'DOWN', _Ref, process, Pid, _Reason}, State) ->
handle_info(_Unknown, State) ->
{noreply, State}.
-spec terminate(Reason, State) -> ok when
Reason :: term(),
State :: state().
-spec terminate(term(), state()) -> ok.
terminate(_Reason, _State) ->
ok.
-spec code_change(term(), term(), term()) -> {ok, state()}.
code_change(_OldVsn, #state{guilds = Guilds, api_host = ApiHost, api_canary_host = ApiCanaryHost, pending_requests = Pending}, _Extra) ->
{ok, #{
guilds => Guilds,
api_host => ApiHost,
api_canary_host => ApiCanaryHost,
pending_requests => Pending
}};
code_change(_OldVsn, State, _Extra) when is_map(State) ->
{ok, State}.
-spec do_start_or_lookup(guild_id(), gen_server:from(), state()) ->
{reply, {ok, pid()} | {error, term()}, state()} | {noreply, state()}.
do_start_or_lookup(GuildId, From, State) ->
Guilds = maps:get(guilds, State),
case maps:get(GuildId, Guilds, undefined) of
{Pid, _Ref} ->
{reply, {ok, Pid}, State};
loading ->
add_pending_request(GuildId, From, State);
undefined ->
lookup_or_fetch(GuildId, From, State)
end.
-spec lookup_or_fetch(guild_id(), gen_server:from(), state()) ->
{reply, {ok, pid()}, state()} | {noreply, state()}.
lookup_or_fetch(GuildId, From, State) ->
GuildName = process_registry:build_process_name(guild, GuildId),
case whereis(GuildName) of
undefined ->
start_fetch(GuildId, From, State);
_ExistingPid ->
Guilds = maps:get(guilds, State),
case process_registry:lookup_or_monitor(GuildName, GuildId, Guilds) of
{ok, Pid, _Ref, NewGuilds} ->
{reply, {ok, Pid}, State#{guilds => NewGuilds}};
{error, not_found} ->
{reply, {error, process_died}, State}
end
end.
-spec start_fetch(guild_id(), gen_server:from(), state()) -> {noreply, state()}.
start_fetch(GuildId, From, State) ->
Guilds = maps:get(guilds, State),
NewGuilds = maps:put(GuildId, loading, Guilds),
Pending = maps:get(pending_requests, State),
NewPending = maps:put(GuildId, [From], Pending),
NewState = State#{guilds => NewGuilds, pending_requests => NewPending},
spawn_fetch(GuildId, State),
{noreply, NewState}.
-spec spawn_fetch(guild_id(), state()) -> pid().
spawn_fetch(GuildId, State) ->
Manager = self(),
ApiHostInfo = select_api_host(State),
spawn(fun() ->
try
Result = fetch_guild_data_with_fallback(GuildId, ApiHostInfo, State),
gen_server:cast(Manager, {guild_data_fetched, GuildId, Result})
catch
_:_:_ ->
gen_server:cast(Manager, {guild_data_fetched, GuildId, {error, fetch_failed}})
end
end).
-spec add_pending_request(guild_id(), gen_server:from(), state()) -> {noreply, state()}.
add_pending_request(GuildId, From, State) ->
Pending = maps:get(pending_requests, State),
Requests = maps:get(GuildId, Pending, []),
NewPending = maps:put(GuildId, [From | Requests], Pending),
{noreply, State#{pending_requests => NewPending}}.
-spec handle_guild_data_fetched(guild_id(), fetch_result(), state()) -> {noreply, state()}.
handle_guild_data_fetched(GuildId, Result, State) ->
Pending = maps:get(pending_requests, State),
Requests = maps:get(GuildId, Pending, []),
Guilds = maps:get(guilds, State),
case Result of
{ok, Data} ->
case start_guild(GuildId, Data, State) of
{ok, Pid, NewState} ->
reply_to_all(Requests, {ok, Pid}),
NewPending = maps:remove(GuildId, Pending),
{noreply, NewState#{pending_requests => NewPending}};
{error, Reason} ->
reply_to_all(Requests, {error, Reason}),
NewGuilds = maps:remove(GuildId, Guilds),
NewPending = maps:remove(GuildId, Pending),
{noreply, State#{guilds => NewGuilds, pending_requests => NewPending}}
end;
{error, Reason} ->
reply_to_all(Requests, {error, Reason}),
NewGuilds = maps:remove(GuildId, Guilds),
NewPending = maps:remove(GuildId, Pending),
{noreply, State#{guilds => NewGuilds, pending_requests => NewPending}}
end.
-spec handle_guild_data_reloaded(guild_id(), pid(), gen_server:from(), fetch_result(), state()) ->
{noreply, state()}.
handle_guild_data_reloaded(_GuildId, Pid, From, Result, State) ->
case Result of
{ok, Data} ->
catch gen_server:call(Pid, {reload, Data}, ?GUILD_CALL_TIMEOUT),
gen_server:reply(From, ok);
_ ->
gen_server:reply(From, {error, fetch_failed})
end,
{noreply, State}.
-spec reply_to_all([gen_server:from()], term()) -> ok.
reply_to_all(Requests, Reply) ->
lists:foreach(fun(From) -> gen_server:reply(From, Reply) end, Requests).
-spec do_stop_guild(guild_id(), state()) -> {reply, ok, state()}.
do_stop_guild(GuildId, State) ->
Guilds = maps:get(guilds, State),
GuildName = process_registry:build_process_name(guild, GuildId),
case maps:get(GuildId, Guilds, undefined) of
{Pid, Ref} ->
demonitor(Ref, [flush]),
catch gen_server:stop(Pid, normal, ?SHUTDOWN_TIMEOUT),
process_registry:safe_unregister(GuildName),
NewGuilds = maps:remove(GuildId, Guilds),
{reply, ok, State#{guilds => NewGuilds}};
_ ->
case whereis(GuildName) of
undefined ->
{reply, ok, State};
ExistingPid ->
catch gen_server:stop(ExistingPid, normal, ?SHUTDOWN_TIMEOUT),
process_registry:safe_unregister(GuildName),
{reply, ok, State}
end
end.
-spec do_reload_guild(guild_id(), gen_server:from(), state()) ->
{reply, {error, not_found}, state()} | {noreply, state()}.
do_reload_guild(GuildId, From, State) ->
Guilds = maps:get(guilds, State),
GuildName = process_registry:build_process_name(guild, GuildId),
case maps:get(GuildId, Guilds, undefined) of
{Pid, _Ref} ->
spawn_reload(GuildId, Pid, From, State),
{noreply, State};
_ ->
case whereis(GuildName) of
undefined ->
{reply, {error, not_found}, State};
_ExistingPid ->
case process_registry:lookup_or_monitor(GuildName, GuildId, Guilds) of
{ok, Pid, _Ref, NewGuilds} ->
NewState = State#{guilds => NewGuilds},
spawn_reload(GuildId, Pid, From, NewState),
{noreply, NewState};
{error, not_found} ->
{reply, {error, not_found}, State}
end
end
end.
-spec spawn_reload(guild_id(), pid(), gen_server:from(), state()) -> pid().
spawn_reload(GuildId, Pid, From, State) ->
Manager = self(),
ApiHostInfo = select_api_host(State),
spawn(fun() ->
try
Result = fetch_guild_data_with_fallback(GuildId, ApiHostInfo, State),
gen_server:cast(Manager, {guild_data_reloaded, GuildId, Pid, From, Result})
catch
_:_:_ ->
gen_server:cast(
Manager, {guild_data_reloaded, GuildId, Pid, From, {error, fetch_failed}}
)
end
end).
-spec do_reload_all_guilds([guild_id()], gen_server:from(), state()) -> {noreply, state()}.
do_reload_all_guilds(GuildIds, From, State) ->
Guilds = maps:get(guilds, State),
GuildsToReload = select_guilds_to_reload(GuildIds, Guilds),
Manager = self(),
spawn(fun() ->
try
reload_guilds_in_batches(GuildsToReload, State),
gen_server:cast(Manager, {all_guilds_reloaded, From, length(GuildsToReload)})
catch
_:_:_ ->
gen_server:cast(Manager, {all_guilds_reloaded, From, 0})
end
end),
{noreply, State}.
-spec select_guilds_to_reload([guild_id()], #{guild_id() => guild_ref() | loading}) ->
[{guild_id(), pid()}].
select_guilds_to_reload([], Guilds) ->
[{GuildId, Pid} || {GuildId, {Pid, _Ref}} <- maps:to_list(Guilds)];
select_guilds_to_reload(GuildIds, Guilds) ->
lists:filtermap(
fun(GuildId) ->
case maps:get(GuildId, Guilds, undefined) of
{Pid, _Ref} -> {true, {GuildId, Pid}};
_ -> false
end
end,
GuildIds
).
-spec reload_guilds_in_batches([{guild_id(), pid()}], state()) -> ok.
reload_guilds_in_batches([], _State) ->
ok;
reload_guilds_in_batches(Guilds, State) ->
{Batch, Remaining} = lists:split(min(?BATCH_SIZE, length(Guilds)), Guilds),
reload_batch(Batch, State),
case Remaining of
[] ->
ok;
_ ->
timer:sleep(?BATCH_DELAY_MS),
reload_guilds_in_batches(Remaining, State)
end.
-spec reload_batch([{guild_id(), pid()}], state()) -> ok.
reload_batch(Batch, State) ->
ApiHostInfo = select_api_host(State),
lists:foreach(
fun({GuildId, Pid}) ->
spawn(fun() ->
try
case fetch_guild_data_with_fallback(GuildId, ApiHostInfo, State) of
{ok, Data} ->
catch gen_server:call(Pid, {reload, Data}, ?GUILD_CALL_TIMEOUT);
{error, _Reason} ->
ok
end
catch
_:_ -> ok
end
end)
end,
Batch
).
-spec do_shutdown_guild(guild_id(), state()) -> {reply, ok, state()}.
do_shutdown_guild(GuildId, State) ->
Guilds = maps:get(guilds, State),
GuildName = process_registry:build_process_name(guild, GuildId),
case maps:get(GuildId, Guilds, undefined) of
{Pid, Ref} ->
demonitor(Ref, [flush]),
catch gen_server:call(Pid, {terminate}, ?SHUTDOWN_TIMEOUT),
process_registry:safe_unregister(GuildName),
NewGuilds = maps:remove(GuildId, Guilds),
{reply, ok, State#{guilds => NewGuilds}};
_ ->
case whereis(GuildName) of
undefined ->
{reply, ok, State};
ExistingPid ->
catch gen_server:call(ExistingPid, {terminate}, ?SHUTDOWN_TIMEOUT),
process_registry:safe_unregister(GuildName),
{reply, ok, State}
end
end.
-spec start_guild(guild_id(), guild_data(), state()) -> {ok, pid(), state()} | {error, term()}.
start_guild(GuildId, Data, State) ->
GuildName = process_registry:build_process_name(guild, GuildId),
case whereis(GuildName) of
undefined ->
start_new_guild(GuildId, Data, GuildName, State);
_ExistingPid ->
lookup_existing_guild(GuildId, GuildName, State)
end.
-spec start_new_guild(guild_id(), guild_data(), atom(), state()) ->
{ok, pid(), state()} | {error, term()}.
start_new_guild(GuildId, Data, GuildName, State) ->
GuildState = #{
id => GuildId,
data => Data,
sessions => #{}
},
Guilds = maps:get(guilds, State),
GuildModule =
case is_very_large_guild(Data) of
true -> very_large_guild;
false -> guild
end,
case GuildModule:start_link(GuildState) of
{ok, Pid} ->
case process_registry:register_and_monitor(GuildName, Pid, Guilds) of
{ok, RegisteredPid, Ref, NewGuilds0} ->
CleanGuilds = maps:remove(GuildName, NewGuilds0),
NewGuilds = maps:put(GuildId, {RegisteredPid, Ref}, CleanGuilds),
{ok, RegisteredPid, State#{guilds => NewGuilds}};
{error, Reason} ->
{error, Reason}
end;
Error ->
Error
end.
-spec is_very_large_guild(guild_data()) -> boolean().
is_very_large_guild(Data) when is_map(Data) ->
Guild = maps:get(<<"guild">>, Data, #{}),
Features = maps:get(<<"features">>, Guild, []),
lists:member(<<"VERY_LARGE_GUILD">>, Features);
is_very_large_guild(_) ->
false.
-spec lookup_existing_guild(guild_id(), atom(), state()) -> {ok, pid(), state()} | {error, term()}.
lookup_existing_guild(GuildId, GuildName, State) ->
Guilds = maps:get(guilds, State),
case process_registry:lookup_or_monitor(GuildName, GuildId, Guilds) of
{ok, Pid, _Ref, NewGuilds} ->
{ok, Pid, State#{guilds => NewGuilds}};
{error, not_found} ->
{error, process_died}
end.
-spec fetch_guild_data(guild_id(), string()) -> fetch_result().
fetch_guild_data(GuildId, ApiHost) ->
RpcRequest = #{
@@ -217,43 +425,27 @@ fetch_guild_data(GuildId, ApiHost) ->
<<"version">> => 1
},
Url = rpc_client:get_rpc_url(ApiHost),
Headers =
rpc_client:get_rpc_headers() ++ [{<<"content-type">>, <<"application/json">>}],
Body = jsx:encode(RpcRequest),
case
hackney:request(post, Url, Headers, Body, [{recv_timeout, 30000}, {connect_timeout, 5000}])
of
{ok, 200, _RespHeaders, ClientRef} ->
case hackney:body(ClientRef) of
{ok, RespBody} ->
hackney:close(ClientRef),
Response = jsx:decode(RespBody, [return_maps]),
Data = maps:get(<<"data">>, Response, #{}),
{ok, Data};
{error, BodyReason} ->
hackney:close(ClientRef),
logger:error("[guild_manager] Failed to read guild response body: ~p", [
BodyReason
]),
{error, fetch_failed}
end;
{ok, StatusCode, _RespHeaders, ClientRef} ->
ErrorBody =
case hackney:body(ClientRef) of
{ok, Body2} -> Body2;
{error, _} -> <<"<unable to read error body>">>
end,
hackney:close(ClientRef),
logger:error(
"[guild_manager] Guild RPC failed with status ~p: ~s",
[StatusCode, ErrorBody]
),
{error, fetch_failed};
Headers = rpc_client:get_rpc_headers() ++ [{<<"content-type">>, <<"application/json">>}],
Body = json:encode(RpcRequest),
case gateway_http_client:request(rpc, post, Url, Headers, Body) of
{ok, 200, _RespHeaders, RespBody} ->
handle_fetch_response(RespBody);
{ok, StatusCode, _RespHeaders, _RespBody} ->
handle_fetch_error(StatusCode);
{error, Reason} ->
logger:error("[guild_manager] Guild RPC request failed: ~p", [Reason]),
{error, fetch_failed}
{error, {request_failed, Reason}}
end.
-spec handle_fetch_response(binary()) -> fetch_result().
handle_fetch_response(RespBody) ->
Response = json:decode(RespBody),
Data = maps:get(<<"data">>, Response, #{}),
{ok, Data}.
-spec handle_fetch_error(integer()) -> {error, {http_status, integer()}}.
handle_fetch_error(StatusCode) ->
{error, {http_status, StatusCode}}.
-spec select_api_host(state()) -> {string(), boolean()}.
select_api_host(State) ->
case maps:get(api_canary_host, State) of
@@ -270,12 +462,8 @@ select_api_host(State) ->
should_use_canary_api() ->
erlang:unique_integer([positive]) rem 100 < ?GUILD_API_CANARY_PERCENTAGE.
-spec fetch_guild_data_with_fallback(
guild_id(),
{string(), boolean()},
state()
) -> fetch_result().
fetch_guild_data_with_fallback(GuildId, {ApiHost, false}, _) ->
-spec fetch_guild_data_with_fallback(guild_id(), {string(), boolean()}, state()) -> fetch_result().
fetch_guild_data_with_fallback(GuildId, {ApiHost, false}, _State) ->
fetch_guild_data(GuildId, ApiHost);
fetch_guild_data_with_fallback(GuildId, {ApiHost, true}, State) ->
case fetch_guild_data(GuildId, ApiHost) of
@@ -287,244 +475,57 @@ fetch_guild_data_with_fallback(GuildId, {ApiHost, true}, State) ->
true ->
Error;
false ->
logger:warning(
"[guild_manager] Canary API request failed for ~p, retrying against stable host",
[GuildId]
),
fetch_guild_data(GuildId, StableHost)
end
end.
-spec start_guild(guild_id(), guild_data(), state()) -> {ok, pid(), state()} | {error, term()}.
start_guild(GuildId, Data, State) ->
GuildName = process_registry:build_process_name(guild, GuildId),
case whereis(GuildName) of
undefined ->
GuildState = #{
id => GuildId,
data => Data,
sessions => #{},
presences => #{}
},
Guilds = maps:get(guilds, State),
case guild:start_link(GuildState) of
{ok, Pid} ->
case process_registry:register_and_monitor(GuildName, Pid, Guilds) of
{ok, RegisteredPid, Ref, NewGuilds0} ->
CleanGuilds = maps:remove(GuildName, NewGuilds0),
NewGuilds = maps:put(GuildId, {RegisteredPid, Ref}, CleanGuilds),
{ok, RegisteredPid, State#{guilds => NewGuilds}};
{error, Reason} ->
{error, Reason}
end;
Error ->
Error
end;
_ExistingPid ->
Guilds = maps:get(guilds, State),
case process_registry:lookup_or_monitor(GuildName, GuildId, Guilds) of
{ok, Pid, _Ref, NewGuilds} ->
{ok, Pid, State#{guilds => NewGuilds}};
{error, not_found} ->
{error, process_died}
end
end.
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-spec reload_guilds_in_batches(
[{guild_id(), pid()}],
pid(),
state(),
pos_integer(),
non_neg_integer()
) -> ok.
reload_guilds_in_batches([], _Manager, _State, _BatchSize, _DelayMs) ->
ok;
reload_guilds_in_batches(Guilds, Manager, State, BatchSize, DelayMs) ->
{Batch, Remaining} = lists:split(min(BatchSize, length(Guilds)), Guilds),
lists:foreach(
fun({GuildId, Pid}) ->
ApiHostInfo = select_api_host(State),
spawn(fun() ->
try
case fetch_guild_data_with_fallback(GuildId, ApiHostInfo, State) of
{ok, Data} ->
gen_server:call(Pid, {reload, Data}, ?GUILD_CALL_TIMEOUT);
{error, Reason} ->
logger:error("[guild_manager] Failed to reload guild ~p: ~p", [
GuildId, Reason
])
end
catch
Class:Error:Stacktrace ->
logger:error(
"[guild_manager] Spawned process failed: ~p:~p~n~p",
[Class, Error, Stacktrace]
)
end
end)
end,
Batch
),
case Remaining of
[] ->
ok;
_ ->
timer:sleep(DelayMs),
reload_guilds_in_batches(Remaining, Manager, State, BatchSize, DelayMs)
end.
select_api_host_no_canary_test() ->
State = #{api_host => "http://api.local", api_canary_host => undefined},
{Host, IsCanary} = select_api_host(State),
?assertEqual("http://api.local", Host),
?assertEqual(false, IsCanary).
-spec do_start_or_lookup(guild_id(), gen_server:from(), state()) ->
{reply, {ok, pid()} | {error, term()}, state()} | {noreply, state()}.
do_start_or_lookup(GuildId, From, State) ->
Guilds = maps:get(guilds, State),
case maps:get(GuildId, Guilds, undefined) of
{Pid, _Ref} ->
{reply, {ok, Pid}, State};
loading ->
Pending = maps:get(pending_requests, State),
Requests = maps:get(GuildId, Pending, []),
NewPending = maps:put(GuildId, [From | Requests], Pending),
{noreply, State#{pending_requests => NewPending}};
undefined ->
GuildName = process_registry:build_process_name(guild, GuildId),
case whereis(GuildName) of
undefined ->
NewGuilds = maps:put(GuildId, loading, Guilds),
Pending = maps:get(pending_requests, State),
NewPending = maps:put(GuildId, [From], Pending),
NewState = State#{guilds => NewGuilds, pending_requests => NewPending},
Manager = self(),
ApiHostInfo = select_api_host(State),
spawn(fun() ->
try
Result = fetch_guild_data_with_fallback(GuildId, ApiHostInfo, State),
gen_server:cast(Manager, {guild_data_fetched, GuildId, Result})
catch
Class:Error:Stacktrace ->
logger:error(
"[guild_manager] Spawned process failed: ~p:~p~n~p",
[Class, Error, Stacktrace]
),
gen_server:cast(
Manager, {guild_data_fetched, GuildId, {error, fetch_failed}}
)
end
end),
{noreply, NewState};
_ExistingPid ->
case process_registry:lookup_or_monitor(GuildName, GuildId, Guilds) of
{ok, Pid, _Ref, NewGuilds} ->
{reply, {ok, Pid}, State#{guilds => NewGuilds}};
{error, not_found} ->
{reply, {error, process_died}, State}
end
end
end.
should_use_canary_api_returns_boolean_test() ->
Result = should_use_canary_api(),
?assert(is_boolean(Result)).
-spec do_stop_guild(guild_id(), state()) -> {reply, ok, state()}.
do_stop_guild(GuildId, State) ->
Guilds = maps:get(guilds, State),
GuildName = process_registry:build_process_name(guild, GuildId),
case maps:get(GuildId, Guilds, undefined) of
{Pid, Ref} ->
demonitor(Ref, [flush]),
gen_server:stop(Pid, normal, ?SHUTDOWN_TIMEOUT),
process_registry:safe_unregister(GuildName),
NewGuilds = maps:remove(GuildId, Guilds),
{reply, ok, State#{guilds => NewGuilds}};
_ ->
case whereis(GuildName) of
undefined ->
{reply, ok, State};
ExistingPid ->
gen_server:stop(ExistingPid, normal, ?SHUTDOWN_TIMEOUT),
process_registry:safe_unregister(GuildName),
{reply, ok, State}
end
end.
select_guilds_to_reload_empty_ids_test() ->
Guilds = #{1 => {self(), make_ref()}, 2 => {self(), make_ref()}},
Result = select_guilds_to_reload([], Guilds),
?assertEqual(2, length(Result)).
-spec do_reload_guild(guild_id(), gen_server:from(), state()) ->
{reply, {error, not_found}, state()} | {noreply, state()}.
do_reload_guild(GuildId, From, State) ->
Guilds = maps:get(guilds, State),
GuildName = process_registry:build_process_name(guild, GuildId),
case maps:get(GuildId, Guilds, undefined) of
{Pid, _Ref} ->
Manager = self(),
ApiHostInfo = select_api_host(State),
spawn(fun() ->
try
Result = fetch_guild_data_with_fallback(GuildId, ApiHostInfo, State),
gen_server:cast(Manager, {guild_data_reloaded, GuildId, Pid, From, Result})
catch
Class:Error:Stacktrace ->
logger:error(
"[guild_manager] Spawned process failed: ~p:~p~n~p",
[Class, Error, Stacktrace]
),
gen_server:cast(
Manager,
{guild_data_reloaded, GuildId, Pid, From, {error, fetch_failed}}
)
end
end),
{noreply, State};
_ ->
case whereis(GuildName) of
undefined ->
{reply, {error, not_found}, State};
_ExistingPid ->
case process_registry:lookup_or_monitor(GuildName, GuildId, Guilds) of
{ok, Pid, _Ref, NewGuilds} ->
NewState = State#{guilds => NewGuilds},
Manager = self(),
ApiHostInfo = select_api_host(NewState),
spawn(fun() ->
try
Result = fetch_guild_data_with_fallback(
GuildId, ApiHostInfo, NewState
),
gen_server:cast(
Manager, {guild_data_reloaded, GuildId, Pid, From, Result}
)
catch
Class:Error:Stacktrace ->
logger:error(
"[guild_manager] Spawned process failed: ~p:~p~n~p",
[Class, Error, Stacktrace]
),
gen_server:cast(
Manager,
{guild_data_reloaded, GuildId, Pid, From,
{error, fetch_failed}}
)
end
end),
{noreply, NewState};
{error, not_found} ->
{reply, {error, not_found}, State}
end
end
end.
select_guilds_to_reload_specific_ids_test() ->
Pid = self(),
Ref = make_ref(),
Guilds = #{1 => {Pid, Ref}, 2 => {Pid, Ref}, 3 => loading},
Result = select_guilds_to_reload([1, 3], Guilds),
?assertEqual(1, length(Result)).
-spec do_shutdown_guild(guild_id(), state()) -> {reply, ok, state()}.
do_shutdown_guild(GuildId, State) ->
Guilds = maps:get(guilds, State),
GuildName = process_registry:build_process_name(guild, GuildId),
case maps:get(GuildId, Guilds, undefined) of
{Pid, Ref} ->
demonitor(Ref, [flush]),
gen_server:call(Pid, {terminate}, ?SHUTDOWN_TIMEOUT),
process_registry:safe_unregister(GuildName),
NewGuilds = maps:remove(GuildId, Guilds),
{reply, ok, State#{guilds => NewGuilds}};
_ ->
case whereis(GuildName) of
undefined ->
{reply, ok, State};
ExistingPid ->
catch gen_server:call(ExistingPid, {terminate}, ?SHUTDOWN_TIMEOUT),
process_registry:safe_unregister(GuildName),
{reply, ok, State}
end
end.
reply_to_all_empty_list_test() ->
?assertEqual(ok, reply_to_all([], ok)).
do_start_or_lookup_loading_deduplicates_requests_test() ->
GuildId = 4444,
From1 = {self(), make_ref()},
From2 = {self(), make_ref()},
State0 = #{
guilds => #{GuildId => loading},
api_host => "http://api.local",
api_canary_host => undefined,
pending_requests => #{},
shard_index => 0
},
{noreply, State1} = do_start_or_lookup(GuildId, From1, State0),
Pending1 = maps:get(pending_requests, State1),
?assertEqual([From1], maps:get(GuildId, Pending1)),
{noreply, State2} = do_start_or_lookup(GuildId, From2, State1),
Pending2 = maps:get(pending_requests, State2),
Requests = maps:get(GuildId, Pending2),
?assertEqual(2, length(Requests)),
?assert(lists:member(From1, Requests)),
?assert(lists:member(From2, Requests)).
-endif.

View File

@@ -25,11 +25,14 @@
get_items_in_range/3,
handle_member_update/3,
build_sync_response/4,
member_list_delta/4,
member_list_snapshot/2,
get_online_count/1,
broadcast_member_list_updates/3,
broadcast_all_member_list_updates/1,
broadcast_member_list_updates_for_channel/2,
normalize_ranges/1
normalize_ranges/1,
get_members_cursor/2
]).
-ifdef(TEST).
@@ -43,13 +46,19 @@
-type group_item() :: map().
-type list_item() :: member_item() | group_item().
-type user_id() :: integer().
-type channel_id() :: integer().
-define(MAX_RANGE_END, 100000).
-spec validate_range(range()) -> range() | invalid.
validate_range({Start, End}) when is_integer(Start), is_integer(End),
Start >= 0, End >= 0, Start =< End,
End =< ?MAX_RANGE_END ->
validate_range({Start, End}) when
is_integer(Start),
is_integer(End),
Start >= 0,
End >= 0,
Start =< End,
End =< ?MAX_RANGE_END
->
{Start, End};
validate_range(_) ->
invalid.
@@ -77,27 +86,39 @@ merge_overlapping_ranges([{S1, E1}, {S2, E2} | Rest]) when S2 =< E1 + 1 ->
merge_overlapping_ranges([Range | Rest]) ->
[Range | merge_overlapping_ranges(Rest)].
-spec calculate_list_id(integer(), guild_state()) -> list_id().
calculate_list_id(ChannelId, _State) when is_integer(ChannelId), ChannelId > 0 ->
integer_to_binary(ChannelId);
-spec calculate_list_id(channel_id(), guild_state()) -> list_id().
calculate_list_id(ChannelId, State) when is_integer(ChannelId), ChannelId > 0 ->
Data = maps:get(data, State, #{}),
Channels = map_utils:ensure_list(maps:get(<<"channels">>, Data, [])),
ChannelIdBin = integer_to_binary(ChannelId),
ChannelExists = lists:any(
fun(Ch) -> maps:get(<<"id">>, Ch, undefined) =:= ChannelIdBin end, Channels
),
case ChannelExists of
true -> ChannelIdBin;
false -> <<"0">>
end;
calculate_list_id(_, _) ->
<<"0">>.
-spec get_member_groups(list_id(), guild_state()) -> [group_item()].
get_member_groups(ListId, State) ->
Data = maps:get(data, State, #{}),
Members = map_utils:ensure_list(maps:get(<<"members">>, Data, [])),
Members = guild_data_index:member_values(Data),
Roles = map_utils:ensure_list(maps:get(<<"roles">>, Data, [])),
GuildId = maps:get(id, State, 0),
HoistedRoles = get_hoisted_roles_sorted(Roles, GuildId),
FilteredMembers = filter_members_for_list(ListId, Members, State),
{OnlineMembers, OfflineMembers} = partition_members_by_online(FilteredMembers, State),
RoleGroups = build_role_groups(HoistedRoles, OnlineMembers),
OnlineGroup = #{<<"id">> => <<"online">>, <<"count">> => count_ungrouped_online(OnlineMembers, HoistedRoles)},
OnlineGroup = #{
<<"id">> => <<"online">>, <<"count">> => count_ungrouped_online(OnlineMembers, HoistedRoles)
},
OfflineGroup = #{<<"id">> => <<"offline">>, <<"count">> => length(OfflineMembers)},
RoleGroups ++ [OnlineGroup, OfflineGroup].
-spec subscribe_ranges(binary(), list_id(), [range()], guild_state()) -> {guild_state(), boolean(), [range()]}.
-spec subscribe_ranges(binary(), list_id(), [range()], guild_state()) ->
{guild_state(), boolean(), [range()]}.
subscribe_ranges(SessionId, ListId, Ranges, State) ->
NormalizedRanges = normalize_ranges(Ranges),
Subscriptions = maps:get(member_list_subscriptions, State, #{}),
@@ -205,8 +226,9 @@ build_sync_response(GuildId, ListId, Ranges, State) ->
-spec get_online_count(guild_state()) -> non_neg_integer().
get_online_count(State) ->
Data = maps:get(data, State, #{}),
Members = map_utils:ensure_list(maps:get(<<"members">>, Data, [])),
{OnlineMembers, _} = partition_members_by_online(Members, State),
Members = guild_data_index:member_values(Data),
EligibleMembers = filter_members_for_list(<<"0">>, Members, State),
{OnlineMembers, _} = partition_members_by_online(EligibleMembers, State),
length(OnlineMembers).
-spec broadcast_member_list_updates(user_id(), guild_state(), guild_state()) -> ok.
@@ -226,7 +248,9 @@ broadcast_member_list_updates(_UserId, OldState, UpdatedState) ->
<<"groups">> => Groups,
<<"ops">> => Ops
},
send_member_list_update_to_sessions(ListId, ListSubs, Sessions, Payload, UpdatedState);
send_member_list_update_to_sessions(
ListId, ListSubs, Sessions, Payload, UpdatedState
);
_ ->
ok
end
@@ -246,7 +270,9 @@ send_member_list_update_to_sessions(ListId, ListSubs, Sessions, Payload, State)
true ->
case session_can_view_channel(SessionData, ChannelId, State) of
true ->
gen_server:cast(SessionPid, {dispatch, guild_member_list_update, Payload});
gen_server:cast(
SessionPid, {dispatch, guild_member_list_update, Payload}
);
false ->
ok
end;
@@ -260,7 +286,7 @@ send_member_list_update_to_sessions(ListId, ListSubs, Sessions, Payload, State)
ListSubs
).
-spec member_list_delta(list_id(), guild_state(), guild_state(), user_id()) ->
-spec member_list_delta(list_id(), guild_state(), guild_state(), user_id()) ->
{non_neg_integer(), non_neg_integer(), [group_item()], [list_item()], boolean()}.
member_list_delta(ListId, OldState, UpdatedState, UserId) ->
{OldCount, OldOnline, OldGroups, OldItems} = member_list_snapshot(ListId, OldState),
@@ -270,10 +296,14 @@ member_list_delta(ListId, OldState, UpdatedState, UserId) ->
{MemberCount, OnlineCount, Groups, Ops, true};
{false, _} ->
Ops = diff_items_to_ops(OldItems, Items),
Changed = Ops =/= [] orelse OldCount =/= MemberCount orelse OldOnline =/= OnlineCount orelse OldGroups =/= Groups,
Changed =
Ops =/= [] orelse OldCount =/= MemberCount orelse OldOnline =/= OnlineCount orelse
OldGroups =/= Groups,
{MemberCount, OnlineCount, Groups, Ops, Changed}
end.
-spec presence_move_ops(user_id(), guild_state(), guild_state(), [list_item()], [list_item()]) ->
{boolean(), [map()]}.
presence_move_ops(UserId, OldState, UpdatedState, OldItems, NewItems) ->
case presence_status_changed(UserId, OldState, UpdatedState) of
false ->
@@ -295,6 +325,7 @@ presence_move_ops(UserId, OldState, UpdatedState, OldItems, NewItems) ->
end
end.
-spec presence_status_changed(user_id(), guild_state(), guild_state()) -> boolean().
presence_status_changed(UserId, OldState, UpdatedState) ->
OldPresence = resolve_presence_for_user(OldState, UserId),
NewPresence = resolve_presence_for_user(UpdatedState, UserId),
@@ -302,9 +333,13 @@ presence_status_changed(UserId, OldState, UpdatedState) ->
NewStatus = maps:get(<<"status">>, NewPresence, <<"offline">>),
OldStatus =/= NewStatus.
-spec find_member_entry(user_id(), [list_item()]) ->
{ok, non_neg_integer(), list_item()} | {error, not_found}.
find_member_entry(UserId, Items) ->
find_member_entry(UserId, Items, 0).
-spec find_member_entry(user_id(), [list_item()], non_neg_integer()) ->
{ok, non_neg_integer(), list_item()} | {error, not_found}.
find_member_entry(_UserId, [], _Index) ->
{error, not_found};
find_member_entry(UserId, [Item | Rest], Index) ->
@@ -318,6 +353,7 @@ find_member_entry(UserId, [Item | Rest], Index) ->
end
end.
-spec adjusted_insert_index(non_neg_integer(), non_neg_integer()) -> non_neg_integer().
adjusted_insert_index(OldIdx, NewIdx) when NewIdx > OldIdx ->
NewIdx - 1;
adjusted_insert_index(_OldIdx, NewIdx) ->
@@ -345,9 +381,16 @@ broadcast_all_member_list_updates(State) ->
case maps:get(SessionId, Sessions, undefined) of
SessionData when is_map(SessionData) ->
SessionPid = maps:get(pid, SessionData, undefined),
case {is_pid(SessionPid), session_can_view_channel(SessionData, ChannelId, State)} of
case
{
is_pid(SessionPid),
session_can_view_channel(SessionData, ChannelId, State)
}
of
{true, true} ->
SyncResponse = build_sync_response(GuildId, ListId, Ranges, State),
SyncResponse = build_sync_response(
GuildId, ListId, Ranges, State
),
gen_server:cast(
SessionPid,
{dispatch, guild_member_list_update, SyncResponse}
@@ -366,7 +409,7 @@ broadcast_all_member_list_updates(State) ->
),
ok.
-spec broadcast_member_list_updates_for_channel(integer(), guild_state()) -> ok.
-spec broadcast_member_list_updates_for_channel(channel_id(), guild_state()) -> ok.
broadcast_member_list_updates_for_channel(ChannelId, State) ->
GuildId = maps:get(id, State, 0),
ListId = calculate_list_id(ChannelId, State),
@@ -381,13 +424,23 @@ broadcast_member_list_updates_for_channel(ChannelId, State) ->
case maps:get(SessionId, Sessions, undefined) of
SessionData when is_map(SessionData) ->
SessionPid = maps:get(pid, SessionData, undefined),
case {is_pid(SessionPid), session_can_view_channel(SessionData, ChannelId, State)} of
case
{
is_pid(SessionPid),
session_can_view_channel(SessionData, ChannelId, State)
}
of
{true, true} ->
SyncResponse = build_sync_response(GuildId, ListId, Ranges, State),
SyncResponseWithChannel = maps:put(<<"channel_id">>, integer_to_binary(ChannelId), SyncResponse),
SyncResponse = build_sync_response(
GuildId, ListId, Ranges, State
),
SyncResponseWithChannel = maps:put(
<<"channel_id">>, integer_to_binary(ChannelId), SyncResponse
),
gen_server:cast(
SessionPid,
{dispatch, guild_member_list_update, SyncResponseWithChannel}
{dispatch, guild_member_list_update,
SyncResponseWithChannel}
);
_ ->
ok
@@ -445,14 +498,33 @@ filter_members_for_list(ListId, Members, State) ->
)
end.
-spec connected_session_user_ids(guild_state()) -> sets:set().
connected_session_user_ids(State) ->
Sessions = maps:get(sessions, State, #{}),
maps:fold(
fun(_SessionId, SessionData, Acc) ->
case maps:get(user_id, SessionData, undefined) of
UserId when is_integer(UserId), UserId > 0 ->
sets:add_element(UserId, Acc);
_ ->
Acc
end
end,
sets:new(),
Sessions
).
-spec partition_members_by_online([map()], guild_state()) -> {[map()], [map()]}.
partition_members_by_online(Members, State) ->
ConnectedUserIds = connected_session_user_ids(State),
lists:partition(
fun(Member) ->
UserId = get_member_user_id(Member),
Presence = resolve_presence_for_user(State, UserId),
Status = maps:get(<<"status">>, Presence, <<"offline">>),
Status =/= <<"offline">> andalso Status =/= <<"invisible">>
IsConnected = sets:is_element(UserId, ConnectedUserIds),
IsOnlineStatus = Status =/= <<"offline">> andalso Status =/= <<"invisible">>,
IsConnected andalso IsOnlineStatus
end,
Members
).
@@ -471,15 +543,17 @@ build_role_groups(HoistedRoles, OnlineMembers) ->
-spec count_members_with_top_role(integer(), [map()], [map()]) -> non_neg_integer().
count_members_with_top_role(RoleId, Members, HoistedRoles) ->
HoistedRoleIds = [map_utils:get_integer(R, <<"id">>, 0) || R <- HoistedRoles],
length(lists:filter(
fun(Member) ->
MemberRoles = map_utils:ensure_list(maps:get(<<"roles">>, Member, [])),
MemberRoleIds = [type_conv:to_integer(R) || R <- MemberRoles],
TopHoisted = find_top_hoisted_role(MemberRoleIds, HoistedRoleIds),
TopHoisted =:= RoleId
end,
Members
)).
length(
lists:filter(
fun(Member) ->
MemberRoles = map_utils:ensure_list(maps:get(<<"roles">>, Member, [])),
MemberRoleIds = [type_conv:to_integer(R) || R <- MemberRoles],
TopHoisted = find_top_hoisted_role(MemberRoleIds, HoistedRoleIds),
TopHoisted =:= RoleId
end,
Members
)
).
-spec find_top_hoisted_role([integer()], [integer()]) -> integer() | undefined.
find_top_hoisted_role(MemberRoleIds, HoistedRoleIds) ->
@@ -491,20 +565,22 @@ find_top_hoisted_role(MemberRoleIds, HoistedRoleIds) ->
-spec count_ungrouped_online([map()], [map()]) -> non_neg_integer().
count_ungrouped_online(OnlineMembers, HoistedRoles) ->
HoistedRoleIds = [map_utils:get_integer(R, <<"id">>, 0) || R <- HoistedRoles],
length(lists:filter(
fun(Member) ->
MemberRoles = map_utils:ensure_list(maps:get(<<"roles">>, Member, [])),
MemberRoleIds = [type_conv:to_integer(R) || R <- MemberRoles],
TopHoisted = find_top_hoisted_role(MemberRoleIds, HoistedRoleIds),
TopHoisted =:= undefined
end,
OnlineMembers
)).
length(
lists:filter(
fun(Member) ->
MemberRoles = map_utils:ensure_list(maps:get(<<"roles">>, Member, [])),
MemberRoleIds = [type_conv:to_integer(R) || R <- MemberRoles],
TopHoisted = find_top_hoisted_role(MemberRoleIds, HoistedRoleIds),
TopHoisted =:= undefined
end,
OnlineMembers
)
).
-spec get_sorted_members_for_list(list_id(), guild_state()) -> [map()].
get_sorted_members_for_list(ListId, State) ->
Data = maps:get(data, State, #{}),
Members = map_utils:ensure_list(maps:get(<<"members">>, Data, [])),
Members = guild_data_index:member_values(Data),
FilteredMembers = filter_members_for_list(ListId, Members, State),
lists:sort(
fun(A, B) ->
@@ -519,11 +595,18 @@ get_member_user_id(Member) ->
map_utils:get_integer(User, <<"id">>, 0).
-spec normalize_name(term()) -> binary().
normalize_name(undefined) -> <<>>;
normalize_name(null) -> <<>>;
normalize_name(<<_/binary>> = B) -> B;
normalize_name(undefined) ->
<<>>;
normalize_name(null) ->
<<>>;
normalize_name(<<_/binary>> = B) ->
B;
normalize_name(L) when is_list(L) ->
try unicode:characters_to_binary(L) catch _:_ -> <<>> end;
try
unicode:characters_to_binary(L)
catch
_:_ -> <<>>
end;
normalize_name(I) when is_integer(I) ->
integer_to_binary(I);
normalize_name(_) ->
@@ -560,11 +643,13 @@ casefold_binary(Value) ->
_:_ -> Bin
end.
-spec add_presence_to_member(map(), guild_state()) -> map().
add_presence_to_member(Member, State) ->
UserId = get_member_user_id(Member),
Presence = resolve_presence_for_user(State, UserId),
maps:put(<<"presence">>, Presence, Member).
-spec default_presence() -> map().
default_presence() ->
#{
<<"status">> => <<"offline">>,
@@ -572,20 +657,10 @@ default_presence() ->
<<"afk">> => false
}.
-spec resolve_presence_for_user(guild_state(), user_id()) -> map().
resolve_presence_for_user(State, UserId) ->
Presences = maps:get(presences, State, #{}),
case maps:get(UserId, Presences, undefined) of
undefined -> fetch_presence_from_cache(UserId);
Presence -> Presence
end.
fetch_presence_from_cache(UserId) ->
try presence_cache:get(UserId) of
{ok, Presence} -> Presence;
_ -> default_presence()
catch
exit:{noproc, _} -> default_presence()
end.
MemberPresence = maps:get(member_presence, State, #{}),
maps:get(UserId, MemberPresence, default_presence()).
-spec build_member_list_items([group_item()], [map()], guild_state()) -> [list_item()].
build_member_list_items(Groups, Members, State) ->
@@ -609,9 +684,21 @@ build_member_list_items(Groups, Members, State) ->
end,
OnlineMembers
),
[GroupHeader | [#{<<"member">> => add_presence_to_member(M, State)} || M <- UngroupedOnline]];
[
GroupHeader
| [
#{<<"member">> => add_presence_to_member(M, State)}
|| M <- UngroupedOnline
]
];
<<"offline">> ->
[GroupHeader | [#{<<"member">> => add_presence_to_member(M, State)} || M <- OfflineMembers]];
[
GroupHeader
| [
#{<<"member">> => add_presence_to_member(M, State)}
|| M <- OfflineMembers
]
];
RoleIdBin ->
RoleId = type_conv:to_integer(RoleIdBin),
RoleMembers = lists:filter(
@@ -622,7 +709,10 @@ build_member_list_items(Groups, Members, State) ->
end,
OnlineMembers
),
[GroupHeader | [#{<<"member">> => add_presence_to_member(M, State)} || M <- RoleMembers]]
[
GroupHeader
| [#{<<"member">> => add_presence_to_member(M, State)} || M <- RoleMembers]
]
end
end,
Groups
@@ -636,7 +726,7 @@ slice_items(Items, Start, End) ->
false -> lists:sublist(Items, Start + 1, SafeEnd - Start + 1)
end.
-spec member_in_list(integer(), [map()]) -> boolean().
-spec member_in_list(user_id(), [map()]) -> boolean().
member_in_list(UserId, Members) ->
lists:any(fun(M) -> get_member_user_id(M) =:= UserId end, Members).
@@ -650,50 +740,83 @@ full_sync_ops(ListId, State) ->
SortedMembers = get_sorted_members_for_list(ListId, State),
Items = build_full_items(ListId, State, SortedMembers),
case length(Items) of
0 -> [];
0 ->
[];
N ->
[#{
<<"op">> => <<"SYNC">>,
<<"range">> => [0, N - 1],
<<"items">> => Items
}]
[
#{
<<"op">> => <<"SYNC">>,
<<"range">> => [0, N - 1],
<<"items">> => Items
}
]
end.
-spec upsert_member_in_state(integer(), map(), guild_state()) -> {map() | undefined, map(), guild_state()}.
-spec get_members_cursor(map(), guild_state()) -> {reply, map(), guild_state()}.
get_members_cursor(Request, State) ->
Limit = maps:get(<<"limit">>, Request, 1),
AfterId = maps:get(<<"after">>, Request, undefined),
SortedMembers = sort_members_by_user_id(State),
FilteredMembers = filter_members_after(SortedMembers, AfterId),
ResponseMembers = take_first(FilteredMembers, Limit),
{reply,
#{
members => ResponseMembers,
total => length(SortedMembers)
},
State}.
-spec take_first([map()], integer()) -> [map()].
take_first(_Members, Limit) when Limit =< 0 ->
[];
take_first(Members, Limit) ->
Count = min(Limit, length(Members)),
case Count of
0 -> [];
_ -> lists:sublist(Members, 1, Count)
end.
-spec filter_members_after([map()], integer() | undefined) -> [map()].
filter_members_after(Members, undefined) ->
Members;
filter_members_after(Members, AfterId) ->
lists:dropwhile(
fun(Member) ->
get_member_user_id(Member) =< AfterId
end,
Members
).
-spec sort_members_by_user_id(guild_state()) -> [map()].
sort_members_by_user_id(State) ->
Data = maps:get(data, State, #{}),
Members = guild_data_index:member_values(Data),
lists:sort(
fun(A, B) ->
get_member_user_id(A) =< get_member_user_id(B)
end,
Members
).
-spec upsert_member_in_state(user_id(), map(), guild_state()) ->
{map() | undefined, map(), guild_state()}.
upsert_member_in_state(UserId, MemberUpdate, State) ->
Data = maps:get(data, State, #{}),
Members0 = map_utils:ensure_list(maps:get(<<"members">>, Data, [])),
{Found, CurrentMember} = find_member_by_user_id(UserId, Members0),
Members0 = guild_data_index:member_map(Data),
CurrentMember = maps:get(UserId, Members0, undefined),
UpdatedMember =
case Found of
true -> deep_merge_member(CurrentMember, MemberUpdate);
false -> MemberUpdate
case CurrentMember of
undefined -> MemberUpdate;
_ -> deep_merge_member(CurrentMember, MemberUpdate)
end,
Members1 =
case Found of
true ->
lists:map(
fun(M) ->
case get_member_user_id(M) =:= UserId of
true -> UpdatedMember;
false -> M
end
end,
Members0
);
false ->
Members0 ++ [UpdatedMember]
end,
Data1 = maps:put(<<"members">>, Members1, Data),
Members1 = maps:put(UserId, UpdatedMember, Members0),
Data1 = guild_data_index:put_member_map(Members1, Data),
NewState = maps:put(data, Data1, State),
{case Found of true -> CurrentMember; false -> undefined end, UpdatedMember, NewState}.
-spec find_member_by_user_id(integer(), [map()]) -> {boolean(), map()} | {false, undefined}.
find_member_by_user_id(UserId, Members) ->
case lists:search(fun(M) -> get_member_user_id(M) =:= UserId end, Members) of
{value, M} -> {true, M};
false -> {false, undefined}
end.
{
CurrentMember,
UpdatedMember,
NewState
}.
-spec deep_merge_member(map(), map()) -> map().
deep_merge_member(CurrentMember, MemberUpdate) ->
@@ -741,13 +864,16 @@ diff_items_to_ops(OldItems, NewItems) ->
-spec full_sync_from_items([list_item()]) -> [map()].
full_sync_from_items(Items) ->
case length(Items) of
0 -> [];
0 ->
[];
N ->
[#{
<<"op">> => <<"SYNC">>,
<<"range">> => [0, N - 1],
<<"items">> => Items
}]
[
#{
<<"op">> => <<"SYNC">>,
<<"range">> => [0, N - 1],
<<"items">> => Items
}
]
end.
-spec item_keys([list_item()]) -> [term()].
@@ -777,11 +903,14 @@ updates_for_changed_items(OldItems, NewItems) ->
true ->
Acc;
false ->
[#{
<<"op">> => <<"UPDATE">>,
<<"index">> => Idx,
<<"item">> => NewItem
} | Acc]
[
#{
<<"op">> => <<"UPDATE">>,
<<"index">> => Idx,
<<"item">> => NewItem
}
| Acc
]
end
end,
[],
@@ -792,6 +921,9 @@ updates_for_changed_items(OldItems, NewItems) ->
-spec zip_with_index([term()], [term()]) -> [{non_neg_integer(), term(), term()}].
zip_with_index(A, B) ->
zip_with_index(A, B, 0, []).
-spec zip_with_index([term()], [term()], non_neg_integer(), [{non_neg_integer(), term(), term()}]) ->
[{non_neg_integer(), term(), term()}].
zip_with_index([], [], _I, Acc) ->
lists:reverse(Acc);
zip_with_index([HA | TA], [HB | TB], I, Acc) ->
@@ -802,6 +934,11 @@ zip_with_index(_, _, _I, Acc) ->
-spec mismatch_span([term()], [term()]) -> none | {non_neg_integer(), non_neg_integer()}.
mismatch_span(A, B) ->
mismatch_span(A, B, 0, none).
-spec mismatch_span(
[term()], [term()], non_neg_integer(), none | {non_neg_integer(), non_neg_integer()}
) ->
none | {non_neg_integer(), non_neg_integer()}.
mismatch_span([], [], _I, none) ->
none;
mismatch_span([], [], _I, {S, E}) ->
@@ -831,7 +968,8 @@ sync_range_op(Start, End, NewItems) ->
<<"items">> => slice_items(NewItems, Start, End)
}.
-spec try_pure_insert_delete([list_item()], [list_item()], [term()], [term()]) -> {ok, [map()]} | error.
-spec try_pure_insert_delete([list_item()], [list_item()], [term()], [term()]) ->
{ok, [map()]} | error.
try_pure_insert_delete(OldItems, NewItems, OldKeys, NewKeys) ->
LenOld = length(OldKeys),
LenNew = length(NewKeys),
@@ -872,6 +1010,8 @@ try_pure_insert_delete(OldItems, NewItems, OldKeys, NewKeys) ->
-spec first_mismatch_index([term()], [term()]) -> non_neg_integer().
first_mismatch_index(A, B) ->
first_mismatch_index(A, B, 0).
-spec first_mismatch_index([term()], [term()], non_neg_integer()) -> non_neg_integer().
first_mismatch_index([], _B, I) ->
I;
first_mismatch_index(_A, [], I) ->
@@ -896,6 +1036,8 @@ delete_many(List, Index, Count) ->
-spec insert_ops(non_neg_integer(), [list_item()]) -> [map()].
insert_ops(StartIdx, Items) ->
insert_ops(StartIdx, Items, 0, []).
-spec insert_ops(non_neg_integer(), [list_item()], non_neg_integer(), [map()]) -> [map()].
insert_ops(_StartIdx, [], _Offset, Acc) ->
lists:reverse(Acc);
insert_ops(StartIdx, [Item | Rest], Offset, Acc) ->
@@ -920,18 +1062,23 @@ delete_ops(Idx, Count) ->
lists:seq(1, Count)
).
-spec session_can_view_channel(map(), integer(), guild_state()) -> boolean().
session_can_view_channel(_SessionData, ChannelId, _State) when not is_integer(ChannelId); ChannelId =< 0 ->
-spec session_can_view_channel(map(), channel_id(), guild_state()) -> boolean().
session_can_view_channel(_SessionData, ChannelId, _State) when
not is_integer(ChannelId); ChannelId =< 0
->
false;
session_can_view_channel(SessionData, ChannelId, State) ->
case maps:get(user_id, SessionData, undefined) of
UserId when is_integer(UserId) ->
case {maps:get(user_id, SessionData, undefined), maps:get(viewable_channels, SessionData, undefined)} of
{UserId, ViewableChannels} when is_integer(UserId), is_map(ViewableChannels) ->
maps:is_key(ChannelId, ViewableChannels) orelse
guild_permissions:can_view_channel(UserId, ChannelId, undefined, State);
{UserId, _} when is_integer(UserId) ->
guild_permissions:can_view_channel(UserId, ChannelId, undefined, State);
_ ->
false
end.
-spec list_id_channel_id(list_id()) -> integer().
-spec list_id_channel_id(list_id()) -> channel_id().
list_id_channel_id(ListId) when is_binary(ListId) ->
case type_conv:to_integer(ListId) of
undefined -> 0;
@@ -966,6 +1113,16 @@ list_id_channel_id_parses_binary_test() ->
list_id_channel_id_invalid_value_test() ->
?assertEqual(0, list_id_channel_id(<<"abc">>)).
session_can_view_channel_uses_cached_visibility_test() ->
SessionData = #{user_id => 12, viewable_channels => #{500 => true}},
State = #{data => #{<<"members">> => #{}}},
?assertEqual(true, session_can_view_channel(SessionData, 500, State)).
session_can_view_channel_rejects_when_cache_misses_and_user_missing_test() ->
SessionData = #{user_id => 99, viewable_channels => #{}},
State = #{data => #{<<"members">> => #{}}},
?assertEqual(false, session_can_view_channel(SessionData, 500, State)).
subscribe_ranges_test() ->
State = #{member_list_subscriptions => #{}},
{NewState, ShouldSync, NormalizedRanges} =
@@ -1024,4 +1181,64 @@ validate_range_invalid_negative_test() ->
validate_range_invalid_too_large_test() ->
?assertEqual(invalid, validate_range({0, 100001})).
get_members_cursor_returns_atom_keys_test() ->
Member1 = #{<<"user">> => #{<<"id">> => <<"2">>}},
Member2 = #{<<"user">> => #{<<"id">> => <<"1">>}},
State = #{data => #{<<"members">> => [Member1, Member2]}},
{reply, Reply, _NewState} = get_members_cursor(#{<<"limit">> => 1}, State),
?assert(maps:is_key(members, Reply)),
?assert(maps:is_key(total, Reply)),
?assertNot(maps:is_key(<<"members">>, Reply)),
?assertNot(maps:is_key(<<"total">>, Reply)).
get_online_count_ignores_members_without_connected_session_test() ->
Members = [
#{<<"user">> => #{<<"id">> => <<"1">>}},
#{<<"user">> => #{<<"id">> => <<"2">>}}
],
State = #{
data => #{<<"members">> => Members},
sessions => #{<<"s1">> => #{user_id => 1}},
member_presence => #{
1 => #{<<"status">> => <<"online">>},
2 => #{<<"status">> => <<"online">>}
}
},
?assertEqual(1, get_online_count(State)).
filter_members_for_list_keeps_members_without_connected_session_test() ->
Members = [
#{<<"user">> => #{<<"id">> => <<"1">>}},
#{<<"user">> => #{<<"id">> => <<"2">>}}
],
State = #{
sessions => #{<<"s1">> => #{user_id => 1}},
data => #{<<"channels">> => []}
},
FilteredMembers = filter_members_for_list(<<"0">>, Members, State),
?assertEqual([1, 2], [get_member_user_id(Member) || Member <- FilteredMembers]).
get_member_groups_counts_members_without_connected_session_as_offline_test() ->
Members = [
#{<<"user">> => #{<<"id">> => <<"1">>}},
#{<<"user">> => #{<<"id">> => <<"2">>}}
],
State = #{
id => 123,
data => #{
<<"members">> => Members,
<<"roles">> => []
},
sessions => #{<<"s1">> => #{user_id => 1}},
member_presence => #{
1 => #{<<"status">> => <<"online">>},
2 => #{<<"status">> => <<"online">>}
}
},
Groups = get_member_groups(<<"0">>, State),
OnlineGroup = lists:nth(1, Groups),
OfflineGroup = lists:nth(2, Groups),
?assertEqual(1, maps:get(<<"count">>, OnlineGroup)),
?assertEqual(1, maps:get(<<"count">>, OfflineGroup)).
-endif.

View File

@@ -29,40 +29,35 @@
compute_list_id/1
]).
-record(member_storage, {
members_table :: ets:tid(),
display_name_index :: gb_trees:tree()
}).
-type storage() :: #member_storage{}.
-type storage() :: #{
members_table := ets:tid(),
display_name_index := gb_trees:tree({binary(), user_id()}, user_id())
}.
-type user_id() :: integer().
-type member() :: map().
-type index_key() :: {binary(), user_id()}.
-export_type([storage/0]).
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-endif.
-spec new() -> storage().
new() ->
MembersTable = ets:new(members, [set, private]),
MembersTable = ets:new(members, [set, private, {read_concurrency, true}]),
DisplayNameIndex = gb_trees:empty(),
#member_storage{
members_table = MembersTable,
display_name_index = DisplayNameIndex
#{
members_table => MembersTable,
display_name_index => DisplayNameIndex
}.
-spec insert_member(member(), storage()) -> storage().
insert_member(Member, Storage) ->
UserId = extract_user_id(Member),
case UserId of
case extract_user_id(Member) of
undefined ->
Storage;
_ ->
UserId ->
OldMember = get_member(UserId, Storage),
Storage1 = remove_from_index(OldMember, Storage),
ets:insert(Storage1#member_storage.members_table, {UserId, Member}),
#{members_table := MembersTable} = Storage1,
ets:insert(MembersTable, {UserId, Member}),
add_to_index(UserId, Member, Storage1)
end.
@@ -73,13 +68,15 @@ remove_member(UserId, Storage) ->
Storage;
Member ->
Storage1 = remove_from_index(Member, Storage),
ets:delete(Storage1#member_storage.members_table, UserId),
#{members_table := MembersTable} = Storage1,
ets:delete(MembersTable, UserId),
Storage1
end.
-spec get_member(user_id(), storage()) -> member() | undefined.
get_member(UserId, Storage) ->
case ets:lookup(Storage#member_storage.members_table, UserId) of
#{members_table := MembersTable} = Storage,
case ets:lookup(MembersTable, UserId) of
[{UserId, Member}] -> Member;
[] -> undefined
end.
@@ -100,41 +97,46 @@ get_members_by_ids(UserIds, Storage) ->
search_members(Query, Limit, Storage) when is_binary(Query), Limit > 0 ->
NormalizedQuery = normalize_display_name(Query),
case NormalizedQuery of
<<>> ->
[];
_ ->
search_by_prefix(NormalizedQuery, Limit, Storage)
<<>> -> [];
_ -> search_by_prefix(NormalizedQuery, Limit, Storage)
end;
search_members(_, _, _) ->
[].
-spec get_range(non_neg_integer(), non_neg_integer(), storage()) -> [member()].
get_range(Offset, Limit, Storage) when is_integer(Offset), is_integer(Limit), Limit > 0 ->
Index = Storage#member_storage.display_name_index,
case gb_trees:size(Index) of
Size when Offset >= Size ->
[];
Size ->
AllKeys = gb_trees:keys(Index),
EndIdx = min(Offset + Limit, Size),
SelectedKeys = lists:sublist(AllKeys, Offset + 1, EndIdx - Offset),
lists:filtermap(
fun(Key) ->
UserId = gb_trees:get(Key, Index),
case get_member(UserId, Storage) of
undefined -> false;
Member -> {true, Member}
end
end,
SelectedKeys
)
#{display_name_index := Index} = Storage,
Size = gb_trees:size(Index),
case Offset >= Size of
true -> [];
false -> get_range_from_index(Offset, Limit, Size, Index, Storage)
end;
get_range(_, _, _) ->
[].
-spec get_range_from_index(
non_neg_integer(), non_neg_integer(), non_neg_integer(), gb_trees:tree(), storage()
) ->
[member()].
get_range_from_index(Offset, Limit, Size, Index, Storage) ->
AllKeys = gb_trees:keys(Index),
EndIdx = min(Offset + Limit, Size),
SelectedKeys = lists:sublist(AllKeys, Offset + 1, EndIdx - Offset),
lists:filtermap(
fun(Key) ->
UserId = gb_trees:get(Key, Index),
case get_member(UserId, Storage) of
undefined -> false;
Member -> {true, Member}
end
end,
SelectedKeys
).
-spec count(storage()) -> non_neg_integer().
count(Storage) ->
ets:info(Storage#member_storage.members_table, size).
#{members_table := MembersTable} = Storage,
ets:info(MembersTable, size).
-spec compute_list_id([user_id()]) -> integer().
compute_list_id(UserIds) ->
@@ -155,19 +157,17 @@ extract_user_id(_) ->
-spec get_display_name(member()) -> binary().
get_display_name(Member) when is_map(Member) ->
Nick = maps:get(<<"nick">>, Member, undefined),
case Nick of
undefined ->
User = maps:get(<<"user">>, Member, #{}),
GlobalName = maps:get(<<"global_name">>, User, undefined),
case GlobalName of
undefined ->
maps:get(<<"username">>, User, <<>>);
_ ->
GlobalName
end;
_ ->
Nick
case maps:get(<<"nick">>, Member, undefined) of
undefined -> get_display_name_from_user(Member);
Nick -> Nick
end.
-spec get_display_name_from_user(member()) -> binary().
get_display_name_from_user(Member) ->
User = maps:get(<<"user">>, Member, #{}),
case maps:get(<<"global_name">>, User, undefined) of
undefined -> maps:get(<<"username">>, User, <<>>);
GlobalName -> GlobalName
end.
-spec normalize_display_name(binary()) -> binary().
@@ -180,9 +180,9 @@ add_to_index(UserId, Member, Storage) ->
DisplayName = get_display_name(Member),
NormalizedName = normalize_display_name(DisplayName),
Key = make_index_key(NormalizedName, UserId),
Index = Storage#member_storage.display_name_index,
#{display_name_index := Index} = Storage,
NewIndex = gb_trees:enter(Key, UserId, Index),
Storage#member_storage{display_name_index = NewIndex}.
Storage#{display_name_index => NewIndex}.
-spec remove_from_index(member() | undefined, storage()) -> storage().
remove_from_index(undefined, Storage) ->
@@ -192,41 +192,46 @@ remove_from_index(Member, Storage) ->
DisplayName = get_display_name(Member),
NormalizedName = normalize_display_name(DisplayName),
Key = make_index_key(NormalizedName, UserId),
Index = Storage#member_storage.display_name_index,
#{display_name_index := Index} = Storage,
case gb_trees:is_defined(Key, Index) of
true ->
NewIndex = gb_trees:delete(Key, Index),
Storage#member_storage{display_name_index = NewIndex};
Storage#{display_name_index => NewIndex};
false ->
Storage
end.
-spec make_index_key(binary(), user_id()) -> {binary(), user_id()}.
-spec make_index_key(binary(), user_id()) -> index_key().
make_index_key(NormalizedName, UserId) ->
{NormalizedName, UserId}.
-spec search_by_prefix(binary(), non_neg_integer(), storage()) -> [member()].
search_by_prefix(Prefix, Limit, Storage) ->
Index = Storage#member_storage.display_name_index,
#{display_name_index := Index} = Storage,
AllKeys = gb_trees:keys(Index),
Matches = lists:filtermap(
fun({Name, UserId}) ->
PrefixLen = byte_size(Prefix),
case Name of
<<Prefix:PrefixLen/binary, _/binary>> ->
case get_member(UserId, Storage) of
undefined -> false;
Member -> {true, Member}
end;
_ ->
false
end
end,
AllKeys
),
PrefixLen = byte_size(Prefix),
Matches = find_prefix_matches(AllKeys, Prefix, PrefixLen, Storage, []),
lists:sublist(Matches, Limit).
-spec find_prefix_matches([index_key()], binary(), non_neg_integer(), storage(), [member()]) ->
[member()].
find_prefix_matches([], _Prefix, _PrefixLen, _Storage, Acc) ->
lists:reverse(Acc);
find_prefix_matches([{Name, UserId} | Rest], Prefix, PrefixLen, Storage, Acc) ->
case Name of
<<Prefix:PrefixLen/binary, _/binary>> ->
case get_member(UserId, Storage) of
undefined ->
find_prefix_matches(Rest, Prefix, PrefixLen, Storage, Acc);
Member ->
find_prefix_matches(Rest, Prefix, PrefixLen, Storage, [Member | Acc])
end;
_ ->
find_prefix_matches(Rest, Prefix, PrefixLen, Storage, Acc)
end.
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
new_creates_empty_storage_test() ->
Storage = new(),
@@ -309,11 +314,16 @@ display_name_username_fallback_test() ->
},
?assertEqual(<<"user">>, get_display_name(Member)).
compute_list_id_test() ->
compute_list_id_deterministic_test() ->
Id1 = compute_list_id([1, 2, 3]),
Id2 = compute_list_id([3, 2, 1]),
?assertEqual(Id1, Id2).
compute_list_id_different_for_different_lists_test() ->
Id1 = compute_list_id([1, 2, 3]),
Id2 = compute_list_id([1, 2, 4]),
?assertNotEqual(Id1, Id2).
get_range_test() ->
Storage = new(),
Member1 = #{<<"user">> => #{<<"id">> => <<"1">>, <<"username">> => <<"alice">>}},
@@ -325,4 +335,46 @@ get_range_test() ->
Results = get_range(1, 2, Storage3),
?assertEqual(2, length(Results)).
get_range_offset_beyond_size_test() ->
Storage = new(),
Member1 = #{<<"user">> => #{<<"id">> => <<"1">>, <<"username">> => <<"alice">>}},
Storage1 = insert_member(Member1, Storage),
Results = get_range(10, 5, Storage1),
?assertEqual([], Results).
search_members_empty_query_test() ->
Storage = new(),
Member1 = #{<<"user">> => #{<<"id">> => <<"1">>, <<"username">> => <<"alice">>}},
Storage1 = insert_member(Member1, Storage),
Results = search_members(<<>>, 10, Storage1),
?assertEqual([], Results).
search_members_limit_test() ->
Storage = new(),
Member1 = #{<<"user">> => #{<<"id">> => <<"1">>, <<"username">> => <<"aaa">>}},
Member2 = #{<<"user">> => #{<<"id">> => <<"2">>, <<"username">> => <<"aab">>}},
Member3 = #{<<"user">> => #{<<"id">> => <<"3">>, <<"username">> => <<"aac">>}},
Storage1 = insert_member(Member1, Storage),
Storage2 = insert_member(Member2, Storage1),
Storage3 = insert_member(Member3, Storage2),
Results = search_members(<<"aa">>, 2, Storage3),
?assertEqual(2, length(Results)).
normalize_display_name_test() ->
?assertEqual(<<"hello">>, normalize_display_name(<<"HELLO">>)),
?assertEqual(<<"hello">>, normalize_display_name(<<"Hello">>)),
?assertEqual(<<"hello">>, normalize_display_name(<<"hello">>)).
insert_member_updates_index_test() ->
Storage = new(),
Member1 = #{<<"user">> => #{<<"id">> => <<"1">>, <<"username">> => <<"alice">>}},
Storage1 = insert_member(Member1, Storage),
Member2 = #{<<"user">> => #{<<"id">> => <<"1">>, <<"username">> => <<"bob">>}},
Storage2 = insert_member(Member2, Storage1),
?assertEqual(1, count(Storage2)),
AliceResults = search_members(<<"alice">>, 10, Storage2),
?assertEqual(0, length(AliceResults)),
BobResults = search_members(<<"bob">>, 10, Storage2),
?assertEqual(1, length(BobResults)).
-endif.

View File

@@ -17,16 +17,18 @@
-module(guild_members).
-export([get_users_to_mention_by_roles/2]).
-export([get_users_to_mention_by_user_ids/2]).
-export([get_all_users_to_mention/2]).
-export([resolve_all_mentions/2]).
-export([get_members_with_role/2]).
-export([can_manage_roles/2]).
-export([can_manage_role/2]).
-export([get_assignable_roles/2]).
-export([check_target_member/2]).
-export([get_viewable_channels/2]).
-export([
get_users_to_mention_by_roles/2,
get_users_to_mention_by_user_ids/2,
get_all_users_to_mention/2,
resolve_all_mentions/2,
get_members_with_role/2,
can_manage_roles/2,
can_manage_role/2,
get_assignable_roles/2,
check_target_member/2,
get_viewable_channels/2
]).
-type guild_state() :: map().
-type guild_reply(T) :: {reply, T, guild_state()}.
@@ -37,22 +39,18 @@
-type role_id() :: integer().
-type channel_id() :: integer().
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-endif.
-spec get_users_to_mention_by_roles(map(), guild_state()) -> guild_reply(map()).
get_users_to_mention_by_roles(
#{channel_id := ChannelId, role_ids := RoleIds, author_id := AuthorId}, State
) ->
Members = guild_members(State),
RoleIdSet = normalize_int_list(RoleIds),
UserIds = collect_mentions(
Members,
RoleIdList = normalize_int_list(RoleIds),
CandidateUserIds = user_ids_for_any_role(RoleIdList, State),
UserIds = collect_mentions_for_user_ids(
CandidateUserIds,
AuthorId,
ChannelId,
State,
fun(Member) -> member_has_any_role(Member, RoleIdSet) end
fun(_UserId, _Member) -> true end
),
{reply, #{user_ids => UserIds}, State}.
@@ -60,19 +58,13 @@ get_users_to_mention_by_roles(
get_users_to_mention_by_user_ids(
#{channel_id := ChannelId, user_ids := UserIdsReq, author_id := AuthorId}, State
) ->
Members = guild_members(State),
TargetIds = normalize_int_list(UserIdsReq),
UserIds = collect_mentions(
Members,
UserIds = collect_mentions_for_user_ids(
TargetIds,
AuthorId,
ChannelId,
State,
fun(Member) ->
case member_user_id(Member) of
undefined -> false;
Id -> lists:member(Id, TargetIds)
end
end
fun(_UserId, _Member) -> true end
),
{reply, #{user_ids => UserIds}, State}.
@@ -83,7 +75,7 @@ get_all_users_to_mention(#{channel_id := ChannelId, author_id := AuthorId}, Stat
{reply, #{user_ids => UserIds}, State}.
-spec resolve_all_mentions(map(), guild_state()) -> guild_reply(map()).
resolve_all_mentions(
resolve_all_mentions(Request, State) ->
#{
channel_id := ChannelId,
author_id := AuthorId,
@@ -91,48 +83,122 @@ resolve_all_mentions(
mention_here := MentionHere,
role_ids := RoleIds,
user_ids := DirectUserIds
},
State
) ->
} = Request,
Members = guild_members(State),
MemberMap = guild_data_index:member_map(guild_data(State)),
Sessions = maps:get(sessions, State, #{}),
RoleIdSet = gb_sets:from_list(normalize_int_list(RoleIds)),
DirectUserIdSet = gb_sets:from_list(normalize_int_list(DirectUserIds)),
HasRoleMentions = not gb_sets:is_empty(RoleIdSet),
HasDirectMentions = not gb_sets:is_empty(DirectUserIdSet),
ConnectedUserIds =
case MentionHere of
ConnectedUserIds = build_connected_user_ids(MentionHere, Sessions),
UserIds =
case MentionEveryone of
true ->
gb_sets:from_list([
maps:get(user_id, S)
|| {_Sid, S} <- maps:to_list(Sessions)
]);
resolve_mentions(
Members,
AuthorId,
ChannelId,
MentionEveryone,
MentionHere,
HasRoleMentions,
HasDirectMentions,
RoleIdSet,
DirectUserIdSet,
ConnectedUserIds,
State
);
false ->
gb_sets:empty()
CandidateUserIds = candidate_user_ids_for_mentions(
MentionHere,
HasRoleMentions,
HasDirectMentions,
RoleIdSet,
DirectUserIdSet,
ConnectedUserIds,
State
),
resolve_mentions_for_user_ids(
CandidateUserIds,
MemberMap,
AuthorId,
ChannelId,
MentionEveryone,
MentionHere,
HasRoleMentions,
HasDirectMentions,
RoleIdSet,
DirectUserIdSet,
ConnectedUserIds,
State
)
end,
{reply, #{user_ids => UserIds}, State}.
UserIds = lists:filtermap(
-spec build_connected_user_ids(boolean(), map()) -> gb_sets:set(user_id()).
build_connected_user_ids(false, _Sessions) ->
gb_sets:empty();
build_connected_user_ids(true, Sessions) ->
gb_sets:from_list(
lists:filtermap(
fun({_Sid, SessionData}) ->
case maps:get(user_id, SessionData, undefined) of
UserId when is_integer(UserId) -> {true, UserId};
_ -> false
end
end,
maps:to_list(Sessions)
)
).
-spec resolve_mentions(
[member()],
user_id(),
channel_id(),
boolean(),
boolean(),
boolean(),
boolean(),
gb_sets:set(),
gb_sets:set(),
gb_sets:set(),
guild_state()
) -> [user_id()].
resolve_mentions(
Members,
AuthorId,
ChannelId,
MentionEveryone,
MentionHere,
HasRoleMentions,
HasDirectMentions,
RoleIdSet,
DirectUserIdSet,
ConnectedUserIds,
State
) ->
lists:filtermap(
fun(Member) ->
case member_user_id(Member) of
undefined ->
false;
UserId when UserId =:= AuthorId ->
false;
UserId when UserId =:= AuthorId -> false;
UserId ->
case is_member_bot(Member) of
true ->
false;
false ->
ShouldMention =
MentionEveryone orelse
(MentionHere andalso
gb_sets:is_member(UserId, ConnectedUserIds)) orelse
(HasRoleMentions andalso
member_has_any_role_set(Member, RoleIdSet)) orelse
(HasDirectMentions andalso
gb_sets:is_member(UserId, DirectUserIdSet)),
ShouldMention = check_should_mention(
UserId,
Member,
MentionEveryone,
MentionHere,
HasRoleMentions,
HasDirectMentions,
RoleIdSet,
DirectUserIdSet,
ConnectedUserIds
),
case
ShouldMention andalso
member_can_view_channel(UserId, ChannelId, Member, State)
@@ -144,61 +210,206 @@ resolve_all_mentions(
end
end,
Members
),
{reply, #{user_ids => UserIds}, State}.
).
-spec get_members_with_role(map(), guild_state()) -> guild_reply(map()).
get_members_with_role(#{role_id := RoleId}, State) ->
Members = guild_members(State),
TargetRoles = [RoleId],
UserIds = lists:filtermap(
fun(Member) ->
case member_user_id(Member) of
undefined ->
-spec resolve_mentions_for_user_ids(
[user_id()],
#{user_id() => member()},
user_id(),
channel_id(),
boolean(),
boolean(),
boolean(),
boolean(),
gb_sets:set(),
gb_sets:set(),
gb_sets:set(),
guild_state()
) -> [user_id()].
resolve_mentions_for_user_ids(
CandidateUserIds,
MemberMap,
AuthorId,
ChannelId,
MentionEveryone,
MentionHere,
HasRoleMentions,
HasDirectMentions,
RoleIdSet,
DirectUserIdSet,
ConnectedUserIds,
State
) ->
lists:filtermap(
fun(UserId) ->
case UserId =:= AuthorId of
true ->
false;
UserId ->
case member_has_any_role(Member, TargetRoles) of
true -> {true, UserId};
false -> false
false ->
case maps:get(UserId, MemberMap, undefined) of
undefined ->
false;
Member ->
case is_member_bot(Member) of
true ->
false;
false ->
ShouldMention = check_should_mention(
UserId,
Member,
MentionEveryone,
MentionHere,
HasRoleMentions,
HasDirectMentions,
RoleIdSet,
DirectUserIdSet,
ConnectedUserIds
),
case
ShouldMention andalso
member_can_view_channel(UserId, ChannelId, Member, State)
of
true -> {true, UserId};
false -> false
end
end
end
end
end,
Members
lists:usort(CandidateUserIds)
).
-spec candidate_user_ids_for_mentions(
boolean(),
boolean(),
boolean(),
gb_sets:set(),
gb_sets:set(),
gb_sets:set(),
guild_state()
) -> [user_id()].
candidate_user_ids_for_mentions(
MentionHere,
HasRoleMentions,
HasDirectMentions,
RoleIdSet,
DirectUserIdSet,
ConnectedUserIds,
State
) ->
HereSet =
case MentionHere of
true -> ConnectedUserIds;
false -> gb_sets:empty()
end,
RoleUsersSet =
case HasRoleMentions of
true ->
gb_sets:from_list(user_ids_for_any_role(gb_sets:to_list(RoleIdSet), State));
false ->
gb_sets:empty()
end,
DirectSet =
case HasDirectMentions of
true -> DirectUserIdSet;
false -> gb_sets:empty()
end,
gb_sets:to_list(gb_sets:union(HereSet, gb_sets:union(RoleUsersSet, DirectSet))).
-spec user_ids_for_any_role([role_id()], guild_state()) -> [user_id()].
user_ids_for_any_role(RoleIds, State) ->
Data = guild_data(State),
MemberRoleIndex = guild_data_index:member_role_index(Data),
RoleUserIds = lists:foldl(
fun(RoleId, AccSet) ->
case maps:get(RoleId, MemberRoleIndex, undefined) of
undefined ->
AccSet;
UserMap ->
lists:foldl(
fun(UserId, InnerSet) -> gb_sets:add(UserId, InnerSet) end,
AccSet,
maps:keys(UserMap)
)
end
end,
gb_sets:empty(),
RoleIds
),
gb_sets:to_list(RoleUserIds).
-spec check_should_mention(
user_id(),
member(),
boolean(),
boolean(),
boolean(),
boolean(),
gb_sets:set(),
gb_sets:set(),
gb_sets:set()
) -> boolean().
check_should_mention(
UserId,
Member,
MentionEveryone,
MentionHere,
HasRoleMentions,
HasDirectMentions,
RoleIdSet,
DirectUserIdSet,
ConnectedUserIds
) ->
MentionEveryone orelse
(MentionHere andalso gb_sets:is_member(UserId, ConnectedUserIds)) orelse
(HasRoleMentions andalso member_has_any_role_set(Member, RoleIdSet)) orelse
(HasDirectMentions andalso gb_sets:is_member(UserId, DirectUserIdSet)).
-spec get_members_with_role(map(), guild_state()) -> guild_reply(map()).
get_members_with_role(#{role_id := RoleId}, State) ->
Data = guild_data(State),
MemberRoleIndex = guild_data_index:member_role_index(Data),
TargetRoleId = type_conv:to_integer(RoleId),
UserMap =
case TargetRoleId of
undefined ->
#{};
_ ->
maps:get(TargetRoleId, MemberRoleIndex, #{})
end,
UserIds = lists:sort(maps:keys(UserMap)),
{reply, #{user_ids => UserIds}, State}.
-spec can_manage_roles(map(), guild_state()) -> guild_reply(map()).
can_manage_roles(#{user_id := UserId, role_id := RoleId}, State) ->
Data = guild_data(State),
OwnerId = owner_id(State),
Reply =
if
UserId =:= OwnerId ->
true;
true ->
UserPermissions = guild_permissions:get_member_permissions(
UserId, undefined, State
),
case (UserPermissions band constants:manage_roles_permission()) =/= 0 of
false ->
false;
true ->
Roles = maps:get(<<"roles">>, Data, []),
case find_role_by_id(RoleId, Roles) of
undefined ->
false;
Role ->
UserMax = guild_permissions:get_max_role_position(UserId, State),
UserMax > role_position(Role)
end
end
end,
Reply = check_can_manage_roles(UserId, RoleId, OwnerId, Data, State),
{reply, #{can_manage => Reply}, State}.
-spec check_can_manage_roles(user_id(), role_id(), user_id(), map(), guild_state()) -> boolean().
check_can_manage_roles(UserId, _RoleId, UserId, _Data, _State) ->
true;
check_can_manage_roles(UserId, RoleId, _OwnerId, Data, State) ->
UserPermissions = guild_permissions:get_member_permissions(UserId, undefined, State),
case (UserPermissions band constants:manage_roles_permission()) =/= 0 of
false ->
false;
true ->
Roles = guild_data_index:role_index(Data),
case find_role_by_id(RoleId, Roles) of
undefined ->
false;
Role ->
UserMax = guild_permissions:get_max_role_position(UserId, State),
UserMax > role_position(Role)
end
end.
-spec can_manage_role(map(), guild_state()) -> guild_reply(map()).
can_manage_role(#{user_id := UserId, role_id := RoleId}, State) ->
Data = guild_data(State),
Roles = maps:get(<<"roles">>, Data, []),
Roles = guild_data_index:role_index(Data),
Reply =
case find_role_by_id(RoleId, Roles) of
undefined ->
@@ -212,6 +423,7 @@ can_manage_role(#{user_id := UserId, role_id := RoleId}, State) ->
end,
{reply, #{can_manage => Reply}, State}.
-spec compare_role_ids_for_equal_position(user_id(), role_id(), guild_state()) -> boolean().
compare_role_ids_for_equal_position(UserId, TargetRoleId, State) ->
case guild_permissions:find_member_by_user_id(UserId, State) of
undefined ->
@@ -219,9 +431,8 @@ compare_role_ids_for_equal_position(UserId, TargetRoleId, State) ->
Member ->
MemberRoles = member_roles(Member),
Data = guild_data(State),
Roles = maps:get(<<"roles">>, Data, []),
UserHighestRole = get_highest_role(MemberRoles, Roles),
case UserHighestRole of
Roles = guild_data_index:role_index(Data),
case get_highest_role(MemberRoles, Roles) of
undefined ->
false;
HighestRole ->
@@ -230,39 +441,42 @@ compare_role_ids_for_equal_position(UserId, TargetRoleId, State) ->
end
end.
-spec get_highest_role([role_id()], [role()] | map()) -> role() | undefined.
get_highest_role(MemberRoleIds, Roles) ->
lists:foldl(
fun(RoleId, Acc) ->
case find_role_by_id(RoleId, Roles) of
undefined ->
Acc;
Role ->
case Acc of
undefined ->
Role;
AccRole ->
AccPos = role_position(AccRole),
RolePos = role_position(Role),
if
RolePos > AccPos ->
Role;
RolePos =:= AccPos ->
AccId = map_utils:get_integer(AccRole, <<"id">>, 0),
RId = map_utils:get_integer(Role, <<"id">>, 0),
if
RId < AccId -> Role;
true -> AccRole
end;
true ->
AccRole
end
end
undefined -> Acc;
Role -> compare_roles(Role, Acc)
end
end,
undefined,
MemberRoleIds
).
-spec compare_roles(role(), role() | undefined) -> role().
compare_roles(Role, undefined) ->
Role;
compare_roles(Role, AccRole) ->
AccPos = role_position(AccRole),
RolePos = role_position(Role),
case RolePos > AccPos of
true ->
Role;
false ->
case RolePos =:= AccPos of
true ->
AccId = map_utils:get_integer(AccRole, <<"id">>, 0),
RId = map_utils:get_integer(Role, <<"id">>, 0),
case RId < AccId of
true -> Role;
false -> AccRole
end;
false ->
AccRole
end
end.
-spec get_assignable_roles(map(), guild_state()) -> guild_reply(map()).
get_assignable_roles(#{user_id := UserId}, State) ->
Roles = guild_roles(State),
@@ -270,6 +484,7 @@ get_assignable_roles(#{user_id := UserId}, State) ->
RoleIds = get_assignable_role_ids(UserId, OwnerId, Roles, State),
{reply, #{role_ids => RoleIds}, State}.
-spec get_assignable_role_ids(user_id(), user_id(), [role()], guild_state()) -> [role_id()].
get_assignable_role_ids(OwnerId, OwnerId, Roles, _State) ->
role_ids_from_roles(Roles);
get_assignable_role_ids(UserId, _OwnerId, Roles, State) ->
@@ -279,6 +494,7 @@ get_assignable_role_ids(UserId, _OwnerId, Roles, State) ->
Roles
).
-spec filter_assignable_role(role(), integer()) -> {true, role_id()} | false.
filter_assignable_role(Role, UserMaxPosition) ->
case role_position(Role) < UserMaxPosition of
true ->
@@ -293,19 +509,19 @@ filter_assignable_role(Role, UserMaxPosition) ->
-spec check_target_member(map(), guild_state()) -> guild_reply(map()).
check_target_member(#{user_id := UserId, target_user_id := TargetUserId}, State) ->
OwnerId = owner_id(State),
CanManage =
if
UserId =:= OwnerId ->
true;
TargetUserId =:= OwnerId ->
false;
true ->
UserMaxPos = guild_permissions:get_max_role_position(UserId, State),
TargetMaxPos = guild_permissions:get_max_role_position(TargetUserId, State),
UserMaxPos > TargetMaxPos
end,
CanManage = check_can_manage_target(UserId, TargetUserId, OwnerId, State),
{reply, #{can_manage => CanManage}, State}.
-spec check_can_manage_target(user_id(), user_id(), user_id(), guild_state()) -> boolean().
check_can_manage_target(UserId, _TargetUserId, UserId, _State) ->
true;
check_can_manage_target(_UserId, OwnerId, OwnerId, _State) ->
false;
check_can_manage_target(UserId, TargetUserId, _OwnerId, State) ->
UserMaxPos = guild_permissions:get_max_role_position(UserId, State),
TargetMaxPos = guild_permissions:get_max_role_position(TargetUserId, State),
UserMaxPos > TargetMaxPos.
-spec get_viewable_channels(map(), guild_state()) -> guild_reply(map()).
get_viewable_channels(#{user_id := UserId}, State) ->
Channels = guild_channels(State),
@@ -313,29 +529,33 @@ get_viewable_channels(#{user_id := UserId}, State) ->
undefined ->
{reply, #{channel_ids => []}, State};
Member ->
ChannelIds = lists:filtermap(
fun(Channel) ->
ChannelId = map_utils:get_integer(Channel, <<"id">>, undefined),
case ChannelId of
undefined ->
false;
_ ->
case
guild_permissions:can_view_channel(UserId, ChannelId, Member, State)
of
true -> {true, ChannelId};
false -> false
end
end
end,
Channels
),
ChannelIds = filter_viewable_channels(Channels, UserId, Member, State),
{reply, #{channel_ids => ChannelIds}, State}
end.
-spec filter_viewable_channels([channel()], user_id(), member(), guild_state()) -> [channel_id()].
filter_viewable_channels(Channels, UserId, Member, State) ->
lists:filtermap(
fun(Channel) ->
ChannelId = map_utils:get_integer(Channel, <<"id">>, undefined),
case ChannelId of
undefined ->
false;
_ ->
case guild_permissions:can_view_channel(UserId, ChannelId, Member, State) of
true -> {true, ChannelId};
false -> false
end
end
end,
Channels
).
-spec find_member_by_user_id(user_id(), guild_state()) -> member() | undefined.
find_member_by_user_id(UserId, State) ->
guild_permissions:find_member_by_user_id(UserId, State).
-spec find_role_by_id(role_id(), [role()] | map()) -> role() | undefined.
find_role_by_id(RoleId, Roles) ->
guild_permissions:find_role_by_id(RoleId, Roles).
@@ -345,15 +565,15 @@ guild_data(State) ->
-spec guild_members(guild_state()) -> [member()].
guild_members(State) ->
map_utils:ensure_list(maps:get(<<"members">>, guild_data(State), [])).
guild_data_index:member_values(guild_data(State)).
-spec guild_roles(guild_state()) -> [role()].
guild_roles(State) ->
map_utils:ensure_list(maps:get(<<"roles">>, guild_data(State), [])).
guild_data_index:role_list(guild_data(State)).
-spec guild_channels(guild_state()) -> [channel()].
guild_channels(State) ->
map_utils:ensure_list(maps:get(<<"channels">>, guild_data(State), [])).
guild_data_index:channel_list(guild_data(State)).
-spec owner_id(guild_state()) -> user_id().
owner_id(State) ->
@@ -369,11 +589,6 @@ member_user_id(Member) ->
member_roles(Member) ->
normalize_int_list(map_utils:ensure_list(maps:get(<<"roles">>, Member, []))).
-spec member_has_any_role(member(), [role_id()]) -> boolean().
member_has_any_role(Member, RoleIds) ->
MemberRoles = member_roles(Member),
lists:any(fun(RoleId) -> lists:member(RoleId, MemberRoles) end, RoleIds).
-spec member_has_any_role_set(member(), gb_sets:set(role_id())) -> boolean().
member_has_any_role_set(Member, RoleIdSet) ->
MemberRoles = member_roles(Member),
@@ -390,6 +605,39 @@ member_can_view_channel(UserId, ChannelId, Member, State) when is_integer(Channe
member_can_view_channel(_, _, _, _) ->
false.
-spec collect_mentions_for_user_ids(
[user_id()],
user_id(),
channel_id(),
guild_state(),
fun((user_id(), member()) -> boolean())
) ->
[user_id()].
collect_mentions_for_user_ids(UserIds, AuthorId, ChannelId, State, Predicate) ->
MemberMap = guild_data_index:member_map(guild_data(State)),
lists:filtermap(
fun(UserId) ->
case UserId =:= AuthorId of
true ->
false;
false ->
case maps:get(UserId, MemberMap, undefined) of
undefined ->
false;
Member ->
case
Predicate(UserId, Member) andalso
member_can_view_channel(UserId, ChannelId, Member, State)
of
true -> {true, UserId};
false -> false
end
end
end
end,
lists:usort(UserIds)
).
-spec collect_mentions([member()], user_id(), channel_id(), guild_state(), fun(
(member()) -> boolean()
)) ->
@@ -446,6 +694,7 @@ role_position(Role) ->
maps:get(<<"position">>, Role, 0).
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
get_users_to_mention_by_roles_basic_test() ->
State = test_state(),
@@ -455,6 +704,30 @@ get_users_to_mention_by_roles_basic_test() ->
{reply, #{user_ids := UserIds}, _} = get_users_to_mention_by_roles(Request, State),
?assertEqual([2], UserIds).
get_users_to_mention_by_user_ids_filters_unknown_members_test() ->
State = test_state(),
Request = #{channel_id => 500, user_ids => [2, 999], author_id => 1},
{reply, #{user_ids := UserIds}, _} = get_users_to_mention_by_user_ids(Request, State),
?assertEqual([2], UserIds).
get_members_with_role_uses_member_role_index_test() ->
State = test_state(),
{reply, #{user_ids := UserIds}, _} = get_members_with_role(#{role_id => 200}, State),
?assertEqual([2], UserIds).
resolve_all_mentions_merges_role_and_direct_candidates_test() ->
State = test_state(),
Request = #{
channel_id => 500,
author_id => 1,
mention_everyone => false,
mention_here => false,
role_ids => [200],
user_ids => [3]
},
{reply, #{user_ids := UserIds}, _} = resolve_all_mentions(Request, State),
?assertEqual([2, 3], UserIds).
get_assignable_roles_owner_test() ->
State = test_state(),
{reply, #{role_ids := RoleIds}, _} = get_assignable_roles(#{user_id => 1}, State),
@@ -470,6 +743,32 @@ get_viewable_channels_filters_test() ->
{reply, #{channel_ids := ChannelIds}, _} = get_viewable_channels(#{user_id => 2}, State),
?assert(lists:member(500, ChannelIds)).
member_has_any_role_set_test() ->
Member = #{<<"roles">> => [<<"100">>, <<"200">>]},
RoleSet = gb_sets:from_list([200]),
MissingRoleSet = gb_sets:from_list([999]),
?assertEqual(true, member_has_any_role_set(Member, RoleSet)),
?assertEqual(false, member_has_any_role_set(Member, MissingRoleSet)).
is_member_bot_test() ->
BotMember = #{<<"user">> => #{<<"bot">> => true}},
HumanMember = #{<<"user">> => #{<<"bot">> => false}},
?assertEqual(true, is_member_bot(BotMember)),
?assertEqual(false, is_member_bot(HumanMember)).
normalize_int_list_test() ->
?assertEqual([1, 2, 3], normalize_int_list([<<"1">>, <<"2">>, <<"3">>])),
?assertEqual([1, 2], normalize_int_list([1, 2])),
?assertEqual([], normalize_int_list([])).
role_ids_from_roles_test() ->
Roles = [
#{<<"id">> => <<"100">>},
#{<<"id">> => <<"200">>},
#{<<"name">> => <<"no_id">>}
],
?assertEqual([100, 200], role_ids_from_roles(Roles)).
test_state() ->
GuildId = 100,
OwnerId = 1,

View File

@@ -21,7 +21,8 @@
schedule_passive_sync/1,
handle_passive_sync/1,
send_passive_updates_to_sessions/1,
compute_delta/2
compute_delta/2,
compute_channel_diffs/2
]).
-ifdef(TEST).
@@ -30,71 +31,156 @@
-define(PASSIVE_SYNC_INTERVAL, 30000).
-type guild_state() :: map().
-type channel_id() :: binary().
-type last_message_id() :: binary().
-type version() :: integer().
-type voice_state() :: map().
-spec schedule_passive_sync(guild_state()) -> guild_state().
schedule_passive_sync(State) ->
erlang:send_after(?PASSIVE_SYNC_INTERVAL, self(), passive_sync),
State.
-spec handle_passive_sync(guild_state()) -> {noreply, guild_state()}.
handle_passive_sync(State) ->
NewState = send_passive_updates_to_sessions(State),
schedule_passive_sync(NewState),
{noreply, NewState}.
-spec send_passive_updates_to_sessions(guild_state()) -> guild_state().
send_passive_updates_to_sessions(State) ->
GuildId = maps:get(id, State),
Sessions = maps:get(sessions, State, #{}),
Data = maps:get(data, State, #{}),
Channels = maps:get(<<"channels">>, Data, []),
MemberCount = maps:get(member_count, State, undefined),
IsLargeGuild = case MemberCount of
undefined -> false;
Count when is_integer(Count) -> Count > 250
end,
IsLargeGuild =
case MemberCount of
undefined -> false;
Count when is_integer(Count) -> Count > 250
end,
PassiveSessions = maps:filter(
fun(_SessionId, SessionData) ->
IsLargeGuild andalso session_passive:is_passive(GuildId, SessionData)
end,
Sessions
),
case map_size(PassiveSessions) of
0 ->
State;
_ ->
UpdatedSessions = lists:foldl(
fun({SessionId, SessionData}, AccSessions) ->
Pid = maps:get(pid, SessionData),
UserId = maps:get(user_id, SessionData),
Member = guild_permissions:find_member_by_user_id(UserId, State),
CurrentLastMessageIds = build_last_message_ids(Channels, UserId, Member, State),
PreviousLastMessageIds = maps:get(previous_passive_updates, SessionData, #{}),
Delta = compute_delta(CurrentLastMessageIds, PreviousLastMessageIds),
case {map_size(Delta), is_pid(Pid)} of
{0, _} ->
AccSessions;
{_, true} ->
EventData = #{
<<"guild_id">> => integer_to_binary(GuildId),
<<"channels">> => Delta
},
gen_server:cast(Pid, {dispatch, passive_updates, EventData}),
MergedLastMessageIds = maps:merge(PreviousLastMessageIds, Delta),
UpdatedSessionData = maps:put(previous_passive_updates, MergedLastMessageIds, SessionData),
maps:put(SessionId, UpdatedSessionData, AccSessions);
_ ->
AccSessions
end
end,
Sessions,
maps:to_list(PassiveSessions)
UpdatedSessions = process_passive_sessions(
maps:to_list(PassiveSessions), GuildId, Sessions, Channels, State
),
maps:put(sessions, UpdatedSessions, State)
end.
-spec process_passive_sessions([{binary(), map()}], integer(), map(), [map()], guild_state()) ->
map().
process_passive_sessions(PassiveSessionList, GuildId, Sessions, Channels, State) ->
lists:foldl(
fun({SessionId, SessionData}, AccSessions) ->
process_single_passive_session(
SessionId, SessionData, GuildId, Channels, State, AccSessions
)
end,
Sessions,
PassiveSessionList
).
-spec process_single_passive_session(binary(), map(), integer(), [map()], guild_state(), map()) ->
map().
process_single_passive_session(SessionId, SessionData, GuildId, Channels, State, AccSessions) ->
Pid = maps:get(pid, SessionData),
UserId = maps:get(user_id, SessionData),
Member = guild_permissions:find_member_by_user_id(UserId, State),
CurrentLastMessageIds = build_last_message_ids(Channels, UserId, Member, State),
PreviousLastMessageIds = maps:get(previous_passive_updates, SessionData, #{}),
Delta = compute_delta(CurrentLastMessageIds, PreviousLastMessageIds),
PreviousChannelVersions = maps:get(previous_passive_channel_versions, SessionData, #{}),
{CurrentChannelVersions, CurrentChannelsById} =
build_viewable_channel_snapshots(Channels, UserId, Member, State),
{CreatedChannelIds, UpdatedChannelIds, DeletedChannelIds} =
compute_channel_diffs(CurrentChannelVersions, PreviousChannelVersions),
CreatedChannels = [maps:get(Id, CurrentChannelsById) || Id <- CreatedChannelIds],
UpdatedChannels = [maps:get(Id, CurrentChannelsById) || Id <- UpdatedChannelIds],
ViewableChannels = guild_visibility:viewable_channel_set(UserId, State),
CurrentVoiceStates = build_current_voice_state_map(ViewableChannels, State),
PreviousVoiceStates = maps:get(previous_passive_voice_states, SessionData, #{}),
VoiceStateUpdates = compute_voice_state_updates(
CurrentVoiceStates, PreviousVoiceStates, GuildId
),
UpdatedSessionDataBase =
maps:put(previous_passive_voice_states, CurrentVoiceStates, SessionData),
HasChannelDelta = map_size(Delta) > 0,
HasVoiceUpdates = VoiceStateUpdates =/= [],
HasCreatedChannels = CreatedChannels =/= [],
HasUpdatedChannels = UpdatedChannels =/= [],
HasDeletedChannels = DeletedChannelIds =/= [],
ShouldSend =
HasChannelDelta orelse HasVoiceUpdates orelse
HasCreatedChannels orelse HasUpdatedChannels orelse HasDeletedChannels,
case {ShouldSend, is_pid(Pid)} of
{true, true} ->
EventData = build_passive_event_data(
GuildId,
Delta,
CreatedChannels,
UpdatedChannels,
DeletedChannelIds,
VoiceStateUpdates
),
gen_server:cast(Pid, {dispatch, passive_updates, EventData}),
PreviousLastMessageIds1 = maps:without(DeletedChannelIds, PreviousLastMessageIds),
MergedLastMessageIds = maps:merge(PreviousLastMessageIds1, Delta),
UpdatedSessionData0 =
maps:put(previous_passive_updates, MergedLastMessageIds, UpdatedSessionDataBase),
UpdatedSessionData =
maps:put(
previous_passive_channel_versions, CurrentChannelVersions, UpdatedSessionData0
),
maps:put(SessionId, UpdatedSessionData, AccSessions);
_ ->
UpdatedSessionData =
maps:put(
previous_passive_channel_versions,
CurrentChannelVersions,
UpdatedSessionDataBase
),
maps:put(SessionId, UpdatedSessionData, AccSessions)
end.
-spec build_passive_event_data(integer(), map(), [map()], [map()], [binary()], [map()]) -> map().
build_passive_event_data(
GuildId, Delta, CreatedChannels, UpdatedChannels, DeletedChannelIds, VoiceStateUpdates
) ->
EventDataBase = #{
<<"guild_id">> => integer_to_binary(GuildId),
<<"channels">> => Delta
},
EventData1 =
case CreatedChannels of
[] -> EventDataBase;
_ -> maps:put(<<"created_channels">>, CreatedChannels, EventDataBase)
end,
EventData2 =
case UpdatedChannels of
[] -> EventData1;
_ -> maps:put(<<"updated_channels">>, UpdatedChannels, EventData1)
end,
EventData3 =
case DeletedChannelIds of
[] -> EventData2;
_ -> maps:put(<<"deleted_channel_ids">>, DeletedChannelIds, EventData2)
end,
case VoiceStateUpdates of
[] -> EventData3;
_ -> maps:put(<<"voice_states">>, VoiceStateUpdates, EventData3)
end.
-spec compute_delta(#{channel_id() => last_message_id()}, #{channel_id() => last_message_id()}) ->
#{channel_id() => last_message_id()}.
compute_delta(CurrentLastMessageIds, PreviousLastMessageIds) ->
maps:filter(
fun(ChannelId, CurrentValue) ->
@@ -106,6 +192,23 @@ compute_delta(CurrentLastMessageIds, PreviousLastMessageIds) ->
CurrentLastMessageIds
).
-spec compute_channel_diffs(#{channel_id() => version()}, #{channel_id() => version()}) ->
{[channel_id()], [channel_id()], [channel_id()]}.
compute_channel_diffs(Current, Previous) ->
Created =
[Id || {Id, _} <- maps:to_list(Current), not maps:is_key(Id, Previous)],
Updated =
[
Id
|| {Id, CurV} <- maps:to_list(Current),
maps:is_key(Id, Previous) andalso maps:get(Id, Previous) =/= CurV
],
Deleted =
[Id || {Id, _} <- maps:to_list(Previous), not maps:is_key(Id, Current)],
{Created, Updated, Deleted}.
-spec build_last_message_ids([map()], integer(), map() | undefined, guild_state()) ->
#{channel_id() => last_message_id()}.
build_last_message_ids(Channels, UserId, Member, State) ->
lists:foldl(
fun(Channel, Acc) ->
@@ -117,16 +220,22 @@ build_last_message_ids(Channels, UserId, Member, State) ->
{_, null} ->
Acc;
_ ->
ChannelId = validation:snowflake_or_default(<<"id">>, ChannelIdBin, 0),
case Member of
case parse_snowflake(<<"id">>, ChannelIdBin) of
undefined ->
Acc;
_ ->
case guild_permissions:can_view_channel(UserId, ChannelId, Member, State) of
true ->
maps:put(ChannelIdBin, LastMessageId, Acc);
false ->
Acc
ChannelId ->
case Member of
undefined ->
Acc;
_ ->
case
guild_permissions:can_view_channel(UserId, ChannelId, Member, State)
of
true ->
maps:put(ChannelIdBin, LastMessageId, Acc);
false ->
Acc
end
end
end
end
@@ -135,41 +244,196 @@ build_last_message_ids(Channels, UserId, Member, State) ->
Channels
).
-spec build_viewable_channel_snapshots([map()], integer(), map() | undefined, guild_state()) ->
{#{channel_id() => version()}, #{channel_id() => map()}}.
build_viewable_channel_snapshots(Channels, UserId, Member, State) ->
case Member of
undefined ->
{#{}, #{}};
_ ->
lists:foldl(
fun(Channel, {VersionsAcc, ChannelsAcc}) ->
ChannelIdBin = maps:get(<<"id">>, Channel, undefined),
case ChannelIdBin of
undefined ->
{VersionsAcc, ChannelsAcc};
_ ->
case parse_snowflake(<<"id">>, ChannelIdBin) of
undefined ->
{VersionsAcc, ChannelsAcc};
ChannelId ->
case
guild_permissions:can_view_channel(UserId, ChannelId, Member, State)
of
true ->
Version = map_utils:get_integer(Channel, <<"version">>, 0),
{
maps:put(ChannelIdBin, Version, VersionsAcc),
maps:put(ChannelIdBin, Channel, ChannelsAcc)
};
false ->
{VersionsAcc, ChannelsAcc}
end
end
end
end,
{#{}, #{}},
Channels
)
end.
-spec build_current_voice_state_map(sets:set(), guild_state()) -> #{binary() => voice_state()}.
build_current_voice_state_map(ViewableChannels, State) ->
VoiceStates = maps:get(voice_states, State, #{}),
maps:fold(
fun(ConnectionId, VoiceState, Acc) ->
ChannelIdBin = maps:get(<<"channel_id">>, VoiceState, null),
case ChannelIdBin of
null ->
Acc;
_ ->
case parse_snowflake(<<"channel_id">>, ChannelIdBin) of
undefined ->
Acc;
ChannelId ->
case sets:is_element(ChannelId, ViewableChannels) of
true -> maps:put(ConnectionId, VoiceState, Acc);
false -> Acc
end
end
end
end,
#{},
VoiceStates
).
-spec parse_snowflake(binary(), term()) -> integer() | undefined.
parse_snowflake(FieldName, Value) ->
case validation:validate_snowflake(FieldName, Value) of
{ok, Id} -> Id;
{error, _, _} -> undefined
end.
-spec compute_voice_state_updates(
#{binary() => voice_state()}, #{binary() => voice_state()}, integer()
) ->
[voice_state()].
compute_voice_state_updates(Current, Previous, GuildId) ->
GuildIdBin = integer_to_binary(GuildId),
Updated = maps:fold(
fun(ConnectionId, VoiceState, Acc) ->
PrevState = maps:get(ConnectionId, Previous, undefined),
case is_voice_state_changed(VoiceState, PrevState) of
true -> [ensure_voice_state_guild(VoiceState, GuildIdBin) | Acc];
false -> Acc
end
end,
[],
Current
),
Removed = maps:fold(
fun(ConnectionId, PrevState, Acc) ->
case maps:is_key(ConnectionId, Current) of
true -> Acc;
false -> [build_removed_voice_state(PrevState, GuildIdBin) | Acc]
end
end,
[],
Previous
),
lists:reverse(Updated) ++ lists:reverse(Removed).
-spec is_voice_state_changed(voice_state(), voice_state() | undefined) -> boolean().
is_voice_state_changed(_Current, undefined) ->
true;
is_voice_state_changed(Current, Previous) ->
voice_state_version(Current) =/= voice_state_version(Previous).
-spec voice_state_version(voice_state()) -> integer().
voice_state_version(VoiceState) ->
map_utils:get_integer(VoiceState, <<"version">>, 0).
-spec ensure_voice_state_guild(voice_state(), binary()) -> voice_state().
ensure_voice_state_guild(VoiceState, GuildIdBin) ->
case maps:get(<<"guild_id">>, VoiceState, undefined) of
undefined -> maps:put(<<"guild_id">>, GuildIdBin, VoiceState);
_ -> VoiceState
end.
-spec build_removed_voice_state(voice_state(), binary()) -> voice_state().
build_removed_voice_state(PrevState, GuildIdBin) ->
PrevWithGuild = ensure_voice_state_guild(PrevState, GuildIdBin),
maps:put(<<"channel_id">>, null, PrevWithGuild).
-ifdef(TEST).
compute_delta_empty_previous_test() ->
Current = #{<<"1">> => <<"100">>, <<"2">> => <<"200">>},
Previous = #{},
Delta = compute_delta(Current, Previous),
?assertEqual(Current, Delta),
ok.
?assertEqual(Current, Delta).
compute_delta_no_changes_test() ->
Current = #{<<"1">> => <<"100">>, <<"2">> => <<"200">>},
Previous = #{<<"1">> => <<"100">>, <<"2">> => <<"200">>},
Delta = compute_delta(Current, Previous),
?assertEqual(#{}, Delta),
ok.
?assertEqual(#{}, Delta).
compute_delta_partial_changes_test() ->
Current = #{<<"1">> => <<"101">>, <<"2">> => <<"200">>, <<"3">> => <<"300">>},
Previous = #{<<"1">> => <<"100">>, <<"2">> => <<"200">>},
Delta = compute_delta(Current, Previous),
?assertEqual(#{<<"1">> => <<"101">>, <<"3">> => <<"300">>}, Delta),
ok.
?assertEqual(#{<<"1">> => <<"101">>, <<"3">> => <<"300">>}, Delta).
compute_delta_only_new_channels_test() ->
Current = #{<<"1">> => <<"100">>, <<"2">> => <<"200">>, <<"3">> => <<"300">>},
Previous = #{<<"1">> => <<"100">>, <<"2">> => <<"200">>},
Delta = compute_delta(Current, Previous),
?assertEqual(#{<<"3">> => <<"300">>}, Delta),
ok.
?assertEqual(#{<<"3">> => <<"300">>}, Delta).
compute_delta_ignores_removed_channels_test() ->
Current = #{<<"1">> => <<"100">>},
Previous = #{<<"1">> => <<"100">>, <<"2">> => <<"200">>},
Delta = compute_delta(Current, Previous),
?assertEqual(#{}, Delta),
ok.
?assertEqual(#{}, Delta).
compute_channel_diffs_detects_created_updated_deleted_test() ->
Current = #{<<"1">> => 2, <<"2">> => 1},
Previous = #{<<"1">> => 1, <<"3">> => 9},
{Created, Updated, Deleted} = compute_channel_diffs(Current, Previous),
?assertEqual([<<"2">>], lists:sort(Created)),
?assertEqual([<<"1">>], lists:sort(Updated)),
?assertEqual([<<"3">>], lists:sort(Deleted)).
compute_voice_state_updates_reports_changes_test() ->
PrevVoiceState = #{
<<"connection_id">> => <<"conn1">>,
<<"channel_id">> => <<"100">>,
<<"user_id">> => <<"200">>,
<<"version">> => 1
},
CurrentVoiceState = maps:put(<<"version">>, 2, PrevVoiceState),
Current = #{<<"conn1">> => CurrentVoiceState},
Previous = #{<<"conn1">> => PrevVoiceState},
Updates = compute_voice_state_updates(Current, Previous, 999),
?assertEqual(1, length(Updates)),
Update = hd(Updates),
?assertEqual(<<"conn1">>, maps:get(<<"connection_id">>, Update)),
?assertEqual(integer_to_binary(999), maps:get(<<"guild_id">>, Update)).
compute_voice_state_updates_reports_removal_test() ->
RemovedVoiceState = #{
<<"connection_id">> => <<"conn2">>,
<<"channel_id">> => <<"200">>,
<<"user_id">> => <<"300">>,
<<"version">> => 3
},
Current = #{},
Previous = #{<<"conn2">> => RemovedVoiceState},
Updates = compute_voice_state_updates(Current, Previous, 101),
?assertEqual(1, length(Updates)),
Update = hd(Updates),
?assertEqual(null, maps:get(<<"channel_id">>, Update)),
?assertEqual(integer_to_binary(101), maps:get(<<"guild_id">>, Update)).
-endif.

View File

@@ -0,0 +1,132 @@
%% Copyright (C) 2026 Fluxer Contributors
%%
%% This file is part of Fluxer.
%%
%% Fluxer is free software: you can redistribute it and/or modify
%% it under the terms of the GNU Affero General Public License as published by
%% the Free Software Foundation, either version 3 of the License, or
%% (at your option) any later version.
%%
%% Fluxer is distributed in the hope that it will be useful,
%% but WITHOUT ANY WARRANTY; without even the implied warranty of
%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
%% GNU Affero General Public License for more details.
%%
%% You should have received a copy of the GNU Affero General Public License
%% along with Fluxer. If not, see <https://www.gnu.org/licenses/>.
-module(guild_permission_cache).
-export([
put_state/1,
put_data/2,
delete/1,
get_permissions/3,
get_snapshot/1
]).
-type guild_id() :: integer().
-type user_id() :: integer().
-type channel_id() :: integer().
-type guild_state() :: map().
-type guild_data() :: map().
-define(TABLE, guild_permission_cache).
-spec put_state(guild_state()) -> ok.
put_state(State) when is_map(State) ->
GuildId = maps:get(id, State, undefined),
Data = maps:get(data, State, #{}),
case is_integer(GuildId) of
true ->
put_data(GuildId, Data);
false ->
ok
end;
put_state(_) ->
ok.
-spec put_data(guild_id(), guild_data()) -> ok.
put_data(GuildId, Data) when is_integer(GuildId), is_map(Data) ->
ensure_table(),
NormalizedData = guild_data_index:normalize_data(Data),
Snapshot = #{id => GuildId, data => NormalizedData},
true = ets:insert(?TABLE, {GuildId, Snapshot}),
ok;
put_data(_, _) ->
ok.
-spec delete(guild_id()) -> ok.
delete(GuildId) when is_integer(GuildId) ->
ensure_table(),
_ = ets:delete(?TABLE, GuildId),
ok;
delete(_) ->
ok.
-spec get_permissions(guild_id(), user_id(), channel_id() | undefined) ->
{ok, integer()} | {error, not_found}.
get_permissions(GuildId, UserId, ChannelId) when is_integer(GuildId), is_integer(UserId) ->
case get_snapshot(GuildId) of
{ok, Snapshot} ->
Permissions = guild_permissions:get_member_permissions(UserId, ChannelId, Snapshot),
{ok, Permissions};
{error, not_found} ->
{error, not_found}
end;
get_permissions(_, _, _) ->
{error, not_found}.
-spec get_snapshot(guild_id()) -> {ok, guild_state()} | {error, not_found}.
get_snapshot(GuildId) when is_integer(GuildId) ->
ensure_table(),
case ets:lookup(?TABLE, GuildId) of
[{GuildId, Snapshot}] ->
{ok, Snapshot};
[] ->
{error, not_found}
end;
get_snapshot(_) ->
{error, not_found}.
-spec ensure_table() -> ok.
ensure_table() ->
case ets:whereis(?TABLE) of
undefined ->
try ets:new(?TABLE, [named_table, public, set, {read_concurrency, true}]) of
_ -> ok
catch
error:badarg -> ok
end;
_ ->
ok
end.
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
put_and_get_permissions_test() ->
GuildId = 101,
UserId = 44,
ViewPermission = constants:view_channel_permission(),
Data = #{
<<"guild">> => #{<<"owner_id">> => <<"999">>},
<<"roles">> => [
#{<<"id">> => integer_to_binary(GuildId), <<"permissions">> => integer_to_binary(ViewPermission)}
],
<<"members">> => #{
UserId => #{<<"user">> => #{<<"id">> => integer_to_binary(UserId)}, <<"roles">> => []}
},
<<"channels">> => [
#{<<"id">> => <<"500">>, <<"permission_overwrites">> => []}
]
},
ok = put_data(GuildId, Data),
{ok, Permissions} = get_permissions(GuildId, UserId, 500),
?assert((Permissions band ViewPermission) =/= 0),
ok = delete(GuildId).
missing_guild_returns_not_found_test() ->
?assertEqual({error, not_found}, get_permissions(999999, 1, undefined)).
-endif.

View File

@@ -19,22 +19,21 @@
-define(ALL_PERMISSIONS, 16#FFFFFFFFFFFFFFFF).
-export([get_member_permissions/3]).
-export([can_view_channel/4]).
-export([can_view_channel_by_permissions/4]).
-export([can_manage_channel/3]).
-export([apply_channel_overwrites/5]).
-export([get_max_role_position/2]).
-export([find_member_by_user_id/2]).
-export([find_role_by_id/2]).
-export([find_channel_by_id/2]).
-export([
get_member_permissions/3,
can_view_channel/4,
can_view_channel_by_permissions/4,
can_manage_channel/3,
can_access_message_by_permissions/3,
apply_channel_overwrites/5,
get_max_role_position/2,
find_member_by_user_id/2,
find_role_by_id/2,
find_channel_by_id/2
]).
-export_type([permission/0]).
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-endif.
-type permission() :: non_neg_integer().
-type user_id() :: integer().
-type role_id() :: integer().
@@ -61,13 +60,48 @@ can_view_channel(UserId, ChannelId, Member, State) ->
-spec can_view_channel_by_permissions(user_id(), channel_id(), maybe_member(), guild_state()) ->
boolean().
can_view_channel_by_permissions(UserId, ChannelId, Member, State) ->
(compute_member_permissions(UserId, ChannelId, Member, State) band
constants:view_channel_permission()) =/= 0.
Perms = compute_member_permissions(UserId, ChannelId, Member, State),
(Perms band constants:view_channel_permission()) =/= 0.
-spec can_manage_channel(user_id(), maybe_channel_id(), guild_state()) -> boolean().
can_manage_channel(UserId, ChannelId, State) ->
(get_member_permissions(UserId, ChannelId, State) band
constants:manage_channels_permission()) =/= 0.
Perms = get_member_permissions(UserId, ChannelId, State),
(Perms band constants:manage_channels_permission()) =/= 0.
-spec can_access_message_by_permissions(permission(), binary(), guild_state()) -> boolean().
can_access_message_by_permissions(Permissions, MessageId, State) ->
HasReadHistory = (Permissions band constants:read_message_history_permission()) =/= 0,
case HasReadHistory of
true ->
true;
false ->
case get_message_history_cutoff(State) of
null ->
false;
CutoffMs ->
MessageMs = snowflake_util:extract_timestamp(MessageId),
MessageMs >= CutoffMs
end
end.
-spec get_message_history_cutoff(guild_state()) -> integer() | null.
get_message_history_cutoff(State) ->
case resolve_data_map(State) of
undefined ->
null;
Data ->
Guild = maps:get(<<"guild">>, Data, #{}),
case maps:get(<<"message_history_cutoff">>, Guild, null) of
null ->
null;
CutoffBin when is_binary(CutoffBin) ->
calendar:rfc3339_to_system_time(
binary_to_list(CutoffBin), [{unit, millisecond}]
);
CutoffInt when is_integer(CutoffInt) ->
CutoffInt
end
end.
-spec apply_channel_overwrites(permission(), user_id(), member_roles(), channel(), role_id()) ->
permission().
@@ -86,64 +120,51 @@ get_max_role_position(UserId, State) ->
{_, undefined} ->
-1;
{Member, Data} ->
Roles = ensure_list(maps:get(<<"roles">>, Data, [])),
lists:foldl(
fun(RoleId, MaxPos) ->
case find_role_by_id(RoleId, Roles) of
undefined ->
MaxPos;
Role ->
Position = maps:get(<<"position">>, Role, 0),
max(Position, MaxPos)
end
end,
-1,
member_role_ids(Member)
)
Roles = guild_data_index:role_index(Data),
compute_max_position(Member, Roles)
end.
-spec compute_max_position(member(), [role()] | map()) -> integer().
compute_max_position(Member, Roles) ->
lists:foldl(
fun(RoleId, MaxPos) ->
case find_role_by_id(RoleId, Roles) of
undefined ->
MaxPos;
Role ->
Position = maps:get(<<"position">>, Role, 0),
max(Position, MaxPos)
end
end,
-1,
member_role_ids(Member)
).
-spec find_member_by_user_id(user_id(), guild_state()) -> member() | undefined.
find_member_by_user_id(UserId, State) when is_integer(UserId) ->
case resolve_data_map(State) of
undefined ->
undefined;
Data ->
Members = ensure_list(maps:get(<<"members">>, Data, [])),
lists:foldl(
fun(Member, Acc) ->
case Acc of
undefined ->
MUser = maps:get(<<"user">>, Member, #{}),
MemberId = to_int(maps:get(<<"id">>, MUser, <<"0">>)),
case MemberId =:= UserId of
true -> Member;
false -> undefined
end;
Found ->
Found
end
end,
undefined,
Members
)
Members = guild_data_index:member_map(Data),
maps:get(UserId, Members, undefined)
end;
find_member_by_user_id(_, _) ->
undefined.
-spec find_role_by_id(role_id(), list()) -> role() | undefined.
-spec find_role_by_id(role_id(), [role()] | map()) -> role() | undefined.
find_role_by_id(RoleId, Roles) when is_map(Roles) ->
maps:get(to_int(RoleId), Roles, undefined);
find_role_by_id(RoleId, Roles) ->
TargetId = to_int(RoleId),
lists:foldl(
fun(Role, Acc) ->
case Acc of
undefined ->
case role_id(Role) =:= TargetId of
true -> Role;
false -> undefined
end;
Found ->
Found
end
fun
(_, Found) when Found =/= undefined -> Found;
(Role, undefined) ->
case role_id(Role) =:= TargetId of
true -> Role;
false -> undefined
end
end,
undefined,
ensure_list(Roles)
@@ -155,23 +176,8 @@ find_channel_by_id(ChannelId, State) when is_integer(ChannelId) ->
undefined ->
undefined;
Data ->
Channels = ensure_list(maps:get(<<"channels">>, Data, [])),
lists:foldl(
fun(Channel, Acc) ->
case Acc of
undefined ->
ChanId = to_int(maps:get(<<"id">>, Channel, <<"0">>)),
case ChanId =:= ChannelId of
true -> Channel;
false -> undefined
end;
Found ->
Found
end
end,
undefined,
Channels
)
Channels = guild_data_index:channel_index(Data),
maps:get(ChannelId, Channels, undefined)
end;
find_channel_by_id(_, _) ->
undefined.
@@ -188,36 +194,36 @@ compute_member_permissions(UserId, ChannelId, ProvidedMember, State) when is_int
true ->
?ALL_PERMISSIONS;
false ->
case resolve_member(UserId, ProvidedMember, State) of
undefined ->
0;
Member ->
GuildId = guild_id(State),
Roles = ensure_list(maps:get(<<"roles">>, Data, [])),
BasePermissions = base_role_permissions(GuildId, Roles),
MemberRoles = member_role_ids(Member),
Permissions = aggregate_role_permissions(
MemberRoles, Roles, BasePermissions
),
case (Permissions band constants:administrator_permission()) =/= 0 of
true ->
?ALL_PERMISSIONS;
false ->
maybe_apply_channel_overwrites(
Permissions,
UserId,
MemberRoles,
ChannelId,
GuildId,
State
)
end
end
compute_non_owner_permissions(UserId, ChannelId, ProvidedMember, State, Data)
end
end;
compute_member_permissions(_, _, _, _) ->
0.
-spec compute_non_owner_permissions(
user_id(), maybe_channel_id(), maybe_member(), guild_state(), guild_data()
) ->
permission().
compute_non_owner_permissions(UserId, ChannelId, ProvidedMember, State, Data) ->
case resolve_member(UserId, ProvidedMember, State) of
undefined ->
0;
Member ->
GuildId = guild_id(State),
Roles = guild_data_index:role_index(Data),
BasePermissions = base_role_permissions(GuildId, Roles),
MemberRoles = member_role_ids(Member),
Permissions = aggregate_role_permissions(MemberRoles, Roles, BasePermissions),
case (Permissions band constants:administrator_permission()) =/= 0 of
true ->
?ALL_PERMISSIONS;
false ->
maybe_apply_channel_overwrites(
Permissions, UserId, MemberRoles, ChannelId, GuildId, State
)
end
end.
-spec resolve_member(user_id(), maybe_member(), guild_state()) -> maybe_member().
resolve_member(_UserId, Member, _State) when is_map(Member) ->
Member;
@@ -232,36 +238,25 @@ guild_owner_id(Data) ->
-spec guild_id(guild_state()) -> integer().
guild_id(State) ->
case maps:get(id, State, undefined) of
undefined ->
to_int(maps:get(<<"id">>, State, 0));
GuildId when is_integer(GuildId) ->
GuildId;
GuildId ->
to_int(GuildId)
undefined -> to_int(maps:get(<<"id">>, State, 0));
GuildId when is_integer(GuildId) -> GuildId;
GuildId -> to_int(GuildId)
end.
-spec base_role_permissions(role_id(), list()) -> permission().
-spec base_role_permissions(role_id(), map()) -> permission().
base_role_permissions(GuildId, Roles) ->
lists:foldl(
fun(Role, Acc) ->
case role_id(Role) =:= GuildId of
true -> role_permissions(Role);
false -> Acc
end
end,
0,
ensure_list(Roles)
).
case find_role_by_id(GuildId, Roles) of
undefined -> 0;
Role -> role_permissions(Role)
end.
-spec aggregate_role_permissions(member_roles(), list(), permission()) -> permission().
-spec aggregate_role_permissions(member_roles(), [role()] | map(), permission()) -> permission().
aggregate_role_permissions(MemberRoles, Roles, BasePermissions) ->
lists:foldl(
fun(RoleId, Acc) ->
case find_role_by_id(RoleId, Roles) of
undefined ->
Acc;
Role ->
Acc bor role_permissions(Role)
undefined -> Acc;
Role -> Acc bor role_permissions(Role)
end
end,
BasePermissions,
@@ -277,10 +272,8 @@ maybe_apply_channel_overwrites(Permissions, UserId, MemberRoles, ChannelId, Guil
is_integer(ChannelId)
->
case find_channel_by_id(ChannelId, State) of
undefined ->
Permissions;
Channel ->
apply_channel_overwrites(Permissions, UserId, MemberRoles, Channel, GuildId)
undefined -> Permissions;
Channel -> apply_channel_overwrites(Permissions, UserId, MemberRoles, Channel, GuildId)
end;
maybe_apply_channel_overwrites(Permissions, _UserId, _MemberRoles, _ChannelId, _GuildId, _State) ->
Permissions.
@@ -327,10 +320,8 @@ accumulate_role_overwrites(MemberRoles, Overwrites) ->
lists:foldl(
fun(Overwrite, {A, D}) ->
case overwrite_matches_role(Overwrite, RoleId) of
true ->
{A bor overwrite_allow(Overwrite), D bor overwrite_deny(Overwrite)};
false ->
{A, D}
true -> {A bor overwrite_allow(Overwrite), D bor overwrite_deny(Overwrite)};
false -> {A, D}
end
end,
{AllowAcc, DenyAcc},
@@ -406,10 +397,8 @@ extract_integer_list(_) ->
[].
-spec ensure_list(term()) -> list().
ensure_list(List) when is_list(List) ->
List;
ensure_list(_) ->
[].
ensure_list(List) when is_list(List) -> List;
ensure_list(_) -> [].
-spec to_int(term()) -> integer().
to_int(Value) ->
@@ -421,22 +410,20 @@ to_int(Value) ->
-spec resolve_data_map(guild_state() | map()) -> guild_data() | undefined.
resolve_data_map(State) when is_map(State) ->
case maps:find(data, State) of
{ok, Data} when is_map(Data) ->
Data;
{ok, Data} when is_map(Data) =:= false ->
{ok, Data} when is_map(Data) -> Data;
{ok, Data} ->
Data;
error ->
case State of
#{<<"members">> := _} ->
State;
_ ->
undefined
case maps:is_key(<<"members">>, State) of
true -> State;
false -> undefined
end
end;
resolve_data_map(_) ->
undefined.
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
owner_receives_full_permissions_test() ->
OwnerId = 1,
@@ -537,19 +524,16 @@ administrator_role_grants_all_permissions_test() ->
data => #{
<<"guild">> => #{<<"owner_id">> => integer_to_binary(OwnerId)},
<<"roles">> => [
#{<<"id">> => integer_to_binary(GuildId), <<"permissions">> => integer_to_binary(Admin)}
#{
<<"id">> => integer_to_binary(GuildId),
<<"permissions">> => integer_to_binary(Admin)
}
],
<<"members">> => [
#{
<<"user">> => #{<<"id">> => integer_to_binary(UserId)},
<<"roles">> => []
}
#{<<"user">> => #{<<"id">> => integer_to_binary(UserId)}, <<"roles">> => []}
],
<<"channels">> => [
#{
<<"id">> => integer_to_binary(ChannelId),
<<"permission_overwrites">> => []
}
#{<<"id">> => integer_to_binary(ChannelId), <<"permission_overwrites">> => []}
]
}
},
@@ -557,4 +541,122 @@ administrator_role_grants_all_permissions_test() ->
?assertEqual(?ALL_PERMISSIONS, get_member_permissions(UserId, ChannelId, State)),
?assert(can_view_channel(UserId, ChannelId, undefined, State)).
find_member_by_user_id_found_test() ->
State = #{
data => #{
<<"members">> => [
#{<<"user">> => #{<<"id">> => <<"123">>}, <<"nick">> => <<"Test">>}
]
}
},
Result = find_member_by_user_id(123, State),
?assertEqual(<<"Test">>, maps:get(<<"nick">>, Result)).
find_member_by_user_id_not_found_test() ->
State = #{data => #{<<"members">> => []}},
?assertEqual(undefined, find_member_by_user_id(123, State)).
find_member_by_user_id_map_storage_test() ->
State = #{
data => #{
<<"members">> => #{
321 => #{<<"user">> => #{<<"id">> => <<"321">>}, <<"nick">> => <<"Mapped">>}
}
}
},
Result = find_member_by_user_id(321, State),
?assertEqual(<<"Mapped">>, maps:get(<<"nick">>, Result)).
find_role_by_id_found_test() ->
Roles = [#{<<"id">> => <<"100">>, <<"name">> => <<"Admin">>}],
Result = find_role_by_id(100, Roles),
?assertEqual(<<"Admin">>, maps:get(<<"name">>, Result)).
find_role_by_id_not_found_test() ->
Roles = [#{<<"id">> => <<"100">>}],
?assertEqual(undefined, find_role_by_id(999, Roles)).
find_role_by_id_map_index_test() ->
Roles = #{
100 => #{<<"id">> => <<"100">>, <<"name">> => <<"Admin">>}
},
Result = find_role_by_id(100, Roles),
?assertEqual(<<"Admin">>, maps:get(<<"name">>, Result)).
find_channel_by_id_with_index_test() ->
State = #{
data => #{
<<"channels">> => [#{<<"id">> => <<"900">>, <<"name">> => <<"general">>}],
<<"channel_index">> => #{900 => #{<<"id">> => <<"900">>, <<"name">> => <<"general">>}}
}
},
Result = find_channel_by_id(900, State),
?assertEqual(<<"general">>, maps:get(<<"name">>, Result)).
to_int_test() ->
?assertEqual(123, to_int(123)),
?assertEqual(123, to_int(<<"123">>)),
?assertEqual(0, to_int(undefined)).
ensure_list_test() ->
?assertEqual([1, 2], ensure_list([1, 2])),
?assertEqual([], ensure_list(undefined)),
?assertEqual([], ensure_list(#{})).
can_access_message_with_read_history_test() ->
ReadHistory = constants:read_message_history_permission(),
State = #{data => #{<<"guild">> => #{}}},
MessageId = <<"100">>,
?assertEqual(true, can_access_message_by_permissions(ReadHistory, MessageId, State)).
can_access_message_no_read_history_no_cutoff_test() ->
State = #{data => #{<<"guild">> => #{}}},
MessageId = <<"100">>,
?assertEqual(false, can_access_message_by_permissions(0, MessageId, State)).
can_access_message_no_read_history_null_cutoff_test() ->
State = #{data => #{<<"guild">> => #{<<"message_history_cutoff">> => null}}},
MessageId = <<"100">>,
?assertEqual(false, can_access_message_by_permissions(0, MessageId, State)).
can_access_message_no_read_history_message_before_cutoff_test() ->
CutoffMs = 1704067200000,
BeforeCutoffTimestamp = CutoffMs - 60000,
FluxerEpoch = 1420070400000,
RelativeTs = BeforeCutoffTimestamp - FluxerEpoch,
Snowflake = RelativeTs bsl 22,
MessageId = integer_to_binary(Snowflake),
State = #{data => #{<<"guild">> => #{<<"message_history_cutoff">> => CutoffMs}}},
?assertEqual(false, can_access_message_by_permissions(0, MessageId, State)).
can_access_message_no_read_history_message_after_cutoff_test() ->
CutoffMs = 1704067200000,
AfterCutoffTimestamp = CutoffMs + 60000,
FluxerEpoch = 1420070400000,
RelativeTs = AfterCutoffTimestamp - FluxerEpoch,
Snowflake = RelativeTs bsl 22,
MessageId = integer_to_binary(Snowflake),
State = #{data => #{<<"guild">> => #{<<"message_history_cutoff">> => CutoffMs}}},
?assertEqual(true, can_access_message_by_permissions(0, MessageId, State)).
can_access_message_no_read_history_message_at_cutoff_test() ->
CutoffMs = 1704067200000,
FluxerEpoch = 1420070400000,
RelativeTs = CutoffMs - FluxerEpoch,
Snowflake = RelativeTs bsl 22,
MessageId = integer_to_binary(Snowflake),
State = #{data => #{<<"guild">> => #{<<"message_history_cutoff">> => CutoffMs}}},
?assertEqual(true, can_access_message_by_permissions(0, MessageId, State)).
can_access_message_with_rfc3339_cutoff_test() ->
CutoffBin = <<"2024-01-01T00:00:00Z">>,
CutoffMs = calendar:rfc3339_to_system_time("2024-01-01T00:00:00Z", [{unit, millisecond}]),
AfterCutoffTimestamp = CutoffMs + 60000,
FluxerEpoch = 1420070400000,
RelativeTs = AfterCutoffTimestamp - FluxerEpoch,
Snowflake = RelativeTs bsl 22,
MessageId = integer_to_binary(Snowflake),
State = #{data => #{<<"guild">> => #{<<"message_history_cutoff">> => CutoffBin}}},
?assertEqual(true, can_access_message_by_permissions(0, MessageId, State)).
-endif.

View File

@@ -20,8 +20,6 @@
-export([handle_bus_presence/3, send_cached_presence_to_session/3]).
-export([broadcast_presence_update/3]).
-import(guild_sessions, [handle_user_offline/2]).
-type guild_state() :: map().
-type member() :: map().
-type user_id() :: integer().
@@ -31,13 +29,19 @@
-endif.
-spec handle_bus_presence(user_id(), map(), guild_state()) -> {noreply, guild_state()}.
-spec send_cached_presence_to_session(user_id(), binary(), guild_state()) -> guild_state().
handle_bus_presence(UserId, Payload, State) ->
case maps:get(<<"user_update">>, Payload, false) of
true ->
UserData = maps:get(<<"user">>, Payload, #{}),
UpdatedState = handle_user_data_update(UserId, UserData, State),
guild_member_list:broadcast_member_list_updates(UserId, State, UpdatedState),
MemberData = find_member_by_user_id(UserId, UpdatedState),
case is_map(MemberData) of
true ->
maybe_notify_very_large_guild_member_list({upsert_member, MemberData}, UpdatedState);
false ->
maybe_notify_very_large_guild_member_list({notify_member_update, UserId}, UpdatedState)
end,
maybe_broadcast_member_list_updates(UserId, State, UpdatedState),
{noreply, UpdatedState};
false ->
Member = find_member_by_user_id(UserId, State),
@@ -50,7 +54,6 @@ handle_bus_presence(UserId, Payload, State) ->
Status = constants:status_type_atom(NormalizedStatusBin),
Mobile = maps:get(<<"mobile">>, Payload, false),
Afk = maps:get(<<"afk">>, Payload, false),
logger:debug("[guild_presence] Presence update for UserId=~p, Status=~p", [UserId, Status]),
MemberUser = maps:get(<<"user">>, Member, #{}),
CustomStatus = maps:get(<<"custom_status">>, Payload, null),
PresenceMap = presence_payload:build(
@@ -60,29 +63,48 @@ handle_bus_presence(UserId, Payload, State) ->
Afk,
CustomStatus
),
Presences = maps:get(presences, State, #{}),
UpdatedPresences = maps:put(UserId, PresenceMap, Presences),
StateWithPresences = maps:put(presences, UpdatedPresences, State),
broadcast_presence_update(UserId, PresenceMap, StateWithPresences),
logger:debug("[guild_presence] Broadcasting member list updates for UserId=~p", [UserId]),
guild_member_list:broadcast_member_list_updates(UserId, State, StateWithPresences),
StateWithPresence = store_member_presence(UserId, PresenceMap, State),
broadcast_presence_update(UserId, PresenceMap, StateWithPresence),
maybe_notify_very_large_guild_member_list(
{presence_update, UserId, PresenceMap}, StateWithPresence
),
maybe_broadcast_member_list_updates(UserId, State, StateWithPresence),
StateAfterOffline =
case Status of
offline ->
handle_user_offline(UserId, StateWithPresences);
guild_sessions:handle_user_offline(UserId, StateWithPresence);
_ ->
StateWithPresences
StateWithPresence
end,
{noreply, StateAfterOffline}
end
end.
-spec maybe_broadcast_member_list_updates(user_id(), guild_state(), guild_state()) -> ok.
maybe_broadcast_member_list_updates(UserId, OldState, UpdatedState) ->
case maps:get(very_large_guild_coordinator_pid, UpdatedState, undefined) of
Pid when is_pid(Pid) ->
ok;
_ ->
guild_member_list:broadcast_member_list_updates(UserId, OldState, UpdatedState)
end.
-spec maybe_notify_very_large_guild_member_list(term(), guild_state()) -> ok.
maybe_notify_very_large_guild_member_list(NotifyMsg, State) ->
case maps:get(very_large_guild_coordinator_pid, State, undefined) of
CoordPid when is_pid(CoordPid) ->
gen_server:cast(CoordPid, {very_large_guild_member_list_notify, NotifyMsg}),
ok;
_ ->
ok
end.
-spec broadcast_presence_update(user_id(), map(), guild_state()) -> ok.
broadcast_presence_update(UserId, Payload, State) ->
case find_member_by_user_id(UserId, State) of
undefined ->
ok;
_Member ->
_Member ->
GuildId = map_utils:get_integer(State, id, 0),
PresenceUpdate = maps:put(<<"guild_id">>, integer_to_binary(GuildId), Payload),
Sessions = maps:get(sessions, State, #{}),
@@ -90,7 +112,9 @@ broadcast_presence_update(UserId, Payload, State) ->
SubscribedSessionIds = guild_subscriptions:get_subscribed_sessions(UserId, MemberSubs),
TargetChannels = guild_visibility:viewable_channel_set(UserId, State),
{ValidSessionIds, InvalidSessionIds} =
partition_subscribed_sessions(SubscribedSessionIds, Sessions, TargetChannels, UserId, State),
partition_subscribed_sessions(
SubscribedSessionIds, Sessions, TargetChannels, UserId, State
),
StateAfterInvalidRemovals =
lists:foldl(
fun(SessionId, AccState) ->
@@ -122,10 +146,12 @@ broadcast_presence_update(UserId, Payload, State) ->
ok
end.
-spec normalize_presence_status(binary() | term()) -> binary().
normalize_presence_status(<<"invisible">>) -> <<"offline">>;
normalize_presence_status(Status) when is_binary(Status) -> Status;
normalize_presence_status(_) -> <<"offline">>.
-spec send_cached_presence_to_session(user_id(), binary(), guild_state()) -> guild_state().
send_cached_presence_to_session(UserId, SessionId, State) ->
case presence_cache:get(UserId) of
{ok, Payload} ->
@@ -134,6 +160,7 @@ send_cached_presence_to_session(UserId, SessionId, State) ->
State
end.
-spec send_presence_payload_to_session(user_id(), binary(), map(), guild_state()) -> guild_state().
send_presence_payload_to_session(UserId, SessionId, Payload, State) ->
GuildId = map_utils:get_integer(State, id, 0),
Sessions = maps:get(sessions, State, #{}),
@@ -151,7 +178,9 @@ send_presence_payload_to_session(UserId, SessionId, Payload, State) ->
CustomStatus = maps:get(<<"custom_status">>, Payload, null),
PresenceBase =
presence_payload:build(MemberUser, StatusBin, Mobile, Afk, CustomStatus),
PresenceUpdate = maps:put(<<"guild_id">>, integer_to_binary(GuildId), PresenceBase),
PresenceUpdate = maps:put(
<<"guild_id">>, integer_to_binary(GuildId), PresenceBase
),
gen_server:cast(SessionPid, {dispatch, presence_update, PresenceUpdate}),
State
end;
@@ -162,7 +191,7 @@ send_presence_payload_to_session(UserId, SessionId, Payload, State) ->
-spec handle_user_data_update(user_id(), map(), guild_state()) -> guild_state().
handle_user_data_update(UserId, UserData, State) ->
Data = guild_data(State),
Members = guild_members(State),
Members = guild_data_index:member_map(Data),
case find_member_by_user_id(UserId, State) of
undefined ->
State;
@@ -172,13 +201,13 @@ handle_user_data_update(UserId, UserData, State) ->
false ->
State;
true ->
UpdatedMembers = lists:map(
fun(M) ->
maybe_replace_member(M, UserId, UserData)
UpdatedMembers = maps:map(
fun(_MemberUserId, Member0) ->
maybe_replace_member(Member0, UserId, UserData)
end,
Members
),
UpdatedData = maps:put(<<"members">>, UpdatedMembers, Data),
UpdatedData = guild_data_index:put_member_map(UpdatedMembers, Data),
UpdatedState = maps:put(data, UpdatedData, State),
maybe_dispatch_member_update(UserId, UpdatedState),
UpdatedState
@@ -211,15 +240,13 @@ maybe_dispatch_member_update(UserId, State) ->
guild_data(State) ->
map_utils:ensure_map(map_utils:get_safe(State, data, #{})).
-spec guild_members(guild_state()) -> [map()].
guild_members(State) ->
map_utils:ensure_list(maps:get(<<"members">>, guild_data(State), [])).
-spec member_id(map()) -> user_id() | undefined.
member_id(Member) ->
User = map_utils:ensure_map(maps:get(<<"user">>, Member, #{})),
map_utils:get_integer(User, <<"id">>, undefined).
-spec partition_subscribed_sessions([binary()], map(), sets:set(), user_id(), guild_state()) ->
{[binary()], [binary()]}.
partition_subscribed_sessions(SessionIds, Sessions, TargetChannels, TargetUserId, State) ->
lists:foldl(
fun(SessionId, {Valids, Invalids}) ->
@@ -235,8 +262,12 @@ partition_subscribed_sessions(SessionIds, Sessions, TargetChannels, TargetUserId
UserId when UserId =:= TargetUserId ->
false;
_ ->
SessionChannels = guild_visibility:viewable_channel_set(SessionUserId, State),
not sets:is_empty(sets:intersection(SessionChannels, TargetChannels))
SessionChannels = guild_visibility:viewable_channel_set(
SessionUserId, State
),
not sets:is_empty(
sets:intersection(SessionChannels, TargetChannels)
)
end,
case Shared of
true -> {[SessionId | Valids], Invalids};
@@ -248,12 +279,27 @@ partition_subscribed_sessions(SessionIds, Sessions, TargetChannels, TargetUserId
SessionIds
).
-spec remove_session_member_subscription(binary(), user_id(), guild_state()) -> guild_state().
remove_session_member_subscription(SessionId, UserId, State) ->
MemberSubs = maps:get(member_subscriptions, State, guild_subscriptions:init_state()),
NewMemberSubs = guild_subscriptions:unsubscribe(SessionId, UserId, MemberSubs),
State1 = maps:put(member_subscriptions, NewMemberSubs, State),
guild_sessions:unsubscribe_from_user_presence(UserId, State1).
-spec check_user_data_differs(map(), map()) -> boolean().
check_user_data_differs(CurrentUserData, NewUserData) ->
utils:check_user_data_differs(CurrentUserData, NewUserData).
-spec find_member_by_user_id(user_id(), guild_state()) -> member() | undefined.
find_member_by_user_id(UserId, State) ->
guild_permissions:find_member_by_user_id(UserId, State).
-spec store_member_presence(user_id(), map(), guild_state()) -> guild_state().
store_member_presence(UserId, PresenceMap, State) ->
MemberPresence = maps:get(member_presence, State, #{}),
UpdatedPresence = maps:put(UserId, PresenceMap, MemberPresence),
maps:put(member_presence, UpdatedPresence, State).
-ifdef(TEST).
handle_bus_presence_non_member_noop_test() ->
@@ -279,24 +325,24 @@ handle_bus_presence_user_update_test() ->
Payload = #{<<"user">> => UserData, <<"user_update">> => true},
{noreply, NewState} = handle_bus_presence(1, Payload, State),
Data = maps:get(data, NewState),
[Member | _] = maps:get(<<"members">>, Data),
Member = maps:get(1, maps:get(<<"members">>, Data)),
?assertEqual(<<"Updated">>, maps:get(<<"username">>, maps:get(<<"user">>, Member))).
normalize_presence_status_test() ->
?assertEqual(<<"offline">>, normalize_presence_status(<<"invisible">>)),
?assertEqual(<<"online">>, normalize_presence_status(<<"online">>)),
?assertEqual(<<"idle">>, normalize_presence_status(<<"idle">>)),
?assertEqual(<<"offline">>, normalize_presence_status(undefined)).
presence_test_state() ->
#{
id => 42,
data => #{
<<"members">> => [
#{<<"user">> => #{<<"id">> => <<"1">>, <<"username">> => <<"Alpha">>}}
]
<<"members">> => #{
1 => #{<<"user">> => #{<<"id">> => <<"1">>, <<"username">> => <<"Alpha">>}}
}
},
sessions => #{}
}.
-endif.
check_user_data_differs(CurrentUserData, NewUserData) ->
utils:check_user_data_differs(CurrentUserData, NewUserData).
find_member_by_user_id(UserId, State) ->
guild_permissions:find_member_by_user_id(UserId, State).

View File

@@ -27,6 +27,8 @@
-type session_state() :: map().
-type request_data() :: map().
-type member() :: map().
-type presence() :: map().
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
@@ -34,13 +36,10 @@
-spec handle_request(request_data(), pid(), session_state()) -> ok | {error, atom()}.
handle_request(Data, SocketPid, SessionState) when is_map(Data), is_pid(SocketPid) ->
logger:debug("[guild_request_members] Handling guild members request: ~p", [Data]),
case parse_request(Data) of
{ok, Request} ->
logger:debug("[guild_request_members] Request parsed successfully: ~p", [Request]),
process_request(Request, SocketPid, SessionState);
{error, Reason} ->
logger:warning("[guild_request_members] Failed to parse request: ~p", [Reason]),
{error, Reason}
end;
handle_request(_, _, _) ->
@@ -55,7 +54,6 @@ parse_request(Data) ->
Presences = maps:get(<<"presences">>, Data, false),
Nonce = maps:get(<<"nonce">>, Data, null),
NormalizedNonce = normalize_nonce(Nonce),
case validate_guild_id(GuildIdRaw) of
{ok, GuildId} ->
case validate_user_ids(UserIdsRaw) of
@@ -137,25 +135,16 @@ process_request(Request, SocketPid, SessionState) ->
#{guild_id := GuildId, query := Query, limit := Limit, user_ids := UserIds} = Request,
UserIdBin = maps:get(user_id, SessionState),
UserId = type_conv:to_integer(UserIdBin),
logger:debug(
"[guild_request_members] Processing request for guild ~p, user ~p, user_ids: ~p",
[GuildId, UserId, UserIds]
),
case check_permission(UserId, GuildId, Query, Limit, UserIds, SessionState) of
ok ->
logger:debug("[guild_request_members] Permission check passed, fetching members"),
fetch_and_send_members(Request, SocketPid, SessionState);
{error, Reason} ->
logger:warning(
"[guild_request_members] Permission check failed: ~p",
[Reason]
),
{error, Reason}
end.
-spec check_permission(integer(), integer(), binary(), non_neg_integer(), [integer()], session_state()) ->
-spec check_permission(
integer(), integer(), binary(), non_neg_integer(), [integer()], session_state()
) ->
ok | {error, atom()}.
check_permission(UserId, GuildId, Query, Limit, UserIds, SessionState) ->
RequiresPermission = Query =:= <<>> andalso Limit =:= 0 andalso UserIds =:= [],
@@ -177,7 +166,6 @@ check_management_permission(UserId, _GuildId, GuildPid) ->
KickMembers = constants:kick_members_permission(),
BanMembers = constants:ban_members_permission(),
RequiredPermission = ManageRoles bor KickMembers bor BanMembers,
PermRequest = #{
user_id => UserId,
permission => RequiredPermission,
@@ -215,65 +203,45 @@ fetch_and_send_members(Request, _SocketPid, SessionState) ->
nonce := Nonce
} = Request,
SessionId = maps:get(session_id, SessionState),
logger:debug(
"[guild_request_members] Looking up guild ~p for member request",
[GuildId]
),
case lookup_guild(GuildId, SessionState) of
{ok, GuildPid} ->
logger:debug("[guild_request_members] Guild ~p found, fetching members", [GuildId]),
Members = fetch_members(GuildPid, Query, Limit, UserIds),
logger:debug("[guild_request_members] Found ~p members", [length(Members)]),
PresencesList = maybe_fetch_presences(Presences, GuildPid, Members),
send_member_chunks(GuildPid, SessionId, Members, PresencesList, Nonce),
logger:debug(
"[guild_request_members] Sent ~p member chunks for guild ~p with nonce ~p",
[max(1, (length(Members) + ?CHUNK_SIZE - 1) div ?CHUNK_SIZE), GuildId, Nonce]
),
ok;
{error, Reason} ->
logger:warning(
"[guild_request_members] Failed to lookup guild ~p: ~p",
[GuildId, Reason]
),
{error, Reason}
end.
-spec fetch_members(pid(), binary(), non_neg_integer(), [integer()]) -> [map()].
-spec fetch_members(pid(), binary(), non_neg_integer(), [integer()]) -> [member()].
fetch_members(GuildPid, _Query, _Limit, UserIds) when UserIds =/= [] ->
logger:debug("[guild_request_members] Fetching members by user_ids: ~p", [UserIds]),
case gen_server:call(GuildPid, {list_guild_members, #{limit => 100000, offset => 0}}, 10000) of
#{members := AllMembers} ->
logger:debug("[guild_request_members] Got ~p members from guild, filtering by user_ids", [length(AllMembers)]),
Filtered = filter_members_by_ids(AllMembers, UserIds),
logger:debug("[guild_request_members] Filtered to ~p members", [length(Filtered)]),
Filtered;
Other ->
logger:warning("[guild_request_members] Unexpected response from guild: ~p", [Other]),
filter_members_by_ids(AllMembers, UserIds);
_ ->
[]
end;
fetch_members(GuildPid, Query, Limit, []) ->
ActualLimit = case Limit of 0 -> 100000; L -> L end,
logger:debug("[guild_request_members] Fetching members with query '~s', limit ~p", [Query, ActualLimit]),
case gen_server:call(GuildPid, {list_guild_members, #{limit => ActualLimit, offset => 0}}, 10000) of
ActualLimit =
case Limit of
0 -> 100000;
L -> L
end,
case
gen_server:call(GuildPid, {list_guild_members, #{limit => ActualLimit, offset => 0}}, 10000)
of
#{members := AllMembers} ->
logger:debug("[guild_request_members] Got ~p members from guild", [length(AllMembers)]),
Result = case Query of
case Query of
<<>> ->
lists:sublist(AllMembers, ActualLimit);
_ ->
filter_members_by_query(AllMembers, Query, ActualLimit)
end,
logger:debug("[guild_request_members] Returning ~p members after query/filter", [length(Result)]),
Result;
Other ->
logger:warning("[guild_request_members] Unexpected response from guild: ~p", [Other]),
end;
_ ->
[]
end.
-spec filter_members_by_ids([map()], [integer()]) -> [map()].
-spec filter_members_by_ids([member()], [integer()]) -> [member()].
filter_members_by_ids(Members, UserIds) ->
UserIdSet = sets:from_list(UserIds),
lists:filter(
@@ -284,7 +252,7 @@ filter_members_by_ids(Members, UserIds) ->
Members
).
-spec filter_members_by_query([map()], binary(), non_neg_integer()) -> [map()].
-spec filter_members_by_query([member()], binary(), non_neg_integer()) -> [member()].
filter_members_by_query(Members, Query, Limit) ->
NormalizedQuery = string:lowercase(binary_to_list(Query)),
Matches = lists:filter(
@@ -297,50 +265,47 @@ filter_members_by_query(Members, Query, Limit) ->
),
lists:sublist(Matches, Limit).
-spec get_display_name(map()) -> binary().
-spec get_display_name(member()) -> binary().
get_display_name(Member) when is_map(Member) ->
Nick = maps:get(<<"nick">>, Member, undefined),
case Nick of
undefined -> nick_isundefined(Member);
null -> nick_isundefined(Member);
undefined -> get_fallback_name(Member);
null -> get_fallback_name(Member);
_ when is_binary(Nick) -> Nick;
_ -> nick_isundefined(Member)
_ -> get_fallback_name(Member)
end;
get_display_name(_) ->
<<>>.
nick_isundefined(Member) ->
-spec get_fallback_name(member()) -> binary().
get_fallback_name(Member) ->
User = maps:get(<<"user">>, Member, #{}),
GlobalName = maps:get(<<"global_name">>, User, undefined),
case GlobalName of
undefined ->
Username = maps:get(<<"username">>, User, <<>>),
case Username of
null -> <<>>;
undefined -> <<>>;
_ when is_binary(Username) -> Username;
_ -> <<>>
end;
null ->
Username = maps:get(<<"username">>, User, <<>>),
case Username of
null -> <<>>;
undefined -> <<>>;
_ when is_binary(Username) -> Username;
_ -> <<>>
end;
undefined -> get_username(User);
null -> get_username(User);
_ when is_binary(GlobalName) -> GlobalName;
_ -> get_username(User)
end.
-spec get_username(map()) -> binary().
get_username(User) ->
Username = maps:get(<<"username">>, User, <<>>),
case Username of
null -> <<>>;
undefined -> <<>>;
_ when is_binary(Username) -> Username;
_ -> <<>>
end.
-spec extract_user_id(map()) -> integer() | undefined.
-spec extract_user_id(member()) -> integer() | undefined.
extract_user_id(Member) when is_map(Member) ->
User = maps:get(<<"user">>, Member, #{}),
map_utils:get_integer(User, <<"id">>, undefined);
extract_user_id(_) ->
undefined.
-spec maybe_fetch_presences(boolean(), pid(), [map()]) -> [map()].
-spec maybe_fetch_presences(boolean(), pid(), [member()]) -> [presence()].
maybe_fetch_presences(false, _GuildPid, _Members) ->
[];
maybe_fetch_presences(true, _GuildPid, Members) ->
@@ -361,41 +326,30 @@ maybe_fetch_presences(true, _GuildPid, Members) ->
[P || P <- Cached, presence_visible(P)]
end.
-spec presence_visible(map()) -> boolean().
-spec presence_visible(presence()) -> boolean().
presence_visible(P) ->
Status = maps:get(<<"status">>, P, <<"offline">>),
Status =/= <<"offline">> andalso Status =/= <<"invisible">>.
-spec send_member_chunks(pid(), binary(), [map()], [map()], term()) -> ok.
-spec send_member_chunks(pid(), binary(), [member()], [presence()], term()) -> ok.
send_member_chunks(GuildPid, SessionId, Members, Presences, Nonce) ->
TotalChunks = max(1, (length(Members) + ?CHUNK_SIZE - 1) div ?CHUNK_SIZE),
MemberChunks = chunk_list(Members, ?CHUNK_SIZE),
PresenceChunks = chunk_presences(Presences, MemberChunks),
logger:debug(
"[guild_request_members] Sending ~p member chunks (total members: ~p, nonce: ~p)",
[TotalChunks, length(Members), Nonce]
),
lists:foldl(
fun({MemberChunk, PresenceChunk}, ChunkIndex) ->
ChunkData = build_chunk_data(
MemberChunk, PresenceChunk, ChunkIndex, TotalChunks, Nonce
),
logger:debug(
"[guild_request_members] Sending chunk ~p/~p with ~p members, nonce: ~p",
[ChunkIndex + 1, TotalChunks, length(MemberChunk), Nonce]
),
gen_server:cast(GuildPid, {send_members_chunk, SessionId, ChunkData}),
ChunkIndex + 1
end,
0,
lists:zip(MemberChunks, PresenceChunks)
),
logger:debug("[guild_request_members] All chunks sent successfully"),
ok.
-spec build_chunk_data([map()], [map()], non_neg_integer(), non_neg_integer(), term()) ->
-spec build_chunk_data([member()], [presence()], non_neg_integer(), non_neg_integer(), term()) ->
map().
build_chunk_data(Members, Presences, ChunkIndex, TotalChunks, Nonce) ->
Base = #{
@@ -403,15 +357,15 @@ build_chunk_data(Members, Presences, ChunkIndex, TotalChunks, Nonce) ->
<<"chunk_index">> => ChunkIndex,
<<"chunk_count">> => TotalChunks
},
WithPresences = case Presences of
[] -> Base;
_ -> maps:put(<<"presences">>, Presences, Base)
end,
WithNonce = case Nonce of
WithPresences =
case Presences of
[] -> Base;
_ -> maps:put(<<"presences">>, Presences, Base)
end,
case Nonce of
null -> WithPresences;
_ -> maps:put(<<"nonce">>, Nonce, WithPresences)
end,
WithNonce.
end.
-spec chunk_list([T], pos_integer()) -> [[T]] when T :: term().
chunk_list([], _Size) ->
@@ -419,13 +373,14 @@ chunk_list([], _Size) ->
chunk_list(List, Size) ->
chunk_list(List, Size, []).
-spec chunk_list([T], pos_integer(), [[T]]) -> [[T]] when T :: term().
chunk_list([], _Size, Acc) ->
lists:reverse(Acc);
chunk_list(List, Size, Acc) ->
{Chunk, Rest} = lists:split(min(Size, length(List)), List),
chunk_list(Rest, Size, [Chunk | Acc]).
-spec chunk_presences([map()], [[map()]]) -> [[map()]].
-spec chunk_presences([presence()], [[member()]]) -> [[presence()]].
chunk_presences(Presences, MemberChunks) ->
lists:map(
fun(MemberChunk) ->
@@ -491,15 +446,18 @@ display_name_priority_test() ->
<<"nick">> => <<"Nick">>
},
?assertEqual(<<"Nick">>, get_display_name(MemberWithNick)),
MemberWithGlobal = #{
<<"user">> => #{<<"username">> => <<"user">>, <<"global_name">> => <<"Global">>}
},
?assertEqual(<<"Global">>, get_display_name(MemberWithGlobal)),
MemberWithUsername = #{
<<"user">> => #{<<"username">> => <<"user">>}
},
?assertEqual(<<"user">>, get_display_name(MemberWithUsername)).
normalize_nonce_test() ->
?assertEqual(<<"abc">>, normalize_nonce(<<"abc">>)),
?assertEqual(null, normalize_nonce(<<"this_nonce_is_way_too_long_to_be_valid">>)),
?assertEqual(null, normalize_nonce(undefined)).
-endif.

View File

@@ -20,7 +20,9 @@
-export([
handle_session_connect/3,
handle_session_down/2,
remove_session/2,
filter_sessions_for_channel/4,
filter_sessions_for_message/5,
filter_sessions_for_manage_channels/4,
filter_sessions_exclude_session/2,
handle_user_offline/2,
@@ -29,67 +31,79 @@
build_initial_last_message_ids/1,
is_session_active/2,
subscribe_to_user_presence/2,
unsubscribe_from_user_presence/2
unsubscribe_from_user_presence/2,
set_session_viewable_channels/3,
refresh_all_viewable_channels/1
]).
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-endif.
-import(guild_permissions, [can_view_channel/4, can_manage_channel/3, find_member_by_user_id/2]).
-import(guild_data, [get_guild_state/2]).
-import(guild_availability, [is_guild_unavailable_for_user/2]).
-type guild_state() :: map().
-type session_id() :: binary().
-type user_id() :: integer().
-type guild_id() :: integer().
-type channel_id() :: integer().
-type session_data() :: map().
-type sessions_map() :: #{session_id() => session_data()}.
-type session_pair() :: {session_id(), session_data()}.
-spec handle_session_connect(map(), pid(), guild_state()) ->
{reply, {ok, map()} | {ok, unavailable, map()}, guild_state()}.
handle_session_connect(Request, Pid, State) ->
#{session_id := SessionId, user_id := UserId} = Request,
Sessions = maps:get(sessions, State, #{}),
case maps:is_key(SessionId, Sessions) of
true ->
{reply, {ok, guild_data:get_guild_state(UserId, State)}, State};
false ->
register_new_session(Request, Pid, UserId, SessionId, State)
end.
-spec register_new_session(map(), pid(), user_id(), session_id(), guild_state()) ->
{reply, {ok, map()} | {ok, unavailable, map()}, guild_state()}.
register_new_session(Request, Pid, UserId, SessionId, State) ->
Sessions = maps:get(sessions, State, #{}),
ActiveGuilds = maps:get(active_guilds, Request, sets:new()),
InitialGuildId = maps:get(initial_guild_id, Request, undefined),
UserRoles = session_passive:get_user_roles_for_guild(UserId, State),
Bot = maps:get(bot, Request, false),
GuildId = maps:get(id, State),
case maps:is_key(SessionId, Sessions) of
Ref = monitor(process, Pid),
GuildState = guild_data:get_guild_state(UserId, State),
InitialLastMessageIds = build_initial_last_message_ids(GuildState),
InitialChannelVersions = build_initial_channel_versions(GuildState),
InitialViewableChannels = build_viewable_channel_map(
guild_visibility:get_user_viewable_channels(UserId, State)
),
SessionData = #{
session_id => SessionId,
user_id => UserId,
pid => Pid,
mref => Ref,
active_guilds => ActiveGuilds,
user_roles => UserRoles,
bot => Bot,
is_staff => maps:get(is_staff, Request, false),
previous_passive_updates => InitialLastMessageIds,
previous_passive_channel_versions => InitialChannelVersions,
previous_passive_voice_states => #{},
viewable_channels => InitialViewableChannels
},
NewSessions = maps:put(SessionId, SessionData, Sessions),
State1 = maps:put(sessions, NewSessions, State),
State2 = subscribe_to_user_presence(UserId, State1),
_ = maybe_notify_coordinator(session_connected, SessionId, UserId, State2),
case guild_availability:is_guild_unavailable_for_user(UserId, State2) of
true ->
{reply, {ok, get_guild_state(UserId, State)}, State};
false ->
Ref = monitor(process, Pid),
GuildState = get_guild_state(UserId, State),
InitialLastMessageIds = build_initial_last_message_ids(GuildState),
SessionData = #{
session_id => SessionId,
user_id => UserId,
pid => Pid,
mref => Ref,
active_guilds => ActiveGuilds,
user_roles => UserRoles,
bot => Bot,
previous_passive_updates => InitialLastMessageIds
UnavailableResponse = #{
<<"id">> => integer_to_binary(GuildId),
<<"unavailable">> => true
},
NewSessions = maps:put(SessionId, SessionData, Sessions),
State1 = maps:put(sessions, NewSessions, State),
State2 = subscribe_to_user_presence(UserId, State1),
case is_guild_unavailable_for_user(UserId, State2) of
true ->
GuildId = maps:get(id, State2),
UnavailableResponse = #{
<<"id">> => integer_to_binary(GuildId),
<<"unavailable">> => true
},
{reply, {ok, unavailable, UnavailableResponse}, State2};
false ->
SyncedState = maybe_auto_sync_initial_guild(
SessionId,
GuildId,
InitialGuildId,
State2
),
{reply, {ok, GuildState}, SyncedState}
end
{reply, {ok, unavailable, UnavailableResponse}, State2};
false ->
SyncedState = maybe_auto_sync_initial_guild(SessionId, GuildId, InitialGuildId, State2),
{reply, {ok, GuildState}, SyncedState}
end.
-spec build_initial_last_message_ids(map()) -> #{binary() => binary()}.
build_initial_last_message_ids(GuildState) ->
Channels = maps:get(<<"channels">>, GuildState, []),
lists:foldl(
@@ -106,10 +120,69 @@ build_initial_last_message_ids(GuildState) ->
Channels
).
-spec build_initial_channel_versions(map()) -> #{binary() => integer()}.
build_initial_channel_versions(GuildState) ->
Channels = maps:get(<<"channels">>, GuildState, []),
lists:foldl(
fun(Channel, Acc) ->
ChannelIdBin = maps:get(<<"id">>, Channel, undefined),
case ChannelIdBin of
undefined ->
Acc;
_ ->
Version = map_utils:get_integer(Channel, <<"version">>, 0),
maps:put(ChannelIdBin, Version, Acc)
end
end,
#{},
Channels
).
-spec handle_session_down(reference(), guild_state()) ->
{noreply, guild_state()} | {stop, normal, guild_state()}.
handle_session_down(Ref, State) ->
Sessions = maps:get(sessions, State, #{}),
DisconnectingSession = find_session_by_ref(Ref, Sessions),
State1 = cleanup_disconnecting_session(DisconnectingSession, State),
NewSessions = maps:filter(fun(_K, S) -> maps:get(mref, S) =/= Ref end, Sessions),
NewState = maps:put(sessions, NewSessions, State1),
case map_size(NewSessions) of
0 ->
case should_auto_stop_on_empty(NewState) of
true -> {stop, normal, NewState};
false -> {noreply, NewState}
end;
_ -> {noreply, NewState}
end.
DisconnectingSession = maps:fold(
-spec remove_session(session_id(), guild_state()) -> guild_state().
remove_session(SessionId, State) ->
Sessions = maps:get(sessions, State, #{}),
case maps:get(SessionId, Sessions, undefined) of
undefined ->
State;
Session ->
maybe_demonitor_session(Session),
StateAfterCleanup = cleanup_disconnecting_session(Session, State),
SessionsAfterCleanup = maps:get(sessions, StateAfterCleanup, #{}),
NewSessions = maps:remove(SessionId, SessionsAfterCleanup),
maps:put(sessions, NewSessions, StateAfterCleanup)
end.
-spec maybe_demonitor_session(session_data()) -> ok.
maybe_demonitor_session(Session) ->
Ref = maps:get(mref, Session, undefined),
case is_reference(Ref) of
true ->
demonitor(Ref, [flush]),
ok;
false ->
ok
end.
-spec find_session_by_ref(reference(), sessions_map()) -> session_data() | undefined.
find_session_by_ref(Ref, Sessions) ->
maps:fold(
fun(_K, S, Acc) ->
case maps:get(mref, S) =:= Ref of
true -> S;
@@ -118,95 +191,253 @@ handle_session_down(Ref, State) ->
end,
undefined,
Sessions
),
State1 =
case DisconnectingSession of
undefined ->
State;
Session ->
UserId = maps:get(user_id, Session),
SessionId = maps:get(session_id, Session),
StateAfterPresence = unsubscribe_from_user_presence(UserId, State),
StateAfterMemberList = guild_member_list:unsubscribe_session(
SessionId, StateAfterPresence
),
MemberSubs = maps:get(
member_subscriptions, StateAfterMemberList, guild_subscriptions:init_state()
),
NewMemberSubs = guild_subscriptions:unsubscribe_session(SessionId, MemberSubs),
maps:put(member_subscriptions, NewMemberSubs, StateAfterMemberList)
end,
NewSessions = maps:filter(fun(_K, S) -> maps:get(mref, S) =/= Ref end, Sessions),
NewState = maps:put(sessions, NewSessions, State1),
case map_size(NewSessions) of
0 ->
{stop, normal, NewState};
_ ->
{noreply, NewState}
end.
filter_sessions_for_channel(Sessions, ChannelId, SessionIdOpt, State) ->
GuildId = maps:get(id, State, 0),
lists:filter(
fun({Sid, S}) ->
UserId = maps:get(user_id, S),
Member = find_member_by_user_id(UserId, State),
ExcludeSession =
case SessionIdOpt of
undefined -> false;
SessionId -> Sid =:= SessionId
end,
case {ExcludeSession, Member} of
{true, _} ->
false;
{_, undefined} ->
logger:warning(
"[guild_sessions] Filtering out session with no member: "
"guild_id=~p session_id=~p user_id=~p",
[GuildId, Sid, UserId]
),
false;
{false, _} ->
can_view_channel(UserId, ChannelId, Member, State)
end
end,
maps:to_list(Sessions)
).
filter_sessions_for_manage_channels(Sessions, ChannelId, SessionIdOpt, State) ->
-spec cleanup_disconnecting_session(session_data() | undefined, guild_state()) -> guild_state().
cleanup_disconnecting_session(undefined, State) ->
State;
cleanup_disconnecting_session(Session, State) ->
UserId = maps:get(user_id, Session),
SessionId = maps:get(session_id, Session),
_ = maybe_notify_coordinator(session_disconnected, SessionId, UserId, State),
StateAfterPresence = unsubscribe_from_user_presence(UserId, State),
StateAfterMemberList = guild_member_list:unsubscribe_session(SessionId, StateAfterPresence),
MemberSubs = maps:get(
member_subscriptions, StateAfterMemberList, guild_subscriptions:init_state()
),
NewMemberSubs = guild_subscriptions:unsubscribe_session(SessionId, MemberSubs),
StateAfterSubs = maps:put(member_subscriptions, NewMemberSubs, StateAfterMemberList),
cleanup_connect_admission_for_session(SessionId, StateAfterSubs).
-spec cleanup_connect_admission_for_session(session_id(), guild_state()) -> guild_state().
cleanup_connect_admission_for_session(SessionId, State) ->
Pending0 = maps:get(session_connect_pending, State, undefined),
State1 =
case Pending0 of
Pending when is_map(Pending) ->
maps:put(session_connect_pending, maps:remove(SessionId, Pending), State);
_ ->
State
end,
Queue0 = maps:get(session_connect_queue, State1, undefined),
Queue = normalize_connect_queue(Queue0),
case queue:is_queue(Queue) of
true ->
Filtered = queue:filter(
fun(Item) ->
Request = maps:get(request, Item, #{}),
maps:get(session_id, Request, undefined) =/= SessionId
end,
Queue
),
maps:put(session_connect_queue, Filtered, State1);
false ->
State1
end.
-spec normalize_connect_queue(term()) -> queue:queue() | undefined.
normalize_connect_queue(Value) when is_list(Value) ->
queue:from_list(Value);
normalize_connect_queue(Value) ->
case queue:is_queue(Value) of
true -> Value;
false -> undefined
end.
-spec should_auto_stop_on_empty(guild_state()) -> boolean().
should_auto_stop_on_empty(State) ->
case maps:get(disable_auto_stop_on_empty, State, false) of
true ->
false;
false ->
case maps:get(very_large_guild_coordinator_pid, State, undefined) of
Pid when is_pid(Pid) -> false;
_ -> true
end
end.
-spec maybe_notify_coordinator(session_connected | session_disconnected, session_id(), user_id(), guild_state()) ->
ok.
maybe_notify_coordinator(Type, SessionId, UserId, State) ->
case {maps:get(very_large_guild_coordinator_pid, State, undefined),
maps:get(very_large_guild_shard_index, State, undefined)}
of
{CoordPid, ShardIndex} when is_pid(CoordPid), is_integer(ShardIndex) ->
Msg =
case Type of
session_connected ->
{very_large_guild_session_connected, ShardIndex, SessionId, UserId};
session_disconnected ->
{very_large_guild_session_disconnected, ShardIndex, SessionId, UserId}
end,
CoordPid ! Msg,
ok;
_ ->
ok
end.
-spec filter_sessions_for_channel(
sessions_map(), channel_id(), session_id() | undefined, guild_state()
) ->
[session_pair()].
filter_sessions_for_channel(Sessions, ChannelId, SessionIdOpt, State) ->
lists:filter(
fun({Sid, S}) ->
UserId = maps:get(user_id, S),
ExcludeSession =
case SessionIdOpt of
undefined -> false;
SessionId -> Sid =:= SessionId
end,
case maps:get(pending_connect, S, false) of
true ->
false;
false ->
ExcludeSession = should_exclude_session(Sid, SessionIdOpt),
case ExcludeSession of
true ->
false;
false ->
can_manage_channel(UserId, ChannelId, State)
session_can_view_channel(S, ChannelId, State)
end
end
end,
maps:to_list(Sessions)
).
filter_sessions_exclude_session(Sessions, SessionIdOpt) ->
case SessionIdOpt of
-spec filter_sessions_for_message(
sessions_map(), channel_id(), binary(), session_id() | undefined, guild_state()
) ->
[session_pair()].
filter_sessions_for_message(Sessions, ChannelId, MessageId, SessionIdOpt, State) ->
lists:filter(
fun({Sid, S}) ->
case maps:get(pending_connect, S, false) of
true ->
false;
false ->
ExcludeSession = should_exclude_session(Sid, SessionIdOpt),
case ExcludeSession of
true ->
false;
false ->
UserId = maps:get(user_id, S, undefined),
case session_can_view_channel(S, ChannelId, State) of
false ->
false;
true ->
Perms = guild_permissions:get_member_permissions(UserId, ChannelId, State),
guild_permissions:can_access_message_by_permissions(
Perms, MessageId, State
)
end
end
end
end,
maps:to_list(Sessions)
).
-spec filter_sessions_for_manage_channels(
sessions_map(), channel_id(), session_id() | undefined, guild_state()
) ->
[session_pair()].
filter_sessions_for_manage_channels(Sessions, ChannelId, SessionIdOpt, State) ->
lists:filter(
fun({Sid, S}) ->
case maps:get(pending_connect, S, false) of
true ->
false;
false ->
UserId = maps:get(user_id, S),
ExcludeSession = should_exclude_session(Sid, SessionIdOpt),
case ExcludeSession of
true -> false;
false -> guild_permissions:can_manage_channel(UserId, ChannelId, State)
end
end
end,
maps:to_list(Sessions)
).
-spec filter_sessions_exclude_session(sessions_map(), session_id() | undefined) -> [session_pair()].
filter_sessions_exclude_session(Sessions, undefined) ->
[{Sid, S} || {Sid, S} <- maps:to_list(Sessions), maps:get(pending_connect, S, false) =/= true];
filter_sessions_exclude_session(Sessions, SessionId) ->
[
{Sid, S}
|| {Sid, S} <- maps:to_list(Sessions),
Sid =/= SessionId,
maps:get(pending_connect, S, false) =/= true
].
-spec should_exclude_session(session_id(), session_id() | undefined) -> boolean().
should_exclude_session(_, undefined) -> false;
should_exclude_session(Sid, SessionId) -> Sid =:= SessionId.
-spec set_session_viewable_channels(session_id(), map(), guild_state()) -> guild_state().
set_session_viewable_channels(SessionId, ViewableChannels, State) ->
Sessions = maps:get(sessions, State, #{}),
case maps:get(SessionId, Sessions, undefined) of
undefined ->
maps:to_list(Sessions);
SessionId ->
[{Sid, S} || {Sid, S} <- maps:to_list(Sessions), Sid =/= SessionId]
State;
SessionData ->
NewSessionData = maps:put(viewable_channels, ViewableChannels, SessionData),
NewSessions = maps:put(SessionId, NewSessionData, Sessions),
maps:put(sessions, NewSessions, State)
end.
-spec refresh_all_viewable_channels(guild_state()) -> guild_state().
refresh_all_viewable_channels(State) ->
Sessions = maps:get(sessions, State, #{}),
lists:foldl(
fun({SessionId, SessionData}, AccState) ->
UserId = maps:get(user_id, SessionData, undefined),
case is_integer(UserId) of
true ->
ViewableChannels = build_viewable_channel_map(
guild_visibility:get_user_viewable_channels(UserId, AccState)
),
set_session_viewable_channels(SessionId, ViewableChannels, AccState);
false ->
AccState
end
end,
State,
maps:to_list(Sessions)
).
-spec session_can_view_channel(session_data(), channel_id(), guild_state()) -> boolean().
session_can_view_channel(SessionData, ChannelId, State) ->
UserId = maps:get(user_id, SessionData, undefined),
case {UserId, maps:get(viewable_channels, SessionData, undefined)} of
{Uid, ViewableChannels} when is_integer(Uid), is_map(ViewableChannels) ->
case maps:is_key(ChannelId, ViewableChannels) of
true ->
true;
false ->
check_member_channel_access(Uid, ChannelId, State)
end;
{Uid, _} when is_integer(Uid) ->
check_member_channel_access(Uid, ChannelId, State);
_ ->
false
end.
-spec check_member_channel_access(user_id(), channel_id(), guild_state()) -> boolean().
check_member_channel_access(UserId, ChannelId, State) ->
Member = guild_permissions:find_member_by_user_id(UserId, State),
case Member of
undefined ->
false;
_ ->
guild_permissions:can_view_channel(UserId, ChannelId, Member, State)
end.
-spec build_viewable_channel_map([channel_id()]) -> #{channel_id() => true}.
build_viewable_channel_map(ChannelIds) ->
lists:foldl(
fun(ChannelId, Acc) ->
maps:put(ChannelId, true, Acc)
end,
#{},
ChannelIds
).
-spec subscribe_to_user_presence(user_id(), guild_state()) -> guild_state().
subscribe_to_user_presence(UserId, State) ->
PresenceSubs = maps:get(presence_subscriptions, State, #{}),
CurrentCount = maps:get(UserId, PresenceSubs, 0),
@@ -221,6 +452,7 @@ subscribe_to_user_presence(UserId, State) ->
maps:put(presence_subscriptions, NewSubs, State)
end.
-spec unsubscribe_from_user_presence(user_id(), guild_state()) -> guild_state().
unsubscribe_from_user_presence(UserId, State) ->
PresenceSubs = maps:get(presence_subscriptions, State, #{}),
CurrentCount = maps:get(UserId, PresenceSubs, 0),
@@ -235,30 +467,33 @@ unsubscribe_from_user_presence(UserId, State) ->
maps:put(presence_subscriptions, NewSubs, State)
end.
-spec handle_user_offline(user_id(), guild_state()) -> guild_state().
handle_user_offline(UserId, State) ->
PresenceSubs = maps:get(presence_subscriptions, State, #{}),
case maps:get(UserId, PresenceSubs, undefined) of
0 ->
presence_bus:unsubscribe(UserId),
NewSubs = maps:remove(UserId, PresenceSubs),
maps:put(presence_subscriptions, NewSubs, State);
undefined ->
State;
StateWithSubs = maps:put(presence_subscriptions, NewSubs, State),
MemberPresence = maps:get(member_presence, StateWithSubs, #{}),
UpdatedMemberPresence = maps:remove(UserId, MemberPresence),
maps:put(member_presence, UpdatedMemberPresence, StateWithSubs);
_ ->
State
end.
-spec maybe_send_cached_presence(user_id(), guild_state()) -> guild_state().
maybe_send_cached_presence(UserId, State) ->
case presence_cache:get(UserId) of
{ok, Payload} ->
case guild_presence:handle_bus_presence(UserId, Payload, State) of
{noreply, UpdatedState} ->
UpdatedState
{noreply, UpdatedState} -> UpdatedState
end;
_ ->
State
end.
-spec set_session_active_guild(session_id(), guild_id(), guild_state()) -> guild_state().
set_session_active_guild(SessionId, GuildId, State) ->
Sessions = maps:get(sessions, State, #{}),
case maps:get(SessionId, Sessions, undefined) of
@@ -270,6 +505,7 @@ set_session_active_guild(SessionId, GuildId, State) ->
maps:put(sessions, NewSessions, State)
end.
-spec set_session_passive_guild(session_id(), guild_id(), guild_state()) -> guild_state().
set_session_passive_guild(SessionId, GuildId, State) ->
Sessions = maps:get(sessions, State, #{}),
case maps:get(SessionId, Sessions, undefined) of
@@ -282,39 +518,39 @@ set_session_passive_guild(SessionId, GuildId, State) ->
maps:put(sessions, NewSessions, State)
end.
-spec is_session_active(session_id(), guild_state()) -> boolean().
is_session_active(SessionId, State) ->
GuildId = maps:get(id, State, 0),
Sessions = maps:get(sessions, State, #{}),
case maps:get(SessionId, Sessions, undefined) of
undefined ->
false;
SessionData ->
not session_passive:is_passive(GuildId, SessionData)
undefined -> false;
SessionData -> not session_passive:is_passive(GuildId, SessionData)
end.
maybe_auto_sync_initial_guild(SessionId, GuildId, InitialGuildId, State) ->
case InitialGuildId of
GuildId ->
Sessions = maps:get(sessions, State, #{}),
case maps:get(SessionId, Sessions, undefined) of
undefined ->
State;
SessionData ->
SyncedSessionData = session_passive:mark_guild_synced(GuildId, SessionData),
NewSessions = maps:put(SessionId, SyncedSessionData, Sessions),
maps:put(sessions, NewSessions, State)
end;
_ ->
State
end.
-spec maybe_auto_sync_initial_guild(
session_id(), guild_id(), guild_id() | undefined, guild_state()
) ->
guild_state().
maybe_auto_sync_initial_guild(SessionId, GuildId, GuildId, State) ->
Sessions = maps:get(sessions, State, #{}),
case maps:get(SessionId, Sessions, undefined) of
undefined ->
State;
SessionData ->
SyncedSessionData = session_passive:mark_guild_synced(GuildId, SessionData),
NewSessions = maps:put(SessionId, SyncedSessionData, Sessions),
maps:put(sessions, NewSessions, State)
end;
maybe_auto_sync_initial_guild(_SessionId, _GuildId, _InitialGuildId, State) ->
State.
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
build_initial_last_message_ids_empty_channels_test() ->
GuildState = #{<<"channels">> => []},
Result = build_initial_last_message_ids(GuildState),
?assertEqual(#{}, Result),
ok.
?assertEqual(#{}, Result).
build_initial_last_message_ids_with_channels_test() ->
GuildState = #{
@@ -324,8 +560,7 @@ build_initial_last_message_ids_with_channels_test() ->
]
},
Result = build_initial_last_message_ids(GuildState),
?assertEqual(#{<<"100">> => <<"500">>, <<"101">> => <<"600">>}, Result),
ok.
?assertEqual(#{<<"100">> => <<"500">>, <<"101">> => <<"600">>}, Result).
build_initial_last_message_ids_filters_null_test() ->
GuildState = #{
@@ -336,13 +571,237 @@ build_initial_last_message_ids_filters_null_test() ->
]
},
Result = build_initial_last_message_ids(GuildState),
?assertEqual(#{<<"100">> => <<"500">>}, Result),
ok.
?assertEqual(#{<<"100">> => <<"500">>}, Result).
build_initial_last_message_ids_no_channels_key_test() ->
GuildState = #{},
Result = build_initial_last_message_ids(GuildState),
?assertEqual(#{}, Result),
?assertEqual(#{}, Result).
build_initial_channel_versions_test() ->
GuildState = #{
<<"channels">> => [
#{<<"id">> => <<"100">>, <<"version">> => 5},
#{<<"id">> => <<"101">>}
]
},
Result = build_initial_channel_versions(GuildState),
?assertEqual(#{<<"100">> => 5, <<"101">> => 0}, Result).
filter_sessions_for_channel_uses_cached_viewable_channels_test() ->
SessionId = <<"s1">>,
SessionData = #{
session_id => SessionId,
user_id => 10,
pid => self(),
viewable_channels => #{200 => true}
},
Sessions = #{SessionId => SessionData},
State = #{
sessions => Sessions,
data => #{<<"members">> => #{}}
},
Result = filter_sessions_for_channel(Sessions, 200, undefined, State),
?assertEqual([{SessionId, SessionData}], Result).
refresh_all_viewable_channels_populates_cache_test() ->
SessionId = <<"s1">>,
UserId = 10,
GuildId = 42,
ViewPerm = constants:view_channel_permission(),
State = #{
id => GuildId,
sessions => #{
SessionId => #{
session_id => SessionId,
user_id => UserId,
pid => self()
}
},
data => #{
<<"guild">> => #{<<"owner_id">> => <<"999">>},
<<"roles">> => [
#{<<"id">> => integer_to_binary(GuildId), <<"permissions">> => integer_to_binary(ViewPerm)}
],
<<"members">> => #{
UserId => #{<<"user">> => #{<<"id">> => integer_to_binary(UserId)}, <<"roles">> => []}
},
<<"channels">> => [
#{<<"id">> => <<"200">>, <<"permission_overwrites">> => []}
]
}
},
UpdatedState = refresh_all_viewable_channels(State),
UpdatedSessions = maps:get(sessions, UpdatedState, #{}),
UpdatedSession = maps:get(SessionId, UpdatedSessions),
ViewableChannels = maps:get(viewable_channels, UpdatedSession, #{}),
?assertEqual(true, maps:is_key(200, ViewableChannels)).
filter_sessions_exclude_session_undefined_test() ->
Sessions = #{<<"s1">> => #{}, <<"s2">> => #{}},
Result = filter_sessions_exclude_session(Sessions, undefined),
?assertEqual(2, length(Result)).
filter_sessions_exclude_session_specific_test() ->
Sessions = #{<<"s1">> => #{}, <<"s2">> => #{}},
Result = filter_sessions_exclude_session(Sessions, <<"s1">>),
?assertEqual(1, length(Result)),
?assertEqual([{<<"s2">>, #{}}], Result).
should_exclude_session_test() ->
?assertEqual(false, should_exclude_session(<<"s1">>, undefined)),
?assertEqual(true, should_exclude_session(<<"s1">>, <<"s1">>)),
?assertEqual(false, should_exclude_session(<<"s1">>, <<"s2">>)).
find_session_by_ref_found_test() ->
Ref = make_ref(),
Sessions = #{<<"s1">> => #{mref => Ref, user_id => 1}},
Result = find_session_by_ref(Ref, Sessions),
?assertEqual(#{mref => Ref, user_id => 1}, Result).
find_session_by_ref_not_found_test() ->
Ref = make_ref(),
OtherRef = make_ref(),
Sessions = #{<<"s1">> => #{mref => OtherRef, user_id => 1}},
Result = find_session_by_ref(Ref, Sessions),
?assertEqual(undefined, Result).
remove_session_removes_entry_test() ->
SessionId = <<"s1">>,
SessionData = #{
session_id => SessionId,
user_id => 1,
pid => self(),
mref => make_ref(),
active_guilds => sets:new(),
user_roles => [],
bot => false,
previous_passive_updates => #{},
previous_passive_channel_versions => #{},
previous_passive_voice_states => #{}
},
State = #{
sessions => #{SessionId => SessionData},
presence_subscriptions => #{1 => 1},
member_list_subscriptions => #{},
member_subscriptions => guild_subscriptions:init_state()
},
NewState = remove_session(SessionId, State),
?assertEqual(#{}, maps:get(sessions, NewState, #{})),
?assertEqual(0, maps:get(1, maps:get(presence_subscriptions, NewState, #{}), 0)).
remove_session_cleans_connect_pending_test() ->
SessionId = <<"s1">>,
SessionData = #{
session_id => SessionId,
user_id => 1,
pid => self(),
mref => make_ref(),
active_guilds => sets:new(),
user_roles => [],
bot => false,
previous_passive_updates => #{},
previous_passive_channel_versions => #{},
previous_passive_voice_states => #{}
},
State = #{
sessions => #{SessionId => SessionData},
presence_subscriptions => #{1 => 1},
member_list_subscriptions => #{},
member_subscriptions => guild_subscriptions:init_state(),
session_connect_pending => #{SessionId => 3, <<"s2">> => 1},
session_connect_queue => [
#{request => #{session_id => SessionId}, attempt => 3},
#{request => #{session_id => <<"s2">>}, attempt => 1}
]
},
NewState = remove_session(SessionId, State),
Pending = maps:get(session_connect_pending, NewState, #{}),
?assertEqual(false, maps:is_key(SessionId, Pending)),
?assertEqual(true, maps:is_key(<<"s2">>, Pending)),
Queue0 = maps:get(session_connect_queue, NewState, queue:new()),
Queue =
case Queue0 of
L when is_list(L) -> L;
_ ->
case queue:is_queue(Queue0) of
true -> queue:to_list(Queue0);
false -> []
end
end,
?assertEqual(
1,
length([
Item
|| Item <- Queue,
maps:get(session_id, maps:get(request, Item, #{}), undefined) =:= <<"s2">>
])
),
?assertEqual(
0,
length([
Item
|| Item <- Queue,
maps:get(session_id, maps:get(request, Item, #{}), undefined) =:= SessionId
])
),
ok.
cleanup_connect_admission_queue_format_test() ->
S1 = <<"s1">>,
S2 = <<"s2">>,
S3 = <<"s3">>,
Queue = queue:from_list([
#{request => #{session_id => S1}, attempt => 0},
#{request => #{session_id => S2}, attempt => 1},
#{request => #{session_id => S1}, attempt => 2},
#{request => #{session_id => S3}, attempt => 3}
]),
State0 = #{
sessions => #{},
session_connect_queue => Queue,
presence_subscriptions => #{},
member_list_subscriptions => #{},
member_subscriptions => guild_subscriptions:init_state()
},
State1 = cleanup_connect_admission_for_session(S1, State0),
ResultQueue0 = maps:get(session_connect_queue, State1),
ResultQueue =
case queue:is_queue(ResultQueue0) of
true -> queue:to_list(ResultQueue0);
false -> ResultQueue0
end,
?assertEqual(2, length(ResultQueue)),
SessionIds = [
maps:get(session_id, maps:get(request, Item, #{}), undefined)
|| Item <- ResultQueue
],
?assertEqual(false, lists:member(S1, SessionIds)),
?assertEqual(true, lists:member(S2, SessionIds)),
?assertEqual(true, lists:member(S3, SessionIds)).
pending_connect_filtered_from_channel_sessions_test() ->
NormalSession = #{
session_id => <<"s1">>,
user_id => 10,
pid => self(),
viewable_channels => #{200 => true}
},
PendingSession = #{
session_id => <<"s2">>,
user_id => 11,
pid => self(),
pending_connect => true,
viewable_channels => #{200 => true}
},
Sessions = #{<<"s1">> => NormalSession, <<"s2">> => PendingSession},
State = #{
sessions => Sessions,
data => #{<<"members">> => #{}}
},
Result = filter_sessions_for_channel(Sessions, 200, undefined, State),
?assertEqual(1, length(Result)),
[{ResultSid, _}] = Result,
?assertEqual(<<"s1">>, ResultSid).
-endif.

View File

@@ -17,24 +17,26 @@
-module(guild_state).
-export([
update_state/3
]).
-export([update_state/3]).
-import(guild_user_data, [maybe_update_cached_user_data/3]).
-import(guild_availability, [handle_unavailability_transition/2]).
-import(guild_visibility, [compute_and_dispatch_visibility_changes/2]).
-import(guild, [update_counts/1]).
-type guild_state() :: map().
-type guild_data() :: map().
-type event() :: atom().
-type event_data() :: map().
-type user_id() :: integer().
-type role_id() :: binary().
-spec update_state(event(), event_data(), guild_state()) -> guild_state().
update_state(Event, EventData, State) ->
StateWithUpdatedUser = maybe_update_cached_user_data(Event, EventData, State),
Data = maps:get(data, StateWithUpdatedUser),
StateWithUpdatedUser0 = guild_user_data:maybe_update_cached_user_data(Event, EventData, State),
Data0 = maps:get(data, StateWithUpdatedUser0),
Data = guild_data_index:normalize_data(Data0),
StateWithUpdatedUser = maps:put(data, Data, StateWithUpdatedUser0),
UpdatedData = update_data_for_event(Event, EventData, Data, State),
UpdatedState = maps:put(data, UpdatedData, StateWithUpdatedUser),
handle_post_update(Event, EventData, StateWithUpdatedUser, UpdatedState).
handle_post_update(Event, StateWithUpdatedUser, UpdatedState).
-spec update_data_for_event(event(), event_data(), guild_data(), guild_state()) -> guild_data().
update_data_for_event(guild_update, EventData, Data, _State) ->
handle_guild_update(EventData, Data);
update_data_for_event(guild_member_add, EventData, Data, _State) ->
@@ -70,23 +72,184 @@ update_data_for_event(guild_stickers_update, EventData, Data, _State) ->
update_data_for_event(_Event, _EventData, Data, _State) ->
Data.
handle_post_update(guild_update, StateWithUpdatedUser, UpdatedState) ->
handle_unavailability_transition(StateWithUpdatedUser, UpdatedState),
-spec handle_post_update(event(), event_data(), guild_state(), guild_state()) -> guild_state().
handle_post_update(guild_update, _EventData, StateWithUpdatedUser, UpdatedState) ->
guild_availability:handle_unavailability_transition(StateWithUpdatedUser, UpdatedState);
handle_post_update(guild_member_add, _EventData, _StateWithUpdatedUser, UpdatedState) ->
guild:update_counts(UpdatedState);
handle_post_update(guild_member_update, EventData, StateWithUpdatedUser, UpdatedState) ->
UserId = extract_user_id(EventData),
case is_integer(UserId) andalso UserId > 0 of
true ->
guild_visibility:compute_and_dispatch_visibility_changes_for_users(
[UserId],
StateWithUpdatedUser,
UpdatedState
);
false ->
guild_visibility:compute_and_dispatch_visibility_changes(
StateWithUpdatedUser,
UpdatedState
)
end;
handle_post_update(guild_role_create, _EventData, _StateWithUpdatedUser, UpdatedState) ->
UpdatedState;
handle_post_update(guild_member_add, _StateWithUpdatedUser, UpdatedState) ->
update_counts(UpdatedState);
handle_post_update(guild_member_remove, _StateWithUpdatedUser, UpdatedState) ->
handle_post_update(guild_role_update, EventData, StateWithUpdatedUser, UpdatedState) ->
recompute_visibility_for_roles(
extract_role_ids_from_role_update(EventData),
StateWithUpdatedUser,
UpdatedState
);
handle_post_update(guild_role_update_bulk, EventData, StateWithUpdatedUser, UpdatedState) ->
recompute_visibility_for_roles(
extract_role_ids_from_role_update_bulk(EventData),
StateWithUpdatedUser,
UpdatedState
);
handle_post_update(guild_role_delete, EventData, StateWithUpdatedUser, UpdatedState) ->
recompute_visibility_for_roles(
extract_role_ids_from_role_delete(EventData),
StateWithUpdatedUser,
UpdatedState
);
handle_post_update(channel_update, EventData, StateWithUpdatedUser, UpdatedState) ->
ChannelIds = extract_channel_ids_from_channel_update(EventData),
guild_visibility:compute_and_dispatch_visibility_changes_for_channels(
ChannelIds,
StateWithUpdatedUser,
UpdatedState
);
handle_post_update(channel_update_bulk, EventData, StateWithUpdatedUser, UpdatedState) ->
ChannelIds = extract_channel_ids_from_channel_update_bulk(EventData),
guild_visibility:compute_and_dispatch_visibility_changes_for_channels(
ChannelIds,
StateWithUpdatedUser,
UpdatedState
);
handle_post_update(guild_member_remove, EventData, _StateWithUpdatedUser, UpdatedState) ->
UserId = extract_user_id(EventData),
State1 = cleanup_removed_member_sessions(UpdatedState),
update_counts(State1);
handle_post_update(Event, StateWithUpdatedUser, UpdatedState) ->
State2 = maybe_disconnect_removed_member(UserId, State1),
guild:update_counts(State2);
handle_post_update(Event, _EventData, StateWithUpdatedUser, UpdatedState) ->
case needs_visibility_check(Event) of
true ->
compute_and_dispatch_visibility_changes(StateWithUpdatedUser, UpdatedState),
UpdatedState;
guild_visibility:compute_and_dispatch_visibility_changes(
StateWithUpdatedUser, UpdatedState
);
false ->
UpdatedState
end.
-spec maybe_disconnect_removed_member(user_id(), guild_state()) -> guild_state().
maybe_disconnect_removed_member(UserId, State) when is_integer(UserId), UserId > 0 ->
{reply, _Result, NewState} =
guild_voice_disconnect:disconnect_voice_user(
#{user_id => UserId, connection_id => null},
State
),
NewState;
maybe_disconnect_removed_member(_, State) ->
State.
-spec recompute_visibility_for_roles([integer()], guild_state(), guild_state()) -> guild_state().
recompute_visibility_for_roles(RoleIds, StateWithUpdatedUser, UpdatedState) ->
GuildId = maps:get(id, UpdatedState, 0),
case lists:any(fun(RoleId) -> RoleId =:= GuildId end, RoleIds) of
true ->
guild_visibility:compute_and_dispatch_visibility_changes(
StateWithUpdatedUser,
UpdatedState
);
false ->
UserIds = affected_user_ids_for_roles(RoleIds, StateWithUpdatedUser, UpdatedState),
case UserIds of
[] ->
UpdatedState;
_ ->
guild_visibility:compute_and_dispatch_visibility_changes_for_users(
UserIds,
StateWithUpdatedUser,
UpdatedState
)
end
end.
-spec affected_user_ids_for_roles([integer()], guild_state(), guild_state()) -> [user_id()].
affected_user_ids_for_roles(RoleIds, StateWithUpdatedUser, UpdatedState) ->
OldData = maps:get(data, StateWithUpdatedUser, #{}),
NewData = maps:get(data, UpdatedState, #{}),
OldMemberRoleIndex = guild_data_index:member_role_index(OldData),
NewMemberRoleIndex = guild_data_index:member_role_index(NewData),
UserIdSet = lists:foldl(
fun(RoleId, AccSet) ->
OldUsers = maps:keys(maps:get(RoleId, OldMemberRoleIndex, #{})),
NewUsers = maps:keys(maps:get(RoleId, NewMemberRoleIndex, #{})),
RoleUsers = OldUsers ++ NewUsers,
lists:foldl(
fun(UserId, InnerSet) -> sets:add_element(UserId, InnerSet) end,
AccSet,
RoleUsers
)
end,
sets:new(),
lists:usort(RoleIds)
),
sets:to_list(UserIdSet).
-spec extract_role_ids_from_role_update(event_data()) -> [integer()].
extract_role_ids_from_role_update(EventData) ->
RoleData = maps:get(<<"role">>, EventData, #{}),
case type_conv:to_integer(maps:get(<<"id">>, RoleData, undefined)) of
undefined -> [];
RoleId -> [RoleId]
end.
-spec extract_role_ids_from_role_update_bulk(event_data()) -> [integer()].
extract_role_ids_from_role_update_bulk(EventData) ->
Roles = maps:get(<<"roles">>, EventData, []),
lists:filtermap(
fun(RoleData) ->
case type_conv:to_integer(maps:get(<<"id">>, RoleData, undefined)) of
undefined ->
false;
RoleId ->
{true, RoleId}
end
end,
Roles
).
-spec extract_role_ids_from_role_delete(event_data()) -> [integer()].
extract_role_ids_from_role_delete(EventData) ->
case type_conv:to_integer(maps:get(<<"role_id">>, EventData, undefined)) of
undefined -> [];
RoleId -> [RoleId]
end.
-spec extract_channel_ids_from_channel_update(event_data()) -> [integer()].
extract_channel_ids_from_channel_update(EventData) ->
case type_conv:to_integer(maps:get(<<"id">>, EventData, undefined)) of
undefined -> [];
ChannelId -> [ChannelId]
end.
-spec extract_channel_ids_from_channel_update_bulk(event_data()) -> [integer()].
extract_channel_ids_from_channel_update_bulk(EventData) ->
Channels = maps:get(<<"channels">>, EventData, []),
lists:filtermap(
fun(ChannelData) ->
case type_conv:to_integer(maps:get(<<"id">>, ChannelData, undefined)) of
undefined ->
false;
ChannelId ->
{true, ChannelId}
end
end,
Channels
).
-spec needs_visibility_check(event()) -> boolean().
needs_visibility_check(guild_role_create) -> true;
needs_visibility_check(guild_role_update) -> true;
needs_visibility_check(guild_role_update_bulk) -> true;
@@ -96,33 +259,36 @@ needs_visibility_check(channel_update) -> true;
needs_visibility_check(channel_update_bulk) -> true;
needs_visibility_check(_) -> false.
-spec handle_guild_update(event_data(), guild_data()) -> guild_data().
handle_guild_update(EventData, Data) ->
Guild = maps:get(<<"guild">>, Data),
UpdatedGuild = maps:merge(Guild, EventData),
maps:put(<<"guild">>, UpdatedGuild, Data).
-spec handle_member_add(event_data(), guild_data()) -> guild_data().
handle_member_add(EventData, Data) ->
Members = maps:get(<<"members">>, Data, []),
UpdatedData = maps:put(<<"members">>, Members ++ [EventData], Data),
UpdatedData.
guild_data_index:put_member(EventData, Data).
-spec handle_member_update(event_data(), guild_data()) -> guild_data().
handle_member_update(EventData, Data) ->
Members = maps:get(<<"members">>, Data, []),
UserId = extract_user_id(EventData),
UpdatedMembers = replace_member_by_id(Members, UserId, EventData),
maps:put(<<"members">>, UpdatedMembers, Data).
Members = guild_data_index:member_map(Data),
case maps:is_key(UserId, Members) of
true ->
guild_data_index:put_member(EventData, Data);
false ->
Data
end.
-spec handle_member_remove(event_data(), guild_data(), guild_state()) -> guild_data().
handle_member_remove(EventData, Data, _State) ->
Members = maps:get(<<"members">>, Data, []),
UserId = extract_user_id(EventData),
FilteredMembers = remove_member_by_id(Members, UserId),
maps:put(<<"members">>, FilteredMembers, Data).
guild_data_index:remove_member(UserId, Data).
-spec cleanup_removed_member_sessions(guild_state()) -> guild_state().
cleanup_removed_member_sessions(State) ->
Data = maps:get(data, State),
Members = maps:get(<<"members">>, Data, []),
MemberUserIds = sets:from_list([extract_user_id_from_member(M) || M <- Members]),
MemberUserIds = sets:from_list(guild_data_index:member_ids(Data)),
Sessions = maps:get(sessions, State, #{}),
FilteredSessions = maps:filter(
fun(_K, S) ->
@@ -131,158 +297,141 @@ cleanup_removed_member_sessions(State) ->
end,
Sessions
),
maps:put(sessions, FilteredSessions, State).
Presences = maps:get(presences, State, #{}),
FilteredPresences = maps:filter(
fun(UserId, _V) ->
sets:is_element(UserId, MemberUserIds)
end,
Presences
),
State1 = maps:put(sessions, FilteredSessions, State),
maps:put(presences, FilteredPresences, State1).
extract_user_id_from_member(Member) when is_map(Member) ->
MUser = maps:get(<<"user">>, Member, #{}),
utils:binary_to_integer_safe(maps:get(<<"id">>, MUser, <<"0">>));
extract_user_id_from_member(_) ->
0.
-spec extract_user_id(event_data()) -> user_id().
extract_user_id(EventData) ->
MUser = maps:get(<<"user">>, EventData, #{}),
utils:binary_to_integer_safe(maps:get(<<"id">>, MUser, <<"0">>)).
replace_member_by_id(Members, UserId, NewMember) ->
lists:map(
fun(M) when is_map(M) ->
MMUser = maps:get(<<"user">>, M, #{}),
MUserId = utils:binary_to_integer_safe(maps:get(<<"id">>, MMUser, <<"0">>)),
case MUserId =:= UserId of
true -> NewMember;
false -> M
end
end,
Members
).
remove_member_by_id(Members, UserId) ->
lists:filter(
fun(M) when is_map(M) ->
MMUser = maps:get(<<"user">>, M, #{}),
MUserId = utils:binary_to_integer_safe(maps:get(<<"id">>, MMUser, <<"0">>)),
MUserId =/= UserId
end,
Members
).
-spec handle_role_create(event_data(), guild_data()) -> guild_data().
handle_role_create(EventData, Data) ->
Roles = maps:get(<<"roles">>, Data, []),
Roles = guild_data_index:role_list(Data),
RoleData = maps:get(<<"role">>, EventData),
maps:put(<<"roles">>, Roles ++ [RoleData], Data).
guild_data_index:put_roles(Roles ++ [RoleData], Data).
-spec handle_role_update(event_data(), guild_data()) -> guild_data().
handle_role_update(EventData, Data) ->
Roles = maps:get(<<"roles">>, Data, []),
Roles = guild_data_index:role_list(Data),
RoleData = maps:get(<<"role">>, EventData),
RoleId = maps:get(<<"id">>, RoleData),
UpdatedRoles = replace_item_by_id(Roles, RoleId, RoleData),
maps:put(<<"roles">>, UpdatedRoles, Data).
guild_data_index:put_roles(UpdatedRoles, Data).
-spec handle_role_update_bulk(event_data(), guild_data()) -> guild_data().
handle_role_update_bulk(EventData, Data) ->
Roles = maps:get(<<"roles">>, Data, []),
Roles = guild_data_index:role_list(Data),
BulkRoles = maps:get(<<"roles">>, EventData, []),
UpdatedRoles = bulk_update_items(Roles, BulkRoles),
maps:put(<<"roles">>, UpdatedRoles, Data).
guild_data_index:put_roles(UpdatedRoles, Data).
-spec handle_role_delete(event_data(), guild_data()) -> guild_data().
handle_role_delete(EventData, Data) ->
Roles = maps:get(<<"roles">>, Data, []),
Roles = guild_data_index:role_list(Data),
RoleId = maps:get(<<"role_id">>, EventData),
FilteredRoles = remove_item_by_id(Roles, RoleId),
Data1 = maps:put(<<"roles">>, FilteredRoles, Data),
Data1 = guild_data_index:put_roles(FilteredRoles, Data),
Data2 = strip_role_from_members(RoleId, Data1),
strip_role_from_channel_overwrites(RoleId, Data2).
-spec strip_role_from_members(role_id(), guild_data()) -> guild_data().
strip_role_from_members(RoleId, Data) ->
Members = maps:get(<<"members">>, Data, []),
UpdatedMembers = lists:map(
fun(Member) when is_map(Member) ->
MemberRoles = maps:get(<<"roles">>, Member, []),
FilteredRoles = lists:filter(
fun(R) ->
RoleIdInt = utils:binary_to_integer_safe(RoleId),
RInt = utils:binary_to_integer_safe(R),
RInt =/= RoleIdInt
end,
MemberRoles
),
maps:put(<<"roles">>, FilteredRoles, Member);
(Member) ->
Member
RoleIdInt = utils:binary_to_integer_safe(RoleId),
MemberRoleIndex = guild_data_index:member_role_index(Data),
AffectedUsers = maps:keys(maps:get(RoleIdInt, MemberRoleIndex, #{})),
lists:foldl(
fun(UserId, AccData) ->
case guild_data_index:get_member(UserId, AccData) of
undefined ->
AccData;
Member ->
MemberRoles = maps:get(<<"roles">>, Member, []),
FilteredRoles = lists:filter(
fun(R) ->
RInt = utils:binary_to_integer_safe(R),
RInt =/= RoleIdInt
end,
MemberRoles
),
guild_data_index:put_member(maps:put(<<"roles">>, FilteredRoles, Member), AccData)
end
end,
Members
),
maps:put(<<"members">>, UpdatedMembers, Data).
Data,
AffectedUsers
).
-spec strip_role_from_channel_overwrites(role_id(), guild_data()) -> guild_data().
strip_role_from_channel_overwrites(RoleId, Data) ->
Channels = maps:get(<<"channels">>, Data, []),
Channels = guild_data_index:channel_list(Data),
RoleIdInt = utils:binary_to_integer_safe(RoleId),
UpdatedChannels = lists:map(
fun(Channel) when is_map(Channel) ->
Overwrites = maps:get(<<"permission_overwrites">>, Channel, []),
FilteredOverwrites = lists:filter(
fun(Overwrite) when is_map(Overwrite) ->
OverwriteType = maps:get(<<"type">>, Overwrite, 0),
OverwriteId = utils:binary_to_integer_safe(maps:get(<<"id">>, Overwrite, <<"0">>)),
not (OverwriteType =:= 0 andalso OverwriteId =:= RoleIdInt);
(_) ->
true
end,
Overwrites
),
maps:put(<<"permission_overwrites">>, FilteredOverwrites, Channel);
(Channel) ->
Channel
fun
(Channel) when is_map(Channel) ->
Overwrites = maps:get(<<"permission_overwrites">>, Channel, []),
FilteredOverwrites = lists:filter(
fun
(Overwrite) when is_map(Overwrite) ->
OverwriteType = maps:get(<<"type">>, Overwrite, 0),
OverwriteId = utils:binary_to_integer_safe(
maps:get(<<"id">>, Overwrite, <<"0">>)
),
not (OverwriteType =:= 0 andalso OverwriteId =:= RoleIdInt);
(_) ->
true
end,
Overwrites
),
maps:put(<<"permission_overwrites">>, FilteredOverwrites, Channel);
(Channel) ->
Channel
end,
Channels
),
maps:put(<<"channels">>, UpdatedChannels, Data).
guild_data_index:put_channels(UpdatedChannels, Data).
-spec handle_channel_create(event_data(), guild_data()) -> guild_data().
handle_channel_create(EventData, Data) ->
Channels = maps:get(<<"channels">>, Data, []),
maps:put(<<"channels">>, Channels ++ [EventData], Data).
Channels = guild_data_index:channel_list(Data),
guild_data_index:put_channels(Channels ++ [EventData], Data).
-spec handle_channel_update(event_data(), guild_data()) -> guild_data().
handle_channel_update(EventData, Data) ->
Channels = maps:get(<<"channels">>, Data, []),
Channels = guild_data_index:channel_list(Data),
ChannelId = maps:get(<<"id">>, EventData),
UpdatedChannels = replace_item_by_id(Channels, ChannelId, EventData),
maps:put(<<"channels">>, UpdatedChannels, Data).
guild_data_index:put_channels(UpdatedChannels, Data).
-spec handle_channel_update_bulk(event_data(), guild_data()) -> guild_data().
handle_channel_update_bulk(EventData, Data) ->
Channels = maps:get(<<"channels">>, Data, []),
Channels = guild_data_index:channel_list(Data),
BulkChannels = maps:get(<<"channels">>, EventData, []),
UpdatedChannels = bulk_update_items(Channels, BulkChannels),
maps:put(<<"channels">>, UpdatedChannels, Data).
guild_data_index:put_channels(UpdatedChannels, Data).
-spec handle_channel_delete(event_data(), guild_data()) -> guild_data().
handle_channel_delete(EventData, Data) ->
Channels = maps:get(<<"channels">>, Data, []),
Channels = guild_data_index:channel_list(Data),
ChannelId = maps:get(<<"id">>, EventData),
FilteredChannels = remove_item_by_id(Channels, ChannelId),
maps:put(<<"channels">>, FilteredChannels, Data).
guild_data_index:put_channels(FilteredChannels, Data).
-spec handle_message_create(event_data(), guild_data()) -> guild_data().
handle_message_create(EventData, Data) ->
Channels = maps:get(<<"channels">>, Data, []),
Channels = guild_data_index:channel_list(Data),
ChannelId = maps:get(<<"channel_id">>, EventData),
MessageId = maps:get(<<"id">>, EventData),
UpdatedChannels = update_channel_field(Channels, ChannelId, <<"last_message_id">>, MessageId),
maps:put(<<"channels">>, UpdatedChannels, Data).
guild_data_index:put_channels(UpdatedChannels, Data).
-spec handle_channel_pins_update(event_data(), guild_data()) -> guild_data().
handle_channel_pins_update(EventData, Data) ->
Channels = maps:get(<<"channels">>, Data, []),
Channels = guild_data_index:channel_list(Data),
ChannelId = maps:get(<<"channel_id">>, EventData),
LastPin = maps:get(<<"last_pin_timestamp">>, EventData),
UpdatedChannels = update_channel_field(Channels, ChannelId, <<"last_pin_timestamp">>, LastPin),
maps:put(<<"channels">>, UpdatedChannels, Data).
guild_data_index:put_channels(UpdatedChannels, Data).
-spec update_channel_field([map()], binary(), binary(), term()) -> [map()].
update_channel_field(Channels, ChannelId, Field, Value) ->
lists:map(
fun(C) when is_map(C) ->
@@ -294,12 +443,15 @@ update_channel_field(Channels, ChannelId, Field, Value) ->
Channels
).
-spec handle_emojis_update(event_data(), guild_data()) -> guild_data().
handle_emojis_update(EventData, Data) ->
maps:put(<<"emojis">>, maps:get(<<"emojis">>, EventData, []), Data).
-spec handle_stickers_update(event_data(), guild_data()) -> guild_data().
handle_stickers_update(EventData, Data) ->
maps:put(<<"stickers">>, maps:get(<<"stickers">>, EventData, []), Data).
-spec replace_item_by_id([map()], binary(), map()) -> [map()].
replace_item_by_id(Items, Id, NewItem) ->
lists:map(
fun(Item) when is_map(Item) ->
@@ -311,6 +463,7 @@ replace_item_by_id(Items, Id, NewItem) ->
Items
).
-spec remove_item_by_id([map()], binary()) -> [map()].
remove_item_by_id(Items, Id) ->
lists:filter(
fun(Item) when is_map(Item) ->
@@ -319,6 +472,7 @@ remove_item_by_id(Items, Id) ->
Items
).
-spec bulk_update_items([map()], [map()]) -> [map()].
bulk_update_items(Items, BulkItems) ->
BulkMap = lists:foldl(
fun
@@ -333,7 +487,6 @@ bulk_update_items(Items, BulkItems) ->
#{},
BulkItems
),
lists:map(
fun
(Item) when is_map(Item) ->
@@ -358,26 +511,19 @@ handle_role_delete_strips_from_members_test() ->
#{<<"id">> => <<"100">>, <<"name">> => <<"Admin">>},
#{<<"id">> => <<"200">>, <<"name">> => <<"Moderator">>}
],
<<"members">> => [
#{
<<"user">> => #{<<"id">> => <<"1">>},
<<"roles">> => [<<"100">>, <<"200">>]
},
#{
<<"user">> => #{<<"id">> => <<"2">>},
<<"roles">> => [<<"200">>]
},
#{
<<"user">> => #{<<"id">> => <<"3">>},
<<"roles">> => [<<"100">>]
}
],
<<"members">> => #{
1 => #{<<"user">> => #{<<"id">> => <<"1">>}, <<"roles">> => [<<"100">>, <<"200">>]},
2 => #{<<"user">> => #{<<"id">> => <<"2">>}, <<"roles">> => [<<"200">>]},
3 => #{<<"user">> => #{<<"id">> => <<"3">>}, <<"roles">> => [<<"100">>]}
},
<<"channels">> => []
},
EventData = #{<<"role_id">> => RoleIdToDelete},
Result = handle_role_delete(EventData, Data),
Members = maps:get(<<"members">>, Result),
[M1, M2, M3] = Members,
M1 = maps:get(1, Members),
M2 = maps:get(2, Members),
M3 = maps:get(3, Members),
?assertEqual([<<"100">>], maps:get(<<"roles">>, M1)),
?assertEqual([], maps:get(<<"roles">>, M2)),
?assertEqual([<<"100">>], maps:get(<<"roles">>, M3)).
@@ -394,45 +540,24 @@ handle_role_delete_strips_from_channel_overwrites_test() ->
#{
<<"id">> => <<"500">>,
<<"permission_overwrites">> => [
#{<<"id">> => <<"100">>, <<"type">> => 0, <<"allow">> => <<"0">>, <<"deny">> => <<"1024">>},
#{<<"id">> => <<"200">>, <<"type">> => 0, <<"allow">> => <<"1024">>, <<"deny">> => <<"0">>},
#{<<"id">> => <<"1">>, <<"type">> => 1, <<"allow">> => <<"2048">>, <<"deny">> => <<"0">>}
]
},
#{
<<"id">> => <<"501">>,
<<"permission_overwrites">> => [
#{<<"id">> => <<"200">>, <<"type">> => 0, <<"allow">> => <<"1024">>, <<"deny">> => <<"0">>}
]
}
]
},
EventData = #{<<"role_id">> => RoleIdToDelete},
Result = handle_role_delete(EventData, Data),
Channels = maps:get(<<"channels">>, Result),
[Ch1, Ch2] = Channels,
Ch1Overwrites = maps:get(<<"permission_overwrites">>, Ch1),
Ch2Overwrites = maps:get(<<"permission_overwrites">>, Ch2),
?assertEqual(2, length(Ch1Overwrites)),
?assertEqual(0, length(Ch2Overwrites)),
OverwriteIds = [maps:get(<<"id">>, O) || O <- Ch1Overwrites],
?assert(lists:member(<<"100">>, OverwriteIds)),
?assert(lists:member(<<"1">>, OverwriteIds)),
?assertNot(lists:member(<<"200">>, OverwriteIds)).
handle_role_delete_preserves_user_overwrites_test() ->
RoleIdToDelete = <<"200">>,
Data = #{
<<"roles">> => [
#{<<"id">> => <<"200">>, <<"name">> => <<"Moderator">>}
],
<<"members">> => [],
<<"channels">> => [
#{
<<"id">> => <<"500">>,
<<"permission_overwrites">> => [
#{<<"id">> => <<"200">>, <<"type">> => 0, <<"allow">> => <<"1024">>, <<"deny">> => <<"0">>},
#{<<"id">> => <<"200">>, <<"type">> => 1, <<"allow">> => <<"2048">>, <<"deny">> => <<"0">>}
#{
<<"id">> => <<"100">>,
<<"type">> => 0,
<<"allow">> => <<"0">>,
<<"deny">> => <<"1024">>
},
#{
<<"id">> => <<"200">>,
<<"type">> => 0,
<<"allow">> => <<"1024">>,
<<"deny">> => <<"0">>
},
#{
<<"id">> => <<"1">>,
<<"type">> => 1,
<<"allow">> => <<"2048">>,
<<"deny">> => <<"0">>
}
]
}
]
@@ -441,26 +566,141 @@ handle_role_delete_preserves_user_overwrites_test() ->
Result = handle_role_delete(EventData, Data),
Channels = maps:get(<<"channels">>, Result),
[Ch1] = Channels,
Overwrites = maps:get(<<"permission_overwrites">>, Ch1),
?assertEqual(1, length(Overwrites)),
[RemainingOverwrite] = Overwrites,
?assertEqual(1, maps:get(<<"type">>, RemainingOverwrite)).
Ch1Overwrites = maps:get(<<"permission_overwrites">>, Ch1),
?assertEqual(2, length(Ch1Overwrites)).
handle_role_delete_removes_role_from_roles_list_test() ->
RoleIdToDelete = <<"200">>,
handle_member_add_test() ->
Data = #{<<"members">> => #{1 => #{<<"user">> => #{<<"id">> => <<"1">>}}}},
EventData = #{<<"user">> => #{<<"id">> => <<"2">>}},
Result = handle_member_add(EventData, Data),
Members = maps:get(<<"members">>, Result),
?assertEqual(2, map_size(Members)).
handle_member_update_test() ->
Data = #{
<<"roles">> => [
#{<<"id">> => <<"100">>, <<"name">> => <<"Admin">>},
#{<<"id">> => <<"200">>, <<"name">> => <<"Moderator">>}
],
<<"members">> => [],
<<"channels">> => []
<<"members">> => #{
1 => #{<<"user">> => #{<<"id">> => <<"1">>}, <<"nick">> => <<"OldNick">>}
}
},
EventData = #{<<"role_id">> => RoleIdToDelete},
Result = handle_role_delete(EventData, Data),
Roles = maps:get(<<"roles">>, Result),
?assertEqual(1, length(Roles)),
[RemainingRole] = Roles,
?assertEqual(<<"100">>, maps:get(<<"id">>, RemainingRole)).
EventData = #{<<"user">> => #{<<"id">> => <<"1">>}, <<"nick">> => <<"NewNick">>},
Result = handle_member_update(EventData, Data),
Members = maps:get(<<"members">>, Result),
Member = maps:get(1, Members),
?assertEqual(<<"NewNick">>, maps:get(<<"nick">>, Member)).
handle_channel_create_test() ->
Data = #{<<"channels">> => []},
EventData = #{<<"id">> => <<"100">>, <<"name">> => <<"general">>},
Result = handle_channel_create(EventData, Data),
Channels = maps:get(<<"channels">>, Result),
?assertEqual(1, length(Channels)).
bulk_update_items_test() ->
Items = [
#{<<"id">> => <<"1">>, <<"value">> => <<"old1">>},
#{<<"id">> => <<"2">>, <<"value">> => <<"old2">>}
],
BulkItems = [
#{<<"id">> => <<"1">>, <<"value">> => <<"new1">>}
],
Result = bulk_update_items(Items, BulkItems),
[Item1, Item2] = Result,
?assertEqual(<<"new1">>, maps:get(<<"value">>, Item1)),
?assertEqual(<<"old2">>, maps:get(<<"value">>, Item2)).
needs_visibility_check_test() ->
?assertEqual(true, needs_visibility_check(guild_role_update)),
?assertEqual(true, needs_visibility_check(channel_update)),
?assertEqual(false, needs_visibility_check(message_create)),
?assertEqual(false, needs_visibility_check(unknown_event)).
extract_channel_ids_from_channel_update_test() ->
?assertEqual([42], extract_channel_ids_from_channel_update(#{<<"id">> => <<"42">>})),
?assertEqual([], extract_channel_ids_from_channel_update(#{})).
extract_channel_ids_from_channel_update_bulk_test() ->
EventData = #{
<<"channels">> => [
#{<<"id">> => <<"10">>},
#{<<"id">> => <<"11">>},
#{<<"name">> => <<"missing_id">>}
]
},
?assertEqual([10, 11], extract_channel_ids_from_channel_update_bulk(EventData)).
guild_member_remove_disconnects_voice_test() ->
Self = self(),
TestFun = fun(GuildId, ChannelId, UserId, ConnectionId) ->
Self ! {force_disconnect, GuildId, ChannelId, UserId, ConnectionId},
{ok, #{success => true}}
end,
GuildId = 42,
UserId = 5,
ChannelId = 20,
Data = #{
<<"guild">> => #{<<"owner_id">> => <<"999">>},
<<"roles">> => [
#{<<"id">> => integer_to_binary(GuildId), <<"permissions">> => <<"0">>}
],
<<"members">> => #{
UserId => #{<<"user">> => #{<<"id">> => integer_to_binary(UserId)}, <<"roles">> => []}
},
<<"channels">> => [#{<<"id">> => integer_to_binary(ChannelId)}]
},
VoiceStates = #{
<<"conn">> => #{
<<"user_id">> => integer_to_binary(UserId),
<<"guild_id">> => integer_to_binary(GuildId),
<<"channel_id">> => integer_to_binary(ChannelId),
<<"connection_id">> => <<"conn">>
}
},
Sessions = #{<<"s1">> => #{user_id => UserId, pid => self()}},
State = #{
id => GuildId,
data => Data,
voice_states => VoiceStates,
sessions => Sessions,
test_force_disconnect_fun => TestFun
},
EventData = #{<<"user">> => #{<<"id">> => integer_to_binary(UserId)}},
UpdatedState = update_state(guild_member_remove, EventData, State),
?assertEqual(#{}, maps:get(voice_states, UpdatedState)),
?assertEqual(#{}, maps:get(sessions, UpdatedState, #{})),
receive
{force_disconnect, GuildId, ChannelId, UserId, <<"conn">>} -> ok
after 200 ->
?assert(false)
end.
guild_update_syncs_unavailability_cache_test() ->
GuildId = 420042,
CleanupState = #{
id => GuildId,
data => #{
<<"guild">> => #{<<"features">> => []}
}
},
_ = guild_availability:update_unavailability_cache_for_state(CleanupState),
try
State0 = #{
id => GuildId,
data => #{
<<"guild">> => #{<<"features">> => []},
<<"members">> => []
},
sessions => #{}
},
State1 = update_state(
guild_update,
#{<<"features">> => [<<"UNAVAILABLE_FOR_EVERYONE">>]},
State0
),
?assertEqual(unavailable_for_everyone, guild_availability:get_cached_unavailability_mode(GuildId)),
_State2 = update_state(guild_update, #{<<"features">> => []}, State1),
?assertEqual(available, guild_availability:get_cached_unavailability_mode(GuildId))
after
_ = guild_availability:update_unavailability_cache_for_state(CleanupState)
end.
-endif.

View File

@@ -77,10 +77,8 @@ unsubscribe_session(SessionId, State) ->
update_subscriptions(SessionId, NewMemberIds, State) ->
CurrentlySubscribed = get_user_ids_for_session(SessionId, State),
NewMemberIdSet = sets:from_list(NewMemberIds),
ToRemove = sets:subtract(CurrentlySubscribed, NewMemberIdSet),
ToAdd = sets:subtract(NewMemberIdSet, CurrentlySubscribed),
State1 = sets:fold(
fun(UserId, AccState) ->
unsubscribe(SessionId, UserId, AccState)
@@ -88,7 +86,6 @@ update_subscriptions(SessionId, NewMemberIds, State) ->
State,
ToRemove
),
sets:fold(
fun(UserId, AccState) ->
subscribe(SessionId, UserId, AccState)
@@ -182,4 +179,11 @@ update_subscriptions_test() ->
?assert(is_subscribed(<<"session1">>, 200, State3)),
?assert(is_subscribed(<<"session1">>, 300, State3)).
get_user_ids_for_session_test() ->
State0 = init_state(),
State1 = subscribe(<<"session1">>, 100, State0),
State2 = subscribe(<<"session1">>, 200, State1),
UserIds = get_user_ids_for_session(<<"session1">>, State2),
?assertEqual([100, 200], lists:sort(sets:to_list(UserIds))).
-endif.

View File

@@ -19,5 +19,6 @@
-export([send_guild_sync/2]).
-spec send_guild_sync(pid(), binary()) -> ok.
send_guild_sync(GuildPid, SessionId) ->
gen_server:cast(GuildPid, {send_guild_sync, SessionId}).

View File

@@ -23,95 +23,88 @@
-include_lib("eunit/include/eunit.hrl").
-endif.
-spec handle_subscriptions(map(), pid(), map()) -> ok.
-type session_state() :: map().
-type session_id() :: binary() | undefined.
-spec handle_subscriptions(map(), pid(), session_state()) -> ok.
handle_subscriptions(Data, SocketPid, SessionState) ->
Subscriptions = maps:get(<<"subscriptions">>, Data, #{}),
Guilds = maps:get(guilds, SessionState, #{}),
SessionId = maps:get(id, SessionState, undefined),
logger:debug("[guild_unified_subscriptions] Processing ~p guild subscriptions for session ~p", [
map_size(Subscriptions), SessionId
]),
maps:foreach(
fun(GuildIdBin, GuildSubData) ->
process_guild_subscription(GuildIdBin, GuildSubData, Guilds, SessionId, SocketPid, SessionState)
process_guild_subscription(
GuildIdBin, GuildSubData, Guilds, SessionId, SocketPid, SessionState
)
end,
Subscriptions
),
ok.
-spec process_guild_subscription(binary(), map(), map(), binary() | undefined, pid(), map()) -> ok.
-spec process_guild_subscription(binary(), map(), map(), session_id(), pid(), session_state()) ->
ok.
process_guild_subscription(GuildIdBin, GuildSubData, Guilds, SessionId, SocketPid, SessionState) ->
case validation:validate_snowflake(<<"guild_id">>, GuildIdBin) of
{ok, GuildId} ->
case maps:get(GuildId, Guilds, undefined) of
{GuildPid, _Ref} when is_pid(GuildPid) ->
process_guild_sub_options(GuildId, GuildPid, GuildSubData, SessionId, SocketPid, SessionState);
process_guild_sub_options(
GuildId, GuildPid, GuildSubData, SessionId, SocketPid, SessionState
);
undefined ->
logger:warning("[guild_unified_subscriptions] Guild ~p not found in session state", [GuildId]),
ok;
_ ->
ok
end;
{error, _, Reason} ->
logger:warning("[guild_unified_subscriptions] Invalid guild_id ~p: ~p", [GuildIdBin, Reason]),
{error, _, _} ->
ok
end.
-spec process_guild_sub_options(integer(), pid(), map(), binary() | undefined, pid(), map()) -> ok.
-spec process_guild_sub_options(integer(), pid(), map(), session_id(), pid(), session_state()) ->
ok.
process_guild_sub_options(GuildId, GuildPid, GuildSubData, SessionId, SocketPid, SessionState) ->
WasActive = not session_passive:is_passive(GuildId, SessionState),
ActiveChanged = process_active_flag(GuildSubData, GuildPid, SessionId, WasActive),
process_sync_flag(GuildSubData, GuildId, GuildPid, SessionId, ActiveChanged),
process_member_list_channels(GuildSubData, GuildId, GuildPid, SessionId, SocketPid),
process_member_subscriptions(GuildSubData, GuildPid, SessionId),
process_typing_flag(GuildSubData, GuildPid, SessionId),
ok.
-spec process_active_flag(map(), pid(), binary() | undefined, boolean()) -> boolean().
-spec process_active_flag(map(), pid(), session_id(), boolean()) -> boolean().
process_active_flag(GuildSubData, GuildPid, SessionId, WasActive) ->
case maps:get(<<"active">>, GuildSubData, undefined) of
undefined ->
false;
true ->
gen_server:cast(GuildPid, {set_session_active, SessionId}),
logger:debug("[guild_unified_subscriptions] Set session ~p active", [SessionId]),
not WasActive;
false ->
gen_server:cast(GuildPid, {set_session_passive, SessionId}),
logger:debug("[guild_unified_subscriptions] Set session ~p passive", [SessionId]),
WasActive
end.
-spec process_sync_flag(map(), integer(), pid(), binary() | undefined, boolean()) -> ok.
process_sync_flag(GuildSubData, GuildId, GuildPid, SessionId, ActiveChanged) ->
-spec process_sync_flag(map(), integer(), pid(), session_id(), boolean()) -> ok.
process_sync_flag(GuildSubData, _GuildId, GuildPid, SessionId, ActiveChanged) ->
ShouldSync = maps:get(<<"sync">>, GuildSubData, false) =:= true orelse ActiveChanged,
case ShouldSync of
true ->
guild_sync:send_guild_sync(GuildPid, SessionId),
logger:debug("[guild_unified_subscriptions] Sent guild sync for guild ~p", [GuildId]);
guild_sync:send_guild_sync(GuildPid, SessionId);
false ->
ok
end.
-spec process_member_list_channels(map(), integer(), pid(), binary() | undefined, pid()) -> ok.
-spec process_member_list_channels(map(), integer(), pid(), session_id(), pid()) -> ok.
process_member_list_channels(GuildSubData, GuildId, GuildPid, SessionId, SocketPid) ->
case maps:get(<<"member_list_channels">>, GuildSubData, undefined) of
undefined ->
ok;
MemberListChannels when is_map(MemberListChannels) ->
logger:debug("[guild_unified_subscriptions] Processing ~p member list channels for guild ~p", [
map_size(MemberListChannels), GuildId
]),
maps:foreach(
fun(ChannelIdBin, Ranges) ->
process_channel_lazy_subscribe(ChannelIdBin, Ranges, GuildId, GuildPid, SessionId, SocketPid)
process_channel_lazy_subscribe(
ChannelIdBin, Ranges, GuildId, GuildPid, SessionId, SocketPid
)
end,
MemberListChannels
);
@@ -119,28 +112,29 @@ process_member_list_channels(GuildSubData, GuildId, GuildPid, SessionId, SocketP
ok
end.
-spec process_channel_lazy_subscribe(binary(), list(), integer(), pid(), binary() | undefined, pid()) -> ok.
-spec process_channel_lazy_subscribe(binary(), list(), integer(), pid(), session_id(), pid()) -> ok.
process_channel_lazy_subscribe(ChannelIdBin, Ranges, _GuildId, GuildPid, SessionId, _SocketPid) ->
case validation:validate_snowflake(<<"channel_id">>, ChannelIdBin) of
{ok, ChannelId} ->
ParsedRanges = parse_ranges(Ranges),
logger:debug("[guild_unified_subscriptions] Lazy subscribe channel ~p with ranges ~p", [
ChannelId, ParsedRanges
]),
case gen_server:call(GuildPid, {lazy_subscribe, #{
session_id => SessionId,
channel_id => ChannelId,
ranges => ParsedRanges
}}, 10000) of
case
gen_server:call(
GuildPid,
{lazy_subscribe, #{
session_id => SessionId,
channel_id => ChannelId,
ranges => ParsedRanges
}},
10000
)
of
ok ->
ok;
Error ->
logger:error("[guild_unified_subscriptions] lazy_subscribe failed for channel ~p: ~p", [
ChannelId, Error
])
_Error ->
ok
end;
{error, _, Reason} ->
logger:warning("[guild_unified_subscriptions] Invalid channel_id ~p: ~p", [ChannelIdBin, Reason])
{error, _, _} ->
ok
end,
ok.
@@ -160,16 +154,13 @@ parse_ranges(Ranges) when is_list(Ranges) ->
parse_ranges(_) ->
[].
-spec process_member_subscriptions(map(), pid(), binary() | undefined) -> ok.
-spec process_member_subscriptions(map(), pid(), session_id()) -> ok.
process_member_subscriptions(GuildSubData, GuildPid, SessionId) ->
case maps:get(<<"members">>, GuildSubData, undefined) of
undefined ->
ok;
Members when is_list(Members) ->
MemberIds = parse_member_ids(Members),
logger:debug("[guild_unified_subscriptions] Updating member subscriptions with ~p members", [
length(MemberIds)
]),
gen_server:cast(GuildPid, {update_member_subscriptions, SessionId, MemberIds});
_ ->
ok
@@ -189,16 +180,13 @@ parse_member_ids(Members) when is_list(Members) ->
parse_member_ids(_) ->
[].
-spec process_typing_flag(map(), pid(), binary() | undefined) -> ok.
-spec process_typing_flag(map(), pid(), session_id()) -> ok.
process_typing_flag(GuildSubData, GuildPid, SessionId) ->
case maps:get(<<"typing">>, GuildSubData, undefined) of
undefined ->
ok;
TypingFlag when is_boolean(TypingFlag) ->
gen_server:cast(GuildPid, {set_session_typing_override, SessionId, TypingFlag}),
logger:debug("[guild_unified_subscriptions] Set typing override to ~p for session ~p", [
TypingFlag, SessionId
]);
gen_server:cast(GuildPid, {set_session_typing_override, SessionId, TypingFlag});
_ ->
ok
end.

View File

@@ -24,129 +24,155 @@
check_user_data_differs/2
]).
-import(guild_permissions, [find_member_by_user_id/2]).
-type guild_state() :: map().
-type user_id() :: integer().
-type member() :: map().
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-endif.
-spec update_user_data(map(), guild_state()) -> {noreply, guild_state()}.
update_user_data(EventData, State) ->
UserId = utils:binary_to_integer_safe(maps:get(<<"id">>, EventData)),
Data = maps:get(data, State),
Members = maps:get(<<"members">>, Data, []),
UpdatedMembers = lists:map(
fun(Member) when is_map(Member) ->
MUser = maps:get(<<"user">>, Member, #{}),
MemberId =
case is_map(MUser) of
true ->
utils:binary_to_integer_safe(maps:get(<<"id">>, MUser, <<"0">>));
false ->
undefined
end,
if
MemberId =:= UserId ->
maps:put(<<"user">>, EventData, Member);
true ->
Member
end
Members = guild_data_index:member_map(Data),
UpdatedMembers = maps:map(
fun(_MemberUserId, Member) ->
maybe_update_member_user(Member, UserId, EventData)
end,
Members
),
UpdatedData = maps:put(<<"members">>, UpdatedMembers, Data),
UpdatedData = guild_data_index:put_member_map(UpdatedMembers, Data),
UpdatedState = maps:put(data, UpdatedData, State),
UpdatedMember = find_member_by_user_id(UserId, UpdatedState),
case UpdatedMember of
undefined -> ok;
M -> gen_server:cast(self(), {dispatch, #{event => guild_member_update, data => M}})
end,
dispatch_member_update_if_found(UserId, UpdatedState),
{noreply, UpdatedState}.
handle_user_data_update(UserId, UserData, State) ->
Data = maps:get(data, State),
Members = maps:get(<<"members">>, Data, []),
-spec maybe_update_member_user(member(), user_id(), map()) -> member().
maybe_update_member_user(Member, UserId, EventData) ->
MUser = maps:get(<<"user">>, Member, #{}),
MemberId =
case is_map(MUser) of
true -> utils:binary_to_integer_safe(maps:get(<<"id">>, MUser, <<"0">>));
false -> undefined
end,
case MemberId =:= UserId of
true -> maps:put(<<"user">>, EventData, Member);
false -> Member
end.
CurrentMember = find_member_by_user_id(UserId, State),
case CurrentMember of
-spec dispatch_member_update_if_found(user_id(), guild_state()) -> ok.
dispatch_member_update_if_found(UserId, State) ->
case guild_permissions:find_member_by_user_id(UserId, State) of
undefined -> ok;
M -> gen_server:cast(self(), {dispatch, #{event => guild_member_update, data => M}})
end.
-spec handle_user_data_update(user_id(), map(), guild_state()) -> guild_state().
handle_user_data_update(UserId, UserData, State) ->
case guild_permissions:find_member_by_user_id(UserId, State) of
undefined ->
State;
Member ->
CurrentUserData = maps:get(<<"user">>, Member, #{}),
IsDifferent = check_user_data_differs(CurrentUserData, UserData),
if
IsDifferent ->
UpdatedMembers = lists:map(
fun(M) when is_map(M) ->
MUser = maps:get(<<"user">>, M, #{}),
MemberId =
case is_map(MUser) of
true ->
utils:binary_to_integer_safe(
maps:get(<<"id">>, MUser, <<"0">>)
);
false ->
undefined
end,
if
MemberId =:= UserId ->
maps:put(<<"user">>, UserData, M);
true ->
M
end
end,
Members
),
UpdatedData = maps:put(<<"members">>, UpdatedMembers, Data),
UpdatedState = maps:put(data, UpdatedData, State),
UpdatedMember = find_member_by_user_id(UserId, UpdatedState),
case UpdatedMember of
undefined ->
ok;
M ->
GuildId = maps:get(id, UpdatedState),
MemberUpdateData = maps:put(
<<"guild_id">>, integer_to_binary(GuildId), M
),
gen_server:cast(
self(),
{dispatch, #{
event => guild_member_update, data => MemberUpdateData
}}
)
end,
UpdatedState;
case check_user_data_differs(CurrentUserData, UserData) of
false ->
State;
true ->
State
apply_user_data_update(UserId, UserData, State)
end
end.
-spec apply_user_data_update(user_id(), map(), guild_state()) -> guild_state().
apply_user_data_update(UserId, UserData, State) ->
Data = maps:get(data, State),
Members = guild_data_index:member_map(Data),
UpdatedMembers = maps:map(
fun(_MemberUserId, Member) ->
maybe_update_member_user(Member, UserId, UserData)
end,
Members
),
UpdatedData = guild_data_index:put_member_map(UpdatedMembers, Data),
UpdatedState = maps:put(data, UpdatedData, State),
dispatch_guild_member_update(UserId, UpdatedState),
UpdatedState.
-spec dispatch_guild_member_update(user_id(), guild_state()) -> ok.
dispatch_guild_member_update(UserId, State) ->
case guild_permissions:find_member_by_user_id(UserId, State) of
undefined ->
ok;
M ->
GuildId = maps:get(id, State),
MemberUpdateData = maps:put(<<"guild_id">>, integer_to_binary(GuildId), M),
gen_server:cast(
self(),
{dispatch, #{event => guild_member_update, data => MemberUpdateData}}
)
end.
-spec check_user_data_differs(map(), map()) -> boolean().
check_user_data_differs(CurrentUserData, NewUserData) ->
utils:check_user_data_differs(CurrentUserData, NewUserData).
maybe_update_cached_user_data(Event, EventData, State) ->
case Event of
E when E =:= message_create; E =:= message_update ->
case maps:get(<<"author">>, EventData, undefined) of
-spec maybe_update_cached_user_data(atom(), map(), guild_state()) -> guild_state().
maybe_update_cached_user_data(Event, EventData, State) when
Event =:= message_create; Event =:= message_update
->
case maps:get(<<"author">>, EventData, undefined) of
undefined ->
State;
AuthorData ->
UserId = utils:binary_to_integer_safe(maps:get(<<"id">>, AuthorData, <<"0">>)),
case guild_permissions:find_member_by_user_id(UserId, State) of
undefined ->
State;
AuthorData ->
UserId = utils:binary_to_integer_safe(maps:get(<<"id">>, AuthorData, <<"0">>)),
case find_member_by_user_id(UserId, State) of
undefined ->
State;
Member ->
CurrentUserData = maps:get(<<"user">>, Member, #{}),
case check_user_data_differs(CurrentUserData, AuthorData) of
true ->
handle_user_data_update(UserId, AuthorData, State);
false ->
State
end
Member ->
CurrentUserData = maps:get(<<"user">>, Member, #{}),
case check_user_data_differs(CurrentUserData, AuthorData) of
true ->
handle_user_data_update(UserId, AuthorData, State);
false ->
State
end
end;
_ ->
State
end.
end
end;
maybe_update_cached_user_data(_, _, State) ->
State.
-ifdef(TEST).
update_user_data_updates_member_test() ->
State = test_state(),
EventData = #{<<"id">> => <<"100">>, <<"username">> => <<"updated">>},
{noreply, UpdatedState} = update_user_data(EventData, State),
Data = maps:get(data, UpdatedState),
Member = maps:get(100, maps:get(<<"members">>, Data)),
User = maps:get(<<"user">>, Member),
?assertEqual(<<"updated">>, maps:get(<<"username">>, User)).
handle_user_data_update_no_change_test() ->
State = test_state(),
UserData = #{<<"id">> => <<"100">>, <<"username">> => <<"alice">>},
NewState = handle_user_data_update(100, UserData, State),
?assertEqual(State, NewState).
check_user_data_differs_test() ->
Current = #{<<"username">> => <<"alice">>},
Same = #{<<"username">> => <<"alice">>},
Different = #{<<"username">> => <<"bob">>},
?assertEqual(false, check_user_data_differs(Current, Same)),
?assertEqual(true, check_user_data_differs(Current, Different)).
test_state() ->
#{
id => 42,
data => #{
<<"members">> => #{
100 => #{<<"user">> => #{<<"id">> => <<"100">>, <<"username">> => <<"alice">>}}
}
}
}.
-endif.

View File

@@ -23,18 +23,37 @@
has_virtual_access/3,
get_virtual_channels_for_user/2,
get_users_with_virtual_access/2,
dispatch_channel_visibility_change/4
dispatch_channel_visibility_change/4,
mark_pending_join/3,
clear_pending_join/3,
is_pending_join/3,
mark_preserve/3,
clear_preserve/3,
has_preserve/3,
mark_move_pending/3,
clear_move_pending/3,
is_move_pending/3
]).
-import(guild_permissions, [find_channel_by_id/2]).
-type guild_state() :: map().
-type user_id() :: integer().
-type channel_id() :: integer().
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-endif.
-spec add_virtual_access(user_id(), channel_id(), guild_state()) -> guild_state().
add_virtual_access(UserId, ChannelId, State) ->
VirtualAccess = maps:get(virtual_channel_access, State, #{}),
UserChannels = maps:get(UserId, VirtualAccess, sets:new()),
UpdatedUserChannels = sets:add_element(ChannelId, UserChannels),
UpdatedVirtualAccess = maps:put(UserId, UpdatedUserChannels, VirtualAccess),
maps:put(virtual_channel_access, UpdatedVirtualAccess, State).
State1 = maps:put(virtual_channel_access, UpdatedVirtualAccess, State),
State2 = update_user_session_view_cache(UserId, ChannelId, add, State1),
mark_pending_join(UserId, ChannelId, State2).
-spec remove_virtual_access(user_id(), channel_id(), guild_state()) -> guild_state().
remove_virtual_access(UserId, ChannelId, State) ->
VirtualAccess = maps:get(virtual_channel_access, State, #{}),
case maps:get(UserId, VirtualAccess, undefined) of
@@ -44,32 +63,149 @@ remove_virtual_access(UserId, ChannelId, State) ->
UpdatedUserChannels = sets:del_element(ChannelId, UserChannels),
case sets:size(UpdatedUserChannels) of
0 ->
UpdatedVirtualAccess = maps:remove(UserId, VirtualAccess),
maps:put(virtual_channel_access, UpdatedVirtualAccess, State);
remove_all_user_virtual_access(UserId, State);
_ ->
UpdatedVirtualAccess = maps:put(UserId, UpdatedUserChannels, VirtualAccess),
maps:put(virtual_channel_access, UpdatedVirtualAccess, State)
update_user_virtual_access(UserId, ChannelId, UpdatedUserChannels, State)
end
end.
-spec remove_all_user_virtual_access(user_id(), guild_state()) -> guild_state().
remove_all_user_virtual_access(UserId, State) ->
VirtualAccess = maps:get(virtual_channel_access, State, #{}),
PendingMap = maps:get(virtual_channel_access_pending, State, #{}),
PreserveMap = maps:get(virtual_channel_access_preserve, State, #{}),
MoveMap = maps:get(virtual_channel_access_move_pending, State, #{}),
UpdatedVirtualAccess = maps:remove(UserId, VirtualAccess),
UpdatedPending = maps:remove(UserId, PendingMap),
UpdatedPreserve = maps:remove(UserId, PreserveMap),
UpdatedMove = maps:remove(UserId, MoveMap),
State1 = maps:put(virtual_channel_access, UpdatedVirtualAccess, State),
State2 = maps:put(virtual_channel_access_pending, UpdatedPending, State1),
State3 = maps:put(virtual_channel_access_preserve, UpdatedPreserve, State2),
State4 = maps:put(virtual_channel_access_move_pending, UpdatedMove, State3),
clear_user_session_view_cache(UserId, State4).
-spec update_user_virtual_access(user_id(), channel_id(), sets:set(), guild_state()) ->
guild_state().
update_user_virtual_access(UserId, ChannelId, UpdatedUserChannels, State) ->
VirtualAccess = maps:get(virtual_channel_access, State, #{}),
PendingMap = maps:get(virtual_channel_access_pending, State, #{}),
PreserveMap = maps:get(virtual_channel_access_preserve, State, #{}),
MoveMap = maps:get(virtual_channel_access_move_pending, State, #{}),
UpdatedVirtualAccess = maps:put(UserId, UpdatedUserChannels, VirtualAccess),
UpdatedUserPending = sets:del_element(ChannelId, maps:get(UserId, PendingMap, sets:new())),
UpdatedPending = maps:put(UserId, UpdatedUserPending, PendingMap),
UpdatedUserPreserve = sets:del_element(ChannelId, maps:get(UserId, PreserveMap, sets:new())),
UpdatedPreserve = maps:put(UserId, UpdatedUserPreserve, PreserveMap),
UpdatedUserMove = sets:del_element(ChannelId, maps:get(UserId, MoveMap, sets:new())),
UpdatedMove = maps:put(UserId, UpdatedUserMove, MoveMap),
State1 = maps:put(virtual_channel_access, UpdatedVirtualAccess, State),
State2 = maps:put(virtual_channel_access_pending, UpdatedPending, State1),
State3 = maps:put(virtual_channel_access_preserve, UpdatedPreserve, State2),
State4 = maps:put(virtual_channel_access_move_pending, UpdatedMove, State3),
update_user_session_view_cache(UserId, ChannelId, remove, State4).
-spec has_virtual_access(user_id(), channel_id(), guild_state()) -> boolean().
has_virtual_access(UserId, ChannelId, State) ->
VirtualAccess = maps:get(virtual_channel_access, State, #{}),
case maps:get(UserId, VirtualAccess, undefined) of
undefined ->
false;
UserChannels ->
sets:is_element(ChannelId, UserChannels)
undefined -> false;
UserChannels -> sets:is_element(ChannelId, UserChannels)
end.
-spec get_virtual_channels_for_user(user_id(), guild_state()) -> [channel_id()].
get_virtual_channels_for_user(UserId, State) ->
VirtualAccess = maps:get(virtual_channel_access, State, #{}),
case maps:get(UserId, VirtualAccess, undefined) of
undefined ->
[];
UserChannels ->
sets:to_list(UserChannels)
undefined -> [];
UserChannels -> sets:to_list(UserChannels)
end.
-spec mark_pending_join(user_id(), channel_id(), guild_state()) -> guild_state().
mark_pending_join(UserId, ChannelId, State) ->
PendingMap = maps:get(virtual_channel_access_pending, State, #{}),
UserPending = maps:get(UserId, PendingMap, sets:new()),
UpdatedUserPending = sets:add_element(ChannelId, UserPending),
UpdatedPending = maps:put(UserId, UpdatedUserPending, PendingMap),
maps:put(virtual_channel_access_pending, UpdatedPending, State).
-spec clear_pending_join(user_id(), channel_id(), guild_state()) -> guild_state().
clear_pending_join(UserId, ChannelId, State) ->
PendingMap = maps:get(virtual_channel_access_pending, State, #{}),
UserPending = maps:get(UserId, PendingMap, sets:new()),
UpdatedUserPending = sets:del_element(ChannelId, UserPending),
UpdatedPending =
case sets:size(UpdatedUserPending) of
0 -> maps:remove(UserId, PendingMap);
_ -> maps:put(UserId, UpdatedUserPending, PendingMap)
end,
maps:put(virtual_channel_access_pending, UpdatedPending, State).
-spec is_pending_join(user_id(), channel_id(), guild_state()) -> boolean().
is_pending_join(UserId, ChannelId, State) ->
PendingMap = maps:get(virtual_channel_access_pending, State, #{}),
case maps:get(UserId, PendingMap, undefined) of
undefined -> false;
UserPending -> sets:is_element(ChannelId, UserPending)
end.
-spec mark_preserve(user_id(), channel_id(), guild_state()) -> guild_state().
mark_preserve(UserId, ChannelId, State) ->
PreserveMap = maps:get(virtual_channel_access_preserve, State, #{}),
UserPreserve = maps:get(UserId, PreserveMap, sets:new()),
UpdatedUserPreserve = sets:add_element(ChannelId, UserPreserve),
UpdatedPreserve = maps:put(UserId, UpdatedUserPreserve, PreserveMap),
maps:put(virtual_channel_access_preserve, UpdatedPreserve, State).
-spec clear_preserve(user_id(), channel_id(), guild_state()) -> guild_state().
clear_preserve(UserId, ChannelId, State) ->
PreserveMap = maps:get(virtual_channel_access_preserve, State, #{}),
UserPreserve = maps:get(UserId, PreserveMap, sets:new()),
UpdatedUserPreserve = sets:del_element(ChannelId, UserPreserve),
UpdatedPreserve =
case sets:size(UpdatedUserPreserve) of
0 -> maps:remove(UserId, PreserveMap);
_ -> maps:put(UserId, UpdatedUserPreserve, PreserveMap)
end,
maps:put(virtual_channel_access_preserve, UpdatedPreserve, State).
-spec has_preserve(user_id(), channel_id(), guild_state()) -> boolean().
has_preserve(UserId, ChannelId, State) ->
PreserveMap = maps:get(virtual_channel_access_preserve, State, #{}),
case maps:get(UserId, PreserveMap, undefined) of
undefined -> false;
UserPreserve -> sets:is_element(ChannelId, UserPreserve)
end.
-spec mark_move_pending(user_id(), channel_id(), guild_state()) -> guild_state().
mark_move_pending(UserId, ChannelId, State) ->
MoveMap = maps:get(virtual_channel_access_move_pending, State, #{}),
UserMoves = maps:get(UserId, MoveMap, sets:new()),
UpdatedUserMoves = sets:add_element(ChannelId, UserMoves),
UpdatedMoveMap = maps:put(UserId, UpdatedUserMoves, MoveMap),
maps:put(virtual_channel_access_move_pending, UpdatedMoveMap, State).
-spec clear_move_pending(user_id(), channel_id(), guild_state()) -> guild_state().
clear_move_pending(UserId, ChannelId, State) ->
MoveMap = maps:get(virtual_channel_access_move_pending, State, #{}),
UserMoves = maps:get(UserId, MoveMap, sets:new()),
UpdatedUserMoves = sets:del_element(ChannelId, UserMoves),
UpdatedMoveMap =
case sets:size(UpdatedUserMoves) of
0 -> maps:remove(UserId, MoveMap);
_ -> maps:put(UserId, UpdatedUserMoves, MoveMap)
end,
maps:put(virtual_channel_access_move_pending, UpdatedMoveMap, State).
-spec is_move_pending(user_id(), channel_id(), guild_state()) -> boolean().
is_move_pending(UserId, ChannelId, State) ->
MoveMap = maps:get(virtual_channel_access_move_pending, State, #{}),
case maps:get(UserId, MoveMap, undefined) of
undefined -> false;
UserMoves -> sets:is_element(ChannelId, UserMoves)
end.
-spec get_users_with_virtual_access(channel_id(), guild_state()) -> [user_id()].
get_users_with_virtual_access(ChannelId, State) ->
VirtualAccess = maps:get(virtual_channel_access, State, #{}),
maps:fold(
@@ -83,45 +219,160 @@ get_users_with_virtual_access(ChannelId, State) ->
VirtualAccess
).
-spec dispatch_channel_visibility_change(user_id(), channel_id(), add | remove, guild_state()) ->
ok.
dispatch_channel_visibility_change(UserId, ChannelId, Action, State) ->
Channel = find_channel_by_id(ChannelId, State),
Channel = guild_permissions:find_channel_by_id(ChannelId, State),
case Channel of
undefined ->
ok;
_ ->
Sessions = maps:get(sessions, State, #{}),
GuildId = maps:get(id, State),
UserSessions = maps:filter(
fun(_Sid, SessionData) ->
maps:get(user_id, SessionData) =:= UserId
end,
Sessions
),
case Action of
add ->
ChannelWithGuild = maps:put(
<<"guild_id">>, integer_to_binary(GuildId), Channel
),
maps:foreach(
fun(_Sid, SessionData) ->
Pid = maps:get(pid, SessionData),
gen_server:cast(Pid, {dispatch, channel_create, ChannelWithGuild})
end,
UserSessions
);
remove ->
ChannelDelete = #{
<<"id">> => integer_to_binary(ChannelId),
<<"guild_id">> => integer_to_binary(GuildId)
},
maps:foreach(
fun(_Sid, SessionData) ->
Pid = maps:get(pid, SessionData),
gen_server:cast(Pid, {dispatch, channel_delete, ChannelDelete})
end,
UserSessions
)
end
dispatch_to_user_sessions(Action, Channel, ChannelId, GuildId, UserSessions)
end.
-spec dispatch_to_user_sessions(add | remove, map(), channel_id(), integer(), map()) -> ok.
dispatch_to_user_sessions(add, Channel, _ChannelId, GuildId, UserSessions) ->
ChannelWithGuild = maps:put(<<"guild_id">>, integer_to_binary(GuildId), Channel),
maps:foreach(
fun(_Sid, SessionData) ->
Pid = maps:get(pid, SessionData),
gen_server:cast(Pid, {dispatch, channel_create, ChannelWithGuild})
end,
UserSessions
);
dispatch_to_user_sessions(remove, _Channel, ChannelId, GuildId, UserSessions) ->
ChannelDelete = #{
<<"id">> => integer_to_binary(ChannelId),
<<"guild_id">> => integer_to_binary(GuildId)
},
maps:foreach(
fun(_Sid, SessionData) ->
Pid = maps:get(pid, SessionData),
gen_server:cast(Pid, {dispatch, channel_delete, ChannelDelete})
end,
UserSessions
).
-spec update_user_session_view_cache(user_id(), channel_id(), add | remove, guild_state()) ->
guild_state().
update_user_session_view_cache(UserId, ChannelId, Action, State) ->
Sessions = maps:get(sessions, State, #{}),
UpdatedSessions = maps:map(
fun(_SessionId, SessionData) ->
case maps:get(user_id, SessionData, undefined) of
UserId ->
update_session_view_cache(SessionData, ChannelId, Action);
_ ->
SessionData
end
end,
Sessions
),
maps:put(sessions, UpdatedSessions, State).
-spec clear_user_session_view_cache(user_id(), guild_state()) -> guild_state().
clear_user_session_view_cache(UserId, State) ->
Sessions = maps:get(sessions, State, #{}),
UpdatedSessions = maps:map(
fun(_SessionId, SessionData) ->
case maps:get(user_id, SessionData, undefined) of
UserId ->
maps:put(viewable_channels, #{}, SessionData);
_ ->
SessionData
end
end,
Sessions
),
maps:put(sessions, UpdatedSessions, State).
-spec update_session_view_cache(map(), channel_id(), add | remove) -> map().
update_session_view_cache(SessionData, ChannelId, add) ->
ViewableChannels = ensure_viewable_channel_map(maps:get(viewable_channels, SessionData, #{})),
UpdatedViewableChannels = maps:put(ChannelId, true, ViewableChannels),
maps:put(viewable_channels, UpdatedViewableChannels, SessionData);
update_session_view_cache(SessionData, ChannelId, remove) ->
ViewableChannels = ensure_viewable_channel_map(maps:get(viewable_channels, SessionData, #{})),
UpdatedViewableChannels = maps:remove(ChannelId, ViewableChannels),
maps:put(viewable_channels, UpdatedViewableChannels, SessionData).
-spec ensure_viewable_channel_map(term()) -> map().
ensure_viewable_channel_map(ViewableChannels) when is_map(ViewableChannels) ->
ViewableChannels;
ensure_viewable_channel_map(_) ->
#{}.
-ifdef(TEST).
add_virtual_access_test() ->
State = #{},
State1 = add_virtual_access(100, 500, State),
?assertEqual(true, has_virtual_access(100, 500, State1)),
?assertEqual(true, is_pending_join(100, 500, State1)).
add_virtual_access_updates_session_cache_test() ->
State = #{
sessions => #{
<<"s1">> => #{user_id => 100, viewable_channels => #{}}
}
},
State1 = add_virtual_access(100, 500, State),
Session = maps:get(<<"s1">>, maps:get(sessions, State1)),
ViewableChannels = maps:get(viewable_channels, Session, #{}),
?assertEqual(true, maps:is_key(500, ViewableChannels)).
remove_virtual_access_test() ->
State = add_virtual_access(100, 500, #{}),
State1 = remove_virtual_access(100, 500, State),
?assertEqual(false, has_virtual_access(100, 500, State1)).
remove_virtual_access_updates_session_cache_test() ->
State = #{
sessions => #{
<<"s1">> => #{user_id => 100, viewable_channels => #{500 => true}}
},
virtual_channel_access => #{100 => sets:from_list([500])},
virtual_channel_access_pending => #{100 => sets:from_list([500])},
virtual_channel_access_preserve => #{100 => sets:new()},
virtual_channel_access_move_pending => #{100 => sets:new()}
},
State1 = remove_virtual_access(100, 500, State),
Session = maps:get(<<"s1">>, maps:get(sessions, State1)),
ViewableChannels = maps:get(viewable_channels, Session, #{}),
?assertEqual(false, maps:is_key(500, ViewableChannels)).
get_virtual_channels_for_user_test() ->
State = add_virtual_access(100, 500, #{}),
State1 = add_virtual_access(100, 501, State),
Channels = lists:sort(get_virtual_channels_for_user(100, State1)),
?assertEqual([500, 501], Channels).
get_users_with_virtual_access_test() ->
State = add_virtual_access(100, 500, #{}),
State1 = add_virtual_access(101, 500, State),
Users = lists:sort(get_users_with_virtual_access(500, State1)),
?assertEqual([100, 101], Users).
mark_and_clear_preserve_test() ->
State = add_virtual_access(100, 500, #{}),
State1 = mark_preserve(100, 500, State),
?assertEqual(true, has_preserve(100, 500, State1)),
State2 = clear_preserve(100, 500, State1),
?assertEqual(false, has_preserve(100, 500, State2)).
mark_and_clear_move_pending_test() ->
State = add_virtual_access(100, 500, #{}),
State1 = mark_move_pending(100, 500, State),
?assertEqual(true, is_move_pending(100, 500, State1)),
State2 = clear_move_pending(100, 500, State1),
?assertEqual(false, is_move_pending(100, 500, State2)).
-endif.

View File

@@ -20,19 +20,25 @@
-export([
get_user_viewable_channels/2,
compute_and_dispatch_visibility_changes/2,
compute_and_dispatch_visibility_changes_for_users/3,
compute_and_dispatch_visibility_changes_for_channels/3,
viewable_channel_set/2,
have_shared_viewable_channel/3
]).
-import(guild_member_list, [calculate_list_id/2, build_sync_response/4]).
-import(guild_permissions, [can_view_channel/4, find_member_by_user_id/2, find_channel_by_id/2]).
-type guild_state() :: map().
-type user_id() :: integer().
-type channel_id() :: integer().
-spec get_user_viewable_channels(integer(), map()) -> [integer()].
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-endif.
-spec get_user_viewable_channels(user_id(), guild_state()) -> [channel_id()].
get_user_viewable_channels(UserId, State) ->
Data = map_utils:ensure_map(map_utils:get_safe(State, data, #{})),
Channels = map_utils:ensure_list(maps:get(<<"channels">>, Data, [])),
Member = find_member_by_user_id(UserId, State),
Member = guild_permissions:find_member_by_user_id(UserId, State),
case Member of
undefined ->
[];
@@ -44,7 +50,9 @@ get_user_viewable_channels(UserId, State) ->
undefined ->
false;
_ ->
case can_view_channel(UserId, ChannelId, Member, State) of
case
guild_permissions:can_view_channel(UserId, ChannelId, Member, State)
of
true -> {true, ChannelId};
false -> false
end
@@ -54,62 +62,346 @@ get_user_viewable_channels(UserId, State) ->
)
end.
-spec viewable_channel_set(integer(), map()) -> sets:set().
-spec viewable_channel_set(user_id(), guild_state()) -> sets:set(channel_id()).
viewable_channel_set(UserId, State) when is_integer(UserId) ->
sets:from_list(get_user_viewable_channels(UserId, State));
case get_cached_viewable_channel_map(UserId, State) of
undefined ->
sets:from_list(get_user_viewable_channels(UserId, State));
ViewableChannelMap ->
sets:from_list(maps:keys(ViewableChannelMap))
end;
viewable_channel_set(_, _) ->
sets:new().
-spec have_shared_viewable_channel(integer(), integer(), map()) -> boolean().
have_shared_viewable_channel(UserId, OtherUserId, State) when is_integer(UserId), is_integer(OtherUserId), UserId =/= OtherUserId ->
-spec get_cached_viewable_channel_map(user_id(), guild_state()) -> map() | undefined.
get_cached_viewable_channel_map(UserId, State) ->
Sessions = maps:get(sessions, State, #{}),
maps:fold(
fun(_SessionId, SessionData, Acc) ->
case Acc of
undefined ->
case maps:get(user_id, SessionData, undefined) of
UserId ->
case maps:get(viewable_channels, SessionData, undefined) of
ViewableChannels when is_map(ViewableChannels) ->
ViewableChannels;
_ ->
undefined
end;
_ ->
undefined
end;
_ ->
Acc
end
end,
undefined,
Sessions
).
-spec have_shared_viewable_channel(user_id(), user_id(), guild_state()) -> boolean().
have_shared_viewable_channel(UserId, OtherUserId, State) when
is_integer(UserId), is_integer(OtherUserId), UserId =/= OtherUserId
->
SetA = viewable_channel_set(UserId, State),
SetB = viewable_channel_set(OtherUserId, State),
not sets:is_empty(sets:intersection(SetA, SetB));
have_shared_viewable_channel(_, _, _) ->
false.
-spec compute_and_dispatch_visibility_changes(map(), map()) -> ok.
compute_and_dispatch_visibility_changes(OldState, NewState) ->
Sessions = maps:get(sessions, NewState, #{}),
GuildId = maps:get(id, NewState, 0),
-spec filter_connected_session_entries(map()) -> [{binary(), map()}].
filter_connected_session_entries(Sessions) ->
[{Sid, S} || {Sid, S} <- maps:to_list(Sessions), maps:get(pending_connect, S, false) =/= true].
lists:foreach(
fun({SessionId, SessionData}) ->
-spec compute_and_dispatch_visibility_changes(guild_state(), guild_state()) -> guild_state().
compute_and_dispatch_visibility_changes(OldState, NewState) ->
compute_and_dispatch_visibility_changes_for_sessions(
filter_connected_session_entries(maps:get(sessions, NewState, #{})),
OldState,
NewState
).
-spec compute_and_dispatch_visibility_changes_for_channels(
[channel_id()], guild_state(), guild_state()
) ->
guild_state().
compute_and_dispatch_visibility_changes_for_channels(ChannelIds, OldState, NewState) ->
ValidChannelIds = lists:usort([Id || Id <- ChannelIds, is_integer(Id), Id > 0]),
case ValidChannelIds of
[] ->
compute_and_dispatch_visibility_changes(OldState, NewState);
_ ->
compute_and_dispatch_channel_visibility_changes_for_sessions(
filter_connected_session_entries(maps:get(sessions, NewState, #{})),
ValidChannelIds,
OldState,
NewState
)
end.
-spec compute_and_dispatch_visibility_changes_for_users([user_id()], guild_state(), guild_state()) ->
guild_state().
compute_and_dispatch_visibility_changes_for_users(UserIds, OldState, NewState) ->
Sessions = maps:get(sessions, NewState, #{}),
UserIdSet = sets:from_list(UserIds),
TargetSessions = lists:filter(
fun({_SessionId, SessionData}) ->
maps:get(pending_connect, SessionData, false) =/= true andalso
begin
SessionUserId = maps:get(user_id, SessionData, undefined),
is_integer(SessionUserId) andalso sets:is_element(SessionUserId, UserIdSet)
end
end,
maps:to_list(Sessions)
),
compute_and_dispatch_visibility_changes_for_sessions(TargetSessions, OldState, NewState).
-spec compute_and_dispatch_channel_visibility_changes_for_sessions(
[{binary(), map()}], [channel_id()], guild_state(), guild_state()
) ->
guild_state().
compute_and_dispatch_channel_visibility_changes_for_sessions(
SessionEntries, ChannelIds, OldState, NewState
) ->
GuildId = maps:get(id, NewState, 0),
lists:foldl(
fun({SessionId, SessionData}, AccState) ->
compute_channel_visibility_changes_for_session(
SessionId,
SessionData,
ChannelIds,
OldState,
AccState,
GuildId
)
end,
NewState,
SessionEntries
).
-spec compute_channel_visibility_changes_for_session(
binary(), map(), [channel_id()], guild_state(), guild_state(), integer()
) ->
guild_state().
compute_channel_visibility_changes_for_session(
SessionId, SessionData, ChannelIds, OldState, NewState, GuildId
) ->
UserId = maps:get(user_id, SessionData, undefined),
case is_integer(UserId) of
false ->
NewState;
true ->
Pid = maps:get(pid, SessionData, undefined),
OldMember = guild_permissions:find_member_by_user_id(UserId, OldState),
ConnectedSet = connected_voice_channel_set(UserId, NewState),
InitialViewableMap = ensure_viewable_channel_map(SessionData, UserId, OldState),
{FinalViewableMap, StateAfterChannels} = lists:foldl(
fun(ChannelId, {ViewableMapAcc, StateAcc}) ->
OldVisible = channel_is_visible(UserId, ChannelId, OldMember, OldState),
{StateAfterPreserve, NewVisible} = ensure_new_channel_visibility(
UserId,
ChannelId,
ConnectedSet,
StateAcc
),
UpdatedViewableMap = update_viewable_map_for_channel(
ViewableMapAcc, ChannelId, NewVisible
),
case {OldVisible, NewVisible} of
{true, false} ->
dispatch_channel_delete(ChannelId, Pid, OldState, GuildId),
{UpdatedViewableMap, StateAfterPreserve};
{false, true} ->
dispatch_channel_create(
ChannelId, Pid, StateAfterPreserve, GuildId
),
send_member_list_sync(
SessionId,
SessionData,
ChannelId,
GuildId,
StateAfterPreserve
),
{UpdatedViewableMap, StateAfterPreserve};
_ ->
{UpdatedViewableMap, StateAfterPreserve}
end
end,
{InitialViewableMap, NewState},
ChannelIds
),
guild_sessions:set_session_viewable_channels(
SessionId,
FinalViewableMap,
StateAfterChannels
)
end.
-spec ensure_viewable_channel_map(map(), user_id(), guild_state()) -> #{channel_id() => true}.
ensure_viewable_channel_map(SessionData, UserId, State) ->
case maps:get(viewable_channels, SessionData, undefined) of
ViewableChannels when is_map(ViewableChannels) ->
ViewableChannels;
_ ->
viewable_channel_map(sets:from_list(get_user_viewable_channels(UserId, State)))
end.
-spec ensure_new_channel_visibility(
user_id(), channel_id(), sets:set(channel_id()), guild_state()
) ->
{guild_state(), boolean()}.
ensure_new_channel_visibility(UserId, ChannelId, ConnectedSet, State) ->
NewMember = guild_permissions:find_member_by_user_id(UserId, State),
NewVisible0 = channel_is_visible(UserId, ChannelId, NewMember, State),
case {NewVisible0, sets:is_element(ChannelId, ConnectedSet)} of
{true, _} ->
{State, true};
{false, false} ->
{State, false};
{false, true} ->
case guild_virtual_channel_access:has_virtual_access(UserId, ChannelId, State) of
true ->
{State, true};
false ->
State1 = guild_virtual_channel_access:add_virtual_access(
UserId,
ChannelId,
State
),
State2 = guild_virtual_channel_access:clear_pending_join(
UserId,
ChannelId,
State1
),
{State2, true}
end
end.
-spec channel_is_visible(user_id(), channel_id(), map() | undefined, guild_state()) -> boolean().
channel_is_visible(UserId, ChannelId, Member, State) ->
guild_permissions:can_view_channel(UserId, ChannelId, Member, State).
-spec update_viewable_map_for_channel(map(), channel_id(), boolean()) -> map().
update_viewable_map_for_channel(ViewableMap, ChannelId, true) ->
maps:put(ChannelId, true, ViewableMap);
update_viewable_map_for_channel(ViewableMap, ChannelId, false) ->
maps:remove(ChannelId, ViewableMap).
-spec compute_and_dispatch_visibility_changes_for_sessions(
[{binary(), map()}], guild_state(), guild_state()
) -> guild_state().
compute_and_dispatch_visibility_changes_for_sessions(SessionEntries, OldState, NewState) ->
GuildId = maps:get(id, NewState, 0),
lists:foldl(
fun({SessionId, SessionData}, AccState) ->
UserId = maps:get(user_id, SessionData),
Pid = maps:get(pid, SessionData),
OldViewable = get_user_viewable_channels(UserId, OldState),
NewViewable = get_user_viewable_channels(UserId, NewState),
OldSet = sets:from_list(OldViewable),
OldSet = cached_viewable_channel_set(SessionData, UserId, OldState),
NewViewable = get_user_viewable_channels(UserId, AccState),
NewSet = sets:from_list(NewViewable),
Removed = sets:subtract(OldSet, NewSet),
Added = sets:subtract(NewSet, OldSet),
ConnectedSet = connected_voice_channel_set(UserId, AccState),
Removed0 = sets:subtract(OldSet, NewSet),
{StateWithAccess, PreservedSet} = preserve_connected_channels(
UserId,
Removed0,
ConnectedSet,
AccState
),
NewSet2 = sets:union(NewSet, PreservedSet),
StateWithCachedVisibility = guild_sessions:set_session_viewable_channels(
SessionId,
viewable_channel_map(NewSet2),
StateWithAccess
),
Removed = sets:subtract(OldSet, NewSet2),
Added = sets:subtract(NewSet2, OldSet),
lists:foreach(
fun(ChannelId) ->
dispatch_channel_delete(ChannelId, Pid, OldState, GuildId)
end,
sets:to_list(Removed)
),
lists:foreach(
fun(ChannelId) ->
dispatch_channel_create(ChannelId, Pid, NewState, GuildId),
send_member_list_sync(SessionId, SessionData, ChannelId, GuildId, NewState)
dispatch_channel_create(ChannelId, Pid, StateWithCachedVisibility, GuildId),
send_member_list_sync(
SessionId, SessionData, ChannelId, GuildId, StateWithCachedVisibility
)
end,
sets:to_list(Added)
)
),
StateWithCachedVisibility
end,
maps:to_list(Sessions)
),
ok.
NewState,
SessionEntries
).
-spec cached_viewable_channel_set(map(), user_id(), guild_state()) -> sets:set(channel_id()).
cached_viewable_channel_set(SessionData, UserId, State) ->
case maps:get(viewable_channels, SessionData, undefined) of
ViewableMap when is_map(ViewableMap) ->
sets:from_list(maps:keys(ViewableMap));
_ ->
sets:from_list(get_user_viewable_channels(UserId, State))
end.
-spec viewable_channel_map(sets:set(channel_id())) -> #{channel_id() => true}.
viewable_channel_map(ChannelSet) ->
sets:fold(
fun(ChannelId, Acc) ->
maps:put(ChannelId, true, Acc)
end,
#{},
ChannelSet
).
-spec connected_voice_channel_set(user_id(), guild_state()) -> sets:set(channel_id()).
connected_voice_channel_set(UserId, State) ->
VoiceStates = voice_state_utils:voice_states(State),
maps:fold(
fun(_ConnId, VoiceState, Acc) ->
case voice_state_utils:voice_state_user_id(VoiceState) of
UserId ->
case voice_state_utils:voice_state_channel_id(VoiceState) of
undefined -> Acc;
ChannelId -> sets:add_element(ChannelId, Acc)
end;
_ ->
Acc
end
end,
sets:new(),
VoiceStates
).
-spec preserve_connected_channels(
user_id(), sets:set(channel_id()), sets:set(channel_id()), guild_state()
) ->
{guild_state(), sets:set(channel_id())}.
preserve_connected_channels(UserId, RemovedSet, ConnectedSet, State) ->
ToPreserve = sets:intersection(RemovedSet, ConnectedSet),
UpdatedState = sets:fold(
fun(ChannelId, AccState) ->
case guild_virtual_channel_access:has_virtual_access(UserId, ChannelId, AccState) of
true ->
AccState;
false ->
State1 = guild_virtual_channel_access:add_virtual_access(
UserId, ChannelId, AccState
),
guild_virtual_channel_access:clear_pending_join(UserId, ChannelId, State1)
end
end,
State,
ToPreserve
),
{UpdatedState, ToPreserve}.
-spec dispatch_channel_delete(channel_id(), pid(), guild_state(), integer()) -> ok.
dispatch_channel_delete(ChannelId, SessionPid, OldState, GuildId) ->
case is_pid(SessionPid) of
true ->
case find_channel_by_id(ChannelId, OldState) of
case guild_permissions:find_channel_by_id(ChannelId, OldState) of
undefined ->
ok;
_Channel ->
@@ -123,10 +415,11 @@ dispatch_channel_delete(ChannelId, SessionPid, OldState, GuildId) ->
ok
end.
-spec dispatch_channel_create(channel_id(), pid(), guild_state(), integer()) -> ok.
dispatch_channel_create(ChannelId, SessionPid, NewState, GuildId) ->
case is_pid(SessionPid) of
true ->
case find_channel_by_id(ChannelId, NewState) of
case guild_permissions:find_channel_by_id(ChannelId, NewState) of
undefined ->
ok;
Channel ->
@@ -139,13 +432,14 @@ dispatch_channel_create(ChannelId, SessionPid, NewState, GuildId) ->
ok
end.
-spec send_member_list_sync(binary(), map(), channel_id(), integer(), guild_state()) -> ok.
send_member_list_sync(SessionId, SessionData, ChannelId, GuildId, State) ->
SessionPid = maps:get(pid, SessionData),
case is_pid(SessionPid) of
false ->
ok;
true ->
ListId = calculate_list_id(ChannelId, State),
ListId = guild_member_list:calculate_list_id(ChannelId, State),
MemberListSubs = maps:get(member_list_subscriptions, State, #{}),
ListSubs = maps:get(ListId, MemberListSubs, #{}),
Ranges = maps:get(SessionId, ListSubs, []),
@@ -156,15 +450,254 @@ send_member_list_sync(SessionId, SessionData, ChannelId, GuildId, State) ->
SessionUserId = maps:get(user_id, SessionData),
case can_send_member_list(SessionUserId, ChannelId, State) of
true ->
SyncResponse = build_sync_response(GuildId, ListId, Ranges, State),
SyncResponseWithChannel = maps:put(<<"channel_id">>, integer_to_binary(ChannelId), SyncResponse),
gen_server:cast(SessionPid, {dispatch, guild_member_list_update, SyncResponseWithChannel});
SyncResponse = guild_member_list:build_sync_response(
GuildId, ListId, Ranges, State
),
SyncResponseWithChannel = maps:put(
<<"channel_id">>, integer_to_binary(ChannelId), SyncResponse
),
gen_server:cast(
SessionPid,
{dispatch, guild_member_list_update, SyncResponseWithChannel}
);
false ->
ok
end
end
end.
-spec can_send_member_list(user_id() | undefined, channel_id(), guild_state()) -> boolean().
can_send_member_list(UserId, ChannelId, State) ->
is_integer(UserId) andalso
guild_permissions:can_view_channel(UserId, ChannelId, undefined, State).
-ifdef(TEST).
get_user_viewable_channels_returns_empty_for_non_member_test() ->
State = #{
data => #{
<<"channels">> => [#{<<"id">> => <<"100">>, <<"type">> => 0}],
<<"members">> => []
}
},
?assertEqual([], get_user_viewable_channels(999, State)).
viewable_channel_set_returns_empty_for_invalid_user_test() ->
State = #{data => #{}},
?assertEqual(sets:new(), viewable_channel_set(undefined, State)).
have_shared_viewable_channel_same_user_test() ->
State = #{data => #{}},
?assertEqual(false, have_shared_viewable_channel(100, 100, State)).
preserves_connected_channel_visibility_on_permission_loss_test() ->
UserId = 10,
GuildId = 1,
ChannelId = 5,
ViewPerm = constants:view_channel_permission(),
OldState = visibility_state(GuildId, UserId, ChannelId, ViewPerm, true),
NewState = visibility_state(GuildId, UserId, ChannelId, 0, true),
UpdatedState = compute_and_dispatch_visibility_changes(OldState, NewState),
?assert(guild_virtual_channel_access:has_virtual_access(UserId, ChannelId, UpdatedState)),
?assertEqual(
false,
guild_virtual_channel_access:is_pending_join(UserId, ChannelId, UpdatedState)
),
?assert(guild_permissions:can_view_channel(UserId, ChannelId, undefined, UpdatedState)).
does_not_add_virtual_access_when_not_connected_test() ->
UserId = 20,
GuildId = 2,
ChannelId = 6,
ViewPerm = constants:view_channel_permission(),
OldState = visibility_state(GuildId, UserId, ChannelId, ViewPerm, false),
NewState = visibility_state(GuildId, UserId, ChannelId, 0, false),
UpdatedState = compute_and_dispatch_visibility_changes(OldState, NewState),
?assertNot(guild_virtual_channel_access:has_virtual_access(UserId, ChannelId, UpdatedState)).
does_not_add_virtual_access_when_permission_remains_test() ->
UserId = 30,
GuildId = 3,
ChannelId = 7,
ViewPerm = constants:view_channel_permission(),
OldState = visibility_state(GuildId, UserId, ChannelId, ViewPerm, true),
NewState = visibility_state(GuildId, UserId, ChannelId, ViewPerm, true),
UpdatedState = compute_and_dispatch_visibility_changes(OldState, NewState),
?assertNot(guild_virtual_channel_access:has_virtual_access(UserId, ChannelId, UpdatedState)).
compute_and_dispatch_visibility_changes_for_users_targets_selected_users_test() ->
GuildId = 33,
ChannelId = 77,
RoleId = 101,
ViewPerm = constants:view_channel_permission(),
Session10 = #{
session_id => <<"s10">>,
user_id => 10,
pid => self(),
viewable_channels => #{ChannelId => true}
},
Session20 = #{
session_id => <<"s20">>,
user_id => 20,
pid => self(),
viewable_channels => #{ChannelId => true}
},
OldState = #{
id => GuildId,
sessions => #{<<"s10">> => Session10, <<"s20">> => Session20},
data => #{
<<"guild">> => #{<<"owner_id">> => <<"999">>},
<<"roles">> => [
#{<<"id">> => integer_to_binary(GuildId), <<"permissions">> => <<"0">>},
#{<<"id">> => integer_to_binary(RoleId), <<"permissions">> => integer_to_binary(ViewPerm)}
],
<<"members">> => #{
10 => #{<<"user">> => #{<<"id">> => <<"10">>}, <<"roles">> => [integer_to_binary(RoleId)]},
20 => #{<<"user">> => #{<<"id">> => <<"20">>}, <<"roles">> => [integer_to_binary(RoleId)]}
},
<<"channels">> => [#{<<"id">> => integer_to_binary(ChannelId), <<"permission_overwrites">> => []}]
}
},
NewState = #{
id => GuildId,
sessions => #{<<"s10">> => Session10, <<"s20">> => Session20},
data => #{
<<"guild">> => #{<<"owner_id">> => <<"999">>},
<<"roles">> => [
#{<<"id">> => integer_to_binary(GuildId), <<"permissions">> => <<"0">>},
#{<<"id">> => integer_to_binary(RoleId), <<"permissions">> => integer_to_binary(ViewPerm)}
],
<<"members">> => #{
10 => #{<<"user">> => #{<<"id">> => <<"10">>}, <<"roles">> => []},
20 => #{<<"user">> => #{<<"id">> => <<"20">>}, <<"roles">> => [integer_to_binary(RoleId)]}
},
<<"channels">> => [#{<<"id">> => integer_to_binary(ChannelId), <<"permission_overwrites">> => []}]
}
},
UpdatedState = compute_and_dispatch_visibility_changes_for_users([10], OldState, NewState),
UpdatedSessions = maps:get(sessions, UpdatedState),
UpdatedSession10 = maps:get(<<"s10">>, UpdatedSessions),
UpdatedSession20 = maps:get(<<"s20">>, UpdatedSessions),
?assertEqual(false, maps:is_key(ChannelId, maps:get(viewable_channels, UpdatedSession10, #{}))),
?assertEqual(true, maps:is_key(ChannelId, maps:get(viewable_channels, UpdatedSession20, #{}))).
compute_and_dispatch_visibility_changes_for_channels_limits_to_changed_channels_test() ->
GuildId = 44,
UserId = 10,
ChannelA = 100,
ChannelB = 101,
ViewPerm = constants:view_channel_permission(),
BaseRole = #{
<<"id">> => integer_to_binary(GuildId),
<<"permissions">> => integer_to_binary(ViewPerm)
},
Session = #{
session_id => <<"s10">>,
user_id => UserId,
pid => self(),
viewable_channels => #{ChannelA => true, ChannelB => true}
},
OldState = #{
id => GuildId,
sessions => #{<<"s10">> => Session},
data => #{
<<"guild">> => #{<<"owner_id">> => <<"999">>},
<<"roles">> => [BaseRole],
<<"members">> => #{
UserId => #{
<<"user">> => #{<<"id">> => integer_to_binary(UserId)},
<<"roles">> => []
}
},
<<"channels">> => [
#{<<"id">> => integer_to_binary(ChannelA), <<"permission_overwrites">> => []},
#{<<"id">> => integer_to_binary(ChannelB), <<"permission_overwrites">> => []}
]
}
},
NewState = #{
id => GuildId,
sessions => #{<<"s10">> => Session},
data => #{
<<"guild">> => #{<<"owner_id">> => <<"999">>},
<<"roles">> => [BaseRole],
<<"members">> => #{
UserId => #{
<<"user">> => #{<<"id">> => integer_to_binary(UserId)},
<<"roles">> => []
}
},
<<"channels">> => [
#{
<<"id">> => integer_to_binary(ChannelA),
<<"permission_overwrites">> => [
#{
<<"id">> => integer_to_binary(GuildId),
<<"type">> => 0,
<<"allow">> => <<"0">>,
<<"deny">> => integer_to_binary(ViewPerm)
}
]
},
#{<<"id">> => integer_to_binary(ChannelB), <<"permission_overwrites">> => []}
]
}
},
UpdatedState = compute_and_dispatch_visibility_changes_for_channels(
[ChannelA],
OldState,
NewState
),
UpdatedSession = maps:get(<<"s10">>, maps:get(sessions, UpdatedState)),
UpdatedViewable = maps:get(viewable_channels, UpdatedSession, #{}),
?assertEqual(false, maps:is_key(ChannelA, UpdatedViewable)),
?assertEqual(true, maps:is_key(ChannelB, UpdatedViewable)).
visibility_state(GuildId, UserId, ChannelId, Perms, Connected) ->
VoiceStates =
case Connected of
true ->
#{
<<"conn">> => #{
<<"user_id">> => integer_to_binary(UserId),
<<"guild_id">> => integer_to_binary(GuildId),
<<"channel_id">> => integer_to_binary(ChannelId),
<<"connection_id">> => <<"conn">>
}
};
false ->
#{}
end,
Sessions = #{<<"s1">> => #{user_id => UserId, pid => self()}},
Data = #{
<<"guild">> => #{<<"owner_id">> => <<"999">>},
<<"roles">> => [
#{
<<"id">> => integer_to_binary(GuildId),
<<"permissions">> => integer_to_binary(Perms)
}
],
<<"members">> => [
#{<<"user">> => #{<<"id">> => integer_to_binary(UserId)}, <<"roles">> => []}
],
<<"channels">> => [
#{<<"id">> => integer_to_binary(ChannelId), <<"permission_overwrites">> => []}
]
},
#{
id => GuildId,
data => Data,
sessions => Sessions,
voice_states => VoiceStates
}.
filter_connected_session_entries_excludes_pending_test() ->
Normal = #{session_id => <<"s1">>, user_id => 1, pending_connect => false},
Pending = #{session_id => <<"s2">>, user_id => 2, pending_connect => true},
NoPending = #{session_id => <<"s3">>, user_id => 3},
Sessions = #{<<"s1">> => Normal, <<"s2">> => Pending, <<"s3">> => NoPending},
Result = filter_connected_session_entries(Sessions),
ResultIds = lists:sort([Sid || {Sid, _} <- Result]),
?assertEqual([<<"s1">>, <<"s3">>], ResultIds).
-endif.

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -17,6 +17,8 @@
-module(dm_voice).
-import(guild_voice_unclaimed_account_utils, [parse_unclaimed_error/1]).
-export([voice_state_update/2]).
-export([get_voice_state/2]).
-export([get_voice_token/6]).
@@ -24,58 +26,48 @@
-export([broadcast_voice_state_update/3]).
-export([join_or_create_call/5, join_or_create_call/6]).
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-endif.
-type dm_state() :: map().
-type voice_state() :: map().
-type voice_state_map() :: #{binary() => voice_state()}.
-spec voice_state_update(map(), dm_state()) ->
{reply, map(), dm_state()} | {reply, {error, atom(), atom()}, dm_state()}.
voice_state_update(Request, State) ->
#{
user_id := UserId,
channel_id := ChannelId
} = Request,
ConnectionId = maps:get(connection_id, Request, undefined),
VoiceStates = maps:get(dm_voice_states, State, #{}),
case ChannelId of
null ->
handle_dm_disconnect(ConnectionId, UserId, VoiceStates, State);
ChannelIdValue ->
Channels = maps:get(channels, State, #{}),
UserId = maps:get(user_id, State),
logger:info(
"[dm_voice] Looking up channel ~p for user ~p, channels map has ~p entries",
[ChannelIdValue, UserId, maps:size(Channels)]
),
case maps:get(ChannelIdValue, Channels, undefined) of
undefined ->
logger:info(
"[dm_voice] Channel ~p not found locally for user ~p, trying RPC fallback",
[ChannelIdValue, UserId]
),
case fetch_dm_channel_via_rpc(ChannelIdValue, UserId) of
{ok, Channel} ->
NewChannels = maps:put(ChannelIdValue, Channel, Channels),
NewState = maps:put(channels, NewChannels, State),
logger:info(
"[dm_voice] RPC fallback found channel ~p for user ~p, added to local map",
[ChannelIdValue, UserId]
),
handle_dm_voice_with_channel(
Channel, ChannelIdValue, UserId, Request, NewState
);
{error, Reason} ->
logger:warning(
"[dm_voice] Channel ~p not found for user ~p via RPC: ~p",
[ChannelIdValue, UserId, Reason]
),
{error, _Reason} ->
{reply, gateway_errors:error(dm_channel_not_found), State}
end;
Channel ->
logger:info(
"[dm_voice] Found channel ~p for user ~p, type: ~p",
[ChannelIdValue, UserId, maps:get(<<"type">>, Channel, 0)]
),
handle_dm_voice_with_channel(Channel, ChannelIdValue, UserId, Request, State)
end
end.
-spec handle_dm_voice_with_channel(map(), integer(), integer(), map(), dm_state()) ->
{reply, map(), dm_state()} | {reply, {error, atom(), atom()}, dm_state()}.
handle_dm_voice_with_channel(Channel, ChannelIdValue, UserId, Request, State) ->
#{
session_id := SessionId,
@@ -86,11 +78,10 @@ handle_dm_voice_with_channel(Channel, ChannelIdValue, UserId, Request, State) ->
SelfStream = maps:get(self_stream, Request, false),
ConnectionId = maps:get(connection_id, Request, undefined),
IsMobile = maps:get(is_mobile, Request, false),
ViewerStreamKey = maps:get(viewer_stream_key, Request, undefined),
ViewerStreamKeys = maps:get(viewer_stream_keys, Request, undefined),
Latitude = maps:get(latitude, Request, null),
Longitude = maps:get(longitude, Request, null),
VoiceStates = maps:get(dm_voice_states, State, #{}),
ChannelType = maps:get(<<"type">>, Channel, 0),
case is_dm_channel_type(ChannelType) of
false ->
@@ -109,7 +100,7 @@ handle_dm_voice_with_channel(Channel, ChannelIdValue, UserId, Request, State) ->
SelfDeaf,
SelfVideo,
SelfStream,
ViewerStreamKey,
ViewerStreamKeys,
IsMobile,
Latitude,
Longitude,
@@ -119,6 +110,8 @@ handle_dm_voice_with_channel(Channel, ChannelIdValue, UserId, Request, State) ->
end
end.
-spec handle_dm_disconnect(binary() | undefined, integer(), voice_state_map(), dm_state()) ->
{reply, map(), dm_state()} | {reply, {error, atom(), atom()}, dm_state()}.
handle_dm_disconnect(undefined, _UserId, _VoiceStates, State) ->
{reply, gateway_errors:error(voice_missing_connection_id), State};
handle_dm_disconnect(ConnectionId, _UserId, VoiceStates, State) ->
@@ -128,13 +121,11 @@ handle_dm_disconnect(ConnectionId, _UserId, VoiceStates, State) ->
OldVoiceState ->
NewVoiceStates = maps:remove(ConnectionId, VoiceStates),
NewState = maps:put(dm_voice_states, NewVoiceStates, State),
OldChannelId = maps:get(<<"channel_id">>, OldVoiceState, null),
DisconnectVoiceState = maps:put(
<<"channel_id">>, null, maps:put(<<"connection_id">>, ConnectionId, OldVoiceState)
),
SessionId = maps:get(id, State),
case OldChannelId of
null ->
ok;
@@ -154,7 +145,6 @@ handle_dm_disconnect(ConnectionId, _UserId, VoiceStates, State) ->
end
end)
end,
case OldChannelId of
null ->
ok;
@@ -164,15 +154,29 @@ handle_dm_disconnect(ConnectionId, _UserId, VoiceStates, State) ->
broadcast_voice_state_update(
OldChannelIdInt, DisconnectVoiceState, NewState
);
{error, _, Reason} ->
logger:warning("[dm_voice] Invalid channel_id: ~p", [Reason]),
{error, _, _Reason} ->
ok
end
end,
{reply, #{success => true}, NewState}
end.
-spec handle_dm_connect_or_update(
binary() | undefined | null,
integer(),
integer(),
binary(),
boolean(),
boolean(),
boolean(),
boolean(),
term(),
boolean(),
term(),
term(),
voice_state_map(),
dm_state()
) -> {reply, map(), dm_state()} | {reply, {error, atom(), atom()}, dm_state()}.
handle_dm_connect_or_update(
ConnectionId,
ChannelIdValue,
@@ -182,7 +186,7 @@ handle_dm_connect_or_update(
SelfDeaf,
SelfVideo,
SelfStream,
ViewerStreamKey,
ViewerStreamKeys,
IsMobile,
Latitude,
Longitude,
@@ -190,7 +194,7 @@ handle_dm_connect_or_update(
State
) when ConnectionId =:= undefined; ConnectionId =:= null ->
VoiceStates = maps:get(dm_voice_states, State, #{}),
case validate_dm_viewer_stream_key(ViewerStreamKey, ChannelIdValue, VoiceStates) of
case validate_dm_viewer_stream_keys(ViewerStreamKeys, ChannelIdValue, VoiceStates) of
{error, ErrorAtom} ->
{reply, gateway_errors:error(ErrorAtom), State};
{ok, ParsedViewerKey} ->
@@ -218,7 +222,7 @@ handle_dm_connect_or_update(
SelfDeaf,
SelfVideo,
SelfStream,
ViewerStreamKey,
ViewerStreamKeys,
IsMobile,
_Latitude,
_Longitude,
@@ -229,43 +233,49 @@ handle_dm_connect_or_update(
undefined ->
{reply, gateway_errors:error(voice_connection_not_found), State};
ExistingVoiceState ->
ExistingSessionId = maps:get(<<"session_id">>, ExistingVoiceState, undefined),
EffectiveSessionId = resolve_effective_session_id(ExistingSessionId, SessionId),
ValidViewerKey = validate_dm_viewer_stream_key(
ViewerStreamKey, ChannelIdValue, VoiceStates
),
case ValidViewerKey of
{error, ErrorAtom} ->
{reply, gateway_errors:error(ErrorAtom), State};
{ok, ParsedViewerKey} ->
UpdatedVoiceState = ExistingVoiceState#{
<<"channel_id">> => integer_to_binary(ChannelIdValue),
<<"session_id">> => EffectiveSessionId,
<<"self_mute">> => SelfMute,
<<"self_deaf">> => SelfDeaf,
<<"self_video">> => SelfVideo,
<<"self_stream">> => SelfStream,
<<"is_mobile">> => IsMobile,
<<"viewer_stream_key">> => ParsedViewerKey
},
NewVoiceStates = maps:put(ConnectionId, UpdatedVoiceState, VoiceStates),
NewState = maps:put(dm_voice_states, NewVoiceStates, State),
broadcast_voice_state_update(ChannelIdValue, UpdatedVoiceState, NewState),
OldChannelId = maps:get(<<"channel_id">>, ExistingVoiceState, null),
NewChannelIdBin = integer_to_binary(ChannelIdValue),
NeedsToken = OldChannelId =/= NewChannelIdBin,
maybe_spawn_join_call(
NeedsToken, ChannelIdValue, UserId, UpdatedVoiceState, SessionId
case guild_voice_state:user_matches_voice_state(ExistingVoiceState, UserId) of
false ->
{reply, gateway_errors:error(voice_user_mismatch), State};
true ->
ExistingSessionId = maps:get(<<"session_id">>, ExistingVoiceState, undefined),
EffectiveSessionId = resolve_effective_session_id(ExistingSessionId, SessionId),
ValidViewerKey = validate_dm_viewer_stream_keys(
ViewerStreamKeys, ChannelIdValue, VoiceStates
),
{reply, #{success => true, needs_token => NeedsToken}, NewState}
case ValidViewerKey of
{error, ErrorAtom} ->
{reply, gateway_errors:error(ErrorAtom), State};
{ok, ParsedViewerKey} ->
UpdatedVoiceState = ExistingVoiceState#{
<<"channel_id">> => integer_to_binary(ChannelIdValue),
<<"session_id">> => EffectiveSessionId,
<<"self_mute">> => SelfMute,
<<"self_deaf">> => SelfDeaf,
<<"self_video">> => SelfVideo,
<<"self_stream">> => SelfStream,
<<"is_mobile">> => IsMobile,
<<"viewer_stream_keys">> => ParsedViewerKey
},
NewVoiceStates = maps:put(ConnectionId, UpdatedVoiceState, VoiceStates),
NewState = maps:put(dm_voice_states, NewVoiceStates, State),
broadcast_voice_state_update(ChannelIdValue, UpdatedVoiceState, NewState),
OldChannelId = maps:get(<<"channel_id">>, ExistingVoiceState, null),
NewChannelIdBin = integer_to_binary(ChannelIdValue),
NeedsToken = OldChannelId =/= NewChannelIdBin,
maybe_spawn_join_call(
NeedsToken,
ChannelIdValue,
UserId,
UpdatedVoiceState,
EffectiveSessionId,
State
),
{reply, #{success => true, needs_token => NeedsToken}, NewState}
end
end
end.
-spec normalize_session_id(term()) -> binary() | undefined.
normalize_session_id(undefined) ->
undefined;
normalize_session_id(SessionId) when is_binary(SessionId) ->
@@ -281,32 +291,50 @@ normalize_session_id(SessionId) ->
_:_ -> SessionId
end.
validate_dm_viewer_stream_key(RawKey, ChannelIdValue, VoiceStates) ->
case RawKey of
undefined ->
{ok, null};
null ->
{ok, null};
_ when not is_binary(RawKey) -> {error, voice_invalid_state};
_ ->
case voice_state_utils:parse_stream_key(RawKey) of
{ok, #{scope := dm, channel_id := ParsedChannelId, connection_id := ConnId}} when
ParsedChannelId =:= ChannelIdValue
->
case maps:get(ConnId, VoiceStates, undefined) of
undefined ->
{error, voice_connection_not_found};
StreamVS ->
case map_utils:get_integer(StreamVS, <<"channel_id">>, undefined) of
ChannelIdValue -> {ok, RawKey};
_ -> {error, voice_invalid_state}
end
end;
_ ->
{error, voice_invalid_state}
end
-spec validate_dm_viewer_stream_keys(term(), integer(), voice_state_map()) ->
{ok, list()} | {error, atom()}.
validate_dm_viewer_stream_keys(RawKeys, _ChannelIdValue, _VoiceStates) when RawKeys =:= undefined; RawKeys =:= null ->
{ok, []};
validate_dm_viewer_stream_keys(RawKeys, _ChannelIdValue, _VoiceStates) when not is_list(RawKeys) ->
{error, voice_invalid_state};
validate_dm_viewer_stream_keys(Keys, ChannelIdValue, VoiceStates) ->
validate_dm_viewer_stream_keys_list(Keys, ChannelIdValue, VoiceStates, []).
-spec validate_dm_viewer_stream_keys_list(list(), integer(), voice_state_map(), list()) ->
{ok, list()} | {error, atom()}.
validate_dm_viewer_stream_keys_list([], _ChannelIdValue, _VoiceStates, Acc) ->
{ok, lists:reverse(Acc)};
validate_dm_viewer_stream_keys_list([Key | Rest], ChannelIdValue, VoiceStates, Acc) ->
case validate_single_dm_viewer_stream_key(Key, ChannelIdValue, VoiceStates) of
{ok, ValidKey} ->
validate_dm_viewer_stream_keys_list(Rest, ChannelIdValue, VoiceStates, [ValidKey | Acc]);
{error, _} = Error ->
Error
end.
-spec validate_single_dm_viewer_stream_key(term(), integer(), voice_state_map()) ->
{ok, binary()} | {error, atom()}.
validate_single_dm_viewer_stream_key(RawKey, _ChannelIdValue, _VoiceStates) when not is_binary(RawKey) ->
{error, voice_invalid_state};
validate_single_dm_viewer_stream_key(RawKey, ChannelIdValue, VoiceStates) ->
case voice_state_utils:parse_stream_key(RawKey) of
{ok, #{scope := dm, channel_id := ParsedChannelId, connection_id := ConnId}} when
ParsedChannelId =:= ChannelIdValue
->
case maps:get(ConnId, VoiceStates, undefined) of
undefined ->
{error, voice_connection_not_found};
StreamVS ->
case map_utils:get_integer(StreamVS, <<"channel_id">>, undefined) of
ChannelIdValue -> {ok, RawKey};
_ -> {error, voice_invalid_state}
end
end;
_ ->
{error, voice_invalid_state}
end.
-spec resolve_effective_session_id(term(), term()) -> binary() | undefined.
resolve_effective_session_id(ExistingSessionId, RequestSessionId) ->
ExistingNormalized = normalize_session_id(ExistingSessionId),
RequestNormalized = normalize_session_id(RequestSessionId),
@@ -316,28 +344,43 @@ resolve_effective_session_id(ExistingSessionId, RequestSessionId) ->
_ -> ExistingNormalized
end.
maybe_spawn_join_call(false, _ChannelId, _UserId, _VoiceState, _SessionId) ->
-spec maybe_spawn_join_call(
boolean(), integer(), integer(), voice_state(), binary() | undefined, dm_state()
) -> ok.
maybe_spawn_join_call(false, _ChannelId, _UserId, _VoiceState, _SessionId, _State) ->
ok;
maybe_spawn_join_call(true, ChannelId, UserId, VoiceState, SessionId) ->
spawn(fun() ->
try
join_or_create_call(ChannelId, UserId, VoiceState, SessionId, self())
catch
_:_ -> ok
end
end).
maybe_spawn_join_call(true, ChannelId, UserId, VoiceState, SessionId, State) when
is_binary(SessionId)
->
SessionPid = maps:get(session_pid, State, undefined),
case SessionPid of
Pid when is_pid(Pid) ->
spawn(fun() ->
try
join_or_create_call(ChannelId, UserId, VoiceState, SessionId, Pid)
catch
_:_ -> ok
end
end),
ok;
_ ->
ok
end;
maybe_spawn_join_call(true, _ChannelId, _UserId, _VoiceState, _SessionId, _State) ->
ok.
-spec get_voice_token(integer(), integer(), binary(), pid(), term(), term()) -> ok | error.
get_voice_token(ChannelId, UserId, _SessionId, SessionPid, Latitude, Longitude) ->
Req = voice_utils:build_voice_token_rpc_request(
null, ChannelId, UserId, null, Latitude, Longitude
),
case rpc_client:call(Req) of
Region = resolve_call_region(ChannelId),
ReqWithRegion = voice_utils:add_rtc_region_to_request(Req, Region),
case rpc_client:call(ReqWithRegion) of
{ok, Data} ->
Token = maps:get(<<"token">>, Data),
Endpoint = maps:get(<<"endpoint">>, Data),
ConnectionId = maps:get(<<"connectionId">>, Data),
SessionPid !
{voice_server_update, #{
channel_id => integer_to_binary(ChannelId),
@@ -357,6 +400,20 @@ get_voice_token(ChannelId, UserId, _SessionId, SessionPid, Latitude, Longitude)
error
end.
-spec get_dm_voice_token_and_create_state(
integer(),
integer(),
binary(),
boolean(),
boolean(),
boolean(),
boolean(),
term(),
boolean(),
term(),
term(),
dm_state()
) -> {reply, map(), dm_state()} | {reply, {error, atom(), atom()}, dm_state()}.
get_dm_voice_token_and_create_state(
UserId,
ChannelId,
@@ -365,7 +422,7 @@ get_dm_voice_token_and_create_state(
SelfDeaf,
SelfVideo,
SelfStream,
ViewerStreamKey,
ViewerStreamKeys,
IsMobile,
Latitude,
Longitude,
@@ -374,8 +431,9 @@ get_dm_voice_token_and_create_state(
Req = voice_utils:build_voice_token_rpc_request(
null, ChannelId, UserId, null, Latitude, Longitude
),
case rpc_client:call(Req) of
Region = resolve_call_region(ChannelId, State),
ReqWithRegion = voice_utils:add_rtc_region_to_request(Req, Region),
case rpc_client:call(ReqWithRegion) of
{ok, Data} ->
handle_dm_token_success(
Data,
@@ -386,7 +444,7 @@ get_dm_voice_token_and_create_state(
SelfDeaf,
SelfVideo,
SelfStream,
ViewerStreamKey,
ViewerStreamKeys,
IsMobile,
State
);
@@ -399,6 +457,48 @@ get_dm_voice_token_and_create_state(
{reply, gateway_errors:error(voice_token_failed), State}
end.
-spec resolve_call_region(integer()) -> binary() | null.
resolve_call_region(ChannelId) ->
case gen_server:call(call_manager, {lookup, ChannelId}, 5000) of
{ok, CallPid} ->
case gen_server:call(CallPid, {get_state}, 5000) of
{ok, CallData} -> maps:get(region, CallData, null);
_ -> null
end;
_ ->
null
end.
-spec resolve_call_region(integer(), dm_state()) -> binary() | null.
resolve_call_region(ChannelId, State) ->
Calls = maps:get(calls, State, #{}),
case maps:get(ChannelId, Calls, undefined) of
{CallPid, _Ref} when is_pid(CallPid) ->
try
case gen_server:call(CallPid, {get_state}, 250) of
{ok, CallData} -> maps:get(region, CallData, null);
_ -> null
end
catch
_:_ -> null
end;
_ ->
null
end.
-spec handle_dm_token_success(
map(),
integer(),
integer(),
binary(),
boolean(),
boolean(),
boolean(),
boolean(),
term(),
boolean(),
dm_state()
) -> {reply, map(), dm_state()}.
handle_dm_token_success(
Data,
UserId,
@@ -408,14 +508,13 @@ handle_dm_token_success(
SelfDeaf,
SelfVideo,
SelfStream,
ViewerStreamKey,
ViewerStreamKeys,
IsMobile,
State
) ->
Token = maps:get(<<"token">>, Data),
Endpoint = maps:get(<<"endpoint">>, Data),
ConnectionId = maps:get(<<"connectionId">>, Data),
VoiceState = #{
<<"user_id">> => integer_to_binary(UserId),
<<"channel_id">> => integer_to_binary(ChannelId),
@@ -426,15 +525,12 @@ handle_dm_token_success(
<<"self_deaf">> => SelfDeaf,
<<"self_video">> => SelfVideo,
<<"self_stream">> => SelfStream,
<<"viewer_stream_key">> => ViewerStreamKey
<<"viewer_stream_keys">> => ViewerStreamKeys
},
VoiceStates = maps:get(dm_voice_states, State, #{}),
NewVoiceStates = maps:put(ConnectionId, VoiceState, VoiceStates),
NewState = maps:put(dm_voice_states, NewVoiceStates, State),
broadcast_voice_state_update(ChannelId, VoiceState, NewState),
SessionPid = maps:get(session_pid, State),
VoiceServerUpdate = #{
<<"token">> => Token,
@@ -443,7 +539,6 @@ handle_dm_token_success(
<<"connection_id">> => ConnectionId
},
gen_server:cast(SessionPid, {dispatch, voice_server_update, VoiceServerUpdate}),
GatewaySessionId = maps:get(id, State),
spawn(fun() ->
try
@@ -452,23 +547,22 @@ handle_dm_token_success(
_:_ -> ok
end
end),
{reply, #{success => true, needs_token => false, connection_id => ConnectionId}, NewState}.
-spec get_voice_state(binary(), dm_state()) -> voice_state() | undefined.
get_voice_state(ConnectionId, State) ->
VoiceStates = maps:get(dm_voice_states, State, #{}),
maps:get(ConnectionId, VoiceStates, undefined).
-spec disconnect_voice_user(integer(), dm_state()) -> {reply, map(), dm_state()}.
disconnect_voice_user(UserId, State) ->
VoiceStates = maps:get(dm_voice_states, State, #{}),
UserVoiceStates = maps:filter(
fun(_ConnectionId, VoiceState) ->
maps:get(<<"user_id">>, VoiceState) =:= integer_to_binary(UserId)
end,
VoiceStates
),
case maps:size(UserVoiceStates) of
0 ->
{reply, #{success => true}, State};
@@ -481,71 +575,58 @@ disconnect_voice_user(UserId, State) ->
UserVoiceStates
),
NewState = maps:put(dm_voice_states, NewVoiceStates, State),
maps:foreach(
fun(_ConnectionId, VoiceState) ->
ChannelId = maps:get(<<"channel_id">>, VoiceState, null),
DisconnectVoiceState = maps:put(
<<"channel_id">>,
null,
maps:put(<<"connection_id">>, _ConnectionId, VoiceState)
),
case ChannelId of
null ->
ok;
_ ->
case validation:validate_snowflake(<<"channel_id">>, ChannelId) of
{ok, ChannelIdInt} ->
broadcast_voice_state_update(
ChannelIdInt, DisconnectVoiceState, NewState
);
{error, _, Reason} ->
logger:warning(
"[dm_voice] Invalid channel_id in voice state: ~p", [Reason]
),
ok
end
end
end,
UserVoiceStates
),
spawn(fun() ->
maps:foreach(
fun(ConnId, VoiceState) ->
ChannelId = maps:get(<<"channel_id">>, VoiceState, null),
DisconnectVoiceState = maps:put(
<<"channel_id">>,
null,
maps:put(<<"connection_id">>, ConnId, VoiceState)
),
case ChannelId of
null ->
ok;
_ ->
case validation:validate_snowflake(<<"channel_id">>, ChannelId) of
{ok, ChannelIdInt} ->
broadcast_voice_state_update(
ChannelIdInt, DisconnectVoiceState, NewState
);
{error, _, _Reason} ->
ok
end
end
end,
UserVoiceStates
)
end),
{reply, #{success => true}, NewState}
end.
parse_unclaimed_error(Body) when is_binary(Body) ->
try jsx:decode(Body, [return_maps]) of
#{<<"code">> := <<"UNCLAIMED_ACCOUNT_RESTRICTED">>} -> true;
#{<<"error">> := #{<<"code">> := <<"UNCLAIMED_ACCOUNT_RESTRICTED">>}} -> true;
_ -> false
catch
_:_ -> false
end;
parse_unclaimed_error(_) ->
false.
-spec broadcast_voice_state_update(integer(), voice_state(), dm_state()) -> ok.
broadcast_voice_state_update(ChannelId, VoiceState, State) ->
Channels = maps:get(channels, State, #{}),
case maps:get(ChannelId, Channels, undefined) of
undefined ->
ok;
Channel ->
Recipients = maps:get(<<"recipient_ids">>, Channel, []),
UserId = maps:get(user_id, State),
AllRecipients = lists:usort([UserId | Recipients]),
Event = voice_state_update,
lists:foreach(
fun(RecipientId) ->
presence_manager:dispatch_to_user(RecipientId, Event, VoiceState)
end,
AllRecipients
)
spawn(fun() ->
lists:foreach(
fun(RecipientId) ->
presence_manager:dispatch_to_user(RecipientId, Event, VoiceState)
end,
AllRecipients
)
end),
ok
end.
-spec check_recipient(integer(), integer(), dm_state()) -> boolean().
check_recipient(UserId, ChannelId, State) ->
Channels = maps:get(channels, State, #{}),
case maps:get(ChannelId, Channels, undefined) of
@@ -556,20 +637,24 @@ check_recipient(UserId, ChannelId, State) ->
is_dm_channel_type(ChannelType) andalso is_channel_recipient(UserId, Channel, State)
end.
-spec is_dm_channel_type(integer()) -> boolean().
is_dm_channel_type(1) -> true;
is_dm_channel_type(3) -> true;
is_dm_channel_type(_) -> false.
-spec is_channel_recipient(integer(), map(), dm_state()) -> boolean().
is_channel_recipient(UserId, Channel, State) ->
Recipients = maps:get(<<"recipient_ids">>, Channel, []),
CurrentUserId = maps:get(user_id, State),
lists:member(UserId, [CurrentUserId | Recipients]).
-spec join_or_create_call(integer(), integer(), voice_state(), binary(), pid()) -> ok.
join_or_create_call(ChannelId, UserId, VoiceState, SessionId, SessionPid) ->
join_or_create_call(ChannelId, UserId, VoiceState, SessionId, SessionPid, 10).
join_or_create_call(_ChannelId, UserId, _VoiceState, _SessionId, _SessionPid, 0) ->
logger:warning("[dm_voice] Failed to join call after retries, user ~p could not join", [UserId]),
-spec join_or_create_call(integer(), integer(), voice_state(), binary(), pid(), non_neg_integer()) ->
ok.
join_or_create_call(_ChannelId, _UserId, _VoiceState, _SessionId, _SessionPid, 0) ->
ok;
join_or_create_call(ChannelId, UserId, VoiceState, SessionId, SessionPid, Retries) ->
ConnectionId = maps:get(<<"connection_id">>, VoiceState, undefined),
@@ -597,6 +682,7 @@ join_or_create_call(ChannelId, UserId, VoiceState, SessionId, SessionPid, Retrie
join_or_create_call(ChannelId, UserId, VoiceState, SessionId, SessionPid, Retries - 1)
end.
-spec fetch_dm_channel_via_rpc(integer(), integer()) -> {ok, map()} | {error, term()}.
fetch_dm_channel_via_rpc(ChannelId, UserId) ->
Req = #{
<<"type">> => <<"get_dm_channel">>,
@@ -614,6 +700,7 @@ fetch_dm_channel_via_rpc(ChannelId, UserId) ->
{error, Reason}
end.
-spec convert_api_channel_to_gateway_format(map(), integer()) -> map().
convert_api_channel_to_gateway_format(Channel, CurrentUserId) ->
ChannelType = maps:get(<<"type">>, Channel, 0),
Recipients = maps:get(<<"recipients">>, Channel, []),
@@ -627,6 +714,7 @@ convert_api_channel_to_gateway_format(Channel, CurrentUserId) ->
<<"recipient_ids">> => RecipientIds
}.
-spec extract_recipient_id(term(), integer()) -> {true, integer()} | false.
extract_recipient_id(Recipient, CurrentUserId) when is_map(Recipient) ->
case maps:get(<<"id">>, Recipient, undefined) of
undefined -> false;
@@ -635,6 +723,7 @@ extract_recipient_id(Recipient, CurrentUserId) when is_map(Recipient) ->
extract_recipient_id(Id, CurrentUserId) ->
filter_recipient_id(parse_id(Id), CurrentUserId).
-spec parse_id(term()) -> integer() | null.
parse_id(Id) when is_integer(Id) -> Id;
parse_id(Id) when is_binary(Id) ->
case validation:validate_snowflake(<<"id">>, Id) of
@@ -644,6 +733,111 @@ parse_id(Id) when is_binary(Id) ->
parse_id(_) ->
null.
-spec filter_recipient_id(integer() | null, integer()) -> {true, integer()} | false.
filter_recipient_id(null, _CurrentUserId) -> false;
filter_recipient_id(Id, Id) -> false;
filter_recipient_id(Id, _CurrentUserId) -> {true, Id}.
-ifdef(TEST).
is_dm_channel_type_test() ->
?assert(is_dm_channel_type(1)),
?assert(is_dm_channel_type(3)),
?assertNot(is_dm_channel_type(0)),
?assertNot(is_dm_channel_type(2)),
?assertNot(is_dm_channel_type(4)).
normalize_session_id_test() ->
?assertEqual(undefined, normalize_session_id(undefined)),
?assertEqual(<<"abc">>, normalize_session_id(<<"abc">>)),
?assertEqual(<<"123">>, normalize_session_id(123)),
?assertEqual(<<"test">>, normalize_session_id("test")).
validate_dm_viewer_stream_keys_null_test() ->
?assertEqual({ok, []}, validate_dm_viewer_stream_keys(undefined, 123, #{})),
?assertEqual({ok, []}, validate_dm_viewer_stream_keys(null, 123, #{})).
validate_dm_viewer_stream_keys_invalid_type_test() ->
?assertEqual({error, voice_invalid_state}, validate_dm_viewer_stream_keys(123, 456, #{})).
resolve_effective_session_id_test() ->
?assertEqual(<<"req">>, resolve_effective_session_id(undefined, <<"req">>)),
?assertEqual(<<"existing">>, resolve_effective_session_id(<<"existing">>, <<"req">>)),
?assertEqual(<<"same">>, resolve_effective_session_id(<<"same">>, <<"same">>)).
filter_recipient_id_test() ->
?assertEqual(false, filter_recipient_id(null, 1)),
?assertEqual(false, filter_recipient_id(1, 1)),
?assertEqual({true, 2}, filter_recipient_id(2, 1)).
parse_id_test() ->
?assertEqual(123, parse_id(123)),
?assertEqual(null, parse_id(invalid)).
handle_dm_connect_or_update_user_mismatch_test() ->
VoiceStates = #{
<<"conn-1">> => #{
<<"channel_id">> => <<"100">>,
<<"user_id">> => <<"20">>,
<<"session_id">> => <<"sess">>
}
},
State = #{dm_voice_states => VoiceStates},
{reply, {error, validation_error, voice_user_mismatch}, _} =
handle_dm_connect_or_update(
<<"conn-1">>,
100,
10,
<<"sess">>,
false,
false,
false,
false,
undefined,
false,
null,
null,
VoiceStates,
State
).
handle_dm_connect_or_update_owner_match_proceeds_test() ->
VoiceStates = #{
<<"conn-1">> => #{
<<"channel_id">> => <<"100">>,
<<"user_id">> => <<"10">>,
<<"session_id">> => <<"sess">>
}
},
State = #{
dm_voice_states => VoiceStates,
channels => #{100 => #{<<"type">> => 1, <<"recipient_ids">> => [10]}},
user_id => 10,
id => <<"sess">>,
session_pid => self()
},
case
handle_dm_connect_or_update(
<<"conn-1">>,
100,
10,
<<"sess">>,
false,
false,
false,
false,
undefined,
false,
null,
null,
VoiceStates,
State
)
of
{reply, {error, validation_error, voice_user_mismatch}, _} ->
error(should_not_get_user_mismatch);
{reply, _, _} ->
ok
end.
-endif.

View File

@@ -26,7 +26,7 @@
-export([confirm_voice_connection_from_livekit/2]).
-export([move_member/2]).
-export([broadcast_voice_state_update/3]).
-export([broadcast_voice_server_update_to_session/6]).
-export([broadcast_voice_server_update_to_session/7]).
-export([send_voice_server_update_for_move/5]).
-export([send_voice_server_updates_for_move/4]).
-export([switch_voice_region_handler/2]).
@@ -35,67 +35,88 @@
-export([handle_virtual_channel_access_for_move/4]).
-export([cleanup_virtual_access_on_disconnect/2]).
voice_state_update(Request, State) ->
case guild_voice_connection:voice_state_update(Request, State) of
{reply, Response, NewState} ->
{reply, Response, NewState};
{error, Category, Message} ->
{reply, {error, Category, Message}, State}
end.
-type guild_state() :: map().
-type voice_state() :: map().
-spec voice_state_update(map(), guild_state()) ->
{reply, map(), guild_state()} | {reply, {error, atom(), atom()}, guild_state()}.
voice_state_update(Request, State) ->
guild_voice_connection:voice_state_update(Request, State).
-spec get_voice_state(map(), guild_state()) -> {reply, map(), guild_state()}.
get_voice_state(Request, State) ->
guild_voice_state:get_voice_state(Request, State).
-spec get_voice_states_list(guild_state()) -> [voice_state()].
get_voice_states_list(State) ->
guild_voice_state:get_voice_states_list(State).
-spec update_member_voice(map(), guild_state()) -> {reply, map(), guild_state()}.
update_member_voice(Request, State) ->
guild_voice_member:update_member_voice(Request, State).
-spec disconnect_voice_user(map(), guild_state()) -> {reply, map(), guild_state()}.
disconnect_voice_user(Request, State) ->
guild_voice_disconnect:disconnect_voice_user(Request, State).
-spec disconnect_voice_user_if_in_channel(map(), guild_state()) -> {reply, map(), guild_state()}.
disconnect_voice_user_if_in_channel(Request, State) ->
guild_voice_disconnect:disconnect_voice_user_if_in_channel(Request, State).
-spec disconnect_all_voice_users_in_channel(map(), guild_state()) -> {reply, map(), guild_state()}.
disconnect_all_voice_users_in_channel(Request, State) ->
guild_voice_disconnect:disconnect_all_voice_users_in_channel(Request, State).
-spec confirm_voice_connection_from_livekit(map(), guild_state()) ->
{reply, map(), guild_state()} | {error, atom(), atom()}.
confirm_voice_connection_from_livekit(Request, State) ->
case guild_voice_connection:confirm_voice_connection_from_livekit(Request, State) of
{reply, Response, NewState} ->
{reply, Response, NewState};
{error, Category, Message} ->
{reply, {error, Category, Message}, State}
end.
guild_voice_connection:confirm_voice_connection_from_livekit(Request, State).
-spec move_member(map(), guild_state()) -> {reply, map(), guild_state()}.
move_member(Request, State) ->
guild_voice_move:move_member(Request, State).
-spec send_voice_server_update_for_move(integer(), integer(), integer(), binary(), pid()) -> ok.
send_voice_server_update_for_move(GuildId, ChannelId, UserId, SessionId, GuildPid) ->
guild_voice_move:send_voice_server_update_for_move(
GuildId, ChannelId, UserId, SessionId, GuildPid
).
-spec send_voice_server_updates_for_move(integer(), integer(), [map()], pid()) -> ok.
send_voice_server_updates_for_move(GuildId, ChannelId, SessionDataList, GuildPid) ->
guild_voice_move:send_voice_server_updates_for_move(
GuildId, ChannelId, SessionDataList, GuildPid
).
-spec broadcast_voice_state_update(voice_state(), guild_state(), binary() | null) -> ok.
broadcast_voice_state_update(VoiceState, State, OldChannelIdBin) ->
guild_voice_broadcast:broadcast_voice_state_update(VoiceState, State, OldChannelIdBin).
broadcast_voice_server_update_to_session(GuildId, SessionId, Token, Endpoint, ConnectionId, State) ->
-spec broadcast_voice_server_update_to_session(
integer(), integer(), binary(), binary(), binary(), binary(), guild_state()
) -> ok.
broadcast_voice_server_update_to_session(
GuildId,
ChannelId,
SessionId,
Token,
Endpoint,
ConnectionId,
State
) ->
guild_voice_broadcast:broadcast_voice_server_update_to_session(
GuildId, SessionId, Token, Endpoint, ConnectionId, State
GuildId, ChannelId, SessionId, Token, Endpoint, ConnectionId, State
).
-spec switch_voice_region_handler(map(), guild_state()) -> {reply, map(), guild_state()}.
switch_voice_region_handler(Request, State) ->
guild_voice_region:switch_voice_region_handler(Request, State).
-spec switch_voice_region(integer(), integer(), pid()) -> ok | {error, term()}.
switch_voice_region(GuildId, ChannelId, GuildPid) ->
guild_voice_region:switch_voice_region(GuildId, ChannelId, GuildPid).
-spec handle_virtual_channel_access_for_move(integer(), integer(), map(), pid()) -> ok.
handle_virtual_channel_access_for_move(UserId, ChannelId, _ConnectionsToMove, GuildPid) ->
case gen_server:call(GuildPid, {get_sessions}, 10000) of
State when is_map(State) ->
@@ -104,16 +125,21 @@ handle_virtual_channel_access_for_move(UserId, ChannelId, _ConnectionsToMove, Gu
undefined ->
ok;
_ ->
HasViewPermission = guild_permissions:can_view_channel_by_permissions(
UserId, ChannelId, Member, State
Permissions = guild_permissions:get_member_permissions(
UserId, ChannelId, State
),
case HasViewPermission of
ViewPerm = constants:view_channel_permission(),
ConnectPerm = constants:connect_permission(),
HasView = (Permissions band ViewPerm) =:= ViewPerm,
HasConnect = (Permissions band ConnectPerm) =:= ConnectPerm,
case HasView andalso HasConnect of
true ->
ok;
false ->
gen_server:cast(
gen_server:call(
GuildPid,
{add_virtual_channel_access, UserId, ChannelId}
{add_virtual_channel_access, UserId, ChannelId},
10000
)
end
end;
@@ -121,5 +147,6 @@ handle_virtual_channel_access_for_move(UserId, ChannelId, _ConnectionsToMove, Gu
ok
end.
-spec cleanup_virtual_access_on_disconnect(integer(), pid()) -> ok.
cleanup_virtual_access_on_disconnect(UserId, GuildPid) ->
gen_server:cast(GuildPid, {cleanup_virtual_access_for_user, UserId}).

View File

@@ -18,28 +18,23 @@
-module(guild_voice_broadcast).
-export([broadcast_voice_state_update/3]).
-export([broadcast_voice_server_update_to_session/6]).
-export([broadcast_voice_server_update_to_session/7]).
-ifdef(TEST).
-define(WARN_MISSING_CONN(_VoiceState), ok).
-else.
-define(WARN_MISSING_CONN(VoiceState),
logger:warning(
"[guild_voice_broadcast] Skipping VOICE_STATE_UPDATE broadcast - missing connection_id: ~p",
[VoiceState]
)
).
-include_lib("eunit/include/eunit.hrl").
-endif.
-type guild_state() :: map().
-type voice_state() :: map().
-spec broadcast_voice_state_update(voice_state(), guild_state(), binary() | null) -> ok.
broadcast_voice_state_update(VoiceState, State, OldChannelIdBin) ->
case maps:get(<<"connection_id">>, VoiceState, undefined) of
undefined ->
?WARN_MISSING_CONN(VoiceState),
ok;
ConnectionId ->
_ConnectionId ->
Sessions = maps:get(sessions, State, #{}),
ChannelIdBin = maps:get(<<"channel_id">>, VoiceState, null),
FilterChannelIdBin =
case ChannelIdBin of
null ->
@@ -47,71 +42,100 @@ broadcast_voice_state_update(VoiceState, State, OldChannelIdBin) ->
_ ->
ChannelIdBin
end,
FilterChannelId = utils:binary_to_integer_safe(FilterChannelIdBin),
UserId = maps:get(<<"user_id">>, VoiceState, <<"unknown">>),
GuildId = maps:get(id, State, 0),
AllSessionDetails = [{Sid, maps:get(user_id, S)} || {Sid, S} <- maps:to_list(Sessions)],
logger:info(
"[guild_voice_broadcast] Broadcasting voice state update: "
"guild_id=~p user_id=~p channel_id=~p connection_id=~p "
"total_sessions=~p all_sessions=~p filter_channel_id=~p",
[
GuildId,
UserId,
ChannelIdBin,
ConnectionId,
maps:size(Sessions),
AllSessionDetails,
FilterChannelId
]
),
FilteredSessions = guild_sessions:filter_sessions_for_channel(
Sessions, FilterChannelId, undefined, State
),
SessionDetails = [{Sid, maps:get(user_id, S)} || {Sid, S} <- FilteredSessions],
Pids = [maps:get(pid, S) || {_Sid, S} <- FilteredSessions],
logger:info(
"[guild_voice_broadcast] Filtered sessions: "
"guild_id=~p user_id=~p filtered_count=~p session_details=~p pids=~p",
[GuildId, UserId, length(FilteredSessions), SessionDetails, Pids]
),
lists:foreach(
fun(Pid) when is_pid(Pid) ->
logger:info(
"[guild_voice_broadcast] Sending voice_state_update to session pid ~p",
[Pid]
),
gen_server:cast(Pid, {dispatch, voice_state_update, VoiceState})
end,
Pids
)
),
maybe_relay_voice_state_update(VoiceState, OldChannelIdBin, State),
ok
end.
broadcast_voice_server_update_to_session(GuildId, SessionId, Token, Endpoint, ConnectionId, State) ->
-spec broadcast_voice_server_update_to_session(
integer(), integer(), binary(), binary(), binary(), binary(), guild_state()
) -> ok.
broadcast_voice_server_update_to_session(
GuildId,
ChannelId,
SessionId,
Token,
Endpoint,
ConnectionId,
State
) ->
VoiceServerUpdate = #{
<<"token">> => Token,
<<"endpoint">> => Endpoint,
<<"guild_id">> => integer_to_binary(GuildId),
<<"channel_id">> => integer_to_binary(ChannelId),
<<"connection_id">> => ConnectionId
},
Sessions = maps:get(sessions, State, #{}),
case maps:get(SessionId, Sessions, undefined) of
undefined ->
maybe_relay_voice_server_update(
GuildId, ChannelId, SessionId, Token, Endpoint, ConnectionId, State
),
ok;
SessionData ->
SessionPid = maps:get(pid, SessionData, null),
case SessionPid of
Pid when is_pid(Pid) ->
gen_server:cast(Pid, {dispatch, voice_server_update, VoiceServerUpdate});
gen_server:cast(Pid, {dispatch, voice_server_update, VoiceServerUpdate}),
ok;
_ ->
maybe_relay_voice_server_update(
GuildId, ChannelId, SessionId, Token, Endpoint, ConnectionId, State
),
ok
end
end.
-spec maybe_relay_voice_state_update(map(), binary() | null, guild_state()) -> ok.
maybe_relay_voice_state_update(VoiceState, OldChannelIdBin, State) ->
case {maps:get(very_large_guild_coordinator_pid, State, undefined),
maps:get(very_large_guild_shard_index, State, undefined)}
of
{CoordPid, ShardIndex} when is_pid(CoordPid), is_integer(ShardIndex) ->
CoordPid ! {very_large_guild_voice_state_update, ShardIndex, VoiceState, OldChannelIdBin},
ok;
_ ->
ok
end.
-spec maybe_relay_voice_server_update(
integer(),
integer(),
binary(),
binary(),
binary(),
binary(),
guild_state()
) -> ok.
maybe_relay_voice_server_update(GuildId, ChannelId, SessionId, Token, Endpoint, ConnectionId, State) ->
case {maps:get(very_large_guild_coordinator_pid, State, undefined),
maps:get(very_large_guild_shard_index, State, undefined)}
of
{CoordPid, ShardIndex} when is_pid(CoordPid), is_integer(ShardIndex) ->
CoordPid !
{very_large_guild_voice_server_update, ShardIndex, GuildId, ChannelId, SessionId,
Token, Endpoint, ConnectionId},
ok;
_ ->
ok
end.
-ifdef(TEST).
broadcast_voice_state_update_missing_connection_id_test() ->
VoiceState = #{<<"user_id">> => <<"1">>},
State = #{sessions => #{}},
?assertEqual(ok, broadcast_voice_state_update(VoiceState, State, null)).
-endif.

File diff suppressed because it is too large Load Diff

View File

@@ -23,15 +23,18 @@
-export([disconnect_voice_user_if_in_channel/2]).
-export([disconnect_all_voice_users_in_channel/2]).
-export([cleanup_virtual_channel_access_for_user/2]).
-type guild_state() :: map().
-type voice_state() :: map().
-type voice_state_map() :: #{binary() => voice_state()}.
-export([recently_disconnected_voice_states/1]).
-export([clear_recently_disconnected/2]).
-export([clear_recently_disconnected_for_channel/2]).
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-endif.
-type guild_state() :: map().
-type voice_state() :: map().
-type voice_state_map() :: #{binary() => voice_state()}.
-spec handle_voice_disconnect(
binary() | undefined,
term(),
@@ -45,7 +48,10 @@ handle_voice_disconnect(ConnectionId, _SessionId, UserId, VoiceStates0, State) -
VoiceStates = voice_state_utils:ensure_voice_states(VoiceStates0),
case maps:get(ConnectionId, VoiceStates, undefined) of
undefined ->
{reply, #{success => true}, State};
%% Voice state not in voice_states - check if it's still pending
%% (user disconnected before LiveKit confirmation)
State1 = clear_pending_voice_connection(ConnectionId, State),
{reply, #{success => true}, State1};
OldVoiceState ->
case guild_voice_state:user_matches_voice_state(OldVoiceState, UserId) of
false ->
@@ -64,16 +70,43 @@ handle_voice_disconnect(ConnectionId, _SessionId, UserId, VoiceStates0, State) -
{GuildId, ChannelId} ->
maybe_force_disconnect(GuildId, ChannelId, UserId, ConnectionId, State),
NewVoiceStates = maps:remove(ConnectionId, VoiceStates),
NewState = maps:put(voice_states, NewVoiceStates, State),
NewState0 = maps:put(voice_states, NewVoiceStates, State),
NewState = clear_recently_disconnected(ConnectionId, NewState0),
voice_state_utils:broadcast_disconnects(
#{ConnectionId => OldVoiceState}, NewState
),
FinalState = cleanup_virtual_channel_access_for_user(UserId, NewState),
FinalState = maybe_cleanup_after_disconnect(
UserId, ChannelId, NewState
),
{reply, #{success => true}, FinalState}
end
end
end.
-spec clear_pending_voice_connection(binary(), guild_state()) -> guild_state().
clear_pending_voice_connection(ConnectionId, State) ->
PendingConnections = maps:get(pending_voice_connections, State, #{}),
case maps:is_key(ConnectionId, PendingConnections) of
false ->
State;
true ->
NewPendingConnections = maps:remove(ConnectionId, PendingConnections),
maps:put(pending_voice_connections, NewPendingConnections, State)
end.
-spec maybe_cleanup_after_disconnect(integer(), integer(), guild_state()) -> guild_state().
maybe_cleanup_after_disconnect(UserId, ChannelId, State) ->
case
guild_virtual_channel_access:is_pending_join(UserId, ChannelId, State) orelse
guild_virtual_channel_access:has_preserve(UserId, ChannelId, State) orelse
guild_virtual_channel_access:is_move_pending(UserId, ChannelId, State)
of
true ->
State;
false ->
cleanup_virtual_channel_access_for_user(UserId, State)
end.
-spec disconnect_voice_user(map(), guild_state()) -> {reply, map(), guild_state()}.
disconnect_voice_user(#{user_id := UserId} = Request, State) ->
ConnectionId = maps:get(connection_id, Request, null),
@@ -85,12 +118,22 @@ disconnect_voice_user(#{user_id := UserId} = Request, State) ->
end),
case maps:size(UserVoiceStates) of
0 ->
{reply, #{success => true}, State};
%% No active voice states - also clean up any pending connections
State1 = clear_pending_voice_connections_for_user(UserId, State),
{reply, #{success => true}, State1};
_ ->
maybe_force_disconnect_voice_states(UserVoiceStates, State),
NewVoiceStates = voice_state_utils:drop_voice_states(
UserVoiceStates, VoiceStates
),
NewState = maps:put(voice_states, NewVoiceStates, State),
NewState0 = maps:put(voice_states, NewVoiceStates, State),
NewState = maps:fold(
fun(ConnId, _, AccState) ->
clear_recently_disconnected(ConnId, AccState)
end,
NewState0,
UserVoiceStates
),
voice_state_utils:broadcast_disconnects(UserVoiceStates, NewState),
FinalState = cleanup_virtual_channel_access_for_user(UserId, NewState),
{reply, #{success => true}, FinalState}
@@ -98,14 +141,18 @@ disconnect_voice_user(#{user_id := UserId} = Request, State) ->
SpecificConnection ->
case maps:get(SpecificConnection, VoiceStates, undefined) of
undefined ->
{reply, #{success => true}, State};
%% Not found in voice_states - also clean up pending connection
State1 = clear_pending_voice_connection(SpecificConnection, State),
{reply, #{success => true}, State1};
VoiceState ->
case voice_state_utils:voice_state_user_id(VoiceState) of
undefined ->
{reply, gateway_errors:error(voice_invalid_state), State};
VoiceStateUserId when VoiceStateUserId =:= UserId ->
maybe_force_disconnect_voice_state(SpecificConnection, VoiceState, State),
NewVoiceStates = maps:remove(SpecificConnection, VoiceStates),
NewState = maps:put(voice_states, NewVoiceStates, State),
NewState0 = maps:put(voice_states, NewVoiceStates, State),
NewState = clear_recently_disconnected(SpecificConnection, NewState0),
voice_state_utils:broadcast_disconnects(
#{SpecificConnection => VoiceState}, NewState
),
@@ -117,6 +164,19 @@ disconnect_voice_user(#{user_id := UserId} = Request, State) ->
end
end.
-spec clear_pending_voice_connections_for_user(integer(), guild_state()) -> guild_state().
clear_pending_voice_connections_for_user(UserId, State) ->
PendingConnections = maps:get(pending_voice_connections, State, #{}),
FilteredPending = maps:filter(
fun(_ConnId, PendingData) ->
PendingUserId = maps:get(user_id, PendingData, undefined),
PendingUserId =/= UserId
end,
PendingConnections
),
maps:put(pending_voice_connections, FilteredPending, State).
-spec disconnect_voice_user_if_in_channel(map(), guild_state()) -> {reply, map(), guild_state()}.
disconnect_voice_user_if_in_channel(
#{user_id := UserId, expected_channel_id := ExpectedChannelId} = Request,
State
@@ -131,27 +191,36 @@ disconnect_voice_user_if_in_channel(
end),
case maps:size(UserVoiceStates) of
0 ->
%% Not found in voice_states - also clean up any pending connections
%% for this user/channel (user disconnected before LiveKit confirmation)
State1 = clear_pending_voice_connections_for_user_channel(
UserId, ExpectedChannelId, State
),
{reply,
#{
success => true,
ignored => true,
reason => <<"not_in_expected_channel">>
},
State};
State1};
_ ->
NewVoiceStates = voice_state_utils:drop_voice_states(
UserVoiceStates, VoiceStates
),
NewState = maps:put(voice_states, NewVoiceStates, State),
NewState0 = maps:put(voice_states, NewVoiceStates, State),
NewState = cache_recently_disconnected(UserVoiceStates, NewState0),
voice_state_utils:broadcast_disconnects(UserVoiceStates, NewState),
{reply, #{success => true}, NewState}
end;
ConnId ->
case maps:get(ConnId, VoiceStates, undefined) of
undefined ->
%% Not found in voice_states - also clean up pending connection
%% (user disconnected before LiveKit confirmation)
State1 = clear_pending_voice_connection(ConnId, State),
{reply,
#{success => true, ignored => true, reason => <<"connection_not_found">>},
State};
State1};
VoiceState ->
case
{
@@ -161,7 +230,10 @@ disconnect_voice_user_if_in_channel(
of
{UserId, ExpectedChannelId} ->
NewVoiceStates = maps:remove(ConnId, VoiceStates),
NewState = maps:put(voice_states, NewVoiceStates, State),
NewState0 = maps:put(voice_states, NewVoiceStates, State),
NewState = cache_recently_disconnected(
#{ConnId => VoiceState}, NewState0
),
voice_state_utils:broadcast_disconnects(
#{ConnId => VoiceState}, NewState
),
@@ -178,105 +250,157 @@ disconnect_voice_user_if_in_channel(
end
end.
-spec clear_pending_voice_connections_for_user_channel(integer(), integer(), guild_state()) ->
guild_state().
clear_pending_voice_connections_for_user_channel(UserId, ChannelId, State) ->
PendingConnections = maps:get(pending_voice_connections, State, #{}),
FilteredPending = maps:filter(
fun(_ConnId, PendingData) ->
PendingUserId = maps:get(user_id, PendingData, undefined),
PendingChannelId = maps:get(channel_id, PendingData, undefined),
not (PendingUserId =:= UserId andalso PendingChannelId =:= ChannelId)
end,
PendingConnections
),
maps:put(pending_voice_connections, FilteredPending, State).
-spec clear_pending_voice_connections_for_channel(integer(), guild_state()) -> guild_state().
clear_pending_voice_connections_for_channel(ChannelId, State) ->
PendingConnections = maps:get(pending_voice_connections, State, #{}),
FilteredPending = maps:filter(
fun(_ConnId, PendingData) ->
PendingChannelId = maps:get(channel_id, PendingData, undefined),
PendingChannelId =/= ChannelId
end,
PendingConnections
),
maps:put(pending_voice_connections, FilteredPending, State).
-spec disconnect_all_voice_users_in_channel(map(), guild_state()) -> {reply, map(), guild_state()}.
disconnect_all_voice_users_in_channel(#{channel_id := ChannelId}, State) ->
VoiceStates = voice_state_utils:voice_states(State),
ChannelVoiceStates = voice_state_utils:filter_voice_states(VoiceStates, fun(_, V) ->
voice_state_utils:voice_state_channel_id(V) =:= ChannelId
end),
%% Also clean up any pending connections for this channel
%% (users that requested tokens but haven't confirmed via LiveKit yet)
State1 = clear_pending_voice_connections_for_channel(ChannelId, State),
case maps:size(ChannelVoiceStates) of
0 ->
{reply, #{success => true, disconnected_count => 0}, State};
{reply, #{success => true, disconnected_count => 0}, State1};
Count ->
maybe_force_disconnect_voice_states(ChannelVoiceStates, State1),
NewVoiceStates = voice_state_utils:drop_voice_states(ChannelVoiceStates, VoiceStates),
NewState = maps:put(voice_states, NewVoiceStates, State),
NewState0 = maps:put(voice_states, NewVoiceStates, State1),
NewState = clear_recently_disconnected_for_channel(ChannelId, NewState0),
voice_state_utils:broadcast_disconnects(ChannelVoiceStates, NewState),
{reply, #{success => true, disconnected_count => Count}, NewState}
end.
-spec maybe_force_disconnect_voice_states(voice_state_map(), guild_state()) -> ok.
maybe_force_disconnect_voice_states(VoiceStates, State) ->
maps:foreach(
fun(ConnId, VoiceState) ->
maybe_force_disconnect_voice_state(ConnId, VoiceState, State)
end,
VoiceStates
),
ok.
-spec maybe_force_disconnect_voice_state(binary(), voice_state(), guild_state()) -> ok.
maybe_force_disconnect_voice_state(ConnectionId, VoiceState, State) ->
UserId = voice_state_utils:voice_state_user_id(VoiceState),
ChannelId = voice_state_utils:voice_state_channel_id(VoiceState),
GuildId = resolve_guild_id(VoiceState, State),
case {GuildId, ChannelId, UserId} of
{GId, CId, UId} when is_integer(GId), is_integer(CId), is_integer(UId) ->
_ = maybe_force_disconnect(GId, CId, UId, ConnectionId, State),
ok;
_ ->
ok
end.
-spec resolve_guild_id(voice_state(), guild_state()) -> integer() | undefined.
resolve_guild_id(VoiceState, State) ->
case voice_state_utils:voice_state_guild_id(VoiceState) of
undefined ->
map_utils:get_integer(State, id, undefined);
GuildId ->
GuildId
end.
-spec force_disconnect_participant(integer(), integer(), integer(), binary()) ->
{ok, map()} | {error, term()}.
force_disconnect_participant(GuildId, ChannelId, UserId, ConnectionId) ->
Req = voice_utils:build_force_disconnect_rpc_request(GuildId, ChannelId, UserId, ConnectionId),
case rpc_client:call(Req) of
{ok, _Data} ->
logger:debug(
"[guild_voice_disconnect] Force disconnected participant via RPC ~p",
[
[
{guildId, GuildId},
{channelId, ChannelId},
{userId, UserId},
{connectionId, ConnectionId}
]
]
),
{ok, #{success => true}};
{error, Reason} ->
logger:error(
"[guild_voice_disconnect] Failed to force disconnect participant via RPC ~p",
[
[
{guildId, GuildId},
{channelId, ChannelId},
{userId, UserId},
{connectionId, ConnectionId},
{error, Reason}
]
]
),
{error, Reason}
end.
-spec cleanup_virtual_channel_access_for_user(integer(), guild_state()) -> guild_state().
cleanup_virtual_channel_access_for_user(UserId, State) ->
VoiceStates = voice_state_utils:voice_states(State),
HasVoiceConnection = maps:fold(
VirtualChannels = guild_virtual_channel_access:get_virtual_channels_for_user(UserId, State),
lists:foldl(
fun(ChannelId, AccState) ->
case user_has_voice_connection_in_channel(UserId, ChannelId, VoiceStates) of
true ->
AccState;
false ->
case
guild_virtual_channel_access:is_pending_join(UserId, ChannelId, AccState) orelse
guild_virtual_channel_access:has_preserve(UserId, ChannelId, AccState) orelse
guild_virtual_channel_access:is_move_pending(
UserId, ChannelId, AccState
)
of
true ->
AccState;
false ->
ok = maybe_dispatch_visibility_remove(UserId, ChannelId, AccState),
guild_virtual_channel_access:remove_virtual_access(
UserId, ChannelId, AccState
)
end
end
end,
State,
VirtualChannels
).
-spec user_has_voice_connection_in_channel(integer(), integer(), voice_state_map()) -> boolean().
user_has_voice_connection_in_channel(UserId, ChannelId, VoiceStates) ->
ChannelIdBin = integer_to_binary(ChannelId),
maps:fold(
fun(_ConnId, VoiceState, Acc) ->
case Acc of
true -> true;
false -> voice_state_utils:voice_state_user_id(VoiceState) =:= UserId
true ->
true;
false ->
voice_state_utils:voice_state_user_id(VoiceState) =:= UserId andalso
maps:get(<<"channel_id">>, VoiceState, null) =:= ChannelIdBin
end
end,
false,
VoiceStates
),
case HasVoiceConnection of
).
-spec maybe_dispatch_visibility_remove(integer(), integer(), guild_state()) -> ok.
maybe_dispatch_visibility_remove(UserId, ChannelId, State) ->
case guild_virtual_channel_access:has_virtual_access(UserId, ChannelId, State) of
true ->
State;
guild_virtual_channel_access:dispatch_channel_visibility_change(
UserId, ChannelId, remove, State
);
false ->
VirtualChannels = guild_virtual_channel_access:get_virtual_channels_for_user(
UserId, State
),
lists:foldl(
fun(ChannelId, AccState) ->
Member = guild_permissions:find_member_by_user_id(UserId, AccState),
case Member of
undefined ->
AccState;
_ ->
HasViewPermission = guild_permissions:can_view_channel_by_permissions(
UserId, ChannelId, Member, AccState
),
case HasViewPermission of
true ->
guild_virtual_channel_access:remove_virtual_access(
UserId, ChannelId, AccState
);
false ->
guild_virtual_channel_access:dispatch_channel_visibility_change(
UserId, ChannelId, remove, AccState
),
guild_virtual_channel_access:remove_virtual_access(
UserId, ChannelId, AccState
)
end
end
end,
State,
VirtualChannels
)
ok
end.
-spec maybe_force_disconnect(integer(), integer(), integer(), binary(), guild_state()) ->
{ok, map()} | {error, term()}.
maybe_force_disconnect(GuildId, ChannelId, UserId, ConnectionId, State) ->
case maps:get(test_force_disconnect_fun, State, undefined) of
Fun when is_function(Fun, 4) ->
@@ -285,6 +409,59 @@ maybe_force_disconnect(GuildId, ChannelId, UserId, ConnectionId, State) ->
force_disconnect_participant(GuildId, ChannelId, UserId, ConnectionId)
end.
-define(RECENTLY_DISCONNECTED_TTL_MS, 60000).
-spec recently_disconnected_voice_states(guild_state()) -> map().
recently_disconnected_voice_states(State) ->
case maps:get(recently_disconnected_voice_states, State, undefined) of
Map when is_map(Map) -> Map;
_ -> #{}
end.
-spec cache_recently_disconnected(voice_state_map(), guild_state()) -> guild_state().
cache_recently_disconnected(VoiceStatesToCache, State) ->
Now = erlang:system_time(millisecond),
Existing = recently_disconnected_voice_states(State),
Swept = sweep_expired_recently_disconnected(Existing, Now),
NewEntries = maps:fold(
fun(ConnId, VoiceState, Acc) ->
maps:put(ConnId, #{voice_state => VoiceState, disconnected_at => Now}, Acc)
end,
Swept,
VoiceStatesToCache
),
maps:put(recently_disconnected_voice_states, NewEntries, State).
-spec sweep_expired_recently_disconnected(map(), integer()) -> map().
sweep_expired_recently_disconnected(Cache, Now) ->
maps:filter(
fun(_ConnId, #{disconnected_at := DisconnectedAt}) ->
(Now - DisconnectedAt) < ?RECENTLY_DISCONNECTED_TTL_MS;
(_ConnId, _) ->
false
end,
Cache
).
-spec clear_recently_disconnected(binary(), guild_state()) -> guild_state().
clear_recently_disconnected(ConnectionId, State) ->
Cache = recently_disconnected_voice_states(State),
NewCache = maps:remove(ConnectionId, Cache),
maps:put(recently_disconnected_voice_states, NewCache, State).
-spec clear_recently_disconnected_for_channel(integer(), guild_state()) -> guild_state().
clear_recently_disconnected_for_channel(ChannelId, State) ->
Cache = recently_disconnected_voice_states(State),
NewCache = maps:filter(
fun(_ConnId, #{voice_state := VS}) ->
voice_state_utils:voice_state_channel_id(VS) =/= ChannelId;
(_ConnId, _) ->
false
end,
Cache
),
maps:put(recently_disconnected_voice_states, NewCache, State).
-ifdef(TEST).
disconnect_voice_user_removes_all_connections_test() ->
@@ -292,7 +469,10 @@ disconnect_voice_user_removes_all_connections_test() ->
<<"a">> => voice_state_fixture(5, 10, 20),
<<"b">> => voice_state_fixture(5, 10, 21)
},
State = #{voice_states => VoiceStates},
State = #{
voice_states => VoiceStates,
test_force_disconnect_fun => fun(_, _, _, _) -> {ok, #{success => true}} end
},
{reply, #{success := true}, #{voice_states := #{}}} =
disconnect_voice_user(#{user_id => 5, connection_id => null}, State).
@@ -309,6 +489,353 @@ disconnect_voice_user_if_in_channel_ignored_test() ->
{reply, #{ignored := true}, _} =
disconnect_voice_user_if_in_channel(#{user_id => 5, expected_channel_id => 99}, State).
user_has_voice_connection_in_channel_test() ->
VoiceStates = #{
<<"conn">> => #{<<"user_id">> => <<"10">>, <<"channel_id">> => <<"100">>}
},
?assert(user_has_voice_connection_in_channel(10, 100, VoiceStates)),
?assertNot(user_has_voice_connection_in_channel(10, 200, VoiceStates)),
?assertNot(user_has_voice_connection_in_channel(20, 100, VoiceStates)).
recently_disconnected_voice_states_default_test() ->
?assertEqual(#{}, recently_disconnected_voice_states(#{})).
recently_disconnected_voice_states_returns_map_test() ->
Cache = #{<<"conn">> => #{voice_state => #{}, disconnected_at => 1000}},
State = #{recently_disconnected_voice_states => Cache},
?assertEqual(Cache, recently_disconnected_voice_states(State)).
cache_recently_disconnected_test() ->
VS = voice_state_fixture(5, 10, 20),
State = #{},
NewState = cache_recently_disconnected(#{<<"conn">> => VS}, State),
Cache = recently_disconnected_voice_states(NewState),
?assert(maps:is_key(<<"conn">>, Cache)),
#{<<"conn">> := #{voice_state := CachedVS}} = Cache,
?assertEqual(VS, CachedVS).
clear_recently_disconnected_test() ->
VS = voice_state_fixture(5, 10, 20),
State0 = cache_recently_disconnected(#{<<"conn">> => VS}, #{}),
State1 = clear_recently_disconnected(<<"conn">>, State0),
?assertEqual(#{}, recently_disconnected_voice_states(State1)).
clear_recently_disconnected_for_channel_test() ->
VS1 = voice_state_fixture(5, 10, 20),
VS2 = voice_state_fixture(6, 10, 30),
State0 = cache_recently_disconnected(#{<<"a">> => VS1, <<"b">> => VS2}, #{}),
State1 = clear_recently_disconnected_for_channel(20, State0),
Cache = recently_disconnected_voice_states(State1),
?assertNot(maps:is_key(<<"a">>, Cache)),
?assert(maps:is_key(<<"b">>, Cache)).
sweep_expired_recently_disconnected_test() ->
Now = erlang:system_time(millisecond),
Cache = #{
<<"fresh">> => #{voice_state => #{}, disconnected_at => Now - 1000},
<<"expired">> => #{voice_state => #{}, disconnected_at => Now - 70000}
},
Swept = sweep_expired_recently_disconnected(Cache, Now),
?assert(maps:is_key(<<"fresh">>, Swept)),
?assertNot(maps:is_key(<<"expired">>, Swept)).
disconnect_voice_user_if_in_channel_caches_by_connection_test() ->
VS = voice_state_fixture(5, 10, 20),
VoiceStates = #{<<"conn">> => VS},
State = #{voice_states => VoiceStates},
{reply, #{success := true}, NewState} =
disconnect_voice_user_if_in_channel(
#{user_id => 5, expected_channel_id => 20, connection_id => <<"conn">>},
State
),
Cache = recently_disconnected_voice_states(NewState),
?assert(maps:is_key(<<"conn">>, Cache)).
disconnect_voice_user_calls_force_disconnect_for_all_connections_test() ->
Self = self(),
TestFun = fun(GuildId, ChannelId, UserId, ConnectionId) ->
Self ! {force_disconnect, GuildId, ChannelId, UserId, ConnectionId},
{ok, #{success => true}}
end,
VoiceStates = #{
<<"a">> => voice_state_fixture(5, 10, 20),
<<"b">> => voice_state_fixture(5, 10, 21)
},
State = #{
id => 10,
voice_states => VoiceStates,
test_force_disconnect_fun => TestFun
},
{reply, #{success := true}, #{voice_states := #{}}} =
disconnect_voice_user(#{user_id => 5, connection_id => null}, State),
Msgs = collect_force_disconnect_messages(2),
?assertEqual(2, length(Msgs)),
?assert(lists:member({force_disconnect, 10, 20, 5, <<"a">>}, Msgs)),
?assert(lists:member({force_disconnect, 10, 21, 5, <<"b">>}, Msgs)).
disconnect_voice_user_calls_force_disconnect_for_specific_connection_test() ->
Self = self(),
TestFun = fun(GuildId, ChannelId, UserId, ConnectionId) ->
Self ! {force_disconnect, GuildId, ChannelId, UserId, ConnectionId},
{ok, #{success => true}}
end,
VoiceStates = #{
<<"a">> => voice_state_fixture(5, 10, 20),
<<"b">> => voice_state_fixture(5, 10, 21)
},
State = #{
id => 10,
voice_states => VoiceStates,
test_force_disconnect_fun => TestFun
},
{reply, #{success := true}, NewState} =
disconnect_voice_user(#{user_id => 5, connection_id => <<"a">>}, State),
Msgs = collect_force_disconnect_messages(1),
?assertEqual(1, length(Msgs)),
?assert(lists:member({force_disconnect, 10, 20, 5, <<"a">>}, Msgs)),
Remaining = maps:get(voice_states, NewState),
?assert(maps:is_key(<<"b">>, Remaining)),
?assertNot(maps:is_key(<<"a">>, Remaining)).
disconnect_all_voice_users_in_channel_calls_force_disconnect_test() ->
Self = self(),
TestFun = fun(GuildId, ChannelId, UserId, ConnectionId) ->
Self ! {force_disconnect, GuildId, ChannelId, UserId, ConnectionId},
{ok, #{success => true}}
end,
VoiceStates = #{
<<"a">> => voice_state_fixture(5, 10, 20),
<<"b">> => voice_state_fixture(6, 10, 20),
<<"c">> => voice_state_fixture(7, 10, 30)
},
State = #{
id => 10,
voice_states => VoiceStates,
test_force_disconnect_fun => TestFun
},
{reply, #{success := true, disconnected_count := 2}, NewState} =
disconnect_all_voice_users_in_channel(#{channel_id => 20}, State),
Msgs = collect_force_disconnect_messages(2),
?assertEqual(2, length(Msgs)),
?assert(lists:member({force_disconnect, 10, 20, 5, <<"a">>}, Msgs)),
?assert(lists:member({force_disconnect, 10, 20, 6, <<"b">>}, Msgs)),
Remaining = maps:get(voice_states, NewState),
?assert(maps:is_key(<<"c">>, Remaining)),
?assertNot(maps:is_key(<<"a">>, Remaining)),
?assertNot(maps:is_key(<<"b">>, Remaining)).
disconnect_voice_user_if_in_channel_skips_force_disconnect_test() ->
Self = self(),
TestFun = fun(_, _, _, _) ->
Self ! force_disconnect_called,
{ok, #{success => true}}
end,
VoiceStates = #{<<"a">> => voice_state_fixture(5, 10, 20)},
State = #{
voice_states => VoiceStates,
test_force_disconnect_fun => TestFun
},
{reply, #{success := true}, _} =
disconnect_voice_user_if_in_channel(
#{user_id => 5, expected_channel_id => 20, connection_id => <<"a">>},
State
),
receive
force_disconnect_called -> ?assert(false)
after 0 ->
ok
end.
%% Tests for clear_pending_voice_connection/2
clear_pending_voice_connection_removes_connection_test() ->
PendingConnections = #{
<<"conn1">> => #{user_id => 1, channel_id => 100},
<<"conn2">> => #{user_id => 2, channel_id => 200}
},
State = #{pending_voice_connections => PendingConnections},
NewState = clear_pending_voice_connection(<<"conn1">>, State),
NewPending = maps:get(pending_voice_connections, NewState),
?assertNot(maps:is_key(<<"conn1">>, NewPending)),
?assert(maps:is_key(<<"conn2">>, NewPending)).
clear_pending_voice_connection_ignores_missing_test() ->
PendingConnections = #{<<"conn1">> => #{user_id => 1, channel_id => 100}},
State = #{pending_voice_connections => PendingConnections},
NewState = clear_pending_voice_connection(<<"missing">>, State),
?assertEqual(State, NewState).
clear_pending_voice_connection_handles_empty_pending_test() ->
State = #{voice_states => #{}},
NewState = clear_pending_voice_connection(<<"conn">>, State),
?assertEqual(#{}, maps:get(pending_voice_connections, NewState, #{})).
%% Tests for clear_pending_voice_connections_for_user/2
clear_pending_voice_connections_for_user_removes_all_user_connections_test() ->
PendingConnections = #{
<<"conn1">> => #{user_id => 5, channel_id => 100},
<<"conn2">> => #{user_id => 5, channel_id => 200},
<<"conn3">> => #{user_id => 6, channel_id => 100}
},
State = #{pending_voice_connections => PendingConnections},
NewState = clear_pending_voice_connections_for_user(5, State),
NewPending = maps:get(pending_voice_connections, NewState),
?assertNot(maps:is_key(<<"conn1">>, NewPending)),
?assertNot(maps:is_key(<<"conn2">>, NewPending)),
?assert(maps:is_key(<<"conn3">>, NewPending)).
%% Tests for clear_pending_voice_connections_for_user_channel/3
clear_pending_voice_connections_for_user_channel_removes_matching_test() ->
PendingConnections = #{
<<"conn1">> => #{user_id => 5, channel_id => 100},
<<"conn2">> => #{user_id => 5, channel_id => 200},
<<"conn3">> => #{user_id => 6, channel_id => 100}
},
State = #{pending_voice_connections => PendingConnections},
NewState = clear_pending_voice_connections_for_user_channel(5, 100, State),
NewPending = maps:get(pending_voice_connections, NewState),
?assertNot(maps:is_key(<<"conn1">>, NewPending)),
?assert(maps:is_key(<<"conn2">>, NewPending)),
?assert(maps:is_key(<<"conn3">>, NewPending)).
%% Tests for clear_pending_voice_connections_for_channel/2
clear_pending_voice_connections_for_channel_removes_all_channel_connections_test() ->
PendingConnections = #{
<<"conn1">> => #{user_id => 5, channel_id => 100},
<<"conn2">> => #{user_id => 6, channel_id => 100},
<<"conn3">> => #{user_id => 7, channel_id => 200}
},
State = #{pending_voice_connections => PendingConnections},
NewState = clear_pending_voice_connections_for_channel(100, State),
NewPending = maps:get(pending_voice_connections, NewState),
?assertNot(maps:is_key(<<"conn1">>, NewPending)),
?assertNot(maps:is_key(<<"conn2">>, NewPending)),
?assert(maps:is_key(<<"conn3">>, NewPending)).
%% Tests for disconnect handlers cleaning up pending connections
handle_voice_disconnect_cleans_pending_when_not_in_voice_states_test() ->
PendingConnections = #{<<"conn1">> => #{user_id => 5, channel_id => 100}},
State = #{
voice_states => #{},
pending_voice_connections => PendingConnections
},
{reply, #{success := true}, NewState} =
handle_voice_disconnect(<<"conn1">>, undefined, 5, #{}, State),
NewPending = maps:get(pending_voice_connections, NewState),
?assertNot(maps:is_key(<<"conn1">>, NewPending)).
disconnect_voice_user_cleans_pending_when_no_active_states_test() ->
PendingConnections = #{
<<"conn1">> => #{user_id => 5, channel_id => 100},
<<"conn2">> => #{user_id => 6, channel_id => 100}
},
State = #{
id => 10,
voice_states => #{},
pending_voice_connections => PendingConnections
},
{reply, #{success := true}, NewState} =
disconnect_voice_user(#{user_id => 5}, State),
NewPending = maps:get(pending_voice_connections, NewState),
?assertNot(maps:is_key(<<"conn1">>, NewPending)),
?assert(maps:is_key(<<"conn2">>, NewPending)).
disconnect_voice_user_cleans_pending_for_specific_connection_test() ->
PendingConnections = #{
<<"conn1">> => #{user_id => 5, channel_id => 100},
<<"conn2">> => #{user_id => 5, channel_id => 200}
},
State = #{
id => 10,
voice_states => #{},
pending_voice_connections => PendingConnections
},
{reply, #{success := true}, NewState} =
disconnect_voice_user(#{user_id => 5, connection_id => <<"conn1">>}, State),
NewPending = maps:get(pending_voice_connections, NewState),
?assertNot(maps:is_key(<<"conn1">>, NewPending)),
?assert(maps:is_key(<<"conn2">>, NewPending)).
disconnect_voice_user_if_in_channel_cleans_pending_when_not_found_test() ->
PendingConnections = #{
<<"conn1">> => #{user_id => 5, channel_id => 100},
<<"conn2">> => #{user_id => 6, channel_id => 100}
},
State = #{
id => 10,
voice_states => #{},
pending_voice_connections => PendingConnections
},
{reply, #{success := true, ignored := true}, NewState} =
disconnect_voice_user_if_in_channel(
#{user_id => 5, expected_channel_id => 100},
State
),
NewPending = maps:get(pending_voice_connections, NewState),
?assertNot(maps:is_key(<<"conn1">>, NewPending)),
?assert(maps:is_key(<<"conn2">>, NewPending)).
disconnect_all_voice_users_in_channel_cleans_pending_test() ->
Self = self(),
TestFun = fun(GuildId, ChannelId, UserId, ConnectionId) ->
Self ! {force_disconnect, GuildId, ChannelId, UserId, ConnectionId},
{ok, #{success => true}}
end,
VoiceStates = #{
<<"a">> => voice_state_fixture(5, 10, 20)
},
PendingConnections = #{
<<"pending1">> => #{user_id => 6, channel_id => 20},
<<"pending2">> => #{user_id => 7, channel_id => 30}
},
State = #{
id => 10,
voice_states => VoiceStates,
pending_voice_connections => PendingConnections,
test_force_disconnect_fun => TestFun
},
{reply, #{success := true, disconnected_count := 1}, NewState} =
disconnect_all_voice_users_in_channel(#{channel_id => 20}, State),
_ = collect_force_disconnect_messages(1),
NewPending = maps:get(pending_voice_connections, NewState),
?assertNot(maps:is_key(<<"pending1">>, NewPending)),
?assert(maps:is_key(<<"pending2">>, NewPending)).
disconnect_all_voice_users_in_channel_cleans_pending_when_no_active_states_test() ->
PendingConnections = #{
<<"pending1">> => #{user_id => 5, channel_id => 20},
<<"pending2">> => #{user_id => 6, channel_id => 30}
},
State = #{
id => 10,
voice_states => #{},
pending_voice_connections => PendingConnections
},
{reply, #{success := true, disconnected_count := 0}, NewState} =
disconnect_all_voice_users_in_channel(#{channel_id => 20}, State),
NewPending = maps:get(pending_voice_connections, NewState),
?assertNot(maps:is_key(<<"pending1">>, NewPending)),
?assert(maps:is_key(<<"pending2">>, NewPending)).
collect_force_disconnect_messages(Count) ->
collect_force_disconnect_messages(Count, []).
collect_force_disconnect_messages(0, Acc) ->
lists:reverse(Acc);
collect_force_disconnect_messages(Count, Acc) when Count > 0 ->
receive
{force_disconnect, _, _, _, _} = Msg ->
collect_force_disconnect_messages(Count - 1, [Msg | Acc]);
_Other ->
collect_force_disconnect_messages(Count, Acc)
after 200 ->
lists:reverse(Acc)
end.
voice_state_fixture(UserId, GuildId, ChannelId) ->
#{
<<"user_id">> => integer_to_binary(UserId),

View File

@@ -40,7 +40,6 @@ update_member_voice(Request, State) ->
#{user_id := UserId, mute := Mute, deaf := Deaf} = Request,
VoiceStates = voice_state_utils:voice_states(State),
GuildId = map_utils:get_integer(State, id, 0),
case find_member_by_user_id(UserId, State) of
undefined ->
{reply, gateway_errors:error(voice_member_not_found), State};
@@ -48,7 +47,6 @@ update_member_voice(Request, State) ->
UpdatedMember = set_member_voice_flags(Member, Mute, Deaf),
StateWithUpdatedMember = store_member(UpdatedMember, State),
UserVoiceStates = user_voice_states(UserId, VoiceStates),
case maps:size(UserVoiceStates) of
0 ->
{reply, #{success => true}, StateWithUpdatedMember};
@@ -64,43 +62,22 @@ update_member_voice(Request, State) ->
end
end.
-spec find_member_by_user_id(integer(), guild_state()) -> member() | undefined.
find_member_by_user_id(UserId, State) ->
guild_permissions:find_member_by_user_id(UserId, State).
-spec find_channel_by_id(integer(), guild_state()) -> map() | undefined.
find_channel_by_id(ChannelId, State) ->
guild_permissions:find_channel_by_id(ChannelId, State).
-spec enforce_participant_state_in_livekit(integer(), integer(), integer(), boolean(), boolean()) ->
ok.
enforce_participant_state_in_livekit(GuildId, ChannelId, UserId, Mute, Deaf) ->
Req = voice_utils:build_update_participant_rpc_request(GuildId, ChannelId, UserId, Mute, Deaf),
case rpc_client:call(Req) of
{ok, _Data} ->
logger:debug(
"[guild_voice_member] Enforced participant state in LiveKit ~p",
[
[
{guildId, GuildId},
{channelId, ChannelId},
{userId, UserId},
{mute, Mute},
{deaf, Deaf}
]
]
),
ok;
{error, Reason} ->
logger:warning(
"[guild_voice_member] Failed to enforce participant state in LiveKit ~p",
[
[
{guildId, GuildId},
{channelId, ChannelId},
{userId, UserId},
{mute, Mute},
{deaf, Deaf},
{error, Reason}
]
]
),
{error, _Reason} ->
ok
end.
@@ -108,16 +85,10 @@ enforce_participant_state_in_livekit(GuildId, ChannelId, UserId, Mute, Deaf) ->
guild_data(State) ->
map_utils:ensure_map(map_utils:get_safe(State, data, #{})).
-spec guild_members(guild_state()) -> [member()].
guild_members(State) ->
map_utils:ensure_list(maps:get(<<"members">>, guild_data(State), [])).
-spec member_user_id(member()) -> integer() | undefined.
member_user_id(Member) when is_map(Member) ->
User = map_utils:ensure_map(maps:get(<<"user">>, Member, #{})),
map_utils:get_integer(User, <<"id">>, undefined);
member_user_id(_) ->
undefined.
map_utils:get_integer(User, <<"id">>, undefined).
-spec set_member_voice_flags(member(), boolean(), boolean()) -> member().
set_member_voice_flags(Member, Mute, Deaf) ->
@@ -128,19 +99,9 @@ store_member(Member, State) ->
case member_user_id(Member) of
undefined ->
State;
TargetId ->
_TargetId ->
Data = guild_data(State),
Members = guild_members(State),
UpdatedMembers = lists:map(
fun(Current) ->
case member_user_id(Current) of
TargetId -> Member;
_ -> Current
end
end,
Members
),
UpdatedData = maps:put(<<"members">>, UpdatedMembers, Data),
UpdatedData = guild_data_index:put_member(Member, Data),
maps:put(data, UpdatedData, State)
end.
@@ -246,9 +207,17 @@ update_member_voice_updates_voice_states_test() ->
?assert(false)
end.
update_member_voice_member_not_found_test() ->
State = voice_member_test_state(#{}),
Request = #{user_id => 999, mute => true, deaf => false},
{reply, Error, _} = update_member_voice(Request, State),
?assertEqual({error, not_found, voice_member_not_found}, Error).
voice_member_test_state(Overrides) ->
BaseData = #{
<<"members">> => [member_fixture(10)]
<<"members">> => #{
10 => member_fixture(10)
}
},
BaseState = #{
id => 42,

View File

@@ -19,9 +19,16 @@
-export([move_member/2]).
-export([send_voice_server_update_for_move/5]).
-export([send_voice_server_update_for_move/6]).
-export([send_voice_server_updates_for_move/4]).
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-endif.
-type guild_state() :: map().
-type voice_state() :: map().
-type voice_state_map() :: #{binary() => voice_state()}.
-type move_request() :: #{
user_id := integer(),
moderator_id := integer(),
@@ -31,10 +38,6 @@
deaf := boolean()
}.
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-endif.
-spec move_member(move_request(), guild_state()) -> {reply, map(), guild_state()}.
move_member(Request, State) ->
#{
@@ -44,10 +47,17 @@ move_member(Request, State) ->
} = Request,
ConnectionId = maps:get(connection_id, Request, null),
ChannelId = normalize_channel_id(ChannelIdRaw),
logger:debug(
"Handling voice move_member request",
#{
user_id => UserId,
moderator_id => ModeratorId,
channel_id => ChannelId,
connection_id => ConnectionId
}
),
VoiceStates = voice_state_utils:voice_states(State),
UserVoiceStates = find_user_voice_states(UserId, VoiceStates),
case maps:size(UserVoiceStates) of
0 ->
{reply, gateway_errors:error(voice_user_not_in_voice), State};
@@ -55,11 +65,20 @@ move_member(Request, State) ->
ConnectionsToMove = select_connections_to_move(
ConnectionId, UserId, VoiceStates, UserVoiceStates
),
logger:debug(
"Selected voice connections to move",
#{
user_id => UserId,
connection_id => ConnectionId,
connections_to_move_count => maps:size(ConnectionsToMove)
}
),
handle_move(
ConnectionsToMove, ChannelId, UserId, ModeratorId, ConnectionId, VoiceStates, State
)
end.
-spec find_user_voice_states(integer(), voice_state_map()) -> voice_state_map().
find_user_voice_states(UserId, VoiceStates) ->
maps:filter(
fun(_ConnId, VoiceState) ->
@@ -68,6 +87,8 @@ find_user_voice_states(UserId, VoiceStates) ->
VoiceStates
).
-spec select_connections_to_move(binary() | null, integer(), voice_state_map(), voice_state_map()) ->
voice_state_map().
select_connections_to_move(null, _UserId, _VoiceStates, UserVoiceStates) ->
UserVoiceStates;
select_connections_to_move(ConnectionId, UserId, VoiceStates, _UserVoiceStates) ->
@@ -83,11 +104,16 @@ select_connections_to_move(ConnectionId, UserId, VoiceStates, _UserVoiceStates)
end
end.
-spec handle_move(
voice_state_map(),
integer() | null,
integer(),
integer(),
binary() | null,
voice_state_map(),
guild_state()
) -> {reply, map(), guild_state()}.
handle_move(ConnectionsToMove, ChannelId, UserId, ModeratorId, ConnectionId, VoiceStates, State) ->
logger:info(
"[guild_voice_move] handle_move user_id=~p moderator_id=~p channel_id=~p connection_id=~p connections=~p",
[UserId, ModeratorId, ChannelId, ConnectionId, maps:keys(ConnectionsToMove)]
),
case maps:size(ConnectionsToMove) of
0 ->
Error =
@@ -99,14 +125,24 @@ handle_move(ConnectionsToMove, ChannelId, UserId, ModeratorId, ConnectionId, Voi
_ ->
case ChannelId of
null ->
logger:debug(
"Disconnect move requested",
#{user_id => UserId, connection_id => ConnectionId}
),
handle_disconnect_move(ConnectionsToMove, UserId, VoiceStates, State);
ChannelIdValue ->
logger:debug(
"Channel move requested",
#{user_id => UserId, channel_id => ChannelIdValue, connection_id => ConnectionId}
),
handle_channel_move(
ConnectionsToMove, ChannelIdValue, UserId, ModeratorId, VoiceStates, State
)
end
end.
-spec handle_disconnect_move(voice_state_map(), integer(), voice_state_map(), guild_state()) ->
{reply, map(), guild_state()}.
handle_disconnect_move(ConnectionsToMove, UserId, VoiceStates, State) ->
NewVoiceStates = maps:fold(
fun(ConnId, _VoiceState, Acc) -> maps:remove(ConnId, Acc) end,
@@ -114,79 +150,107 @@ handle_disconnect_move(ConnectionsToMove, UserId, VoiceStates, State) ->
ConnectionsToMove
),
NewState = maps:put(voice_states, NewVoiceStates, State),
maps:foreach(
fun(_ConnId, VoiceState) ->
OldChannelIdBin = maps:get(<<"channel_id">>, VoiceState, null),
DisconnectVoiceState = maps:put(<<"channel_id">>, null, VoiceState),
guild_voice_broadcast:broadcast_voice_state_update(
DisconnectVoiceState, NewState, OldChannelIdBin
)
end,
ConnectionsToMove
),
spawn(fun() ->
maps:foreach(
fun(_ConnId, VoiceState) ->
OldChannelIdBin = maps:get(<<"channel_id">>, VoiceState, null),
DisconnectVoiceState = maps:put(<<"channel_id">>, null, VoiceState),
guild_voice_broadcast:broadcast_voice_state_update(
DisconnectVoiceState, NewState, OldChannelIdBin
)
end,
ConnectionsToMove
)
end),
{reply, #{success => true, user_id => UserId, connections_moved => ConnectionsToMove},
NewState}.
-spec handle_channel_move(
voice_state_map(), integer(), integer(), integer(), voice_state_map(), guild_state()
) -> {reply, map(), guild_state()}.
handle_channel_move(ConnectionsToMove, ChannelIdValue, UserId, ModeratorId, VoiceStates, State) ->
logger:info(
"[guild_voice_move] handle_channel_move user_id=~p moderator_id=~p target_channel_id=~p connections=~p",
[UserId, ModeratorId, ChannelIdValue, maps:keys(ConnectionsToMove)]
),
Channel = guild_voice_member:find_channel_by_id(ChannelIdValue, State),
case Channel of
undefined ->
{reply, gateway_errors:error(voice_channel_not_found), State};
_ ->
StateWithPending0 = guild_virtual_channel_access:mark_pending_join(
UserId, ChannelIdValue, State
),
StateWithPending1 = guild_virtual_channel_access:mark_preserve(
UserId, ChannelIdValue, StateWithPending0
),
StateWithPending2 = guild_virtual_channel_access:mark_move_pending(
UserId, ChannelIdValue, StateWithPending1
),
ChannelType = maps:get(<<"type">>, Channel, 0),
case ChannelType of
2 ->
check_move_permissions_and_execute(
ConnectionsToMove, ChannelIdValue, UserId, ModeratorId, VoiceStates, State
ConnectionsToMove,
ChannelIdValue,
UserId,
ModeratorId,
VoiceStates,
StateWithPending2
);
_ ->
{reply, gateway_errors:error(voice_channel_not_voice), State}
end
end.
-spec check_move_permissions_and_execute(
voice_state_map(), integer(), integer(), integer(), voice_state_map(), guild_state()
) -> {reply, map(), guild_state()}.
check_move_permissions_and_execute(
ConnectionsToMove, ChannelIdValue, _UserId, ModeratorId, VoiceStates, State
ConnectionsToMove, ChannelIdValue, UserId, ModeratorId, VoiceStates, State
) ->
ViewPerm = constants:view_channel_permission(),
ConnectPerm = constants:connect_permission(),
ModPerms = guild_permissions:get_member_permissions(ModeratorId, ChannelIdValue, State),
ModHasConnect = (ModPerms band ConnectPerm) =:= ConnectPerm,
ModHasView = (ModPerms band ViewPerm) =:= ViewPerm,
case ModHasConnect andalso ModHasView of
false ->
{reply, gateway_errors:error(voice_moderator_missing_connect), State};
true ->
execute_move(ConnectionsToMove, VoiceStates, State)
execute_move(ConnectionsToMove, ChannelIdValue, UserId, VoiceStates, State)
end.
execute_move(ConnectionsToMove, VoiceStates, State) ->
-spec execute_move(voice_state_map(), integer(), integer(), voice_state_map(), guild_state()) ->
{reply, map(), guild_state()}.
execute_move(ConnectionsToMove, ChannelIdValue, UserId, VoiceStates, State) ->
StatePending = guild_virtual_channel_access:mark_pending_join(UserId, ChannelIdValue, State),
StatePending2 = guild_virtual_channel_access:mark_preserve(
UserId, ChannelIdValue, StatePending
),
StatePending3 = guild_virtual_channel_access:mark_move_pending(
UserId, ChannelIdValue, StatePending2
),
logger:debug(
"Executing voice channel move",
#{user_id => UserId, channel_id => ChannelIdValue}
),
NewVoiceStates = maps:fold(
fun(ConnId, _VoiceState, Acc) -> maps:remove(ConnId, Acc) end,
VoiceStates,
ConnectionsToMove
),
StateAfterDisconnect = maps:put(voice_states, NewVoiceStates, State),
maps:foreach(
fun(_ConnId, VoiceState) ->
OldChannelIdBin = maps:get(<<"channel_id">>, VoiceState, null),
DisconnectVoiceState = maps:put(<<"channel_id">>, null, VoiceState),
guild_voice_broadcast:broadcast_voice_state_update(
DisconnectVoiceState, StateAfterDisconnect, OldChannelIdBin
)
end,
ConnectionsToMove
),
StateAfterDisconnect = maps:put(voice_states, NewVoiceStates, StatePending3),
StateWithVirtualAccess = maybe_add_virtual_access(UserId, ChannelIdValue, StateAfterDisconnect),
spawn(fun() ->
maps:foreach(
fun(_ConnId, VoiceState) ->
OldChannelIdBin = maps:get(<<"channel_id">>, VoiceState, null),
DisconnectVoiceState = maps:put(<<"channel_id">>, null, VoiceState),
guild_voice_broadcast:broadcast_voice_state_update(
DisconnectVoiceState, StateWithVirtualAccess, OldChannelIdBin
)
end,
ConnectionsToMove
)
end),
SessionData = extract_session_data(ConnectionsToMove),
{reply,
#{
success => true,
@@ -194,8 +258,9 @@ execute_move(ConnectionsToMove, VoiceStates, State) ->
session_data => SessionData,
connections_to_move => ConnectionsToMove
},
StateAfterDisconnect}.
StateWithVirtualAccess}.
-spec extract_session_data(voice_state_map()) -> [map()].
extract_session_data(ConnectionsToMove) ->
{_ConnectionIds, SessionData} = maps:fold(
fun(ConnId, VoiceState, {AccConnIds, AccSessionData}) ->
@@ -223,6 +288,159 @@ member_user_id(Member) ->
User = map_utils:ensure_map(maps:get(<<"user">>, map_utils:ensure_map(Member), #{})),
map_utils:get_integer(User, <<"id">>, undefined).
-spec send_voice_server_update_for_move(
integer(), integer(), integer(), binary() | undefined, pid()
) -> ok.
send_voice_server_update_for_move(GuildId, ChannelId, UserId, SessionId, GuildPid) ->
send_voice_server_update_for_move(GuildId, ChannelId, UserId, SessionId, null, GuildPid).
-spec send_voice_server_update_for_move(
integer(), integer(), integer(), binary() | undefined, binary() | null, pid()
) -> ok.
send_voice_server_update_for_move(GuildId, ChannelId, UserId, SessionId, OldConnectionId, GuildPid) ->
case SessionId of
undefined ->
ok;
_ ->
spawn(fun() ->
case gen_server:call(GuildPid, {get_sessions}, 10000) of
State when is_map(State) ->
VoicePermissions = voice_utils:compute_voice_permissions(
UserId, ChannelId, State
),
case
guild_voice_connection:request_voice_token(
GuildId, ChannelId, UserId, OldConnectionId, VoicePermissions
)
of
{ok, TokenData} ->
Token = maps:get(token, TokenData),
Endpoint = maps:get(endpoint, TokenData),
ConnectionId = maps:get(connection_id, TokenData),
guild_voice_broadcast:broadcast_voice_server_update_to_session(
GuildId,
ChannelId,
SessionId,
Token,
Endpoint,
ConnectionId,
State
);
{error, _Reason} ->
ok
end;
_ ->
ok
end
end),
ok
end.
-spec maybe_add_virtual_access(integer(), integer(), guild_state()) -> guild_state().
maybe_add_virtual_access(UserId, ChannelId, State) ->
Member = guild_permissions:find_member_by_user_id(UserId, State),
case Member of
undefined ->
State;
_ ->
Permissions = guild_permissions:get_member_permissions(UserId, ChannelId, State),
ViewPerm = constants:view_channel_permission(),
ConnectPerm = constants:connect_permission(),
HasView = (Permissions band ViewPerm) =:= ViewPerm,
HasConnect = (Permissions band ConnectPerm) =:= ConnectPerm,
case HasView andalso HasConnect of
true ->
State;
false ->
NewState = guild_virtual_channel_access:add_virtual_access(
UserId, ChannelId, State
),
guild_virtual_channel_access:dispatch_channel_visibility_change(
UserId, ChannelId, add, NewState
),
NewState
end
end.
-spec send_voice_server_updates_for_move(integer(), integer(), [map()], pid()) -> ok.
send_voice_server_updates_for_move(GuildId, ChannelId, SessionDataList, GuildPid) ->
spawn(fun() ->
lists:foreach(
fun(SessionInfo) ->
send_single_voice_server_update(GuildId, ChannelId, SessionInfo, GuildPid)
end,
SessionDataList
)
end),
ok.
-spec send_single_voice_server_update(integer(), integer(), map(), pid()) -> ok.
send_single_voice_server_update(GuildId, ChannelId, SessionInfo, GuildPid) ->
SessionId = maps:get(session_id, SessionInfo),
SelfMute = maps:get(self_mute, SessionInfo),
SelfDeaf = maps:get(self_deaf, SessionInfo),
SelfVideo = maps:get(self_video, SessionInfo),
SelfStream = maps:get(self_stream, SessionInfo),
IsMobile = maps:get(is_mobile, SessionInfo),
OldConnectionId = maps:get(connection_id, SessionInfo, null),
Member = maps:get(member, SessionInfo),
ServerMute = maps:get(<<"mute">>, Member, false),
ServerDeaf = maps:get(<<"deaf">>, Member, false),
case member_user_id(Member) of
undefined ->
ok;
UserId ->
case gen_server:call(GuildPid, {get_sessions}, 10000) of
StateData when is_map(StateData) ->
VoicePermissions = voice_utils:compute_voice_permissions(
UserId, ChannelId, StateData
),
case
guild_voice_connection:request_voice_token(
GuildId, ChannelId, UserId, OldConnectionId, VoicePermissions
)
of
{ok, TokenData} ->
Token = maps:get(token, TokenData),
Endpoint = maps:get(endpoint, TokenData),
NewConnectionId = maps:get(connection_id, TokenData),
PendingMetadata = #{
<<"user_id">> => UserId,
<<"guild_id">> => GuildId,
<<"channel_id">> => ChannelId,
<<"connection_id">> => NewConnectionId,
<<"session_id">> => SessionId,
<<"self_mute">> => SelfMute,
<<"self_deaf">> => SelfDeaf,
<<"self_video">> => SelfVideo,
<<"self_stream">> => SelfStream,
<<"is_mobile">> => IsMobile,
<<"server_mute">> => ServerMute,
<<"server_deaf">> => ServerDeaf,
<<"member">> => Member
},
_ = gen_server:call(
GuildPid,
{store_pending_connection, NewConnectionId, PendingMetadata},
10000
),
guild_voice_broadcast:broadcast_voice_server_update_to_session(
GuildId,
ChannelId,
SessionId,
Token,
Endpoint,
NewConnectionId,
StateData
);
{error, _Reason} ->
ok
end;
_ ->
ok
end
end.
-ifdef(TEST).
move_member_user_not_in_voice_test() ->
@@ -253,6 +471,12 @@ select_connections_to_move_specific_connection_test() ->
?assertEqual(#{<<"conn-b">> => maps:get(<<"conn-b">>, VoiceStates)}, Selected),
?assertEqual(#{}, select_connections_to_move(<<"conn-b">>, 10, VoiceStates, #{})).
normalize_channel_id_test() ->
?assertEqual(null, normalize_channel_id(null)),
?assertEqual(123, normalize_channel_id(123)),
?assertEqual(456, normalize_channel_id(<<"456">>)),
?assertEqual(null, normalize_channel_id(undefined)).
test_state(VoiceStates) ->
#{
id => 1,
@@ -274,105 +498,3 @@ voice_state_fixture(UserId, ChannelId, ConnId) ->
}.
-endif.
send_voice_server_update_for_move(GuildId, ChannelId, UserId, SessionId, GuildPid) ->
case SessionId of
undefined ->
ok;
_ ->
case gen_server:call(GuildPid, {get_sessions}, 10000) of
State when is_map(State) ->
VoicePermissions = voice_utils:compute_voice_permissions(
UserId, ChannelId, State
),
case
guild_voice_connection:request_voice_token(
GuildId, ChannelId, UserId, VoicePermissions
)
of
{ok, TokenData} ->
Token = maps:get(token, TokenData),
Endpoint = maps:get(endpoint, TokenData),
ConnectionId = maps:get(connection_id, TokenData),
guild_voice_broadcast:broadcast_voice_server_update_to_session(
GuildId, SessionId, Token, Endpoint, ConnectionId, State
);
{error, _Reason} ->
ok
end;
_ ->
ok
end
end.
send_voice_server_updates_for_move(GuildId, ChannelId, SessionDataList, GuildPid) ->
lists:foreach(
fun(SessionInfo) ->
send_single_voice_server_update(GuildId, ChannelId, SessionInfo, GuildPid)
end,
SessionDataList
).
send_single_voice_server_update(GuildId, ChannelId, SessionInfo, GuildPid) ->
SessionId = maps:get(session_id, SessionInfo),
SelfMute = maps:get(self_mute, SessionInfo),
SelfDeaf = maps:get(self_deaf, SessionInfo),
SelfVideo = maps:get(self_video, SessionInfo),
SelfStream = maps:get(self_stream, SessionInfo),
IsMobile = maps:get(is_mobile, SessionInfo),
Member = maps:get(member, SessionInfo),
ServerMute = maps:get(<<"mute">>, Member, false),
ServerDeaf = maps:get(<<"deaf">>, Member, false),
case member_user_id(Member) of
undefined ->
logger:warning(
"[guild_voice_move] Missing user_id in member while sending voice server update: ~p",
[SessionInfo]
),
ok;
UserId ->
case gen_server:call(GuildPid, {get_sessions}, 10000) of
StateData when is_map(StateData) ->
VoicePermissions = voice_utils:compute_voice_permissions(
UserId, ChannelId, StateData
),
case
guild_voice_connection:request_voice_token(
GuildId, ChannelId, UserId, VoicePermissions
)
of
{ok, TokenData} ->
Token = maps:get(token, TokenData),
Endpoint = maps:get(endpoint, TokenData),
NewConnectionId = maps:get(connection_id, TokenData),
PendingMetadata = #{
<<"user_id">> => UserId,
<<"guild_id">> => GuildId,
<<"channel_id">> => ChannelId,
<<"connection_id">> => NewConnectionId,
<<"session_id">> => SessionId,
<<"self_mute">> => SelfMute,
<<"self_deaf">> => SelfDeaf,
<<"self_video">> => SelfVideo,
<<"self_stream">> => SelfStream,
<<"is_mobile">> => IsMobile,
<<"server_mute">> => ServerMute,
<<"server_deaf">> => ServerDeaf,
<<"member">> => Member
},
gen_server:cast(
GuildPid,
{store_pending_connection, NewConnectionId, PendingMetadata}
),
guild_voice_broadcast:broadcast_voice_server_update_to_session(
GuildId, SessionId, Token, Endpoint, NewConnectionId, StateData
);
{error, _Reason} ->
ok
end;
_ ->
ok
end
end.

View File

@@ -26,23 +26,23 @@
-type guild_state() :: map().
-type voice_state() :: map().
-type user_id() :: integer().
-type channel_id() :: integer().
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-endif.
-spec sync_user_voice_permissions(integer(), guild_state()) -> ok.
-spec sync_user_voice_permissions(user_id(), guild_state()) -> ok.
sync_user_voice_permissions(UserId, State) ->
VoiceStates = voice_state_utils:voice_states(State),
GuildId = map_utils:get_integer(State, id, 0),
UserVoiceStates = maps:filter(
fun(_ConnId, VoiceState) ->
voice_state_utils:voice_state_user_id(VoiceState) =:= UserId
end,
VoiceStates
),
maps:foreach(
fun(_ConnId, VoiceState) ->
sync_voice_state_permissions(GuildId, UserId, VoiceState, State)
@@ -51,18 +51,16 @@ sync_user_voice_permissions(UserId, State) ->
),
ok.
-spec sync_all_voice_permissions_for_channel(integer(), guild_state()) -> ok.
-spec sync_all_voice_permissions_for_channel(channel_id(), guild_state()) -> ok.
sync_all_voice_permissions_for_channel(ChannelId, State) ->
VoiceStates = voice_state_utils:voice_states(State),
GuildId = map_utils:get_integer(State, id, 0),
ChannelVoiceStates = maps:filter(
fun(_ConnId, VoiceState) ->
voice_state_utils:voice_state_channel_id(VoiceState) =:= ChannelId
end,
VoiceStates
),
maps:foreach(
fun(_ConnId, VoiceState) ->
UserId = voice_state_utils:voice_state_user_id(VoiceState),
@@ -84,15 +82,12 @@ maybe_sync_permissions_on_role_update(RoleUpdate, State) ->
_ ->
OldPermissions = maps:get(<<"old_permissions">>, RoleUpdate, 0),
NewPermissions = maps:get(<<"permissions">>, RoleUpdate, 0),
AdminPerm = constants:administrator_permission(),
SpeakPerm = constants:speak_permission(),
StreamPerm = constants:stream_permission(),
VoicePerms = AdminPerm bor SpeakPerm bor StreamPerm,
OldVoicePerms = OldPermissions band VoicePerms,
NewVoicePerms = NewPermissions band VoicePerms,
case OldVoicePerms =/= NewVoicePerms of
true ->
sync_users_with_role(RoleId, State);
@@ -110,7 +105,6 @@ maybe_sync_permissions_on_member_update(MemberUpdate, State) ->
_ ->
OldRoles = maps:get(<<"old_roles">>, MemberUpdate, []),
NewRoles = maps:get(<<"roles">>, MemberUpdate, []),
case OldRoles =/= NewRoles of
true ->
sync_user_voice_permissions(UserId, State);
@@ -119,11 +113,10 @@ maybe_sync_permissions_on_member_update(MemberUpdate, State) ->
end
end.
-spec sync_voice_state_permissions(integer(), integer(), voice_state(), guild_state()) -> ok.
-spec sync_voice_state_permissions(integer(), user_id(), voice_state(), guild_state()) -> ok.
sync_voice_state_permissions(GuildId, UserId, VoiceState, State) ->
ChannelId = voice_state_utils:voice_state_channel_id(VoiceState),
ConnectionId = maps:get(<<"connection_id">>, VoiceState, undefined),
case {ChannelId, ConnectionId} of
{undefined, _} ->
ok;
@@ -134,7 +127,9 @@ sync_voice_state_permissions(GuildId, UserId, VoiceState, State) ->
dispatch_permission_update(GuildId, ChId, UserId, ConnId, VoicePermissions, State)
end.
-spec dispatch_permission_update(integer(), integer(), integer(), binary(), map(), guild_state()) ->
-spec dispatch_permission_update(
integer(), channel_id(), user_id(), binary(), map(), guild_state()
) ->
ok.
dispatch_permission_update(GuildId, ChannelId, UserId, ConnectionId, VoicePermissions, State) ->
case maps:get(test_permission_sync_fun, State, undefined) of
@@ -149,7 +144,7 @@ dispatch_permission_update(GuildId, ChannelId, UserId, ConnectionId, VoicePermis
end.
-spec enforce_voice_permissions_in_livekit(
integer(), integer(), integer(), binary(), map()
integer(), channel_id(), user_id(), binary(), map()
) -> ok.
enforce_voice_permissions_in_livekit(GuildId, ChannelId, UserId, ConnectionId, VoicePermissions) ->
Req = voice_utils:build_update_participant_permissions_rpc_request(
@@ -157,33 +152,8 @@ enforce_voice_permissions_in_livekit(GuildId, ChannelId, UserId, ConnectionId, V
),
case rpc_client:call(Req) of
{ok, _Data} ->
logger:debug(
"[guild_voice_permission_sync] Synced voice permissions ~p",
[
[
{guildId, GuildId},
{channelId, ChannelId},
{userId, UserId},
{connectionId, ConnectionId},
{permissions, VoicePermissions}
]
]
),
ok;
{error, Reason} ->
logger:warning(
"[guild_voice_permission_sync] Failed to sync voice permissions ~p",
[
[
{guildId, GuildId},
{channelId, ChannelId},
{userId, UserId},
{connectionId, ConnectionId},
{permissions, VoicePermissions},
{error, Reason}
]
]
),
{error, _Reason} ->
ok
end.
@@ -192,7 +162,6 @@ sync_users_with_role(RoleId, State) ->
RoleIdBin = ensure_binary(RoleId),
VoiceStates = voice_state_utils:voice_states(State),
GuildId = map_utils:get_integer(State, id, 0),
maps:foreach(
fun(_ConnId, VoiceState) ->
UserId = voice_state_utils:voice_state_user_id(VoiceState),
@@ -212,7 +181,7 @@ sync_users_with_role(RoleId, State) ->
),
ok.
-spec user_has_role(integer(), binary(), guild_state()) -> boolean().
-spec user_has_role(user_id(), binary(), guild_state()) -> boolean().
user_has_role(UserId, RoleIdBin, State) ->
case guild_voice_member:find_member_by_user_id(UserId, State) of
undefined ->
@@ -222,7 +191,7 @@ user_has_role(UserId, RoleIdBin, State) ->
lists:member(RoleIdBin, Roles)
end.
-spec get_member_user_id(map()) -> integer() | undefined.
-spec get_member_user_id(map()) -> user_id() | undefined.
get_member_user_id(MemberUpdate) ->
User = maps:get(<<"user">>, MemberUpdate, #{}),
map_utils:get_integer(User, <<"id">>, undefined).
@@ -242,19 +211,16 @@ sync_user_voice_permissions_syncs_connected_user_test() ->
ChannelId = 500,
GuildId = 42,
RoleId = 999,
VoiceState = #{
<<"user_id">> => integer_to_binary(UserId),
<<"channel_id">> => integer_to_binary(ChannelId),
<<"connection_id">> => <<"test-conn">>
},
Permissions =
constants:view_channel_permission() bor
constants:connect_permission() bor
constants:speak_permission() bor
constants:stream_permission(),
State = #{
id => GuildId,
voice_states => #{<<"conn">> => VoiceState},
@@ -301,4 +267,17 @@ sync_user_voice_permissions_no_voice_state_test() ->
},
ok = sync_user_voice_permissions(10, State).
maybe_sync_permissions_on_member_update_no_role_change_test() ->
State = #{id => 42, voice_states => #{}},
MemberUpdate = #{
<<"user">> => #{<<"id">> => <<"10">>},
<<"roles">> => [<<"1">>],
<<"old_roles">> => [<<"1">>]
},
?assertEqual(ok, maybe_sync_permissions_on_member_update(MemberUpdate, State)).
ensure_binary_test() ->
?assertEqual(<<"123">>, ensure_binary(123)),
?assertEqual(<<"abc">>, ensure_binary(<<"abc">>)).
-endif.

View File

@@ -19,15 +19,15 @@
-export([check_voice_permissions_and_limits/6]).
-type guild_state() :: map().
-type voice_state_map() :: #{binary() => map()}.
-type channel() :: map().
-import(utils, [parse_iso8601_to_unix_ms/1]).
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-endif.
-import(utils, [parse_iso8601_to_unix_ms/1]).
-type guild_state() :: map().
-type voice_state_map() :: #{binary() => map()}.
-type channel() :: map().
-spec check_voice_permissions_and_limits(
integer(), integer(), channel(), voice_state_map(), guild_state(), boolean()
@@ -45,8 +45,10 @@ check_voice_permissions_and_limits(UserId, ChannelIdValue, Channel, VoiceStates,
case
channel_has_capacity(UserId, ChannelIdValue, Channel, VoiceStates, IsUpdate)
of
true -> {ok, allowed};
false -> gateway_errors:error(voice_channel_full)
true ->
{ok, allowed};
false ->
gateway_errors:error(voice_channel_full)
end
end
end.
@@ -57,18 +59,26 @@ has_view_and_connect_perms(UserId, ChannelIdValue, State) ->
true ->
true;
false ->
Permissions = resolve_permissions(UserId, ChannelIdValue, State),
ViewPerm = constants:view_channel_permission(),
ConnectPerm = constants:connect_permission(),
(Permissions band ViewPerm) =:= ViewPerm andalso
(Permissions band ConnectPerm) =:= ConnectPerm
case guild_virtual_channel_access:is_move_pending(UserId, ChannelIdValue, State) of
true ->
true;
false ->
Permissions = resolve_permissions(UserId, ChannelIdValue, State),
ViewPerm = constants:view_channel_permission(),
ConnectPerm = constants:connect_permission(),
HasView = (Permissions band ViewPerm) =:= ViewPerm,
HasConnect = (Permissions band ConnectPerm) =:= ConnectPerm,
HasView andalso HasConnect
end
end.
-spec channel_has_capacity(integer(), integer(), channel(), voice_state_map(), boolean()) ->
boolean().
channel_has_capacity(UserId, ChannelIdValue, Channel, VoiceStates, IsUpdate) ->
UserLimit = maps:get(<<"user_limit">>, Channel, 0),
case UserLimit of
AnyCameraActive = any_camera_active_in_channel(ChannelIdValue, VoiceStates),
EffectiveLimit = effective_user_limit(UserLimit, AnyCameraActive),
case EffectiveLimit of
0 ->
true;
Limit when Limit > 0 ->
@@ -85,6 +95,26 @@ channel_has_capacity(UserId, ChannelIdValue, Channel, VoiceStates, IsUpdate) ->
true
end.
-spec any_camera_active_in_channel(integer(), voice_state_map()) -> boolean().
any_camera_active_in_channel(ChannelIdValue, VoiceStates) ->
lists:any(
fun({_ConnId, VS}) ->
case map_utils:get_integer(VS, <<"channel_id">>, undefined) of
ChannelIdValue ->
maps:get(<<"self_video">>, VS, false) =:= true;
_ ->
false
end
end,
maps:to_list(VoiceStates)
).
-spec effective_user_limit(integer(), boolean()) -> integer().
effective_user_limit(0, false) -> 0;
effective_user_limit(0, true) -> 25;
effective_user_limit(Limit, false) -> Limit;
effective_user_limit(Limit, true) -> min(Limit, 25).
-spec is_member_timed_out(integer(), guild_state()) -> boolean().
is_member_timed_out(UserId, State) ->
case guild_permissions:find_member_by_user_id(UserId, State) of
@@ -98,13 +128,11 @@ is_member_timed_out(UserId, State) ->
undefined ->
false;
Value when is_integer(Value) ->
Value > erlang:system_time(millisecond);
_ ->
false
Value > erlang:system_time(millisecond)
end
end.
-spec users_in_channel(integer(), voice_state_map()) -> sets:set().
-spec users_in_channel(integer(), voice_state_map()) -> sets:set(integer()).
users_in_channel(ChannelIdValue, VoiceStates0) ->
VoiceStates = voice_state_utils:ensure_voice_states(VoiceStates0),
maps:fold(
@@ -161,6 +189,18 @@ voice_permissions_existing_user_update_test() ->
),
?assertEqual({ok, allowed}, Result).
users_in_channel_test() ->
VoiceStates = #{
<<"conn1">> => #{<<"channel_id">> => <<"10">>, <<"user_id">> => <<"1">>},
<<"conn2">> => #{<<"channel_id">> => <<"10">>, <<"user_id">> => <<"2">>},
<<"conn3">> => #{<<"channel_id">> => <<"20">>, <<"user_id">> => <<"3">>}
},
Result = users_in_channel(10, VoiceStates),
?assertEqual(2, sets:size(Result)),
?assert(sets:is_element(1, Result)),
?assert(sets:is_element(2, Result)),
?assertNot(sets:is_element(3, Result)).
required_voice_perms() ->
constants:view_channel_permission() bor constants:connect_permission().

View File

@@ -20,9 +20,17 @@
-export([switch_voice_region_handler/2]).
-export([switch_voice_region/3]).
-type guild_state() :: map().
-type guild_reply(T) :: {reply, T, guild_state()}.
-type voice_state() :: map().
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-endif.
-spec switch_voice_region_handler(map(), guild_state()) -> guild_reply(map()).
switch_voice_region_handler(Request, State) ->
#{channel_id := ChannelId} = Request,
Channel = guild_voice_member:find_channel_by_id(ChannelId, State),
case Channel of
undefined ->
@@ -37,42 +45,21 @@ switch_voice_region_handler(Request, State) ->
end
end.
-spec switch_voice_region(integer(), integer(), pid()) -> ok.
switch_voice_region(GuildId, ChannelId, GuildPid) ->
case gen_server:call(GuildPid, {get_sessions}, 10000) of
State when is_map(State) ->
VoiceStates = voice_state_utils:voice_states(State),
UsersInChannel = maps:fold(
fun(ConnectionId, VoiceState, Acc) ->
case voice_state_utils:voice_state_channel_id(VoiceState) of
ChannelId ->
case voice_state_utils:voice_state_user_id(VoiceState) of
undefined ->
logger:warning(
"[guild_voice_region] Missing user_id for connection ~p",
[ConnectionId]
),
Acc;
UserId ->
SessionId = maps:get(<<"session_id">>, VoiceState, undefined),
[{UserId, SessionId, VoiceState} | Acc]
end;
_ ->
Acc
end
end,
[],
VoiceStates
),
UsersInChannel = collect_users_in_channel(VoiceStates, ChannelId),
lists:foreach(
fun({UserId, SessionId, VoiceState}) ->
fun({UserId, SessionId, ExistingConnectionId, VoiceState}) ->
case SessionId of
undefined ->
ok;
_ ->
send_voice_server_update_for_region_switch(
GuildId, ChannelId, UserId, SessionId, VoiceState, GuildPid
GuildId, ChannelId, UserId, SessionId, ExistingConnectionId,
VoiceState, GuildPid
)
end
end,
@@ -82,42 +69,61 @@ switch_voice_region(GuildId, ChannelId, GuildPid) ->
ok
end.
-spec collect_users_in_channel(map(), integer()) ->
[{integer(), binary() | undefined, binary(), voice_state()}].
collect_users_in_channel(VoiceStates, ChannelId) ->
maps:fold(
fun(ConnectionId, VoiceState, Acc) ->
case voice_state_utils:voice_state_channel_id(VoiceState) of
ChannelId ->
case voice_state_utils:voice_state_user_id(VoiceState) of
undefined ->
Acc;
UserId ->
SessionId = maps:get(<<"session_id">>, VoiceState, undefined),
[{UserId, SessionId, ConnectionId, VoiceState} | Acc]
end;
_ ->
Acc
end
end,
[],
VoiceStates
).
-spec send_voice_server_update_for_region_switch(
integer(), integer(), integer(), binary(), binary(), voice_state(), pid()
) -> ok.
send_voice_server_update_for_region_switch(
GuildId, ChannelId, UserId, SessionId, ExistingVoiceState, GuildPid
GuildId, ChannelId, UserId, SessionId, ExistingConnectionId, ExistingVoiceState, GuildPid
) ->
case gen_server:call(GuildPid, {get_sessions}, 10000) of
State when is_map(State) ->
VoicePermissions = voice_utils:compute_voice_permissions(UserId, ChannelId, State),
TokenNonce = voice_utils:generate_token_nonce(),
case
guild_voice_connection:request_voice_token(
GuildId, ChannelId, UserId, VoicePermissions
GuildId, ChannelId, UserId, ExistingConnectionId, VoicePermissions, TokenNonce
)
of
{ok, TokenData} ->
Token = maps:get(token, TokenData),
Endpoint = maps:get(endpoint, TokenData),
ConnectionId = maps:get(connection_id, TokenData),
PendingMetadata = #{
user_id => UserId,
guild_id => GuildId,
channel_id => ChannelId,
session_id => SessionId,
self_mute => maps:get(<<"self_mute">>, ExistingVoiceState, false),
self_deaf => maps:get(<<"self_deaf">>, ExistingVoiceState, false),
self_video => maps:get(<<"self_video">>, ExistingVoiceState, false),
self_stream => maps:get(<<"self_stream">>, ExistingVoiceState, false),
is_mobile => maps:get(<<"is_mobile">>, ExistingVoiceState, false),
server_mute => maps:get(<<"mute">>, ExistingVoiceState, false),
server_deaf => maps:get(<<"deaf">>, ExistingVoiceState, false),
member => maps:get(<<"member">>, ExistingVoiceState, #{})
},
gen_server:cast(
GuildPid, {store_pending_connection, ConnectionId, PendingMetadata}
PendingMetadata = build_pending_metadata(
UserId, GuildId, ChannelId, SessionId, ExistingVoiceState, TokenNonce
),
_ = gen_server:call(
GuildPid, {store_pending_connection, ConnectionId, PendingMetadata}, 10000
),
guild_voice_broadcast:broadcast_voice_server_update_to_session(
GuildId, SessionId, Token, Endpoint, ConnectionId, State
GuildId,
ChannelId,
SessionId,
Token,
Endpoint,
ConnectionId,
State
);
{error, _Reason} ->
ok
@@ -125,3 +131,73 @@ send_voice_server_update_for_region_switch(
_ ->
ok
end.
-spec build_pending_metadata(integer(), integer(), integer(), binary(), voice_state(), binary()) -> map().
build_pending_metadata(UserId, GuildId, ChannelId, SessionId, ExistingVoiceState, TokenNonce) ->
Now = erlang:system_time(millisecond),
#{
user_id => UserId,
guild_id => GuildId,
channel_id => ChannelId,
session_id => SessionId,
self_mute => maps:get(<<"self_mute">>, ExistingVoiceState, false),
self_deaf => maps:get(<<"self_deaf">>, ExistingVoiceState, false),
self_video => maps:get(<<"self_video">>, ExistingVoiceState, false),
self_stream => maps:get(<<"self_stream">>, ExistingVoiceState, false),
is_mobile => maps:get(<<"is_mobile">>, ExistingVoiceState, false),
server_mute => maps:get(<<"mute">>, ExistingVoiceState, false),
server_deaf => maps:get(<<"deaf">>, ExistingVoiceState, false),
member => maps:get(<<"member">>, ExistingVoiceState, #{}),
viewer_stream_keys => [],
token_nonce => TokenNonce,
created_at => Now,
expires_at => Now + 30000
}.
-ifdef(TEST).
switch_voice_region_handler_not_found_test() ->
State = #{data => #{<<"channels">> => []}},
Request = #{channel_id => 999},
{reply, Error, _} = switch_voice_region_handler(Request, State),
?assertEqual({error, not_found, voice_channel_not_found}, Error).
switch_voice_region_handler_not_voice_test() ->
State = #{
data => #{
<<"channels">> => [
#{<<"id">> => <<"100">>, <<"type">> => 0}
]
}
},
Request = #{channel_id => 100},
{reply, Error, _} = switch_voice_region_handler(Request, State),
?assertEqual({error, validation_error, voice_channel_not_voice}, Error).
switch_voice_region_handler_success_test() ->
State = #{
data => #{
<<"channels">> => [
#{<<"id">> => <<"100">>, <<"type">> => 2}
]
}
},
Request = #{channel_id => 100},
{reply, Reply, _} = switch_voice_region_handler(Request, State),
?assertEqual(true, maps:get(success, Reply)).
collect_users_in_channel_test() ->
VoiceState = #{
<<"channel_id">> => <<"100">>,
<<"user_id">> => <<"10">>,
<<"session_id">> => <<"sess1">>
},
VoiceStates = #{<<"conn1">> => VoiceState},
Result = collect_users_in_channel(VoiceStates, 100),
?assertEqual(1, length(Result)),
[{UserId, SessionId, ConnectionId, _}] = Result,
?assertEqual(10, UserId),
?assertEqual(<<"sess1">>, SessionId),
?assertEqual(<<"conn1">>, ConnectionId).
-endif.

View File

@@ -26,14 +26,14 @@
-export([create_voice_state/8]).
-export([extract_session_info_from_voice_state/2]).
-type guild_state() :: map().
-type voice_state() :: map().
-type voice_state_map() :: #{binary() => voice_state()}.
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-endif.
-type guild_state() :: map().
-type voice_state() :: map().
-type voice_state_map() :: #{binary() => voice_state()}.
-spec get_voice_state(map(), guild_state()) -> {reply, map(), guild_state()}.
get_voice_state(Request, State) ->
case maps:get(connection_id, Request, null) of
@@ -58,7 +58,7 @@ get_voice_states_list(State) ->
voice_state_map(),
guild_state(),
boolean(),
term()
list()
) -> {reply, map(), guild_state()}.
update_voice_state_data(
ConnectionId,
@@ -69,39 +69,85 @@ update_voice_state_data(
VoiceStates,
State,
NeedsToken,
ViewerStreamKey
ViewerStreamKeys
) ->
#voice_flags{
self_mute = SelfMute,
self_deaf = SelfDeaf,
self_video = SelfVideo,
self_stream = SelfStream,
is_mobile = IsMobile
#{
self_mute := SelfMute,
self_deaf := SelfDeaf,
self_video := SelfVideo,
self_stream := SelfStream,
is_mobile := IsMobile
} = Flags,
ServerMute = maps:get(<<"mute">>, Member, false),
ServerDeaf = maps:get(<<"deaf">>, Member, false),
OldVersion = maps:get(<<"version">>, ExistingVoiceState, 0),
UpdatedVoiceState = ExistingVoiceState#{
<<"channel_id">> => ChannelIdBin,
<<"mute">> => ServerMute,
<<"deaf">> => ServerDeaf,
<<"self_mute">> => SelfMute,
<<"self_deaf">> => SelfDeaf,
<<"self_video">> => SelfVideo,
<<"self_stream">> => SelfStream,
<<"is_mobile">> => IsMobile,
<<"viewer_stream_key">> => ViewerStreamKey,
<<"version">> => OldVersion + 1
},
NewVoiceStates = maps:put(ConnectionId, UpdatedVoiceState, VoiceStates),
NewState = maps:put(voice_states, NewVoiceStates, State),
guild_voice_broadcast:broadcast_voice_state_update(UpdatedVoiceState, NewState, ChannelIdBin),
Reply =
case NeedsToken of
true -> #{success => true, voice_state => UpdatedVoiceState, needs_token => true};
false -> #{success => true, voice_state => UpdatedVoiceState}
end,
{reply, Reply, NewState}.
OldChannelIdBin = maps:get(<<"channel_id">>, ExistingVoiceState, null),
IsChannelChange = OldChannelIdBin =/= ChannelIdBin,
HasStateChange = has_voice_state_change(
ExistingVoiceState, ChannelIdBin, ServerMute, ServerDeaf,
SelfMute, SelfDeaf, SelfVideo, SelfStream, IsMobile, ViewerStreamKeys
),
case HasStateChange of
false ->
Reply = #{success => true, voice_state => ExistingVoiceState},
{reply, Reply, State};
true ->
OldVersion = maps:get(<<"version">>, ExistingVoiceState, 0),
UpdatedVoiceState = ExistingVoiceState#{
<<"channel_id">> => ChannelIdBin,
<<"mute">> => ServerMute,
<<"deaf">> => ServerDeaf,
<<"self_mute">> => SelfMute,
<<"self_deaf">> => SelfDeaf,
<<"self_video">> => SelfVideo,
<<"self_stream">> => SelfStream,
<<"is_mobile">> => IsMobile,
<<"viewer_stream_keys">> => ViewerStreamKeys,
<<"version">> => OldVersion + 1
},
NewVoiceStates = maps:put(ConnectionId, UpdatedVoiceState, VoiceStates),
NewState = maps:put(voice_states, NewVoiceStates, State),
case IsChannelChange of
true ->
DisconnectState = ExistingVoiceState#{
<<"channel_id">> => null,
<<"connection_id">> => ConnectionId
},
guild_voice_broadcast:broadcast_voice_state_update(
DisconnectState, NewState, OldChannelIdBin
),
guild_voice_broadcast:broadcast_voice_state_update(
UpdatedVoiceState, NewState, ChannelIdBin
);
false ->
guild_voice_broadcast:broadcast_voice_state_update(
UpdatedVoiceState, NewState, ChannelIdBin
)
end,
Reply =
case NeedsToken of
true -> #{success => true, voice_state => UpdatedVoiceState, needs_token => true};
false -> #{success => true, voice_state => UpdatedVoiceState}
end,
{reply, Reply, NewState}
end.
-spec has_voice_state_change(
voice_state(), binary(), boolean(), boolean(),
boolean(), boolean(), boolean(), boolean(), boolean(), term()
) -> boolean().
has_voice_state_change(
ExistingVoiceState, ChannelIdBin, ServerMute, ServerDeaf,
SelfMute, SelfDeaf, SelfVideo, SelfStream, IsMobile, ViewerStreamKeys
) ->
maps:get(<<"channel_id">>, ExistingVoiceState, null) =/= ChannelIdBin orelse
maps:get(<<"mute">>, ExistingVoiceState, false) =/= ServerMute orelse
maps:get(<<"deaf">>, ExistingVoiceState, false) =/= ServerDeaf orelse
maps:get(<<"self_mute">>, ExistingVoiceState, false) =/= SelfMute orelse
maps:get(<<"self_deaf">>, ExistingVoiceState, false) =/= SelfDeaf orelse
maps:get(<<"self_video">>, ExistingVoiceState, false) =/= SelfVideo orelse
maps:get(<<"self_stream">>, ExistingVoiceState, false) =/= SelfStream orelse
maps:get(<<"is_mobile">>, ExistingVoiceState, false) =/= IsMobile orelse
maps:get(<<"viewer_stream_keys">>, ExistingVoiceState, []) =/= ViewerStreamKeys.
-spec user_matches_voice_state(voice_state(), integer() | binary()) -> boolean().
user_matches_voice_state(VoiceState, UserId) when is_integer(UserId) ->
@@ -122,7 +168,7 @@ user_matches_voice_state(_VoiceState, _UserId) ->
boolean(),
boolean(),
voice_flags(),
term()
list()
) -> voice_state().
create_voice_state(
GuildIdBin,
@@ -132,14 +178,14 @@ create_voice_state(
ServerMute,
ServerDeaf,
Flags,
ViewerStreamKey
ViewerStreamKeys
) ->
#voice_flags{
self_mute = SelfMute,
self_deaf = SelfDeaf,
self_video = SelfVideo,
self_stream = SelfStream,
is_mobile = IsMobile
#{
self_mute := SelfMute,
self_deaf := SelfDeaf,
self_video := SelfVideo,
self_stream := SelfStream,
is_mobile := IsMobile
} = Flags,
#{
<<"guild_id">> => GuildIdBin,
@@ -153,7 +199,7 @@ create_voice_state(
<<"self_video">> => SelfVideo,
<<"self_stream">> => SelfStream,
<<"is_mobile">> => IsMobile,
<<"viewer_stream_key">> => ViewerStreamKey,
<<"viewer_stream_keys">> => ViewerStreamKeys,
<<"version">> => 0
}.
@@ -177,29 +223,138 @@ user_matches_voice_state_integer_test() ->
?assert(user_matches_voice_state(VoiceState, 10)),
?assertNot(user_matches_voice_state(VoiceState, 11)).
update_voice_state_data_updates_version_test() ->
VoiceState = #{<<"version">> => 1, <<"channel_id">> => <<"1">>},
Member = #{<<"mute">> => true, <<"deaf">> => false},
Flags = #voice_flags{
self_mute = true,
self_deaf = false,
self_video = false,
self_stream = false,
is_mobile = false
user_matches_voice_state_binary_test() ->
VoiceState = #{<<"user_id">> => <<"123">>},
?assert(user_matches_voice_state(VoiceState, <<"123">>)),
?assertNot(user_matches_voice_state(VoiceState, <<"456">>)).
user_matches_voice_state_undefined_test() ->
VoiceState = #{},
?assertNot(user_matches_voice_state(VoiceState, 10)).
create_voice_state_test() ->
Flags = #{
self_mute => true,
self_deaf => false,
self_video => true,
self_stream => false,
is_mobile => true
},
{reply, #{voice_state := Updated}, _} =
update_voice_state_data(
<<"conn">>,
<<"2">>,
Flags,
Member,
VoiceState,
#{<<"conn">> => VoiceState},
#{voice_states => #{}},
false,
null
),
?assertEqual(2, maps:get(<<"version">>, Updated)),
?assertEqual(<<"2">>, maps:get(<<"channel_id">>, Updated)).
VS = create_voice_state(
<<"1">>, <<"2">>, <<"3">>, <<"conn">>, false, false, Flags, []
),
?assertEqual(<<"1">>, maps:get(<<"guild_id">>, VS)),
?assertEqual(<<"2">>, maps:get(<<"channel_id">>, VS)),
?assertEqual(<<"3">>, maps:get(<<"user_id">>, VS)),
?assertEqual(<<"conn">>, maps:get(<<"connection_id">>, VS)),
?assertEqual(true, maps:get(<<"self_mute">>, VS)),
?assertEqual(false, maps:get(<<"self_deaf">>, VS)),
?assertEqual(true, maps:get(<<"self_video">>, VS)),
?assertEqual(false, maps:get(<<"self_stream">>, VS)),
?assertEqual(true, maps:get(<<"is_mobile">>, VS)),
?assertEqual(0, maps:get(<<"version">>, VS)).
extract_session_info_from_voice_state_test() ->
VoiceState = #{
<<"session_id">> => <<"sess">>,
<<"self_mute">> => true,
<<"self_deaf">> => false,
<<"self_video">> => true,
<<"self_stream">> => false,
<<"is_mobile">> => true,
<<"member">> => #{<<"id">> => <<"m">>}
},
Info = extract_session_info_from_voice_state(<<"conn">>, VoiceState),
?assertEqual(<<"conn">>, maps:get(connection_id, Info)),
?assertEqual(<<"sess">>, maps:get(session_id, Info)),
?assertEqual(true, maps:get(self_mute, Info)),
?assertEqual(#{<<"id">> => <<"m">>}, maps:get(member, Info)).
has_voice_state_change_no_change_test() ->
ExistingVoiceState = #{
<<"channel_id">> => <<"100">>,
<<"mute">> => false,
<<"deaf">> => false,
<<"self_mute">> => true,
<<"self_deaf">> => false,
<<"self_video">> => false,
<<"self_stream">> => false,
<<"is_mobile">> => false,
<<"viewer_stream_keys">> => []
},
?assertNot(has_voice_state_change(
ExistingVoiceState, <<"100">>, false, false, true, false, false, false, false, []
)).
has_voice_state_change_channel_change_test() ->
ExistingVoiceState = #{
<<"channel_id">> => <<"100">>,
<<"mute">> => false,
<<"deaf">> => false,
<<"self_mute">> => false,
<<"self_deaf">> => false,
<<"self_video">> => false,
<<"self_stream">> => false,
<<"is_mobile">> => false,
<<"viewer_stream_keys">> => []
},
?assert(has_voice_state_change(
ExistingVoiceState, <<"200">>, false, false, false, false, false, false, false, []
)).
has_voice_state_change_self_mute_change_test() ->
ExistingVoiceState = #{
<<"channel_id">> => <<"100">>,
<<"mute">> => false,
<<"deaf">> => false,
<<"self_mute">> => false,
<<"self_deaf">> => false,
<<"self_video">> => false,
<<"self_stream">> => false,
<<"is_mobile">> => false,
<<"viewer_stream_keys">> => []
},
?assert(has_voice_state_change(
ExistingVoiceState, <<"100">>, false, false, true, false, false, false, false, []
)).
has_voice_state_change_server_mute_change_test() ->
ExistingVoiceState = #{
<<"channel_id">> => <<"100">>,
<<"mute">> => false,
<<"deaf">> => false,
<<"self_mute">> => false,
<<"self_deaf">> => false,
<<"self_video">> => false,
<<"self_stream">> => false,
<<"is_mobile">> => false,
<<"viewer_stream_keys">> => []
},
?assert(has_voice_state_change(
ExistingVoiceState, <<"100">>, true, false, false, false, false, false, false, []
)).
has_voice_state_change_viewer_stream_keys_change_test() ->
ExistingVoiceState = #{
<<"channel_id">> => <<"100">>,
<<"mute">> => false,
<<"deaf">> => false,
<<"self_mute">> => false,
<<"self_deaf">> => false,
<<"self_video">> => false,
<<"self_stream">> => false,
<<"is_mobile">> => false,
<<"viewer_stream_keys">> => []
},
?assert(has_voice_state_change(
ExistingVoiceState, <<"100">>, false, false, false, false, false, false, false,
[<<"999:100:conn">>]
)).
has_voice_state_change_defaults_test() ->
ExistingVoiceState = #{},
?assertNot(has_voice_state_change(
ExistingVoiceState, null, false, false, false, false, false, false, false, []
)).
-endif.

View File

@@ -0,0 +1,99 @@
%% Copyright (C) 2026 Fluxer Contributors
%%
%% This file is part of Fluxer.
%%
%% Fluxer is free software: you can redistribute it and/or modify
%% it under the terms of the GNU Affero General Public License as published by
%% the Free Software Foundation, either version 3 of the License, or
%% (at your option) any later version.
%%
%% Fluxer is distributed in the hope that it will be useful,
%% but WITHOUT ANY WARRANTY; without even the implied warranty of
%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
%% GNU Affero General Public License for more details.
%%
%% You should have received a copy of the GNU Affero General Public License
%% along with Fluxer. If not, see <https://www.gnu.org/licenses/>.
-module(guild_voice_unclaimed_account_utils).
-export([parse_unclaimed_error/1]).
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-endif.
-spec parse_unclaimed_error(iodata() | term()) -> boolean().
parse_unclaimed_error(Body) when is_binary(Body); is_list(Body) ->
try json:decode(iolist_to_binary(Body)) of
Map when is_map(Map) ->
case get_unclaimed_error_code(Map) of
Code when is_binary(Code) -> is_voice_unclaimed_error_code(Code);
_ -> false
end;
_ ->
false
catch
_:_ -> false
end;
parse_unclaimed_error(_) ->
false.
-spec get_unclaimed_error_code(map()) -> binary() | undefined.
get_unclaimed_error_code(Map) when is_map(Map) ->
case maps:get(<<"code">>, Map, undefined) of
Code when is_binary(Code) ->
Code;
_ ->
case maps:get(<<"error">>, Map, undefined) of
Error when is_map(Error) ->
maps:get(<<"code">>, Error, undefined);
_ ->
undefined
end
end.
-spec is_voice_unclaimed_error_code(binary()) -> boolean().
is_voice_unclaimed_error_code(Code) when is_binary(Code) ->
lists:member(
Code,
[
<<"UNCLAIMED_ACCOUNT_CANNOT_JOIN_ONE_ON_ONE_VOICE_CALLS">>,
<<"UNCLAIMED_ACCOUNT_CANNOT_JOIN_VOICE_CHANNELS">>
]
).
-ifdef(TEST).
parse_unclaimed_error_with_direct_code_test() ->
Body = json:encode(#{<<"code">> => <<"UNCLAIMED_ACCOUNT_CANNOT_JOIN_VOICE_CHANNELS">>}),
?assertEqual(true, parse_unclaimed_error(Body)).
parse_unclaimed_error_with_nested_code_test() ->
Body = json:encode(#{
<<"error">> => #{<<"code">> => <<"UNCLAIMED_ACCOUNT_CANNOT_JOIN_ONE_ON_ONE_VOICE_CALLS">>}
}),
?assertEqual(true, parse_unclaimed_error(Body)).
parse_unclaimed_error_with_unknown_code_test() ->
Body = json:encode(#{<<"code">> => <<"SOME_OTHER_ERROR">>}),
?assertEqual(false, parse_unclaimed_error(Body)).
parse_unclaimed_error_with_invalid_json_test() ->
?assertEqual(false, parse_unclaimed_error(<<"not json">>)).
parse_unclaimed_error_with_non_binary_test() ->
?assertEqual(false, parse_unclaimed_error(undefined)),
?assertEqual(false, parse_unclaimed_error(123)).
is_voice_unclaimed_error_code_test() ->
?assertEqual(
true, is_voice_unclaimed_error_code(<<"UNCLAIMED_ACCOUNT_CANNOT_JOIN_VOICE_CHANNELS">>)
),
?assertEqual(
true,
is_voice_unclaimed_error_code(<<"UNCLAIMED_ACCOUNT_CANNOT_JOIN_ONE_ON_ONE_VOICE_CALLS">>)
),
?assertEqual(false, is_voice_unclaimed_error_code(<<"OTHER_ERROR">>)).
-endif.

View File

@@ -24,6 +24,10 @@
channel_has_capacity/3
]).
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-endif.
-type user_id() :: integer().
-type session_id() :: binary().
-type session_pid() :: pid().
@@ -103,3 +107,56 @@ channel_has_capacity(ChannelId, UserLimit, VoiceStates) ->
-spec ensure_binary(binary() | integer()) -> binary().
ensure_binary(Value) when is_binary(Value) -> Value;
ensure_binary(Value) when is_integer(Value) -> integer_to_binary(Value).
-ifdef(TEST).
find_session_by_user_id_test() ->
Pid = self(),
Ref = make_ref(),
Sessions = #{
<<"session1">> => {100, Pid, Ref},
<<"session2">> => {200, Pid, make_ref()}
},
?assertMatch({ok, <<"session1">>, _, _}, find_session_by_user_id(100, Sessions)),
?assertEqual(not_found, find_session_by_user_id(999, Sessions)).
disconnect_user_not_found_test() ->
VoiceStates = #{},
Sessions = #{},
CleanupFun = fun(_, _) -> ok end,
?assertMatch({not_found, _, _}, disconnect_user(100, VoiceStates, Sessions, CleanupFun)).
disconnect_user_if_in_channel_mismatch_test() ->
VoiceStates = #{100 => #{<<"channel_id">> => <<"999">>}},
Sessions = #{},
CleanupFun = fun(_, _) -> ok end,
Result = disconnect_user_if_in_channel(100, 123, VoiceStates, Sessions, CleanupFun),
?assertMatch({channel_mismatch, _, _}, Result).
disconnect_user_if_in_channel_not_found_test() ->
VoiceStates = #{},
Sessions = #{},
CleanupFun = fun(_, _) -> ok end,
Result = disconnect_user_if_in_channel(100, 123, VoiceStates, Sessions, CleanupFun),
?assertMatch({not_found, _, _}, Result).
channel_has_capacity_unlimited_test() ->
VoiceStates = #{
1 => #{<<"channel_id">> => <<"100">>},
2 => #{<<"channel_id">> => <<"100">>}
},
?assert(channel_has_capacity(100, 0, VoiceStates)).
channel_has_capacity_limited_test() ->
VoiceStates = #{
1 => #{<<"channel_id">> => <<"100">>},
2 => #{<<"channel_id">> => <<"100">>}
},
?assertNot(channel_has_capacity(100, 2, VoiceStates)),
?assert(channel_has_capacity(100, 3, VoiceStates)).
ensure_binary_test() ->
?assertEqual(<<"123">>, ensure_binary(123)),
?assertEqual(<<"abc">>, ensure_binary(<<"abc">>)).
-endif.

View File

@@ -24,6 +24,10 @@
confirm_pending_connection/2
]).
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-endif.
-type connection_id() :: binary().
-type pending_metadata() :: map().
-type pending_map() :: #{connection_id() => pending_metadata()}.
@@ -56,3 +60,35 @@ confirm_pending_connection(ConnectionId, PendingMap) ->
_Metadata ->
{confirmed, maps:remove(ConnectionId, PendingMap)}
end.
-ifdef(TEST).
add_pending_connection_test() ->
PendingMap = #{},
Metadata = #{user_id => 1, channel_id => 2},
Result = add_pending_connection(<<"conn">>, Metadata, PendingMap),
?assert(maps:is_key(<<"conn">>, Result)),
StoredMetadata = maps:get(<<"conn">>, Result),
?assertEqual(1, maps:get(user_id, StoredMetadata)),
?assertEqual(2, maps:get(channel_id, StoredMetadata)),
?assert(maps:is_key(joined_at, StoredMetadata)).
remove_pending_connection_test() ->
PendingMap = #{<<"conn">> => #{user_id => 1}},
?assertEqual(#{}, remove_pending_connection(<<"conn">>, PendingMap)),
?assertEqual(PendingMap, remove_pending_connection(undefined, PendingMap)),
?assertEqual(PendingMap, remove_pending_connection(<<"other">>, PendingMap)).
get_pending_connection_test() ->
PendingMap = #{<<"conn">> => #{user_id => 1}},
?assertEqual(#{user_id => 1}, get_pending_connection(<<"conn">>, PendingMap)),
?assertEqual(undefined, get_pending_connection(<<"other">>, PendingMap)),
?assertEqual(undefined, get_pending_connection(undefined, PendingMap)).
confirm_pending_connection_test() ->
PendingMap = #{<<"conn">> => #{user_id => 1}},
?assertMatch({confirmed, #{}}, confirm_pending_connection(<<"conn">>, PendingMap)),
?assertMatch({not_found, _}, confirm_pending_connection(<<"other">>, PendingMap)),
?assertMatch({not_found, _}, confirm_pending_connection(undefined, PendingMap)).
-endif.

View File

@@ -33,66 +33,86 @@
build_stream_key/3
]).
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-endif.
-type voice_state() :: map().
-type voice_state_map() :: #{binary() => voice_state()}.
-type guild_state() :: map().
-type stream_key_result() :: #{
scope := guild | dm,
guild_id := integer() | undefined,
channel_id := integer(),
connection_id := binary()
}.
-spec voice_states(guild_state()) -> voice_state_map().
voice_states(State) when is_map(State) ->
case maps:get(voice_states, State, undefined) of
Map when is_map(Map) -> Map;
_ -> #{}
end.
-spec ensure_voice_states(term()) -> voice_state_map().
ensure_voice_states(Map) when is_map(Map) ->
Map;
ensure_voice_states(_) ->
#{}.
-spec voice_state_user_id(voice_state()) -> integer() | undefined.
voice_state_user_id(VoiceState) ->
map_utils:get_integer(VoiceState, <<"user_id">>, undefined).
-spec voice_state_channel_id(voice_state()) -> integer() | undefined.
voice_state_channel_id(VoiceState) ->
map_utils:get_integer(VoiceState, <<"channel_id">>, undefined).
-spec voice_state_guild_id(voice_state()) -> integer() | undefined.
voice_state_guild_id(VoiceState) ->
map_utils:get_integer(VoiceState, <<"guild_id">>, undefined).
-spec filter_voice_states(voice_state_map(), fun((binary(), voice_state()) -> boolean())) ->
voice_state_map().
filter_voice_states(VoiceStates, Predicate) when is_map(VoiceStates) ->
maps:filter(Predicate, VoiceStates);
filter_voice_states(_, _) ->
#{}.
-spec drop_voice_states(voice_state_map(), voice_state_map()) -> voice_state_map().
drop_voice_states(ToDrop, VoiceStates) ->
maps:fold(fun(ConnId, _VoiceState, Acc) -> maps:remove(ConnId, Acc) end, VoiceStates, ToDrop).
-spec broadcast_disconnects(voice_state_map(), guild_state()) -> ok.
broadcast_disconnects(VoiceStates, State) ->
maps:foreach(
fun(ConnId, VoiceState) ->
OldChannelIdBin = maps:get(<<"channel_id">>, VoiceState, null),
DisconnectVoiceState = VoiceState#{
<<"channel_id">> => null,
<<"connection_id">> => ConnId
},
guild_voice_broadcast:broadcast_voice_state_update(
DisconnectVoiceState, State, OldChannelIdBin
)
end,
VoiceStates
).
spawn(fun() ->
maps:foreach(
fun(ConnId, VoiceState) ->
OldChannelIdBin = maps:get(<<"channel_id">>, VoiceState, null),
DisconnectVoiceState = VoiceState#{
<<"channel_id">> => null,
<<"connection_id">> => ConnId
},
guild_voice_broadcast:broadcast_voice_state_update(
DisconnectVoiceState, State, OldChannelIdBin
)
end,
VoiceStates
)
end),
ok.
-spec voice_flags_from_context(map()) -> voice_flags().
voice_flags_from_context(Context) ->
#voice_flags{
self_mute = maps:get(self_mute, Context, false),
self_deaf = maps:get(self_deaf, Context, false),
self_video = maps:get(self_video, Context, false),
self_stream = maps:get(self_stream, Context, false),
is_mobile = maps:get(is_mobile, Context, false)
#{
self_mute => maps:get(self_mute, Context, false),
self_deaf => maps:get(self_deaf, Context, false),
self_video => maps:get(self_video, Context, false),
self_stream => maps:get(self_stream, Context, false),
is_mobile => maps:get(is_mobile, Context, false)
}.
-spec parse_stream_key(term()) ->
{ok, #{
scope := guild | dm,
guild_id := integer() | undefined,
channel_id := integer(),
connection_id := binary()
}}
| {error, invalid_stream_key}.
-spec parse_stream_key(term()) -> {ok, stream_key_result()} | {error, invalid_stream_key}.
parse_stream_key(StreamKey) when is_binary(StreamKey) ->
Parts = binary:split(StreamKey, <<":">>, [global]),
case Parts of
@@ -126,12 +146,7 @@ parse_channel_bin(ChannelBin) ->
Chan.
-spec build_stream_key_result({dm, undefined} | {guild, integer()}, integer(), binary()) ->
{ok, #{
scope := guild | dm,
guild_id := integer() | undefined,
channel_id := integer(),
connection_id := binary()
}}.
{ok, stream_key_result()}.
build_stream_key_result({dm, _}, ChannelId, ConnId) ->
{ok, #{
scope => dm,
@@ -162,3 +177,84 @@ build_stream_key(GuildId, ChannelId, ConnectionId) when
":",
ConnectionId/binary
>>.
-ifdef(TEST).
voice_states_returns_map_test() ->
State = #{voice_states => #{<<"a">> => #{}}},
?assertEqual(#{<<"a">> => #{}}, voice_states(State)),
?assertEqual(#{}, voice_states(#{})),
?assertEqual(#{}, voice_states(#{voice_states => not_a_map})).
ensure_voice_states_test() ->
?assertEqual(#{<<"a">> => 1}, ensure_voice_states(#{<<"a">> => 1})),
?assertEqual(#{}, ensure_voice_states(not_a_map)).
voice_state_user_id_test() ->
?assertEqual(123, voice_state_user_id(#{<<"user_id">> => <<"123">>})),
?assertEqual(undefined, voice_state_user_id(#{})).
voice_state_channel_id_test() ->
?assertEqual(456, voice_state_channel_id(#{<<"channel_id">> => <<"456">>})),
?assertEqual(undefined, voice_state_channel_id(#{})).
voice_state_guild_id_test() ->
?assertEqual(789, voice_state_guild_id(#{<<"guild_id">> => <<"789">>})),
?assertEqual(undefined, voice_state_guild_id(#{})).
filter_voice_states_test() ->
VoiceStates = #{
<<"a">> => #{<<"user_id">> => <<"1">>},
<<"b">> => #{<<"user_id">> => <<"2">>}
},
Filtered = filter_voice_states(VoiceStates, fun(_, V) ->
maps:get(<<"user_id">>, V) =:= <<"1">>
end),
?assertEqual(#{<<"a">> => #{<<"user_id">> => <<"1">>}}, Filtered).
drop_voice_states_test() ->
VoiceStates = #{<<"a">> => #{}, <<"b">> => #{}, <<"c">> => #{}},
ToDrop = #{<<"a">> => #{}, <<"c">> => #{}},
Result = drop_voice_states(ToDrop, VoiceStates),
?assertEqual(#{<<"b">> => #{}}, Result).
voice_flags_from_context_test() ->
Context = #{
self_mute => true,
self_deaf => false,
self_video => true,
self_stream => false,
is_mobile => true
},
Flags = voice_flags_from_context(Context),
?assertEqual(true, maps:get(self_mute, Flags)),
?assertEqual(false, maps:get(self_deaf, Flags)),
?assertEqual(true, maps:get(self_video, Flags)),
?assertEqual(false, maps:get(self_stream, Flags)),
?assertEqual(true, maps:get(is_mobile, Flags)).
parse_stream_key_dm_test() ->
Result = parse_stream_key(<<"dm:123:conn-id">>),
?assertMatch({ok, #{scope := dm, channel_id := 123, connection_id := <<"conn-id">>}}, Result).
parse_stream_key_guild_test() ->
Result = parse_stream_key(<<"999:123:conn-id">>),
?assertMatch(
{ok, #{scope := guild, guild_id := 999, channel_id := 123, connection_id := <<"conn-id">>}},
Result
).
parse_stream_key_invalid_test() ->
?assertEqual({error, invalid_stream_key}, parse_stream_key(<<"invalid">>)),
?assertEqual({error, invalid_stream_key}, parse_stream_key(<<"a:b">>)),
?assertEqual({error, invalid_stream_key}, parse_stream_key(123)).
build_stream_key_dm_test() ->
Result = build_stream_key(undefined, 123, <<"conn">>),
?assertEqual(<<"dm:123:conn">>, Result).
build_stream_key_guild_test() ->
Result = build_stream_key(999, 123, <<"conn">>),
?assertEqual(<<"999:123:conn">>, Result).
-endif.

View File

@@ -20,45 +20,56 @@
-export([
build_voice_token_rpc_request/6,
build_voice_token_rpc_request/7,
build_voice_token_rpc_request/8,
build_force_disconnect_rpc_request/4,
build_update_participant_rpc_request/5,
build_update_participant_permissions_rpc_request/5,
add_geolocation_to_request/3,
compute_voice_permissions/3
add_rtc_region_to_request/2,
compute_voice_permissions/3,
generate_token_nonce/0
]).
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-endif.
-type guild_state() :: map().
-type voice_permissions() :: #{
can_speak := boolean(),
can_stream := boolean(),
can_video := boolean()
}.
-spec build_voice_token_rpc_request(
integer() | null,
integer(),
integer(),
binary() | integer() | null,
binary() | number() | list() | null,
binary() | number() | list() | null
) -> map().
build_voice_token_rpc_request(GuildId, ChannelId, UserId, ConnectionId, Latitude, Longitude) ->
BaseReq0 = #{
<<"type">> => <<"voice_get_token">>,
<<"channel_id">> => integer_to_binary(ChannelId),
<<"user_id">> => integer_to_binary(UserId)
},
BaseReq =
case GuildId of
null ->
#{
<<"type">> => <<"voice_get_token">>,
<<"channel_id">> => integer_to_binary(ChannelId),
<<"user_id">> => integer_to_binary(UserId)
};
_ ->
BaseMap = #{
<<"type">> => <<"voice_get_token">>,
<<"guild_id">> => integer_to_binary(GuildId),
<<"channel_id">> => integer_to_binary(ChannelId),
<<"user_id">> => integer_to_binary(UserId)
},
case ConnectionId of
null ->
BaseMap;
ConnectionId when is_binary(ConnectionId) ->
maps:put(<<"connection_id">>, ConnectionId, BaseMap);
ConnectionId when is_integer(ConnectionId) ->
maps:put(<<"connection_id">>, integer_to_binary(ConnectionId), BaseMap);
_ ->
BaseMap
end
null -> BaseReq0;
_ -> maps:put(<<"guild_id">>, integer_to_binary(GuildId), BaseReq0)
end,
WithConnection = add_connection_id_to_request(BaseReq, ConnectionId),
add_geolocation_to_request(WithConnection, Latitude, Longitude).
add_geolocation_to_request(BaseReq, Latitude, Longitude).
-spec add_geolocation_to_request(
map(),
binary() | number() | list() | null,
binary() | number() | list() | null
) -> map().
add_geolocation_to_request(RequestMap, Latitude, Longitude) ->
case {Latitude, Longitude} of
case {normalise_coordinate(Latitude), normalise_coordinate(Longitude)} of
{Lat, Long} when is_binary(Lat) andalso is_binary(Long) ->
maps:merge(RequestMap, #{
<<"latitude">> => Lat,
@@ -68,6 +79,47 @@ add_geolocation_to_request(RequestMap, Latitude, Longitude) ->
RequestMap
end.
-spec normalise_coordinate(binary() | number() | list() | null) -> binary() | undefined.
normalise_coordinate(null) ->
undefined;
normalise_coordinate(Value) when is_binary(Value) ->
Value;
normalise_coordinate(Value) when is_integer(Value) ->
integer_to_binary(Value);
normalise_coordinate(Value) when is_float(Value) ->
float_to_binary(Value, [short]);
normalise_coordinate(Value) when is_list(Value) ->
try
list_to_binary(Value)
catch
error:badarg -> undefined
end;
normalise_coordinate(_Value) ->
undefined.
-spec add_rtc_region_to_request(map(), binary() | null) -> map().
add_rtc_region_to_request(RequestMap, Region) ->
case Region of
RegionBin when is_binary(RegionBin) ->
maps:put(<<"rtc_region">>, RegionBin, RequestMap);
_ ->
RequestMap
end.
-spec add_connection_id_to_request(map(), binary() | integer() | null) -> map().
add_connection_id_to_request(RequestMap, ConnectionId) ->
case ConnectionId of
null ->
RequestMap;
ConnectionIdBin when is_binary(ConnectionIdBin) ->
maps:put(<<"connection_id">>, ConnectionIdBin, RequestMap);
ConnectionIdInt when is_integer(ConnectionIdInt) ->
maps:put(<<"connection_id">>, integer_to_binary(ConnectionIdInt), RequestMap);
_ ->
RequestMap
end.
-spec build_force_disconnect_rpc_request(integer() | null, integer(), integer(), binary()) -> map().
build_force_disconnect_rpc_request(GuildId, ChannelId, UserId, ConnectionId) ->
BaseReq = #{
<<"type">> => <<"voice_force_disconnect_participant">>,
@@ -82,6 +134,9 @@ build_force_disconnect_rpc_request(GuildId, ChannelId, UserId, ConnectionId) ->
maps:put(<<"guild_id">>, integer_to_binary(GuildId), BaseReq)
end.
-spec build_update_participant_rpc_request(
integer() | null, integer(), integer(), boolean(), boolean()
) -> map().
build_update_participant_rpc_request(GuildId, ChannelId, UserId, Mute, Deaf) ->
BaseReq = #{
<<"type">> => <<"voice_update_participant">>,
@@ -97,6 +152,9 @@ build_update_participant_rpc_request(GuildId, ChannelId, UserId, Mute, Deaf) ->
maps:put(<<"guild_id">>, integer_to_binary(GuildId), BaseReq)
end.
-spec build_update_participant_permissions_rpc_request(
integer() | null, integer(), integer(), binary(), voice_permissions()
) -> map().
build_update_participant_permissions_rpc_request(
GuildId, ChannelId, UserId, ConnectionId, VoicePermissions
) ->
@@ -116,35 +174,181 @@ build_update_participant_permissions_rpc_request(
maps:put(<<"guild_id">>, integer_to_binary(GuildId), BaseReq)
end.
-spec compute_voice_permissions(integer(), integer(), map()) -> map().
-spec compute_voice_permissions(integer(), integer(), guild_state()) -> voice_permissions().
compute_voice_permissions(UserId, ChannelId, State) ->
Permissions = guild_permissions:get_member_permissions(UserId, ChannelId, State),
SpeakPerm = constants:speak_permission(),
StreamPerm = constants:stream_permission(),
AdminPerm = constants:administrator_permission(),
IsAdmin = (Permissions band AdminPerm) =:= AdminPerm,
CanSpeak = IsAdmin orelse ((Permissions band SpeakPerm) =:= SpeakPerm),
CanStream = IsAdmin orelse ((Permissions band StreamPerm) =:= StreamPerm),
HasVirtualAccess = guild_virtual_channel_access:has_virtual_access(UserId, ChannelId, State),
FinalCanSpeak = CanSpeak orelse HasVirtualAccess,
FinalCanStream = CanStream orelse HasVirtualAccess,
#{
can_speak => FinalCanSpeak,
can_stream => FinalCanStream,
can_video => FinalCanStream
}.
-spec build_voice_token_rpc_request(
integer() | null,
integer(),
integer(),
binary() | integer() | null,
binary() | null,
binary() | null,
voice_permissions()
) -> map().
build_voice_token_rpc_request(
GuildId, ChannelId, UserId, ConnectionId, Latitude, Longitude, VoicePermissions
) ->
build_voice_token_rpc_request(
GuildId, ChannelId, UserId, ConnectionId, Latitude, Longitude, VoicePermissions, null
).
-spec build_voice_token_rpc_request(
integer() | null,
integer(),
integer(),
binary() | integer() | null,
binary() | null,
binary() | null,
voice_permissions(),
binary() | null
) -> map().
build_voice_token_rpc_request(
GuildId, ChannelId, UserId, ConnectionId, Latitude, Longitude, VoicePermissions, TokenNonce
) ->
BaseReq = build_voice_token_rpc_request(
GuildId, ChannelId, UserId, ConnectionId, Latitude, Longitude
),
maps:merge(BaseReq, #{
Req0 = maps:merge(BaseReq, #{
<<"can_speak">> => maps:get(can_speak, VoicePermissions, true),
<<"can_stream">> => maps:get(can_stream, VoicePermissions, true),
<<"can_video">> => maps:get(can_video, VoicePermissions, true)
}).
}),
case TokenNonce of
null -> Req0;
undefined -> Req0;
_ when is_binary(TokenNonce) -> maps:put(<<"token_nonce">>, TokenNonce, Req0);
_ -> Req0
end.
-spec generate_token_nonce() -> binary().
generate_token_nonce() ->
Bytes = crypto:strong_rand_bytes(16),
binary:encode_hex(Bytes, lowercase).
-ifdef(TEST).
build_voice_token_rpc_request_guild_test() ->
Req = build_voice_token_rpc_request(123, 456, 789, null, null, null),
?assertEqual(<<"voice_get_token">>, maps:get(<<"type">>, Req)),
?assertEqual(<<"123">>, maps:get(<<"guild_id">>, Req)),
?assertEqual(<<"456">>, maps:get(<<"channel_id">>, Req)),
?assertEqual(<<"789">>, maps:get(<<"user_id">>, Req)),
?assertNot(maps:is_key(<<"connection_id">>, Req)).
build_voice_token_rpc_request_dm_test() ->
Req = build_voice_token_rpc_request(null, 456, 789, null, null, null),
?assertEqual(<<"voice_get_token">>, maps:get(<<"type">>, Req)),
?assertNot(maps:is_key(<<"guild_id">>, Req)),
?assertEqual(<<"456">>, maps:get(<<"channel_id">>, Req)).
build_voice_token_rpc_request_with_connection_test() ->
Req = build_voice_token_rpc_request(123, 456, 789, <<"conn-id">>, null, null),
?assertEqual(<<"conn-id">>, maps:get(<<"connection_id">>, Req)).
build_voice_token_rpc_request_dm_with_connection_test() ->
Req = build_voice_token_rpc_request(null, 456, 789, <<"conn-id">>, null, null),
?assertEqual(<<"conn-id">>, maps:get(<<"connection_id">>, Req)).
add_geolocation_to_request_test() ->
BaseReq = #{<<"type">> => <<"test">>},
WithGeo = add_geolocation_to_request(BaseReq, <<"1.0">>, <<"2.0">>),
?assertEqual(<<"1.0">>, maps:get(<<"latitude">>, WithGeo)),
?assertEqual(<<"2.0">>, maps:get(<<"longitude">>, WithGeo)),
WithoutGeo = add_geolocation_to_request(BaseReq, null, null),
?assertNot(maps:is_key(<<"latitude">>, WithoutGeo)).
add_geolocation_to_request_number_test() ->
BaseReq = #{<<"type">> => <<"test">>},
WithGeo = add_geolocation_to_request(BaseReq, 1.5, 2.25),
?assertEqual(<<"1.5">>, maps:get(<<"latitude">>, WithGeo)),
?assertEqual(<<"2.25">>, maps:get(<<"longitude">>, WithGeo)).
add_rtc_region_to_request_test() ->
BaseReq = #{<<"type">> => <<"test">>},
WithRegion = add_rtc_region_to_request(BaseReq, <<"us-east">>),
?assertEqual(<<"us-east">>, maps:get(<<"rtc_region">>, WithRegion)),
WithoutRegion = add_rtc_region_to_request(BaseReq, null),
?assertNot(maps:is_key(<<"rtc_region">>, WithoutRegion)).
build_force_disconnect_rpc_request_test() ->
Req = build_force_disconnect_rpc_request(123, 456, 789, <<"conn">>),
?assertEqual(<<"voice_force_disconnect_participant">>, maps:get(<<"type">>, Req)),
?assertEqual(<<"123">>, maps:get(<<"guild_id">>, Req)),
?assertEqual(<<"conn">>, maps:get(<<"connection_id">>, Req)).
build_update_participant_rpc_request_test() ->
Req = build_update_participant_rpc_request(123, 456, 789, true, false),
?assertEqual(<<"voice_update_participant">>, maps:get(<<"type">>, Req)),
?assertEqual(true, maps:get(<<"mute">>, Req)),
?assertEqual(false, maps:get(<<"deaf">>, Req)).
generate_token_nonce_format_test() ->
Nonce = generate_token_nonce(),
?assert(is_binary(Nonce)),
?assertEqual(32, byte_size(Nonce)),
?assert(lists:all(fun(C) ->
(C >= $0 andalso C =< $9) orelse (C >= $a andalso C =< $f)
end, binary_to_list(Nonce))).
generate_token_nonce_unique_test() ->
Nonce1 = generate_token_nonce(),
Nonce2 = generate_token_nonce(),
Nonce3 = generate_token_nonce(),
?assertNot(Nonce1 =:= Nonce2),
?assertNot(Nonce2 =:= Nonce3),
?assertNot(Nonce1 =:= Nonce3).
build_voice_token_rpc_request_with_nonce_test() ->
VoicePerms = #{
can_speak => true,
can_stream => false,
can_video => false
},
Req = build_voice_token_rpc_request(
123, 456, 789, null, null, null, VoicePerms, <<"test-nonce-123">>
),
?assertEqual(<<"test-nonce-123">>, maps:get(<<"token_nonce">>, Req)),
?assertEqual(true, maps:get(<<"can_speak">>, Req)),
?assertEqual(false, maps:get(<<"can_stream">>, Req)).
build_voice_token_rpc_request_without_nonce_test() ->
VoicePerms = #{
can_speak => true,
can_stream => true,
can_video => true
},
Req = build_voice_token_rpc_request(
123, 456, 789, null, null, null, VoicePerms, null
),
?assertNot(maps:is_key(<<"token_nonce">>, Req)),
?assertEqual(true, maps:get(<<"can_speak">>, Req)).
build_voice_token_rpc_request_undefined_nonce_test() ->
VoicePerms = #{
can_speak => false,
can_stream => true,
can_video => true
},
Req = build_voice_token_rpc_request(
123, 456, 789, null, null, null, VoicePerms, undefined
),
?assertNot(maps:is_key(<<"token_nonce">>, Req)),
?assertEqual(false, maps:get(<<"can_speak">>, Req)).
-endif.

View File

@@ -21,20 +21,63 @@
-export([start_link/1]).
-export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]).
-type user_id() :: integer().
-type session_id() :: binary().
-type status() :: online | offline | idle | dnd | invisible.
-type custom_status() :: map() | null.
-type session_entry() :: #{
session_id := session_id(),
status := status(),
afk := boolean(),
mobile := boolean(),
pid := pid(),
mref := reference(),
socket_pid := pid() | undefined
}.
-type sessions() :: #{session_id() => session_entry()}.
-type subscription_entry() :: #{friend := boolean(), gdm_channels := #{integer() => true}}.
-type subscriptions() :: #{user_id() => subscription_entry()}.
-type push_buffer_entry() :: #{channel_id := integer(), message_id := integer(), params := map()}.
-type state() :: #{
user_id := user_id(),
user_data := map(),
sessions := sessions(),
push_buffer := [push_buffer_entry()],
custom_status := custom_status(),
status := status(),
guild_ids := #{integer() => true},
temporary_guild_ids := #{integer() => true},
friends := #{user_id() => true},
group_dm_recipients := #{integer() => #{user_id() => true}},
subscriptions := subscriptions(),
is_bot := boolean(),
initial_presences_sent := boolean(),
last_published_presence := map() | undefined
}.
-type presence_data() :: #{
user_id := user_id(),
user_data := map(),
guild_ids => [integer()],
friend_ids => [user_id()],
group_dm_recipients => #{integer() => [user_id()] | #{user_id() => true}},
status := status(),
custom_status => custom_status()
}.
-spec start_link(presence_data()) -> {ok, pid()} | {error, term()}.
start_link(PresenceData) ->
gen_server:start_link(?MODULE, PresenceData, []).
-spec init(presence_data()) -> {ok, state()}.
init(PresenceData) ->
process_flag(trap_exit, true),
UserId = maps:get(user_id, PresenceData),
UserData = maps:get(user_data, PresenceData),
Status = maps:get(status, PresenceData),
IsBot0 = maps:get(<<"bot">>, UserData, false),
IsBot =
case IsBot0 of
true -> true;
_ -> false
end,
IsBot = IsBot0 =:= true,
GuildIds0 = maps:get(guild_ids, PresenceData, []),
FriendIds0 = maps:get(friend_ids, PresenceData, []),
GroupDmRecipients0 = maps:get(group_dm_recipients, PresenceData, #{}),
@@ -50,6 +93,7 @@ init(PresenceData) ->
user_id => UserId,
user_data => UserData,
sessions => #{},
push_buffer => [],
custom_status => CustomStatus,
status => Status,
guild_ids => map_from_ids(GuildIds0),
@@ -68,6 +112,8 @@ init(PresenceData) ->
{ok, StateWithSubs}.
-spec handle_call(term(), gen_server:from(), state()) ->
{reply, term(), state()} | {stop, normal, ok, state()}.
handle_call({session_connect, Request}, {Pid, _}, State) ->
Result = presence_session:handle_session_connect(Request, Pid, State),
publish_global_if_needed(Result);
@@ -84,9 +130,7 @@ handle_call({terminate_session, SessionIdHashes}, _From, State) ->
handle_call({dispatch, EventAtom, Data}, _From, State) ->
Sessions = maps:get(sessions, State),
UserId = maps:get(user_id, State),
SessionPids = [maps:get(pid, S) || S <- maps:values(Sessions)],
lists:foreach(
fun(Pid) when is_pid(Pid) ->
case erlang:is_process_alive(Pid) of
@@ -98,7 +142,6 @@ handle_call({dispatch, EventAtom, Data}, _From, State) ->
end,
SessionPids
),
case EventAtom of
user_update ->
CurrentUserData = maps:get(user_data, State, #{}),
@@ -115,34 +158,9 @@ handle_call({dispatch, EventAtom, Data}, _From, State) ->
FinalState = force_publish_global_presence(NewState),
{reply, ok, FinalState};
message_create ->
HasMobile = lists:any(
fun(Session) ->
maps:get(mobile, Session, false)
end,
maps:values(Sessions)
),
AllAfk = lists:all(
fun(Session) ->
maps:get(afk, Session, false)
end,
maps:values(Sessions)
),
ShouldSendPush =
(map_size(Sessions) =:= 0) orelse ((not HasMobile) andalso AllAfk),
case ShouldSendPush of
true ->
AuthorIdBin = maps:get(<<"id">>, maps:get(<<"author">>, Data, #{}), <<"0">>),
AuthorId = validation:snowflake_or_default(AuthorIdBin, 0),
push:handle_message_create(#{
message_data => Data,
user_ids => [UserId],
guild_id => 0,
author_id => AuthorId
});
false ->
ok
end,
{reply, ok, State};
{reply, ok, handle_message_create_event(Data, State)};
message_ack ->
{reply, ok, handle_message_ack_event(Data, State)};
_ ->
{reply, ok, State}
end;
@@ -175,10 +193,10 @@ handle_call({terminate, SessionIdHashes}, _From, State) ->
handle_call(_, _From, State) ->
{reply, ok, State}.
-spec handle_cast(term(), state()) -> {noreply, state()}.
handle_cast({dispatch, Event, Data}, State) ->
Sessions = maps:get(sessions, State),
UserId = maps:get(user_id, State),
SessionPids = [maps:get(pid, S) || S <- maps:values(Sessions)],
lists:foreach(
fun(Pid) when is_pid(Pid) ->
@@ -186,7 +204,6 @@ handle_cast({dispatch, Event, Data}, State) ->
end,
SessionPids
),
case Event of
user_update ->
CurrentUserData = maps:get(user_data, State, #{}),
@@ -203,34 +220,9 @@ handle_cast({dispatch, Event, Data}, State) ->
FinalState = force_publish_global_presence(NewState),
{noreply, FinalState};
message_create ->
HasMobile = lists:any(
fun(Session) ->
maps:get(mobile, Session, false)
end,
maps:values(Sessions)
),
AllAfk = lists:all(
fun(Session) ->
maps:get(afk, Session, false)
end,
maps:values(Sessions)
),
ShouldSendPush =
(map_size(Sessions) =:= 0) orelse ((not HasMobile) andalso AllAfk),
case ShouldSendPush of
true ->
AuthorIdBin = maps:get(<<"id">>, maps:get(<<"author">>, Data, #{}), <<"0">>),
AuthorId = validation:snowflake_or_default(AuthorIdBin, 0),
push:handle_message_create(#{
message_data => Data,
user_ids => [UserId],
guild_id => 0,
author_id => AuthorId
});
false ->
ok
end,
{noreply, State};
{noreply, handle_message_create_event(Data, State)};
message_ack ->
{noreply, handle_message_ack_event(Data, State)};
_ ->
{noreply, State}
end;
@@ -264,9 +256,28 @@ handle_cast({sync_friends, FriendIds}, State) ->
handle_cast({sync_group_dm_recipients, RecipientsByChannel}, State) ->
NewState = sync_group_dm_subscriptions(RecipientsByChannel, State),
{noreply, NewState};
handle_cast({join_guild, GuildId}, State) ->
{reply, _Reply, NewState} = handle_join_guild(GuildId, State),
{noreply, NewState};
handle_cast({leave_guild, GuildId}, State) ->
{reply, _Reply, NewState} = handle_leave_guild(GuildId, State),
{noreply, NewState};
handle_cast({add_temporary_guild, GuildId}, State) ->
{reply, _JoinReply, JoinedState} = handle_join_guild(GuildId, State),
TemporaryGuildIds = maps:get(temporary_guild_ids, JoinedState, #{}),
NewTemporaryGuildIds = maps:put(GuildId, true, TemporaryGuildIds),
NewState = maps:put(temporary_guild_ids, NewTemporaryGuildIds, JoinedState),
{noreply, NewState};
handle_cast({remove_temporary_guild, GuildId}, State) ->
{reply, _LeaveReply, LeftState} = handle_leave_guild(GuildId, State),
TemporaryGuildIds = maps:get(temporary_guild_ids, LeftState, #{}),
NewTemporaryGuildIds = maps:remove(GuildId, TemporaryGuildIds),
NewState = maps:put(temporary_guild_ids, NewTemporaryGuildIds, LeftState),
{noreply, NewState};
handle_cast(_, State) ->
{noreply, State}.
-spec handle_info(term(), state()) -> {noreply, state()} | {stop, normal, state()}.
handle_info({presence, TargetId, Payload}, State) ->
dispatch_global_presence(TargetId, Payload, State);
handle_info({initial_presences, Presences}, State) ->
@@ -284,18 +295,22 @@ handle_info({'DOWN', Ref, process, _Pid, Reason}, State) ->
handle_info(_, State) ->
{noreply, State}.
-spec terminate(term(), state() | term()) -> ok.
terminate(_Reason, State) when not is_map(State) ->
ok;
terminate(_Reason, State) ->
flush_push_buffer(State),
UserId = maps:get(user_id, State),
presence_cache:delete(UserId),
publish_offline_on_terminate(UserId, State),
kick_temporary_members_on_terminate(UserId, State),
ok.
-spec code_change(term(), state(), term()) -> {ok, state()}.
code_change(_OldVsn, State, _Extra) ->
{ok, State}.
-spec kick_temporary_members_on_terminate(user_id(), state()) -> ok.
kick_temporary_members_on_terminate(UserId, State) ->
TemporaryGuildIds = maps:get(temporary_guild_ids, State, #{}),
case map_size(TemporaryGuildIds) of
@@ -309,18 +324,11 @@ kick_temporary_members_on_terminate(UserId, State) ->
<<"user_id">> => type_conv:to_binary(UserId),
<<"guild_ids">> => [type_conv:to_binary(Gid) || Gid <- GuildIdsList]
},
case rpc_client:call(Request) of
{ok, _} ->
ok;
{error, Reason} ->
logger:warning(
"[presence] Failed to kick temporary member ~p from guilds ~p: ~p",
[UserId, GuildIdsList, Reason]
)
end
rpc_client:call(Request)
end)
end.
-spec publish_offline_on_terminate(user_id(), state()) -> ok.
publish_offline_on_terminate(UserId, State) ->
LastPublished = maps:get(last_published_presence, State, undefined),
WasVisible =
@@ -351,9 +359,10 @@ publish_offline_on_terminate(UserId, State) ->
ok
end.
-spec handle_process_down(reference(), term(), state()) ->
{noreply, state()} | {stop, normal, state()}.
handle_process_down(Ref, _Reason, State) ->
Sessions = maps:get(sessions, State),
case presence_session:find_session_by_ref(Ref, Sessions) of
{ok, SessionId} ->
NewSessions = maps:remove(SessionId, Sessions),
@@ -370,6 +379,7 @@ handle_process_down(Ref, _Reason, State) ->
{noreply, State}
end.
-spec ensure_initial_global_subscriptions(state()) -> state().
ensure_initial_global_subscriptions(State) ->
case maps:get(is_bot, State, false) of
true ->
@@ -399,6 +409,8 @@ ensure_initial_global_subscriptions(State) ->
)
end.
-spec publish_global_if_needed({reply, term(), state()} | {noreply, state()}) ->
{reply, term(), state()} | {noreply, state()}.
publish_global_if_needed({reply, Reply, NewState}) ->
FinalState = publish_global_presence(maps:get(sessions, NewState), NewState),
{reply, Reply, FinalState};
@@ -406,21 +418,167 @@ publish_global_if_needed({noreply, NewState}) ->
FinalState = publish_global_presence(maps:get(sessions, NewState), NewState),
{noreply, FinalState}.
-spec publish_global_presence(sessions(), state()) -> state().
publish_global_presence(_Sessions, State) ->
{Payload, CurrentExternal, ExternalStatus} = build_presence_external(State),
LastPublished = maps:get(last_published_presence, State, undefined),
case presence_changed(LastPublished, CurrentExternal) of
true ->
publish_presence_payload(State, Payload, CurrentExternal, ExternalStatus);
NewState = publish_presence_payload(State, Payload, CurrentExternal, ExternalStatus),
maybe_update_push_eligibility(NewState);
false ->
maybe_update_push_eligibility(State)
end.
-spec force_publish_global_presence(state()) -> state().
force_publish_global_presence(State) ->
{Payload, CurrentExternal, ExternalStatus} = build_presence_external(State),
NewState = publish_presence_payload(State, Payload, CurrentExternal, ExternalStatus),
maybe_update_push_eligibility(NewState).
-spec handle_message_create_event(map(), state()) -> state().
handle_message_create_event(Data, State) ->
UserId = maps:get(user_id, State),
Sessions = maps:get(sessions, State, #{}),
case build_push_create_params(UserId, Data) of
undefined ->
State;
Params ->
Eligible = is_push_eligible(Sessions),
case Eligible of
true ->
FlushedState = flush_push_buffer(State),
push:handle_message_create(Params),
FlushedState;
false ->
buffer_push_notification(Params, State)
end
end.
-spec handle_message_ack_event(map(), state()) -> state().
handle_message_ack_event(Data, State) ->
ChannelId = parse_snowflake(<<"channel_id">>, maps:get(<<"channel_id">>, Data, undefined)),
MessageId = parse_snowflake(<<"message_id">>, maps:get(<<"message_id">>, Data, undefined)),
case {ChannelId, MessageId} of
{ParsedChannelId, ParsedMessageId}
when is_integer(ParsedChannelId), is_integer(ParsedMessageId)
->
ack_push_buffer(ParsedChannelId, ParsedMessageId, State);
_ ->
State
end.
force_publish_global_presence(State) ->
{Payload, CurrentExternal, ExternalStatus} = build_presence_external(State),
publish_presence_payload(State, Payload, CurrentExternal, ExternalStatus).
-spec build_push_create_params(user_id(), map()) -> map() | undefined.
build_push_create_params(UserId, Data) ->
AuthorIdBin = maps:get(<<"id">>, maps:get(<<"author">>, Data, #{}), undefined),
case parse_snowflake(<<"author_id">>, AuthorIdBin) of
undefined ->
undefined;
AuthorId ->
#{
message_data => Data,
user_ids => [UserId],
guild_id => 0,
author_id => AuthorId
}
end.
-spec buffer_push_notification(map(), state()) -> state().
buffer_push_notification(Params, State) ->
case make_push_buffer_entry(Params) of
undefined ->
State;
Entry ->
Buffer = maps:get(push_buffer, State, []),
maps:put(push_buffer, [Entry | Buffer], State)
end.
-spec flush_push_buffer(state()) -> state().
flush_push_buffer(State) ->
Buffer = maps:get(push_buffer, State, []),
case Buffer of
[] ->
State;
_ ->
Entries = lists:reverse(Buffer),
lists:foreach(
fun(Entry) ->
push:handle_message_create(maps:get(params, Entry))
end,
Entries
),
maps:put(push_buffer, [], State)
end.
-spec ack_push_buffer(integer(), integer(), state()) -> state().
ack_push_buffer(ChannelId, MessageId, State) when ChannelId > 0, MessageId > 0 ->
Buffer = maps:get(push_buffer, State, []),
FilteredBuffer = lists:filter(
fun(Entry) -> not should_drop_buffer_entry(Entry, ChannelId, MessageId) end,
Buffer
),
maps:put(push_buffer, FilteredBuffer, State);
ack_push_buffer(_, _, State) ->
State.
-spec should_drop_buffer_entry(push_buffer_entry(), integer(), integer()) -> boolean().
should_drop_buffer_entry(Entry, ChannelId, MessageId) ->
EntryChannel = maps:get(channel_id, Entry),
EntryMessage = maps:get(message_id, Entry),
EntryChannel =:= ChannelId andalso EntryMessage =< MessageId.
-spec make_push_buffer_entry(map()) -> push_buffer_entry() | undefined.
make_push_buffer_entry(Params) ->
MessageData = maps:get(message_data, Params, #{}),
ChannelId = parse_snowflake(<<"channel_id">>, maps:get(<<"channel_id">>, MessageData, undefined)),
MessageId = parse_snowflake(<<"id">>, maps:get(<<"id">>, MessageData, undefined)),
case {ChannelId, MessageId} of
{ParsedChannelId, ParsedMessageId}
when is_integer(ParsedChannelId), is_integer(ParsedMessageId)
->
#{
channel_id => ParsedChannelId,
message_id => ParsedMessageId,
params => Params
};
_ ->
undefined
end.
-spec parse_snowflake(binary(), term()) -> integer() | undefined.
parse_snowflake(FieldName, Value) ->
case validation:validate_snowflake(FieldName, Value) of
{ok, Id} -> Id;
{error, _, _} -> undefined
end.
-spec is_push_eligible(sessions()) -> boolean().
is_push_eligible(Sessions) ->
case map_size(Sessions) of
0 ->
true;
_ ->
HasMobile = lists:any(
fun(Session) -> maps:get(mobile, Session, false) end,
maps:values(Sessions)
),
AllAfk = lists:all(
fun(Session) -> maps:get(afk, Session, false) end,
maps:values(Sessions)
),
(not HasMobile) andalso AllAfk
end.
-spec maybe_update_push_eligibility(state()) -> state().
maybe_update_push_eligibility(State) ->
Sessions = maps:get(sessions, State, #{}),
Eligible = is_push_eligible(Sessions),
case {Eligible, maps:get(push_buffer, State, [])} of
{true, [_ | _]} -> flush_push_buffer(State);
_ -> State
end.
-spec build_presence_external(state()) -> {map(), map(), binary()}.
build_presence_external(State) ->
Payload = build_presence_payload(State),
ExternalStatus = maps:get(<<"status">>, Payload, <<"offline">>),
@@ -435,6 +593,7 @@ build_presence_external(State) ->
},
{Payload, CurrentExternal, ExternalStatus}.
-spec publish_presence_payload(state(), map(), map(), binary()) -> state().
publish_presence_payload(State, Payload, CurrentExternal, ExternalStatus) ->
UserId = maps:get(user_id, State),
case ExternalStatus of
@@ -446,11 +605,13 @@ publish_presence_payload(State, Payload, CurrentExternal, ExternalStatus) ->
presence_bus:publish(UserId, Payload),
maps:put(last_published_presence, CurrentExternal, State).
-spec presence_changed(map() | undefined, map()) -> boolean().
presence_changed(undefined, _Current) ->
true;
presence_changed(Last, Current) ->
Last =/= Current.
-spec publish_user_update_to_bus(user_id(), map(), state()) -> ok.
publish_user_update_to_bus(UserId, UserData, State) ->
LastPublished = maps:get(last_published_presence, State, undefined),
WasVisible = is_last_published_visible(LastPublished),
@@ -466,6 +627,7 @@ publish_user_update_to_bus(UserId, UserData, State) ->
ok
end.
-spec is_last_published_visible(map() | undefined) -> boolean().
is_last_published_visible(undefined) ->
false;
is_last_published_visible(#{status := Status}) when
@@ -477,6 +639,7 @@ is_last_published_visible(#{status := Status}) when
is_last_published_visible(_) ->
false.
-spec dispatch_global_presence(user_id(), map(), state()) -> {noreply, state()}.
dispatch_global_presence(TargetId, Payload, State) ->
UserId = maps:get(user_id, State),
case TargetId =:= UserId of
@@ -495,6 +658,7 @@ dispatch_global_presence(TargetId, Payload, State) ->
{noreply, State}
end.
-spec sync_friend_subscriptions([user_id()], state()) -> state().
sync_friend_subscriptions(FriendIds, State) ->
case maps:get(is_bot, State, false) of
true ->
@@ -525,6 +689,8 @@ sync_friend_subscriptions(FriendIds, State) ->
maybe_force_offline(Removals, State4)
end.
-spec sync_group_dm_subscriptions(#{integer() => [user_id()] | #{user_id() => true}}, state()) ->
state().
sync_group_dm_subscriptions(RecipientsByChannel, State) ->
case maps:get(is_bot, State, false) of
true ->
@@ -558,6 +724,10 @@ sync_group_dm_subscriptions(RecipientsByChannel, State) ->
maps:put(group_dm_recipients, Normalized, State4)
end.
-spec diff_group_dm_recipients(#{integer() => #{user_id() => true}}, #{
integer() => #{user_id() => true}
}) ->
{[{user_id(), integer()}], [{user_id(), integer()}]}.
diff_group_dm_recipients(Old, New) ->
OldPairs =
lists:append(
@@ -578,6 +748,7 @@ diff_group_dm_recipients(Old, New) ->
lists:subtract(OldPairs, NewPairs)
}.
-spec ensure_subscription(user_id(), friend | gdm, integer() | undefined, state()) -> state().
ensure_subscription(UserId, Reason, ChannelId, State) ->
case UserId =:= maps:get(user_id, State) of
true ->
@@ -602,6 +773,8 @@ ensure_subscription(UserId, Reason, ChannelId, State) ->
maps:put(subscriptions, NewSubscriptions, State)
end.
-spec remove_subscription_reason(user_id(), friend | gdm, integer() | undefined, state()) ->
state().
remove_subscription_reason(UserId, Reason, ChannelId, State) ->
Subscriptions = maps:get(subscriptions, State, #{}),
Entry0 = maps:get(UserId, Subscriptions, #{friend => false, gdm_channels => #{}}),
@@ -625,10 +798,15 @@ remove_subscription_reason(UserId, Reason, ChannelId, State) ->
end,
maps:put(subscriptions, NewSubscriptions, State).
-spec has_subscription(subscription_entry()) -> boolean().
has_subscription(Entry) ->
(maps:get(friend, Entry, false) =:= true) orelse
(map_size(maps:get(gdm_channels, Entry, #{})) > 0).
-spec normalize_group_dm_recipients(
#{integer() => [user_id()] | #{user_id() => true}}, user_id(), boolean()
) ->
#{integer() => #{user_id() => true}}.
normalize_group_dm_recipients(RecipientsByChannel, UserId, IsBot) ->
case IsBot of
true ->
@@ -646,6 +824,7 @@ normalize_group_dm_recipients(RecipientsByChannel, UserId, IsBot) ->
)
end.
-spec handle_join_guild(integer(), state()) -> {reply, ok, state()}.
handle_join_guild(GuildId, State) ->
Guilds = maps:get(guild_ids, State, #{}),
case maps:is_key(GuildId, Guilds) of
@@ -658,6 +837,7 @@ handle_join_guild(GuildId, State) ->
{reply, ok, NewState}
end.
-spec handle_leave_guild(integer(), state()) -> {reply, ok, state()}.
handle_leave_guild(GuildId, State) ->
Guilds = maps:get(guild_ids, State, #{}),
case maps:is_key(GuildId, Guilds) of
@@ -673,9 +853,11 @@ handle_leave_guild(GuildId, State) ->
{reply, ok, NewState}
end.
-spec map_from_ids([term()]) -> #{term() => true}.
map_from_ids(Ids) when is_list(Ids) ->
maps:from_list([{Id, true} || Id <- Ids]).
-spec cache_if_visible(user_id(), map()) -> ok.
cache_if_visible(UserId, Payload) when is_integer(UserId), is_map(Payload) ->
Status = maps:get(<<"status">>, Payload, <<"offline">>),
case Status of
@@ -686,6 +868,7 @@ cache_if_visible(UserId, Payload) when is_integer(UserId), is_map(Payload) ->
cache_if_visible(_, _) ->
ok.
-spec build_presence_payload(state()) -> map().
build_presence_payload(State) ->
Sessions = maps:get(sessions, State),
Status = presence_status:get_current_status(Sessions),
@@ -695,6 +878,7 @@ build_presence_payload(State) ->
CustomStatus = maps:get(custom_status, State, null),
presence_payload:build(UserData, Status, Mobile, Afk, CustomStatus).
-spec maybe_handle_custom_status(map(), state()) -> {map(), state()}.
maybe_handle_custom_status(Request, State) ->
case maps:find(<<"custom_status">>, Request) of
error ->
@@ -716,6 +900,7 @@ maybe_handle_custom_status(Request, State) ->
{Request, State}
end.
-spec validate_custom_status(map(), map(), state()) -> {map(), state()}.
validate_custom_status(CustomStatus, Request, State) ->
UserId = maps:get(user_id, State),
case custom_status_validation:validate(UserId, CustomStatus) of
@@ -725,14 +910,11 @@ validate_custom_status(CustomStatus, Request, State) ->
{ok, _} ->
UpdatedRequest = maps:put(<<"custom_status">>, null, Request),
{UpdatedRequest, maps:put(custom_status, null, State)};
{error, Reason} ->
logger:warning(
"[presence] Custom status validation failed for user ~p: ~p",
[UserId, Reason]
),
{error, _Reason} ->
{Request, State}
end.
-spec custom_status_comparator(custom_status()) -> map() | null.
custom_status_comparator(null) ->
null;
custom_status_comparator(Map) when is_map(Map) ->
@@ -743,6 +925,7 @@ custom_status_comparator(Map) when is_map(Map) ->
<<"emoji_name">> => field_or_null(Map, <<"emoji_name">>)
}.
-spec handle_user_settings_update(map(), state()) -> state().
handle_user_settings_update(Data, State) ->
case maps:find(<<"custom_status">>, Data) of
error ->
@@ -752,6 +935,7 @@ handle_user_settings_update(Data, State) ->
maps:put(custom_status, Normalized, State)
end.
-spec normalize_state_custom_status(term()) -> custom_status().
normalize_state_custom_status(null) ->
null;
normalize_state_custom_status(Map) when is_map(Map) ->
@@ -759,12 +943,14 @@ normalize_state_custom_status(Map) when is_map(Map) ->
normalize_state_custom_status(_) ->
null.
-spec field_or_null(map(), binary()) -> term() | null.
field_or_null(Map, Key) ->
case maps:get(Key, Map, undefined) of
undefined -> null;
Value -> Value
end.
-spec maybe_send_cached_presences([user_id()], state()) -> state().
maybe_send_cached_presences(UserIds, State) ->
case UserIds of
[] ->
@@ -784,6 +970,7 @@ maybe_send_cached_presences(UserIds, State) ->
State
end.
-spec maybe_force_offline([user_id()], state()) -> state().
maybe_force_offline(UserIds, State) ->
Subscriptions = maps:get(subscriptions, State, #{}),
lists:foldl(
@@ -807,6 +994,7 @@ maybe_force_offline(UserIds, State) ->
UserIds
).
-spec notify_sessions_presence(map(), state()) -> state().
notify_sessions_presence(Payload, State) ->
Sessions = maps:get(sessions, State, #{}),
SessionPids = [maps:get(pid, S) || S <- maps:values(Sessions)],
@@ -818,6 +1006,7 @@ notify_sessions_presence(Payload, State) ->
),
State.
-spec fetch_initial_presences(pid(), state()) -> ok.
fetch_initial_presences(PresencePid, State) ->
case maps:get(is_bot, State, false) of
true ->
@@ -850,6 +1039,7 @@ fetch_initial_presences(PresencePid, State) ->
end
end.
-spec recipient_list([user_id()] | #{user_id() => true} | term()) -> [user_id()].
recipient_list(Value) when is_list(Value) ->
Value;
recipient_list(Value) when is_map(Value) ->
@@ -877,12 +1067,61 @@ gdm_subscription_add_remove_test() ->
Entry1 = maps:get(10, Subscriptions1),
GdmChannels1 = maps:get(gdm_channels, Entry1, #{}),
?assertEqual(true, maps:get(1, GdmChannels1)),
State2 = sync_group_dm_subscriptions(#{}, State1),
Subscriptions2 = maps:get(subscriptions, State2, #{}),
?assertEqual(false, maps:is_key(10, Subscriptions2)),
ok.
map_from_ids_test() ->
?assertEqual(#{}, map_from_ids([])),
?assertEqual(#{1 => true, 2 => true}, map_from_ids([1, 2])).
has_subscription_test() ->
?assertEqual(false, has_subscription(#{friend => false, gdm_channels => #{}})),
?assertEqual(true, has_subscription(#{friend => true, gdm_channels => #{}})),
?assertEqual(true, has_subscription(#{friend => false, gdm_channels => #{1 => true}})).
is_push_eligible_test() ->
?assertEqual(true, is_push_eligible(#{})),
?assertEqual(false, is_push_eligible(#{<<"s1">> => #{mobile => true, afk => false}})),
?assertEqual(true, is_push_eligible(#{<<"s1">> => #{mobile => false, afk => true}})),
?assertEqual(false, is_push_eligible(#{<<"s1">> => #{mobile => false, afk => false}})).
presence_changed_test() ->
?assertEqual(true, presence_changed(undefined, #{status => <<"online">>})),
?assertEqual(false, presence_changed(#{status => <<"online">>}, #{status => <<"online">>})),
?assertEqual(true, presence_changed(#{status => <<"online">>}, #{status => <<"idle">>})).
is_last_published_visible_test() ->
?assertEqual(false, is_last_published_visible(undefined)),
?assertEqual(true, is_last_published_visible(#{status => <<"online">>})),
?assertEqual(true, is_last_published_visible(#{status => <<"idle">>})),
?assertEqual(true, is_last_published_visible(#{status => <<"dnd">>})),
?assertEqual(false, is_last_published_visible(#{status => <<"offline">>})),
?assertEqual(false, is_last_published_visible(#{status => <<"invisible">>})).
custom_status_comparator_test() ->
?assertEqual(null, custom_status_comparator(null)),
Expected = #{
<<"text">> => <<"hello">>,
<<"expires_at">> => null,
<<"emoji_id">> => null,
<<"emoji_name">> => null
},
?assertEqual(Expected, custom_status_comparator(#{<<"text">> => <<"hello">>})).
normalize_state_custom_status_test() ->
?assertEqual(null, normalize_state_custom_status(null)),
?assertEqual(
#{<<"text">> => <<"hi">>}, normalize_state_custom_status(#{<<"text">> => <<"hi">>})
),
?assertEqual(null, normalize_state_custom_status(<<"invalid">>)).
recipient_list_test() ->
?assertEqual([1, 2], recipient_list([1, 2])),
?assertEqual([1], recipient_list(#{1 => true})),
?assertEqual([], recipient_list(undefined)).
maybe_start_presence_bus() ->
case whereis(presence_bus) of
undefined ->

View File

@@ -45,9 +45,8 @@ publish(UserId, Payload) when is_integer(UserId) ->
-spec init(list()) -> {ok, state()}.
init([]) ->
process_flag(trap_exit, true),
{ShardCount, Source} = determine_shard_count(presence_bus_shards),
{ShardCount, _Source} = determine_shard_count(presence_bus_shards),
Shards = start_shards(ShardCount, #{}),
maybe_log_shard_source(presence_bus, ShardCount, Source),
{ok, #{shards => Shards, shard_count => ShardCount}}.
-spec handle_call(term(), gen_server:from(), state()) -> {reply, term(), state()}.
@@ -60,8 +59,7 @@ handle_call({unsubscribe, UserId, Pid}, _From, State) ->
handle_call({publish, UserId, Payload}, _From, State) ->
{Reply, NewState} = forward_call(UserId, {publish, UserId, Payload}, State),
{reply, Reply, NewState};
handle_call(Request, _From, State) ->
logger:warning("[presence_bus] unknown request ~p", [Request]),
handle_call(_Request, _From, State) ->
{reply, ok, State}.
-spec handle_cast(term(), state()) -> {noreply, state()}.
@@ -69,21 +67,19 @@ handle_cast(_Msg, State) ->
{noreply, State}.
-spec handle_info(term(), state()) -> {noreply, state()}.
handle_info({'DOWN', Ref, process, _Pid, Reason}, State) ->
handle_info({'DOWN', Ref, process, _Pid, _Reason}, State) ->
Shards = maps:get(shards, State),
case find_shard_by_ref(Ref, Shards) of
{ok, Index} ->
logger:warning("[presence_bus] shard ~p crashed: ~p", [Index, Reason]),
{_Shard, NewState} = restart_shard(Index, State),
{noreply, NewState};
not_found ->
{noreply, State}
end;
handle_info({'EXIT', Pid, Reason}, State) ->
handle_info({'EXIT', Pid, _Reason}, State) ->
Shards = maps:get(shards, State),
case find_shard_by_pid(Pid, Shards) of
{ok, Index} ->
logger:warning("[presence_bus] shard ~p exited: ~p", [Index, Reason]),
{_Shard, NewState} = restart_shard(Index, State),
{noreply, NewState};
not_found ->
@@ -132,17 +128,6 @@ default_shard_count() ->
],
lists:max([C || C <- Candidates, is_integer(C), C > 0] ++ [1]).
-spec maybe_log_shard_source(atom(), pos_integer(), configured | auto) -> ok.
maybe_log_shard_source(Name, Count, configured) ->
logger:info("[~p] starting with ~p shards (configured)", [Name, Count]),
ok;
maybe_log_shard_source(Name, Count, auto) ->
logger:info(
"[~p] starting with ~p shards (auto, set FLUXER_GATEWAY_PRESENCE_BUS_SHARDS for cross-node consistency)",
[Name, Count]
),
ok.
-spec start_shards(pos_integer(), #{}) -> #{non_neg_integer() => shard()}.
start_shards(Count, Acc) ->
lists:foldl(
@@ -150,8 +135,7 @@ start_shards(Count, Acc) ->
case start_shard(Index) of
{ok, Shard} ->
maps:put(Index, Shard, MapAcc);
{error, Reason} ->
logger:warning("[presence_bus] failed to start shard ~p: ~p", [Index, Reason]),
{error, _Reason} ->
MapAcc
end
end,
@@ -176,8 +160,7 @@ restart_shard(Index, State) ->
Shards = maps:get(shards, State),
Updated = State#{shards := maps:put(Index, Shard, Shards)},
{Shard, Updated};
{error, Reason} ->
logger:error("[presence_bus] failed to restart shard ~p: ~p", [Index, Reason]),
{error, _Reason} ->
Dummy = #{pid => spawn(fun() -> exit(normal) end), ref => make_ref()},
{Dummy, State}
end.
@@ -284,6 +267,21 @@ unsubscribe_stops_delivery_test() ->
end,
?assertEqual(ok, gen_server:stop(Pid)).
select_shard_test() ->
?assert(select_shard(100, 4) >= 0),
?assert(select_shard(100, 4) < 4).
find_shard_by_ref_test() ->
Ref1 = make_ref(),
Shards = #{0 => #{pid => self(), ref => Ref1}},
?assertEqual({ok, 0}, find_shard_by_ref(Ref1, Shards)),
?assertEqual(not_found, find_shard_by_ref(make_ref(), Shards)).
find_shard_by_pid_test() ->
Shards = #{0 => #{pid => self(), ref => make_ref()}},
?assertEqual({ok, 0}, find_shard_by_pid(self(), Shards)),
?assertEqual(not_found, find_shard_by_pid(spawn(fun() -> ok end), Shards)).
maybe_start_for_test() ->
case whereis(?MODULE) of
undefined -> start_link();

View File

@@ -60,16 +60,11 @@ handle_cast(_Msg, State) ->
{noreply, State}.
-spec handle_info(term(), state()) -> {noreply, state()}.
handle_info({'EXIT', PgPid, Reason}, State) ->
handle_info({'EXIT', PgPid, _Reason}, State) ->
StoredPgPid = maps:get(pg_pid, State),
case PgPid =:= StoredPgPid of
true ->
Scope = maps:get(scope, State),
ShardIndex = maps:get(shard_index, State),
logger:warning(
"[presence_bus_shard ~p] pg process exited: ~p; restarting scope",
[ShardIndex, Reason]
),
case ensure_pg_scope(Scope) of
{ok, NewPgPid} ->
{noreply, State#{pg_pid := NewPgPid}};
@@ -98,8 +93,7 @@ do_subscribe(Scope, UserId, Pid) ->
case catch pg:join(Scope, Group, Pid) of
ok ->
ok;
{'EXIT', Reason} ->
logger:warning("[presence_bus_shard] failed to join group ~p: ~p", [Group, Reason]),
{'EXIT', _Reason} ->
ok;
_ ->
ok
@@ -111,8 +105,7 @@ do_unsubscribe(Scope, UserId, Pid) ->
case catch pg:leave(Scope, Group, Pid) of
ok ->
ok;
{'EXIT', Reason} ->
logger:warning("[presence_bus_shard] failed to leave group ~p: ~p", [Group, Reason]),
{'EXIT', _Reason} ->
ok;
_ ->
ok
@@ -157,3 +150,11 @@ ensure_pg_scope(Scope) ->
-spec scope_name(non_neg_integer()) -> atom().
scope_name(Index) ->
list_to_atom(atom_to_list(?SCOPE_PREFIX) ++ "_" ++ integer_to_list(Index)).
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
scope_name_test() ->
?assertEqual(presence_bus_0, scope_name(0)),
?assertEqual(presence_bus_42, scope_name(42)).
-endif.

View File

@@ -44,7 +44,7 @@ delete(UserId) when is_integer(UserId) ->
get(UserId) when is_integer(UserId) ->
gen_server:call(?MODULE, {get, UserId}, ?DEFAULT_GEN_SERVER_TIMEOUT).
-spec bulk_get([term()]) -> [map()].
-spec bulk_get([integer()]) -> [map()].
bulk_get(UserIds) when is_list(UserIds) ->
gen_server:call(?MODULE, {bulk_get, UserIds}, ?DEFAULT_GEN_SERVER_TIMEOUT).
@@ -55,9 +55,8 @@ get_memory_stats() ->
-spec init(list()) -> {ok, state()}.
init([]) ->
process_flag(trap_exit, true),
{ShardCount, Source} = determine_shard_count(presence_cache_shards),
{ShardCount, _Source} = determine_shard_count(presence_cache_shards),
Shards = start_shards(ShardCount, #{}),
maybe_log_shard_source(presence_cache, ShardCount, Source),
{ok, #{shards => Shards, shard_count => ShardCount}}.
-spec handle_call(term(), gen_server:from(), state()) -> {reply, term(), state()}.
@@ -76,23 +75,30 @@ handle_call({bulk_get, UserIds}, _From, State) ->
handle_call(get_memory_stats, _From, State) ->
Count = maps:get(shard_count, State),
WordSize = erlang:system_info(wordsize),
TotalMemory = lists:foldl(fun(Index, Acc) ->
TableName = presence_cache_shard:table_name(Index),
case ets:info(TableName, memory) of
undefined -> Acc;
Words -> Acc + (Words * WordSize)
end
end, 0, lists:seq(0, Count - 1)),
TotalEntries = lists:foldl(fun(Index, Acc) ->
TableName = presence_cache_shard:table_name(Index),
case ets:info(TableName, size) of
undefined -> Acc;
Size -> Acc + Size
end
end, 0, lists:seq(0, Count - 1)),
TotalMemory = lists:foldl(
fun(Index, Acc) ->
TableName = presence_cache_shard:table_name(Index),
case ets:info(TableName, memory) of
undefined -> Acc;
Words -> Acc + (Words * WordSize)
end
end,
0,
lists:seq(0, Count - 1)
),
TotalEntries = lists:foldl(
fun(Index, Acc) ->
TableName = presence_cache_shard:table_name(Index),
case ets:info(TableName, size) of
undefined -> Acc;
Size -> Acc + Size
end
end,
0,
lists:seq(0, Count - 1)
),
{reply, {ok, #{memory_bytes => TotalMemory, entry_count => TotalEntries}}, State};
handle_call(Request, _From, State) ->
logger:warning("[presence_cache] unknown request ~p", [Request]),
handle_call(_Request, _From, State) ->
{reply, ok, State}.
-spec handle_cast(term(), state()) -> {noreply, state()}.
@@ -100,21 +106,19 @@ handle_cast(_Msg, State) ->
{noreply, State}.
-spec handle_info(term(), state()) -> {noreply, state()}.
handle_info({'DOWN', Ref, process, _Pid, Reason}, State) ->
handle_info({'DOWN', Ref, process, _Pid, _Reason}, State) ->
Shards = maps:get(shards, State),
case find_shard_by_ref(Ref, Shards) of
{ok, Index} ->
logger:warning("[presence_cache] shard ~p crashed: ~p", [Index, Reason]),
{_Shard, NewState} = restart_shard(Index, State),
{noreply, NewState};
not_found ->
{noreply, State}
end;
handle_info({'EXIT', Pid, Reason}, State) ->
handle_info({'EXIT', Pid, _Reason}, State) ->
Shards = maps:get(shards, State),
case find_shard_by_pid(Pid, Shards) of
{ok, Index} ->
logger:warning("[presence_cache] shard ~p exited: ~p", [Index, Reason]),
{_Shard, NewState} = restart_shard(Index, State),
{noreply, NewState};
not_found ->
@@ -163,14 +167,6 @@ default_shard_count() ->
],
lists:max([C || C <- Candidates, is_integer(C), C > 0] ++ [1]).
-spec maybe_log_shard_source(atom(), pos_integer(), configured | auto) -> ok.
maybe_log_shard_source(Name, Count, configured) ->
logger:info("[~p] starting with ~p shards (configured)", [Name, Count]),
ok;
maybe_log_shard_source(Name, Count, auto) ->
logger:info("[~p] starting with ~p shards (auto)", [Name, Count]),
ok.
-spec start_shards(pos_integer(), #{}) -> #{non_neg_integer() => shard()}.
start_shards(Count, Acc) ->
lists:foldl(
@@ -178,8 +174,7 @@ start_shards(Count, Acc) ->
case start_shard(Index) of
{ok, Shard} ->
maps:put(Index, Shard, MapAcc);
{error, Reason} ->
logger:warning("[presence_cache] failed to start shard ~p: ~p", [Index, Reason]),
{error, _Reason} ->
MapAcc
end
end,
@@ -204,8 +199,7 @@ restart_shard(Index, State) ->
Shards = maps:get(shards, State),
Updated = State#{shards := maps:put(Index, Shard, Shards)},
{Shard, Updated};
{error, Reason} ->
logger:error("[presence_cache] failed to restart shard ~p: ~p", [Index, Reason]),
{error, _Reason} ->
Dummy = #{pid => spawn(fun() -> exit(normal) end), ref => make_ref()},
{Dummy, State}
end.
@@ -215,7 +209,7 @@ forward_call(Key, Request, State) ->
{Index, State1} = ensure_shard(Key, State),
call_shard(Index, Request, State1).
-spec forward_bulk_get([term()], state()) -> {term(), state()}.
-spec forward_bulk_get([integer()], state()) -> {[map()], state()}.
forward_bulk_get(UserIds, State) ->
Count = maps:get(shard_count, State),
Unique = lists:usort(UserIds),
@@ -326,6 +320,16 @@ bulk_get_across_shards_test() ->
?assertEqual(2, length(Results)),
?assertEqual(ok, gen_server:stop(Pid)).
select_shard_test() ->
?assert(select_shard(100, 4) >= 0),
?assert(select_shard(100, 4) < 4).
find_shard_by_ref_test() ->
Ref1 = make_ref(),
Shards = #{0 => #{pid => self(), ref => Ref1}},
?assertEqual({ok, 0}, find_shard_by_ref(Ref1, Shards)),
?assertEqual(not_found, find_shard_by_ref(make_ref(), Shards)).
maybe_start_for_test() ->
case whereis(?MODULE) of
undefined -> start_link();

View File

@@ -31,6 +31,10 @@
start_link(ShardIndex) ->
gen_server:start_link(?MODULE, #{shard_index => ShardIndex}, []).
-spec table_name(non_neg_integer()) -> atom().
table_name(Index) ->
list_to_atom(atom_to_list(?TABLE_PREFIX) ++ "_" ++ integer_to_list(Index)).
-spec init(map()) -> {ok, state()}.
init(#{shard_index := ShardIndex}) ->
process_flag(trap_exit, true),
@@ -113,6 +117,33 @@ ensure_table(Table) ->
ok
end.
-spec table_name(non_neg_integer()) -> atom().
table_name(Index) ->
list_to_atom(atom_to_list(?TABLE_PREFIX) ++ "_" ++ integer_to_list(Index)).
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
table_name_test() ->
?assertEqual(presence_cache_0, table_name(0)),
?assertEqual(presence_cache_5, table_name(5)).
do_put_online_inserts_test() ->
Table = test_cache_table,
ets:new(Table, [named_table, public, set]),
?assertEqual(ok, do_put(Table, 1, #{<<"status">> => <<"online">>})),
?assertMatch([{1, _}], ets:lookup(Table, 1)),
ets:delete(Table).
do_put_offline_deletes_test() ->
Table = test_cache_table_2,
ets:new(Table, [named_table, public, set]),
ets:insert(Table, {1, #{<<"status">> => <<"online">>}}),
?assertEqual(ok, do_put(Table, 1, #{<<"status">> => <<"offline">>})),
?assertEqual([], ets:lookup(Table, 1)),
ets:delete(Table).
do_put_invisible_deletes_test() ->
Table = test_cache_table_3,
ets:new(Table, [named_table, public, set]),
ets:insert(Table, {1, #{<<"status">> => <<"online">>}}),
?assertEqual(ok, do_put(Table, 1, #{<<"status">> => <<"invisible">>})),
?assertEqual([], ets:lookup(Table, 1)),
ets:delete(Table).
-endif.

View File

@@ -20,7 +20,10 @@
-include_lib("fluxer_gateway/include/timeout_config.hrl").
-export([start_link/0, lookup/1, dispatch_to_user/3, terminate_all_sessions/1]).
-define(PID_CACHE_TABLE, presence_pid_cache).
-define(CACHE_TTL_MS, 300000).
-export([start_link/0, lookup/1, lookup_async/2, dispatch_to_user/3, terminate_all_sessions/1]).
-export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]).
-type user_id() :: integer().
@@ -34,7 +37,59 @@ start_link() ->
-spec lookup(user_id()) -> {ok, pid()} | {error, not_found}.
lookup(UserId) ->
gen_server:call(?MODULE, {lookup, UserId}, ?DEFAULT_GEN_SERVER_TIMEOUT).
case check_cache(UserId) of
{hit, Pid} ->
{ok, Pid};
miss ->
lookup_and_cache(UserId)
end.
-spec lookup_async(user_id(), term()) -> ok.
lookup_async(UserId, Message) ->
case check_cache(UserId) of
{hit, Pid} ->
gen_server:cast(Pid, Message);
miss ->
spawn(fun() -> lookup_and_cast(UserId, Message) end)
end,
ok.
-spec check_cache(user_id()) -> {hit, pid()} | miss.
check_cache(UserId) ->
case ets:lookup(?PID_CACHE_TABLE, UserId) of
[{UserId, Pid, Timestamp}] ->
IsFresh = erlang:monotonic_time(millisecond) - Timestamp < ?CACHE_TTL_MS,
IsAlive = erlang:is_process_alive(Pid),
case {IsFresh, IsAlive} of
{true, true} ->
{hit, Pid};
_ ->
ets:delete(?PID_CACHE_TABLE, UserId),
miss
end;
[] ->
miss
end.
-spec lookup_and_cache(user_id()) -> {ok, pid()} | {error, not_found}.
lookup_and_cache(UserId) ->
case gen_server:call(?MODULE, {lookup, UserId}, ?DEFAULT_GEN_SERVER_TIMEOUT) of
{ok, Pid} ->
ets:insert(?PID_CACHE_TABLE, {UserId, Pid, erlang:monotonic_time(millisecond)}),
{ok, Pid};
{error, not_found} = Error ->
Error
end.
-spec lookup_and_cast(user_id(), term()) -> ok.
lookup_and_cast(UserId, Message) ->
case gen_server:call(?MODULE, {lookup, UserId}, ?DEFAULT_GEN_SERVER_TIMEOUT) of
{ok, Pid} ->
ets:insert(?PID_CACHE_TABLE, {UserId, Pid, erlang:monotonic_time(millisecond)}),
gen_server:cast(Pid, Message);
{error, not_found} ->
ok
end.
-spec terminate_all_sessions(user_id()) -> ok | {error, term()}.
terminate_all_sessions(UserId) ->
@@ -47,23 +102,20 @@ dispatch_to_user(UserId, Event, Data) ->
-spec init(list()) -> {ok, state()}.
init([]) ->
process_flag(trap_exit, true),
{ShardCount, Source} = determine_shard_count(),
ets:new(?PID_CACHE_TABLE, [named_table, public, set]),
{ShardCount, _Source} = determine_shard_count(),
{ShardMap, _} = lists:foldl(
fun(Index, {Acc, Counter}) ->
case start_shard(Index) of
{ok, Shard} ->
{maps:put(Index, Shard, Acc), Counter + 1};
{error, Reason} ->
logger:warning("[presence_manager] failed to start shard ~p: ~p", [
Index, Reason
]),
{error, _Reason} ->
{Acc, Counter}
end
end,
{#{}, 0},
lists:seq(0, ShardCount - 1)
),
maybe_log_shard_source(presence_manager, ShardCount, Source),
{ok, #{shards => ShardMap, shard_count => ShardCount}}.
-spec handle_call(term(), gen_server:from(), state()) -> {reply, term(), state()}.
@@ -86,8 +138,7 @@ handle_call(get_local_count, _From, State) ->
handle_call(get_global_count, _From, State) ->
{Count, NewState} = aggregate_counts(get_global_count, State),
{reply, {ok, Count}, NewState};
handle_call(Request, _From, State) ->
logger:warning("[presence_manager] unknown request ~p", [Request]),
handle_call(_Request, _From, State) ->
{reply, ok, State}.
-spec handle_cast(term(), state()) -> {noreply, state()}.
@@ -95,21 +146,21 @@ handle_cast(_Msg, State) ->
{noreply, State}.
-spec handle_info(term(), state()) -> {noreply, state()}.
handle_info({'DOWN', Ref, process, _Pid, Reason}, State) ->
handle_info({'DOWN', Ref, process, Pid, _Reason}, State) ->
clean_cache_by_pid(Pid),
Shards = maps:get(shards, State),
case find_shard_by_ref(Ref, Shards) of
{ok, Index} ->
logger:warning("[presence_manager] shard ~p crashed: ~p", [Index, Reason]),
{_ShardEntry, UpdatedState} = restart_shard(Index, State),
{noreply, UpdatedState};
not_found ->
{noreply, State}
end;
handle_info({'EXIT', Pid, Reason}, State) ->
handle_info({'EXIT', Pid, _Reason}, State) ->
clean_cache_by_pid(Pid),
Shards = maps:get(shards, State),
case find_shard_by_pid(Pid, Shards) of
{ok, Index} ->
logger:warning("[presence_manager] shard ~p exited: ~p", [Index, Reason]),
{_ShardEntry, UpdatedState} = restart_shard(Index, State),
{noreply, UpdatedState};
not_found ->
@@ -120,10 +171,10 @@ handle_info(_Info, State) ->
-spec terminate(term(), state()) -> ok.
terminate(_Reason, State) ->
catch ets:delete(?PID_CACHE_TABLE),
Shards = maps:get(shards, State),
lists:foreach(
fun(Shard) ->
Pid = maps:get(pid, Shard),
fun(#{pid := Pid}) ->
catch gen_server:stop(Pid, shutdown, 5000)
end,
maps:values(Shards)
@@ -168,8 +219,7 @@ restart_shard(Index, State) ->
Shards = maps:get(shards, State),
Updated = State#{shards := maps:put(Index, Shard, Shards)},
{Shard, Updated};
{error, Reason} ->
logger:error("[presence_manager] failed to restart shard ~p: ~p", [Index, Reason]),
{error, _Reason} ->
Dummy = #{pid => spawn(fun() -> exit(normal) end), ref => make_ref()},
{Dummy, State}
end.
@@ -178,8 +228,7 @@ restart_shard(Index, State) ->
forward_call(Key, Request, State) ->
{ShardIndex, State1} = ensure_shard(Key, State),
Shards = maps:get(shards, State1),
Shard = maps:get(ShardIndex, Shards),
Pid = maps:get(pid, Shard),
#{pid := Pid} = maps:get(ShardIndex, Shards),
case catch gen_server:call(Pid, Request, ?DEFAULT_GEN_SERVER_TIMEOUT) of
{'EXIT', _} ->
{_ShardEntry, State2} = restart_shard(ShardIndex, State1),
@@ -191,17 +240,16 @@ forward_call(Key, Request, State) ->
-spec aggregate_counts(term(), state()) -> {non_neg_integer(), state()}.
aggregate_counts(Request, State) ->
Shards = maps:get(shards, State),
Results =
[
begin
Pid = maps:get(pid, Shard),
case catch gen_server:call(Pid, Request, ?DEFAULT_GEN_SERVER_TIMEOUT) of
{ok, Count} -> Count;
_ -> 0
end
Results = [
begin
#{pid := Pid} = Shard,
case catch gen_server:call(Pid, Request, ?DEFAULT_GEN_SERVER_TIMEOUT) of
{ok, Count} -> Count;
_ -> 0
end
|| Shard <- maps:values(Shards)
],
end
|| Shard <- maps:values(Shards)
],
{lists:sum(Results), State}.
-spec ensure_shard(user_id(), state()) -> {non_neg_integer(), state()}.
@@ -213,7 +261,7 @@ ensure_shard(Key, State) ->
undefined ->
{_ShardEntry, NewState} = restart_shard(Index, State),
{Index, NewState};
#{pid := Pid} when is_pid(Pid) ->
#{pid := Pid} ->
case erlang:is_process_alive(Pid) of
true ->
{Index, State};
@@ -258,17 +306,24 @@ find_shard_by_pid(Pid, Shards) ->
-spec default_shard_count() -> pos_integer().
default_shard_count() ->
Candidates = [
erlang:system_info(logical_processors_available), erlang:system_info(schedulers_online)
erlang:system_info(logical_processors_available),
erlang:system_info(schedulers_online)
],
lists:max([C || C <- Candidates, is_integer(C), C > 0] ++ [1]).
-spec maybe_log_shard_source(atom(), pos_integer(), configured | auto) -> ok.
maybe_log_shard_source(Name, Count, configured) ->
logger:info("[~p] starting with ~p shards (configured)", [Name, Count]),
ok;
maybe_log_shard_source(Name, Count, auto) ->
logger:info("[~p] starting with ~p shards (auto)", [Name, Count]),
ok.
-spec clean_cache_by_pid(pid()) -> ok.
clean_cache_by_pid(Pid) ->
ets:foldl(
fun
({UserId, CachedPid, _}, Acc) when CachedPid =:= Pid ->
ets:delete(?PID_CACHE_TABLE, UserId),
Acc;
(_, Acc) ->
Acc
end,
ok,
?PID_CACHE_TABLE
).
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
@@ -284,6 +339,25 @@ determine_shard_count_auto_test() ->
?assert(Count > 0)
end).
select_shard_test() ->
?assert(select_shard(100, 4) >= 0),
?assert(select_shard(100, 4) < 4).
extract_user_id_test() ->
?assertEqual(123, extract_user_id({start_or_lookup, #{user_id => 123}})),
?assertEqual(0, extract_user_id(unknown)).
find_shard_by_ref_test() ->
Ref1 = make_ref(),
Ref2 = make_ref(),
Shards = #{0 => #{pid => self(), ref => Ref1}, 1 => #{pid => self(), ref => Ref2}},
?assertEqual({ok, 0}, find_shard_by_ref(Ref1, Shards)),
?assertEqual({ok, 1}, find_shard_by_ref(Ref2, Shards)),
?assertEqual(not_found, find_shard_by_ref(make_ref(), Shards)).
default_shard_count_test() ->
?assert(default_shard_count() >= 1).
with_runtime_config(Key, Value, Fun) ->
Original = fluxer_gateway_env:get(Key),
fluxer_gateway_env:patch(#{Key => Value}),

View File

@@ -20,6 +20,9 @@
-include_lib("fluxer_gateway/include/timeout_config.hrl").
-define(PID_CACHE_TABLE, presence_pid_cache).
-define(CACHE_TTL_MS, 300000).
-export([start_link/1]).
-export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]).
@@ -168,6 +171,7 @@ do_start_or_lookup(Request, State) ->
NewPresences = maps:put(
UserId, {RegisteredPid, Ref}, CleanPresences
),
update_cache(UserId, RegisteredPid),
{reply, {ok, RegisteredPid}, State#{
presences := NewPresences
}};
@@ -182,6 +186,7 @@ do_start_or_lookup(Request, State) ->
_ExistingPid ->
case process_registry:lookup_or_monitor(PresenceName, UserId, Presences) of
{ok, Pid, _Ref, NewPresences} ->
update_cache(UserId, Pid),
{reply, {ok, Pid}, State#{presences := NewPresences}};
{error, not_found} ->
{reply, {error, process_disappeared}, State}
@@ -201,12 +206,14 @@ lookup_presence(UserId, State) ->
{ok, Pid, Ref, NewPresences0} ->
CleanPresences = maps:remove(PresenceName, NewPresences0),
FinalPresences = maps:put(UserId, {Pid, Ref}, CleanPresences),
update_cache(UserId, Pid),
{ok, Pid, State#{presences := FinalPresences}};
{error, not_found} ->
{error, not_found, State}
end
end.
-spec terminate_sessions_for_user(user_id(), state()) -> {ok, state()}.
terminate_sessions_for_user(UserId, State) ->
Presences = maps:get(presences, State),
case maps:get(UserId, Presences, undefined) of
@@ -219,9 +226,31 @@ terminate_sessions_for_user(UserId, State) ->
{ok, Pid, Ref, NewPresences0} ->
CleanPresences = maps:remove(PresenceName, NewPresences0),
FinalPresences = maps:put(UserId, {Pid, Ref}, CleanPresences),
update_cache(UserId, Pid),
gen_server:cast(Pid, {terminate_all_sessions}),
{ok, State#{presences := FinalPresences}};
{error, not_found} ->
{ok, State}
end
end.
-spec update_cache(user_id(), pid()) -> ok.
update_cache(UserId, Pid) ->
Timestamp = erlang:monotonic_time(millisecond),
try
ets:insert(?PID_CACHE_TABLE, {UserId, Pid, Timestamp}),
ok
catch
_:_ -> ok
end.
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
init_returns_empty_presences_test() ->
{ok, State} = init(#{shard_index => 0}),
?assertEqual(#{}, maps:get(presences, State)).
update_cache_handles_missing_table_test() ->
?assertEqual(ok, update_cache(999, self())).
-endif.

View File

@@ -19,6 +19,10 @@
-export([build/5]).
-type status() :: online | offline | idle | dnd | invisible | binary().
-type custom_status() :: map() | null.
-spec build(map(), status(), boolean(), boolean(), custom_status()) -> map().
build(UserData, Status, Mobile, Afk, CustomStatus) ->
StatusBin = ensure_status_binary(Status),
#{
@@ -29,13 +33,16 @@ build(UserData, Status, Mobile, Afk, CustomStatus) ->
<<"custom_status">> => custom_status_for(StatusBin, CustomStatus)
}.
ensure_status_binary(Status) when is_atom(Status) ->
constants:status_type_atom(Status);
ensure_status_binary(Status) when is_binary(Status) ->
Status;
ensure_status_binary(_) ->
<<"offline">>.
-spec ensure_status_binary(status()) -> binary().
ensure_status_binary(online) -> <<"online">>;
ensure_status_binary(offline) -> <<"offline">>;
ensure_status_binary(idle) -> <<"idle">>;
ensure_status_binary(dnd) -> <<"dnd">>;
ensure_status_binary(invisible) -> <<"invisible">>;
ensure_status_binary(Status) when is_binary(Status) -> Status;
ensure_status_binary(_) -> <<"offline">>.
-spec custom_status_for(binary(), custom_status()) -> custom_status().
custom_status_for(StatusBin, CustomStatus) ->
case StatusBin of
<<"offline">> ->
@@ -46,9 +53,49 @@ custom_status_for(StatusBin, CustomStatus) ->
normalize_custom_status(CustomStatus)
end.
-spec normalize_custom_status(term()) -> custom_status().
normalize_custom_status(null) ->
null;
normalize_custom_status(CustomStatus) when is_map(CustomStatus) ->
CustomStatus;
normalize_custom_status(_) ->
null.
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
ensure_status_binary_atom_test() ->
?assertEqual(<<"online">>, ensure_status_binary(online)),
?assertEqual(<<"offline">>, ensure_status_binary(offline)),
?assertEqual(<<"idle">>, ensure_status_binary(idle)),
?assertEqual(<<"dnd">>, ensure_status_binary(dnd)),
?assertEqual(<<"invisible">>, ensure_status_binary(invisible)).
ensure_status_binary_binary_test() ->
?assertEqual(<<"online">>, ensure_status_binary(<<"online">>)),
?assertEqual(<<"custom">>, ensure_status_binary(<<"custom">>)).
ensure_status_binary_unknown_test() ->
?assertEqual(<<"offline">>, ensure_status_binary(123)),
?assertEqual(<<"offline">>, ensure_status_binary(undefined)).
custom_status_for_visible_test() ->
Status = #{<<"text">> => <<"hello">>},
?assertEqual(Status, custom_status_for(<<"online">>, Status)),
?assertEqual(Status, custom_status_for(<<"idle">>, Status)),
?assertEqual(Status, custom_status_for(<<"dnd">>, Status)).
custom_status_for_invisible_test() ->
Status = #{<<"text">> => <<"hello">>},
?assertEqual(null, custom_status_for(<<"offline">>, Status)),
?assertEqual(null, custom_status_for(<<"invisible">>, Status)).
custom_status_for_null_test() ->
?assertEqual(null, custom_status_for(<<"online">>, null)).
normalize_custom_status_test() ->
?assertEqual(null, normalize_custom_status(null)),
?assertEqual(#{<<"text">> => <<"hi">>}, normalize_custom_status(#{<<"text">> => <<"hi">>})),
?assertEqual(null, normalize_custom_status(<<"invalid">>)),
?assertEqual(null, normalize_custom_status(123)).
-endif.

View File

@@ -26,13 +26,40 @@
find_session_by_ref/2
]).
-type session_id() :: binary().
-type status() :: online | offline | idle | dnd | invisible.
-type session_entry() :: #{
session_id := session_id(),
status := status(),
afk := boolean(),
mobile := boolean(),
pid := pid(),
mref := reference(),
socket_pid := pid() | undefined
}.
-type sessions() :: #{session_id() => session_entry()}.
-type state() :: #{sessions := sessions(), _ => _}.
-type connect_request() :: #{
session_id := session_id(),
status := status(),
afk => boolean(),
mobile => boolean(),
socket_pid => pid() | undefined
}.
-type update_request() :: #{
session_id := session_id(),
status := status(),
afk => boolean()
}.
-spec handle_session_connect(connect_request(), pid(), state()) ->
{reply, {ok, [map()]}, state()}.
handle_session_connect(Request, Pid, State) ->
#{session_id := SessionId, status := Status} = Request,
Afk = maps:get(afk, Request, false),
Mobile = maps:get(mobile, Request, false),
SocketPid = maps:get(socket_pid, Request, undefined),
Sessions = maps:get(sessions, State),
case maps:is_key(SessionId, Sessions) of
true ->
SessionsData = presence_status:collect_sessions_for_replace(Sessions),
@@ -50,16 +77,15 @@ handle_session_connect(Request, Pid, State) ->
},
NewSessions = maps:put(SessionId, SessionEntry, Sessions),
NewState = maps:put(sessions, NewSessions, State),
SessionsData = presence_status:collect_sessions_for_replace(NewSessions),
{reply, {ok, SessionsData}, NewState}
end.
-spec handle_presence_update(update_request(), state()) -> {noreply, state()}.
handle_presence_update(Request, State) ->
#{session_id := SessionId, status := Status} = Request,
Afk = maps:get(afk, Request, false),
Sessions = maps:get(sessions, State),
case maps:get(SessionId, Sessions, undefined) of
undefined ->
{noreply, State};
@@ -71,18 +97,20 @@ handle_presence_update(Request, State) ->
{noreply, NewState}
end.
-spec dispatch_sessions_replace(state()) -> ok.
dispatch_sessions_replace(State) ->
Sessions = maps:get(sessions, State),
SessionsData = presence_status:collect_sessions_for_replace(Sessions),
SessionPids = [maps:get(pid, S) || S <- maps:values(Sessions)],
lists:foreach(
fun(Pid) when is_pid(Pid) ->
gen_server:cast(Pid, {dispatch, sessions_replace, SessionsData})
end,
SessionPids
).
),
ok.
-spec notify_sessions_guild_join(integer(), state()) -> ok.
notify_sessions_guild_join(GuildId, State) ->
Sessions = maps:get(sessions, State),
SessionPids = [maps:get(pid, S) || S <- maps:values(Sessions)],
@@ -91,8 +119,10 @@ notify_sessions_guild_join(GuildId, State) ->
gen_server:cast(Pid, {guild_join, GuildId})
end,
SessionPids
).
),
ok.
-spec notify_sessions_guild_leave(integer(), state()) -> ok.
notify_sessions_guild_leave(GuildId, State) ->
Sessions = maps:get(sessions, State),
SessionPids = [maps:get(pid, S) || S <- maps:values(Sessions)],
@@ -101,8 +131,10 @@ notify_sessions_guild_leave(GuildId, State) ->
gen_server:cast(Pid, {guild_leave, GuildId})
end,
SessionPids
).
),
ok.
-spec find_session_by_ref(reference(), sessions()) -> {ok, session_id()} | not_found.
find_session_by_ref(Ref, Sessions) ->
maps:fold(
fun(SessionId, S, Acc) ->
@@ -114,3 +146,25 @@ find_session_by_ref(Ref, Sessions) ->
not_found,
Sessions
).
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
find_session_by_ref_found_test() ->
Ref = make_ref(),
Sessions = #{
<<"s1">> => #{session_id => <<"s1">>, mref => make_ref()},
<<"s2">> => #{session_id => <<"s2">>, mref => Ref}
},
?assertEqual({ok, <<"s2">>}, find_session_by_ref(Ref, Sessions)).
find_session_by_ref_not_found_test() ->
Ref = make_ref(),
Sessions = #{
<<"s1">> => #{session_id => <<"s1">>, mref => make_ref()}
},
?assertEqual(not_found, find_session_by_ref(Ref, Sessions)).
find_session_by_ref_empty_test() ->
?assertEqual(not_found, find_session_by_ref(make_ref(), #{})).
-endif.

View File

@@ -24,15 +24,19 @@
collect_sessions_for_replace/1
]).
-type session_id() :: binary().
-type status() :: online | offline | idle | dnd | invisible.
-type session_entry() :: #{status := status(), afk := boolean(), mobile := boolean(), _ => _}.
-type sessions() :: #{session_id() => session_entry()}.
-spec get_current_status(sessions()) -> status().
get_current_status(Sessions) ->
AllStatuses = [maps:get(status, S) || S <- maps:values(Sessions)],
case lists:member(invisible, AllStatuses) of
true ->
invisible;
false ->
StatusPrecedence = [online, dnd, idle],
lists:foldl(
fun(Status, Acc) ->
case Acc of
@@ -50,6 +54,7 @@ get_current_status(Sessions) ->
)
end.
-spec get_flattened_mobile(sessions()) -> boolean().
get_flattened_mobile(Sessions) ->
lists:any(
fun(Session) ->
@@ -58,6 +63,7 @@ get_flattened_mobile(Sessions) ->
maps:values(Sessions)
).
-spec get_flattened_afk(sessions()) -> boolean().
get_flattened_afk(Sessions) ->
HasMobile = lists:any(
fun(Session) ->
@@ -65,7 +71,6 @@ get_flattened_afk(Sessions) ->
end,
maps:values(Sessions)
),
case HasMobile of
true ->
false;
@@ -83,6 +88,7 @@ get_flattened_afk(Sessions) ->
end
end.
-spec collect_sessions_for_replace(sessions()) -> [map()].
collect_sessions_for_replace(Sessions) ->
Status = get_current_status(Sessions),
Mobile = get_flattened_mobile(Sessions),
@@ -95,7 +101,6 @@ collect_sessions_for_replace(Sessions) ->
<<"afk">> => Afk
}
],
SessionEntries = lists:map(
fun({SessionId, Session}) ->
#{
@@ -107,5 +112,85 @@ collect_sessions_for_replace(Sessions) ->
end,
maps:to_list(Sessions)
),
BaseSessions ++ SessionEntries.
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
get_current_status_empty_test() ->
?assertEqual(offline, get_current_status(#{})).
get_current_status_online_test() ->
Sessions = #{<<"s1">> => #{status => online, afk => false, mobile => false}},
?assertEqual(online, get_current_status(Sessions)).
get_current_status_precedence_test() ->
Sessions = #{
<<"s1">> => #{status => idle, afk => false, mobile => false},
<<"s2">> => #{status => online, afk => false, mobile => false}
},
?assertEqual(online, get_current_status(Sessions)).
get_current_status_dnd_over_idle_test() ->
Sessions = #{
<<"s1">> => #{status => idle, afk => false, mobile => false},
<<"s2">> => #{status => dnd, afk => false, mobile => false}
},
?assertEqual(dnd, get_current_status(Sessions)).
get_current_status_invisible_test() ->
Sessions = #{
<<"s1">> => #{status => invisible, afk => false, mobile => false},
<<"s2">> => #{status => online, afk => false, mobile => false}
},
?assertEqual(invisible, get_current_status(Sessions)).
get_flattened_mobile_true_test() ->
Sessions = #{
<<"s1">> => #{status => online, afk => false, mobile => true},
<<"s2">> => #{status => online, afk => false, mobile => false}
},
?assertEqual(true, get_flattened_mobile(Sessions)).
get_flattened_mobile_false_test() ->
Sessions = #{
<<"s1">> => #{status => online, afk => false, mobile => false}
},
?assertEqual(false, get_flattened_mobile(Sessions)).
get_flattened_mobile_empty_test() ->
?assertEqual(false, get_flattened_mobile(#{})).
get_flattened_afk_all_afk_test() ->
Sessions = #{
<<"s1">> => #{status => online, afk => true, mobile => false},
<<"s2">> => #{status => online, afk => true, mobile => false}
},
?assertEqual(true, get_flattened_afk(Sessions)).
get_flattened_afk_some_not_afk_test() ->
Sessions = #{
<<"s1">> => #{status => online, afk => true, mobile => false},
<<"s2">> => #{status => online, afk => false, mobile => false}
},
?assertEqual(false, get_flattened_afk(Sessions)).
get_flattened_afk_mobile_overrides_test() ->
Sessions = #{
<<"s1">> => #{status => online, afk => true, mobile => true}
},
?assertEqual(false, get_flattened_afk(Sessions)).
get_flattened_afk_empty_test() ->
?assertEqual(false, get_flattened_afk(#{})).
collect_sessions_for_replace_test() ->
Sessions = #{
<<"s1">> => #{status => online, afk => false, mobile => false}
},
Result = collect_sessions_for_replace(Sessions),
?assertEqual(2, length(Result)),
[AllSession | Rest] = Result,
?assertEqual(<<"all">>, maps:get(<<"session_id">>, AllSession)),
?assertEqual(1, length(Rest)).
-endif.

View File

@@ -22,6 +22,17 @@
group_dm_recipients_from_state/1
]).
-type user_id() :: integer().
-type channel_id() :: integer().
-type relationship_type() :: integer().
-type state() :: #{
user_id := user_id(),
relationships => #{user_id() => relationship_type()},
channels => #{channel_id() => map()},
_ => _
}.
-spec friend_ids_from_state(state()) -> [user_id()].
friend_ids_from_state(State) ->
Relationships = maps:get(relationships, State, #{}),
[
@@ -30,6 +41,7 @@ friend_ids_from_state(State) ->
Type =:= 1 orelse Type =:= 3
].
-spec group_dm_recipients_from_state(state()) -> #{channel_id() => #{user_id() => true}}.
group_dm_recipients_from_state(State) ->
UserId = maps:get(user_id, State),
Channels = maps:get(channels, State, #{}),
@@ -42,6 +54,7 @@ group_dm_recipients_from_state(State) ->
]
).
-spec extract_recipient_ids(map()) -> [user_id()].
extract_recipient_ids(Channel) ->
Recipients = maps:get(<<"recipients">>, Channel, maps:get(<<"recipient_ids">>, Channel, [])),
Unique =
@@ -62,6 +75,7 @@ extract_recipient_ids(Channel) ->
),
lists:reverse(Unique).
-spec extract_recipient_id(term()) -> user_id() | undefined.
extract_recipient_id(Entry) when is_map(Entry) ->
type_conv:extract_id(Entry, <<"id">>);
extract_recipient_id(Entry) ->
@@ -74,6 +88,7 @@ extract_recipient_id(Entry) ->
undefined
end.
-spec map_from_ids([user_id()]) -> #{user_id() => true}.
map_from_ids(Ids) when is_list(Ids) ->
maps:from_list([{Id, true} || Id <- Ids]).
@@ -91,7 +106,78 @@ friend_ids_from_state_filters_relationship_types_test() ->
}
},
Ids = lists:sort(friend_ids_from_state(State)),
?assertEqual([10, 11], Ids),
ok.
?assertEqual([10, 11], Ids).
friend_ids_from_state_empty_test() ->
State = #{relationships => #{}},
?assertEqual([], friend_ids_from_state(State)).
friend_ids_from_state_missing_key_test() ->
State = #{},
?assertEqual([], friend_ids_from_state(State)).
group_dm_recipients_from_state_test() ->
State = #{
user_id => 1,
channels => #{
100 => #{
<<"type">> => 3,
<<"recipients">> => [
#{<<"id">> => <<"2">>},
#{<<"id">> => <<"3">>}
]
},
200 => #{
<<"type">> => 0,
<<"recipients">> => [#{<<"id">> => <<"4">>}]
}
}
},
Result = group_dm_recipients_from_state(State),
?assertEqual(#{100 => #{2 => true, 3 => true}}, Result).
group_dm_recipients_excludes_self_test() ->
State = #{
user_id => 2,
channels => #{
100 => #{
<<"type">> => 3,
<<"recipients">> => [
#{<<"id">> => <<"2">>},
#{<<"id">> => <<"3">>}
]
}
}
},
Result = group_dm_recipients_from_state(State),
?assertEqual(#{100 => #{3 => true}}, Result).
extract_recipient_id_map_test() ->
?assertEqual(123, extract_recipient_id(#{<<"id">> => <<"123">>})),
?assertEqual(undefined, extract_recipient_id(#{})).
extract_recipient_id_binary_test() ->
?assertEqual(456, extract_recipient_id(<<"456">>)).
extract_recipient_id_integer_test() ->
?assertEqual(789, extract_recipient_id(789)).
extract_recipient_id_invalid_test() ->
?assertEqual(undefined, extract_recipient_id(undefined)),
?assertEqual(undefined, extract_recipient_id([1, 2, 3])).
extract_recipient_ids_deduplicates_test() ->
Channel = #{
<<"recipients">> => [
#{<<"id">> => <<"1">>},
#{<<"id">> => <<"1">>},
#{<<"id">> => <<"2">>}
]
},
Ids = extract_recipient_ids(Channel),
?assertEqual([1, 2], Ids).
map_from_ids_test() ->
?assertEqual(#{}, map_from_ids([])),
?assertEqual(#{1 => true, 2 => true}, map_from_ids([1, 2])).
-endif.

View File

@@ -28,6 +28,8 @@
-define(PRESENCE_BATCH_SIZE, 500).
-type user_id() :: integer().
-spec collect_guild_member_presences(map()) -> [map()].
collect_guild_member_presences(GuildState) ->
MemberIds = collect_guild_member_ids(GuildState),
@@ -39,13 +41,13 @@ collect_guild_member_presences(GuildState) ->
[P || P <- Presences, is_visible_presence(P)]
end.
-spec collect_guild_member_ids(map()) -> [integer()].
-spec collect_guild_member_ids(map()) -> [user_id()].
collect_guild_member_ids(GuildState) ->
Members = get_members_from_guild_state(GuildState),
MemberIds = [member_user_id(M) || M <- Members],
[Id || Id <- MemberIds, Id =/= undefined].
-spec filter_self_presence(integer(), [map()]) -> [map()].
-spec filter_self_presence(user_id(), [map()]) -> [map()].
filter_self_presence(UserId, Presences) ->
[P || P <- Presences, presence_user_id(P) =/= UserId].
@@ -60,13 +62,14 @@ batch_presences([]) ->
batch_presences(Presences) ->
batch_presences(Presences, []).
-spec batch_presences([map()], [[map()]]) -> [[map()]].
batch_presences([], Acc) ->
lists:reverse(Acc);
batch_presences(Presences, Acc) ->
{Batch, Rest} = take_batch(Presences, ?PRESENCE_BATCH_SIZE),
batch_presences(Rest, [Batch | Acc]).
-spec send_presence_bulk(pid(), integer(), integer(), [map()]) -> ok.
-spec send_presence_bulk(pid(), integer(), user_id(), [map()]) -> ok.
send_presence_bulk(_Pid, _GuildId, _UserId, []) ->
ok;
send_presence_bulk(Pid, GuildId, UserId, Presences) ->
@@ -88,24 +91,28 @@ send_presence_bulk(Pid, GuildId, UserId, Presences) ->
)
end.
-spec get_members_from_guild_state(map()) -> [map()].
get_members_from_guild_state(GuildState) ->
case maps:get(data, GuildState, undefined) of
undefined ->
map_utils:ensure_list(maps:get(<<"members">>, GuildState, []));
guild_data_index:member_values(GuildState);
Data ->
map_utils:ensure_list(maps:get(<<"members">>, Data, []))
guild_data_index:member_values(Data)
end.
-spec member_user_id(map()) -> user_id() | undefined.
member_user_id(Member) ->
User = maps:get(<<"user">>, Member, #{}),
map_utils:get_integer(User, <<"id">>, undefined).
-spec presence_user_id(map() | term()) -> user_id() | undefined.
presence_user_id(P) when is_map(P) ->
User = maps:get(<<"user">>, P, #{}),
map_utils:get_integer(User, <<"id">>, undefined);
presence_user_id(_) ->
undefined.
-spec take_batch([T], pos_integer()) -> {[T], [T]} when T :: term().
take_batch(List, N) when length(List) =< N ->
{List, []};
take_batch(List, N) ->
@@ -184,4 +191,24 @@ collect_guild_member_ids_external_format_test() ->
Ids = collect_guild_member_ids(GuildState),
?assertEqual([100, 200], lists:sort(Ids)).
take_batch_small_list_test() ->
{Batch, Rest} = take_batch([1, 2, 3], 10),
?assertEqual([1, 2, 3], Batch),
?assertEqual([], Rest).
take_batch_exact_test() ->
{Batch, Rest} = take_batch([1, 2, 3], 3),
?assertEqual([1, 2, 3], Batch),
?assertEqual([], Rest).
take_batch_split_test() ->
{Batch, Rest} = take_batch([1, 2, 3, 4, 5], 2),
?assertEqual([1, 2], Batch),
?assertEqual([3, 4, 5], Rest).
presence_user_id_test() ->
?assertEqual(123, presence_user_id(#{<<"user">> => #{<<"id">> => <<"123">>}})),
?assertEqual(undefined, presence_user_id(#{<<"user">> => #{}})),
?assertEqual(undefined, presence_user_id(#{})),
?assertEqual(undefined, presence_user_id(invalid)).
-endif.

View File

@@ -29,16 +29,6 @@
-export([get_cache_stats/0]).
-export_type([state/0]).
-import(push_eligibility, [is_eligible_for_push/8]).
-import(push_cache, [
update_lru/2,
get_user_push_subscriptions/2,
cache_user_subscriptions/3,
invalidate_user_badge_count/2
]).
-import(push_sender, [send_push_notifications/8]).
-import(push_logger_filter, [install_progress_filter/0]).
-type state() :: #{
user_guild_settings_cache := map(),
user_guild_settings_lru := list(),
@@ -59,11 +49,12 @@
badge_counts_ttl_seconds := non_neg_integer()
}.
-spec start_link() -> {ok, pid()} | {error, term()}.
start_link() ->
gen_server:start_link({local, ?MODULE}, ?MODULE, [], []).
-spec init([]) -> {ok, state()}.
init([]) ->
install_progress_filter(),
PushEnabled = fluxer_gateway_env:get(push_enabled),
BaseState = #{
user_guild_settings_cache => #{},
@@ -102,7 +93,7 @@ init([]) ->
{ok, BaseState}
end.
-spec handle_call(term(), gen_server:from(), state()) -> {reply, term(), state()}.
handle_call(get_cache_stats, _From, State) ->
#{
user_guild_settings_cache := UgsCache,
@@ -120,6 +111,7 @@ handle_call(get_cache_stats, _From, State) ->
handle_call(_Request, _From, State) ->
{reply, ok, State}.
-spec handle_cast(term(), state()) -> {noreply, state()}.
handle_cast({handle_message_create, Params}, State) ->
{noreply, do_handle_message_create(Params, State)};
handle_cast({sync_user_guild_settings, UserId, GuildId, UserGuildSettings}, State) ->
@@ -129,7 +121,7 @@ handle_cast({sync_user_guild_settings, UserId, GuildId, UserGuildSettings}, Stat
} = State,
Key = {settings, UserId, GuildId},
NewCache = maps:put(Key, UserGuildSettings, UgsCache),
NewLru = update_lru(Key, UgsLru),
NewLru = push_cache:update_lru(Key, UgsLru),
{noreply, State#{
user_guild_settings_cache := NewCache,
user_guild_settings_lru := NewLru
@@ -141,7 +133,7 @@ handle_cast({sync_user_blocked_ids, UserId, BlockedIds}, State) ->
} = State,
Key = {blocked, UserId},
NewCache = maps:put(Key, BlockedIds, BiCache),
NewLru = update_lru(Key, BiLru),
NewLru = push_cache:update_lru(Key, BiLru),
{noreply, State#{
blocked_ids_cache := NewCache,
blocked_ids_lru := NewLru
@@ -153,24 +145,31 @@ handle_cast({cache_user_guild_settings, UserId, GuildId, Settings}, State) ->
} = State,
Key = {settings, UserId, GuildId},
NewCache = maps:put(Key, Settings, UgsCache),
NewLru = update_lru(Key, UgsLru),
NewLru = push_cache:update_lru(Key, UgsLru),
{noreply, State#{
user_guild_settings_cache := NewCache,
user_guild_settings_lru := NewLru
}};
handle_cast({invalidate_user_badge_count, UserId}, State) ->
{noreply, invalidate_user_badge_count(UserId, State)};
{noreply, push_cache:invalidate_user_badge_count(UserId, State)};
handle_cast(_Msg, State) ->
{noreply, State}.
-spec handle_info(term(), state()) -> {noreply, state()}.
handle_info(_Info, State) ->
{noreply, State}.
-spec terminate(term(), state()) -> ok.
terminate(_Reason, _State) ->
ok.
code_change(_OldVsn, {state, UgsCache, UgsLru, UgsSize, UgsMaxMb, PsCache, PsLru, PsSize, PsMaxMb,
BiCache, BiLru, BiSize, BiMaxMb, BcCache, BcLru, BcSize, BcMaxMb, BcTtl}, _Extra) ->
-spec code_change(term(), state() | tuple(), term()) -> {ok, state()}.
code_change(
_OldVsn,
{state, UgsCache, UgsLru, UgsSize, UgsMaxMb, PsCache, PsLru, PsSize, PsMaxMb, BiCache, BiLru,
BiSize, BiMaxMb, BcCache, BcLru, BcSize, BcMaxMb, BcTtl},
_Extra
) ->
{ok, #{
user_guild_settings_cache => UgsCache,
user_guild_settings_lru => UgsLru,
@@ -193,27 +192,41 @@ code_change(_OldVsn, {state, UgsCache, UgsLru, UgsSize, UgsMaxMb, PsCache, PsLru
code_change(_OldVsn, State, _Extra) ->
{ok, State}.
-spec handle_message_create(map()) -> ok.
handle_message_create(Params) ->
PushEnabled = fluxer_gateway_env:get(push_enabled),
case PushEnabled of
true -> gen_server:cast(?MODULE, {handle_message_create, Params});
case fluxer_gateway_env:get(push_enabled) of
true ->
logger:debug(
"Push: handle_message_create dispatched",
#{
author_id => maps:get(author_id, Params, undefined),
guild_id => maps:get(guild_id, Params, undefined),
user_count => length(maps:get(user_ids, Params, []))
}
),
gen_server:cast(?MODULE, {handle_message_create, Params});
false ->
logger:debug("Push: push_enabled=false, skipping message_create"),
ok
end.
-spec sync_user_guild_settings(integer(), integer(), map()) -> ok.
sync_user_guild_settings(UserId, GuildId, UserGuildSettings) ->
gen_server:cast(?MODULE, {sync_user_guild_settings, UserId, GuildId, UserGuildSettings}).
-spec sync_user_blocked_ids(integer(), [integer()]) -> ok.
sync_user_blocked_ids(UserId, BlockedIds) ->
gen_server:cast(?MODULE, {sync_user_blocked_ids, UserId, BlockedIds}).
-spec invalidate_user_badge_count(integer()) -> ok.
invalidate_user_badge_count(UserId) ->
gen_server:cast(?MODULE, {invalidate_user_badge_count, UserId}).
-spec get_cache_stats() -> {ok, map()}.
get_cache_stats() ->
gen_server:call(?MODULE, get_cache_stats, 5000).
-spec do_handle_message_create(map(), state()) -> state().
do_handle_message_create(Params, State) ->
MessageData = maps:get(message_data, Params),
UserIds = maps:get(user_ids, Params),
@@ -226,12 +239,18 @@ do_handle_message_create(Params, State) ->
GuildName = maps:get(guild_name, Params, undefined),
ChannelName = maps:get(channel_name, Params, undefined),
logger:debug(
"[push] Processing message ~p in channel ~p, guild ~p for users ~p (author ~p, defaults ~p)",
[MessageId, ChannelId, GuildId, UserIds, AuthorId, GuildDefaultNotifications]
"Push: evaluating eligibility",
#{
message_id => MessageId,
channel_id => ChannelId,
guild_id => GuildId,
author_id => AuthorId,
candidate_count => length(UserIds)
}
),
EligibleUsers = lists:filter(
fun(UserId) ->
Eligible = is_eligible_for_push(
push_eligibility:is_eligible_for_push(
UserId,
AuthorId,
GuildId,
@@ -240,18 +259,24 @@ do_handle_message_create(Params, State) ->
GuildDefaultNotifications,
UserRolesMap,
State
),
logger:debug("[push] User ~p eligible: ~p", [UserId, Eligible]),
Eligible
)
end,
UserIds
),
logger:debug("[push] Eligible users: ~p", [EligibleUsers]),
logger:debug(
"Push: eligibility result",
#{
message_id => MessageId,
channel_id => ChannelId,
eligible_count => length(EligibleUsers),
eligible_user_ids => EligibleUsers
}
),
case EligibleUsers of
[] ->
State;
_ ->
send_push_notifications(
push_dispatcher:enqueue_send_notifications(
EligibleUsers,
MessageData,
GuildId,
@@ -260,5 +285,6 @@ do_handle_message_create(Params, State) ->
GuildName,
ChannelName,
State
)
),
State
end.

View File

@@ -26,23 +26,20 @@
-export([evict_if_needed/4]).
-export([invalidate_user_badge_count/2]).
-spec update_lru(term(), list()) -> list().
update_lru(Key, Lru) ->
NewLru = lists:delete(Key, Lru),
[Key | NewLru].
-spec get_user_push_subscriptions(integer(), map()) -> list().
get_user_push_subscriptions(UserId, State) ->
Key = {subscriptions, UserId},
PushSubscriptionsCache = maps:get(push_subscriptions_cache, State, #{}),
case maps:get(Key, PushSubscriptionsCache, undefined) of
undefined ->
[];
Subs ->
Subs
end.
maps:get(Key, PushSubscriptionsCache, []).
-spec cache_user_subscriptions(integer(), list(), map()) -> map().
cache_user_subscriptions(UserId, Subscriptions, State) ->
Key = {subscriptions, UserId},
NewSubsSize = estimate_subscriptions_size(Subscriptions),
OldSubsSize =
case maps:get(Key, maps:get(push_subscriptions_cache, State, #{}), undefined) of
@@ -50,44 +47,34 @@ cache_user_subscriptions(UserId, Subscriptions, State) ->
OldSubs -> estimate_subscriptions_size(OldSubs)
end,
SizeDelta = NewSubsSize - OldSubsSize,
PushSubscriptionsLru = maps:get(push_subscriptions_lru, State, []),
NewLru = update_lru(Key, PushSubscriptionsLru),
PushSubscriptionsCache = maps:get(push_subscriptions_cache, State, #{}),
NewCache = maps:put(Key, Subscriptions, PushSubscriptionsCache),
PushSubscriptionsSize = maps:get(push_subscriptions_size, State, 0),
NewSize = PushSubscriptionsSize + SizeDelta,
MaxBytes =
case maps:get(push_subscriptions_max_mb, State, undefined) of
undefined -> NewSize;
Mb -> Mb * 1024 * 1024
end,
{FinalCache, FinalLru, FinalSize} = evict_if_needed(
NewCache, NewLru, NewSize, MaxBytes
),
{FinalCache, FinalLru, FinalSize} = evict_if_needed(NewCache, NewLru, NewSize, MaxBytes),
State#{
push_subscriptions_cache => FinalCache,
push_subscriptions_lru => FinalLru,
push_subscriptions_size => FinalSize
}.
-spec get_user_badge_count(integer(), map()) -> {non_neg_integer(), integer()} | undefined.
get_user_badge_count(UserId, State) ->
Key = {badge_count, UserId},
BadgeCountsCache = maps:get(badge_counts_cache, State, #{}),
case maps:get(Key, BadgeCountsCache, undefined) of
undefined ->
undefined;
Badge ->
Badge
end.
maps:get(Key, BadgeCountsCache, undefined).
-spec cache_user_badge_count(integer(), non_neg_integer(), integer(), map()) -> map().
cache_user_badge_count(UserId, BadgeCount, CachedAt, State) ->
Key = {badge_count, UserId},
NewBadge = {BadgeCount, CachedAt},
OldBadgeSize =
case maps:get(Key, maps:get(badge_counts_cache, State, #{}), undefined) of
undefined -> 0;
@@ -95,40 +82,41 @@ cache_user_badge_count(UserId, BadgeCount, CachedAt, State) ->
end,
NewBadgeSize = estimate_badge_count_size(NewBadge),
SizeDelta = NewBadgeSize - OldBadgeSize,
BadgeCountsLru = maps:get(badge_counts_lru, State, []),
NewLru = update_lru(Key, BadgeCountsLru),
BadgeCountsCache = maps:get(badge_counts_cache, State, #{}),
NewCache = maps:put(Key, NewBadge, BadgeCountsCache),
BadgeCountsSize = maps:get(badge_counts_size, State, 0),
NewSize = BadgeCountsSize + SizeDelta,
MaxBytes =
case maps:get(badge_counts_max_mb, State, undefined) of
undefined -> NewSize;
Mb -> Mb * 1024 * 1024
end,
{FinalCache, FinalLru, FinalSize} = evict_if_needed(
NewCache, NewLru, NewSize, MaxBytes
),
{FinalCache, FinalLru, FinalSize} = evict_if_needed(NewCache, NewLru, NewSize, MaxBytes),
State#{
badge_counts_cache => FinalCache,
badge_counts_lru => FinalLru,
badge_counts_size => FinalSize
}.
-spec estimate_subscriptions_size(list()) -> non_neg_integer().
estimate_subscriptions_size(Subscriptions) ->
length(Subscriptions) * 200.
-spec estimate_badge_count_size({non_neg_integer(), integer()}) -> non_neg_integer().
estimate_badge_count_size({_Count, _Timestamp}) ->
64.
-spec evict_if_needed(map(), list(), non_neg_integer(), non_neg_integer()) ->
{map(), list(), non_neg_integer()}.
evict_if_needed(Cache, Lru, Size, MaxBytes) when Size > MaxBytes ->
evict_oldest(Cache, Lru, Size, MaxBytes, lists:reverse(Lru));
evict_if_needed(Cache, Lru, Size, _MaxBytes) ->
{Cache, Lru, Size}.
-spec evict_oldest(map(), list(), non_neg_integer(), non_neg_integer(), list()) ->
{map(), list(), non_neg_integer()}.
evict_oldest(Cache, Lru, Size, _MaxBytes, []) ->
{Cache, Lru, Size};
evict_oldest(Cache, Lru, Size, MaxBytes, [OldestKey | Remaining]) ->
@@ -142,6 +130,7 @@ evict_oldest(Cache, Lru, Size, MaxBytes, [OldestKey | Remaining]) ->
evict_if_needed(NewCache, NewLru, NewSize, MaxBytes)
end.
-spec invalidate_user_badge_count(integer(), map()) -> map().
invalidate_user_badge_count(UserId, State) ->
Key = {badge_count, UserId},
BadgeCountsCache = maps:get(badge_counts_cache, State, #{}),
@@ -159,3 +148,40 @@ invalidate_user_badge_count(UserId, State) ->
badge_counts_size => NewSize
}
end.
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
update_lru_test() ->
?assertEqual([a], update_lru(a, [])),
?assertEqual([b, a], update_lru(b, [a])),
?assertEqual([a, b, c], update_lru(a, [b, a, c])).
estimate_subscriptions_size_test() ->
?assertEqual(0, estimate_subscriptions_size([])),
?assertEqual(200, estimate_subscriptions_size([#{}])),
?assertEqual(400, estimate_subscriptions_size([#{}, #{}])).
estimate_badge_count_size_test() ->
?assertEqual(64, estimate_badge_count_size({0, 0})),
?assertEqual(64, estimate_badge_count_size({100, 12345})).
evict_if_needed_no_eviction_test() ->
Cache = #{a => [1, 2]},
Lru = [a],
{ResultCache, ResultLru, ResultSize} = evict_if_needed(Cache, Lru, 400, 1000),
?assertEqual(Cache, ResultCache),
?assertEqual(Lru, ResultLru),
?assertEqual(400, ResultSize).
get_user_push_subscriptions_test() ->
State = #{push_subscriptions_cache => #{{subscriptions, 123} => [sub1, sub2]}},
?assertEqual([sub1, sub2], get_user_push_subscriptions(123, State)),
?assertEqual([], get_user_push_subscriptions(999, State)).
get_user_badge_count_test() ->
State = #{badge_counts_cache => #{{badge_count, 123} => {5, 1000}}},
?assertEqual({5, 1000}, get_user_badge_count(123, State)),
?assertEqual(undefined, get_user_badge_count(999, State)).
-endif.

Some files were not shown because too many files have changed in this diff Show More