refactor progress
This commit is contained in:
@@ -19,36 +19,25 @@
|
||||
-behaviour(application).
|
||||
-export([start/2, stop/1]).
|
||||
|
||||
-spec start(application:start_type(), term()) -> {ok, pid()} | {error, term()}.
|
||||
start(_StartType, _StartArgs) ->
|
||||
fluxer_gateway_env:load(),
|
||||
|
||||
WsPort = fluxer_gateway_env:get(ws_port),
|
||||
RpcPort = fluxer_gateway_env:get(rpc_port),
|
||||
|
||||
otel_metrics:init(),
|
||||
Port = fluxer_gateway_env:get(port),
|
||||
Dispatch = cowboy_router:compile([
|
||||
{'_', [
|
||||
{<<"/_health">>, health_handler, []},
|
||||
{<<"/_rpc">>, gateway_rpc_http_handler, []},
|
||||
{<<"/_admin/reload">>, hot_reload_handler, []},
|
||||
{<<"/">>, gateway_handler, []}
|
||||
]}
|
||||
]),
|
||||
|
||||
{ok, _} = cowboy:start_clear(http, [{port, WsPort}], #{
|
||||
{ok, _} = cowboy:start_clear(http, [{port, Port}], #{
|
||||
env => #{dispatch => Dispatch},
|
||||
max_frame_size => 4096
|
||||
}),
|
||||
|
||||
RpcDispatch = cowboy_router:compile([
|
||||
{'_', [
|
||||
{<<"/_rpc">>, gateway_rpc_http_handler, []},
|
||||
{<<"/_admin/reload">>, hot_reload_handler, []}
|
||||
]}
|
||||
]),
|
||||
|
||||
{ok, _} = cowboy:start_clear(rpc_http, [{port, RpcPort}], #{
|
||||
env => #{dispatch => RpcDispatch}
|
||||
}),
|
||||
|
||||
fluxer_gateway_sup:start_link().
|
||||
|
||||
-spec stop(term()) -> ok.
|
||||
stop(_State) ->
|
||||
ok.
|
||||
|
||||
311
fluxer_gateway/src/gateway/fluxer_gateway_config.erl
Normal file
311
fluxer_gateway/src/gateway/fluxer_gateway_config.erl
Normal file
@@ -0,0 +1,311 @@
|
||||
%% Copyright (C) 2026 Fluxer Contributors
|
||||
%%
|
||||
%% This file is part of Fluxer.
|
||||
%%
|
||||
%% Fluxer is free software: you can redistribute it and/or modify
|
||||
%% it under the terms of the GNU Affero General Public License as published by
|
||||
%% the Free Software Foundation, either version 3 of the License, or
|
||||
%% (at your option) any later version.
|
||||
%%
|
||||
%% Fluxer is distributed in the hope that it will be useful,
|
||||
%% but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
%% GNU Affero General Public License for more details.
|
||||
%%
|
||||
%% You should have received a copy of the GNU Affero General Public License
|
||||
%% along with Fluxer. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
-module(fluxer_gateway_config).
|
||||
|
||||
-export([load/0, load_from/1]).
|
||||
|
||||
-type config() :: map().
|
||||
-type log_level() :: debug | info | notice | warning | error | critical | alert | emergency.
|
||||
|
||||
-spec load() -> config().
|
||||
load() ->
|
||||
case os:getenv("FLUXER_CONFIG") of
|
||||
false -> erlang:error({missing_env, "FLUXER_CONFIG"});
|
||||
"" -> erlang:error({missing_env, "FLUXER_CONFIG"});
|
||||
Path -> load_from(Path)
|
||||
end.
|
||||
|
||||
-spec load_from(string()) -> config().
|
||||
load_from(Path) when is_list(Path) ->
|
||||
case file:read_file(Path) of
|
||||
{ok, Content} ->
|
||||
Json = json:decode(Content),
|
||||
build_config(Json);
|
||||
{error, Reason} ->
|
||||
erlang:error({json_read_failed, Path, Reason})
|
||||
end.
|
||||
|
||||
-spec build_config(map()) -> config().
|
||||
build_config(Json) ->
|
||||
Service = get_map(Json, [<<"services">>, <<"gateway">>]),
|
||||
Gateway = get_map(Json, [<<"gateway">>]),
|
||||
Telemetry = get_map(Json, [<<"telemetry">>]),
|
||||
Sentry = get_map(Json, [<<"sentry">>]),
|
||||
Vapid = get_map(Json, [<<"auth">>, <<"vapid">>]),
|
||||
#{
|
||||
port => get_int(Service, <<"port">>, 8080),
|
||||
rpc_tcp_port => get_int(Service, <<"rpc_tcp_port">>, 8772),
|
||||
api_host => get_env_or_string("FLUXER_GATEWAY_API_HOST", Service, <<"api_host">>, "api"),
|
||||
api_canary_host => get_optional_string(Service, <<"api_canary_host">>),
|
||||
admin_reload_secret => get_optional_binary(Service, <<"admin_reload_secret">>),
|
||||
rpc_secret_key => get_binary(Gateway, <<"rpc_secret">>, undefined),
|
||||
identify_rate_limit_enabled => get_bool(Service, <<"identify_rate_limit_enabled">>, false),
|
||||
push_enabled => get_bool(Service, <<"push_enabled">>, true),
|
||||
push_user_guild_settings_cache_mb => get_int(
|
||||
Service,
|
||||
<<"push_user_guild_settings_cache_mb">>,
|
||||
1024
|
||||
),
|
||||
push_subscriptions_cache_mb => get_int(Service, <<"push_subscriptions_cache_mb">>, 1024),
|
||||
push_blocked_ids_cache_mb => get_int(Service, <<"push_blocked_ids_cache_mb">>, 1024),
|
||||
presence_cache_shards => get_optional_int(Service, <<"presence_cache_shards">>),
|
||||
presence_bus_shards => get_optional_int(Service, <<"presence_bus_shards">>),
|
||||
presence_shards => get_optional_int(Service, <<"presence_shards">>),
|
||||
guild_shards => get_optional_int(Service, <<"guild_shards">>),
|
||||
session_shards => get_optional_int(Service, <<"session_shards">>),
|
||||
push_badge_counts_cache_mb => get_int(Service, <<"push_badge_counts_cache_mb">>, 256),
|
||||
push_badge_counts_cache_ttl_seconds =>
|
||||
get_int(Service, <<"push_badge_counts_cache_ttl_seconds">>, 60),
|
||||
push_dispatcher_max_inflight => get_int(Service, <<"push_dispatcher_max_inflight">>, 16),
|
||||
push_dispatcher_max_queue => get_int(Service, <<"push_dispatcher_max_queue">>, 2048),
|
||||
gateway_http_rpc_connect_timeout_ms =>
|
||||
get_int(Service, <<"gateway_http_rpc_connect_timeout_ms">>, 5000),
|
||||
gateway_http_rpc_recv_timeout_ms =>
|
||||
get_int(Service, <<"gateway_http_rpc_recv_timeout_ms">>, 30000),
|
||||
gateway_http_push_connect_timeout_ms =>
|
||||
get_int(Service, <<"gateway_http_push_connect_timeout_ms">>, 3000),
|
||||
gateway_http_push_recv_timeout_ms =>
|
||||
get_int(Service, <<"gateway_http_push_recv_timeout_ms">>, 5000),
|
||||
gateway_http_rpc_max_concurrency =>
|
||||
get_int(Service, <<"gateway_http_rpc_max_concurrency">>, 512),
|
||||
gateway_rpc_tcp_max_input_buffer_bytes =>
|
||||
get_int(Service, <<"gateway_rpc_tcp_max_input_buffer_bytes">>, 2097152),
|
||||
gateway_http_push_max_concurrency =>
|
||||
get_int(Service, <<"gateway_http_push_max_concurrency">>, 256),
|
||||
gateway_http_failure_threshold =>
|
||||
get_int(Service, <<"gateway_http_failure_threshold">>, 6),
|
||||
gateway_http_recovery_timeout_ms =>
|
||||
get_int(Service, <<"gateway_http_recovery_timeout_ms">>, 15000),
|
||||
gateway_http_cleanup_interval_ms =>
|
||||
get_int(Service, <<"gateway_http_cleanup_interval_ms">>, 30000),
|
||||
gateway_http_cleanup_max_age_ms =>
|
||||
get_int(Service, <<"gateway_http_cleanup_max_age_ms">>, 300000),
|
||||
media_proxy_endpoint => get_optional_binary(Service, <<"media_proxy_endpoint">>),
|
||||
vapid_email => get_binary(Vapid, <<"email">>, <<>>),
|
||||
vapid_public_key => get_optional_binary(Vapid, <<"public_key">>),
|
||||
vapid_private_key => get_optional_binary(Vapid, <<"private_key">>),
|
||||
gateway_metrics_enabled => get_optional_bool(Service, <<"gateway_metrics_enabled">>),
|
||||
gateway_metrics_report_interval_ms =>
|
||||
get_optional_int(Service, <<"gateway_metrics_report_interval_ms">>),
|
||||
release_node => get_string(Service, <<"release_node">>, "fluxer_gateway@127.0.0.1"),
|
||||
logger_level => get_log_level(Service, <<"logger_level">>, info),
|
||||
telemetry => #{
|
||||
enabled => get_bool(Telemetry, <<"enabled">>, true),
|
||||
otlp_endpoint => get_string(Telemetry, <<"otlp_endpoint">>, ""),
|
||||
api_key => get_string(Telemetry, <<"api_key">>, ""),
|
||||
service_name => get_string(Telemetry, <<"service_name">>, "fluxer-gateway"),
|
||||
environment => get_string(Telemetry, <<"environment">>, "development"),
|
||||
trace_sampling_ratio => get_float(Telemetry, <<"trace_sampling_ratio">>, 1.0)
|
||||
},
|
||||
sentry => #{
|
||||
build_sha => get_string(Sentry, <<"build_sha">>, ""),
|
||||
release_channel => get_string(Sentry, <<"release_channel">>, "")
|
||||
}
|
||||
}.
|
||||
|
||||
-spec get_map(map(), [binary()]) -> map().
|
||||
get_map(Map, Keys) ->
|
||||
case get_in(Map, Keys) of
|
||||
Value when is_map(Value) -> Value;
|
||||
_ -> #{}
|
||||
end.
|
||||
|
||||
-spec get_int(map(), binary(), integer()) -> integer().
|
||||
get_int(Map, Key, Default) when is_integer(Default) ->
|
||||
to_int(get_value(Map, Key), Default).
|
||||
|
||||
-spec get_optional_int(map(), binary()) -> integer() | undefined.
|
||||
get_optional_int(Map, Key) ->
|
||||
to_optional_int(get_value(Map, Key)).
|
||||
|
||||
-spec get_bool(map(), binary(), boolean()) -> boolean().
|
||||
get_bool(Map, Key, Default) when is_boolean(Default) ->
|
||||
to_bool(get_value(Map, Key), Default).
|
||||
|
||||
-spec get_optional_bool(map(), binary()) -> boolean() | undefined.
|
||||
get_optional_bool(Map, Key) ->
|
||||
case get_value(Map, Key) of
|
||||
undefined -> undefined;
|
||||
Value -> to_bool(Value, undefined)
|
||||
end.
|
||||
|
||||
-spec get_string(map(), binary(), string()) -> string().
|
||||
get_string(Map, Key, Default) when is_list(Default) ->
|
||||
to_string(get_value(Map, Key), Default).
|
||||
|
||||
-spec get_env_or_string(string(), map(), binary(), string()) -> string().
|
||||
get_env_or_string(EnvVar, Map, Key, Default) when is_list(EnvVar), is_list(Default) ->
|
||||
case os:getenv(EnvVar) of
|
||||
false -> get_string(Map, Key, Default);
|
||||
"" -> get_string(Map, Key, Default);
|
||||
Value -> Value
|
||||
end.
|
||||
|
||||
-spec get_optional_string(map(), binary()) -> string() | undefined.
|
||||
get_optional_string(Map, Key) ->
|
||||
case get_value(Map, Key) of
|
||||
undefined ->
|
||||
undefined;
|
||||
Value ->
|
||||
Clean = string:trim(to_string(Value, "")),
|
||||
case Clean of
|
||||
"" -> undefined;
|
||||
_ -> Clean
|
||||
end
|
||||
end.
|
||||
|
||||
-spec get_binary(map(), binary(), binary() | undefined) -> binary() | undefined.
|
||||
get_binary(Map, Key, Default) ->
|
||||
to_binary(get_value(Map, Key), Default).
|
||||
|
||||
-spec get_optional_binary(map(), binary()) -> binary() | undefined.
|
||||
get_optional_binary(Map, Key) ->
|
||||
case get_value(Map, Key) of
|
||||
undefined -> undefined;
|
||||
Value -> to_binary(Value, undefined)
|
||||
end.
|
||||
|
||||
-spec get_log_level(map(), binary(), log_level()) -> log_level().
|
||||
get_log_level(Map, Key, Default) when is_atom(Default) ->
|
||||
Value = get_value(Map, Key),
|
||||
case normalize_log_level(Value) of
|
||||
undefined -> Default;
|
||||
Level -> Level
|
||||
end.
|
||||
|
||||
-spec get_float(map(), binary(), number()) -> float().
|
||||
get_float(Map, Key, Default) when is_number(Default) ->
|
||||
to_float(get_value(Map, Key), Default).
|
||||
|
||||
-spec get_in(term(), [binary()]) -> term().
|
||||
get_in(Map, [Key | Rest]) when is_map(Map) ->
|
||||
case get_value(Map, Key) of
|
||||
undefined -> undefined;
|
||||
Value when Rest =:= [] -> Value;
|
||||
Value -> get_in(Value, Rest)
|
||||
end;
|
||||
get_in(_, _) ->
|
||||
undefined.
|
||||
|
||||
-spec get_value(term(), binary()) -> term().
|
||||
get_value(Map, Key) when is_map(Map) ->
|
||||
case maps:get(Key, Map, undefined) of
|
||||
undefined when is_binary(Key) ->
|
||||
maps:get(binary_to_list(Key), Map, undefined);
|
||||
Value ->
|
||||
Value
|
||||
end.
|
||||
|
||||
-spec to_int(term(), integer() | undefined) -> integer() | undefined.
|
||||
to_int(Value, _Default) when is_integer(Value) ->
|
||||
Value;
|
||||
to_int(Value, _Default) when is_float(Value) ->
|
||||
trunc(Value);
|
||||
to_int(Value, Default) ->
|
||||
case to_string(Value, "") of
|
||||
"" ->
|
||||
Default;
|
||||
Str ->
|
||||
case string:to_integer(Str) of
|
||||
{Int, _} when is_integer(Int) -> Int;
|
||||
{error, _} -> Default
|
||||
end
|
||||
end.
|
||||
|
||||
-spec to_optional_int(term()) -> integer() | undefined.
|
||||
to_optional_int(Value) ->
|
||||
case to_int(Value, undefined) of
|
||||
undefined -> undefined;
|
||||
Int -> Int
|
||||
end.
|
||||
|
||||
-spec to_bool(term(), boolean() | undefined) -> boolean() | undefined.
|
||||
to_bool(Value, _Default) when is_boolean(Value) ->
|
||||
Value;
|
||||
to_bool(Value, Default) when is_atom(Value) ->
|
||||
case Value of
|
||||
true -> true;
|
||||
false -> false;
|
||||
_ -> Default
|
||||
end;
|
||||
to_bool(Value, Default) ->
|
||||
case string:lowercase(to_string(Value, "")) of
|
||||
"true" -> true;
|
||||
"1" -> true;
|
||||
"false" -> false;
|
||||
"0" -> false;
|
||||
_ -> Default
|
||||
end.
|
||||
|
||||
-spec to_string(term(), string()) -> string().
|
||||
to_string(Value, Default) when is_list(Default) ->
|
||||
case Value of
|
||||
undefined -> Default;
|
||||
Bin when is_binary(Bin) -> binary_to_list(Bin);
|
||||
Str when is_list(Str) -> Str;
|
||||
Atom when is_atom(Atom) -> atom_to_list(Atom);
|
||||
_ -> Default
|
||||
end.
|
||||
|
||||
-spec to_binary(term(), binary() | undefined) -> binary() | undefined.
|
||||
to_binary(Value, Default) ->
|
||||
case Value of
|
||||
undefined -> Default;
|
||||
Bin when is_binary(Bin) -> Bin;
|
||||
Str when is_list(Str) -> list_to_binary(Str);
|
||||
Atom when is_atom(Atom) -> list_to_binary(atom_to_list(Atom));
|
||||
_ -> Default
|
||||
end.
|
||||
|
||||
-spec to_float(term(), float()) -> float().
|
||||
to_float(Value, _Default) when is_float(Value) ->
|
||||
Value;
|
||||
to_float(Value, _Default) when is_integer(Value) ->
|
||||
float(Value);
|
||||
to_float(Value, Default) ->
|
||||
case to_string(Value, "") of
|
||||
"" ->
|
||||
Default;
|
||||
Str ->
|
||||
case string:to_float(Str) of
|
||||
{Float, _} when is_float(Float) -> Float;
|
||||
{error, _} -> Default
|
||||
end
|
||||
end.
|
||||
|
||||
-spec normalize_log_level(term()) -> log_level() | undefined.
|
||||
normalize_log_level(undefined) ->
|
||||
undefined;
|
||||
normalize_log_level(Level) when is_atom(Level) ->
|
||||
normalize_log_level(atom_to_list(Level));
|
||||
normalize_log_level(Level) when is_binary(Level) ->
|
||||
normalize_log_level(binary_to_list(Level));
|
||||
normalize_log_level(Level) when is_list(Level) ->
|
||||
case string:lowercase(string:trim(Level)) of
|
||||
"debug" -> debug;
|
||||
"info" -> info;
|
||||
"notice" -> notice;
|
||||
"warning" -> warning;
|
||||
"error" -> error;
|
||||
"critical" -> critical;
|
||||
"alert" -> alert;
|
||||
"emergency" -> emergency;
|
||||
_ -> undefined
|
||||
end;
|
||||
normalize_log_level(_) ->
|
||||
undefined.
|
||||
302
fluxer_gateway/src/gateway/fluxer_gateway_crypto.erl
Normal file
302
fluxer_gateway/src/gateway/fluxer_gateway_crypto.erl
Normal file
@@ -0,0 +1,302 @@
|
||||
%% Copyright (C) 2026 Fluxer Contributors
|
||||
%%
|
||||
%% This file is part of Fluxer.
|
||||
%%
|
||||
%% Fluxer is free software: you can redistribute it and/or modify
|
||||
%% it under the terms of the GNU Affero General Public License as published by
|
||||
%% the Free Software Foundation, either version 3 of the License, or
|
||||
%% (at your option) any later version.
|
||||
%%
|
||||
%% Fluxer is distributed in the hope that it will be useful,
|
||||
%% but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
%% GNU Affero General Public License for more details.
|
||||
%%
|
||||
%% You should have received a copy of the GNU Affero General Public License
|
||||
%% along with Fluxer. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
-module(fluxer_gateway_crypto).
|
||||
|
||||
-export([
|
||||
init/0,
|
||||
decrypt/2,
|
||||
encrypt/2,
|
||||
derive_shared_secret/2,
|
||||
generate_keypair/0,
|
||||
get_public_key/0,
|
||||
new_crypto_state/1,
|
||||
is_encrypted_frame/1,
|
||||
unwrap_encrypted_frame/1,
|
||||
wrap_encrypted_frame/1
|
||||
]).
|
||||
|
||||
-define(KEYPAIR_KEY, {?MODULE, instance_keypair}).
|
||||
-define(ENCRYPTED_FRAME_PREFIX, 16#FE).
|
||||
-define(NONCE_SIZE, 12).
|
||||
-define(TAG_SIZE, 16).
|
||||
-define(KEY_SIZE, 32).
|
||||
|
||||
-type keypair() :: #{public := binary(), private := binary()}.
|
||||
-type crypto_state() :: #{
|
||||
shared_secret := binary(),
|
||||
send_counter := non_neg_integer(),
|
||||
recv_counter := non_neg_integer()
|
||||
}.
|
||||
|
||||
-export_type([keypair/0, crypto_state/0]).
|
||||
|
||||
-spec init() -> ok.
|
||||
init() ->
|
||||
case persistent_term:get(?KEYPAIR_KEY, undefined) of
|
||||
undefined ->
|
||||
Keypair = generate_keypair(),
|
||||
persistent_term:put(?KEYPAIR_KEY, Keypair),
|
||||
ok;
|
||||
_ ->
|
||||
ok
|
||||
end.
|
||||
|
||||
-spec generate_keypair() -> keypair().
|
||||
generate_keypair() ->
|
||||
{Public, Private} = crypto:generate_key(ecdh, x25519),
|
||||
#{public => Public, private => Private}.
|
||||
|
||||
-spec get_public_key() -> binary() | undefined.
|
||||
get_public_key() ->
|
||||
case persistent_term:get(?KEYPAIR_KEY, undefined) of
|
||||
undefined -> undefined;
|
||||
#{public := Public} -> Public
|
||||
end.
|
||||
|
||||
-spec derive_shared_secret(binary(), keypair()) -> {ok, binary()} | {error, term()}.
|
||||
derive_shared_secret(PeerPublic, #{private := Private}) when
|
||||
byte_size(PeerPublic) =:= ?KEY_SIZE
|
||||
->
|
||||
try
|
||||
SharedSecret = crypto:compute_key(ecdh, PeerPublic, Private, x25519),
|
||||
{ok, SharedSecret}
|
||||
catch
|
||||
error:Reason ->
|
||||
{error, {key_exchange_failed, Reason}}
|
||||
end;
|
||||
derive_shared_secret(PeerPublic, _Keypair) ->
|
||||
{error, {invalid_peer_key_size, byte_size(PeerPublic)}}.
|
||||
|
||||
-spec new_crypto_state(binary()) -> crypto_state().
|
||||
new_crypto_state(SharedSecret) when byte_size(SharedSecret) =:= ?KEY_SIZE ->
|
||||
#{
|
||||
shared_secret => SharedSecret,
|
||||
send_counter => 0,
|
||||
recv_counter => 0
|
||||
}.
|
||||
|
||||
-spec encrypt(binary(), crypto_state()) ->
|
||||
{ok, binary(), crypto_state()} | {error, term()}.
|
||||
encrypt(Plaintext, State = #{shared_secret := Key, send_counter := Counter}) ->
|
||||
try
|
||||
Nonce = counter_to_nonce(Counter),
|
||||
AAD = <<>>,
|
||||
{Ciphertext, Tag} = crypto:crypto_one_time_aead(
|
||||
aes_256_gcm,
|
||||
Key,
|
||||
Nonce,
|
||||
Plaintext,
|
||||
AAD,
|
||||
?TAG_SIZE,
|
||||
true
|
||||
),
|
||||
Encrypted = <<Nonce/binary, Tag/binary, Ciphertext/binary>>,
|
||||
NewState = State#{send_counter => Counter + 1},
|
||||
{ok, Encrypted, NewState}
|
||||
catch
|
||||
error:Reason ->
|
||||
{error, {encrypt_failed, Reason}}
|
||||
end.
|
||||
|
||||
-spec decrypt(binary(), crypto_state()) ->
|
||||
{ok, binary(), crypto_state()} | {error, term()}.
|
||||
decrypt(Data, State = #{shared_secret := Key, recv_counter := Counter}) ->
|
||||
MinSize = ?NONCE_SIZE + ?TAG_SIZE,
|
||||
case byte_size(Data) > MinSize of
|
||||
false ->
|
||||
{error, {invalid_encrypted_data, too_short}};
|
||||
true ->
|
||||
<<Nonce:?NONCE_SIZE/binary, Tag:?TAG_SIZE/binary, Ciphertext/binary>> = Data,
|
||||
ExpectedNonce = counter_to_nonce(Counter),
|
||||
case validate_nonce(Nonce, ExpectedNonce, Counter) of
|
||||
{ok, ActualCounter} ->
|
||||
do_decrypt(Ciphertext, Key, Nonce, Tag, State, ActualCounter);
|
||||
{error, Reason} ->
|
||||
{error, Reason}
|
||||
end
|
||||
end.
|
||||
|
||||
-spec do_decrypt(binary(), binary(), binary(), binary(), crypto_state(), non_neg_integer()) ->
|
||||
{ok, binary(), crypto_state()} | {error, term()}.
|
||||
do_decrypt(Ciphertext, Key, Nonce, Tag, State, ActualCounter) ->
|
||||
AAD = <<>>,
|
||||
try
|
||||
case crypto:crypto_one_time_aead(
|
||||
aes_256_gcm,
|
||||
Key,
|
||||
Nonce,
|
||||
Ciphertext,
|
||||
AAD,
|
||||
Tag,
|
||||
false
|
||||
) of
|
||||
Plaintext when is_binary(Plaintext) ->
|
||||
NewState = State#{recv_counter => ActualCounter + 1},
|
||||
{ok, Plaintext, NewState};
|
||||
error ->
|
||||
{error, authentication_failed}
|
||||
end
|
||||
catch
|
||||
error:Reason ->
|
||||
{error, {decrypt_failed, Reason}}
|
||||
end.
|
||||
|
||||
-spec counter_to_nonce(non_neg_integer()) -> binary().
|
||||
counter_to_nonce(Counter) ->
|
||||
<<0:32, Counter:64/big-unsigned-integer>>.
|
||||
|
||||
-spec validate_nonce(binary(), binary(), non_neg_integer()) ->
|
||||
{ok, non_neg_integer()} | {error, term()}.
|
||||
validate_nonce(Nonce, ExpectedNonce, Counter) when Nonce =:= ExpectedNonce ->
|
||||
{ok, Counter};
|
||||
validate_nonce(Nonce, _ExpectedNonce, Counter) ->
|
||||
<<_Prefix:4/binary, ReceivedCounter:64/big-unsigned-integer>> = Nonce,
|
||||
MaxWindow = 32,
|
||||
case ReceivedCounter > Counter andalso ReceivedCounter =< Counter + MaxWindow of
|
||||
true ->
|
||||
{ok, ReceivedCounter};
|
||||
false ->
|
||||
{error, {nonce_mismatch, Counter, ReceivedCounter}}
|
||||
end.
|
||||
|
||||
-spec is_encrypted_frame(binary()) -> boolean().
|
||||
is_encrypted_frame(<<?ENCRYPTED_FRAME_PREFIX, _Rest/binary>>) ->
|
||||
true;
|
||||
is_encrypted_frame(_) ->
|
||||
false.
|
||||
|
||||
-spec unwrap_encrypted_frame(binary()) -> {ok, binary()} | {error, not_encrypted}.
|
||||
unwrap_encrypted_frame(<<?ENCRYPTED_FRAME_PREFIX, Data/binary>>) ->
|
||||
{ok, Data};
|
||||
unwrap_encrypted_frame(_) ->
|
||||
{error, not_encrypted}.
|
||||
|
||||
-spec wrap_encrypted_frame(binary()) -> binary().
|
||||
wrap_encrypted_frame(Data) ->
|
||||
<<?ENCRYPTED_FRAME_PREFIX, Data/binary>>.
|
||||
|
||||
-ifdef(TEST).
|
||||
-include_lib("eunit/include/eunit.hrl").
|
||||
|
||||
generate_keypair_test() ->
|
||||
Keypair = generate_keypair(),
|
||||
?assert(is_map(Keypair)),
|
||||
?assertEqual(?KEY_SIZE, byte_size(maps:get(public, Keypair))),
|
||||
?assertEqual(?KEY_SIZE, byte_size(maps:get(private, Keypair))).
|
||||
|
||||
derive_shared_secret_test() ->
|
||||
Keypair1 = generate_keypair(),
|
||||
Keypair2 = generate_keypair(),
|
||||
{ok, Secret1} = derive_shared_secret(maps:get(public, Keypair2), Keypair1),
|
||||
{ok, Secret2} = derive_shared_secret(maps:get(public, Keypair1), Keypair2),
|
||||
?assertEqual(Secret1, Secret2),
|
||||
?assertEqual(?KEY_SIZE, byte_size(Secret1)).
|
||||
|
||||
derive_shared_secret_invalid_key_test() ->
|
||||
Keypair = generate_keypair(),
|
||||
Result = derive_shared_secret(<<"short">>, Keypair),
|
||||
?assertMatch({error, {invalid_peer_key_size, _}}, Result).
|
||||
|
||||
new_crypto_state_test() ->
|
||||
Secret = crypto:strong_rand_bytes(?KEY_SIZE),
|
||||
State = new_crypto_state(Secret),
|
||||
?assertEqual(Secret, maps:get(shared_secret, State)),
|
||||
?assertEqual(0, maps:get(send_counter, State)),
|
||||
?assertEqual(0, maps:get(recv_counter, State)).
|
||||
|
||||
encrypt_decrypt_roundtrip_test() ->
|
||||
Secret = crypto:strong_rand_bytes(?KEY_SIZE),
|
||||
State = new_crypto_state(Secret),
|
||||
Plaintext = <<"hello world">>,
|
||||
{ok, Ciphertext, State2} = encrypt(Plaintext, State),
|
||||
?assert(byte_size(Ciphertext) > byte_size(Plaintext)),
|
||||
?assertEqual(1, maps:get(send_counter, State2)),
|
||||
{ok, Decrypted, State3} = decrypt(Ciphertext, State),
|
||||
?assertEqual(Plaintext, Decrypted),
|
||||
?assertEqual(1, maps:get(recv_counter, State3)).
|
||||
|
||||
encrypt_multiple_messages_test() ->
|
||||
Secret = crypto:strong_rand_bytes(?KEY_SIZE),
|
||||
SendState = new_crypto_state(Secret),
|
||||
RecvState = new_crypto_state(Secret),
|
||||
Messages = [<<"msg1">>, <<"msg2">>, <<"msg3">>],
|
||||
{FinalSendState, FinalRecvState, DecryptedMsgs} = lists:foldl(
|
||||
fun(Msg, {SS, RS, Acc}) ->
|
||||
{ok, Cipher, SS2} = encrypt(Msg, SS),
|
||||
{ok, Plain, RS2} = decrypt(Cipher, RS),
|
||||
{SS2, RS2, [Plain | Acc]}
|
||||
end,
|
||||
{SendState, RecvState, []},
|
||||
Messages
|
||||
),
|
||||
?assertEqual(3, maps:get(send_counter, FinalSendState)),
|
||||
?assertEqual(3, maps:get(recv_counter, FinalRecvState)),
|
||||
?assertEqual(Messages, lists:reverse(DecryptedMsgs)).
|
||||
|
||||
decrypt_tampered_test() ->
|
||||
Secret = crypto:strong_rand_bytes(?KEY_SIZE),
|
||||
State = new_crypto_state(Secret),
|
||||
{ok, Ciphertext, _} = encrypt(<<"hello">>, State),
|
||||
Tampered = <<(binary:first(Ciphertext) bxor 1), (binary:part(Ciphertext, 1, byte_size(Ciphertext) - 1))/binary>>,
|
||||
Result = decrypt(Tampered, State),
|
||||
?assertMatch({error, _}, Result).
|
||||
|
||||
decrypt_too_short_test() ->
|
||||
Secret = crypto:strong_rand_bytes(?KEY_SIZE),
|
||||
State = new_crypto_state(Secret),
|
||||
Result = decrypt(<<"short">>, State),
|
||||
?assertMatch({error, {invalid_encrypted_data, too_short}}, Result).
|
||||
|
||||
is_encrypted_frame_test() ->
|
||||
?assertEqual(true, is_encrypted_frame(<<16#FE, "data">>)),
|
||||
?assertEqual(false, is_encrypted_frame(<<"data">>)),
|
||||
?assertEqual(false, is_encrypted_frame(<<16#FF, "data">>)),
|
||||
?assertEqual(false, is_encrypted_frame(<<>>)).
|
||||
|
||||
unwrap_encrypted_frame_test() ->
|
||||
?assertEqual({ok, <<"data">>}, unwrap_encrypted_frame(<<16#FE, "data">>)),
|
||||
?assertEqual({error, not_encrypted}, unwrap_encrypted_frame(<<"data">>)).
|
||||
|
||||
wrap_encrypted_frame_test() ->
|
||||
?assertEqual(<<16#FE, "data">>, wrap_encrypted_frame(<<"data">>)).
|
||||
|
||||
counter_to_nonce_test() ->
|
||||
Nonce0 = counter_to_nonce(0),
|
||||
?assertEqual(?NONCE_SIZE, byte_size(Nonce0)),
|
||||
?assertEqual(<<0:32, 0:64>>, Nonce0),
|
||||
Nonce1 = counter_to_nonce(1),
|
||||
?assertEqual(<<0:32, 1:64>>, Nonce1).
|
||||
|
||||
init_creates_keypair_test() ->
|
||||
persistent_term:erase(?KEYPAIR_KEY),
|
||||
ok = init(),
|
||||
Public = get_public_key(),
|
||||
?assert(is_binary(Public)),
|
||||
?assertEqual(?KEY_SIZE, byte_size(Public)),
|
||||
persistent_term:erase(?KEYPAIR_KEY).
|
||||
|
||||
init_idempotent_test() ->
|
||||
persistent_term:erase(?KEYPAIR_KEY),
|
||||
ok = init(),
|
||||
Public1 = get_public_key(),
|
||||
ok = init(),
|
||||
Public2 = get_public_key(),
|
||||
?assertEqual(Public1, Public2),
|
||||
persistent_term:erase(?KEYPAIR_KEY).
|
||||
|
||||
-endif.
|
||||
@@ -19,14 +19,15 @@
|
||||
|
||||
-export([load/0, get/1, get_optional/1, get_map/0, patch/1, update/1]).
|
||||
|
||||
-define(APP, fluxer_gateway).
|
||||
-define(CONFIG_TERM_KEY, {fluxer_gateway, runtime_config}).
|
||||
|
||||
-type config() :: map().
|
||||
|
||||
-spec load() -> config().
|
||||
load() ->
|
||||
set_config(build_config()).
|
||||
Config = build_config(),
|
||||
apply_system_config(Config),
|
||||
set_config(Config).
|
||||
|
||||
-spec get(atom()) -> term().
|
||||
get(Key) when is_atom(Key) ->
|
||||
@@ -67,202 +68,145 @@ ensure_loaded() ->
|
||||
|
||||
-spec build_config() -> config().
|
||||
build_config() ->
|
||||
#{
|
||||
ws_port => env_int("FLUXER_GATEWAY_WS_PORT", ws_port, 8080),
|
||||
rpc_port => env_int("FLUXER_GATEWAY_RPC_PORT", rpc_port, 8081),
|
||||
api_host => env_string("API_HOST", api_host, "api"),
|
||||
api_canary_host => env_optional_string("API_CANARY_HOST", api_canary_host),
|
||||
rpc_secret_key => env_binary("GATEWAY_RPC_SECRET", rpc_secret_key, undefined),
|
||||
identify_rate_limit_enabled => env_bool("FLUXER_GATEWAY_IDENTIFY_RATE_LIMIT_ENABLED", identify_rate_limit_enabled, false),
|
||||
push_enabled => env_bool("FLUXER_GATEWAY_PUSH_ENABLED", push_enabled, true),
|
||||
push_user_guild_settings_cache_mb => env_int("FLUXER_GATEWAY_PUSH_USER_GUILD_SETTINGS_CACHE_MB",
|
||||
push_user_guild_settings_cache_mb, 1024),
|
||||
push_subscriptions_cache_mb => env_int("FLUXER_GATEWAY_PUSH_SUBSCRIPTIONS_CACHE_MB",
|
||||
push_subscriptions_cache_mb, 1024),
|
||||
push_blocked_ids_cache_mb => env_int("FLUXER_GATEWAY_PUSH_BLOCKED_IDS_CACHE_MB",
|
||||
push_blocked_ids_cache_mb, 1024),
|
||||
presence_cache_shards => env_optional_int("FLUXER_GATEWAY_PRESENCE_CACHE_SHARDS", presence_cache_shards),
|
||||
presence_bus_shards => env_optional_int("FLUXER_GATEWAY_PRESENCE_BUS_SHARDS", presence_bus_shards),
|
||||
presence_shards => env_optional_int("FLUXER_GATEWAY_PRESENCE_SHARDS", presence_shards),
|
||||
guild_shards => env_optional_int("FLUXER_GATEWAY_GUILD_SHARDS", guild_shards),
|
||||
metrics_host => env_optional_string("FLUXER_METRICS_HOST", metrics_host),
|
||||
push_badge_counts_cache_mb => app_env_int(push_badge_counts_cache_mb, 256),
|
||||
push_badge_counts_cache_ttl_seconds => app_env_int(push_badge_counts_cache_ttl_seconds, 60),
|
||||
media_proxy_endpoint => env_optional_binary("MEDIA_PROXY_ENDPOINT", media_proxy_endpoint),
|
||||
vapid_email => env_binary("VAPID_EMAIL", vapid_email, <<"support@fluxer.app">>),
|
||||
vapid_public_key => env_binary("VAPID_PUBLIC_KEY", vapid_public_key, undefined),
|
||||
vapid_private_key => env_binary("VAPID_PRIVATE_KEY", vapid_private_key, undefined),
|
||||
gateway_metrics_enabled => app_env_optional_bool(gateway_metrics_enabled),
|
||||
gateway_metrics_report_interval_ms => app_env_optional_int(gateway_metrics_report_interval_ms)
|
||||
}.
|
||||
fluxer_gateway_config:load().
|
||||
|
||||
-spec env_int(string(), atom(), integer()) -> integer().
|
||||
env_int(EnvVar, AppKey, Default) when is_atom(AppKey), is_integer(Default) ->
|
||||
case os:getenv(EnvVar) of
|
||||
false ->
|
||||
app_env_int(AppKey, Default);
|
||||
Value ->
|
||||
parse_int(Value, Default)
|
||||
-spec apply_system_config(config()) -> ok.
|
||||
apply_system_config(Config) ->
|
||||
apply_logger_config(Config),
|
||||
apply_telemetry_config(Config).
|
||||
|
||||
-spec apply_logger_config(config()) -> ok.
|
||||
apply_logger_config(Config) ->
|
||||
LoggerLevel = resolve_logger_level(Config),
|
||||
logger:set_primary_config(level, LoggerLevel),
|
||||
logger:set_handler_config(default, level, LoggerLevel).
|
||||
|
||||
-spec apply_telemetry_config(config()) -> ok.
|
||||
apply_telemetry_config(Config) ->
|
||||
Telemetry = maps:get(telemetry, Config, #{}),
|
||||
apply_telemetry_config(Telemetry, Config).
|
||||
|
||||
-spec resolve_logger_level(config()) -> atom().
|
||||
resolve_logger_level(Config) ->
|
||||
Default = maps:get(logger_level, Config, info),
|
||||
case os:getenv("LOGGER_LEVEL") of
|
||||
false -> Default;
|
||||
"" -> Default;
|
||||
Value -> parse_logger_level(Value, Default)
|
||||
end.
|
||||
|
||||
-spec env_optional_int(string(), atom()) -> integer() | undefined.
|
||||
env_optional_int(EnvVar, AppKey) when is_atom(AppKey) ->
|
||||
case os:getenv(EnvVar) of
|
||||
false ->
|
||||
app_env_optional_int(AppKey);
|
||||
Value ->
|
||||
parse_int(Value, undefined)
|
||||
end.
|
||||
|
||||
-spec env_bool(string(), atom(), boolean()) -> boolean().
|
||||
env_bool(EnvVar, AppKey, Default) when is_atom(AppKey), is_boolean(Default) ->
|
||||
case os:getenv(EnvVar) of
|
||||
false ->
|
||||
app_env_bool(AppKey, Default);
|
||||
Value ->
|
||||
parse_bool(Value, Default)
|
||||
end.
|
||||
|
||||
-spec env_string(string(), atom(), string()) -> string().
|
||||
env_string(EnvVar, AppKey, Default) when is_atom(AppKey) ->
|
||||
case os:getenv(EnvVar) of
|
||||
false ->
|
||||
app_env_string(AppKey, Default);
|
||||
Value ->
|
||||
Value
|
||||
end.
|
||||
|
||||
-spec env_optional_string(string(), atom()) -> string() | undefined.
|
||||
env_optional_string(EnvVar, AppKey) when is_atom(AppKey) ->
|
||||
case os:getenv(EnvVar) of
|
||||
false ->
|
||||
app_env_optional_string(AppKey);
|
||||
Value ->
|
||||
Value
|
||||
end.
|
||||
|
||||
-spec env_binary(string(), atom(), binary() | undefined) -> binary() | undefined.
|
||||
env_binary(EnvVar, AppKey, Default) when is_atom(AppKey) ->
|
||||
case os:getenv(EnvVar) of
|
||||
false ->
|
||||
app_env_binary(AppKey, Default);
|
||||
Value ->
|
||||
to_binary(Value, Default)
|
||||
end.
|
||||
|
||||
-spec env_optional_binary(string(), atom()) -> binary() | undefined.
|
||||
env_optional_binary(EnvVar, AppKey) when is_atom(AppKey) ->
|
||||
case os:getenv(EnvVar) of
|
||||
false ->
|
||||
app_env_optional_binary(AppKey);
|
||||
Value ->
|
||||
to_binary(Value, undefined)
|
||||
end.
|
||||
|
||||
-spec parse_int(string(), integer() | undefined) -> integer() | undefined.
|
||||
parse_int(Value, Default) ->
|
||||
Str = string:trim(Value),
|
||||
try
|
||||
list_to_integer(Str)
|
||||
catch
|
||||
_:_ -> Default
|
||||
end.
|
||||
|
||||
-spec parse_bool(string(), boolean()) -> boolean().
|
||||
parse_bool(Value, Default) ->
|
||||
Str = string:lowercase(string:trim(Value)),
|
||||
case Str of
|
||||
"true" -> true;
|
||||
"1" -> true;
|
||||
"false" -> false;
|
||||
"0" -> false;
|
||||
-spec parse_logger_level(string(), atom()) -> atom().
|
||||
parse_logger_level(Value, Default) ->
|
||||
case string:lowercase(string:trim(Value)) of
|
||||
"debug" -> debug;
|
||||
"info" -> info;
|
||||
"notice" -> notice;
|
||||
"warning" -> warning;
|
||||
"error" -> error;
|
||||
"critical" -> critical;
|
||||
"alert" -> alert;
|
||||
"emergency" -> emergency;
|
||||
_ -> Default
|
||||
end.
|
||||
|
||||
-spec to_binary(string(), binary() | undefined) -> binary() | undefined.
|
||||
to_binary(Value, Default) ->
|
||||
try
|
||||
list_to_binary(Value)
|
||||
catch
|
||||
_:_ -> Default
|
||||
-ifdef(HAS_OPENTELEMETRY).
|
||||
-spec apply_telemetry_config(map(), config()) -> ok.
|
||||
apply_telemetry_config(Telemetry, Config) ->
|
||||
Sentry = maps:get(sentry, Config, #{}),
|
||||
ShouldEnable = otel_metrics:configure_enabled(Telemetry),
|
||||
case ShouldEnable of
|
||||
true ->
|
||||
set_opentelemetry_env(Telemetry, Sentry, Config);
|
||||
false ->
|
||||
application:set_env(opentelemetry_experimental, readers, []),
|
||||
application:set_env(opentelemetry, processors, []),
|
||||
application:set_env(opentelemetry, traces_exporter, none)
|
||||
end.
|
||||
|
||||
-spec app_env_int(atom(), integer()) -> integer().
|
||||
app_env_int(Key, Default) ->
|
||||
case application:get_env(?APP, Key) of
|
||||
{ok, Value} when is_integer(Value) ->
|
||||
Value;
|
||||
_ ->
|
||||
Default
|
||||
-spec set_opentelemetry_env(map(), map(), config()) -> ok.
|
||||
-ifdef(DEV_MODE).
|
||||
set_opentelemetry_env(_Telemetry, _Sentry, _Config) ->
|
||||
ok.
|
||||
-else.
|
||||
set_opentelemetry_env(Telemetry, Sentry, Config) ->
|
||||
Endpoint = maps:get(otlp_endpoint, Telemetry, ""),
|
||||
ApiKey = maps:get(api_key, Telemetry, ""),
|
||||
Headers = otlp_headers(ApiKey),
|
||||
ServiceName = maps:get(service_name, Telemetry, "fluxer-gateway"),
|
||||
Environment = maps:get(environment, Telemetry, "development"),
|
||||
Version = maps:get(build_sha, Sentry, ""),
|
||||
InstanceId = maps:get(release_node, Config, ""),
|
||||
Resource = [
|
||||
{service_name, ServiceName},
|
||||
{service_version, Version},
|
||||
{service_namespace, "fluxer"},
|
||||
{deployment_environment, Environment},
|
||||
{service_instance_id, InstanceId}
|
||||
],
|
||||
application:set_env(
|
||||
opentelemetry_experimental,
|
||||
readers,
|
||||
[
|
||||
{otel_periodic_reader, #{
|
||||
exporter =>
|
||||
{otel_otlp_metrics, #{
|
||||
protocol => http_protobuf,
|
||||
endpoint => Endpoint,
|
||||
headers => Headers
|
||||
}},
|
||||
interval => 30000
|
||||
}}
|
||||
]
|
||||
),
|
||||
application:set_env(opentelemetry_experimental, resource, Resource),
|
||||
application:set_env(
|
||||
opentelemetry,
|
||||
processors,
|
||||
[
|
||||
{otel_batch_processor, #{
|
||||
exporter => {opentelemetry_exporter, #{}},
|
||||
scheduled_delay_ms => 1000,
|
||||
max_queue_size => 2048,
|
||||
export_timeout_ms => 30000
|
||||
}}
|
||||
]
|
||||
),
|
||||
application:set_env(opentelemetry, traces_exporter, {opentelemetry_exporter, #{}}),
|
||||
application:set_env(
|
||||
opentelemetry,
|
||||
logger,
|
||||
[
|
||||
{handler, default, otel_log_handler, #{
|
||||
level => info,
|
||||
max_queue_size => 2048,
|
||||
scheduled_delay_ms => 1000,
|
||||
exporting_timeout_ms => 30000,
|
||||
exporter =>
|
||||
{otel_otlp_logs, #{
|
||||
protocol => http_protobuf,
|
||||
endpoint => Endpoint,
|
||||
headers => Headers
|
||||
}}
|
||||
}}
|
||||
]
|
||||
),
|
||||
application:set_env(opentelemetry_exporter, otlp_protocol, http_protobuf),
|
||||
application:set_env(opentelemetry_exporter, otlp_endpoint, Endpoint),
|
||||
application:set_env(opentelemetry_exporter, otlp_headers, Headers).
|
||||
|
||||
-spec otlp_headers(string()) -> [{string(), string()}].
|
||||
otlp_headers(ApiKey) ->
|
||||
ApiKeyStr = string:trim(ApiKey),
|
||||
case ApiKeyStr of
|
||||
"" -> [];
|
||||
_ -> [{"Authorization", "Bearer " ++ ApiKeyStr}]
|
||||
end.
|
||||
|
||||
-spec app_env_optional_int(atom()) -> integer() | undefined.
|
||||
app_env_optional_int(Key) ->
|
||||
case application:get_env(?APP, Key) of
|
||||
{ok, Value} when is_integer(Value) ->
|
||||
Value;
|
||||
_ ->
|
||||
undefined
|
||||
end.
|
||||
|
||||
-spec app_env_bool(atom(), boolean()) -> boolean().
|
||||
app_env_bool(Key, Default) ->
|
||||
case application:get_env(?APP, Key) of
|
||||
{ok, Value} when is_boolean(Value) ->
|
||||
Value;
|
||||
_ ->
|
||||
Default
|
||||
end.
|
||||
|
||||
-spec app_env_optional_bool(atom()) -> boolean() | undefined.
|
||||
app_env_optional_bool(Key) ->
|
||||
case application:get_env(?APP, Key) of
|
||||
{ok, Value} when is_boolean(Value) ->
|
||||
Value;
|
||||
_ ->
|
||||
undefined
|
||||
end.
|
||||
|
||||
-spec app_env_string(atom(), string()) -> string().
|
||||
app_env_string(Key, Default) ->
|
||||
case application:get_env(?APP, Key) of
|
||||
{ok, Value} when is_list(Value) ->
|
||||
Value;
|
||||
{ok, Value} when is_binary(Value) ->
|
||||
binary_to_list(Value);
|
||||
_ ->
|
||||
Default
|
||||
end.
|
||||
|
||||
-spec app_env_optional_string(atom()) -> string() | undefined.
|
||||
app_env_optional_string(Key) ->
|
||||
case application:get_env(?APP, Key) of
|
||||
{ok, Value} when is_list(Value) ->
|
||||
Value;
|
||||
{ok, Value} when is_binary(Value) ->
|
||||
binary_to_list(Value);
|
||||
_ ->
|
||||
undefined
|
||||
end.
|
||||
|
||||
-spec app_env_binary(atom(), binary() | undefined) -> binary() | undefined.
|
||||
app_env_binary(Key, Default) ->
|
||||
case application:get_env(?APP, Key) of
|
||||
{ok, Value} when is_binary(Value) ->
|
||||
Value;
|
||||
{ok, Value} when is_list(Value) ->
|
||||
list_to_binary(Value);
|
||||
_ ->
|
||||
Default
|
||||
end.
|
||||
|
||||
-spec app_env_optional_binary(atom()) -> binary() | undefined.
|
||||
app_env_optional_binary(Key) ->
|
||||
case application:get_env(?APP, Key) of
|
||||
{ok, Value} when is_binary(Value) ->
|
||||
Value;
|
||||
{ok, Value} when is_list(Value) ->
|
||||
list_to_binary(Value);
|
||||
_ ->
|
||||
undefined
|
||||
end.
|
||||
-endif.
|
||||
-else.
|
||||
-spec apply_telemetry_config(map(), config()) -> ok.
|
||||
apply_telemetry_config(_Telemetry, _Config) ->
|
||||
application:set_env(opentelemetry_experimental, readers, []),
|
||||
application:set_env(opentelemetry, processors, []),
|
||||
application:set_env(opentelemetry, traces_exporter, none).
|
||||
-endif.
|
||||
|
||||
@@ -19,74 +19,39 @@
|
||||
-behaviour(supervisor).
|
||||
-export([start_link/0, init/1]).
|
||||
|
||||
-spec start_link() -> {ok, pid()} | {error, term()}.
|
||||
start_link() ->
|
||||
supervisor:start_link({local, ?MODULE}, ?MODULE, []).
|
||||
|
||||
-spec init([]) -> {ok, {supervisor:sup_flags(), [supervisor:child_spec()]}}.
|
||||
init([]) ->
|
||||
SessionManager = #{
|
||||
id => session_manager,
|
||||
start => {session_manager, start_link, []},
|
||||
SupFlags = #{
|
||||
strategy => one_for_one,
|
||||
intensity => 5,
|
||||
period => 10
|
||||
},
|
||||
Children = [
|
||||
child_spec(gateway_http_client, gateway_http_client),
|
||||
child_spec(gateway_rpc_tcp_server, gateway_rpc_tcp_server),
|
||||
child_spec(session_manager, session_manager),
|
||||
child_spec(presence_cache, presence_cache),
|
||||
child_spec(presence_bus, presence_bus),
|
||||
child_spec(presence_manager, presence_manager),
|
||||
child_spec(guild_crash_logger, guild_crash_logger),
|
||||
child_spec(guild_manager, guild_manager),
|
||||
child_spec(call_manager, call_manager),
|
||||
child_spec(push_dispatcher, push_dispatcher),
|
||||
child_spec(push, push),
|
||||
child_spec(gateway_metrics_collector, gateway_metrics_collector)
|
||||
],
|
||||
{ok, {SupFlags, Children}}.
|
||||
|
||||
-spec child_spec(atom(), module()) -> supervisor:child_spec().
|
||||
child_spec(Id, Module) ->
|
||||
#{
|
||||
id => Id,
|
||||
start => {Module, start_link, []},
|
||||
restart => permanent,
|
||||
shutdown => 5000,
|
||||
type => worker
|
||||
},
|
||||
PresenceManager = #{
|
||||
id => presence_manager,
|
||||
start => {presence_manager, start_link, []},
|
||||
restart => permanent,
|
||||
shutdown => 5000,
|
||||
type => worker
|
||||
},
|
||||
GuildManager = #{
|
||||
id => guild_manager,
|
||||
start => {guild_manager, start_link, []},
|
||||
restart => permanent,
|
||||
shutdown => 5000,
|
||||
type => worker
|
||||
},
|
||||
Push = #{
|
||||
id => push,
|
||||
start => {push, start_link, []},
|
||||
restart => permanent,
|
||||
shutdown => 5000,
|
||||
type => worker
|
||||
},
|
||||
CallManager = #{
|
||||
id => call_manager,
|
||||
start => {call_manager, start_link, []},
|
||||
restart => permanent,
|
||||
shutdown => 5000,
|
||||
type => worker
|
||||
},
|
||||
PresenceBus = #{
|
||||
id => presence_bus,
|
||||
start => {presence_bus, start_link, []},
|
||||
restart => permanent,
|
||||
shutdown => 5000,
|
||||
type => worker
|
||||
},
|
||||
PresenceCache = #{
|
||||
id => presence_cache,
|
||||
start => {presence_cache, start_link, []},
|
||||
restart => permanent,
|
||||
shutdown => 5000,
|
||||
type => worker
|
||||
},
|
||||
GatewayMetricsCollector = #{
|
||||
id => gateway_metrics_collector,
|
||||
start => {gateway_metrics_collector, start_link, []},
|
||||
restart => permanent,
|
||||
shutdown => 5000,
|
||||
type => worker
|
||||
},
|
||||
{ok,
|
||||
{{one_for_one, 5, 10}, [
|
||||
SessionManager,
|
||||
PresenceCache,
|
||||
PresenceBus,
|
||||
PresenceManager,
|
||||
GuildManager,
|
||||
CallManager,
|
||||
Push,
|
||||
GatewayMetricsCollector
|
||||
]}}.
|
||||
}.
|
||||
|
||||
@@ -24,15 +24,18 @@
|
||||
]).
|
||||
|
||||
-type encoding() :: json.
|
||||
-type frame_type() :: text | binary.
|
||||
|
||||
-export_type([encoding/0]).
|
||||
|
||||
-spec parse_encoding(binary() | undefined) -> encoding().
|
||||
parse_encoding(_) -> json.
|
||||
parse_encoding(_) ->
|
||||
json.
|
||||
|
||||
-spec encode(map(), encoding()) -> {ok, iodata(), text | binary} | {error, term()}.
|
||||
-spec encode(map(), encoding()) -> {ok, iodata(), frame_type()} | {error, term()}.
|
||||
encode(Message, json) ->
|
||||
try
|
||||
Encoded = jsx:encode(Message),
|
||||
Encoded = iolist_to_binary(json:encode(Message)),
|
||||
{ok, Encoded, text}
|
||||
catch
|
||||
_:Reason ->
|
||||
@@ -42,7 +45,7 @@ encode(Message, json) ->
|
||||
-spec decode(binary(), encoding()) -> {ok, map()} | {error, term()}.
|
||||
decode(Data, json) ->
|
||||
try
|
||||
Decoded = jsx:decode(Data, [{return_maps, true}]),
|
||||
Decoded = json:decode(Data),
|
||||
{ok, Decoded}
|
||||
catch
|
||||
_:Reason ->
|
||||
@@ -52,26 +55,66 @@ decode(Data, json) ->
|
||||
-ifdef(TEST).
|
||||
-include_lib("eunit/include/eunit.hrl").
|
||||
|
||||
parse_encoding_test() ->
|
||||
?assertEqual(json, parse_encoding(<<"json">>)),
|
||||
?assertEqual(json, parse_encoding(<<"etf">>)),
|
||||
?assertEqual(json, parse_encoding(undefined)),
|
||||
?assertEqual(json, parse_encoding(<<"invalid">>)).
|
||||
parse_encoding_test_() ->
|
||||
[
|
||||
?_assertEqual(json, parse_encoding(<<"json">>)),
|
||||
?_assertEqual(json, parse_encoding(<<"etf">>)),
|
||||
?_assertEqual(json, parse_encoding(undefined)),
|
||||
?_assertEqual(json, parse_encoding(<<"invalid">>)),
|
||||
?_assertEqual(json, parse_encoding(<<>>))
|
||||
].
|
||||
|
||||
encode_json_test() ->
|
||||
encode_json_test_() ->
|
||||
Message = #{<<"op">> => 0, <<"d">> => #{<<"test">> => true}},
|
||||
[
|
||||
?_assertMatch({ok, _, text}, encode(Message, json)),
|
||||
?_test(begin
|
||||
{ok, Encoded, text} = encode(Message, json),
|
||||
?assert(is_binary(Encoded))
|
||||
end)
|
||||
].
|
||||
|
||||
encode_empty_map_test() ->
|
||||
{ok, Encoded, text} = encode(#{}, json),
|
||||
?assertEqual(<<"{}">>, Encoded).
|
||||
|
||||
encode_nested_test() ->
|
||||
Message = #{<<"a">> => #{<<"b">> => #{<<"c">> => 1}}},
|
||||
{ok, Encoded, text} = encode(Message, json),
|
||||
?assert(is_binary(Encoded)).
|
||||
?assert(is_binary(Encoded)),
|
||||
{ok, Decoded} = decode(Encoded, json),
|
||||
?assertEqual(Message, Decoded).
|
||||
|
||||
decode_json_test() ->
|
||||
decode_json_test_() ->
|
||||
Data = <<"{\"op\":0,\"d\":{\"test\":true}}">>,
|
||||
{ok, Decoded} = decode(Data, json),
|
||||
?assertEqual(0, maps:get(<<"op">>, Decoded)).
|
||||
[
|
||||
?_assertMatch({ok, _}, decode(Data, json)),
|
||||
?_test(begin
|
||||
{ok, Decoded} = decode(Data, json),
|
||||
?assertEqual(0, maps:get(<<"op">>, Decoded))
|
||||
end)
|
||||
].
|
||||
|
||||
roundtrip_json_test() ->
|
||||
Original = #{<<"op">> => 10, <<"d">> => #{<<"heartbeat_interval">> => 41250}},
|
||||
{ok, Encoded, _} = encode(Original, json),
|
||||
{ok, Decoded} = decode(iolist_to_binary(Encoded), json),
|
||||
?assertEqual(Original, Decoded).
|
||||
decode_invalid_json_test() ->
|
||||
?assertMatch({error, {decode_failed, _}}, decode(<<"not json">>, json)).
|
||||
|
||||
decode_empty_object_test() ->
|
||||
{ok, Decoded} = decode(<<"{}">>, json),
|
||||
?assertEqual(#{}, Decoded).
|
||||
|
||||
roundtrip_json_test_() ->
|
||||
Messages = [
|
||||
#{<<"op">> => 10, <<"d">> => #{<<"heartbeat_interval">> => 41250}},
|
||||
#{<<"op">> => 0, <<"s">> => 1, <<"t">> => <<"READY">>, <<"d">> => #{}},
|
||||
#{<<"list">> => [1, 2, 3], <<"bool">> => true, <<"null">> => null}
|
||||
],
|
||||
[
|
||||
?_test(begin
|
||||
{ok, Encoded, _} = encode(Msg, json),
|
||||
{ok, Decoded} = decode(iolist_to_binary(Encoded), json),
|
||||
?assertEqual(Msg, Decoded)
|
||||
end)
|
||||
|| Msg <- Messages
|
||||
].
|
||||
|
||||
-endif.
|
||||
|
||||
@@ -27,80 +27,166 @@
|
||||
]).
|
||||
|
||||
-type compression() :: none | zstd_stream.
|
||||
-export_type([compression/0]).
|
||||
|
||||
-record(compress_ctx, {type :: compression()}).
|
||||
-type compress_ctx() :: #compress_ctx{}.
|
||||
-export_type([compress_ctx/0]).
|
||||
-opaque compress_ctx() :: #{type := compression()}.
|
||||
|
||||
-export_type([compression/0, compress_ctx/0]).
|
||||
|
||||
-spec parse_compression(binary() | undefined) -> compression().
|
||||
parse_compression(<<"none">>) -> none;
|
||||
parse_compression(<<"zstd-stream">>) -> zstd_stream;
|
||||
parse_compression(_) -> none.
|
||||
parse_compression(<<"none">>) ->
|
||||
none;
|
||||
%% TODO: temporarily disabled – re-enable zstd-stream once compression issues are resolved
|
||||
parse_compression(<<"zstd-stream">>) ->
|
||||
none;
|
||||
parse_compression(_) ->
|
||||
none.
|
||||
|
||||
-spec new_context(compression()) -> compress_ctx().
|
||||
new_context(none) ->
|
||||
#compress_ctx{type = none};
|
||||
#{type => none};
|
||||
new_context(zstd_stream) ->
|
||||
#compress_ctx{type = zstd_stream}.
|
||||
#{type => zstd_stream}.
|
||||
|
||||
-spec close_context(compress_ctx()) -> ok.
|
||||
close_context(_Ctx) ->
|
||||
close_context(#{}) ->
|
||||
ok.
|
||||
|
||||
-spec get_type(compress_ctx()) -> compression().
|
||||
get_type(#compress_ctx{type = Type}) ->
|
||||
get_type(#{type := Type}) ->
|
||||
Type.
|
||||
|
||||
-spec compress(iodata(), compress_ctx()) -> {ok, binary(), compress_ctx()} | {error, term()}.
|
||||
compress(Data, Ctx = #compress_ctx{type = none}) ->
|
||||
compress(Data, Ctx = #{type := none}) ->
|
||||
{ok, iolist_to_binary(Data), Ctx};
|
||||
compress(Data, Ctx = #compress_ctx{type = zstd_stream}) ->
|
||||
try
|
||||
Binary = iolist_to_binary(Data),
|
||||
case ezstd:compress(Binary, 3) of
|
||||
Compressed when is_binary(Compressed) ->
|
||||
{ok, Compressed, Ctx};
|
||||
{error, Reason} ->
|
||||
{error, {compress_failed, Reason}}
|
||||
end
|
||||
catch
|
||||
_:Exception ->
|
||||
{error, {compress_failed, Exception}}
|
||||
end.
|
||||
compress(Data, Ctx = #{type := zstd_stream}) ->
|
||||
zstd_compress(Data, Ctx).
|
||||
|
||||
-spec decompress(binary(), compress_ctx()) -> {ok, binary(), compress_ctx()} | {error, term()}.
|
||||
decompress(Data, Ctx = #compress_ctx{type = none}) ->
|
||||
decompress(Data, Ctx = #{type := none}) ->
|
||||
{ok, Data, Ctx};
|
||||
decompress(Data, Ctx = #compress_ctx{type = zstd_stream}) ->
|
||||
try
|
||||
case ezstd:decompress(Data) of
|
||||
Decompressed when is_binary(Decompressed) ->
|
||||
{ok, Decompressed, Ctx};
|
||||
{error, Reason} ->
|
||||
{error, {decompress_failed, Reason}}
|
||||
end
|
||||
catch
|
||||
_:Exception ->
|
||||
{error, {decompress_failed, Exception}}
|
||||
decompress(Data, Ctx = #{type := zstd_stream}) ->
|
||||
zstd_decompress(Data, Ctx).
|
||||
|
||||
zstd_compress(Data, Ctx) ->
|
||||
case ezstd_available() of
|
||||
true ->
|
||||
try
|
||||
Binary = iolist_to_binary(Data),
|
||||
case erlang:apply(ezstd, compress, [Binary, 3]) of
|
||||
Compressed when is_binary(Compressed) ->
|
||||
{ok, Compressed, Ctx};
|
||||
{error, Reason} ->
|
||||
{error, {compress_failed, Reason}}
|
||||
end
|
||||
catch
|
||||
_:Exception ->
|
||||
{error, {compress_failed, Exception}}
|
||||
end;
|
||||
false ->
|
||||
{error, {compress_failed, zstd_not_available}}
|
||||
end.
|
||||
|
||||
zstd_decompress(Data, Ctx) ->
|
||||
case ezstd_available() of
|
||||
true ->
|
||||
try
|
||||
case erlang:apply(ezstd, decompress, [Data]) of
|
||||
Decompressed when is_binary(Decompressed) ->
|
||||
{ok, Decompressed, Ctx};
|
||||
{error, Reason} ->
|
||||
{error, {decompress_failed, Reason}}
|
||||
end
|
||||
catch
|
||||
_:Exception ->
|
||||
{error, {decompress_failed, Exception}}
|
||||
end;
|
||||
false ->
|
||||
{error, {decompress_failed, zstd_not_available}}
|
||||
end.
|
||||
|
||||
-spec ezstd_available() -> boolean().
|
||||
ezstd_available() ->
|
||||
case code:ensure_loaded(ezstd) of
|
||||
{module, ezstd} ->
|
||||
erlang:function_exported(ezstd, compress, 2) andalso
|
||||
erlang:function_exported(ezstd, decompress, 1);
|
||||
_ ->
|
||||
false
|
||||
end.
|
||||
|
||||
-ifdef(TEST).
|
||||
-include_lib("eunit/include/eunit.hrl").
|
||||
|
||||
parse_compression_test() ->
|
||||
?assertEqual(none, parse_compression(undefined)),
|
||||
?assertEqual(none, parse_compression(<<>>)),
|
||||
?assertEqual(zstd_stream, parse_compression(<<"zstd-stream">>)),
|
||||
?assertEqual(none, parse_compression(<<"none">>)).
|
||||
parse_compression_test_() ->
|
||||
[
|
||||
?_assertEqual(none, parse_compression(undefined)),
|
||||
?_assertEqual(none, parse_compression(<<>>)),
|
||||
?_assertEqual(none, parse_compression(<<"none">>)),
|
||||
?_assertEqual(none, parse_compression(<<"invalid">>)),
|
||||
%% zstd-stream temporarily disabled – always returns none
|
||||
?_assertEqual(none, parse_compression(<<"zstd-stream">>))
|
||||
].
|
||||
|
||||
zstd_roundtrip_test() ->
|
||||
Ctx = new_context(zstd_stream),
|
||||
Data = <<"hello world, this is a test message for zstd compression">>,
|
||||
new_context_test_() ->
|
||||
[
|
||||
?_assertEqual(none, get_type(new_context(none))),
|
||||
?_assertEqual(zstd_stream, get_type(new_context(zstd_stream)))
|
||||
].
|
||||
|
||||
close_context_test() ->
|
||||
Ctx = new_context(none),
|
||||
?assertEqual(ok, close_context(Ctx)).
|
||||
|
||||
compress_none_test() ->
|
||||
Ctx = new_context(none),
|
||||
Data = <<"hello world">>,
|
||||
{ok, Compressed, Ctx2} = compress(Data, Ctx),
|
||||
?assert(is_binary(Compressed)),
|
||||
{ok, Decompressed, _} = decompress(Compressed, Ctx2),
|
||||
?assertEqual(Data, Decompressed),
|
||||
ok = close_context(Ctx2).
|
||||
?assertEqual(Data, Compressed),
|
||||
?assertEqual(none, get_type(Ctx2)).
|
||||
|
||||
compress_none_iolist_test() ->
|
||||
Ctx = new_context(none),
|
||||
Data = [<<"hello">>, <<" ">>, <<"world">>],
|
||||
{ok, Compressed, _} = compress(Data, Ctx),
|
||||
?assertEqual(<<"hello world">>, Compressed).
|
||||
|
||||
decompress_none_test() ->
|
||||
Ctx = new_context(none),
|
||||
Data = <<"hello world">>,
|
||||
{ok, Decompressed, _} = decompress(Data, Ctx),
|
||||
?assertEqual(Data, Decompressed).
|
||||
|
||||
-ifdef(DEV_MODE).
|
||||
zstd_roundtrip_test() ->
|
||||
?assertEqual(skip, skip).
|
||||
|
||||
zstd_compression_ratio_test() ->
|
||||
?assertEqual(skip, skip).
|
||||
-else.
|
||||
zstd_roundtrip_test() ->
|
||||
case ezstd_available() of
|
||||
true ->
|
||||
Ctx = new_context(zstd_stream),
|
||||
Data = <<"hello world, this is a test message for zstd compression">>,
|
||||
{ok, Compressed, Ctx2} = compress(Data, Ctx),
|
||||
?assert(is_binary(Compressed)),
|
||||
{ok, Decompressed, _} = decompress(Compressed, Ctx2),
|
||||
?assertEqual(Data, Decompressed),
|
||||
ok = close_context(Ctx2);
|
||||
false ->
|
||||
?assertEqual(skip, skip)
|
||||
end.
|
||||
|
||||
zstd_compression_ratio_test() ->
|
||||
case ezstd_available() of
|
||||
true ->
|
||||
Ctx = new_context(zstd_stream),
|
||||
Data = binary:copy(<<"aaaaaaaaaa">>, 100),
|
||||
{ok, Compressed, _} = compress(Data, Ctx),
|
||||
?assert(byte_size(Compressed) < byte_size(Data));
|
||||
false ->
|
||||
?assertEqual(skip, skip)
|
||||
end.
|
||||
-endif.
|
||||
|
||||
-endif.
|
||||
|
||||
@@ -17,6 +17,8 @@
|
||||
|
||||
-module(gateway_errors).
|
||||
|
||||
-compile({no_auto_import, [error/1]}).
|
||||
|
||||
-export([
|
||||
error/1,
|
||||
error_code/1,
|
||||
@@ -25,11 +27,64 @@
|
||||
is_recoverable/1
|
||||
]).
|
||||
|
||||
-spec error(atom()) -> {error, atom(), atom()}.
|
||||
-ifdef(TEST).
|
||||
-include_lib("eunit/include/eunit.hrl").
|
||||
-endif.
|
||||
|
||||
-type error_atom() ::
|
||||
voice_connection_not_found
|
||||
| voice_channel_not_found
|
||||
| voice_channel_not_voice
|
||||
| voice_member_not_found
|
||||
| voice_user_not_in_voice
|
||||
| voice_guild_not_found
|
||||
| voice_permission_denied
|
||||
| voice_member_timed_out
|
||||
| voice_channel_full
|
||||
| voice_missing_connection_id
|
||||
| voice_invalid_user_id
|
||||
| voice_invalid_channel_id
|
||||
| voice_invalid_state
|
||||
| voice_user_mismatch
|
||||
| voice_token_failed
|
||||
| voice_guild_id_missing
|
||||
| voice_invalid_guild_id
|
||||
| voice_moderator_missing_connect
|
||||
| voice_unclaimed_account
|
||||
| voice_update_rate_limited
|
||||
| voice_nonce_mismatch
|
||||
| voice_pending_expired
|
||||
| voice_camera_user_limit
|
||||
| dm_channel_not_found
|
||||
| dm_not_recipient
|
||||
| dm_invalid_channel_type
|
||||
| validation_invalid_snowflake
|
||||
| validation_null_snowflake
|
||||
| validation_invalid_snowflake_list
|
||||
| validation_expected_list
|
||||
| validation_expected_map
|
||||
| validation_missing_field
|
||||
| validation_invalid_params
|
||||
| internal_error
|
||||
| timeout
|
||||
| unknown_error
|
||||
| atom().
|
||||
|
||||
-type error_category() ::
|
||||
not_found
|
||||
| validation_error
|
||||
| permission_denied
|
||||
| voice_error
|
||||
| rate_limited
|
||||
| timeout
|
||||
| unknown
|
||||
| auth_failed.
|
||||
|
||||
-spec error(error_atom()) -> {error, error_category(), error_atom()}.
|
||||
error(ErrorAtom) ->
|
||||
{error, error_category(ErrorAtom), ErrorAtom}.
|
||||
|
||||
-spec error_code(atom()) -> binary().
|
||||
-spec error_code(error_atom()) -> binary().
|
||||
error_code(voice_connection_not_found) -> <<"VOICE_CONNECTION_NOT_FOUND">>;
|
||||
error_code(voice_channel_not_found) -> <<"VOICE_CHANNEL_NOT_FOUND">>;
|
||||
error_code(voice_channel_not_voice) -> <<"VOICE_INVALID_CHANNEL_TYPE">>;
|
||||
@@ -49,6 +104,10 @@ error_code(voice_guild_id_missing) -> <<"VOICE_GUILD_ID_MISSING">>;
|
||||
error_code(voice_invalid_guild_id) -> <<"VOICE_INVALID_GUILD_ID">>;
|
||||
error_code(voice_moderator_missing_connect) -> <<"VOICE_PERMISSION_DENIED">>;
|
||||
error_code(voice_unclaimed_account) -> <<"VOICE_UNCLAIMED_ACCOUNT">>;
|
||||
error_code(voice_update_rate_limited) -> <<"VOICE_UPDATE_RATE_LIMITED">>;
|
||||
error_code(voice_nonce_mismatch) -> <<"VOICE_NONCE_MISMATCH">>;
|
||||
error_code(voice_pending_expired) -> <<"VOICE_PENDING_EXPIRED">>;
|
||||
error_code(voice_camera_user_limit) -> <<"VOICE_CAMERA_USER_LIMIT">>;
|
||||
error_code(dm_channel_not_found) -> <<"DM_CHANNEL_NOT_FOUND">>;
|
||||
error_code(dm_not_recipient) -> <<"DM_NOT_RECIPIENT">>;
|
||||
error_code(dm_invalid_channel_type) -> <<"DM_INVALID_CHANNEL_TYPE">>;
|
||||
@@ -64,7 +123,7 @@ error_code(timeout) -> <<"TIMEOUT">>;
|
||||
error_code(unknown_error) -> <<"UNKNOWN_ERROR">>;
|
||||
error_code(_) -> <<"UNKNOWN_ERROR">>.
|
||||
|
||||
-spec error_message(atom()) -> binary().
|
||||
-spec error_message(error_atom()) -> binary().
|
||||
error_message(voice_connection_not_found) -> <<"Voice connection not found">>;
|
||||
error_message(voice_channel_not_found) -> <<"Voice channel not found">>;
|
||||
error_message(voice_channel_not_voice) -> <<"Channel is not a voice channel">>;
|
||||
@@ -84,6 +143,10 @@ error_message(voice_guild_id_missing) -> <<"Guild ID is required">>;
|
||||
error_message(voice_invalid_guild_id) -> <<"Invalid guild ID">>;
|
||||
error_message(voice_moderator_missing_connect) -> <<"Moderator missing connect permission">>;
|
||||
error_message(voice_unclaimed_account) -> <<"Claim your account to join voice">>;
|
||||
error_message(voice_update_rate_limited) -> <<"Voice updates are rate limited">>;
|
||||
error_message(voice_nonce_mismatch) -> <<"Voice token nonce mismatch">>;
|
||||
error_message(voice_pending_expired) -> <<"Voice pending connection expired">>;
|
||||
error_message(voice_camera_user_limit) -> <<"Too many users in channel to enable camera">>;
|
||||
error_message(dm_channel_not_found) -> <<"DM channel not found">>;
|
||||
error_message(dm_not_recipient) -> <<"Not a recipient of this channel">>;
|
||||
error_message(dm_invalid_channel_type) -> <<"Not a DM or Group DM channel">>;
|
||||
@@ -99,7 +162,7 @@ error_message(timeout) -> <<"Request timed out">>;
|
||||
error_message(unknown_error) -> <<"An unknown error occurred">>;
|
||||
error_message(_) -> <<"An unknown error occurred">>.
|
||||
|
||||
-spec error_category(atom()) -> atom().
|
||||
-spec error_category(error_atom()) -> error_category().
|
||||
error_category(voice_connection_not_found) -> not_found;
|
||||
error_category(voice_channel_not_found) -> not_found;
|
||||
error_category(voice_channel_not_voice) -> validation_error;
|
||||
@@ -119,6 +182,10 @@ error_category(voice_guild_id_missing) -> validation_error;
|
||||
error_category(voice_invalid_guild_id) -> validation_error;
|
||||
error_category(voice_moderator_missing_connect) -> permission_denied;
|
||||
error_category(voice_unclaimed_account) -> permission_denied;
|
||||
error_category(voice_update_rate_limited) -> rate_limited;
|
||||
error_category(voice_nonce_mismatch) -> validation_error;
|
||||
error_category(voice_pending_expired) -> validation_error;
|
||||
error_category(voice_camera_user_limit) -> permission_denied;
|
||||
error_category(dm_channel_not_found) -> not_found;
|
||||
error_category(dm_not_recipient) -> permission_denied;
|
||||
error_category(dm_invalid_channel_type) -> validation_error;
|
||||
@@ -134,7 +201,7 @@ error_category(timeout) -> timeout;
|
||||
error_category(unknown_error) -> unknown;
|
||||
error_category(_) -> unknown.
|
||||
|
||||
-spec is_recoverable(atom()) -> boolean().
|
||||
-spec is_recoverable(error_category()) -> boolean().
|
||||
is_recoverable(not_found) -> true;
|
||||
is_recoverable(permission_denied) -> true;
|
||||
is_recoverable(voice_error) -> true;
|
||||
@@ -144,3 +211,122 @@ is_recoverable(unknown) -> true;
|
||||
is_recoverable(rate_limited) -> false;
|
||||
is_recoverable(auth_failed) -> false;
|
||||
is_recoverable(_) -> true.
|
||||
|
||||
-ifdef(TEST).
|
||||
|
||||
error_test() ->
|
||||
?assertEqual({error, not_found, voice_connection_not_found}, error(voice_connection_not_found)),
|
||||
?assertEqual(
|
||||
{error, validation_error, voice_channel_not_voice}, error(voice_channel_not_voice)
|
||||
),
|
||||
?assertEqual(
|
||||
{error, permission_denied, voice_permission_denied}, error(voice_permission_denied)
|
||||
).
|
||||
|
||||
error_code_test() ->
|
||||
?assertEqual(<<"VOICE_CONNECTION_NOT_FOUND">>, error_code(voice_connection_not_found)),
|
||||
?assertEqual(<<"VOICE_CHANNEL_NOT_FOUND">>, error_code(voice_channel_not_found)),
|
||||
?assertEqual(<<"VOICE_PERMISSION_DENIED">>, error_code(voice_permission_denied)),
|
||||
?assertEqual(<<"VOICE_PERMISSION_DENIED">>, error_code(voice_moderator_missing_connect)),
|
||||
?assertEqual(<<"UNKNOWN_ERROR">>, error_code(some_random_error)),
|
||||
?assertEqual(<<"TIMEOUT">>, error_code(timeout)),
|
||||
?assertEqual(<<"INTERNAL_ERROR">>, error_code(internal_error)).
|
||||
|
||||
error_message_test() ->
|
||||
?assertEqual(<<"Voice connection not found">>, error_message(voice_connection_not_found)),
|
||||
?assertEqual(<<"Voice channel not found">>, error_message(voice_channel_not_found)),
|
||||
?assertEqual(<<"Missing voice permissions">>, error_message(voice_permission_denied)),
|
||||
?assertEqual(<<"Voice channel is full">>, error_message(voice_channel_full)),
|
||||
?assertEqual(<<"An unknown error occurred">>, error_message(some_random_error)),
|
||||
?assertEqual(<<"Request timed out">>, error_message(timeout)).
|
||||
|
||||
error_category_test() ->
|
||||
?assertEqual(not_found, error_category(voice_connection_not_found)),
|
||||
?assertEqual(not_found, error_category(voice_channel_not_found)),
|
||||
?assertEqual(not_found, error_category(dm_channel_not_found)),
|
||||
?assertEqual(validation_error, error_category(voice_channel_not_voice)),
|
||||
?assertEqual(validation_error, error_category(validation_invalid_snowflake)),
|
||||
?assertEqual(permission_denied, error_category(voice_permission_denied)),
|
||||
?assertEqual(permission_denied, error_category(voice_channel_full)),
|
||||
?assertEqual(voice_error, error_category(voice_token_failed)),
|
||||
?assertEqual(rate_limited, error_category(voice_update_rate_limited)),
|
||||
?assertEqual(timeout, error_category(timeout)),
|
||||
?assertEqual(unknown, error_category(unknown_error)),
|
||||
?assertEqual(unknown, error_category(some_random_error)).
|
||||
|
||||
is_recoverable_test() ->
|
||||
?assert(is_recoverable(not_found)),
|
||||
?assert(is_recoverable(permission_denied)),
|
||||
?assert(is_recoverable(voice_error)),
|
||||
?assert(is_recoverable(validation_error)),
|
||||
?assert(is_recoverable(timeout)),
|
||||
?assert(is_recoverable(unknown)),
|
||||
?assertNot(is_recoverable(rate_limited)),
|
||||
?assertNot(is_recoverable(auth_failed)).
|
||||
|
||||
all_voice_errors_have_codes_test() ->
|
||||
VoiceErrors = [
|
||||
voice_connection_not_found,
|
||||
voice_channel_not_found,
|
||||
voice_channel_not_voice,
|
||||
voice_member_not_found,
|
||||
voice_user_not_in_voice,
|
||||
voice_guild_not_found,
|
||||
voice_permission_denied,
|
||||
voice_member_timed_out,
|
||||
voice_channel_full,
|
||||
voice_missing_connection_id,
|
||||
voice_invalid_user_id,
|
||||
voice_invalid_channel_id,
|
||||
voice_invalid_state,
|
||||
voice_user_mismatch,
|
||||
voice_token_failed,
|
||||
voice_guild_id_missing,
|
||||
voice_invalid_guild_id,
|
||||
voice_moderator_missing_connect,
|
||||
voice_unclaimed_account,
|
||||
voice_update_rate_limited,
|
||||
voice_nonce_mismatch,
|
||||
voice_pending_expired,
|
||||
voice_camera_user_limit
|
||||
],
|
||||
lists:foreach(
|
||||
fun(Error) ->
|
||||
Code = error_code(Error),
|
||||
?assert(is_binary(Code)),
|
||||
?assertNotEqual(<<"UNKNOWN_ERROR">>, Code)
|
||||
end,
|
||||
VoiceErrors
|
||||
).
|
||||
|
||||
all_dm_errors_have_codes_test() ->
|
||||
DmErrors = [dm_channel_not_found, dm_not_recipient, dm_invalid_channel_type],
|
||||
lists:foreach(
|
||||
fun(Error) ->
|
||||
Code = error_code(Error),
|
||||
?assert(is_binary(Code)),
|
||||
?assertNotEqual(<<"UNKNOWN_ERROR">>, Code)
|
||||
end,
|
||||
DmErrors
|
||||
).
|
||||
|
||||
all_validation_errors_have_codes_test() ->
|
||||
ValidationErrors = [
|
||||
validation_invalid_snowflake,
|
||||
validation_null_snowflake,
|
||||
validation_invalid_snowflake_list,
|
||||
validation_expected_list,
|
||||
validation_expected_map,
|
||||
validation_missing_field,
|
||||
validation_invalid_params
|
||||
],
|
||||
lists:foreach(
|
||||
fun(Error) ->
|
||||
Code = error_code(Error),
|
||||
?assert(is_binary(Code)),
|
||||
?assertNotEqual(<<"UNKNOWN_ERROR">>, Code)
|
||||
end,
|
||||
ValidationErrors
|
||||
).
|
||||
|
||||
-endif.
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
540
fluxer_gateway/src/gateway/gateway_http_client.erl
Normal file
540
fluxer_gateway/src/gateway/gateway_http_client.erl
Normal file
@@ -0,0 +1,540 @@
|
||||
%% Copyright (C) 2026 Fluxer Contributors
|
||||
%%
|
||||
%% This file is part of Fluxer.
|
||||
%%
|
||||
%% Fluxer is free software: you can redistribute it and/or modify
|
||||
%% it under the terms of the GNU Affero General Public License as published by
|
||||
%% the Free Software Foundation, either version 3 of the License, or
|
||||
%% (at your option) any later version.
|
||||
%%
|
||||
%% Fluxer is distributed in the hope that it will be useful,
|
||||
%% but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
%% GNU Affero General Public License for more details.
|
||||
%%
|
||||
%% You should have received a copy of the GNU Affero General Public License
|
||||
%% along with Fluxer. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
-module(gateway_http_client).
|
||||
-behaviour(gen_server).
|
||||
|
||||
-export([start_link/0, request/5, request/6]).
|
||||
-export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]).
|
||||
|
||||
-define(SERVER, ?MODULE).
|
||||
-define(CIRCUIT_TABLE, gateway_http_circuit_breaker).
|
||||
-define(INFLIGHT_TABLE, gateway_http_inflight).
|
||||
|
||||
-define(DEFAULT_RPC_CONNECT_TIMEOUT_MS, 5000).
|
||||
-define(DEFAULT_RPC_RECV_TIMEOUT_MS, 30000).
|
||||
-define(DEFAULT_PUSH_CONNECT_TIMEOUT_MS, 3000).
|
||||
-define(DEFAULT_PUSH_RECV_TIMEOUT_MS, 5000).
|
||||
|
||||
-define(DEFAULT_RPC_MAX_CONCURRENCY, 512).
|
||||
-define(DEFAULT_PUSH_MAX_CONCURRENCY, 256).
|
||||
|
||||
-define(DEFAULT_FAILURE_THRESHOLD, 6).
|
||||
-define(DEFAULT_RECOVERY_TIMEOUT_MS, 15000).
|
||||
-define(DEFAULT_CLEANUP_INTERVAL_MS, 30000).
|
||||
-define(DEFAULT_CLEANUP_MAX_AGE_MS, 300000).
|
||||
|
||||
-type workload() :: rpc | push.
|
||||
-type method() :: get | post | put | patch | delete | head | options.
|
||||
-type request_headers() :: [{binary() | string(), binary() | string()}].
|
||||
-type request_options() :: #{
|
||||
connect_timeout => timeout(),
|
||||
recv_timeout => timeout(),
|
||||
max_concurrency => pos_integer(),
|
||||
failure_threshold => pos_integer(),
|
||||
recovery_timeout_ms => pos_integer(),
|
||||
content_type => binary() | string()
|
||||
}.
|
||||
-type response() :: {ok, non_neg_integer(), [{binary(), binary()}], binary()} | {error, term()}.
|
||||
|
||||
-type state() :: #{}.
|
||||
|
||||
-spec start_link() -> {ok, pid()} | {error, term()}.
|
||||
start_link() ->
|
||||
case whereis(?SERVER) of
|
||||
undefined ->
|
||||
case gen_server:start_link({local, ?SERVER}, ?MODULE, [], []) of
|
||||
{error, {already_started, Pid}} when is_pid(Pid) ->
|
||||
{ok, Pid};
|
||||
Other ->
|
||||
Other
|
||||
end;
|
||||
Pid when is_pid(Pid) ->
|
||||
{ok, Pid}
|
||||
end.
|
||||
|
||||
-spec request(workload(), method(), iodata(), request_headers(), iodata() | undefined) -> response().
|
||||
request(Workload, Method, Url, Headers, Body) ->
|
||||
request(Workload, Method, Url, Headers, Body, #{}).
|
||||
|
||||
-spec request(workload(), method(), iodata(), request_headers(), iodata() | undefined, request_options()) ->
|
||||
response().
|
||||
request(Workload, Method, Url, Headers, Body, Opts) when is_map(Opts) ->
|
||||
ensure_runtime(Workload),
|
||||
WorkloadOpts = merged_workload_options(Workload, Opts),
|
||||
MaxConcurrency = maps:get(max_concurrency, WorkloadOpts),
|
||||
FailureThreshold = maps:get(failure_threshold, WorkloadOpts),
|
||||
RecoveryTimeoutMs = maps:get(recovery_timeout_ms, WorkloadOpts),
|
||||
Host = extract_host_key(Url),
|
||||
CircuitKey = {Workload, Host},
|
||||
case allow_circuit_request(CircuitKey, RecoveryTimeoutMs) of
|
||||
ok ->
|
||||
case acquire_inflight_slot(Workload, MaxConcurrency) of
|
||||
ok ->
|
||||
Result = safe_do_request(Workload, Method, Url, Headers, Body, WorkloadOpts),
|
||||
release_inflight_slot(Workload),
|
||||
update_circuit_state(CircuitKey, Result, FailureThreshold),
|
||||
Result;
|
||||
{error, overloaded} ->
|
||||
{error, overloaded}
|
||||
end;
|
||||
{error, circuit_open} ->
|
||||
{error, circuit_open}
|
||||
end.
|
||||
|
||||
-spec safe_do_request(
|
||||
workload(), method(), iodata(), request_headers(), iodata() | undefined, request_options()
|
||||
) ->
|
||||
response().
|
||||
safe_do_request(Workload, Method, Url, Headers, Body, Opts) ->
|
||||
try do_request(Workload, Method, Url, Headers, Body, Opts) of
|
||||
Result ->
|
||||
Result
|
||||
catch
|
||||
Class:Reason:Stacktrace ->
|
||||
{error,
|
||||
{request_exception, #{
|
||||
class => Class,
|
||||
reason => Reason,
|
||||
frame => first_stack_frame(Stacktrace),
|
||||
workload => Workload,
|
||||
method => Method,
|
||||
url => ensure_binary(Url)
|
||||
}}}
|
||||
end.
|
||||
|
||||
-spec init([]) -> {ok, state()}.
|
||||
init([]) ->
|
||||
process_flag(trap_exit, true),
|
||||
ensure_table(
|
||||
?CIRCUIT_TABLE,
|
||||
[named_table, public, set, {read_concurrency, true}, {write_concurrency, true}]
|
||||
),
|
||||
ensure_table(
|
||||
?INFLIGHT_TABLE,
|
||||
[named_table, public, set, {read_concurrency, true}, {write_concurrency, true}]
|
||||
),
|
||||
ok = ensure_httpc_profile(profile_for(rpc), rpc),
|
||||
ok = ensure_httpc_profile(profile_for(push), push),
|
||||
schedule_cleanup(),
|
||||
{ok, #{}}.
|
||||
|
||||
-spec handle_call(term(), gen_server:from(), state()) -> {reply, ok, state()}.
|
||||
handle_call(_Request, _From, State) ->
|
||||
{reply, ok, State}.
|
||||
|
||||
-spec handle_cast(term(), state()) -> {noreply, state()}.
|
||||
handle_cast(_Msg, State) ->
|
||||
{noreply, State}.
|
||||
|
||||
-spec handle_info(term(), state()) -> {noreply, state()}.
|
||||
handle_info(cleanup_circuits, State) ->
|
||||
prune_circuit_table(),
|
||||
schedule_cleanup(),
|
||||
{noreply, State};
|
||||
handle_info(_Info, State) ->
|
||||
{noreply, State}.
|
||||
|
||||
-spec terminate(term(), state()) -> ok.
|
||||
terminate(_Reason, _State) ->
|
||||
ok.
|
||||
|
||||
-spec code_change(term(), state(), term()) -> {ok, state()}.
|
||||
code_change(_OldVsn, State, _Extra) ->
|
||||
{ok, State}.
|
||||
|
||||
-spec ensure_table(atom(), [term()]) -> ok.
|
||||
ensure_table(Name, Options) ->
|
||||
case ets:whereis(Name) of
|
||||
undefined ->
|
||||
try
|
||||
_ = ets:new(Name, Options),
|
||||
ok
|
||||
catch
|
||||
error:badarg -> ok
|
||||
end;
|
||||
_ ->
|
||||
ok
|
||||
end.
|
||||
|
||||
-spec ensure_runtime(workload()) -> ok.
|
||||
ensure_runtime(Workload) ->
|
||||
ok = ensure_started(),
|
||||
_ = Workload,
|
||||
ok.
|
||||
|
||||
-spec ensure_started() -> ok.
|
||||
ensure_started() ->
|
||||
case start_link() of
|
||||
{ok, _Pid} ->
|
||||
ok;
|
||||
_ ->
|
||||
ok
|
||||
end.
|
||||
|
||||
-spec schedule_cleanup() -> reference().
|
||||
schedule_cleanup() ->
|
||||
erlang:send_after(cleanup_interval_ms(), self(), cleanup_circuits).
|
||||
|
||||
-spec prune_circuit_table() -> ok.
|
||||
prune_circuit_table() ->
|
||||
Now = erlang:system_time(millisecond),
|
||||
MaxAgeMs = cleanup_max_age_ms(),
|
||||
_ =
|
||||
ets:foldl(
|
||||
fun({Key, CircuitState}, Acc) ->
|
||||
case is_stale_circuit(CircuitState, Now, MaxAgeMs) of
|
||||
true ->
|
||||
ets:delete(?CIRCUIT_TABLE, Key),
|
||||
Acc;
|
||||
false ->
|
||||
Acc
|
||||
end
|
||||
end,
|
||||
ok,
|
||||
?CIRCUIT_TABLE
|
||||
),
|
||||
ok.
|
||||
|
||||
-spec is_stale_circuit(map(), integer(), integer()) -> boolean().
|
||||
is_stale_circuit(#{state := open, opened_at := OpenedAt}, Now, MaxAgeMs) ->
|
||||
Now - OpenedAt > MaxAgeMs;
|
||||
is_stale_circuit(#{state := closed, failures := 0, updated_at := UpdatedAt}, Now, MaxAgeMs) ->
|
||||
Now - UpdatedAt > MaxAgeMs;
|
||||
is_stale_circuit(_, _, _) ->
|
||||
false.
|
||||
|
||||
-spec allow_circuit_request({workload(), binary()}, pos_integer()) -> ok | {error, circuit_open}.
|
||||
allow_circuit_request(CircuitKey, RecoveryTimeoutMs) ->
|
||||
Now = erlang:system_time(millisecond),
|
||||
case safe_lookup_circuit(CircuitKey) of
|
||||
[] ->
|
||||
ok;
|
||||
[{_, #{state := open, opened_at := OpenedAt}} = Entry] ->
|
||||
case Now - OpenedAt >= RecoveryTimeoutMs of
|
||||
true ->
|
||||
{_, State0} = Entry,
|
||||
NewState = State0#{state => half_open, updated_at => Now},
|
||||
ets:insert(?CIRCUIT_TABLE, {CircuitKey, NewState}),
|
||||
ok;
|
||||
false ->
|
||||
{error, circuit_open}
|
||||
end;
|
||||
_ ->
|
||||
ok
|
||||
end.
|
||||
|
||||
-spec safe_lookup_circuit({workload(), binary()}) -> list().
|
||||
safe_lookup_circuit(Key) ->
|
||||
try ets:lookup(?CIRCUIT_TABLE, Key) of
|
||||
Result -> Result
|
||||
catch
|
||||
error:badarg -> []
|
||||
end.
|
||||
|
||||
-spec acquire_inflight_slot(workload(), pos_integer()) -> ok | {error, overloaded}.
|
||||
acquire_inflight_slot(Workload, MaxConcurrency) ->
|
||||
case safe_update_counter(?INFLIGHT_TABLE, Workload, {2, 1}) of
|
||||
{ok, Count} when Count =< MaxConcurrency ->
|
||||
ok;
|
||||
{ok, _Count} ->
|
||||
_ = safe_update_counter(?INFLIGHT_TABLE, Workload, {2, -1}),
|
||||
{error, overloaded};
|
||||
{error, _Reason} ->
|
||||
{error, overloaded}
|
||||
end.
|
||||
|
||||
-spec release_inflight_slot(workload()) -> ok.
|
||||
release_inflight_slot(Workload) ->
|
||||
_ = safe_update_counter(?INFLIGHT_TABLE, Workload, {2, -1}),
|
||||
ok.
|
||||
|
||||
-spec safe_update_counter(atom(), term(), {pos_integer(), integer()}) ->
|
||||
{ok, integer()} | {error, term()}.
|
||||
safe_update_counter(Table, Key, Op) ->
|
||||
try
|
||||
{ok, ets:update_counter(Table, Key, Op, {Key, 0})}
|
||||
catch
|
||||
error:badarg ->
|
||||
ok = ensure_started(),
|
||||
try
|
||||
{ok, ets:update_counter(Table, Key, Op, {Key, 0})}
|
||||
catch
|
||||
error:badarg ->
|
||||
{error, badarg}
|
||||
end
|
||||
end.
|
||||
|
||||
-spec update_circuit_state({workload(), binary()}, response(), pos_integer()) -> ok.
|
||||
update_circuit_state(CircuitKey, Result, FailureThreshold) ->
|
||||
Now = erlang:system_time(millisecond),
|
||||
case should_count_failure(Result) of
|
||||
true ->
|
||||
record_failure(CircuitKey, FailureThreshold, Now);
|
||||
false ->
|
||||
record_success(CircuitKey, Now)
|
||||
end.
|
||||
|
||||
-spec should_count_failure(response()) -> boolean().
|
||||
should_count_failure({error, _Reason}) ->
|
||||
true;
|
||||
should_count_failure({ok, StatusCode, _Headers, _Body}) when StatusCode >= 500 ->
|
||||
true;
|
||||
should_count_failure(_) ->
|
||||
false.
|
||||
|
||||
-spec record_failure({workload(), binary()}, pos_integer(), integer()) -> ok.
|
||||
record_failure(CircuitKey, Threshold, Now) ->
|
||||
case safe_lookup_circuit(CircuitKey) of
|
||||
[] ->
|
||||
ets:insert(?CIRCUIT_TABLE, {CircuitKey, #{
|
||||
state => closed,
|
||||
failures => 1,
|
||||
opened_at => undefined,
|
||||
updated_at => Now
|
||||
}}),
|
||||
ok;
|
||||
[{_, #{failures := Failures} = Existing}] ->
|
||||
NewFailures = Failures + 1,
|
||||
NewState =
|
||||
case NewFailures >= Threshold of
|
||||
true -> open;
|
||||
false -> maps:get(state, Existing, closed)
|
||||
end,
|
||||
OpenedAt =
|
||||
case NewState of
|
||||
open -> Now;
|
||||
_ -> maps:get(opened_at, Existing, undefined)
|
||||
end,
|
||||
ets:insert(?CIRCUIT_TABLE, {CircuitKey, Existing#{
|
||||
state => NewState,
|
||||
failures => NewFailures,
|
||||
opened_at => OpenedAt,
|
||||
updated_at => Now
|
||||
}}),
|
||||
ok
|
||||
end.
|
||||
|
||||
-spec record_success({workload(), binary()}, integer()) -> ok.
|
||||
record_success(CircuitKey, Now) ->
|
||||
case safe_lookup_circuit(CircuitKey) of
|
||||
[] ->
|
||||
ok;
|
||||
[{_, Existing}] ->
|
||||
NewState = Existing#{
|
||||
state => closed,
|
||||
failures => 0,
|
||||
updated_at => Now
|
||||
},
|
||||
ets:insert(?CIRCUIT_TABLE, {CircuitKey, maps:remove(opened_at, NewState)}),
|
||||
ok
|
||||
end.
|
||||
|
||||
-spec do_request(workload(), method(), iodata(), request_headers(), iodata() | undefined, request_options()) ->
|
||||
response().
|
||||
do_request(Workload, Method, Url, Headers, Body, Opts) ->
|
||||
HttpMethod = normalize_method(Method),
|
||||
UrlString = ensure_list(Url),
|
||||
RequestHeaders = normalize_request_headers(Headers),
|
||||
RequestTuple = build_request_tuple(UrlString, RequestHeaders, Body, Opts),
|
||||
ConnectTimeout = maps:get(connect_timeout, Opts),
|
||||
RecvTimeout = maps:get(recv_timeout, Opts),
|
||||
HttpOptions = [
|
||||
{connect_timeout, ConnectTimeout},
|
||||
{timeout, RecvTimeout},
|
||||
{autoredirect, false}
|
||||
],
|
||||
RequestOptions = [{body_format, binary}],
|
||||
case httpc:request(HttpMethod, RequestTuple, HttpOptions, RequestOptions, profile_for(Workload)) of
|
||||
{ok, {{_HttpVersion, StatusCode, _ReasonPhrase}, RespHeaders, RespBody}} ->
|
||||
{ok, StatusCode, normalize_response_headers(RespHeaders), ensure_binary(RespBody)};
|
||||
{error, Reason} ->
|
||||
{error, Reason}
|
||||
end.
|
||||
|
||||
-spec normalize_method(method() | atom()) -> method().
|
||||
normalize_method(post) -> post;
|
||||
normalize_method(get) -> get;
|
||||
normalize_method(put) -> put;
|
||||
normalize_method(patch) -> patch;
|
||||
normalize_method(delete) -> delete;
|
||||
normalize_method(head) -> head;
|
||||
normalize_method(options) -> options;
|
||||
normalize_method(_) -> post.
|
||||
|
||||
-spec build_request_tuple(string(), [{string(), string()}], iodata() | undefined, request_options()) ->
|
||||
{string(), [{string(), string()}]}
|
||||
| {string(), [{string(), string()}], string(), iodata()}.
|
||||
build_request_tuple(Url, Headers, undefined, _Opts) ->
|
||||
{Url, Headers};
|
||||
build_request_tuple(Url, Headers, Body, Opts) ->
|
||||
ContentType = resolve_content_type(Headers, Opts),
|
||||
{Url, Headers, ContentType, Body}.
|
||||
|
||||
-spec resolve_content_type([{string(), string()}], request_options()) -> string().
|
||||
resolve_content_type(Headers, Opts) ->
|
||||
case maps:get(content_type, Opts, undefined) of
|
||||
undefined ->
|
||||
case find_content_type_header(Headers) of
|
||||
undefined -> "application/json";
|
||||
Value -> Value
|
||||
end;
|
||||
Value ->
|
||||
ensure_list(Value)
|
||||
end.
|
||||
|
||||
-spec find_content_type_header([{string(), string()}]) -> string() | undefined.
|
||||
find_content_type_header([]) ->
|
||||
undefined;
|
||||
find_content_type_header([{Name, Value} | Rest]) ->
|
||||
case string:lowercase(Name) of
|
||||
"content-type" -> Value;
|
||||
_ -> find_content_type_header(Rest)
|
||||
end.
|
||||
|
||||
-spec normalize_request_headers(request_headers()) -> [{string(), string()}].
|
||||
normalize_request_headers(Headers) ->
|
||||
[
|
||||
{ensure_list(Name), ensure_list(Value)}
|
||||
|| {Name, Value} <- Headers
|
||||
].
|
||||
|
||||
-spec normalize_response_headers([{string(), string()}]) -> [{binary(), binary()}].
|
||||
normalize_response_headers(Headers) ->
|
||||
[
|
||||
{list_to_binary(Name), list_to_binary(Value)}
|
||||
|| {Name, Value} <- Headers
|
||||
].
|
||||
|
||||
-spec extract_host_key(iodata()) -> binary().
|
||||
extract_host_key(Url) ->
|
||||
UrlString = ensure_list(Url),
|
||||
try
|
||||
Parsed = uri_string:parse(UrlString),
|
||||
case maps:get(host, Parsed, undefined) of
|
||||
undefined -> <<"unknown">>;
|
||||
Host when is_binary(Host) -> normalize_host(Host);
|
||||
Host when is_list(Host) -> normalize_host(list_to_binary(Host));
|
||||
_ -> <<"unknown">>
|
||||
end
|
||||
catch
|
||||
_:_ -> <<"unknown">>
|
||||
end.
|
||||
|
||||
-spec normalize_host(binary()) -> binary().
|
||||
normalize_host(Host) ->
|
||||
list_to_binary(string:lowercase(binary_to_list(Host))).
|
||||
|
||||
-spec ensure_binary(iodata()) -> binary().
|
||||
ensure_binary(Value) when is_binary(Value) ->
|
||||
Value;
|
||||
ensure_binary(Value) ->
|
||||
iolist_to_binary(Value).
|
||||
|
||||
-spec ensure_list(iodata()) -> string().
|
||||
ensure_list(Value) when is_binary(Value) ->
|
||||
binary_to_list(Value);
|
||||
ensure_list(Value) when is_list(Value) ->
|
||||
Value;
|
||||
ensure_list(Value) when is_atom(Value) ->
|
||||
atom_to_list(Value);
|
||||
ensure_list(Value) when is_integer(Value) ->
|
||||
integer_to_list(Value);
|
||||
ensure_list(_Value) ->
|
||||
"".
|
||||
|
||||
-spec ensure_httpc_profile(atom(), workload()) -> ok.
|
||||
ensure_httpc_profile(Profile, Workload) ->
|
||||
_ =
|
||||
case inets:start(httpc, [{profile, Profile}]) of
|
||||
{ok, _Pid} -> ok;
|
||||
{error, {already_started, _Pid}} -> ok;
|
||||
{error, {already_started, _Pid, _}} -> ok;
|
||||
{error, _Reason} -> ok
|
||||
end,
|
||||
Options = workload_httpc_options(Workload),
|
||||
_ = httpc:set_options(Options, Profile),
|
||||
ok.
|
||||
|
||||
-spec workload_httpc_options(workload()) -> list().
|
||||
workload_httpc_options(rpc) ->
|
||||
[
|
||||
{max_sessions, 1024},
|
||||
{max_keep_alive_length, 256}
|
||||
];
|
||||
workload_httpc_options(push) ->
|
||||
[
|
||||
{max_sessions, 2048},
|
||||
{max_keep_alive_length, 512}
|
||||
].
|
||||
|
||||
-spec merged_workload_options(workload(), request_options()) -> request_options().
|
||||
merged_workload_options(Workload, Opts) ->
|
||||
maps:merge(default_options(Workload), Opts).
|
||||
|
||||
-spec default_options(workload()) -> request_options().
|
||||
default_options(rpc) ->
|
||||
#{
|
||||
connect_timeout => get_int_or_default(gateway_http_rpc_connect_timeout_ms, ?DEFAULT_RPC_CONNECT_TIMEOUT_MS),
|
||||
recv_timeout => get_int_or_default(gateway_http_rpc_recv_timeout_ms, ?DEFAULT_RPC_RECV_TIMEOUT_MS),
|
||||
max_concurrency =>
|
||||
get_int_or_default(gateway_http_rpc_max_concurrency, ?DEFAULT_RPC_MAX_CONCURRENCY),
|
||||
failure_threshold =>
|
||||
get_int_or_default(gateway_http_failure_threshold, ?DEFAULT_FAILURE_THRESHOLD),
|
||||
recovery_timeout_ms =>
|
||||
get_int_or_default(gateway_http_recovery_timeout_ms, ?DEFAULT_RECOVERY_TIMEOUT_MS),
|
||||
content_type => <<"application/json">>
|
||||
};
|
||||
default_options(push) ->
|
||||
#{
|
||||
connect_timeout => get_int_or_default(gateway_http_push_connect_timeout_ms, ?DEFAULT_PUSH_CONNECT_TIMEOUT_MS),
|
||||
recv_timeout => get_int_or_default(gateway_http_push_recv_timeout_ms, ?DEFAULT_PUSH_RECV_TIMEOUT_MS),
|
||||
max_concurrency =>
|
||||
get_int_or_default(gateway_http_push_max_concurrency, ?DEFAULT_PUSH_MAX_CONCURRENCY),
|
||||
failure_threshold =>
|
||||
get_int_or_default(gateway_http_failure_threshold, ?DEFAULT_FAILURE_THRESHOLD),
|
||||
recovery_timeout_ms =>
|
||||
get_int_or_default(gateway_http_recovery_timeout_ms, ?DEFAULT_RECOVERY_TIMEOUT_MS),
|
||||
content_type => <<"application/octet-stream">>
|
||||
}.
|
||||
|
||||
-spec cleanup_interval_ms() -> pos_integer().
|
||||
cleanup_interval_ms() ->
|
||||
get_int_or_default(gateway_http_cleanup_interval_ms, ?DEFAULT_CLEANUP_INTERVAL_MS).
|
||||
|
||||
-spec cleanup_max_age_ms() -> pos_integer().
|
||||
cleanup_max_age_ms() ->
|
||||
get_int_or_default(gateway_http_cleanup_max_age_ms, ?DEFAULT_CLEANUP_MAX_AGE_MS).
|
||||
|
||||
-spec get_int_or_default(atom(), integer()) -> integer().
|
||||
get_int_or_default(Key, Default) ->
|
||||
case fluxer_gateway_env:get_optional(Key) of
|
||||
Value when is_integer(Value), Value > 0 -> Value;
|
||||
_ -> Default
|
||||
end.
|
||||
|
||||
-spec profile_for(workload()) -> atom().
|
||||
profile_for(rpc) ->
|
||||
gateway_http_rpc_profile;
|
||||
profile_for(push) ->
|
||||
gateway_http_push_profile.
|
||||
|
||||
-spec first_stack_frame(list()) -> term().
|
||||
first_stack_frame([Frame | _]) ->
|
||||
Frame;
|
||||
first_stack_frame([]) ->
|
||||
undefined.
|
||||
@@ -19,21 +19,38 @@
|
||||
|
||||
-export([execute_method/2]).
|
||||
|
||||
-ifdef(TEST).
|
||||
-include_lib("eunit/include/eunit.hrl").
|
||||
-endif.
|
||||
|
||||
-define(CALL_LOOKUP_TIMEOUT, 2000).
|
||||
-define(CALL_CREATE_TIMEOUT, 10000).
|
||||
|
||||
-spec execute_method(binary(), map()) -> term().
|
||||
execute_method(<<"call.get">>, #{<<"channel_id">> := ChannelIdBin}) ->
|
||||
ChannelId = validation:snowflake_or_throw(<<"channel_id">>, ChannelIdBin),
|
||||
case gen_server:call(call_manager, {lookup, ChannelId}, 5000) of
|
||||
case lookup_call(ChannelId) of
|
||||
{ok, Pid} ->
|
||||
case gen_server:call(Pid, {get_state}, 5000) of
|
||||
{ok, CallData} ->
|
||||
CallData;
|
||||
_ ->
|
||||
throw({error, <<"Failed to get call state">>})
|
||||
case gen_server:call(Pid, {get_state}, ?CALL_LOOKUP_TIMEOUT) of
|
||||
{ok, CallData} -> CallData;
|
||||
_ -> throw({error, <<"call_state_error">>})
|
||||
end;
|
||||
{error, not_found} ->
|
||||
null;
|
||||
not_found ->
|
||||
null
|
||||
end;
|
||||
execute_method(<<"call.get_pending_joins">>, #{<<"channel_id">> := ChannelIdBin}) ->
|
||||
ChannelId = validation:snowflake_or_throw(<<"channel_id">>, ChannelIdBin),
|
||||
case lookup_call(ChannelId) of
|
||||
{ok, Pid} ->
|
||||
case gen_server:call(Pid, {get_pending_connections}, ?CALL_LOOKUP_TIMEOUT) of
|
||||
#{pending_joins := PendingJoins} ->
|
||||
#{<<"pending_joins">> => PendingJoins};
|
||||
_ ->
|
||||
throw({error, <<"call_pending_joins_error">>})
|
||||
end;
|
||||
not_found ->
|
||||
#{<<"pending_joins">> => []}
|
||||
end;
|
||||
execute_method(<<"call.create">>, Params) ->
|
||||
#{
|
||||
<<"channel_id">> := ChannelIdBin,
|
||||
@@ -42,12 +59,10 @@ execute_method(<<"call.create">>, Params) ->
|
||||
<<"ringing">> := RingingBins,
|
||||
<<"recipients">> := RecipientsBins
|
||||
} = Params,
|
||||
|
||||
ChannelId = validation:snowflake_or_throw(<<"channel_id">>, ChannelIdBin),
|
||||
MessageId = validation:snowflake_or_throw(<<"message_id">>, MessageIdBin),
|
||||
Ringing = validation:snowflake_list_or_throw(<<"ringing">>, RingingBins),
|
||||
Recipients = validation:snowflake_list_or_throw(<<"recipients">>, RecipientsBins),
|
||||
|
||||
CallData = #{
|
||||
channel_id => ChannelId,
|
||||
message_id => MessageId,
|
||||
@@ -55,161 +70,123 @@ execute_method(<<"call.create">>, Params) ->
|
||||
ringing => Ringing,
|
||||
recipients => Recipients
|
||||
},
|
||||
|
||||
case gen_server:call(call_manager, {create, ChannelId, CallData}, 10000) of
|
||||
case gen_server:call(call_manager, {create, ChannelId, CallData}, ?CALL_CREATE_TIMEOUT) of
|
||||
{ok, Pid} ->
|
||||
case gen_server:call(Pid, {get_state}, 5000) of
|
||||
{ok, CallState} ->
|
||||
CallState;
|
||||
_ ->
|
||||
throw({error, <<"Failed to get call state after creation">>})
|
||||
case gen_server:call(Pid, {get_state}, ?CALL_LOOKUP_TIMEOUT) of
|
||||
{ok, CallState} -> CallState;
|
||||
_ -> throw({error, <<"call_state_error">>})
|
||||
end;
|
||||
{error, already_exists} ->
|
||||
throw({error, <<"Call already exists">>});
|
||||
throw({error, <<"call_already_exists">>});
|
||||
{error, Reason} ->
|
||||
throw({error, iolist_to_binary(io_lib:format("Failed to create call: ~p", [Reason]))})
|
||||
throw({error, iolist_to_binary(io_lib:format("create_call_error: ~p", [Reason]))})
|
||||
end;
|
||||
execute_method(<<"call.update_region">>, #{
|
||||
<<"channel_id">> := ChannelIdBin, <<"region">> := Region
|
||||
execute_method(<<"call.update_region">>, #{<<"channel_id">> := ChannelIdBin, <<"region">> := Region}) ->
|
||||
ChannelId = validation:snowflake_or_throw(<<"channel_id">>, ChannelIdBin),
|
||||
with_call(ChannelId, fun(Pid) ->
|
||||
case gen_server:call(Pid, {update_region, Region}, ?CALL_LOOKUP_TIMEOUT) of
|
||||
ok -> true;
|
||||
_ -> throw({error, <<"update_region_error">>})
|
||||
end
|
||||
end);
|
||||
execute_method(<<"call.ring">>, #{
|
||||
<<"channel_id">> := ChannelIdBin, <<"recipients">> := RecipientsBin
|
||||
}) ->
|
||||
ChannelId = validation:snowflake_or_throw(<<"channel_id">>, ChannelIdBin),
|
||||
case gen_server:call(call_manager, {lookup, ChannelId}, 5000) of
|
||||
{ok, Pid} ->
|
||||
case gen_server:call(Pid, {update_region, Region}, 5000) of
|
||||
ok ->
|
||||
true;
|
||||
_ ->
|
||||
throw({error, <<"Failed to update region">>})
|
||||
end;
|
||||
not_found ->
|
||||
throw({error, <<"Call not found">>})
|
||||
end;
|
||||
execute_method(<<"call.ring">>, Params) ->
|
||||
#{<<"channel_id">> := ChannelIdBin, <<"recipients">> := RecipientsBin} = Params,
|
||||
|
||||
Recipients = validation:snowflake_list_or_throw(<<"recipients">>, RecipientsBin),
|
||||
with_call(ChannelId, fun(Pid) ->
|
||||
case gen_server:call(Pid, {ring_recipients, Recipients}, ?CALL_LOOKUP_TIMEOUT) of
|
||||
ok -> true;
|
||||
_ -> throw({error, <<"ring_recipients_error">>})
|
||||
end
|
||||
end);
|
||||
execute_method(<<"call.stop_ringing">>, #{
|
||||
<<"channel_id">> := ChannelIdBin, <<"recipients">> := RecipientsBin
|
||||
}) ->
|
||||
ChannelId = validation:snowflake_or_throw(<<"channel_id">>, ChannelIdBin),
|
||||
Recipients = validation:snowflake_list_or_throw(<<"recipients">>, RecipientsBin),
|
||||
|
||||
case gen_server:call(call_manager, {lookup, ChannelId}, 5000) of
|
||||
{ok, Pid} ->
|
||||
case gen_server:call(Pid, {ring_recipients, Recipients}, 5000) of
|
||||
ok ->
|
||||
true;
|
||||
_ ->
|
||||
throw({error, <<"Failed to ring recipients">>})
|
||||
end;
|
||||
not_found ->
|
||||
throw({error, <<"Call not found">>})
|
||||
end;
|
||||
execute_method(<<"call.stop_ringing">>, Params) ->
|
||||
#{<<"channel_id">> := ChannelIdBin, <<"recipients">> := RecipientsBin} = Params,
|
||||
|
||||
ChannelId = validation:snowflake_or_throw(<<"channel_id">>, ChannelIdBin),
|
||||
Recipients = validation:snowflake_list_or_throw(<<"recipients">>, RecipientsBin),
|
||||
|
||||
case gen_server:call(call_manager, {lookup, ChannelId}, 5000) of
|
||||
{ok, Pid} ->
|
||||
case gen_server:call(Pid, {stop_ringing, Recipients}, 5000) of
|
||||
ok ->
|
||||
true;
|
||||
_ ->
|
||||
throw({error, <<"Failed to stop ringing">>})
|
||||
end;
|
||||
not_found ->
|
||||
throw({error, <<"Call not found">>})
|
||||
end;
|
||||
execute_method(<<"call.join">>, Params) ->
|
||||
#{
|
||||
<<"channel_id">> := ChannelIdBin,
|
||||
<<"user_id">> := UserIdBin,
|
||||
<<"session_id">> := SessionIdBin,
|
||||
<<"voice_state">> := VoiceState
|
||||
} = Params,
|
||||
|
||||
with_call(ChannelId, fun(Pid) ->
|
||||
case gen_server:call(Pid, {stop_ringing, Recipients}, ?CALL_LOOKUP_TIMEOUT) of
|
||||
ok -> true;
|
||||
_ -> throw({error, <<"stop_ringing_error">>})
|
||||
end
|
||||
end);
|
||||
execute_method(<<"call.join">>, #{
|
||||
<<"channel_id">> := ChannelIdBin,
|
||||
<<"user_id">> := UserIdBin,
|
||||
<<"session_id">> := SessionIdBin,
|
||||
<<"voice_state">> := VoiceState
|
||||
}) ->
|
||||
ChannelId = validation:snowflake_or_throw(<<"channel_id">>, ChannelIdBin),
|
||||
UserId = validation:snowflake_or_throw(<<"user_id">>, UserIdBin),
|
||||
SessionId = SessionIdBin,
|
||||
|
||||
case gen_server:call(session_manager, {lookup, SessionId}, 5000) of
|
||||
case session_manager:lookup(SessionId) of
|
||||
{ok, SessionPid} ->
|
||||
case gen_server:call(call_manager, {lookup, ChannelId}, 5000) of
|
||||
{ok, CallPid} ->
|
||||
case
|
||||
gen_server:call(
|
||||
CallPid, {join, UserId, VoiceState, SessionId, SessionPid}, 5000
|
||||
)
|
||||
of
|
||||
ok ->
|
||||
true;
|
||||
_ ->
|
||||
throw({error, <<"Failed to join call">>})
|
||||
end;
|
||||
not_found ->
|
||||
throw({error, <<"Call not found">>})
|
||||
end;
|
||||
not_found ->
|
||||
throw({error, <<"Session not found">>})
|
||||
with_call(ChannelId, fun(CallPid) ->
|
||||
gen_server:cast(CallPid, {join_async, UserId, VoiceState, SessionId, SessionPid}),
|
||||
true
|
||||
end);
|
||||
{error, not_found} ->
|
||||
throw({error, <<"session_not_found">>})
|
||||
end;
|
||||
execute_method(<<"call.leave">>, #{<<"channel_id">> := ChannelIdBin, <<"session_id">> := SessionId}) ->
|
||||
ChannelId = validation:snowflake_or_throw(<<"channel_id">>, ChannelIdBin),
|
||||
|
||||
case gen_server:call(call_manager, {lookup, ChannelId}, 5000) of
|
||||
{ok, Pid} ->
|
||||
case gen_server:call(Pid, {leave, SessionId}, 5000) of
|
||||
ok ->
|
||||
true;
|
||||
_ ->
|
||||
throw({error, <<"Failed to leave call">>})
|
||||
end;
|
||||
not_found ->
|
||||
throw({error, <<"Call not found">>})
|
||||
end;
|
||||
with_call(ChannelId, fun(Pid) ->
|
||||
case gen_server:call(Pid, {leave, SessionId}, ?CALL_LOOKUP_TIMEOUT) of
|
||||
ok -> true;
|
||||
_ -> throw({error, <<"leave_call_error">>})
|
||||
end
|
||||
end);
|
||||
execute_method(<<"call.delete">>, #{<<"channel_id">> := ChannelIdBin}) ->
|
||||
ChannelId = validation:snowflake_or_throw(<<"channel_id">>, ChannelIdBin),
|
||||
case gen_server:call(call_manager, {terminate_call, ChannelId}, 5000) of
|
||||
ok ->
|
||||
true;
|
||||
{error, not_found} ->
|
||||
throw({error, <<"Call not found">>});
|
||||
_ ->
|
||||
throw({error, <<"Failed to delete call">>})
|
||||
case gen_server:call(call_manager, {terminate_call, ChannelId}, ?CALL_LOOKUP_TIMEOUT) of
|
||||
ok -> true;
|
||||
{error, not_found} -> throw({error, <<"call_not_found">>});
|
||||
_ -> throw({error, <<"delete_call_error">>})
|
||||
end;
|
||||
execute_method(<<"call.confirm_connection">>, Params) ->
|
||||
#{<<"channel_id">> := ChannelIdBin, <<"connection_id">> := ConnectionId} = Params,
|
||||
execute_method(<<"call.confirm_connection">>, #{
|
||||
<<"channel_id">> := ChannelIdBin, <<"connection_id">> := ConnectionId
|
||||
}) ->
|
||||
ChannelId = validation:snowflake_or_throw(<<"channel_id">>, ChannelIdBin),
|
||||
logger:debug(
|
||||
"[gateway_rpc_call] call.confirm_connection channel_id=~p connection_id=~p",
|
||||
[ChannelId, ConnectionId]
|
||||
),
|
||||
case gen_server:call(call_manager, {lookup, ChannelId}, 5000) of
|
||||
case lookup_call(ChannelId) of
|
||||
{ok, Pid} ->
|
||||
gen_server:call(Pid, {confirm_connection, ConnectionId}, 5000);
|
||||
{error, not_found} ->
|
||||
logger:debug(
|
||||
"[gateway_rpc_call] call.confirm_connection call not found for channel_id=~p", [
|
||||
ChannelId
|
||||
]
|
||||
),
|
||||
#{success => true, call_not_found => true};
|
||||
gen_server:call(Pid, {confirm_connection, ConnectionId}, ?CALL_LOOKUP_TIMEOUT);
|
||||
not_found ->
|
||||
logger:debug(
|
||||
"[gateway_rpc_call] call.confirm_connection call manager returned not_found for channel_id=~p",
|
||||
[ChannelId]
|
||||
),
|
||||
#{success => true, call_not_found => true}
|
||||
end;
|
||||
execute_method(<<"call.disconnect_user_if_in_channel">>, Params) ->
|
||||
#{<<"channel_id">> := ChannelIdBin, <<"user_id">> := UserIdBin} = Params,
|
||||
execute_method(
|
||||
<<"call.disconnect_user_if_in_channel">>,
|
||||
#{<<"channel_id">> := ChannelIdBin, <<"user_id">> := UserIdBin} = Params
|
||||
) ->
|
||||
ChannelId = validation:snowflake_or_throw(<<"channel_id">>, ChannelIdBin),
|
||||
UserId = validation:snowflake_or_throw(<<"user_id">>, UserIdBin),
|
||||
ConnectionId = maps:get(<<"connection_id">>, Params, undefined),
|
||||
case gen_server:call(call_manager, {lookup, ChannelId}, 5000) of
|
||||
case lookup_call(ChannelId) of
|
||||
{ok, Pid} ->
|
||||
gen_server:call(
|
||||
Pid, {disconnect_user_if_in_channel, UserId, ChannelId, ConnectionId}, 5000
|
||||
Pid,
|
||||
{disconnect_user_if_in_channel, UserId, ChannelId, ConnectionId},
|
||||
?CALL_LOOKUP_TIMEOUT
|
||||
);
|
||||
{error, not_found} ->
|
||||
#{success => true, call_not_found => true};
|
||||
not_found ->
|
||||
#{success => true, call_not_found => true}
|
||||
end.
|
||||
|
||||
-spec lookup_call(integer()) -> {ok, pid()} | not_found.
|
||||
lookup_call(ChannelId) ->
|
||||
case gen_server:call(call_manager, {lookup, ChannelId}, ?CALL_LOOKUP_TIMEOUT) of
|
||||
{ok, Pid} -> {ok, Pid};
|
||||
{error, not_found} -> not_found;
|
||||
not_found -> not_found
|
||||
end.
|
||||
|
||||
-spec with_call(integer(), fun((pid()) -> T)) -> T when T :: term().
|
||||
with_call(ChannelId, Fun) ->
|
||||
case lookup_call(ChannelId) of
|
||||
{ok, Pid} -> Fun(Pid);
|
||||
not_found -> throw({error, <<"call_not_found">>})
|
||||
end.
|
||||
|
||||
-ifdef(TEST).
|
||||
|
||||
-endif.
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -21,6 +21,7 @@
|
||||
|
||||
-define(JSON_HEADERS, #{<<"content-type">> => <<"application/json">>}).
|
||||
|
||||
-spec init(cowboy_req:req(), term()) -> {ok, cowboy_req:req(), term()}.
|
||||
init(Req0, State) ->
|
||||
case cowboy_req:method(Req0) of
|
||||
<<"POST">> ->
|
||||
@@ -30,27 +31,13 @@ init(Req0, State) ->
|
||||
{ok, Req, State}
|
||||
end.
|
||||
|
||||
-spec handle_post(cowboy_req:req(), term()) -> {ok, cowboy_req:req(), term()}.
|
||||
handle_post(Req0, State) ->
|
||||
case authorize(Req0) of
|
||||
ok ->
|
||||
case read_body(Req0) of
|
||||
{ok, Decoded, Req1} ->
|
||||
case maps:get(<<"method">>, Decoded, undefined) of
|
||||
undefined ->
|
||||
respond(400, #{<<"error">> => <<"Missing method">>}, Req1, State);
|
||||
Method when is_binary(Method) ->
|
||||
ParamsValue = maps:get(<<"params">>, Decoded, #{}),
|
||||
case is_map(ParamsValue) of
|
||||
true ->
|
||||
execute_method(Method, ParamsValue, Req1, State);
|
||||
false ->
|
||||
respond(
|
||||
400, #{<<"error">> => <<"Invalid params">>}, Req1, State
|
||||
)
|
||||
end;
|
||||
_ ->
|
||||
respond(400, #{<<"error">> => <<"Invalid method">>}, Req1, State)
|
||||
end;
|
||||
handle_decoded_body(Decoded, Req1, State);
|
||||
{error, ErrorBody, Req1} ->
|
||||
respond(400, ErrorBody, Req1, State)
|
||||
end;
|
||||
@@ -58,57 +45,98 @@ handle_post(Req0, State) ->
|
||||
{ok, Req1, State}
|
||||
end.
|
||||
|
||||
-spec handle_decoded_body(map(), cowboy_req:req(), term()) -> {ok, cowboy_req:req(), term()}.
|
||||
handle_decoded_body(Decoded, Req0, State) ->
|
||||
case maps:get(<<"method">>, Decoded, undefined) of
|
||||
undefined ->
|
||||
respond(400, #{<<"error">> => <<"Missing method">>}, Req0, State);
|
||||
Method when is_binary(Method) ->
|
||||
ParamsValue = maps:get(<<"params">>, Decoded, #{}),
|
||||
case is_map(ParamsValue) of
|
||||
true ->
|
||||
execute_method(Method, ParamsValue, Req0, State);
|
||||
false ->
|
||||
respond(400, #{<<"error">> => <<"Invalid params">>}, Req0, State)
|
||||
end;
|
||||
_ ->
|
||||
respond(400, #{<<"error">> => <<"Invalid method">>}, Req0, State)
|
||||
end.
|
||||
|
||||
-spec authorize(cowboy_req:req()) -> ok | {error, cowboy_req:req()}.
|
||||
authorize(Req0) ->
|
||||
case cowboy_req:header(<<"authorization">>, Req0) of
|
||||
undefined ->
|
||||
Req = cowboy_req:reply(
|
||||
401,
|
||||
?JSON_HEADERS,
|
||||
jsx:encode(#{<<"error">> => <<"Unauthorized">>}),
|
||||
json:encode(#{<<"error">> => <<"Unauthorized">>}),
|
||||
Req0
|
||||
),
|
||||
{error, Req};
|
||||
AuthHeader ->
|
||||
case fluxer_gateway_env:get(rpc_secret_key) of
|
||||
undefined ->
|
||||
Req = cowboy_req:reply(
|
||||
500,
|
||||
?JSON_HEADERS,
|
||||
jsx:encode(#{<<"error">> => <<"RPC secret not configured">>}),
|
||||
Req0
|
||||
),
|
||||
{error, Req};
|
||||
Secret when is_binary(Secret) ->
|
||||
Expected = <<"Bearer ", Secret/binary>>,
|
||||
case AuthHeader of
|
||||
Expected ->
|
||||
ok;
|
||||
_ ->
|
||||
Req = cowboy_req:reply(
|
||||
401,
|
||||
?JSON_HEADERS,
|
||||
jsx:encode(#{<<"error">> => <<"Unauthorized">>}),
|
||||
Req0
|
||||
),
|
||||
{error, Req}
|
||||
end
|
||||
end
|
||||
authorize_with_secret(AuthHeader, Req0)
|
||||
end.
|
||||
|
||||
read_body(Req0) ->
|
||||
read_body(Req0, <<>>).
|
||||
-spec authorize_with_secret(binary(), cowboy_req:req()) -> ok | {error, cowboy_req:req()}.
|
||||
authorize_with_secret(AuthHeader, Req0) ->
|
||||
case fluxer_gateway_env:get(rpc_secret_key) of
|
||||
undefined ->
|
||||
Req = cowboy_req:reply(
|
||||
500,
|
||||
?JSON_HEADERS,
|
||||
json:encode(#{<<"error">> => <<"RPC secret not configured">>}),
|
||||
Req0
|
||||
),
|
||||
{error, Req};
|
||||
Secret when is_binary(Secret) ->
|
||||
Expected = <<"Bearer ", Secret/binary>>,
|
||||
check_auth_header(AuthHeader, Expected, Req0)
|
||||
end.
|
||||
|
||||
read_body(Req0, Acc) ->
|
||||
-spec check_auth_header(binary(), binary(), cowboy_req:req()) -> ok | {error, cowboy_req:req()}.
|
||||
check_auth_header(AuthHeader, Expected, Req0) ->
|
||||
case secure_compare(AuthHeader, Expected) of
|
||||
true ->
|
||||
ok;
|
||||
false ->
|
||||
Req = cowboy_req:reply(
|
||||
401,
|
||||
?JSON_HEADERS,
|
||||
json:encode(#{<<"error">> => <<"Unauthorized">>}),
|
||||
Req0
|
||||
),
|
||||
{error, Req}
|
||||
end.
|
||||
|
||||
-spec secure_compare(binary(), binary()) -> boolean().
|
||||
secure_compare(Left, Right) when is_binary(Left), is_binary(Right) ->
|
||||
case byte_size(Left) =:= byte_size(Right) of
|
||||
true ->
|
||||
crypto:hash_equals(Left, Right);
|
||||
false ->
|
||||
false
|
||||
end.
|
||||
|
||||
-spec read_body(cowboy_req:req()) ->
|
||||
{ok, map(), cowboy_req:req()} | {error, map(), cowboy_req:req()}.
|
||||
read_body(Req0) ->
|
||||
read_body_chunks(Req0, <<>>).
|
||||
|
||||
-spec read_body_chunks(cowboy_req:req(), binary()) ->
|
||||
{ok, map(), cowboy_req:req()} | {error, map(), cowboy_req:req()}.
|
||||
read_body_chunks(Req0, Acc) ->
|
||||
case cowboy_req:read_body(Req0) of
|
||||
{ok, Body, Req1} ->
|
||||
FullBody = <<Acc/binary, Body/binary>>,
|
||||
decode_body(FullBody, Req1);
|
||||
{more, Body, Req1} ->
|
||||
read_body(Req1, <<Acc/binary, Body/binary>>)
|
||||
read_body_chunks(Req1, <<Acc/binary, Body/binary>>)
|
||||
end.
|
||||
|
||||
-spec decode_body(binary(), cowboy_req:req()) ->
|
||||
{ok, map(), cowboy_req:req()} | {error, map(), cowboy_req:req()}.
|
||||
decode_body(Body, Req0) ->
|
||||
case catch jsx:decode(Body, [return_maps]) of
|
||||
case catch json:decode(Body) of
|
||||
{'EXIT', _Reason} ->
|
||||
{error, #{<<"error">> => <<"Invalid JSON payload">>}, Req0};
|
||||
Decoded when is_map(Decoded) ->
|
||||
@@ -117,6 +145,7 @@ decode_body(Body, Req0) ->
|
||||
{error, #{<<"error">> => <<"Invalid request body">>}, Req0}
|
||||
end.
|
||||
|
||||
-spec execute_method(binary(), map(), cowboy_req:req(), term()) -> {ok, cowboy_req:req(), term()}.
|
||||
execute_method(Method, Params, Req0, State) ->
|
||||
try
|
||||
Result = gateway_rpc_router:execute(Method, Params),
|
||||
@@ -124,10 +153,15 @@ execute_method(Method, Params, Req0, State) ->
|
||||
catch
|
||||
throw:{error, Message} ->
|
||||
respond(400, #{<<"error">> => Message}, Req0, State);
|
||||
exit:timeout ->
|
||||
respond(504, #{<<"error">> => <<"timeout">>}, Req0, State);
|
||||
exit:{timeout, _} ->
|
||||
respond(504, #{<<"error">> => <<"timeout">>}, Req0, State);
|
||||
_:_ ->
|
||||
respond(500, #{<<"error">> => <<"Internal error">>}, Req0, State)
|
||||
end.
|
||||
|
||||
-spec respond(pos_integer(), map(), cowboy_req:req(), term()) -> {ok, cowboy_req:req(), term()}.
|
||||
respond(Status, Body, Req0, State) ->
|
||||
Req = cowboy_req:reply(Status, ?JSON_HEADERS, jsx:encode(Body), Req0),
|
||||
Req = cowboy_req:reply(Status, ?JSON_HEADERS, json:encode(Body), Req0),
|
||||
{ok, Req, State}.
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
|
||||
-export([execute_method/2, get_local_node_stats/0]).
|
||||
|
||||
-spec execute_method(binary(), map()) -> map().
|
||||
execute_method(<<"process.memory_stats">>, Params) ->
|
||||
Limit =
|
||||
case maps:get(<<"limit">>, Params, undefined) of
|
||||
@@ -27,42 +28,22 @@ execute_method(<<"process.memory_stats">>, Params) ->
|
||||
LimitValue ->
|
||||
validation:snowflake_or_throw(<<"limit">>, LimitValue)
|
||||
end,
|
||||
|
||||
Guilds = process_memory_stats:get_guild_memory_stats(Limit),
|
||||
#{<<"guilds">> => Guilds};
|
||||
GuildsWithStringMemory = [G#{memory := integer_to_binary(maps:get(memory, G))} || G <- Guilds],
|
||||
#{<<"guilds">> => GuildsWithStringMemory};
|
||||
execute_method(<<"process.node_stats">>, _Params) ->
|
||||
get_local_node_stats().
|
||||
|
||||
-spec get_local_node_stats() -> map().
|
||||
get_local_node_stats() ->
|
||||
SessionCount =
|
||||
case gen_server:call(session_manager, get_global_count, 1000) of
|
||||
{ok, SC} -> SC;
|
||||
_ -> 0
|
||||
end,
|
||||
|
||||
GuildCount =
|
||||
case gen_server:call(guild_manager, get_global_count, 1000) of
|
||||
{ok, GC} -> GC;
|
||||
_ -> 0
|
||||
end,
|
||||
|
||||
PresenceCount =
|
||||
case gen_server:call(presence_manager, get_global_count, 1000) of
|
||||
{ok, PC} -> PC;
|
||||
_ -> 0
|
||||
end,
|
||||
|
||||
CallCount =
|
||||
case gen_server:call(call_manager, get_global_count, 1000) of
|
||||
{ok, CC} -> CC;
|
||||
_ -> 0
|
||||
end,
|
||||
|
||||
SessionCount = get_manager_count(session_manager),
|
||||
GuildCount = get_manager_count(guild_manager),
|
||||
PresenceCount = get_manager_count(presence_manager),
|
||||
CallCount = get_manager_count(call_manager),
|
||||
MemoryInfo = erlang:memory(),
|
||||
TotalMemory = proplists:get_value(total, MemoryInfo, 0),
|
||||
ProcessMemory = proplists:get_value(processes, MemoryInfo, 0),
|
||||
SystemMemory = proplists:get_value(system, MemoryInfo, 0),
|
||||
|
||||
#{
|
||||
<<"status">> => <<"healthy">>,
|
||||
<<"sessions">> => SessionCount,
|
||||
@@ -70,11 +51,18 @@ get_local_node_stats() ->
|
||||
<<"presences">> => PresenceCount,
|
||||
<<"calls">> => CallCount,
|
||||
<<"memory">> => #{
|
||||
<<"total">> => TotalMemory,
|
||||
<<"processes">> => ProcessMemory,
|
||||
<<"system">> => SystemMemory
|
||||
<<"total">> => integer_to_binary(TotalMemory),
|
||||
<<"processes">> => integer_to_binary(ProcessMemory),
|
||||
<<"system">> => integer_to_binary(SystemMemory)
|
||||
},
|
||||
<<"process_count">> => erlang:system_info(process_count),
|
||||
<<"process_limit">> => erlang:system_info(process_limit),
|
||||
<<"uptime_seconds">> => element(1, erlang:statistics(wall_clock)) div 1000
|
||||
}.
|
||||
|
||||
-spec get_manager_count(atom()) -> non_neg_integer().
|
||||
get_manager_count(Manager) ->
|
||||
case gen_server:call(Manager, get_global_count, 1000) of
|
||||
{ok, Count} -> Count;
|
||||
_ -> 0
|
||||
end.
|
||||
|
||||
@@ -19,6 +19,13 @@
|
||||
|
||||
-export([execute_method/2]).
|
||||
|
||||
-ifdef(TEST).
|
||||
-include_lib("eunit/include/eunit.hrl").
|
||||
-endif.
|
||||
|
||||
-define(PRESENCE_LOOKUP_TIMEOUT, 2000).
|
||||
|
||||
-spec execute_method(binary(), map()) -> term().
|
||||
execute_method(<<"presence.dispatch">>, #{
|
||||
<<"user_id">> := UserIdBin, <<"event">> := Event, <<"data">> := Data
|
||||
}) ->
|
||||
@@ -35,132 +42,73 @@ execute_method(<<"presence.join_guild">>, #{
|
||||
}) ->
|
||||
UserId = validation:snowflake_or_throw(<<"user_id">>, UserIdBin),
|
||||
GuildId = validation:snowflake_or_throw(<<"guild_id">>, GuildIdBin),
|
||||
case gen_server:call(presence_manager, {lookup, UserId}, 10000) of
|
||||
{ok, Pid} ->
|
||||
case gen_server:call(Pid, {join_guild, GuildId}, 10000) of
|
||||
ok -> true;
|
||||
_ -> throw({error, <<"Join guild failed">>})
|
||||
end;
|
||||
not_found ->
|
||||
true;
|
||||
{error, _} ->
|
||||
true;
|
||||
_ ->
|
||||
true
|
||||
end;
|
||||
presence_manager:lookup_async(UserId, {join_guild, GuildId}),
|
||||
true;
|
||||
execute_method(<<"presence.leave_guild">>, #{
|
||||
<<"user_id">> := UserIdBin, <<"guild_id">> := GuildIdBin
|
||||
}) ->
|
||||
UserId = validation:snowflake_or_throw(<<"user_id">>, UserIdBin),
|
||||
GuildId = validation:snowflake_or_throw(<<"guild_id">>, GuildIdBin),
|
||||
case gen_server:call(presence_manager, {lookup, UserId}, 10000) of
|
||||
{ok, Pid} ->
|
||||
case gen_server:call(Pid, {leave_guild, GuildId}, 10000) of
|
||||
ok -> true;
|
||||
_ -> throw({error, <<"Leave guild failed">>})
|
||||
end;
|
||||
not_found ->
|
||||
true;
|
||||
{error, _} ->
|
||||
true;
|
||||
_ ->
|
||||
true
|
||||
end;
|
||||
presence_manager:lookup_async(UserId, {leave_guild, GuildId}),
|
||||
true;
|
||||
execute_method(<<"presence.terminate_sessions">>, #{
|
||||
<<"user_id">> := UserIdBin, <<"session_id_hashes">> := SessionIdHashes
|
||||
}) ->
|
||||
UserId = validation:snowflake_or_throw(<<"user_id">>, UserIdBin),
|
||||
case gen_server:call(presence_manager, {lookup, UserId}, 10000) of
|
||||
{ok, Pid} ->
|
||||
case gen_server:call(Pid, {terminate_session, SessionIdHashes}, 10000) of
|
||||
ok -> true;
|
||||
_ -> throw({error, <<"Terminate session failed">>})
|
||||
end;
|
||||
not_found ->
|
||||
true;
|
||||
{error, _} ->
|
||||
true;
|
||||
_ ->
|
||||
true
|
||||
end;
|
||||
execute_method(<<"presence.terminate_all_sessions">>, #{
|
||||
<<"user_id">> := UserIdBin
|
||||
}) ->
|
||||
presence_manager:lookup_async(UserId, {terminate_session, SessionIdHashes}),
|
||||
true;
|
||||
execute_method(<<"presence.terminate_all_sessions">>, #{<<"user_id">> := UserIdBin}) ->
|
||||
UserId = validation:snowflake_or_throw(<<"user_id">>, UserIdBin),
|
||||
case presence_manager:terminate_all_sessions(UserId) of
|
||||
ok -> true;
|
||||
_ -> throw({error, <<"Terminate all sessions failed">>})
|
||||
_ -> throw({error, <<"terminate_sessions_error">>})
|
||||
end;
|
||||
execute_method(<<"presence.has_active">>, #{<<"user_id">> := UserIdBin}) ->
|
||||
UserId = validation:snowflake_or_throw(<<"user_id">>, UserIdBin),
|
||||
case gen_server:call(presence_manager, {lookup, UserId}, 10000) of
|
||||
{ok, _Pid} ->
|
||||
#{<<"has_active">> => true};
|
||||
_ ->
|
||||
#{<<"has_active">> => false}
|
||||
case gen_server:call(presence_manager, {lookup, UserId}, ?PRESENCE_LOOKUP_TIMEOUT) of
|
||||
{ok, _Pid} -> #{<<"has_active">> => true};
|
||||
_ -> #{<<"has_active">> => false}
|
||||
end;
|
||||
execute_method(<<"presence.add_temporary_guild">>, #{
|
||||
<<"user_id">> := UserIdBin, <<"guild_id">> := GuildIdBin
|
||||
}) ->
|
||||
UserId = validation:snowflake_or_throw(<<"user_id">>, UserIdBin),
|
||||
GuildId = validation:snowflake_or_throw(<<"guild_id">>, GuildIdBin),
|
||||
case gen_server:call(presence_manager, {lookup, UserId}, 10000) of
|
||||
{ok, Pid} ->
|
||||
case gen_server:call(Pid, {add_temporary_guild, GuildId}, 10000) of
|
||||
ok -> true;
|
||||
_ -> throw({error, <<"Add temporary guild failed">>})
|
||||
end;
|
||||
not_found ->
|
||||
true;
|
||||
{error, _} ->
|
||||
true;
|
||||
_ ->
|
||||
true
|
||||
end;
|
||||
presence_manager:lookup_async(UserId, {add_temporary_guild, GuildId}),
|
||||
true;
|
||||
execute_method(<<"presence.remove_temporary_guild">>, #{
|
||||
<<"user_id">> := UserIdBin, <<"guild_id">> := GuildIdBin
|
||||
}) ->
|
||||
UserId = validation:snowflake_or_throw(<<"user_id">>, UserIdBin),
|
||||
GuildId = validation:snowflake_or_throw(<<"guild_id">>, GuildIdBin),
|
||||
case gen_server:call(presence_manager, {lookup, UserId}, 10000) of
|
||||
{ok, Pid} ->
|
||||
case gen_server:call(Pid, {remove_temporary_guild, GuildId}, 10000) of
|
||||
ok -> true;
|
||||
_ -> throw({error, <<"Remove temporary guild failed">>})
|
||||
end;
|
||||
not_found ->
|
||||
true;
|
||||
{error, _} ->
|
||||
true;
|
||||
_ ->
|
||||
true
|
||||
end;
|
||||
presence_manager:lookup_async(UserId, {remove_temporary_guild, GuildId}),
|
||||
true;
|
||||
execute_method(<<"presence.sync_group_dm_recipients">>, #{
|
||||
<<"user_id">> := UserIdBin, <<"recipients_by_channel">> := RecipientsByChannel
|
||||
}) ->
|
||||
UserId = validation:snowflake_or_throw(<<"user_id">>, UserIdBin),
|
||||
NormalizedRecipients =
|
||||
maps:from_list([
|
||||
{
|
||||
validation:snowflake_or_throw(<<"channel_id">>, ChannelIdBin),
|
||||
[validation:snowflake_or_throw(<<"recipient_id">>, RBin) || RBin <- Recipients]
|
||||
}
|
||||
|| {ChannelIdBin, Recipients} <- maps:to_list(RecipientsByChannel)
|
||||
]),
|
||||
case gen_server:call(presence_manager, {lookup, UserId}, 10000) of
|
||||
NormalizedRecipients = normalize_recipients(RecipientsByChannel),
|
||||
case gen_server:call(presence_manager, {lookup, UserId}, ?PRESENCE_LOOKUP_TIMEOUT) of
|
||||
{ok, Pid} ->
|
||||
gen_server:cast(Pid, {sync_group_dm_recipients, NormalizedRecipients}),
|
||||
true;
|
||||
not_found ->
|
||||
true;
|
||||
{error, _} ->
|
||||
true;
|
||||
_ ->
|
||||
true
|
||||
end.
|
||||
|
||||
-spec normalize_recipients(map()) -> map().
|
||||
normalize_recipients(RecipientsByChannel) ->
|
||||
maps:from_list([
|
||||
{
|
||||
validation:snowflake_or_throw(<<"channel_id">>, ChannelIdBin),
|
||||
[validation:snowflake_or_throw(<<"recipient_id">>, RBin) || RBin <- Recipients]
|
||||
}
|
||||
|| {ChannelIdBin, Recipients} <- maps:to_list(RecipientsByChannel)
|
||||
]).
|
||||
|
||||
-spec handle_offline_dispatch(atom(), integer(), map()) -> true.
|
||||
handle_offline_dispatch(message_create, UserId, Data) ->
|
||||
AuthorIdBin = maps:get(<<"id">>, maps:get(<<"author">>, Data, #{}), <<"0">>),
|
||||
AuthorIdBin = maps:get(<<"id">>, maps:get(<<"author">>, Data, #{}), undefined),
|
||||
AuthorId = validation:snowflake_or_throw(<<"author_id">>, AuthorIdBin),
|
||||
push:handle_message_create(#{
|
||||
message_data => Data,
|
||||
@@ -178,5 +126,16 @@ handle_offline_dispatch(relationship_remove, UserId, _Data) ->
|
||||
handle_offline_dispatch(_Event, _UserId, _Data) ->
|
||||
true.
|
||||
|
||||
-spec sync_blocked_ids_for_user(integer()) -> ok.
|
||||
sync_blocked_ids_for_user(_UserId) ->
|
||||
ok.
|
||||
|
||||
-ifdef(TEST).
|
||||
|
||||
normalize_recipients_test() ->
|
||||
Input = #{<<"123">> => [<<"456">>, <<"789">>]},
|
||||
Result = normalize_recipients(Input),
|
||||
?assert(is_map(Result)),
|
||||
?assertEqual(1, maps:size(Result)).
|
||||
|
||||
-endif.
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
|
||||
-export([execute_method/2]).
|
||||
|
||||
-spec execute_method(binary(), map()) -> true.
|
||||
execute_method(<<"push.sync_user_guild_settings">>, #{
|
||||
<<"user_id">> := UserIdBin,
|
||||
<<"guild_id">> := GuildIdBin,
|
||||
|
||||
@@ -19,18 +19,33 @@
|
||||
|
||||
-export([execute/2]).
|
||||
|
||||
-ifdef(TEST).
|
||||
-include_lib("eunit/include/eunit.hrl").
|
||||
-endif.
|
||||
|
||||
-spec execute(binary(), map()) -> term().
|
||||
execute(Method, Params) ->
|
||||
case Method of
|
||||
<<"guild.", _/binary>> ->
|
||||
gateway_rpc_guild:execute_method(Method, Params);
|
||||
<<"presence.", _/binary>> ->
|
||||
gateway_rpc_presence:execute_method(Method, Params);
|
||||
<<"push.", _/binary>> ->
|
||||
gateway_rpc_push:execute_method(Method, Params);
|
||||
<<"call.", _/binary>> ->
|
||||
gateway_rpc_call:execute_method(Method, Params);
|
||||
<<"process.", _/binary>> ->
|
||||
gateway_rpc_misc:execute_method(Method, Params);
|
||||
_ ->
|
||||
throw({error, <<"Unknown method: ", Method/binary>>})
|
||||
end.
|
||||
route_method(Method, Params).
|
||||
|
||||
-spec route_method(binary(), map()) -> term().
|
||||
route_method(<<"guild.", _/binary>> = Method, Params) ->
|
||||
gateway_rpc_guild:execute_method(Method, Params);
|
||||
route_method(<<"presence.", _/binary>> = Method, Params) ->
|
||||
gateway_rpc_presence:execute_method(Method, Params);
|
||||
route_method(<<"push.", _/binary>> = Method, Params) ->
|
||||
gateway_rpc_push:execute_method(Method, Params);
|
||||
route_method(<<"call.", _/binary>> = Method, Params) ->
|
||||
gateway_rpc_call:execute_method(Method, Params);
|
||||
route_method(<<"voice.", _/binary>> = Method, Params) ->
|
||||
gateway_rpc_voice:execute_method(Method, Params);
|
||||
route_method(<<"process.", _/binary>> = Method, Params) ->
|
||||
gateway_rpc_misc:execute_method(Method, Params);
|
||||
route_method(Method, _Params) ->
|
||||
throw({error, <<"Unknown method: ", Method/binary>>}).
|
||||
|
||||
-ifdef(TEST).
|
||||
|
||||
route_method_guild_test() ->
|
||||
?assertThrow({error, _}, route_method(<<"unknown.method">>, #{})).
|
||||
|
||||
-endif.
|
||||
|
||||
446
fluxer_gateway/src/gateway/gateway_rpc_tcp_connection.erl
Normal file
446
fluxer_gateway/src/gateway/gateway_rpc_tcp_connection.erl
Normal file
@@ -0,0 +1,446 @@
|
||||
%% Copyright (C) 2026 Fluxer Contributors
|
||||
%%
|
||||
%% This file is part of Fluxer.
|
||||
%%
|
||||
%% Fluxer is free software: you can redistribute it and/or modify
|
||||
%% it under the terms of the GNU Affero General Public License as published by
|
||||
%% the Free Software Foundation, either version 3 of the License, or
|
||||
%% (at your option) any later version.
|
||||
%%
|
||||
%% Fluxer is distributed in the hope that it will be useful,
|
||||
%% but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
%% GNU Affero General Public License for more details.
|
||||
%%
|
||||
%% You should have received a copy of the GNU Affero General Public License
|
||||
%% along with Fluxer. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
-module(gateway_rpc_tcp_connection).
|
||||
|
||||
-export([serve/1]).
|
||||
|
||||
-define(DEFAULT_MAX_INFLIGHT, 1024).
|
||||
-define(DEFAULT_MAX_INPUT_BUFFER_BYTES, 2097152).
|
||||
-define(DEFAULT_DISPATCH_RESERVE_DIVISOR, 8).
|
||||
-define(MAX_FRAME_BYTES, 1048576).
|
||||
-define(PROTOCOL_VERSION, <<"fluxer.rpc.tcp.v1">>).
|
||||
|
||||
-type state() :: #{
|
||||
socket := inet:socket(),
|
||||
buffer := binary(),
|
||||
authenticated := boolean(),
|
||||
inflight := non_neg_integer(),
|
||||
max_inflight := pos_integer(),
|
||||
max_input_buffer_bytes := pos_integer()
|
||||
}.
|
||||
|
||||
-type rpc_result() :: {ok, term()} | {error, binary()}.
|
||||
|
||||
-spec serve(inet:socket()) -> ok.
|
||||
serve(Socket) ->
|
||||
ok = inet:setopts(Socket, [{active, once}, {nodelay, true}, {keepalive, true}]),
|
||||
State = #{
|
||||
socket => Socket,
|
||||
buffer => <<>>,
|
||||
authenticated => false,
|
||||
inflight => 0,
|
||||
max_inflight => max_inflight(),
|
||||
max_input_buffer_bytes => max_input_buffer_bytes()
|
||||
},
|
||||
loop(State).
|
||||
|
||||
-spec loop(state()) -> ok.
|
||||
loop(#{socket := Socket} = State) ->
|
||||
receive
|
||||
{tcp, Socket, Data} ->
|
||||
case handle_tcp_data(Data, State) of
|
||||
{ok, NewState} ->
|
||||
ok = inet:setopts(Socket, [{active, once}]),
|
||||
loop(NewState);
|
||||
{stop, Reason, _NewState} ->
|
||||
logger:debug("Gateway TCP RPC connection closed: ~p", [Reason]),
|
||||
close_socket(Socket),
|
||||
ok
|
||||
end;
|
||||
{tcp_closed, Socket} ->
|
||||
ok;
|
||||
{tcp_error, Socket, Reason} ->
|
||||
logger:warning("Gateway TCP RPC socket error: ~p", [Reason]),
|
||||
close_socket(Socket),
|
||||
ok;
|
||||
{rpc_response, RequestId, Result} ->
|
||||
NewState = handle_rpc_response(RequestId, Result, State),
|
||||
loop(NewState);
|
||||
_Other ->
|
||||
loop(State)
|
||||
end.
|
||||
|
||||
-spec handle_tcp_data(binary(), state()) -> {ok, state()} | {stop, term(), state()}.
|
||||
handle_tcp_data(Data, #{buffer := Buffer, max_input_buffer_bytes := MaxInputBufferBytes} = State) ->
|
||||
case byte_size(Buffer) + byte_size(Data) =< MaxInputBufferBytes of
|
||||
false ->
|
||||
_ = send_error_frame(State, protocol_error_binary(input_buffer_limit_exceeded)),
|
||||
{stop, input_buffer_limit_exceeded, State};
|
||||
true ->
|
||||
Combined = <<Buffer/binary, Data/binary>>,
|
||||
decode_tcp_frames(Combined, State)
|
||||
end.
|
||||
|
||||
-spec decode_tcp_frames(binary(), state()) -> {ok, state()} | {stop, term(), state()}.
|
||||
decode_tcp_frames(Combined, State) ->
|
||||
case decode_frames(Combined, []) of
|
||||
{ok, Frames, Rest} ->
|
||||
process_frames(Frames, State#{buffer => Rest});
|
||||
{error, Reason} ->
|
||||
_ = send_error_frame(State, protocol_error_binary(Reason)),
|
||||
{stop, Reason, State}
|
||||
end.
|
||||
|
||||
-spec process_frames([map()], state()) -> {ok, state()} | {stop, term(), state()}.
|
||||
process_frames([], State) ->
|
||||
{ok, State};
|
||||
process_frames([Frame | Rest], State) ->
|
||||
case process_frame(Frame, State) of
|
||||
{ok, NewState} ->
|
||||
process_frames(Rest, NewState);
|
||||
{stop, Reason, NewState} ->
|
||||
{stop, Reason, NewState}
|
||||
end.
|
||||
|
||||
-spec process_frame(map(), state()) -> {ok, state()} | {stop, term(), state()}.
|
||||
process_frame(#{<<"type">> := <<"hello">>} = Frame, #{authenticated := false} = State) ->
|
||||
handle_hello_frame(Frame, State);
|
||||
process_frame(#{<<"type">> := <<"hello">>}, State) ->
|
||||
_ = send_error_frame(State, <<"duplicate_hello">>),
|
||||
{stop, duplicate_hello, State};
|
||||
process_frame(#{<<"type">> := <<"request">>} = Frame, #{authenticated := true} = State) ->
|
||||
handle_request_frame(Frame, State);
|
||||
process_frame(#{<<"type">> := <<"request">>}, State) ->
|
||||
_ = send_error_frame(State, <<"unauthorized">>),
|
||||
{stop, unauthorized, State};
|
||||
process_frame(#{<<"type">> := <<"ping">>}, State) ->
|
||||
_ = send_frame(State, #{<<"type">> => <<"pong">>}),
|
||||
{ok, State};
|
||||
process_frame(#{<<"type">> := <<"pong">>}, State) ->
|
||||
{ok, State};
|
||||
process_frame(#{<<"type">> := <<"close">>}, State) ->
|
||||
{stop, client_close, State};
|
||||
process_frame(_Frame, State) ->
|
||||
_ = send_error_frame(State, <<"unknown_frame_type">>),
|
||||
{stop, unknown_frame_type, State}.
|
||||
|
||||
-spec handle_hello_frame(map(), state()) -> {ok, state()} | {stop, term(), state()}.
|
||||
handle_hello_frame(Frame, State) ->
|
||||
case {maps:get(<<"protocol">>, Frame, undefined), maps:get(<<"authorization">>, Frame, undefined)} of
|
||||
{?PROTOCOL_VERSION, AuthHeader} when is_binary(AuthHeader) ->
|
||||
authorize_hello(AuthHeader, State);
|
||||
_ ->
|
||||
_ = send_error_frame(State, <<"invalid_hello">>),
|
||||
{stop, invalid_hello, State}
|
||||
end.
|
||||
|
||||
-spec authorize_hello(binary(), state()) -> {ok, state()} | {stop, term(), state()}.
|
||||
authorize_hello(AuthHeader, State) ->
|
||||
case fluxer_gateway_env:get(rpc_secret_key) of
|
||||
Secret when is_binary(Secret) ->
|
||||
Expected = <<"Bearer ", Secret/binary>>,
|
||||
case secure_compare(AuthHeader, Expected) of
|
||||
true ->
|
||||
HelloAck = #{
|
||||
<<"type">> => <<"hello_ack">>,
|
||||
<<"protocol">> => ?PROTOCOL_VERSION,
|
||||
<<"max_in_flight">> => maps:get(max_inflight, State),
|
||||
<<"ping_interval_ms">> => 15000
|
||||
},
|
||||
_ = send_frame(State, HelloAck),
|
||||
{ok, State#{authenticated => true}};
|
||||
false ->
|
||||
_ = send_error_frame(State, <<"unauthorized">>),
|
||||
{stop, unauthorized, State}
|
||||
end;
|
||||
_ ->
|
||||
_ = send_error_frame(State, <<"rpc_secret_not_configured">>),
|
||||
{stop, rpc_secret_not_configured, State}
|
||||
end.
|
||||
|
||||
-spec handle_request_frame(map(), state()) -> {ok, state()}.
|
||||
handle_request_frame(Frame, State) ->
|
||||
RequestId = request_id_from_frame(Frame),
|
||||
Method = maps:get(<<"method">>, Frame, undefined),
|
||||
case should_reject_request(Method, State) of
|
||||
true ->
|
||||
_ =
|
||||
send_response_frame(
|
||||
State,
|
||||
RequestId,
|
||||
false,
|
||||
undefined,
|
||||
<<"overloaded">>
|
||||
),
|
||||
{ok, State};
|
||||
false ->
|
||||
case {Method, maps:get(<<"params">>, Frame, undefined)} of
|
||||
{MethodName, Params} when is_binary(RequestId), is_binary(MethodName), is_map(Params) ->
|
||||
Parent = self(),
|
||||
_ = spawn(fun() ->
|
||||
Parent ! {rpc_response, RequestId, execute_method(MethodName, Params)}
|
||||
end),
|
||||
{ok, increment_inflight(State)};
|
||||
_ ->
|
||||
_ =
|
||||
send_response_frame(
|
||||
State,
|
||||
RequestId,
|
||||
false,
|
||||
undefined,
|
||||
<<"invalid_request">>
|
||||
),
|
||||
{ok, State}
|
||||
end
|
||||
end.
|
||||
|
||||
-spec should_reject_request(term(), state()) -> boolean().
|
||||
should_reject_request(Method, #{inflight := Inflight, max_inflight := MaxInflight}) ->
|
||||
case is_dispatch_method(Method) of
|
||||
true ->
|
||||
Inflight >= MaxInflight;
|
||||
false ->
|
||||
Inflight >= non_dispatch_inflight_limit(MaxInflight)
|
||||
end.
|
||||
|
||||
-spec non_dispatch_inflight_limit(pos_integer()) -> pos_integer().
|
||||
non_dispatch_inflight_limit(MaxInflight) ->
|
||||
Reserve = dispatch_reserve_slots(MaxInflight),
|
||||
max(1, MaxInflight - Reserve).
|
||||
|
||||
-spec dispatch_reserve_slots(pos_integer()) -> pos_integer().
|
||||
dispatch_reserve_slots(MaxInflight) ->
|
||||
max(1, MaxInflight div ?DEFAULT_DISPATCH_RESERVE_DIVISOR).
|
||||
|
||||
-spec is_dispatch_method(term()) -> boolean().
|
||||
is_dispatch_method(Method) when is_binary(Method) ->
|
||||
Suffix = <<".dispatch">>,
|
||||
MethodSize = byte_size(Method),
|
||||
SuffixSize = byte_size(Suffix),
|
||||
MethodSize >= SuffixSize andalso
|
||||
binary:part(Method, MethodSize - SuffixSize, SuffixSize) =:= Suffix;
|
||||
is_dispatch_method(_) ->
|
||||
false.
|
||||
|
||||
-spec execute_method(binary(), map()) -> rpc_result().
|
||||
execute_method(Method, Params) ->
|
||||
try
|
||||
Result = gateway_rpc_router:execute(Method, Params),
|
||||
{ok, Result}
|
||||
catch
|
||||
throw:{error, Message} ->
|
||||
{error, error_binary(Message)};
|
||||
exit:timeout ->
|
||||
{error, <<"timeout">>};
|
||||
exit:{timeout, _} ->
|
||||
{error, <<"timeout">>};
|
||||
Class:Reason ->
|
||||
logger:error(
|
||||
"Gateway TCP RPC method execution failed. method=~ts class=~p reason=~p",
|
||||
[Method, Class, Reason]
|
||||
),
|
||||
{error, <<"internal_error">>}
|
||||
end.
|
||||
|
||||
-spec handle_rpc_response(binary(), rpc_result(), state()) -> state().
|
||||
handle_rpc_response(RequestId, {ok, Result}, State) ->
|
||||
_ = send_response_frame(State, RequestId, true, Result, undefined),
|
||||
decrement_inflight(State);
|
||||
handle_rpc_response(RequestId, {error, Error}, State) ->
|
||||
_ = send_response_frame(State, RequestId, false, undefined, Error),
|
||||
decrement_inflight(State).
|
||||
|
||||
-spec send_response_frame(state(), binary(), boolean(), term(), binary() | undefined) -> ok | {error, term()}.
|
||||
send_response_frame(State, RequestId, true, Result, _Error) ->
|
||||
send_frame(State, #{
|
||||
<<"type">> => <<"response">>,
|
||||
<<"id">> => RequestId,
|
||||
<<"ok">> => true,
|
||||
<<"result">> => Result
|
||||
});
|
||||
send_response_frame(State, RequestId, false, _Result, Error) ->
|
||||
send_frame(State, #{
|
||||
<<"type">> => <<"response">>,
|
||||
<<"id">> => RequestId,
|
||||
<<"ok">> => false,
|
||||
<<"error">> => Error
|
||||
}).
|
||||
|
||||
-spec send_error_frame(state(), binary()) -> ok | {error, term()}.
|
||||
send_error_frame(State, Error) ->
|
||||
send_frame(State, #{
|
||||
<<"type">> => <<"error">>,
|
||||
<<"error">> => Error
|
||||
}).
|
||||
|
||||
-spec send_frame(state(), map()) -> ok | {error, term()}.
|
||||
send_frame(#{socket := Socket}, Frame) ->
|
||||
gen_tcp:send(Socket, encode_frame(Frame)).
|
||||
|
||||
-spec encode_frame(map()) -> binary().
|
||||
encode_frame(Frame) ->
|
||||
Payload = iolist_to_binary(json:encode(Frame)),
|
||||
Length = integer_to_binary(byte_size(Payload)),
|
||||
<<Length/binary, "\n", Payload/binary>>.
|
||||
|
||||
-spec decode_frames(binary(), [map()]) -> {ok, [map()], binary()} | {error, term()}.
|
||||
decode_frames(Buffer, Acc) ->
|
||||
case binary:match(Buffer, <<"\n">>) of
|
||||
nomatch ->
|
||||
{ok, lists:reverse(Acc), Buffer};
|
||||
{Pos, 1} ->
|
||||
LengthBin = binary:part(Buffer, 0, Pos),
|
||||
case parse_length(LengthBin) of
|
||||
{ok, Length} ->
|
||||
HeaderSize = Pos + 1,
|
||||
RequiredSize = HeaderSize + Length,
|
||||
case byte_size(Buffer) >= RequiredSize of
|
||||
false ->
|
||||
{ok, lists:reverse(Acc), Buffer};
|
||||
true ->
|
||||
Payload = binary:part(Buffer, HeaderSize, Length),
|
||||
RestSize = byte_size(Buffer) - RequiredSize,
|
||||
Rest = binary:part(Buffer, RequiredSize, RestSize),
|
||||
case decode_payload(Payload) of
|
||||
{ok, Frame} ->
|
||||
decode_frames(Rest, [Frame | Acc]);
|
||||
{error, Reason} ->
|
||||
{error, Reason}
|
||||
end
|
||||
end;
|
||||
{error, Reason} ->
|
||||
{error, Reason}
|
||||
end
|
||||
end.
|
||||
|
||||
-spec decode_payload(binary()) -> {ok, map()} | {error, term()}.
|
||||
decode_payload(Payload) ->
|
||||
case catch json:decode(Payload) of
|
||||
{'EXIT', _} ->
|
||||
{error, invalid_json};
|
||||
Frame when is_map(Frame) ->
|
||||
{ok, Frame};
|
||||
_ ->
|
||||
{error, invalid_json}
|
||||
end.
|
||||
|
||||
-spec parse_length(binary()) -> {ok, non_neg_integer()} | {error, term()}.
|
||||
parse_length(<<>>) ->
|
||||
{error, invalid_frame_length};
|
||||
parse_length(LengthBin) ->
|
||||
try
|
||||
Length = binary_to_integer(LengthBin),
|
||||
case Length >= 0 andalso Length =< ?MAX_FRAME_BYTES of
|
||||
true -> {ok, Length};
|
||||
false -> {error, invalid_frame_length}
|
||||
end
|
||||
catch
|
||||
_:_ ->
|
||||
{error, invalid_frame_length}
|
||||
end.
|
||||
|
||||
-spec secure_compare(binary(), binary()) -> boolean().
|
||||
secure_compare(Left, Right) when is_binary(Left), is_binary(Right) ->
|
||||
case byte_size(Left) =:= byte_size(Right) of
|
||||
true ->
|
||||
crypto:hash_equals(Left, Right);
|
||||
false ->
|
||||
false
|
||||
end.
|
||||
|
||||
-spec request_id_from_frame(map()) -> binary().
|
||||
request_id_from_frame(Frame) ->
|
||||
case maps:get(<<"id">>, Frame, <<>>) of
|
||||
Id when is_binary(Id) ->
|
||||
Id;
|
||||
Id when is_integer(Id) ->
|
||||
integer_to_binary(Id);
|
||||
_ ->
|
||||
<<>>
|
||||
end.
|
||||
|
||||
-spec increment_inflight(state()) -> state().
|
||||
increment_inflight(#{inflight := Inflight} = State) ->
|
||||
State#{inflight => Inflight + 1}.
|
||||
|
||||
-spec decrement_inflight(state()) -> state().
|
||||
decrement_inflight(#{inflight := Inflight} = State) when Inflight > 0 ->
|
||||
State#{inflight => Inflight - 1};
|
||||
decrement_inflight(State) ->
|
||||
State.
|
||||
|
||||
-spec error_binary(term()) -> binary().
|
||||
error_binary(Value) when is_binary(Value) ->
|
||||
Value;
|
||||
error_binary(Value) when is_list(Value) ->
|
||||
unicode:characters_to_binary(Value);
|
||||
error_binary(Value) when is_atom(Value) ->
|
||||
atom_to_binary(Value, utf8);
|
||||
error_binary(Value) ->
|
||||
unicode:characters_to_binary(io_lib:format("~p", [Value])).
|
||||
|
||||
-spec protocol_error_binary(term()) -> binary().
|
||||
protocol_error_binary(invalid_json) ->
|
||||
<<"invalid_json">>;
|
||||
protocol_error_binary(invalid_frame_length) ->
|
||||
<<"invalid_frame_length">>;
|
||||
protocol_error_binary(input_buffer_limit_exceeded) ->
|
||||
<<"input_buffer_limit_exceeded">>.
|
||||
|
||||
-spec close_socket(inet:socket()) -> ok.
|
||||
close_socket(Socket) ->
|
||||
catch gen_tcp:close(Socket),
|
||||
ok.
|
||||
|
||||
-spec max_inflight() -> pos_integer().
|
||||
max_inflight() ->
|
||||
case fluxer_gateway_env:get(gateway_http_rpc_max_concurrency) of
|
||||
Value when is_integer(Value), Value > 0 ->
|
||||
Value;
|
||||
_ ->
|
||||
?DEFAULT_MAX_INFLIGHT
|
||||
end.
|
||||
|
||||
-spec max_input_buffer_bytes() -> pos_integer().
|
||||
max_input_buffer_bytes() ->
|
||||
case fluxer_gateway_env:get(gateway_rpc_tcp_max_input_buffer_bytes) of
|
||||
Value when is_integer(Value), Value > 0 ->
|
||||
Value;
|
||||
_ ->
|
||||
?DEFAULT_MAX_INPUT_BUFFER_BYTES
|
||||
end.
|
||||
|
||||
-ifdef(TEST).
|
||||
-include_lib("eunit/include/eunit.hrl").
|
||||
|
||||
decode_single_frame_test() ->
|
||||
Frame = #{<<"type">> => <<"ping">>},
|
||||
Encoded = encode_frame(Frame),
|
||||
?assertEqual({ok, [Frame], <<>>}, decode_frames(Encoded, [])).
|
||||
|
||||
decode_multiple_frames_test() ->
|
||||
FrameA = #{<<"type">> => <<"ping">>},
|
||||
FrameB = #{<<"type">> => <<"pong">>},
|
||||
Encoded = <<(encode_frame(FrameA))/binary, (encode_frame(FrameB))/binary>>,
|
||||
?assertEqual({ok, [FrameA, FrameB], <<>>}, decode_frames(Encoded, [])).
|
||||
|
||||
decode_partial_frame_test() ->
|
||||
Frame = #{<<"type">> => <<"ping">>},
|
||||
Encoded = encode_frame(Frame),
|
||||
Prefix = binary:part(Encoded, 0, 3),
|
||||
?assertEqual({ok, [], Prefix}, decode_frames(Prefix, [])).
|
||||
|
||||
invalid_length_test() ->
|
||||
?assertEqual({error, invalid_frame_length}, decode_frames(<<"x\n{}">>, [])).
|
||||
|
||||
secure_compare_test() ->
|
||||
?assert(secure_compare(<<"abc">>, <<"abc">>)),
|
||||
?assertNot(secure_compare(<<"abc">>, <<"abd">>)),
|
||||
?assertNot(secure_compare(<<"abc">>, <<"abcd">>)).
|
||||
|
||||
-endif.
|
||||
108
fluxer_gateway/src/gateway/gateway_rpc_tcp_server.erl
Normal file
108
fluxer_gateway/src/gateway/gateway_rpc_tcp_server.erl
Normal file
@@ -0,0 +1,108 @@
|
||||
%% Copyright (C) 2026 Fluxer Contributors
|
||||
%%
|
||||
%% This file is part of Fluxer.
|
||||
%%
|
||||
%% Fluxer is free software: you can redistribute it and/or modify
|
||||
%% it under the terms of the GNU Affero General Public License as published by
|
||||
%% the Free Software Foundation, either version 3 of the License, or
|
||||
%% (at your option) any later version.
|
||||
%%
|
||||
%% Fluxer is distributed in the hope that it will be useful,
|
||||
%% but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
%% GNU Affero General Public License for more details.
|
||||
%%
|
||||
%% You should have received a copy of the GNU Affero General Public License
|
||||
%% along with Fluxer. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
-module(gateway_rpc_tcp_server).
|
||||
-behaviour(gen_server).
|
||||
|
||||
-export([start_link/0, accept_loop/1]).
|
||||
-export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]).
|
||||
|
||||
-type state() :: #{
|
||||
listen_socket := inet:socket(),
|
||||
acceptor_pid := pid(),
|
||||
port := inet:port_number()
|
||||
}.
|
||||
|
||||
-spec start_link() -> {ok, pid()} | {error, term()}.
|
||||
start_link() ->
|
||||
gen_server:start_link({local, ?MODULE}, ?MODULE, [], []).
|
||||
|
||||
-spec init([]) -> {ok, state()} | {stop, term()}.
|
||||
init([]) ->
|
||||
process_flag(trap_exit, true),
|
||||
Port = fluxer_gateway_env:get(rpc_tcp_port),
|
||||
case gen_tcp:listen(Port, listen_options()) of
|
||||
{ok, ListenSocket} ->
|
||||
AcceptorPid = spawn_link(?MODULE, accept_loop, [ListenSocket]),
|
||||
logger:info("Gateway TCP RPC listener started on port ~p", [Port]),
|
||||
{ok, #{
|
||||
listen_socket => ListenSocket,
|
||||
acceptor_pid => AcceptorPid,
|
||||
port => Port
|
||||
}};
|
||||
{error, Reason} ->
|
||||
{stop, {rpc_tcp_listen_failed, Port, Reason}}
|
||||
end.
|
||||
|
||||
-spec handle_call(term(), gen_server:from(), state()) -> {reply, ok, state()}.
|
||||
handle_call(_Request, _From, State) ->
|
||||
{reply, ok, State}.
|
||||
|
||||
-spec handle_cast(term(), state()) -> {noreply, state()}.
|
||||
handle_cast(_Msg, State) ->
|
||||
{noreply, State}.
|
||||
|
||||
-spec handle_info(term(), state()) -> {noreply, state()}.
|
||||
handle_info({'EXIT', Pid, Reason}, #{acceptor_pid := Pid, listen_socket := ListenSocket} = State) ->
|
||||
case Reason of
|
||||
normal ->
|
||||
{noreply, State};
|
||||
shutdown ->
|
||||
{noreply, State};
|
||||
_ ->
|
||||
logger:error("Gateway TCP RPC acceptor crashed: ~p", [Reason]),
|
||||
NewAcceptorPid = spawn_link(?MODULE, accept_loop, [ListenSocket]),
|
||||
{noreply, State#{acceptor_pid => NewAcceptorPid}}
|
||||
end;
|
||||
handle_info(_Info, State) ->
|
||||
{noreply, State}.
|
||||
|
||||
-spec terminate(term(), state()) -> ok.
|
||||
terminate(_Reason, #{listen_socket := ListenSocket, port := Port}) ->
|
||||
catch gen_tcp:close(ListenSocket),
|
||||
logger:info("Gateway TCP RPC listener stopped on port ~p", [Port]),
|
||||
ok.
|
||||
|
||||
-spec code_change(term(), state(), term()) -> {ok, state()}.
|
||||
code_change(_OldVsn, State, _Extra) ->
|
||||
{ok, State}.
|
||||
|
||||
-spec accept_loop(inet:socket()) -> ok.
|
||||
accept_loop(ListenSocket) ->
|
||||
case gen_tcp:accept(ListenSocket) of
|
||||
{ok, Socket} ->
|
||||
_ = spawn_link(?MODULE, accept_loop, [ListenSocket]),
|
||||
gateway_rpc_tcp_connection:serve(Socket);
|
||||
{error, closed} ->
|
||||
ok;
|
||||
{error, Reason} ->
|
||||
logger:error("Gateway TCP RPC accept failed: ~p", [Reason]),
|
||||
timer:sleep(200),
|
||||
accept_loop(ListenSocket)
|
||||
end.
|
||||
|
||||
-spec listen_options() -> [gen_tcp:listen_option()].
|
||||
listen_options() ->
|
||||
[
|
||||
binary,
|
||||
{packet, raw},
|
||||
{active, false},
|
||||
{reuseaddr, true},
|
||||
{nodelay, true},
|
||||
{backlog, 4096},
|
||||
{keepalive, true}
|
||||
].
|
||||
232
fluxer_gateway/src/gateway/gateway_rpc_voice.erl
Normal file
232
fluxer_gateway/src/gateway/gateway_rpc_voice.erl
Normal file
@@ -0,0 +1,232 @@
|
||||
%% Copyright (C) 2026 Fluxer Contributors
|
||||
%%
|
||||
%% This file is part of Fluxer.
|
||||
%%
|
||||
%% Fluxer is free software: you can redistribute it and/or modify
|
||||
%% it under the terms of the GNU Affero General Public License as published by
|
||||
%% the Free Software Foundation, either version 3 of the License, or
|
||||
%% (at your option) any later version.
|
||||
%%
|
||||
%% Fluxer is distributed in the hope that it will be useful,
|
||||
%% but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
%% GNU Affero General Public License for more details.
|
||||
%%
|
||||
%% You should have received a copy of the GNU Affero General Public License
|
||||
%% along with Fluxer. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
-module(gateway_rpc_voice).
|
||||
|
||||
-export([execute_method/2]).
|
||||
|
||||
-spec execute_method(binary(), map()) -> term().
|
||||
execute_method(<<"voice.confirm_connection">>, Params) ->
|
||||
ChannelIdBin = maps:get(<<"channel_id">>, Params),
|
||||
ConnectionId = maps:get(<<"connection_id">>, Params),
|
||||
case parse_optional_guild_id(Params) of
|
||||
undefined ->
|
||||
gateway_rpc_call:execute_method(
|
||||
<<"call.confirm_connection">>,
|
||||
#{
|
||||
<<"channel_id">> => ChannelIdBin,
|
||||
<<"connection_id">> => ConnectionId
|
||||
}
|
||||
);
|
||||
GuildId ->
|
||||
TokenNonce = maps:get(<<"token_nonce">>, Params, undefined),
|
||||
gateway_rpc_guild:execute_method(
|
||||
<<"guild.confirm_voice_connection_from_livekit">>,
|
||||
#{
|
||||
<<"guild_id">> => integer_to_binary(GuildId),
|
||||
<<"connection_id">> => ConnectionId,
|
||||
<<"token_nonce">> => TokenNonce
|
||||
}
|
||||
)
|
||||
end;
|
||||
execute_method(<<"voice.disconnect_user_if_in_channel">>, Params) ->
|
||||
ChannelIdBin = maps:get(<<"channel_id">>, Params),
|
||||
UserIdBin = maps:get(<<"user_id">>, Params),
|
||||
ConnectionId = maps:get(<<"connection_id">>, Params, undefined),
|
||||
case parse_optional_guild_id(Params) of
|
||||
undefined ->
|
||||
CallParams = #{
|
||||
<<"channel_id">> => ChannelIdBin,
|
||||
<<"user_id">> => UserIdBin
|
||||
},
|
||||
gateway_rpc_call:execute_method(
|
||||
<<"call.disconnect_user_if_in_channel">>,
|
||||
maybe_put_connection_id(ConnectionId, CallParams)
|
||||
);
|
||||
GuildId ->
|
||||
GuildParams = #{
|
||||
<<"guild_id">> => integer_to_binary(GuildId),
|
||||
<<"user_id">> => UserIdBin,
|
||||
<<"expected_channel_id">> => ChannelIdBin
|
||||
},
|
||||
gateway_rpc_guild:execute_method(
|
||||
<<"guild.disconnect_voice_user_if_in_channel">>,
|
||||
maybe_put_connection_id(ConnectionId, GuildParams)
|
||||
)
|
||||
end;
|
||||
execute_method(<<"voice.get_voice_states_for_channel">>, Params) ->
|
||||
ChannelIdBin = maps:get(<<"channel_id">>, Params),
|
||||
case parse_optional_guild_id(Params) of
|
||||
undefined ->
|
||||
build_dm_voice_states_response(ChannelIdBin);
|
||||
GuildId ->
|
||||
gateway_rpc_guild:execute_method(
|
||||
<<"guild.get_voice_states_for_channel">>,
|
||||
#{
|
||||
<<"guild_id">> => integer_to_binary(GuildId),
|
||||
<<"channel_id">> => ChannelIdBin
|
||||
}
|
||||
)
|
||||
end;
|
||||
execute_method(<<"voice.get_pending_joins_for_channel">>, Params) ->
|
||||
ChannelIdBin = maps:get(<<"channel_id">>, Params),
|
||||
case parse_optional_guild_id(Params) of
|
||||
undefined ->
|
||||
normalize_pending_joins_response(
|
||||
gateway_rpc_call:execute_method(
|
||||
<<"call.get_pending_joins">>,
|
||||
#{<<"channel_id">> => ChannelIdBin}
|
||||
)
|
||||
);
|
||||
GuildId ->
|
||||
gateway_rpc_guild:execute_method(
|
||||
<<"guild.get_pending_joins_for_channel">>,
|
||||
#{
|
||||
<<"guild_id">> => integer_to_binary(GuildId),
|
||||
<<"channel_id">> => ChannelIdBin
|
||||
}
|
||||
)
|
||||
end;
|
||||
execute_method(Method, _Params) ->
|
||||
throw({error, <<"Unknown method: ", Method/binary>>}).
|
||||
|
||||
-spec parse_optional_guild_id(map()) -> integer() | undefined.
|
||||
parse_optional_guild_id(Params) ->
|
||||
case maps:get(<<"guild_id">>, Params, undefined) of
|
||||
undefined ->
|
||||
undefined;
|
||||
null ->
|
||||
undefined;
|
||||
GuildIdBin ->
|
||||
validation:snowflake_or_throw(<<"guild_id">>, GuildIdBin)
|
||||
end.
|
||||
|
||||
-spec maybe_put_connection_id(binary() | undefined, map()) -> map().
|
||||
maybe_put_connection_id(undefined, Params) ->
|
||||
Params;
|
||||
maybe_put_connection_id(ConnectionId, Params) ->
|
||||
Params#{<<"connection_id">> => ConnectionId}.
|
||||
|
||||
-spec build_dm_voice_states_response(binary()) -> map().
|
||||
build_dm_voice_states_response(ChannelIdBin) ->
|
||||
case gateway_rpc_call:execute_method(<<"call.get">>, #{<<"channel_id">> => ChannelIdBin}) of
|
||||
null ->
|
||||
#{<<"voice_states">> => []};
|
||||
CallData when is_map(CallData) ->
|
||||
VoiceStates = get_map_value(CallData, [<<"voice_states">>, voice_states]),
|
||||
#{<<"voice_states">> => normalize_voice_states(VoiceStates)}
|
||||
end.
|
||||
|
||||
-spec normalize_voice_states(term()) -> [map()].
|
||||
normalize_voice_states(VoiceStates) when is_list(VoiceStates) ->
|
||||
lists:reverse(
|
||||
lists:foldl(fun normalize_voice_state_entry/2, [], VoiceStates)
|
||||
);
|
||||
normalize_voice_states(_) ->
|
||||
[].
|
||||
|
||||
-spec normalize_voice_state_entry(map(), [map()]) -> [map()].
|
||||
normalize_voice_state_entry(VoiceState, Acc) ->
|
||||
ConnectionId = normalize_id(get_map_value(VoiceState, [<<"connection_id">>, connection_id])),
|
||||
UserId = normalize_id(get_map_value(VoiceState, [<<"user_id">>, user_id])),
|
||||
ChannelId = normalize_id(get_map_value(VoiceState, [<<"channel_id">>, channel_id])),
|
||||
case {ConnectionId, UserId, ChannelId} of
|
||||
{undefined, _, _} ->
|
||||
Acc;
|
||||
{_, undefined, _} ->
|
||||
Acc;
|
||||
{_, _, undefined} ->
|
||||
Acc;
|
||||
_ ->
|
||||
[#{
|
||||
<<"connection_id">> => ConnectionId,
|
||||
<<"user_id">> => UserId,
|
||||
<<"channel_id">> => ChannelId
|
||||
} | Acc]
|
||||
end.
|
||||
|
||||
-spec normalize_pending_joins_response(term()) -> map().
|
||||
normalize_pending_joins_response(Response) when is_map(Response) ->
|
||||
PendingJoins = get_map_value(Response, [<<"pending_joins">>, pending_joins]),
|
||||
#{<<"pending_joins">> => normalize_pending_joins(PendingJoins)};
|
||||
normalize_pending_joins_response(_) ->
|
||||
#{<<"pending_joins">> => []}.
|
||||
|
||||
-spec normalize_pending_joins(term()) -> [map()].
|
||||
normalize_pending_joins(PendingJoins) when is_list(PendingJoins) ->
|
||||
lists:reverse(
|
||||
lists:foldl(fun normalize_pending_join_entry/2, [], PendingJoins)
|
||||
);
|
||||
normalize_pending_joins(_) ->
|
||||
[].
|
||||
|
||||
-spec normalize_pending_join_entry(map(), [map()]) -> [map()].
|
||||
normalize_pending_join_entry(PendingJoin, Acc) ->
|
||||
ConnectionId = normalize_id(get_map_value(PendingJoin, [<<"connection_id">>, connection_id])),
|
||||
UserId = normalize_id(get_map_value(PendingJoin, [<<"user_id">>, user_id])),
|
||||
TokenNonce = normalize_token_nonce(get_map_value(PendingJoin, [<<"token_nonce">>, token_nonce])),
|
||||
ExpiresAt = normalize_expiry(get_map_value(PendingJoin, [<<"expires_at">>, expires_at])),
|
||||
case {ConnectionId, UserId} of
|
||||
{undefined, _} ->
|
||||
Acc;
|
||||
{_, undefined} ->
|
||||
Acc;
|
||||
_ ->
|
||||
[#{
|
||||
<<"connection_id">> => ConnectionId,
|
||||
<<"user_id">> => UserId,
|
||||
<<"token_nonce">> => TokenNonce,
|
||||
<<"expires_at">> => ExpiresAt
|
||||
} | Acc]
|
||||
end.
|
||||
|
||||
-spec normalize_id(term()) -> binary() | undefined.
|
||||
normalize_id(undefined) ->
|
||||
undefined;
|
||||
normalize_id(Value) when is_binary(Value) ->
|
||||
Value;
|
||||
normalize_id(Value) when is_integer(Value) ->
|
||||
integer_to_binary(Value);
|
||||
normalize_id(_) ->
|
||||
undefined.
|
||||
|
||||
-spec normalize_token_nonce(term()) -> binary().
|
||||
normalize_token_nonce(undefined) ->
|
||||
<<>>;
|
||||
normalize_token_nonce(Value) when is_binary(Value) ->
|
||||
Value;
|
||||
normalize_token_nonce(Value) when is_integer(Value) ->
|
||||
integer_to_binary(Value);
|
||||
normalize_token_nonce(_) ->
|
||||
<<>>.
|
||||
|
||||
-spec normalize_expiry(term()) -> integer().
|
||||
normalize_expiry(Value) when is_integer(Value) ->
|
||||
Value;
|
||||
normalize_expiry(_) ->
|
||||
0.
|
||||
|
||||
-spec get_map_value(map(), [term()]) -> term().
|
||||
get_map_value(_Map, []) ->
|
||||
undefined;
|
||||
get_map_value(Map, [Key | Rest]) ->
|
||||
case maps:find(Key, Map) of
|
||||
{ok, Value} ->
|
||||
Value;
|
||||
error ->
|
||||
get_map_value(Map, Rest)
|
||||
end.
|
||||
@@ -19,6 +19,7 @@
|
||||
|
||||
-export([init/2]).
|
||||
|
||||
-spec init(cowboy_req:req(), term()) -> {ok, cowboy_req:req(), term()}.
|
||||
init(Req0, State) ->
|
||||
Req = cowboy_req:reply(
|
||||
200,
|
||||
|
||||
@@ -60,6 +60,7 @@
|
||||
|
||||
-type purge_mode() :: none | soft | hard.
|
||||
-type reload_opts() :: #{purge => purge_mode()}.
|
||||
-type reload_result() :: map().
|
||||
|
||||
-spec reload_module(atom()) -> {ok, map()} | {error, term()}.
|
||||
reload_module(Module) when is_atom(Module) ->
|
||||
@@ -79,11 +80,11 @@ reload_module(Module) when is_atom(Module) ->
|
||||
end
|
||||
end.
|
||||
|
||||
-spec reload_modules([atom()]) -> {ok, [map()]}.
|
||||
-spec reload_modules([atom()]) -> {ok, [reload_result()]}.
|
||||
reload_modules(Modules) when is_list(Modules) ->
|
||||
reload_modules(Modules, #{purge => soft}).
|
||||
|
||||
-spec reload_modules([atom()], reload_opts()) -> {ok, [map()]}.
|
||||
-spec reload_modules([atom()], reload_opts()) -> {ok, [reload_result()]}.
|
||||
reload_modules(Modules, Opts) when is_list(Modules), is_map(Opts) ->
|
||||
Purge = maps:get(purge, Opts, soft),
|
||||
Results = lists:map(
|
||||
@@ -94,7 +95,7 @@ reload_modules(Modules, Opts) when is_list(Modules), is_map(Opts) ->
|
||||
),
|
||||
{ok, Results}.
|
||||
|
||||
-spec reload_beams([{atom(), binary()}], reload_opts()) -> {ok, [map()]}.
|
||||
-spec reload_beams([{atom(), binary()}], reload_opts()) -> {ok, [reload_result()]}.
|
||||
reload_beams(Pairs, Opts) when is_list(Pairs), is_map(Opts) ->
|
||||
Purge = maps:get(purge, Opts, soft),
|
||||
Results =
|
||||
@@ -106,11 +107,11 @@ reload_beams(Pairs, Opts) when is_list(Pairs), is_map(Opts) ->
|
||||
),
|
||||
{ok, Results}.
|
||||
|
||||
-spec reload_all_changed() -> {ok, [map()]}.
|
||||
-spec reload_all_changed() -> {ok, [reload_result()]}.
|
||||
reload_all_changed() ->
|
||||
reload_all_changed(soft).
|
||||
|
||||
-spec reload_all_changed(purge_mode()) -> {ok, [map()]}.
|
||||
-spec reload_all_changed(purge_mode()) -> {ok, [reload_result()]}.
|
||||
reload_all_changed(Purge) ->
|
||||
ChangedModules = get_changed_modules(),
|
||||
reload_modules(ChangedModules, #{purge => Purge}).
|
||||
@@ -141,6 +142,7 @@ get_module_info(Module) when is_atom(Module) ->
|
||||
}}
|
||||
end.
|
||||
|
||||
-spec reload_one(atom(), purge_mode()) -> reload_result().
|
||||
reload_one(Module, Purge) ->
|
||||
case is_critical_module(Module) of
|
||||
true ->
|
||||
@@ -149,6 +151,7 @@ reload_one(Module, Purge) ->
|
||||
do_reload_one(Module, Purge)
|
||||
end.
|
||||
|
||||
-spec reload_one_beam(atom(), binary(), purge_mode()) -> reload_result().
|
||||
reload_one_beam(Module, BeamBin, Purge) ->
|
||||
case is_critical_module(Module) of
|
||||
true ->
|
||||
@@ -157,19 +160,21 @@ reload_one_beam(Module, BeamBin, Purge) ->
|
||||
do_reload_one_beam(Module, BeamBin, Purge)
|
||||
end.
|
||||
|
||||
-spec do_reload_one(atom(), purge_mode()) -> reload_result().
|
||||
do_reload_one(Module, Purge) ->
|
||||
OldLoadedMd5 = loaded_md5(Module),
|
||||
OldBeamPath = code:which(Module),
|
||||
OldDiskMd5 = disk_md5(OldBeamPath),
|
||||
|
||||
ok = maybe_purge_before_load(Module, Purge),
|
||||
|
||||
case code:load_file(Module) of
|
||||
{module, Module} ->
|
||||
NewLoadedMd5 = loaded_md5(Module),
|
||||
NewBeamPath = code:which(Module),
|
||||
NewDiskMd5 = disk_md5(NewBeamPath),
|
||||
Verified = (NewLoadedMd5 =/= undefined) andalso (NewDiskMd5 =/= undefined) andalso (NewLoadedMd5 =:= NewDiskMd5),
|
||||
Verified =
|
||||
(NewLoadedMd5 =/= undefined) andalso
|
||||
(NewDiskMd5 =/= undefined) andalso
|
||||
(NewLoadedMd5 =:= NewDiskMd5),
|
||||
{PurgedOld, LingeringCount} = maybe_purge_old_after_load(Module, Purge),
|
||||
#{
|
||||
module => Module,
|
||||
@@ -195,9 +200,9 @@ do_reload_one(Module, Purge) ->
|
||||
}
|
||||
end.
|
||||
|
||||
-spec do_reload_one_beam(atom(), binary(), purge_mode()) -> reload_result().
|
||||
do_reload_one_beam(Module, BeamBin, Purge) ->
|
||||
OldLoadedMd5 = loaded_md5(Module),
|
||||
|
||||
ExpectedMd5 =
|
||||
case beam_lib:md5(BeamBin) of
|
||||
{ok, {Module, Md5}} ->
|
||||
@@ -207,9 +212,7 @@ do_reload_one_beam(Module, BeamBin, Purge) ->
|
||||
_ ->
|
||||
erlang:error(invalid_beam)
|
||||
end,
|
||||
|
||||
ok = maybe_purge_before_load(Module, Purge),
|
||||
|
||||
Filename = atom_to_list(Module) ++ ".beam(hot)",
|
||||
case code:load_binary(Module, Filename, BeamBin) of
|
||||
{module, Module} ->
|
||||
@@ -239,6 +242,7 @@ do_reload_one_beam(Module, BeamBin, Purge) ->
|
||||
}
|
||||
end.
|
||||
|
||||
-spec maybe_purge_before_load(atom(), purge_mode()) -> ok.
|
||||
maybe_purge_before_load(_Module, none) ->
|
||||
ok;
|
||||
maybe_purge_before_load(_Module, soft) ->
|
||||
@@ -247,16 +251,28 @@ maybe_purge_before_load(Module, hard) ->
|
||||
_ = code:purge(Module),
|
||||
ok.
|
||||
|
||||
-spec maybe_purge_old_after_load(atom(), purge_mode()) -> {boolean(), non_neg_integer()}.
|
||||
maybe_purge_old_after_load(_Module, none) ->
|
||||
{false, 0};
|
||||
maybe_purge_old_after_load(Module, hard) ->
|
||||
_ = code:soft_purge(Module),
|
||||
Purged = code:purge(Module),
|
||||
{Purged, case Purged of true -> 0; false -> count_lingering(Module) end};
|
||||
LingeringCount =
|
||||
case Purged of
|
||||
true -> 0;
|
||||
false -> count_lingering(Module)
|
||||
end,
|
||||
{Purged, LingeringCount};
|
||||
maybe_purge_old_after_load(Module, soft) ->
|
||||
Purged = wait_soft_purge(Module, 40, 50),
|
||||
{Purged, case Purged of true -> 0; false -> count_lingering(Module) end}.
|
||||
LingeringCount =
|
||||
case Purged of
|
||||
true -> 0;
|
||||
false -> count_lingering(Module)
|
||||
end,
|
||||
{Purged, LingeringCount}.
|
||||
|
||||
-spec wait_soft_purge(atom(), non_neg_integer(), pos_integer()) -> boolean().
|
||||
wait_soft_purge(_Module, 0, _SleepMs) ->
|
||||
false;
|
||||
wait_soft_purge(Module, N, SleepMs) ->
|
||||
@@ -264,10 +280,13 @@ wait_soft_purge(Module, N, SleepMs) ->
|
||||
true ->
|
||||
true;
|
||||
false ->
|
||||
receive after SleepMs -> ok end,
|
||||
receive
|
||||
after SleepMs -> ok
|
||||
end,
|
||||
wait_soft_purge(Module, N - 1, SleepMs)
|
||||
end.
|
||||
|
||||
-spec count_lingering(atom()) -> non_neg_integer().
|
||||
count_lingering(Module) ->
|
||||
lists:foldl(
|
||||
fun(Pid, Acc) ->
|
||||
@@ -280,21 +299,26 @@ count_lingering(Module) ->
|
||||
processes()
|
||||
).
|
||||
|
||||
-spec get_changed_modules() -> [atom()].
|
||||
get_changed_modules() ->
|
||||
Modified = code:modified_modules(),
|
||||
[M || M <- Modified, is_fluxer_module(M), not is_critical_module(M)].
|
||||
|
||||
-spec is_critical_module(atom()) -> boolean().
|
||||
is_critical_module(Module) ->
|
||||
lists:member(Module, ?CRITICAL_MODULES).
|
||||
|
||||
-spec is_fluxer_module(atom()) -> boolean().
|
||||
is_fluxer_module(Module) ->
|
||||
ModuleStr = atom_to_list(Module),
|
||||
lists:prefix("fluxer_", ModuleStr) orelse
|
||||
lists:prefix("gateway", ModuleStr) orelse
|
||||
lists:prefix("gateway_http_", ModuleStr) orelse
|
||||
lists:prefix("session", ModuleStr) orelse
|
||||
lists:prefix("guild", ModuleStr) orelse
|
||||
lists:prefix("presence", ModuleStr) orelse
|
||||
lists:prefix("push", ModuleStr) orelse
|
||||
lists:prefix("push_dispatcher", ModuleStr) orelse
|
||||
lists:prefix("call", ModuleStr) orelse
|
||||
lists:prefix("health", ModuleStr) orelse
|
||||
lists:prefix("hot_reload", ModuleStr) orelse
|
||||
@@ -311,9 +335,13 @@ is_fluxer_module(Module) ->
|
||||
lists:prefix("map_utils", ModuleStr) orelse
|
||||
lists:prefix("type_conv", ModuleStr) orelse
|
||||
lists:prefix("utils", ModuleStr) orelse
|
||||
lists:prefix("snowflake_", ModuleStr) orelse
|
||||
lists:prefix("user_utils", ModuleStr) orelse
|
||||
lists:prefix("custom_status", ModuleStr).
|
||||
lists:prefix("custom_status", ModuleStr) orelse
|
||||
lists:prefix("otel_", ModuleStr) orelse
|
||||
lists:prefix("event_", ModuleStr).
|
||||
|
||||
-spec loaded_md5(atom()) -> binary() | undefined.
|
||||
loaded_md5(Module) ->
|
||||
try
|
||||
Module:module_info(md5)
|
||||
@@ -321,6 +349,7 @@ loaded_md5(Module) ->
|
||||
_:_ -> undefined
|
||||
end.
|
||||
|
||||
-spec disk_md5(string() | atom()) -> binary() | undefined.
|
||||
disk_md5(Path) when is_list(Path) ->
|
||||
case beam_lib:md5(Path) of
|
||||
{ok, {_M, Md5}} -> Md5;
|
||||
@@ -329,11 +358,13 @@ disk_md5(Path) when is_list(Path) ->
|
||||
disk_md5(_) ->
|
||||
undefined.
|
||||
|
||||
-spec hex_or_null(binary() | undefined) -> binary() | null.
|
||||
hex_or_null(undefined) ->
|
||||
null;
|
||||
hex_or_null(Bin) when is_binary(Bin) ->
|
||||
binary:encode_hex(Bin, lowercase).
|
||||
|
||||
-spec get_loaded_time(atom()) -> term().
|
||||
get_loaded_time(Module) ->
|
||||
try
|
||||
case Module:module_info(compile) of
|
||||
@@ -346,6 +377,7 @@ get_loaded_time(Module) ->
|
||||
_:_ -> undefined
|
||||
end.
|
||||
|
||||
-spec get_disk_time(string() | atom()) -> calendar:datetime() | undefined.
|
||||
get_disk_time(BeamPath) when is_list(BeamPath) ->
|
||||
case file:read_file_info(BeamPath) of
|
||||
{ok, FileInfo} ->
|
||||
|
||||
@@ -23,6 +23,9 @@
|
||||
-define(MAX_MODULES, 600).
|
||||
-define(MAX_BODY_BYTES, 26214400).
|
||||
|
||||
-type purge_mode() :: none | soft | hard.
|
||||
|
||||
-spec init(cowboy_req:req(), term()) -> {ok, cowboy_req:req(), term()}.
|
||||
init(Req0, State) ->
|
||||
case cowboy_req:method(Req0) of
|
||||
<<"POST">> ->
|
||||
@@ -32,6 +35,7 @@ init(Req0, State) ->
|
||||
{ok, Req, State}
|
||||
end.
|
||||
|
||||
-spec handle_post(cowboy_req:req(), term()) -> {ok, cowboy_req:req(), term()}.
|
||||
handle_post(Req0, State) ->
|
||||
case authorize(Req0) of
|
||||
ok ->
|
||||
@@ -45,52 +49,75 @@ handle_post(Req0, State) ->
|
||||
{ok, Req1, State}
|
||||
end.
|
||||
|
||||
-spec authorize(cowboy_req:req()) -> ok | {error, cowboy_req:req()}.
|
||||
authorize(Req0) ->
|
||||
case cowboy_req:header(<<"authorization">>, Req0) of
|
||||
undefined ->
|
||||
Req = cowboy_req:reply(
|
||||
401,
|
||||
?JSON_HEADERS,
|
||||
jsx:encode(#{<<"error">> => <<"Unauthorized">>}),
|
||||
json:encode(#{<<"error">> => <<"Unauthorized">>}),
|
||||
Req0
|
||||
),
|
||||
{error, Req};
|
||||
AuthHeader ->
|
||||
case os:getenv("GATEWAY_ADMIN_SECRET") of
|
||||
false ->
|
||||
Req = cowboy_req:reply(
|
||||
500,
|
||||
?JSON_HEADERS,
|
||||
jsx:encode(#{<<"error">> => <<"GATEWAY_ADMIN_SECRET not configured">>}),
|
||||
Req0
|
||||
),
|
||||
{error, Req};
|
||||
Secret ->
|
||||
Expected = <<"Bearer ", (list_to_binary(Secret))/binary>>,
|
||||
case AuthHeader of
|
||||
Expected ->
|
||||
ok;
|
||||
_ ->
|
||||
Req = cowboy_req:reply(
|
||||
401,
|
||||
?JSON_HEADERS,
|
||||
jsx:encode(#{<<"error">> => <<"Unauthorized">>}),
|
||||
Req0
|
||||
),
|
||||
{error, Req}
|
||||
end
|
||||
end
|
||||
authorize_with_secret(AuthHeader, Req0)
|
||||
end.
|
||||
|
||||
-spec authorize_with_secret(binary(), cowboy_req:req()) -> ok | {error, cowboy_req:req()}.
|
||||
authorize_with_secret(AuthHeader, Req0) ->
|
||||
case fluxer_gateway_env:get(admin_reload_secret) of
|
||||
undefined ->
|
||||
Req = cowboy_req:reply(
|
||||
500,
|
||||
?JSON_HEADERS,
|
||||
json:encode(#{<<"error">> => <<"admin reload secret not configured">>}),
|
||||
Req0
|
||||
),
|
||||
{error, Req};
|
||||
Secret when is_binary(Secret) ->
|
||||
check_auth_header(AuthHeader, <<"Bearer ", Secret/binary>>, Req0);
|
||||
Secret when is_list(Secret) ->
|
||||
check_auth_header(AuthHeader, <<"Bearer ", (list_to_binary(Secret))/binary>>, Req0)
|
||||
end.
|
||||
|
||||
-spec check_auth_header(binary(), binary(), cowboy_req:req()) -> ok | {error, cowboy_req:req()}.
|
||||
check_auth_header(AuthHeader, Expected, Req0) ->
|
||||
case secure_compare(AuthHeader, Expected) of
|
||||
true ->
|
||||
ok;
|
||||
false ->
|
||||
Req = cowboy_req:reply(
|
||||
401,
|
||||
?JSON_HEADERS,
|
||||
json:encode(#{<<"error">> => <<"Unauthorized">>}),
|
||||
Req0
|
||||
),
|
||||
{error, Req}
|
||||
end.
|
||||
|
||||
-spec secure_compare(binary(), binary()) -> boolean().
|
||||
secure_compare(Left, Right) when is_binary(Left), is_binary(Right) ->
|
||||
case byte_size(Left) =:= byte_size(Right) of
|
||||
true ->
|
||||
crypto:hash_equals(Left, Right);
|
||||
false ->
|
||||
false
|
||||
end.
|
||||
|
||||
-spec read_body(cowboy_req:req()) ->
|
||||
{ok, map(), cowboy_req:req()} | {error, pos_integer(), map(), cowboy_req:req()}.
|
||||
read_body(Req0) ->
|
||||
case cowboy_req:body_length(Req0) of
|
||||
Length when is_integer(Length), Length > ?MAX_BODY_BYTES ->
|
||||
{error, 413, #{<<"error">> => <<"Request body too large">>}, Req0};
|
||||
_ ->
|
||||
read_body(Req0, <<>>)
|
||||
read_body_chunks(Req0, <<>>)
|
||||
end.
|
||||
|
||||
read_body(Req0, Acc) ->
|
||||
-spec read_body_chunks(cowboy_req:req(), binary()) ->
|
||||
{ok, map(), cowboy_req:req()} | {error, pos_integer(), map(), cowboy_req:req()}.
|
||||
read_body_chunks(Req0, Acc) ->
|
||||
case cowboy_req:read_body(Req0, #{length => 1048576}) of
|
||||
{ok, Body, Req1} ->
|
||||
FullBody = <<Acc/binary, Body/binary>>,
|
||||
@@ -101,14 +128,16 @@ read_body(Req0, Acc) ->
|
||||
true ->
|
||||
{error, 413, #{<<"error">> => <<"Request body too large">>}, Req1};
|
||||
false ->
|
||||
read_body(Req1, NewAcc)
|
||||
read_body_chunks(Req1, NewAcc)
|
||||
end
|
||||
end.
|
||||
|
||||
-spec decode_body(binary(), cowboy_req:req()) ->
|
||||
{ok, map(), cowboy_req:req()} | {error, pos_integer(), map(), cowboy_req:req()}.
|
||||
decode_body(<<>>, Req0) ->
|
||||
{ok, #{}, Req0};
|
||||
decode_body(Body, Req0) ->
|
||||
case catch jsx:decode(Body, [return_maps]) of
|
||||
case catch json:decode(Body) of
|
||||
{'EXIT', _Reason} ->
|
||||
{error, 400, #{<<"error">> => <<"Invalid JSON payload">>}, Req0};
|
||||
Decoded when is_map(Decoded) ->
|
||||
@@ -117,6 +146,7 @@ decode_body(Body, Req0) ->
|
||||
{error, 400, #{<<"error">> => <<"Invalid request body">>}, Req0}
|
||||
end.
|
||||
|
||||
-spec handle_reload(map(), cowboy_req:req(), term()) -> {ok, cowboy_req:req(), term()}.
|
||||
handle_reload(Params, Req0, State) ->
|
||||
try
|
||||
Purge = parse_purge(maps:get(<<"purge">>, Params, <<"soft">>)),
|
||||
@@ -124,14 +154,7 @@ handle_reload(Params, Req0, State) ->
|
||||
undefined ->
|
||||
handle_modules_reload(Params, Purge, Req0, State);
|
||||
Beams when is_list(Beams) ->
|
||||
case length(Beams) =< ?MAX_MODULES of
|
||||
true ->
|
||||
Pairs = decode_beams(Beams),
|
||||
{ok, Results} = hot_reload:reload_beams(Pairs, #{purge => Purge}),
|
||||
respond(200, #{<<"results">> => Results}, Req0, State);
|
||||
false ->
|
||||
respond(400, #{<<"error">> => <<"Too many modules">>}, Req0, State)
|
||||
end;
|
||||
handle_beams_reload(Beams, Purge, Req0, State);
|
||||
_ ->
|
||||
respond(400, #{<<"error">> => <<"beams must be an array">>}, Req0, State)
|
||||
end
|
||||
@@ -142,11 +165,24 @@ handle_reload(Params, Req0, State) ->
|
||||
respond(400, #{<<"error">> => <<"Invalid module name or beam payload">>}, Req0, State);
|
||||
error:{beam_module_mismatch, _, _} ->
|
||||
respond(400, #{<<"error">> => <<"Invalid module name or beam payload">>}, Req0, State);
|
||||
_:Reason ->
|
||||
logger:error("hot_reload_handler: Error during reload: ~p", [Reason]),
|
||||
_:_Reason ->
|
||||
respond(500, #{<<"error">> => <<"Internal error">>}, Req0, State)
|
||||
end.
|
||||
|
||||
-spec handle_beams_reload([map()], purge_mode(), cowboy_req:req(), term()) ->
|
||||
{ok, cowboy_req:req(), term()}.
|
||||
handle_beams_reload(Beams, Purge, Req0, State) ->
|
||||
case length(Beams) =< ?MAX_MODULES of
|
||||
true ->
|
||||
Pairs = decode_beams(Beams),
|
||||
{ok, Results} = hot_reload:reload_beams(Pairs, #{purge => Purge}),
|
||||
respond(200, #{<<"results">> => Results}, Req0, State);
|
||||
false ->
|
||||
respond(400, #{<<"error">> => <<"Too many modules">>}, Req0, State)
|
||||
end.
|
||||
|
||||
-spec handle_modules_reload(map(), purge_mode(), cowboy_req:req(), term()) ->
|
||||
{ok, cowboy_req:req(), term()}.
|
||||
handle_modules_reload(Params, Purge, Req0, State) ->
|
||||
case maps:get(<<"modules">>, Params, []) of
|
||||
[] ->
|
||||
@@ -165,6 +201,7 @@ handle_modules_reload(Params, Purge, Req0, State) ->
|
||||
respond(400, #{<<"error">> => <<"modules must be an array">>}, Req0, State)
|
||||
end.
|
||||
|
||||
-spec decode_beams([map()]) -> [{atom(), binary()}].
|
||||
decode_beams(Beams) ->
|
||||
lists:map(
|
||||
fun(Elem) ->
|
||||
@@ -187,6 +224,7 @@ decode_beams(Beams) ->
|
||||
Beams
|
||||
).
|
||||
|
||||
-spec to_binary(binary() | list()) -> binary().
|
||||
to_binary(B) when is_binary(B) ->
|
||||
B;
|
||||
to_binary(L) when is_list(L) ->
|
||||
@@ -194,6 +232,7 @@ to_binary(L) when is_list(L) ->
|
||||
to_binary(_) ->
|
||||
erlang:error(badarg).
|
||||
|
||||
-spec parse_purge(binary() | atom()) -> purge_mode().
|
||||
parse_purge(<<"none">>) -> none;
|
||||
parse_purge(<<"soft">>) -> soft;
|
||||
parse_purge(<<"hard">>) -> hard;
|
||||
@@ -202,6 +241,7 @@ parse_purge(soft) -> soft;
|
||||
parse_purge(hard) -> hard;
|
||||
parse_purge(_) -> soft.
|
||||
|
||||
-spec to_module_atom(binary() | list()) -> atom().
|
||||
to_module_atom(B) when is_binary(B) ->
|
||||
case is_allowed_module_name(B) of
|
||||
true -> erlang:binary_to_atom(B, utf8);
|
||||
@@ -212,10 +252,12 @@ to_module_atom(L) when is_list(L) ->
|
||||
to_module_atom(_) ->
|
||||
erlang:error(badarg).
|
||||
|
||||
-spec is_allowed_module_name(binary()) -> boolean().
|
||||
is_allowed_module_name(Bin) when is_binary(Bin) ->
|
||||
byte_size(Bin) > 0 andalso byte_size(Bin) < 128 andalso
|
||||
is_safe_chars(Bin) andalso has_allowed_prefix(Bin).
|
||||
|
||||
-spec is_safe_chars(binary()) -> boolean().
|
||||
is_safe_chars(Bin) ->
|
||||
lists:all(
|
||||
fun(C) ->
|
||||
@@ -226,14 +268,17 @@ is_safe_chars(Bin) ->
|
||||
binary_to_list(Bin)
|
||||
).
|
||||
|
||||
-spec has_allowed_prefix(binary()) -> boolean().
|
||||
has_allowed_prefix(Bin) ->
|
||||
Prefixes = [
|
||||
<<"fluxer_">>,
|
||||
<<"gateway">>,
|
||||
<<"gateway_http_">>,
|
||||
<<"session">>,
|
||||
<<"guild">>,
|
||||
<<"presence">>,
|
||||
<<"push">>,
|
||||
<<"push_dispatcher">>,
|
||||
<<"call">>,
|
||||
<<"health">>,
|
||||
<<"hot_reload">>,
|
||||
@@ -251,7 +296,10 @@ has_allowed_prefix(Bin) ->
|
||||
<<"type_conv">>,
|
||||
<<"utils">>,
|
||||
<<"user_utils">>,
|
||||
<<"custom_status">>
|
||||
<<"snowflake_">>,
|
||||
<<"custom_status">>,
|
||||
<<"otel_">>,
|
||||
<<"event_">>
|
||||
],
|
||||
lists:any(
|
||||
fun(P) ->
|
||||
@@ -261,6 +309,7 @@ has_allowed_prefix(Bin) ->
|
||||
Prefixes
|
||||
).
|
||||
|
||||
-spec respond(pos_integer(), map(), cowboy_req:req(), term()) -> {ok, cowboy_req:req(), term()}.
|
||||
respond(Status, Body, Req0, State) ->
|
||||
Req = cowboy_req:reply(Status, ?JSON_HEADERS, jsx:encode(Body), Req0),
|
||||
Req = cowboy_req:reply(Status, ?JSON_HEADERS, json:encode(Body), Req0),
|
||||
{ok, Req, State}.
|
||||
|
||||
@@ -19,10 +19,6 @@
|
||||
|
||||
-export([select/2, group_keys/2]).
|
||||
|
||||
-ifdef(TEST).
|
||||
-include_lib("eunit/include/eunit.hrl").
|
||||
-endif.
|
||||
|
||||
-define(HASH_LIMIT, 16#FFFFFFFF).
|
||||
|
||||
-spec select(term(), pos_integer()) -> non_neg_integer().
|
||||
@@ -51,22 +47,21 @@ select(_Key, _ShardCount) ->
|
||||
|
||||
-spec group_keys([term()], pos_integer()) -> [{non_neg_integer(), [term()]}].
|
||||
group_keys(Keys, ShardCount) when is_list(Keys), ShardCount > 0 ->
|
||||
Sorted =
|
||||
maps:to_list(
|
||||
lists:foldl(
|
||||
fun(Key, Acc) ->
|
||||
Index = select(Key, ShardCount),
|
||||
Existing = maps:get(Index, Acc, []),
|
||||
maps:put(Index, [Key | Existing], Acc)
|
||||
end,
|
||||
#{},
|
||||
Keys
|
||||
)
|
||||
Grouped =
|
||||
lists:foldl(
|
||||
fun(Key, Acc) ->
|
||||
Index = select(Key, ShardCount),
|
||||
Existing = maps:get(Index, Acc, []),
|
||||
maps:put(Index, [Key | Existing], Acc)
|
||||
end,
|
||||
#{},
|
||||
Keys
|
||||
),
|
||||
lists:sort(
|
||||
Sorted = lists:sort(
|
||||
fun({IdxA, _}, {IdxB, _}) -> IdxA =< IdxB end,
|
||||
[{Index, lists:usort(Group)} || {Index, Group} <- Sorted]
|
||||
);
|
||||
[{Index, lists:usort(Group)} || {Index, Group} <- maps:to_list(Grouped)]
|
||||
),
|
||||
Sorted;
|
||||
group_keys(_Keys, _ShardCount) ->
|
||||
[].
|
||||
|
||||
@@ -75,20 +70,55 @@ weight(Key, Index) ->
|
||||
erlang:phash2({Key, Index}, ?HASH_LIMIT).
|
||||
|
||||
-ifdef(TEST).
|
||||
select_returns_valid_index_test() ->
|
||||
-include_lib("eunit/include/eunit.hrl").
|
||||
|
||||
select_single_shard_test() ->
|
||||
?assertEqual(0, select(test_key, 1)),
|
||||
Index = select(test_key, 5),
|
||||
?assert(Index >= 0),
|
||||
?assert(Index < 5).
|
||||
?assertEqual(0, select(any_key, 1)),
|
||||
?assertEqual(0, select(12345, 1)).
|
||||
|
||||
select_is_stable_for_same_inputs_test() ->
|
||||
?assertEqual(select(<<"abc">>, 8), select(<<"abc">>, 8)),
|
||||
?assertEqual(select(12345, 3), select(12345, 3)).
|
||||
select_valid_index_test_() ->
|
||||
[
|
||||
?_test(begin
|
||||
Index = select(test_key, N),
|
||||
?assert(Index >= 0),
|
||||
?assert(Index < N)
|
||||
end)
|
||||
|| N <- [2, 5, 10, 100]
|
||||
].
|
||||
|
||||
group_keys_sorts_and_deduplicates_test() ->
|
||||
select_stability_test_() ->
|
||||
[
|
||||
?_assertEqual(select(<<"abc">>, 8), select(<<"abc">>, 8)),
|
||||
?_assertEqual(select(12345, 3), select(12345, 3)),
|
||||
?_assertEqual(select({user, 1}, 10), select({user, 1}, 10))
|
||||
].
|
||||
|
||||
select_distribution_test() ->
|
||||
Keys = lists:seq(1, 1000),
|
||||
ShardCount = 10,
|
||||
Distribution = lists:foldl(
|
||||
fun(Key, Acc) ->
|
||||
Index = select(Key, ShardCount),
|
||||
maps:update_with(Index, fun(V) -> V + 1 end, 1, Acc)
|
||||
end,
|
||||
#{},
|
||||
Keys
|
||||
),
|
||||
Counts = maps:values(Distribution),
|
||||
?assertEqual(ShardCount, maps:size(Distribution)),
|
||||
lists:foreach(fun(Count) -> ?assert(Count > 0) end, Counts).
|
||||
|
||||
group_keys_empty_test() ->
|
||||
?assertEqual([], group_keys([], 4)).
|
||||
|
||||
group_keys_single_test() ->
|
||||
Groups = group_keys([key1], 4),
|
||||
?assertEqual(1, length(Groups)).
|
||||
|
||||
group_keys_deduplicates_test() ->
|
||||
Keys = [1, 2, 3, 1, 2],
|
||||
Groups = group_keys(Keys, 2),
|
||||
?assertMatch([{_, _}, {_, _}], Groups),
|
||||
lists:foreach(
|
||||
fun({_Index, GroupKeys}) ->
|
||||
?assertEqual(GroupKeys, lists:usort(GroupKeys))
|
||||
@@ -96,6 +126,16 @@ group_keys_sorts_and_deduplicates_test() ->
|
||||
Groups
|
||||
).
|
||||
|
||||
group_keys_handles_empty_test() ->
|
||||
?assertEqual([], group_keys([], 4)).
|
||||
group_keys_sorted_indices_test() ->
|
||||
Keys = lists:seq(1, 100),
|
||||
Groups = group_keys(Keys, 5),
|
||||
Indices = [I || {I, _} <- Groups],
|
||||
?assertEqual(Indices, lists:sort(Indices)).
|
||||
|
||||
group_keys_all_keys_present_test() ->
|
||||
Keys = [a, b, c, d, e],
|
||||
Groups = group_keys(Keys, 3),
|
||||
AllGroupedKeys = lists:flatten([K || {_, K} <- Groups]),
|
||||
?assertEqual(lists:sort(Keys), lists:sort(AllGroupedKeys)).
|
||||
|
||||
-endif.
|
||||
|
||||
@@ -27,59 +27,86 @@
|
||||
|
||||
-type rpc_request() :: map().
|
||||
-type rpc_response() :: {ok, map()} | {error, term()}.
|
||||
-type rpc_options() :: map().
|
||||
|
||||
-spec call(rpc_request()) -> rpc_response().
|
||||
call(Request) ->
|
||||
call(Request, #{}).
|
||||
|
||||
-spec call(rpc_request(), map()) -> rpc_response().
|
||||
-spec call(rpc_request(), rpc_options()) -> rpc_response().
|
||||
call(Request, _Options) ->
|
||||
Url = get_rpc_url(),
|
||||
Headers = get_rpc_headers(),
|
||||
Body = jsx:encode(Request),
|
||||
|
||||
case
|
||||
hackney:request(post, Url, Headers, Body, [{recv_timeout, 30000}, {connect_timeout, 5000}])
|
||||
of
|
||||
{ok, 200, _RespHeaders, ClientRef} ->
|
||||
case hackney:body(ClientRef) of
|
||||
{ok, RespBody} ->
|
||||
Response = jsx:decode(RespBody, [return_maps]),
|
||||
Data = maps:get(<<"data">>, Response, #{}),
|
||||
{ok, Data};
|
||||
{error, Reason} ->
|
||||
logger:error("[rpc_client] Failed to read response body: ~p", [Reason]),
|
||||
{error, {body_read_failed, Reason}}
|
||||
end;
|
||||
{ok, StatusCode, _RespHeaders, ClientRef} ->
|
||||
case hackney:body(ClientRef) of
|
||||
{ok, RespBody} ->
|
||||
hackney:close(ClientRef),
|
||||
logger:error("[rpc_client] RPC request failed with status ~p: ~s", [
|
||||
StatusCode, RespBody
|
||||
]),
|
||||
{error, {http_error, StatusCode, RespBody}};
|
||||
{error, Reason} ->
|
||||
hackney:close(ClientRef),
|
||||
logger:error(
|
||||
"[rpc_client] Failed to read error response body (status ~p): ~p", [
|
||||
StatusCode, Reason
|
||||
]
|
||||
),
|
||||
{error, {http_error, StatusCode, body_read_failed}}
|
||||
end;
|
||||
Body = json:encode(Request),
|
||||
case gateway_http_client:request(rpc, post, Url, Headers, Body) of
|
||||
{ok, 200, _RespHeaders, RespBody} ->
|
||||
handle_success_response(RespBody);
|
||||
{ok, StatusCode, _RespHeaders, RespBody} ->
|
||||
handle_error_response(StatusCode, RespBody);
|
||||
{error, Reason} ->
|
||||
logger:error("[rpc_client] RPC request failed: ~p", [Reason]),
|
||||
{error, Reason}
|
||||
end.
|
||||
|
||||
-spec handle_success_response(binary()) -> rpc_response().
|
||||
handle_success_response(RespBody) ->
|
||||
Response = json:decode(RespBody),
|
||||
Data = maps:get(<<"data">>, Response, #{}),
|
||||
{ok, Data}.
|
||||
|
||||
-spec handle_error_response(pos_integer(), binary()) -> {error, term()}.
|
||||
handle_error_response(StatusCode, RespBody) ->
|
||||
{error, {http_error, StatusCode, RespBody}}.
|
||||
|
||||
-spec get_rpc_url() -> string().
|
||||
get_rpc_url() ->
|
||||
ApiHost = fluxer_gateway_env:get(api_host),
|
||||
get_rpc_url(ApiHost).
|
||||
|
||||
-spec get_rpc_url(string() | binary()) -> string().
|
||||
get_rpc_url(ApiHost) ->
|
||||
"http://" ++ ApiHost ++ "/_rpc".
|
||||
BaseUrl = api_host_base_url(ApiHost),
|
||||
BaseUrl ++ "/_rpc".
|
||||
|
||||
-spec api_host_base_url(string() | binary()) -> string().
|
||||
api_host_base_url(ApiHost) ->
|
||||
HostString = ensure_string(ApiHost),
|
||||
Normalized = normalize_api_host(HostString),
|
||||
strip_trailing_slash(Normalized).
|
||||
|
||||
-spec ensure_string(binary() | string()) -> string().
|
||||
ensure_string(Value) when is_binary(Value) ->
|
||||
binary_to_list(Value);
|
||||
ensure_string(Value) when is_list(Value) ->
|
||||
Value.
|
||||
|
||||
-spec normalize_api_host(string()) -> string().
|
||||
normalize_api_host(Host) ->
|
||||
Lower = string:lowercase(Host),
|
||||
case {has_protocol_prefix(Lower, "http://"), has_protocol_prefix(Lower, "https://")} of
|
||||
{true, _} -> Host;
|
||||
{_, true} -> Host;
|
||||
_ -> "http://" ++ Host
|
||||
end.
|
||||
|
||||
-spec has_protocol_prefix(string(), string()) -> boolean().
|
||||
has_protocol_prefix(Str, Prefix) ->
|
||||
case string:prefix(Str, Prefix) of
|
||||
nomatch -> false;
|
||||
_ -> true
|
||||
end.
|
||||
|
||||
-spec strip_trailing_slash(string()) -> string().
|
||||
strip_trailing_slash([]) ->
|
||||
"";
|
||||
strip_trailing_slash(Url) ->
|
||||
case lists:last(Url) of
|
||||
$/ -> strip_trailing_slash(lists:droplast(Url));
|
||||
_ -> Url
|
||||
end.
|
||||
|
||||
-spec get_rpc_headers() -> [{binary() | string(), binary() | string()}].
|
||||
get_rpc_headers() ->
|
||||
RpcSecretKey = fluxer_gateway_env:get(rpc_secret_key),
|
||||
[{<<"Authorization">>, <<"Bearer ", RpcSecretKey/binary>>}].
|
||||
AuthHeader = {<<"Authorization">>, <<"Bearer ", RpcSecretKey/binary>>},
|
||||
InitialHeaders = [AuthHeader],
|
||||
gateway_tracing:inject_rpc_headers(InitialHeaders).
|
||||
|
||||
Reference in New Issue
Block a user