fix: various fixes to sentry-reported errors and more
This commit is contained in:
@@ -23,11 +23,12 @@
|
||||
start(_StartType, _StartArgs) ->
|
||||
fluxer_gateway_env:load(),
|
||||
otel_metrics:init(),
|
||||
passive_sync_registry:init(),
|
||||
guild_counts_cache:init(),
|
||||
Port = fluxer_gateway_env:get(port),
|
||||
Dispatch = cowboy_router:compile([
|
||||
{'_', [
|
||||
{<<"/_health">>, health_handler, []},
|
||||
{<<"/_rpc">>, gateway_rpc_http_handler, []},
|
||||
{<<"/_admin/reload">>, hot_reload_handler, []},
|
||||
{<<"/">>, gateway_handler, []}
|
||||
]}
|
||||
|
||||
@@ -43,17 +43,15 @@ load_from(Path) when is_list(Path) ->
|
||||
-spec build_config(map()) -> config().
|
||||
build_config(Json) ->
|
||||
Service = get_map(Json, [<<"services">>, <<"gateway">>]),
|
||||
Gateway = get_map(Json, [<<"gateway">>]),
|
||||
Nats = get_map(Json, [<<"services">>, <<"nats">>]),
|
||||
Telemetry = get_map(Json, [<<"telemetry">>]),
|
||||
Sentry = get_map(Json, [<<"sentry">>]),
|
||||
Vapid = get_map(Json, [<<"auth">>, <<"vapid">>]),
|
||||
#{
|
||||
port => get_int(Service, <<"port">>, 8080),
|
||||
rpc_tcp_port => get_int(Service, <<"rpc_tcp_port">>, 8772),
|
||||
api_host => get_env_or_string("FLUXER_GATEWAY_API_HOST", Service, <<"api_host">>, "api"),
|
||||
api_canary_host => get_optional_string(Service, <<"api_canary_host">>),
|
||||
admin_reload_secret => get_optional_binary(Service, <<"admin_reload_secret">>),
|
||||
rpc_secret_key => get_binary(Gateway, <<"rpc_secret">>, undefined),
|
||||
nats_core_url => get_string(Nats, <<"core_url">>, "nats://127.0.0.1:4222"),
|
||||
nats_auth_token => get_string(Nats, <<"auth_token">>, ""),
|
||||
identify_rate_limit_enabled => get_bool(Service, <<"identify_rate_limit_enabled">>, false),
|
||||
push_enabled => get_bool(Service, <<"push_enabled">>, true),
|
||||
push_user_guild_settings_cache_mb => get_int(
|
||||
@@ -73,18 +71,12 @@ build_config(Json) ->
|
||||
get_int(Service, <<"push_badge_counts_cache_ttl_seconds">>, 60),
|
||||
push_dispatcher_max_inflight => get_int(Service, <<"push_dispatcher_max_inflight">>, 16),
|
||||
push_dispatcher_max_queue => get_int(Service, <<"push_dispatcher_max_queue">>, 2048),
|
||||
gateway_http_rpc_connect_timeout_ms =>
|
||||
get_int(Service, <<"gateway_http_rpc_connect_timeout_ms">>, 5000),
|
||||
gateway_http_rpc_recv_timeout_ms =>
|
||||
get_int(Service, <<"gateway_http_rpc_recv_timeout_ms">>, 30000),
|
||||
gateway_http_push_connect_timeout_ms =>
|
||||
get_int(Service, <<"gateway_http_push_connect_timeout_ms">>, 3000),
|
||||
gateway_http_push_recv_timeout_ms =>
|
||||
get_int(Service, <<"gateway_http_push_recv_timeout_ms">>, 5000),
|
||||
gateway_http_rpc_max_concurrency =>
|
||||
get_int(Service, <<"gateway_http_rpc_max_concurrency">>, 512),
|
||||
gateway_rpc_tcp_max_input_buffer_bytes =>
|
||||
get_int(Service, <<"gateway_rpc_tcp_max_input_buffer_bytes">>, 2097152),
|
||||
gateway_http_push_max_concurrency =>
|
||||
get_int(Service, <<"gateway_http_push_max_concurrency">>, 256),
|
||||
gateway_http_failure_threshold =>
|
||||
@@ -148,27 +140,6 @@ get_optional_bool(Map, Key) ->
|
||||
get_string(Map, Key, Default) when is_list(Default) ->
|
||||
to_string(get_value(Map, Key), Default).
|
||||
|
||||
-spec get_env_or_string(string(), map(), binary(), string()) -> string().
|
||||
get_env_or_string(EnvVar, Map, Key, Default) when is_list(EnvVar), is_list(Default) ->
|
||||
case os:getenv(EnvVar) of
|
||||
false -> get_string(Map, Key, Default);
|
||||
"" -> get_string(Map, Key, Default);
|
||||
Value -> Value
|
||||
end.
|
||||
|
||||
-spec get_optional_string(map(), binary()) -> string() | undefined.
|
||||
get_optional_string(Map, Key) ->
|
||||
case get_value(Map, Key) of
|
||||
undefined ->
|
||||
undefined;
|
||||
Value ->
|
||||
Clean = string:trim(to_string(Value, "")),
|
||||
case Clean of
|
||||
"" -> undefined;
|
||||
_ -> Clean
|
||||
end
|
||||
end.
|
||||
|
||||
-spec get_binary(map(), binary(), binary() | undefined) -> binary() | undefined.
|
||||
get_binary(Map, Key, Default) ->
|
||||
to_binary(get_value(Map, Key), Default).
|
||||
|
||||
@@ -32,7 +32,7 @@ init([]) ->
|
||||
},
|
||||
Children = [
|
||||
child_spec(gateway_http_client, gateway_http_client),
|
||||
child_spec(gateway_rpc_tcp_server, gateway_rpc_tcp_server),
|
||||
child_spec(gateway_nats_rpc, gateway_nats_rpc),
|
||||
child_spec(session_manager, session_manager),
|
||||
child_spec(presence_cache, presence_cache),
|
||||
child_spec(presence_bus, presence_bus),
|
||||
|
||||
@@ -35,7 +35,6 @@
|
||||
-spec parse_compression(binary() | undefined) -> compression().
|
||||
parse_compression(<<"none">>) ->
|
||||
none;
|
||||
%% TODO: temporarily disabled – re-enable zstd-stream once compression issues are resolved
|
||||
parse_compression(<<"zstd-stream">>) ->
|
||||
none;
|
||||
parse_compression(_) ->
|
||||
@@ -123,7 +122,6 @@ parse_compression_test_() ->
|
||||
?_assertEqual(none, parse_compression(<<>>)),
|
||||
?_assertEqual(none, parse_compression(<<"none">>)),
|
||||
?_assertEqual(none, parse_compression(<<"invalid">>)),
|
||||
%% zstd-stream temporarily disabled – always returns none
|
||||
?_assertEqual(none, parse_compression(<<"zstd-stream">>))
|
||||
].
|
||||
|
||||
|
||||
@@ -517,11 +517,16 @@ handle_resume(Data, State) ->
|
||||
Token = maps:get(<<"token">>, Data),
|
||||
SessionId = maps:get(<<"session_id">>, Data),
|
||||
Seq = maps:get(<<"seq">>, Data),
|
||||
case session_manager:lookup(SessionId) of
|
||||
{ok, Pid} when is_pid(Pid) ->
|
||||
handle_resume_with_session(Pid, Token, SessionId, Seq, State);
|
||||
{error, not_found} ->
|
||||
handle_resume_session_not_found(State)
|
||||
case is_binary(SessionId) of
|
||||
false ->
|
||||
handle_resume_session_not_found(State);
|
||||
true ->
|
||||
case session_manager:lookup(SessionId) of
|
||||
{ok, Pid} when is_pid(Pid) ->
|
||||
handle_resume_with_session(Pid, Token, SessionId, Seq, State);
|
||||
{error, _} ->
|
||||
handle_resume_session_not_found(State)
|
||||
end
|
||||
end.
|
||||
|
||||
-spec handle_voice_state_update(pid(), map(), state()) -> ws_result().
|
||||
|
||||
260
fluxer_gateway/src/gateway/gateway_nats_rpc.erl
Normal file
260
fluxer_gateway/src/gateway/gateway_nats_rpc.erl
Normal file
@@ -0,0 +1,260 @@
|
||||
%% Copyright (C) 2026 Fluxer Contributors
|
||||
%%
|
||||
%% This file is part of Fluxer.
|
||||
%%
|
||||
%% Fluxer is free software: you can redistribute it and/or modify
|
||||
%% it under the terms of the GNU Affero General Public License as published by
|
||||
%% the Free Software Foundation, either version 3 of the License, or
|
||||
%% (at your option) any later version.
|
||||
%%
|
||||
%% Fluxer is distributed in the hope that it will be useful,
|
||||
%% but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
%% GNU Affero General Public License for more details.
|
||||
%%
|
||||
%% You should have received a copy of the GNU Affero General Public License
|
||||
%% along with Fluxer. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
-module(gateway_nats_rpc).
|
||||
-behaviour(gen_server).
|
||||
|
||||
-export([start_link/0, get_connection/0]).
|
||||
-export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]).
|
||||
|
||||
-define(DEFAULT_MAX_HANDLERS, 1024).
|
||||
-define(RECONNECT_DELAY_MS, 2000).
|
||||
-define(RPC_SUBJECT_PREFIX, <<"rpc.gateway.">>).
|
||||
-define(RPC_SUBJECT_WILDCARD, <<"rpc.gateway.>">>).
|
||||
-define(QUEUE_GROUP, <<"gateway">>).
|
||||
|
||||
-spec start_link() -> {ok, pid()} | {error, term()}.
|
||||
start_link() ->
|
||||
gen_server:start_link({local, ?MODULE}, ?MODULE, [], []).
|
||||
|
||||
-spec get_connection() -> {ok, nats:conn() | undefined} | {error, term()}.
|
||||
get_connection() ->
|
||||
gen_server:call(?MODULE, get_connection).
|
||||
|
||||
-spec init([]) -> {ok, map()}.
|
||||
init([]) ->
|
||||
process_flag(trap_exit, true),
|
||||
self() ! connect,
|
||||
{ok, #{
|
||||
conn => undefined,
|
||||
sub => undefined,
|
||||
handler_count => 0,
|
||||
max_handlers => max_handlers(),
|
||||
monitor_ref => undefined
|
||||
}}.
|
||||
|
||||
-spec handle_call(term(), gen_server:from(), map()) -> {reply, term(), map()}.
|
||||
handle_call(get_connection, _From, #{conn := Conn} = State) ->
|
||||
{reply, {ok, Conn}, State};
|
||||
handle_call(_Request, _From, State) ->
|
||||
{reply, ok, State}.
|
||||
|
||||
-spec handle_cast(term(), map()) -> {noreply, map()}.
|
||||
handle_cast(_Msg, State) ->
|
||||
{noreply, State}.
|
||||
|
||||
-spec handle_info(term(), map()) -> {noreply, map()}.
|
||||
handle_info(connect, State) ->
|
||||
{noreply, do_connect(State)};
|
||||
handle_info({Conn, ready}, #{conn := Conn} = State) ->
|
||||
{noreply, do_subscribe(State)};
|
||||
handle_info({Conn, closed}, #{conn := Conn} = State) ->
|
||||
logger:warning("Gateway NATS RPC connection closed, reconnecting"),
|
||||
{noreply, schedule_reconnect(State#{conn => undefined, sub => undefined, monitor_ref => undefined})};
|
||||
handle_info({Conn, {error, Reason}}, #{conn := Conn} = State) ->
|
||||
logger:warning("Gateway NATS RPC connection error: ~p, reconnecting", [Reason]),
|
||||
{noreply, schedule_reconnect(State#{conn => undefined, sub => undefined, monitor_ref => undefined})};
|
||||
handle_info({Conn, _Sid, {msg, Subject, Payload, MsgOpts}},
|
||||
#{conn := Conn, handler_count := HandlerCount, max_handlers := MaxHandlers} = State) ->
|
||||
case maps:get(reply_to, MsgOpts, undefined) of
|
||||
undefined ->
|
||||
{noreply, State};
|
||||
ReplyTo ->
|
||||
case HandlerCount >= MaxHandlers of
|
||||
true ->
|
||||
ErrorResponse = iolist_to_binary(json:encode(#{
|
||||
<<"ok">> => false,
|
||||
<<"error">> => <<"overloaded">>
|
||||
})),
|
||||
nats:pub(Conn, ReplyTo, ErrorResponse),
|
||||
{noreply, State};
|
||||
false ->
|
||||
Parent = self(),
|
||||
spawn(fun() ->
|
||||
try
|
||||
handle_rpc_request(Conn, Subject, Payload, ReplyTo)
|
||||
after
|
||||
Parent ! {handler_done, self()}
|
||||
end
|
||||
end),
|
||||
{noreply, State#{handler_count => HandlerCount + 1}}
|
||||
end
|
||||
end;
|
||||
handle_info({handler_done, _Pid}, #{handler_count := HandlerCount} = State) when HandlerCount > 0 ->
|
||||
{noreply, State#{handler_count => HandlerCount - 1}};
|
||||
handle_info({handler_done, _Pid}, State) ->
|
||||
{noreply, State};
|
||||
handle_info({'DOWN', MRef, process, Conn, Reason}, #{conn := Conn, monitor_ref := MRef} = State) ->
|
||||
logger:warning("Gateway NATS RPC connection process died: ~p, reconnecting", [Reason]),
|
||||
{noreply, schedule_reconnect(State#{conn => undefined, sub => undefined, monitor_ref => undefined})};
|
||||
handle_info(_Info, State) ->
|
||||
{noreply, State}.
|
||||
|
||||
-spec terminate(term(), map()) -> ok.
|
||||
terminate(_Reason, #{conn := Conn}) when Conn =/= undefined ->
|
||||
catch nats:disconnect(Conn),
|
||||
logger:info("Gateway NATS RPC subscriber stopped"),
|
||||
ok;
|
||||
terminate(_Reason, _State) ->
|
||||
ok.
|
||||
|
||||
-spec code_change(term(), map(), term()) -> {ok, map()}.
|
||||
code_change(_OldVsn, State, _Extra) ->
|
||||
{ok, State}.
|
||||
|
||||
-spec do_connect(map()) -> map().
|
||||
do_connect(State) ->
|
||||
NatsUrl = fluxer_gateway_env:get(nats_core_url),
|
||||
AuthToken = fluxer_gateway_env:get(nats_auth_token),
|
||||
case parse_nats_url(NatsUrl) of
|
||||
{ok, Host, Port} ->
|
||||
Opts = build_connect_opts(AuthToken),
|
||||
case nats:connect(Host, Port, Opts) of
|
||||
{ok, Conn} ->
|
||||
MRef = nats:monitor(Conn),
|
||||
logger:info("Gateway NATS RPC connected to ~s:~p", [Host, Port]),
|
||||
State#{conn => Conn, monitor_ref => MRef};
|
||||
{error, Reason} ->
|
||||
logger:error("Gateway NATS RPC failed to connect: ~p", [Reason]),
|
||||
schedule_reconnect(State)
|
||||
end;
|
||||
{error, Reason} ->
|
||||
logger:error("Gateway NATS RPC failed to parse URL: ~p", [Reason]),
|
||||
schedule_reconnect(State)
|
||||
end.
|
||||
|
||||
-spec do_subscribe(map()) -> map().
|
||||
do_subscribe(#{conn := Conn} = State) when Conn =/= undefined ->
|
||||
case nats:sub(Conn, ?RPC_SUBJECT_WILDCARD, #{queue_group => ?QUEUE_GROUP}) of
|
||||
{ok, Sid} ->
|
||||
logger:info("Gateway NATS RPC subscribed to ~s with queue group ~s",
|
||||
[?RPC_SUBJECT_WILDCARD, ?QUEUE_GROUP]),
|
||||
State#{sub => Sid};
|
||||
{error, Reason} ->
|
||||
logger:error("Gateway NATS RPC failed to subscribe: ~p", [Reason]),
|
||||
State
|
||||
end;
|
||||
do_subscribe(State) ->
|
||||
State.
|
||||
|
||||
-spec handle_rpc_request(nats:conn(), binary(), binary(), binary()) -> ok.
|
||||
handle_rpc_request(Conn, Subject, Payload, ReplyTo) ->
|
||||
Method = strip_rpc_prefix(Subject),
|
||||
Response = execute_rpc_method(Method, Payload),
|
||||
ResponseBin = iolist_to_binary(json:encode(Response)),
|
||||
nats:pub(Conn, ReplyTo, ResponseBin),
|
||||
ok.
|
||||
|
||||
-spec strip_rpc_prefix(binary()) -> binary().
|
||||
strip_rpc_prefix(<<"rpc.gateway.", Method/binary>>) ->
|
||||
Method;
|
||||
strip_rpc_prefix(Subject) ->
|
||||
Subject.
|
||||
|
||||
-spec execute_rpc_method(binary(), binary()) -> map().
|
||||
execute_rpc_method(Method, PayloadBin) ->
|
||||
try
|
||||
Params = json:decode(PayloadBin),
|
||||
Result = gateway_rpc_router:execute(Method, Params),
|
||||
#{<<"ok">> => true, <<"result">> => Result}
|
||||
catch
|
||||
throw:{error, Message} ->
|
||||
#{<<"ok">> => false, <<"error">> => error_binary(Message)};
|
||||
exit:timeout ->
|
||||
#{<<"ok">> => false, <<"error">> => <<"timeout">>};
|
||||
exit:{timeout, _} ->
|
||||
#{<<"ok">> => false, <<"error">> => <<"timeout">>};
|
||||
Class:Reason ->
|
||||
logger:error(
|
||||
"Gateway NATS RPC method execution failed. method=~ts class=~p reason=~p",
|
||||
[Method, Class, Reason]
|
||||
),
|
||||
#{<<"ok">> => false, <<"error">> => <<"internal_error">>}
|
||||
end.
|
||||
|
||||
-spec error_binary(term()) -> binary().
|
||||
error_binary(Value) when is_binary(Value) ->
|
||||
Value;
|
||||
error_binary(Value) when is_list(Value) ->
|
||||
unicode:characters_to_binary(Value);
|
||||
error_binary(Value) when is_atom(Value) ->
|
||||
atom_to_binary(Value, utf8);
|
||||
error_binary(Value) ->
|
||||
unicode:characters_to_binary(io_lib:format("~p", [Value])).
|
||||
|
||||
-spec parse_nats_url(term()) -> {ok, string(), inet:port_number()} | {error, term()}.
|
||||
parse_nats_url(Url) when is_list(Url) ->
|
||||
parse_nats_url(list_to_binary(Url));
|
||||
parse_nats_url(<<"nats://", Rest/binary>>) ->
|
||||
parse_host_port(Rest);
|
||||
parse_nats_url(<<"tls://", Rest/binary>>) ->
|
||||
parse_host_port(Rest);
|
||||
parse_nats_url(Url) when is_binary(Url) ->
|
||||
parse_host_port(Url);
|
||||
parse_nats_url(_) ->
|
||||
{error, invalid_nats_url}.
|
||||
|
||||
-spec parse_host_port(binary()) -> {ok, string(), inet:port_number()} | {error, term()}.
|
||||
parse_host_port(HostPort) ->
|
||||
case binary:split(HostPort, <<":">>) of
|
||||
[Host, PortBin] ->
|
||||
try
|
||||
Port = binary_to_integer(PortBin),
|
||||
{ok, binary_to_list(Host), Port}
|
||||
catch
|
||||
_:_ -> {error, invalid_port}
|
||||
end;
|
||||
[Host] ->
|
||||
{ok, binary_to_list(Host), 4222}
|
||||
end.
|
||||
|
||||
-spec build_connect_opts(term()) -> map().
|
||||
build_connect_opts(AuthToken) when is_binary(AuthToken), byte_size(AuthToken) > 0 ->
|
||||
#{auth_token => AuthToken, buffer_size => 0};
|
||||
build_connect_opts(AuthToken) when is_list(AuthToken) ->
|
||||
case AuthToken of
|
||||
"" -> #{buffer_size => 0};
|
||||
_ -> #{auth_token => list_to_binary(AuthToken), buffer_size => 0}
|
||||
end;
|
||||
build_connect_opts(_) ->
|
||||
#{buffer_size => 0}.
|
||||
|
||||
-spec schedule_reconnect(map()) -> map().
|
||||
schedule_reconnect(State) ->
|
||||
erlang:send_after(?RECONNECT_DELAY_MS, self(), connect),
|
||||
State.
|
||||
|
||||
-spec max_handlers() -> pos_integer().
|
||||
max_handlers() ->
|
||||
case fluxer_gateway_env:get(gateway_http_rpc_max_concurrency) of
|
||||
Value when is_integer(Value), Value > 0 ->
|
||||
Value;
|
||||
_ ->
|
||||
?DEFAULT_MAX_HANDLERS
|
||||
end.
|
||||
|
||||
-ifdef(TEST).
|
||||
-include_lib("eunit/include/eunit.hrl").
|
||||
|
||||
parse_nats_url_test() ->
|
||||
?assertEqual({ok, "127.0.0.1", 4222}, parse_nats_url(<<"nats://127.0.0.1:4222">>)),
|
||||
?assertEqual({ok, "localhost", 4222}, parse_nats_url(<<"nats://localhost:4222">>)),
|
||||
?assertEqual({ok, "localhost", 4222}, parse_nats_url(<<"nats://localhost">>)),
|
||||
?assertEqual({ok, "127.0.0.1", 4222}, parse_nats_url("nats://127.0.0.1:4222")),
|
||||
?assertEqual({error, invalid_nats_url}, parse_nats_url(undefined)).
|
||||
|
||||
-endif.
|
||||
@@ -35,15 +35,11 @@ execute_method(<<"guild.dispatch">>, #{
|
||||
GuildId = validation:snowflake_or_throw(<<"guild_id">>, GuildIdBin),
|
||||
with_guild(GuildId, fun(Pid) ->
|
||||
EventAtom = constants:dispatch_event_atom(Event),
|
||||
case
|
||||
gen_server:call(
|
||||
Pid, {dispatch, #{event => EventAtom, data => Data}}, ?GUILD_CALL_TIMEOUT
|
||||
)
|
||||
of
|
||||
ok ->
|
||||
true;
|
||||
_ -> throw({error, <<"dispatch_error">>})
|
||||
end
|
||||
IsAlive = erlang:is_process_alive(Pid),
|
||||
logger:info("rpc guild.dispatch: guild_id=~p event=~p pid=~p alive=~p",
|
||||
[GuildId, EventAtom, Pid, IsAlive]),
|
||||
gen_server:cast(Pid, {dispatch, #{event => EventAtom, data => Data}}),
|
||||
true
|
||||
end);
|
||||
execute_method(<<"guild.get_counts">>, #{<<"guild_id">> := GuildIdBin}) ->
|
||||
GuildId = validation:snowflake_or_throw(<<"guild_id">>, GuildIdBin),
|
||||
@@ -78,29 +74,23 @@ execute_method(<<"guild.get_data">>, #{<<"guild_id">> := GuildIdBin, <<"user_id"
|
||||
execute_method(<<"guild.get_member">>, #{<<"guild_id">> := GuildIdBin, <<"user_id">> := UserIdBin}) ->
|
||||
GuildId = validation:snowflake_or_throw(<<"guild_id">>, GuildIdBin),
|
||||
UserId = validation:snowflake_or_throw(<<"user_id">>, UserIdBin),
|
||||
with_guild(GuildId, fun(Pid) ->
|
||||
Request = #{user_id => UserId},
|
||||
case gen_server:call(Pid, {get_guild_member, Request}, ?GUILD_CALL_TIMEOUT) of
|
||||
#{success := true, member_data := MemberData} ->
|
||||
#{<<"success">> => true, <<"member_data">> => MemberData};
|
||||
#{success := false} ->
|
||||
#{<<"success">> => false};
|
||||
_ ->
|
||||
throw({error, <<"guild_member_error">>})
|
||||
end
|
||||
end);
|
||||
case get_member_cached_or_rpc(GuildId, UserId) of
|
||||
{ok, MemberData} when is_map(MemberData) ->
|
||||
#{<<"success">> => true, <<"member_data">> => MemberData};
|
||||
{ok, undefined} ->
|
||||
#{<<"success">> => false};
|
||||
error ->
|
||||
throw({error, <<"guild_member_error">>})
|
||||
end;
|
||||
execute_method(<<"guild.has_member">>, #{<<"guild_id">> := GuildIdBin, <<"user_id">> := UserIdBin}) ->
|
||||
GuildId = validation:snowflake_or_throw(<<"guild_id">>, GuildIdBin),
|
||||
UserId = validation:snowflake_or_throw(<<"user_id">>, UserIdBin),
|
||||
with_guild(GuildId, fun(Pid) ->
|
||||
Request = #{user_id => UserId},
|
||||
case gen_server:call(Pid, {has_member, Request}, ?GUILD_CALL_TIMEOUT) of
|
||||
#{has_member := HasMember} when is_boolean(HasMember) ->
|
||||
#{<<"has_member">> => HasMember};
|
||||
_ ->
|
||||
throw({error, <<"membership_check_error">>})
|
||||
end
|
||||
end);
|
||||
case get_has_member_cached_or_rpc(GuildId, UserId) of
|
||||
{ok, HasMember} ->
|
||||
#{<<"has_member">> => HasMember};
|
||||
error ->
|
||||
throw({error, <<"membership_check_error">>})
|
||||
end;
|
||||
execute_method(<<"guild.list_members">>, #{
|
||||
<<"guild_id">> := GuildIdBin, <<"limit">> := Limit, <<"offset">> := Offset
|
||||
}) ->
|
||||
@@ -429,9 +419,9 @@ execute_method(<<"guild.update_member_voice">>, #{
|
||||
}) ->
|
||||
GuildId = validation:snowflake_or_throw(<<"guild_id">>, GuildIdBin),
|
||||
UserId = validation:snowflake_or_throw(<<"user_id">>, UserIdBin),
|
||||
with_guild(GuildId, fun(Pid) ->
|
||||
with_voice_server(GuildId, fun(VoicePid, _GuildPid) ->
|
||||
Request = #{user_id => UserId, mute => Mute, deaf => Deaf},
|
||||
case gen_server:call(Pid, {update_member_voice, Request}, ?GUILD_CALL_TIMEOUT) of
|
||||
case gen_server:call(VoicePid, {update_member_voice, Request}, ?GUILD_CALL_TIMEOUT) of
|
||||
#{success := true} -> #{<<"success">> => true};
|
||||
#{error := Error} -> throw({error, normalize_voice_rpc_error(Error)})
|
||||
end
|
||||
@@ -443,9 +433,9 @@ execute_method(
|
||||
GuildId = validation:snowflake_or_throw(<<"guild_id">>, GuildIdBin),
|
||||
UserId = validation:snowflake_or_throw(<<"user_id">>, UserIdBin),
|
||||
ConnectionId = maps:get(<<"connection_id">>, Params, null),
|
||||
with_guild(GuildId, fun(Pid) ->
|
||||
with_voice_server(GuildId, fun(VoicePid, _GuildPid) ->
|
||||
Request = #{user_id => UserId, connection_id => ConnectionId},
|
||||
case gen_server:call(Pid, {disconnect_voice_user, Request}, ?GUILD_CALL_TIMEOUT) of
|
||||
case gen_server:call(VoicePid, {disconnect_voice_user, Request}, ?GUILD_CALL_TIMEOUT) of
|
||||
#{success := true} -> #{<<"success">> => true};
|
||||
#{error := Error} -> throw({error, normalize_voice_rpc_error(Error)})
|
||||
end
|
||||
@@ -464,11 +454,11 @@ execute_method(
|
||||
<<"expected_channel_id">>, ExpectedChannelIdBin
|
||||
),
|
||||
ConnectionId = maps:get(<<"connection_id">>, Params, undefined),
|
||||
with_guild(GuildId, fun(Pid) ->
|
||||
with_voice_server(GuildId, fun(VoicePid, _GuildPid) ->
|
||||
Request = build_disconnect_request(UserId, ExpectedChannelId, ConnectionId),
|
||||
case
|
||||
gen_server:call(
|
||||
Pid, {disconnect_voice_user_if_in_channel, Request}, ?GUILD_CALL_TIMEOUT
|
||||
VoicePid, {disconnect_voice_user_if_in_channel, Request}, ?GUILD_CALL_TIMEOUT
|
||||
)
|
||||
of
|
||||
#{success := true, ignored := true} -> #{<<"success">> => true, <<"ignored">> => true};
|
||||
@@ -481,11 +471,11 @@ execute_method(<<"guild.disconnect_all_voice_users_in_channel">>, #{
|
||||
}) ->
|
||||
GuildId = validation:snowflake_or_throw(<<"guild_id">>, GuildIdBin),
|
||||
ChannelId = validation:snowflake_or_throw(<<"channel_id">>, ChannelIdBin),
|
||||
with_guild(GuildId, fun(Pid) ->
|
||||
with_voice_server(GuildId, fun(VoicePid, _GuildPid) ->
|
||||
Request = #{channel_id => ChannelId},
|
||||
case
|
||||
gen_server:call(
|
||||
Pid, {disconnect_all_voice_users_in_channel, Request}, ?GUILD_CALL_TIMEOUT
|
||||
VoicePid, {disconnect_all_voice_users_in_channel, Request}, ?GUILD_CALL_TIMEOUT
|
||||
)
|
||||
of
|
||||
#{success := true, disconnected_count := Count} ->
|
||||
@@ -499,11 +489,11 @@ execute_method(<<"guild.confirm_voice_connection_from_livekit">>, Params) ->
|
||||
ConnectionId = maps:get(<<"connection_id">>, Params),
|
||||
TokenNonce = maps:get(<<"token_nonce">>, Params, undefined),
|
||||
GuildId = validation:snowflake_or_throw(<<"guild_id">>, GuildIdBin),
|
||||
with_guild(GuildId, fun(Pid) ->
|
||||
with_voice_server(GuildId, fun(VoicePid, _GuildPid) ->
|
||||
Request = #{connection_id => ConnectionId, token_nonce => TokenNonce},
|
||||
case
|
||||
gen_server:call(
|
||||
Pid, {confirm_voice_connection_from_livekit, Request}, ?GUILD_CALL_TIMEOUT
|
||||
VoicePid, {confirm_voice_connection_from_livekit, Request}, ?GUILD_CALL_TIMEOUT
|
||||
)
|
||||
of
|
||||
#{success := true} -> #{<<"success">> => true};
|
||||
@@ -518,8 +508,8 @@ execute_method(<<"guild.get_voice_states_for_channel">>, Params) ->
|
||||
GuildIdBin = maps:get(<<"guild_id">>, Params),
|
||||
ChannelIdBin = maps:get(<<"channel_id">>, Params),
|
||||
GuildId = validation:snowflake_or_throw(<<"guild_id">>, GuildIdBin),
|
||||
with_guild(GuildId, fun(Pid) ->
|
||||
case gen_server:call(Pid, {get_voice_states_for_channel, ChannelIdBin}, 10000) of
|
||||
with_voice_server(GuildId, fun(VoicePid, _GuildPid) ->
|
||||
case gen_server:call(VoicePid, {get_voice_states_for_channel, ChannelIdBin}, 10000) of
|
||||
#{voice_states := VoiceStates} ->
|
||||
#{<<"voice_states">> => VoiceStates};
|
||||
_ ->
|
||||
@@ -530,8 +520,8 @@ execute_method(<<"guild.get_pending_joins_for_channel">>, Params) ->
|
||||
GuildIdBin = maps:get(<<"guild_id">>, Params),
|
||||
ChannelIdBin = maps:get(<<"channel_id">>, Params),
|
||||
GuildId = validation:snowflake_or_throw(<<"guild_id">>, GuildIdBin),
|
||||
with_guild(GuildId, fun(Pid) ->
|
||||
case gen_server:call(Pid, {get_pending_joins_for_channel, ChannelIdBin}, 10000) of
|
||||
with_voice_server(GuildId, fun(VoicePid, _GuildPid) ->
|
||||
case gen_server:call(VoicePid, {get_pending_joins_for_channel, ChannelIdBin}, 10000) of
|
||||
#{pending_joins := PendingJoins} ->
|
||||
#{<<"pending_joins">> => PendingJoins};
|
||||
_ ->
|
||||
@@ -559,7 +549,7 @@ execute_method(<<"guild.move_member">>, #{
|
||||
connection_id => ConnectionId
|
||||
}
|
||||
),
|
||||
with_guild(GuildId, fun(Pid) ->
|
||||
with_voice_server(GuildId, fun(VoicePid, GuildPid) ->
|
||||
Request = #{
|
||||
user_id => UserId,
|
||||
moderator_id => ModeratorId,
|
||||
@@ -567,10 +557,10 @@ execute_method(<<"guild.move_member">>, #{
|
||||
connection_id => ConnectionId
|
||||
},
|
||||
handle_move_member_result(
|
||||
gen_server:call(Pid, {move_member, Request}, ?GUILD_CALL_TIMEOUT),
|
||||
gen_server:call(VoicePid, {move_member, Request}, ?GUILD_CALL_TIMEOUT),
|
||||
GuildId,
|
||||
ChannelId,
|
||||
Pid
|
||||
GuildPid
|
||||
)
|
||||
end);
|
||||
execute_method(<<"guild.get_voice_state">>, #{
|
||||
@@ -578,9 +568,9 @@ execute_method(<<"guild.get_voice_state">>, #{
|
||||
}) ->
|
||||
GuildId = validation:snowflake_or_throw(<<"guild_id">>, GuildIdBin),
|
||||
UserId = validation:snowflake_or_throw(<<"user_id">>, UserIdBin),
|
||||
with_guild(GuildId, fun(Pid) ->
|
||||
with_voice_server(GuildId, fun(VoicePid, _GuildPid) ->
|
||||
Request = #{user_id => UserId},
|
||||
case gen_server:call(Pid, {get_voice_state, Request}, ?GUILD_CALL_TIMEOUT) of
|
||||
case gen_server:call(VoicePid, {get_voice_state, Request}, ?GUILD_CALL_TIMEOUT) of
|
||||
#{voice_state := null} -> #{<<"voice_state">> => null};
|
||||
#{voice_state := VoiceState} -> #{<<"voice_state">> => VoiceState};
|
||||
_ -> throw({error, <<"voice_state_error">>})
|
||||
@@ -591,11 +581,11 @@ execute_method(<<"guild.switch_voice_region">>, #{
|
||||
}) ->
|
||||
GuildId = validation:snowflake_or_throw(<<"guild_id">>, GuildIdBin),
|
||||
ChannelId = validation:snowflake_or_throw(<<"channel_id">>, ChannelIdBin),
|
||||
with_guild(GuildId, fun(Pid) ->
|
||||
with_voice_server(GuildId, fun(VoicePid, GuildPid) ->
|
||||
Request = #{channel_id => ChannelId},
|
||||
case gen_server:call(Pid, {switch_voice_region, Request}, ?GUILD_CALL_TIMEOUT) of
|
||||
case gen_server:call(VoicePid, {switch_voice_region, Request}, ?GUILD_CALL_TIMEOUT) of
|
||||
#{success := true} ->
|
||||
spawn(fun() -> guild_voice:switch_voice_region(GuildId, ChannelId, Pid) end),
|
||||
spawn(fun() -> guild_voice:switch_voice_region(GuildId, ChannelId, GuildPid) end),
|
||||
#{<<"success">> => true};
|
||||
#{error := Error} ->
|
||||
throw({error, normalize_voice_rpc_error(Error)})
|
||||
@@ -644,12 +634,26 @@ execute_method(<<"guild.batch_voice_state_update">>, #{<<"updates">> := UpdatesB
|
||||
|
||||
-spec fetch_online_count_entry(integer()) -> map() | undefined.
|
||||
fetch_online_count_entry(GuildId) ->
|
||||
case guild_counts_cache:get(GuildId) of
|
||||
{ok, MemberCount, OnlineCount} ->
|
||||
#{
|
||||
<<"guild_id">> => integer_to_binary(GuildId),
|
||||
<<"member_count">> => MemberCount,
|
||||
<<"online_count">> => OnlineCount
|
||||
};
|
||||
miss ->
|
||||
fetch_online_count_entry_from_process(GuildId)
|
||||
end.
|
||||
|
||||
-spec fetch_online_count_entry_from_process(integer()) -> map() | undefined.
|
||||
fetch_online_count_entry_from_process(GuildId) ->
|
||||
case get_guild_pid(GuildId) of
|
||||
{ok, Pid} ->
|
||||
case gen_server:call(Pid, {get_counts}, ?GUILD_CALL_TIMEOUT) of
|
||||
#{presence_count := PresenceCount} ->
|
||||
#{member_count := MemberCount, presence_count := PresenceCount} ->
|
||||
#{
|
||||
<<"guild_id">> => integer_to_binary(GuildId),
|
||||
<<"member_count">> => MemberCount,
|
||||
<<"online_count">> => PresenceCount
|
||||
};
|
||||
_ ->
|
||||
@@ -670,6 +674,23 @@ with_guild(GuildId, Fun, NotFoundError) ->
|
||||
_ -> throw({error, NotFoundError})
|
||||
end.
|
||||
|
||||
-spec with_voice_server(integer(), fun((pid(), pid()) -> T)) -> T when T :: term().
|
||||
with_voice_server(GuildId, Fun) ->
|
||||
case get_guild_pid(GuildId) of
|
||||
{ok, GuildPid} ->
|
||||
VoicePid = resolve_voice_pid(GuildId, GuildPid),
|
||||
Fun(VoicePid, GuildPid);
|
||||
_ ->
|
||||
throw({error, <<"guild_not_found">>})
|
||||
end.
|
||||
|
||||
-spec resolve_voice_pid(integer(), pid()) -> pid().
|
||||
resolve_voice_pid(GuildId, FallbackGuildPid) ->
|
||||
case guild_voice_server:lookup(GuildId) of
|
||||
{ok, VoicePid} -> VoicePid;
|
||||
{error, not_found} -> FallbackGuildPid
|
||||
end.
|
||||
|
||||
-spec get_guild_pid(integer()) -> {ok, pid()} | error.
|
||||
get_guild_pid(GuildId) ->
|
||||
case lookup_guild_pid_from_cache(GuildId) of
|
||||
@@ -792,6 +813,56 @@ get_viewable_channels_via_rpc(GuildId, UserId) ->
|
||||
error
|
||||
end.
|
||||
|
||||
-spec get_has_member_cached_or_rpc(integer(), integer()) -> {ok, boolean()} | error.
|
||||
get_has_member_cached_or_rpc(GuildId, UserId) ->
|
||||
case guild_permission_cache:has_member(GuildId, UserId) of
|
||||
{ok, HasMember} ->
|
||||
{ok, HasMember};
|
||||
{error, not_found} ->
|
||||
get_has_member_via_rpc(GuildId, UserId)
|
||||
end.
|
||||
|
||||
-spec get_has_member_via_rpc(integer(), integer()) -> {ok, boolean()} | error.
|
||||
get_has_member_via_rpc(GuildId, UserId) ->
|
||||
case get_guild_pid(GuildId) of
|
||||
{ok, Pid} ->
|
||||
Request = #{user_id => UserId},
|
||||
case gen_server:call(Pid, {has_member, Request}, ?GUILD_CALL_TIMEOUT) of
|
||||
#{has_member := HasMember} when is_boolean(HasMember) ->
|
||||
{ok, HasMember};
|
||||
_ ->
|
||||
error
|
||||
end;
|
||||
error ->
|
||||
error
|
||||
end.
|
||||
|
||||
-spec get_member_cached_or_rpc(integer(), integer()) -> {ok, map() | undefined} | error.
|
||||
get_member_cached_or_rpc(GuildId, UserId) ->
|
||||
case guild_permission_cache:get_member(GuildId, UserId) of
|
||||
{ok, MemberOrUndefined} ->
|
||||
{ok, MemberOrUndefined};
|
||||
{error, not_found} ->
|
||||
get_member_via_rpc(GuildId, UserId)
|
||||
end.
|
||||
|
||||
-spec get_member_via_rpc(integer(), integer()) -> {ok, map() | undefined} | error.
|
||||
get_member_via_rpc(GuildId, UserId) ->
|
||||
case get_guild_pid(GuildId) of
|
||||
{ok, Pid} ->
|
||||
Request = #{user_id => UserId},
|
||||
case gen_server:call(Pid, {get_guild_member, Request}, ?GUILD_CALL_TIMEOUT) of
|
||||
#{success := true, member_data := MemberData} ->
|
||||
{ok, MemberData};
|
||||
#{success := false} ->
|
||||
{ok, undefined};
|
||||
_ ->
|
||||
error
|
||||
end;
|
||||
error ->
|
||||
error
|
||||
end.
|
||||
|
||||
-spec parse_channel_id(binary()) -> integer() | undefined.
|
||||
parse_channel_id(<<"0">>) -> undefined;
|
||||
parse_channel_id(ChannelIdBin) -> validation:snowflake_or_throw(<<"channel_id">>, ChannelIdBin).
|
||||
@@ -892,11 +963,12 @@ parse_voice_update(
|
||||
-spec process_voice_update({integer(), integer(), boolean(), boolean(), term()}) -> map().
|
||||
process_voice_update({GuildId, UserId, Mute, Deaf, ConnectionId}) ->
|
||||
case gen_server:call(guild_manager, {start_or_lookup, GuildId}, ?GUILD_LOOKUP_TIMEOUT) of
|
||||
{ok, Pid} ->
|
||||
{ok, GuildPid} ->
|
||||
VoicePid = resolve_voice_pid(GuildId, GuildPid),
|
||||
Request = #{
|
||||
user_id => UserId, mute => Mute, deaf => Deaf, connection_id => ConnectionId
|
||||
},
|
||||
case gen_server:call(Pid, {update_member_voice, Request}, ?GUILD_CALL_TIMEOUT) of
|
||||
case gen_server:call(VoicePid, {update_member_voice, Request}, ?GUILD_CALL_TIMEOUT) of
|
||||
#{success := true} ->
|
||||
#{
|
||||
<<"guild_id">> => integer_to_binary(GuildId),
|
||||
@@ -1062,4 +1134,50 @@ get_viewable_channels_cached_or_rpc_prefers_cache_test() ->
|
||||
ok = guild_permission_cache:delete(GuildId)
|
||||
end.
|
||||
|
||||
get_has_member_cached_or_rpc_prefers_cache_test() ->
|
||||
GuildId = 12348,
|
||||
UserId = 502,
|
||||
Data = #{
|
||||
<<"guild">> => #{<<"owner_id">> => <<"999">>},
|
||||
<<"roles">> => [],
|
||||
<<"members">> => #{
|
||||
UserId => #{
|
||||
<<"user">> => #{<<"id">> => integer_to_binary(UserId)},
|
||||
<<"roles">> => []
|
||||
}
|
||||
},
|
||||
<<"channels">> => []
|
||||
},
|
||||
ok = guild_permission_cache:put_data(GuildId, Data),
|
||||
try
|
||||
?assertEqual({ok, true}, get_has_member_cached_or_rpc(GuildId, UserId)),
|
||||
?assertEqual({ok, false}, get_has_member_cached_or_rpc(GuildId, 99999))
|
||||
after
|
||||
ok = guild_permission_cache:delete(GuildId)
|
||||
end.
|
||||
|
||||
get_member_cached_or_rpc_prefers_cache_test() ->
|
||||
GuildId = 12349,
|
||||
UserId = 503,
|
||||
Data = #{
|
||||
<<"guild">> => #{<<"owner_id">> => <<"999">>},
|
||||
<<"roles">> => [],
|
||||
<<"members">> => #{
|
||||
UserId => #{
|
||||
<<"user">> => #{<<"id">> => integer_to_binary(UserId)},
|
||||
<<"roles">> => [],
|
||||
<<"nick">> => <<"CacheNick">>
|
||||
}
|
||||
},
|
||||
<<"channels">> => []
|
||||
},
|
||||
ok = guild_permission_cache:put_data(GuildId, Data),
|
||||
try
|
||||
{ok, MemberData} = get_member_cached_or_rpc(GuildId, UserId),
|
||||
?assertEqual(<<"CacheNick">>, maps:get(<<"nick">>, MemberData)),
|
||||
?assertEqual({ok, undefined}, get_member_cached_or_rpc(GuildId, 99999))
|
||||
after
|
||||
ok = guild_permission_cache:delete(GuildId)
|
||||
end.
|
||||
|
||||
-endif.
|
||||
|
||||
@@ -1,167 +0,0 @@
|
||||
%% Copyright (C) 2026 Fluxer Contributors
|
||||
%%
|
||||
%% This file is part of Fluxer.
|
||||
%%
|
||||
%% Fluxer is free software: you can redistribute it and/or modify
|
||||
%% it under the terms of the GNU Affero General Public License as published by
|
||||
%% the Free Software Foundation, either version 3 of the License, or
|
||||
%% (at your option) any later version.
|
||||
%%
|
||||
%% Fluxer is distributed in the hope that it will be useful,
|
||||
%% but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
%% GNU Affero General Public License for more details.
|
||||
%%
|
||||
%% You should have received a copy of the GNU Affero General Public License
|
||||
%% along with Fluxer. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
-module(gateway_rpc_http_handler).
|
||||
|
||||
-export([init/2]).
|
||||
|
||||
-define(JSON_HEADERS, #{<<"content-type">> => <<"application/json">>}).
|
||||
|
||||
-spec init(cowboy_req:req(), term()) -> {ok, cowboy_req:req(), term()}.
|
||||
init(Req0, State) ->
|
||||
case cowboy_req:method(Req0) of
|
||||
<<"POST">> ->
|
||||
handle_post(Req0, State);
|
||||
_ ->
|
||||
Req = cowboy_req:reply(405, #{<<"allow">> => <<"POST">>}, <<>>, Req0),
|
||||
{ok, Req, State}
|
||||
end.
|
||||
|
||||
-spec handle_post(cowboy_req:req(), term()) -> {ok, cowboy_req:req(), term()}.
|
||||
handle_post(Req0, State) ->
|
||||
case authorize(Req0) of
|
||||
ok ->
|
||||
case read_body(Req0) of
|
||||
{ok, Decoded, Req1} ->
|
||||
handle_decoded_body(Decoded, Req1, State);
|
||||
{error, ErrorBody, Req1} ->
|
||||
respond(400, ErrorBody, Req1, State)
|
||||
end;
|
||||
{error, Req1} ->
|
||||
{ok, Req1, State}
|
||||
end.
|
||||
|
||||
-spec handle_decoded_body(map(), cowboy_req:req(), term()) -> {ok, cowboy_req:req(), term()}.
|
||||
handle_decoded_body(Decoded, Req0, State) ->
|
||||
case maps:get(<<"method">>, Decoded, undefined) of
|
||||
undefined ->
|
||||
respond(400, #{<<"error">> => <<"Missing method">>}, Req0, State);
|
||||
Method when is_binary(Method) ->
|
||||
ParamsValue = maps:get(<<"params">>, Decoded, #{}),
|
||||
case is_map(ParamsValue) of
|
||||
true ->
|
||||
execute_method(Method, ParamsValue, Req0, State);
|
||||
false ->
|
||||
respond(400, #{<<"error">> => <<"Invalid params">>}, Req0, State)
|
||||
end;
|
||||
_ ->
|
||||
respond(400, #{<<"error">> => <<"Invalid method">>}, Req0, State)
|
||||
end.
|
||||
|
||||
-spec authorize(cowboy_req:req()) -> ok | {error, cowboy_req:req()}.
|
||||
authorize(Req0) ->
|
||||
case cowboy_req:header(<<"authorization">>, Req0) of
|
||||
undefined ->
|
||||
Req = cowboy_req:reply(
|
||||
401,
|
||||
?JSON_HEADERS,
|
||||
json:encode(#{<<"error">> => <<"Unauthorized">>}),
|
||||
Req0
|
||||
),
|
||||
{error, Req};
|
||||
AuthHeader ->
|
||||
authorize_with_secret(AuthHeader, Req0)
|
||||
end.
|
||||
|
||||
-spec authorize_with_secret(binary(), cowboy_req:req()) -> ok | {error, cowboy_req:req()}.
|
||||
authorize_with_secret(AuthHeader, Req0) ->
|
||||
case fluxer_gateway_env:get(rpc_secret_key) of
|
||||
undefined ->
|
||||
Req = cowboy_req:reply(
|
||||
500,
|
||||
?JSON_HEADERS,
|
||||
json:encode(#{<<"error">> => <<"RPC secret not configured">>}),
|
||||
Req0
|
||||
),
|
||||
{error, Req};
|
||||
Secret when is_binary(Secret) ->
|
||||
Expected = <<"Bearer ", Secret/binary>>,
|
||||
check_auth_header(AuthHeader, Expected, Req0)
|
||||
end.
|
||||
|
||||
-spec check_auth_header(binary(), binary(), cowboy_req:req()) -> ok | {error, cowboy_req:req()}.
|
||||
check_auth_header(AuthHeader, Expected, Req0) ->
|
||||
case secure_compare(AuthHeader, Expected) of
|
||||
true ->
|
||||
ok;
|
||||
false ->
|
||||
Req = cowboy_req:reply(
|
||||
401,
|
||||
?JSON_HEADERS,
|
||||
json:encode(#{<<"error">> => <<"Unauthorized">>}),
|
||||
Req0
|
||||
),
|
||||
{error, Req}
|
||||
end.
|
||||
|
||||
-spec secure_compare(binary(), binary()) -> boolean().
|
||||
secure_compare(Left, Right) when is_binary(Left), is_binary(Right) ->
|
||||
case byte_size(Left) =:= byte_size(Right) of
|
||||
true ->
|
||||
crypto:hash_equals(Left, Right);
|
||||
false ->
|
||||
false
|
||||
end.
|
||||
|
||||
-spec read_body(cowboy_req:req()) ->
|
||||
{ok, map(), cowboy_req:req()} | {error, map(), cowboy_req:req()}.
|
||||
read_body(Req0) ->
|
||||
read_body_chunks(Req0, <<>>).
|
||||
|
||||
-spec read_body_chunks(cowboy_req:req(), binary()) ->
|
||||
{ok, map(), cowboy_req:req()} | {error, map(), cowboy_req:req()}.
|
||||
read_body_chunks(Req0, Acc) ->
|
||||
case cowboy_req:read_body(Req0) of
|
||||
{ok, Body, Req1} ->
|
||||
FullBody = <<Acc/binary, Body/binary>>,
|
||||
decode_body(FullBody, Req1);
|
||||
{more, Body, Req1} ->
|
||||
read_body_chunks(Req1, <<Acc/binary, Body/binary>>)
|
||||
end.
|
||||
|
||||
-spec decode_body(binary(), cowboy_req:req()) ->
|
||||
{ok, map(), cowboy_req:req()} | {error, map(), cowboy_req:req()}.
|
||||
decode_body(Body, Req0) ->
|
||||
case catch json:decode(Body) of
|
||||
{'EXIT', _Reason} ->
|
||||
{error, #{<<"error">> => <<"Invalid JSON payload">>}, Req0};
|
||||
Decoded when is_map(Decoded) ->
|
||||
{ok, Decoded, Req0};
|
||||
_ ->
|
||||
{error, #{<<"error">> => <<"Invalid request body">>}, Req0}
|
||||
end.
|
||||
|
||||
-spec execute_method(binary(), map(), cowboy_req:req(), term()) -> {ok, cowboy_req:req(), term()}.
|
||||
execute_method(Method, Params, Req0, State) ->
|
||||
try
|
||||
Result = gateway_rpc_router:execute(Method, Params),
|
||||
respond(200, #{<<"result">> => Result}, Req0, State)
|
||||
catch
|
||||
throw:{error, Message} ->
|
||||
respond(400, #{<<"error">> => Message}, Req0, State);
|
||||
exit:timeout ->
|
||||
respond(504, #{<<"error">> => <<"timeout">>}, Req0, State);
|
||||
exit:{timeout, _} ->
|
||||
respond(504, #{<<"error">> => <<"timeout">>}, Req0, State);
|
||||
_:_ ->
|
||||
respond(500, #{<<"error">> => <<"Internal error">>}, Req0, State)
|
||||
end.
|
||||
|
||||
-spec respond(pos_integer(), map(), cowboy_req:req(), term()) -> {ok, cowboy_req:req(), term()}.
|
||||
respond(Status, Body, Req0, State) ->
|
||||
Req = cowboy_req:reply(Status, ?JSON_HEADERS, json:encode(Body), Req0),
|
||||
{ok, Req, State}.
|
||||
@@ -1,446 +0,0 @@
|
||||
%% Copyright (C) 2026 Fluxer Contributors
|
||||
%%
|
||||
%% This file is part of Fluxer.
|
||||
%%
|
||||
%% Fluxer is free software: you can redistribute it and/or modify
|
||||
%% it under the terms of the GNU Affero General Public License as published by
|
||||
%% the Free Software Foundation, either version 3 of the License, or
|
||||
%% (at your option) any later version.
|
||||
%%
|
||||
%% Fluxer is distributed in the hope that it will be useful,
|
||||
%% but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
%% GNU Affero General Public License for more details.
|
||||
%%
|
||||
%% You should have received a copy of the GNU Affero General Public License
|
||||
%% along with Fluxer. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
-module(gateway_rpc_tcp_connection).
|
||||
|
||||
-export([serve/1]).
|
||||
|
||||
-define(DEFAULT_MAX_INFLIGHT, 1024).
|
||||
-define(DEFAULT_MAX_INPUT_BUFFER_BYTES, 2097152).
|
||||
-define(DEFAULT_DISPATCH_RESERVE_DIVISOR, 8).
|
||||
-define(MAX_FRAME_BYTES, 1048576).
|
||||
-define(PROTOCOL_VERSION, <<"fluxer.rpc.tcp.v1">>).
|
||||
|
||||
-type state() :: #{
|
||||
socket := inet:socket(),
|
||||
buffer := binary(),
|
||||
authenticated := boolean(),
|
||||
inflight := non_neg_integer(),
|
||||
max_inflight := pos_integer(),
|
||||
max_input_buffer_bytes := pos_integer()
|
||||
}.
|
||||
|
||||
-type rpc_result() :: {ok, term()} | {error, binary()}.
|
||||
|
||||
-spec serve(inet:socket()) -> ok.
|
||||
serve(Socket) ->
|
||||
ok = inet:setopts(Socket, [{active, once}, {nodelay, true}, {keepalive, true}]),
|
||||
State = #{
|
||||
socket => Socket,
|
||||
buffer => <<>>,
|
||||
authenticated => false,
|
||||
inflight => 0,
|
||||
max_inflight => max_inflight(),
|
||||
max_input_buffer_bytes => max_input_buffer_bytes()
|
||||
},
|
||||
loop(State).
|
||||
|
||||
-spec loop(state()) -> ok.
|
||||
loop(#{socket := Socket} = State) ->
|
||||
receive
|
||||
{tcp, Socket, Data} ->
|
||||
case handle_tcp_data(Data, State) of
|
||||
{ok, NewState} ->
|
||||
ok = inet:setopts(Socket, [{active, once}]),
|
||||
loop(NewState);
|
||||
{stop, Reason, _NewState} ->
|
||||
logger:debug("Gateway TCP RPC connection closed: ~p", [Reason]),
|
||||
close_socket(Socket),
|
||||
ok
|
||||
end;
|
||||
{tcp_closed, Socket} ->
|
||||
ok;
|
||||
{tcp_error, Socket, Reason} ->
|
||||
logger:warning("Gateway TCP RPC socket error: ~p", [Reason]),
|
||||
close_socket(Socket),
|
||||
ok;
|
||||
{rpc_response, RequestId, Result} ->
|
||||
NewState = handle_rpc_response(RequestId, Result, State),
|
||||
loop(NewState);
|
||||
_Other ->
|
||||
loop(State)
|
||||
end.
|
||||
|
||||
-spec handle_tcp_data(binary(), state()) -> {ok, state()} | {stop, term(), state()}.
|
||||
handle_tcp_data(Data, #{buffer := Buffer, max_input_buffer_bytes := MaxInputBufferBytes} = State) ->
|
||||
case byte_size(Buffer) + byte_size(Data) =< MaxInputBufferBytes of
|
||||
false ->
|
||||
_ = send_error_frame(State, protocol_error_binary(input_buffer_limit_exceeded)),
|
||||
{stop, input_buffer_limit_exceeded, State};
|
||||
true ->
|
||||
Combined = <<Buffer/binary, Data/binary>>,
|
||||
decode_tcp_frames(Combined, State)
|
||||
end.
|
||||
|
||||
-spec decode_tcp_frames(binary(), state()) -> {ok, state()} | {stop, term(), state()}.
|
||||
decode_tcp_frames(Combined, State) ->
|
||||
case decode_frames(Combined, []) of
|
||||
{ok, Frames, Rest} ->
|
||||
process_frames(Frames, State#{buffer => Rest});
|
||||
{error, Reason} ->
|
||||
_ = send_error_frame(State, protocol_error_binary(Reason)),
|
||||
{stop, Reason, State}
|
||||
end.
|
||||
|
||||
-spec process_frames([map()], state()) -> {ok, state()} | {stop, term(), state()}.
|
||||
process_frames([], State) ->
|
||||
{ok, State};
|
||||
process_frames([Frame | Rest], State) ->
|
||||
case process_frame(Frame, State) of
|
||||
{ok, NewState} ->
|
||||
process_frames(Rest, NewState);
|
||||
{stop, Reason, NewState} ->
|
||||
{stop, Reason, NewState}
|
||||
end.
|
||||
|
||||
-spec process_frame(map(), state()) -> {ok, state()} | {stop, term(), state()}.
|
||||
process_frame(#{<<"type">> := <<"hello">>} = Frame, #{authenticated := false} = State) ->
|
||||
handle_hello_frame(Frame, State);
|
||||
process_frame(#{<<"type">> := <<"hello">>}, State) ->
|
||||
_ = send_error_frame(State, <<"duplicate_hello">>),
|
||||
{stop, duplicate_hello, State};
|
||||
process_frame(#{<<"type">> := <<"request">>} = Frame, #{authenticated := true} = State) ->
|
||||
handle_request_frame(Frame, State);
|
||||
process_frame(#{<<"type">> := <<"request">>}, State) ->
|
||||
_ = send_error_frame(State, <<"unauthorized">>),
|
||||
{stop, unauthorized, State};
|
||||
process_frame(#{<<"type">> := <<"ping">>}, State) ->
|
||||
_ = send_frame(State, #{<<"type">> => <<"pong">>}),
|
||||
{ok, State};
|
||||
process_frame(#{<<"type">> := <<"pong">>}, State) ->
|
||||
{ok, State};
|
||||
process_frame(#{<<"type">> := <<"close">>}, State) ->
|
||||
{stop, client_close, State};
|
||||
process_frame(_Frame, State) ->
|
||||
_ = send_error_frame(State, <<"unknown_frame_type">>),
|
||||
{stop, unknown_frame_type, State}.
|
||||
|
||||
-spec handle_hello_frame(map(), state()) -> {ok, state()} | {stop, term(), state()}.
|
||||
handle_hello_frame(Frame, State) ->
|
||||
case {maps:get(<<"protocol">>, Frame, undefined), maps:get(<<"authorization">>, Frame, undefined)} of
|
||||
{?PROTOCOL_VERSION, AuthHeader} when is_binary(AuthHeader) ->
|
||||
authorize_hello(AuthHeader, State);
|
||||
_ ->
|
||||
_ = send_error_frame(State, <<"invalid_hello">>),
|
||||
{stop, invalid_hello, State}
|
||||
end.
|
||||
|
||||
-spec authorize_hello(binary(), state()) -> {ok, state()} | {stop, term(), state()}.
|
||||
authorize_hello(AuthHeader, State) ->
|
||||
case fluxer_gateway_env:get(rpc_secret_key) of
|
||||
Secret when is_binary(Secret) ->
|
||||
Expected = <<"Bearer ", Secret/binary>>,
|
||||
case secure_compare(AuthHeader, Expected) of
|
||||
true ->
|
||||
HelloAck = #{
|
||||
<<"type">> => <<"hello_ack">>,
|
||||
<<"protocol">> => ?PROTOCOL_VERSION,
|
||||
<<"max_in_flight">> => maps:get(max_inflight, State),
|
||||
<<"ping_interval_ms">> => 15000
|
||||
},
|
||||
_ = send_frame(State, HelloAck),
|
||||
{ok, State#{authenticated => true}};
|
||||
false ->
|
||||
_ = send_error_frame(State, <<"unauthorized">>),
|
||||
{stop, unauthorized, State}
|
||||
end;
|
||||
_ ->
|
||||
_ = send_error_frame(State, <<"rpc_secret_not_configured">>),
|
||||
{stop, rpc_secret_not_configured, State}
|
||||
end.
|
||||
|
||||
-spec handle_request_frame(map(), state()) -> {ok, state()}.
|
||||
handle_request_frame(Frame, State) ->
|
||||
RequestId = request_id_from_frame(Frame),
|
||||
Method = maps:get(<<"method">>, Frame, undefined),
|
||||
case should_reject_request(Method, State) of
|
||||
true ->
|
||||
_ =
|
||||
send_response_frame(
|
||||
State,
|
||||
RequestId,
|
||||
false,
|
||||
undefined,
|
||||
<<"overloaded">>
|
||||
),
|
||||
{ok, State};
|
||||
false ->
|
||||
case {Method, maps:get(<<"params">>, Frame, undefined)} of
|
||||
{MethodName, Params} when is_binary(RequestId), is_binary(MethodName), is_map(Params) ->
|
||||
Parent = self(),
|
||||
_ = spawn(fun() ->
|
||||
Parent ! {rpc_response, RequestId, execute_method(MethodName, Params)}
|
||||
end),
|
||||
{ok, increment_inflight(State)};
|
||||
_ ->
|
||||
_ =
|
||||
send_response_frame(
|
||||
State,
|
||||
RequestId,
|
||||
false,
|
||||
undefined,
|
||||
<<"invalid_request">>
|
||||
),
|
||||
{ok, State}
|
||||
end
|
||||
end.
|
||||
|
||||
-spec should_reject_request(term(), state()) -> boolean().
|
||||
should_reject_request(Method, #{inflight := Inflight, max_inflight := MaxInflight}) ->
|
||||
case is_dispatch_method(Method) of
|
||||
true ->
|
||||
Inflight >= MaxInflight;
|
||||
false ->
|
||||
Inflight >= non_dispatch_inflight_limit(MaxInflight)
|
||||
end.
|
||||
|
||||
-spec non_dispatch_inflight_limit(pos_integer()) -> pos_integer().
|
||||
non_dispatch_inflight_limit(MaxInflight) ->
|
||||
Reserve = dispatch_reserve_slots(MaxInflight),
|
||||
max(1, MaxInflight - Reserve).
|
||||
|
||||
-spec dispatch_reserve_slots(pos_integer()) -> pos_integer().
|
||||
dispatch_reserve_slots(MaxInflight) ->
|
||||
max(1, MaxInflight div ?DEFAULT_DISPATCH_RESERVE_DIVISOR).
|
||||
|
||||
-spec is_dispatch_method(term()) -> boolean().
|
||||
is_dispatch_method(Method) when is_binary(Method) ->
|
||||
Suffix = <<".dispatch">>,
|
||||
MethodSize = byte_size(Method),
|
||||
SuffixSize = byte_size(Suffix),
|
||||
MethodSize >= SuffixSize andalso
|
||||
binary:part(Method, MethodSize - SuffixSize, SuffixSize) =:= Suffix;
|
||||
is_dispatch_method(_) ->
|
||||
false.
|
||||
|
||||
-spec execute_method(binary(), map()) -> rpc_result().
|
||||
execute_method(Method, Params) ->
|
||||
try
|
||||
Result = gateway_rpc_router:execute(Method, Params),
|
||||
{ok, Result}
|
||||
catch
|
||||
throw:{error, Message} ->
|
||||
{error, error_binary(Message)};
|
||||
exit:timeout ->
|
||||
{error, <<"timeout">>};
|
||||
exit:{timeout, _} ->
|
||||
{error, <<"timeout">>};
|
||||
Class:Reason ->
|
||||
logger:error(
|
||||
"Gateway TCP RPC method execution failed. method=~ts class=~p reason=~p",
|
||||
[Method, Class, Reason]
|
||||
),
|
||||
{error, <<"internal_error">>}
|
||||
end.
|
||||
|
||||
-spec handle_rpc_response(binary(), rpc_result(), state()) -> state().
|
||||
handle_rpc_response(RequestId, {ok, Result}, State) ->
|
||||
_ = send_response_frame(State, RequestId, true, Result, undefined),
|
||||
decrement_inflight(State);
|
||||
handle_rpc_response(RequestId, {error, Error}, State) ->
|
||||
_ = send_response_frame(State, RequestId, false, undefined, Error),
|
||||
decrement_inflight(State).
|
||||
|
||||
-spec send_response_frame(state(), binary(), boolean(), term(), binary() | undefined) -> ok | {error, term()}.
|
||||
send_response_frame(State, RequestId, true, Result, _Error) ->
|
||||
send_frame(State, #{
|
||||
<<"type">> => <<"response">>,
|
||||
<<"id">> => RequestId,
|
||||
<<"ok">> => true,
|
||||
<<"result">> => Result
|
||||
});
|
||||
send_response_frame(State, RequestId, false, _Result, Error) ->
|
||||
send_frame(State, #{
|
||||
<<"type">> => <<"response">>,
|
||||
<<"id">> => RequestId,
|
||||
<<"ok">> => false,
|
||||
<<"error">> => Error
|
||||
}).
|
||||
|
||||
-spec send_error_frame(state(), binary()) -> ok | {error, term()}.
|
||||
send_error_frame(State, Error) ->
|
||||
send_frame(State, #{
|
||||
<<"type">> => <<"error">>,
|
||||
<<"error">> => Error
|
||||
}).
|
||||
|
||||
-spec send_frame(state(), map()) -> ok | {error, term()}.
|
||||
send_frame(#{socket := Socket}, Frame) ->
|
||||
gen_tcp:send(Socket, encode_frame(Frame)).
|
||||
|
||||
-spec encode_frame(map()) -> binary().
|
||||
encode_frame(Frame) ->
|
||||
Payload = iolist_to_binary(json:encode(Frame)),
|
||||
Length = integer_to_binary(byte_size(Payload)),
|
||||
<<Length/binary, "\n", Payload/binary>>.
|
||||
|
||||
-spec decode_frames(binary(), [map()]) -> {ok, [map()], binary()} | {error, term()}.
|
||||
decode_frames(Buffer, Acc) ->
|
||||
case binary:match(Buffer, <<"\n">>) of
|
||||
nomatch ->
|
||||
{ok, lists:reverse(Acc), Buffer};
|
||||
{Pos, 1} ->
|
||||
LengthBin = binary:part(Buffer, 0, Pos),
|
||||
case parse_length(LengthBin) of
|
||||
{ok, Length} ->
|
||||
HeaderSize = Pos + 1,
|
||||
RequiredSize = HeaderSize + Length,
|
||||
case byte_size(Buffer) >= RequiredSize of
|
||||
false ->
|
||||
{ok, lists:reverse(Acc), Buffer};
|
||||
true ->
|
||||
Payload = binary:part(Buffer, HeaderSize, Length),
|
||||
RestSize = byte_size(Buffer) - RequiredSize,
|
||||
Rest = binary:part(Buffer, RequiredSize, RestSize),
|
||||
case decode_payload(Payload) of
|
||||
{ok, Frame} ->
|
||||
decode_frames(Rest, [Frame | Acc]);
|
||||
{error, Reason} ->
|
||||
{error, Reason}
|
||||
end
|
||||
end;
|
||||
{error, Reason} ->
|
||||
{error, Reason}
|
||||
end
|
||||
end.
|
||||
|
||||
-spec decode_payload(binary()) -> {ok, map()} | {error, term()}.
|
||||
decode_payload(Payload) ->
|
||||
case catch json:decode(Payload) of
|
||||
{'EXIT', _} ->
|
||||
{error, invalid_json};
|
||||
Frame when is_map(Frame) ->
|
||||
{ok, Frame};
|
||||
_ ->
|
||||
{error, invalid_json}
|
||||
end.
|
||||
|
||||
-spec parse_length(binary()) -> {ok, non_neg_integer()} | {error, term()}.
|
||||
parse_length(<<>>) ->
|
||||
{error, invalid_frame_length};
|
||||
parse_length(LengthBin) ->
|
||||
try
|
||||
Length = binary_to_integer(LengthBin),
|
||||
case Length >= 0 andalso Length =< ?MAX_FRAME_BYTES of
|
||||
true -> {ok, Length};
|
||||
false -> {error, invalid_frame_length}
|
||||
end
|
||||
catch
|
||||
_:_ ->
|
||||
{error, invalid_frame_length}
|
||||
end.
|
||||
|
||||
-spec secure_compare(binary(), binary()) -> boolean().
|
||||
secure_compare(Left, Right) when is_binary(Left), is_binary(Right) ->
|
||||
case byte_size(Left) =:= byte_size(Right) of
|
||||
true ->
|
||||
crypto:hash_equals(Left, Right);
|
||||
false ->
|
||||
false
|
||||
end.
|
||||
|
||||
-spec request_id_from_frame(map()) -> binary().
|
||||
request_id_from_frame(Frame) ->
|
||||
case maps:get(<<"id">>, Frame, <<>>) of
|
||||
Id when is_binary(Id) ->
|
||||
Id;
|
||||
Id when is_integer(Id) ->
|
||||
integer_to_binary(Id);
|
||||
_ ->
|
||||
<<>>
|
||||
end.
|
||||
|
||||
-spec increment_inflight(state()) -> state().
|
||||
increment_inflight(#{inflight := Inflight} = State) ->
|
||||
State#{inflight => Inflight + 1}.
|
||||
|
||||
-spec decrement_inflight(state()) -> state().
|
||||
decrement_inflight(#{inflight := Inflight} = State) when Inflight > 0 ->
|
||||
State#{inflight => Inflight - 1};
|
||||
decrement_inflight(State) ->
|
||||
State.
|
||||
|
||||
-spec error_binary(term()) -> binary().
|
||||
error_binary(Value) when is_binary(Value) ->
|
||||
Value;
|
||||
error_binary(Value) when is_list(Value) ->
|
||||
unicode:characters_to_binary(Value);
|
||||
error_binary(Value) when is_atom(Value) ->
|
||||
atom_to_binary(Value, utf8);
|
||||
error_binary(Value) ->
|
||||
unicode:characters_to_binary(io_lib:format("~p", [Value])).
|
||||
|
||||
-spec protocol_error_binary(term()) -> binary().
|
||||
protocol_error_binary(invalid_json) ->
|
||||
<<"invalid_json">>;
|
||||
protocol_error_binary(invalid_frame_length) ->
|
||||
<<"invalid_frame_length">>;
|
||||
protocol_error_binary(input_buffer_limit_exceeded) ->
|
||||
<<"input_buffer_limit_exceeded">>.
|
||||
|
||||
-spec close_socket(inet:socket()) -> ok.
|
||||
close_socket(Socket) ->
|
||||
catch gen_tcp:close(Socket),
|
||||
ok.
|
||||
|
||||
-spec max_inflight() -> pos_integer().
|
||||
max_inflight() ->
|
||||
case fluxer_gateway_env:get(gateway_http_rpc_max_concurrency) of
|
||||
Value when is_integer(Value), Value > 0 ->
|
||||
Value;
|
||||
_ ->
|
||||
?DEFAULT_MAX_INFLIGHT
|
||||
end.
|
||||
|
||||
-spec max_input_buffer_bytes() -> pos_integer().
|
||||
max_input_buffer_bytes() ->
|
||||
case fluxer_gateway_env:get(gateway_rpc_tcp_max_input_buffer_bytes) of
|
||||
Value when is_integer(Value), Value > 0 ->
|
||||
Value;
|
||||
_ ->
|
||||
?DEFAULT_MAX_INPUT_BUFFER_BYTES
|
||||
end.
|
||||
|
||||
-ifdef(TEST).
|
||||
-include_lib("eunit/include/eunit.hrl").
|
||||
|
||||
decode_single_frame_test() ->
|
||||
Frame = #{<<"type">> => <<"ping">>},
|
||||
Encoded = encode_frame(Frame),
|
||||
?assertEqual({ok, [Frame], <<>>}, decode_frames(Encoded, [])).
|
||||
|
||||
decode_multiple_frames_test() ->
|
||||
FrameA = #{<<"type">> => <<"ping">>},
|
||||
FrameB = #{<<"type">> => <<"pong">>},
|
||||
Encoded = <<(encode_frame(FrameA))/binary, (encode_frame(FrameB))/binary>>,
|
||||
?assertEqual({ok, [FrameA, FrameB], <<>>}, decode_frames(Encoded, [])).
|
||||
|
||||
decode_partial_frame_test() ->
|
||||
Frame = #{<<"type">> => <<"ping">>},
|
||||
Encoded = encode_frame(Frame),
|
||||
Prefix = binary:part(Encoded, 0, 3),
|
||||
?assertEqual({ok, [], Prefix}, decode_frames(Prefix, [])).
|
||||
|
||||
invalid_length_test() ->
|
||||
?assertEqual({error, invalid_frame_length}, decode_frames(<<"x\n{}">>, [])).
|
||||
|
||||
secure_compare_test() ->
|
||||
?assert(secure_compare(<<"abc">>, <<"abc">>)),
|
||||
?assertNot(secure_compare(<<"abc">>, <<"abd">>)),
|
||||
?assertNot(secure_compare(<<"abc">>, <<"abcd">>)).
|
||||
|
||||
-endif.
|
||||
@@ -1,108 +0,0 @@
|
||||
%% Copyright (C) 2026 Fluxer Contributors
|
||||
%%
|
||||
%% This file is part of Fluxer.
|
||||
%%
|
||||
%% Fluxer is free software: you can redistribute it and/or modify
|
||||
%% it under the terms of the GNU Affero General Public License as published by
|
||||
%% the Free Software Foundation, either version 3 of the License, or
|
||||
%% (at your option) any later version.
|
||||
%%
|
||||
%% Fluxer is distributed in the hope that it will be useful,
|
||||
%% but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
%% GNU Affero General Public License for more details.
|
||||
%%
|
||||
%% You should have received a copy of the GNU Affero General Public License
|
||||
%% along with Fluxer. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
-module(gateway_rpc_tcp_server).
|
||||
-behaviour(gen_server).
|
||||
|
||||
-export([start_link/0, accept_loop/1]).
|
||||
-export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]).
|
||||
|
||||
-type state() :: #{
|
||||
listen_socket := inet:socket(),
|
||||
acceptor_pid := pid(),
|
||||
port := inet:port_number()
|
||||
}.
|
||||
|
||||
-spec start_link() -> {ok, pid()} | {error, term()}.
|
||||
start_link() ->
|
||||
gen_server:start_link({local, ?MODULE}, ?MODULE, [], []).
|
||||
|
||||
-spec init([]) -> {ok, state()} | {stop, term()}.
|
||||
init([]) ->
|
||||
process_flag(trap_exit, true),
|
||||
Port = fluxer_gateway_env:get(rpc_tcp_port),
|
||||
case gen_tcp:listen(Port, listen_options()) of
|
||||
{ok, ListenSocket} ->
|
||||
AcceptorPid = spawn_link(?MODULE, accept_loop, [ListenSocket]),
|
||||
logger:info("Gateway TCP RPC listener started on port ~p", [Port]),
|
||||
{ok, #{
|
||||
listen_socket => ListenSocket,
|
||||
acceptor_pid => AcceptorPid,
|
||||
port => Port
|
||||
}};
|
||||
{error, Reason} ->
|
||||
{stop, {rpc_tcp_listen_failed, Port, Reason}}
|
||||
end.
|
||||
|
||||
-spec handle_call(term(), gen_server:from(), state()) -> {reply, ok, state()}.
|
||||
handle_call(_Request, _From, State) ->
|
||||
{reply, ok, State}.
|
||||
|
||||
-spec handle_cast(term(), state()) -> {noreply, state()}.
|
||||
handle_cast(_Msg, State) ->
|
||||
{noreply, State}.
|
||||
|
||||
-spec handle_info(term(), state()) -> {noreply, state()}.
|
||||
handle_info({'EXIT', Pid, Reason}, #{acceptor_pid := Pid, listen_socket := ListenSocket} = State) ->
|
||||
case Reason of
|
||||
normal ->
|
||||
{noreply, State};
|
||||
shutdown ->
|
||||
{noreply, State};
|
||||
_ ->
|
||||
logger:error("Gateway TCP RPC acceptor crashed: ~p", [Reason]),
|
||||
NewAcceptorPid = spawn_link(?MODULE, accept_loop, [ListenSocket]),
|
||||
{noreply, State#{acceptor_pid => NewAcceptorPid}}
|
||||
end;
|
||||
handle_info(_Info, State) ->
|
||||
{noreply, State}.
|
||||
|
||||
-spec terminate(term(), state()) -> ok.
|
||||
terminate(_Reason, #{listen_socket := ListenSocket, port := Port}) ->
|
||||
catch gen_tcp:close(ListenSocket),
|
||||
logger:info("Gateway TCP RPC listener stopped on port ~p", [Port]),
|
||||
ok.
|
||||
|
||||
-spec code_change(term(), state(), term()) -> {ok, state()}.
|
||||
code_change(_OldVsn, State, _Extra) ->
|
||||
{ok, State}.
|
||||
|
||||
-spec accept_loop(inet:socket()) -> ok.
|
||||
accept_loop(ListenSocket) ->
|
||||
case gen_tcp:accept(ListenSocket) of
|
||||
{ok, Socket} ->
|
||||
_ = spawn_link(?MODULE, accept_loop, [ListenSocket]),
|
||||
gateway_rpc_tcp_connection:serve(Socket);
|
||||
{error, closed} ->
|
||||
ok;
|
||||
{error, Reason} ->
|
||||
logger:error("Gateway TCP RPC accept failed: ~p", [Reason]),
|
||||
timer:sleep(200),
|
||||
accept_loop(ListenSocket)
|
||||
end.
|
||||
|
||||
-spec listen_options() -> [gen_tcp:listen_option()].
|
||||
listen_options() ->
|
||||
[
|
||||
binary,
|
||||
{packet, raw},
|
||||
{active, false},
|
||||
{reuseaddr, true},
|
||||
{nodelay, true},
|
||||
{backlog, 4096},
|
||||
{keepalive, true}
|
||||
].
|
||||
@@ -316,6 +316,7 @@ is_fluxer_module(Module) ->
|
||||
lists:prefix("gateway_http_", ModuleStr) orelse
|
||||
lists:prefix("session", ModuleStr) orelse
|
||||
lists:prefix("guild", ModuleStr) orelse
|
||||
lists:prefix("passive_sync_registry", ModuleStr) orelse
|
||||
lists:prefix("presence", ModuleStr) orelse
|
||||
lists:prefix("push", ModuleStr) orelse
|
||||
lists:prefix("push_dispatcher", ModuleStr) orelse
|
||||
|
||||
@@ -17,96 +17,72 @@
|
||||
|
||||
-module(rpc_client).
|
||||
|
||||
-export([
|
||||
call/1,
|
||||
call/2,
|
||||
get_rpc_url/0,
|
||||
get_rpc_url/1,
|
||||
get_rpc_headers/0
|
||||
]).
|
||||
-export([call/1]).
|
||||
|
||||
-define(NATS_RPC_SUBJECT, <<"rpc.api">>).
|
||||
-define(NATS_RPC_TIMEOUT_MS, 10000).
|
||||
|
||||
-type rpc_request() :: map().
|
||||
-type rpc_response() :: {ok, map()} | {error, term()}.
|
||||
-type rpc_options() :: map().
|
||||
|
||||
-spec call(rpc_request()) -> rpc_response().
|
||||
call(Request) ->
|
||||
call(Request, #{}).
|
||||
case gateway_nats_rpc:get_connection() of
|
||||
{ok, undefined} ->
|
||||
{error, not_connected};
|
||||
{ok, Conn} ->
|
||||
do_request(Conn, Request);
|
||||
{error, Reason} ->
|
||||
{error, {not_connected, Reason}}
|
||||
end.
|
||||
|
||||
-spec call(rpc_request(), rpc_options()) -> rpc_response().
|
||||
call(Request, _Options) ->
|
||||
Url = get_rpc_url(),
|
||||
Headers = get_rpc_headers(),
|
||||
Body = json:encode(Request),
|
||||
case gateway_http_client:request(rpc, post, Url, Headers, Body) of
|
||||
{ok, 200, _RespHeaders, RespBody} ->
|
||||
handle_success_response(RespBody);
|
||||
{ok, StatusCode, _RespHeaders, RespBody} ->
|
||||
handle_error_response(StatusCode, RespBody);
|
||||
-spec do_request(nats:conn(), rpc_request()) -> rpc_response().
|
||||
do_request(Conn, Request) ->
|
||||
Payload = iolist_to_binary(json:encode(Request)),
|
||||
case nats:request(Conn, ?NATS_RPC_SUBJECT, Payload, #{timeout => ?NATS_RPC_TIMEOUT_MS}) of
|
||||
{ok, {ResponseBin, _MsgOpts}} ->
|
||||
handle_nats_response(ResponseBin);
|
||||
{error, timeout} ->
|
||||
{error, timeout};
|
||||
{error, no_responders} ->
|
||||
{error, no_responders};
|
||||
{error, Reason} ->
|
||||
{error, Reason}
|
||||
end.
|
||||
|
||||
-spec handle_success_response(binary()) -> rpc_response().
|
||||
handle_success_response(RespBody) ->
|
||||
Response = json:decode(RespBody),
|
||||
Data = maps:get(<<"data">>, Response, #{}),
|
||||
{ok, Data}.
|
||||
|
||||
-spec handle_error_response(pos_integer(), binary()) -> {error, term()}.
|
||||
handle_error_response(StatusCode, RespBody) ->
|
||||
{error, {http_error, StatusCode, RespBody}}.
|
||||
|
||||
-spec get_rpc_url() -> string().
|
||||
get_rpc_url() ->
|
||||
ApiHost = fluxer_gateway_env:get(api_host),
|
||||
get_rpc_url(ApiHost).
|
||||
|
||||
-spec get_rpc_url(string() | binary()) -> string().
|
||||
get_rpc_url(ApiHost) ->
|
||||
BaseUrl = api_host_base_url(ApiHost),
|
||||
BaseUrl ++ "/_rpc".
|
||||
|
||||
-spec api_host_base_url(string() | binary()) -> string().
|
||||
api_host_base_url(ApiHost) ->
|
||||
HostString = ensure_string(ApiHost),
|
||||
Normalized = normalize_api_host(HostString),
|
||||
strip_trailing_slash(Normalized).
|
||||
|
||||
-spec ensure_string(binary() | string()) -> string().
|
||||
ensure_string(Value) when is_binary(Value) ->
|
||||
binary_to_list(Value);
|
||||
ensure_string(Value) when is_list(Value) ->
|
||||
Value.
|
||||
|
||||
-spec normalize_api_host(string()) -> string().
|
||||
normalize_api_host(Host) ->
|
||||
Lower = string:lowercase(Host),
|
||||
case {has_protocol_prefix(Lower, "http://"), has_protocol_prefix(Lower, "https://")} of
|
||||
{true, _} -> Host;
|
||||
{_, true} -> Host;
|
||||
_ -> "http://" ++ Host
|
||||
-spec handle_nats_response(iodata()) -> rpc_response().
|
||||
handle_nats_response(ResponseBin) ->
|
||||
Response = json:decode(iolist_to_binary(ResponseBin)),
|
||||
case maps:get(<<"_error">>, Response, undefined) of
|
||||
undefined ->
|
||||
Data = maps:get(<<"data">>, Response, #{}),
|
||||
{ok, Data};
|
||||
_ ->
|
||||
Status = maps:get(<<"status">>, Response, 500),
|
||||
Message = maps:get(<<"message">>, Response, <<"unknown error">>),
|
||||
{error, {rpc_error, Status, Message}}
|
||||
end.
|
||||
|
||||
-spec has_protocol_prefix(string(), string()) -> boolean().
|
||||
has_protocol_prefix(Str, Prefix) ->
|
||||
case string:prefix(Str, Prefix) of
|
||||
nomatch -> false;
|
||||
_ -> true
|
||||
end.
|
||||
-ifdef(TEST).
|
||||
-include_lib("eunit/include/eunit.hrl").
|
||||
|
||||
-spec strip_trailing_slash(string()) -> string().
|
||||
strip_trailing_slash([]) ->
|
||||
"";
|
||||
strip_trailing_slash(Url) ->
|
||||
case lists:last(Url) of
|
||||
$/ -> strip_trailing_slash(lists:droplast(Url));
|
||||
_ -> Url
|
||||
end.
|
||||
handle_nats_response_ok_test() ->
|
||||
Response = json:encode(#{
|
||||
<<"type">> => <<"session">>,
|
||||
<<"data">> => #{<<"user">> => <<"test">>}
|
||||
}),
|
||||
?assertEqual({ok, #{<<"user">> => <<"test">>}}, handle_nats_response(Response)).
|
||||
|
||||
-spec get_rpc_headers() -> [{binary() | string(), binary() | string()}].
|
||||
get_rpc_headers() ->
|
||||
RpcSecretKey = fluxer_gateway_env:get(rpc_secret_key),
|
||||
AuthHeader = {<<"Authorization">>, <<"Bearer ", RpcSecretKey/binary>>},
|
||||
InitialHeaders = [AuthHeader],
|
||||
gateway_tracing:inject_rpc_headers(InitialHeaders).
|
||||
handle_nats_response_error_401_test() ->
|
||||
Response = json:encode(#{<<"_error">> => true, <<"status">> => 401, <<"message">> => <<"Unauthorized">>}),
|
||||
?assertEqual({error, {rpc_error, 401, <<"Unauthorized">>}}, handle_nats_response(Response)).
|
||||
|
||||
handle_nats_response_error_429_test() ->
|
||||
Response = json:encode(#{<<"_error">> => true, <<"status">> => 429, <<"message">> => <<"Rate limited">>}),
|
||||
?assertEqual({error, {rpc_error, 429, <<"Rate limited">>}}, handle_nats_response(Response)).
|
||||
|
||||
handle_nats_response_error_500_test() ->
|
||||
Response = json:encode(#{<<"_error">> => true, <<"status">> => 500, <<"message">> => <<"Internal error">>}),
|
||||
?assertEqual({error, {rpc_error, 500, <<"Internal error">>}}, handle_nats_response(Response)).
|
||||
|
||||
-endif.
|
||||
|
||||
Reference in New Issue
Block a user