diff --git a/examples/generic_sock_client/main.go b/examples/generic_sock_client/main.go index 47ea0ec..ee42e8e 100644 --- a/examples/generic_sock_client/main.go +++ b/examples/generic_sock_client/main.go @@ -59,7 +59,7 @@ func main() { args = append(args, arg) } } - reqResult, reqError, err := conn.SendRequest(context.Background(), method, args) + reqResult, reqError, err := conn.SendRequest(context.Background(), method, args...) if err != nil { fmt.Println("Error sending request:", err) return diff --git a/examples/mult_server/main.go b/examples/mult_server/main.go index a5ac017..d6d8e53 100644 --- a/examples/mult_server/main.go +++ b/examples/mult_server/main.go @@ -61,7 +61,7 @@ func main() { // Register the ping method ctx := context.Background() - _, reqErr, err := conn.SendRequest(ctx, "$/register", []any{"mult"}) + _, reqErr, err := conn.SendRequest(ctx, "$/register", "mult") if err != nil { slog.Error("Failed to send register request for ping method", "err", err) return diff --git a/examples/ping_client/main.go b/examples/ping_client/main.go index b9457c2..11fcc62 100644 --- a/examples/ping_client/main.go +++ b/examples/ping_client/main.go @@ -34,7 +34,7 @@ func main() { go conn.Run() // Client - reqResult, reqError, err := conn.SendRequest(context.Background(), "ping", []any{"HELLO", 1, true, 5.0}) + reqResult, reqError, err := conn.SendRequest(context.Background(), "ping", "HELLO", 1, true, 5.0) if err != nil { panic(err) } diff --git a/examples/ping_server/main.go b/examples/ping_server/main.go index 3759c23..c1f58be 100644 --- a/examples/ping_server/main.go +++ b/examples/ping_server/main.go @@ -50,7 +50,7 @@ func main() { // Register the ping method ctx := context.Background() - _, reqErr, err := conn.SendRequest(ctx, "$/register", []any{"ping"}) + _, reqErr, err := conn.SendRequest(ctx, "$/register", "ping") if err != nil { slog.Error("Failed to send register request for ping method", "err", reqErr) return diff --git a/internal/msgpackrouter/router.go b/internal/msgpackrouter/router.go index 5651f36..06cd59a 100644 --- a/internal/msgpackrouter/router.go +++ b/internal/msgpackrouter/router.go @@ -129,7 +129,7 @@ func (r *Router) connectionLoop(conn io.ReadWriteCloser) { } // Forward the call to the registered client - reqResult, reqError, err := client.SendRequest(ctx, method, params) + reqResult, reqError, err := client.SendRequest(ctx, method, params...) if err != nil { slog.Error("Failed to send request", "method", method, "err", err) return nil, routerError(ErrCodeFailedToSendRequests, fmt.Sprintf("failed to send request: %s", err)) @@ -157,7 +157,7 @@ func (r *Router) connectionLoop(conn io.ReadWriteCloser) { } // Forward the notification to the registered client - if err := client.SendNotification(method, params); err != nil { + if err := client.SendNotification(method, params...); err != nil { slog.Error("Failed to send notification", "method", method, "err", err) return } diff --git a/internal/msgpackrouter/router_test.go b/internal/msgpackrouter/router_test.go index a9c0704..78a4b27 100644 --- a/internal/msgpackrouter/router_test.go +++ b/internal/msgpackrouter/router_test.go @@ -98,59 +98,59 @@ func TestBasicRouterFunctionality(t *testing.T) { { // Register a method on the first client - result, reqErr, err := cl1.SendRequest(context.Background(), "$/register", []any{"ping"}) + result, reqErr, err := cl1.SendRequest(context.Background(), "$/register", "ping") require.Equal(t, true, result) require.Nil(t, reqErr) require.NoError(t, err) } { // Try to re-register the same method - result, reqErr, err := cl1.SendRequest(context.Background(), "$/register", []any{"ping"}) + result, reqErr, err := cl1.SendRequest(context.Background(), "$/register", "ping") require.Nil(t, result) require.Equal(t, []any{int8(msgpackrouter.ErrCodeRouteAlreadyExists), "route already exists: ping"}, reqErr) require.NoError(t, err) } { // Register a method on the second client - result, reqErr, err := cl2.SendRequest(context.Background(), "$/register", []any{"temperature"}) + result, reqErr, err := cl2.SendRequest(context.Background(), "$/register", "temperature") require.Equal(t, true, result) require.Nil(t, reqErr) require.NoError(t, err) } { // Call from client2 the registered method on client1 - result, reqErr, err := cl2.SendRequest(context.Background(), "ping", []any{"1", 2, true}) + result, reqErr, err := cl2.SendRequest(context.Background(), "ping", "1", 2, true) require.Equal(t, []any{"1", int8(2), true}, result) require.Nil(t, reqErr) require.NoError(t, err) } { // Self-call from client1 - result, reqErr, err := cl1.SendRequest(context.Background(), "ping", []any{"c", 12, false}) + result, reqErr, err := cl1.SendRequest(context.Background(), "ping", "c", 12, false) require.Equal(t, []any{"c", int8(12), false}, result) require.Nil(t, reqErr) require.NoError(t, err) } { // Call from client2 an un-registered method - result, reqErr, err := cl2.SendRequest(context.Background(), "not-existent-method", []any{"1", 2, true}) + result, reqErr, err := cl2.SendRequest(context.Background(), "not-existent-method", "1", 2, true) require.Nil(t, result) require.Equal(t, []any{int8(msgpackrouter.ErrCodeMethodNotAvailable), "method not-existent-method not available"}, reqErr) require.NoError(t, err) } { // Send notification to client1 - err := cl2.SendNotification("ping", []any{"a", int16(4), false}) + err := cl2.SendNotification("ping", "a", int16(4), false) require.NoError(t, err) } { // Send notification to unregistered method - err := cl2.SendNotification("notexistent", []any{"a", int16(4), false}) + err := cl2.SendNotification("notexistent", "a", int16(4), false) require.NoError(t, err) } { // Self-send notification - err := cl1.SendNotification("ping", []any{"b", int16(14), true, true}) + err := cl1.SendNotification("ping", "b", int16(14), true, true) require.NoError(t, err) } time.Sleep(100 * time.Millisecond) // Give some time for the notifications to be processed @@ -190,7 +190,7 @@ func TestMessageForwarderCongestionControl(t *testing.T) { { // Register a method on the first client - result, reqErr, err := cl1.SendRequest(context.Background(), "$/register", []any{"test"}) + result, reqErr, err := cl1.SendRequest(context.Background(), "$/register", "test") require.Equal(t, true, result) require.Nil(t, reqErr) require.NoError(t, err) @@ -201,7 +201,7 @@ func TestMessageForwarderCongestionControl(t *testing.T) { var wg sync.WaitGroup for range batchSize { wg.Go(func() { - _, _, err := cl2.SendRequest(t.Context(), "test", []any{}) + _, _, err := cl2.SendRequest(t.Context(), "test") require.NoError(t, err) }) } diff --git a/msgpackrpc/connection.go b/msgpackrpc/connection.go index 35c051a..ddd95f0 100644 --- a/msgpackrpc/connection.go +++ b/msgpackrpc/connection.go @@ -327,7 +327,10 @@ func (c *Connection) Close() { _ = c.out.Close() } -func (c *Connection) SendRequest(ctx context.Context, method string, params []any) (reqResult any, reqError any, err error) { +func (c *Connection) SendRequest(ctx context.Context, method string, params ...any) (reqResult any, reqError any, err error) { + if params == nil { + params = []any{} + } id := MessageID(c.lastOutRequestsIndex.Add(1)) c.loggerMutex.Lock() @@ -364,7 +367,7 @@ func (c *Connection) SendRequest(ctx context.Context, method string, params []an c.logger.LogOutgoingCancelRequest(id) c.loggerMutex.Unlock() - _ = c.SendNotification("$/cancelRequest", []any{id}) // ignore error (it won't matter anyway) + _ = c.SendNotification("$/cancelRequest", id) // ignore error (it won't matter anyway) } // After cancelation wait for result... @@ -378,7 +381,11 @@ func (c *Connection) SendRequest(ctx context.Context, method string, params []an return result.reqResult, result.reqError, nil } -func (c *Connection) SendNotification(method string, params []any) error { +func (c *Connection) SendNotification(method string, params ...any) error { + if params == nil { + params = []any{} + } + c.loggerMutex.Lock() c.logger.LogOutgoingNotification(method, params) c.loggerMutex.Unlock() diff --git a/msgpackrpc/connection_rpc_test.go b/msgpackrpc/connection_rpc_test.go index 40a7b2c..dd62389 100644 --- a/msgpackrpc/connection_rpc_test.go +++ b/msgpackrpc/connection_rpc_test.go @@ -130,7 +130,7 @@ func TestRPCConnection(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - respRes, respErr, err := conn.SendRequest(t.Context(), "helloworld", []any{true}) + respRes, respErr, err := conn.SendRequest(t.Context(), "helloworld", true) require.NoError(t, err) require.Nil(t, respErr) require.Equal(t, map[string]any{"fakedata": int8(99)}, respRes)