From c7473a8d024856108881351835fb8240102a9a7a Mon Sep 17 00:00:00 2001 From: saji Date: Sat, 20 May 2023 14:53:34 -0500 Subject: [PATCH] rpc stuff --- mprpc/rpc.go | 80 +++++++++++++++++++++------ mprpc/rpc_msg.go | 8 +++ mprpc/rpc_msg_gen.go | 83 ++++++++++++++++++++++++++++ mprpc/rpc_msg_gen_test.go | 113 ++++++++++++++++++++++++++++++++++++++ mprpc/rpc_test.go | 1 + mprpc/rpcconntrack.go | 7 +++ 6 files changed, 274 insertions(+), 18 deletions(-) create mode 100644 mprpc/rpc_test.go diff --git a/mprpc/rpc.go b/mprpc/rpc.go index 543198f..49546ad 100644 --- a/mprpc/rpc.go +++ b/mprpc/rpc.go @@ -47,7 +47,8 @@ encoding the arguments and decoding the response for a remote procedure. package mprpc import ( - "net" + "errors" + "io" "github.com/tinylib/msgp/msgp" "golang.org/x/exp/slog" @@ -66,12 +67,33 @@ type ServiceFunc func(params msgp.Raw) (res msgp.Raw, err error) // "server" aka listener, and client. type RPCConn struct { // TODO: use io.readwritecloser? - conn net.Conn + rwc io.ReadWriteCloser handlers map[string]ServiceFunc ct rpcConnTrack - slog.Logger + logger slog.Logger +} + + +// creates a new RPC connection on top of an io.ReadWriteCloser. Can be +// pre-seeded with handlers. +func NewRPC(rwc io.ReadWriteCloser, logger *slog.Logger, initialHandlers map[string]ServiceFunc) (rpc *RPCConn, err error) { + + rpc = &RPCConn{ + rwc: rwc, + handlers: make(map[string]ServiceFunc), + ct: NewRPCConnTrack(), + } + if initialHandlers != nil { + for k,v := range initialHandlers { + rpc.handlers[k] = v + } + } + + + return + } // Call intiates an RPC call to a remote method and returns the @@ -85,7 +107,7 @@ func (rpc *RPCConn) Call(method string, params msgp.Raw) (msgp.Raw, error) { req := NewRequest(id, method, params) - w := msgp.NewWriter(rpc.conn) + w := msgp.NewWriter(rpc.rwc) req.EncodeMsg(w) // block and wait for response. @@ -96,14 +118,13 @@ func (rpc *RPCConn) Call(method string, params msgp.Raw) (msgp.Raw, error) { // Notify initiates a notification to a remote method. It does not // return any information. There is no response from the server. -// This method will not block. An error is returned if there is a local -// problem. +// This method will not block nor will it inform the caller if any errors occur. func (rpc *RPCConn) Notify(method string, params msgp.Raw) { // TODO: return an error if there's a local problem? req := NewNotification(method, params) - w := msgp.NewWriter(rpc.conn) + w := msgp.NewWriter(rpc.rwc) req.EncodeMsg(w) } @@ -115,30 +136,46 @@ func (rpc *RPCConn) RegisterHandler(name string, fn ServiceFunc) error { // TODO: mutex lock for sync (or use sync.map? rpc.handlers[name] = fn - rpc.Logger.Info("registered a new handler", "name", name, "fn", fn) + rpc.logger.Info("registered a new handler", "name", name, "fn", fn) return nil } -// Serve runs the server. It will dispatch goroutines to handle each -// method call. This can (and should in most cases) be run in the background to allow for -// sending and receving on the same connection. + +// Removes a handler, if it exists. Never errors. No-op if the name +// is not a registered handler. +func (rpc *RPCConn) RemoveHandler(name string) error { + delete(rpc.handlers, name) + return nil +} + +// Serve runs the server. It will dispatch goroutines to handle each method +// call. This can (and should in most cases) be run in the background to allow +// for sending and receving on the same connection. func (rpc *RPCConn) Serve() { // construct a stream reader. - msgReader := msgp.NewReader(rpc.conn) + msgReader := msgp.NewReader(rpc.rwc) // read a request/notification from the connection. var rawmsg msgp.Raw = make(msgp.Raw, 0, 4) for { - rawmsg.DecodeMsg(msgReader) + err := rawmsg.DecodeMsg(msgReader) + if err != nil { + if errors.Is(err, io.EOF) { + rpc.logger.Info("reached EOF, stopping server") + return + } + rpc.logger.Warn("error decoding message", "err", err) + continue + } rpcIntf, err := parseRPC(rawmsg) if err != nil { - rpc.Logger.Warn("Could not parse RPC message", "err", err) + rpc.logger.Warn("Could not parse RPC message", "err", err) continue } @@ -152,26 +189,33 @@ func (rpc *RPCConn) Serve() { case Response: cbCh, err := rpc.ct.Clear(rpcObject.MsgId) if err != nil { - rpc.Logger.Warn("could not get rpc callback", "msgid", rpcObject.MsgId, "err", err) + rpc.logger.Warn("could not get rpc callback", "msgid", rpcObject.MsgId, "err", err) continue } cbCh <- rpcObject + default: + panic("invalid rpcObject!") } } } + + + +// INTERNAL functions for rpcConn + // dispatch is an internal method used to execute a Request sent by the remote:w func (rpc *RPCConn) dispatch(req Request) { result, err := rpc.handlers[req.Method](req.Params) if err != nil { - rpc.Logger.Warn("error dispatching rpc function", "method", req.Method, "err", err) + rpc.logger.Warn("error dispatching rpc function", "method", req.Method, "err", err) } // construct the response frame. var rpcE *RPCError = MakeRPCError(err) - w := msgp.NewWriter(rpc.conn) + w := msgp.NewWriter(rpc.rwc) response := NewResponse(req.MsgId, *rpcE, result) @@ -187,7 +231,7 @@ func (rpc *RPCConn) dispatchNotif(req Notification) { if err != nil { // log the error, but don't do anything about it. - rpc.Logger.Warn("error dispatching rpc function", "method", req.Method, "err", err) + rpc.logger.Warn("error dispatching rpc function", "method", req.Method, "err", err) } } diff --git a/mprpc/rpc_msg.go b/mprpc/rpc_msg.go index 49ad03b..b4071ab 100644 --- a/mprpc/rpc_msg.go +++ b/mprpc/rpc_msg.go @@ -160,3 +160,11 @@ func MakeRPCError(err error) *RPCError { func (r *RPCError) Error() string { return r.Desc } + + + + +// we need to describe an empty data that will be excluded in the msgp +// for functions without an argument or return value. +type RPCEmpty struct { +} diff --git a/mprpc/rpc_msg_gen.go b/mprpc/rpc_msg_gen.go index 0524814..32c5ec3 100644 --- a/mprpc/rpc_msg_gen.go +++ b/mprpc/rpc_msg_gen.go @@ -97,6 +97,89 @@ func (z *Notification) Msgsize() (s int) { return } +// DecodeMsg implements msgp.Decodable +func (z *RPCEmpty) DecodeMsg(dc *msgp.Reader) (err error) { + var field []byte + _ = field + var zb0001 uint32 + zb0001, err = dc.ReadMapHeader() + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, err = dc.ReadMapKeyPtr() + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + default: + err = dc.Skip() + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + return +} + +// EncodeMsg implements msgp.Encodable +func (z RPCEmpty) EncodeMsg(en *msgp.Writer) (err error) { + // map header, size 0 + err = en.Append(0x80) + if err != nil { + return + } + return +} + +// MarshalMsg implements msgp.Marshaler +func (z RPCEmpty) MarshalMsg(b []byte) (o []byte, err error) { + o = msgp.Require(b, z.Msgsize()) + // map header, size 0 + o = append(o, 0x80) + return +} + +// UnmarshalMsg implements msgp.Unmarshaler +func (z *RPCEmpty) UnmarshalMsg(bts []byte) (o []byte, err error) { + var field []byte + _ = field + var zb0001 uint32 + zb0001, bts, err = msgp.ReadMapHeaderBytes(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, bts, err = msgp.ReadMapKeyZC(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + default: + bts, err = msgp.Skip(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + o = bts + return +} + +// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message +func (z RPCEmpty) Msgsize() (s int) { + s = 1 + return +} + // DecodeMsg implements msgp.Decodable func (z *RPCError) DecodeMsg(dc *msgp.Reader) (err error) { var zb0001 uint32 diff --git a/mprpc/rpc_msg_gen_test.go b/mprpc/rpc_msg_gen_test.go index f451535..aa6493a 100644 --- a/mprpc/rpc_msg_gen_test.go +++ b/mprpc/rpc_msg_gen_test.go @@ -122,6 +122,119 @@ func BenchmarkDecodeNotification(b *testing.B) { } } +func TestMarshalUnmarshalRPCEmpty(t *testing.T) { + v := RPCEmpty{} + bts, err := v.MarshalMsg(nil) + if err != nil { + t.Fatal(err) + } + left, err := v.UnmarshalMsg(bts) + if err != nil { + t.Fatal(err) + } + if len(left) > 0 { + t.Errorf("%d bytes left over after UnmarshalMsg(): %q", len(left), left) + } + + left, err = msgp.Skip(bts) + if err != nil { + t.Fatal(err) + } + if len(left) > 0 { + t.Errorf("%d bytes left over after Skip(): %q", len(left), left) + } +} + +func BenchmarkMarshalMsgRPCEmpty(b *testing.B) { + v := RPCEmpty{} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + v.MarshalMsg(nil) + } +} + +func BenchmarkAppendMsgRPCEmpty(b *testing.B) { + v := RPCEmpty{} + bts := make([]byte, 0, v.Msgsize()) + bts, _ = v.MarshalMsg(bts[0:0]) + b.SetBytes(int64(len(bts))) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + bts, _ = v.MarshalMsg(bts[0:0]) + } +} + +func BenchmarkUnmarshalRPCEmpty(b *testing.B) { + v := RPCEmpty{} + bts, _ := v.MarshalMsg(nil) + b.ReportAllocs() + b.SetBytes(int64(len(bts))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := v.UnmarshalMsg(bts) + if err != nil { + b.Fatal(err) + } + } +} + +func TestEncodeDecodeRPCEmpty(t *testing.T) { + v := RPCEmpty{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + + m := v.Msgsize() + if buf.Len() > m { + t.Log("WARNING: TestEncodeDecodeRPCEmpty Msgsize() is inaccurate") + } + + vn := RPCEmpty{} + err := msgp.Decode(&buf, &vn) + if err != nil { + t.Error(err) + } + + buf.Reset() + msgp.Encode(&buf, &v) + err = msgp.NewReader(&buf).Skip() + if err != nil { + t.Error(err) + } +} + +func BenchmarkEncodeRPCEmpty(b *testing.B) { + v := RPCEmpty{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + b.SetBytes(int64(buf.Len())) + en := msgp.NewWriter(msgp.Nowhere) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + v.EncodeMsg(en) + } + en.Flush() +} + +func BenchmarkDecodeRPCEmpty(b *testing.B) { + v := RPCEmpty{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + b.SetBytes(int64(buf.Len())) + rd := msgp.NewEndlessReader(buf.Bytes(), b) + dc := msgp.NewReader(rd) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + err := v.DecodeMsg(dc) + if err != nil { + b.Fatal(err) + } + } +} + func TestMarshalUnmarshalRPCError(t *testing.T) { v := RPCError{} bts, err := v.MarshalMsg(nil) diff --git a/mprpc/rpc_test.go b/mprpc/rpc_test.go new file mode 100644 index 0000000..967bbe9 --- /dev/null +++ b/mprpc/rpc_test.go @@ -0,0 +1 @@ +package mprpc_test diff --git a/mprpc/rpcconntrack.go b/mprpc/rpcconntrack.go index 708c21d..c2ad098 100644 --- a/mprpc/rpcconntrack.go +++ b/mprpc/rpcconntrack.go @@ -13,6 +13,13 @@ type rpcConnTrack struct { mu sync.RWMutex } + +func NewRPCConnTrack() rpcConnTrack { + return rpcConnTrack{ + ct: make(map[uint32]chan Response), + } +} + // Get attempts to get a random mark from the mutex. func (c *rpcConnTrack) Claim() (uint32, chan Response) { var val uint32