rpc stuff

This commit is contained in:
saji 2023-05-20 14:53:34 -05:00
parent 992c0cac13
commit c7473a8d02
6 changed files with 274 additions and 18 deletions

View file

@ -47,7 +47,8 @@ encoding the arguments and decoding the response for a remote procedure.
package mprpc package mprpc
import ( import (
"net" "errors"
"io"
"github.com/tinylib/msgp/msgp" "github.com/tinylib/msgp/msgp"
"golang.org/x/exp/slog" "golang.org/x/exp/slog"
@ -66,12 +67,33 @@ type ServiceFunc func(params msgp.Raw) (res msgp.Raw, err error)
// "server" aka listener, and client. // "server" aka listener, and client.
type RPCConn struct { type RPCConn struct {
// TODO: use io.readwritecloser? // TODO: use io.readwritecloser?
conn net.Conn rwc io.ReadWriteCloser
handlers map[string]ServiceFunc handlers map[string]ServiceFunc
ct rpcConnTrack ct rpcConnTrack
slog.Logger logger slog.Logger
}
// creates a new RPC connection on top of an io.ReadWriteCloser. Can be
// pre-seeded with handlers.
func NewRPC(rwc io.ReadWriteCloser, logger *slog.Logger, initialHandlers map[string]ServiceFunc) (rpc *RPCConn, err error) {
rpc = &RPCConn{
rwc: rwc,
handlers: make(map[string]ServiceFunc),
ct: NewRPCConnTrack(),
}
if initialHandlers != nil {
for k,v := range initialHandlers {
rpc.handlers[k] = v
}
}
return
} }
// Call intiates an RPC call to a remote method and returns the // Call intiates an RPC call to a remote method and returns the
@ -85,7 +107,7 @@ func (rpc *RPCConn) Call(method string, params msgp.Raw) (msgp.Raw, error) {
req := NewRequest(id, method, params) req := NewRequest(id, method, params)
w := msgp.NewWriter(rpc.conn) w := msgp.NewWriter(rpc.rwc)
req.EncodeMsg(w) req.EncodeMsg(w)
// block and wait for response. // block and wait for response.
@ -96,14 +118,13 @@ func (rpc *RPCConn) Call(method string, params msgp.Raw) (msgp.Raw, error) {
// Notify initiates a notification to a remote method. It does not // Notify initiates a notification to a remote method. It does not
// return any information. There is no response from the server. // return any information. There is no response from the server.
// This method will not block. An error is returned if there is a local // This method will not block nor will it inform the caller if any errors occur.
// problem.
func (rpc *RPCConn) Notify(method string, params msgp.Raw) { func (rpc *RPCConn) Notify(method string, params msgp.Raw) {
// TODO: return an error if there's a local problem? // TODO: return an error if there's a local problem?
req := NewNotification(method, params) req := NewNotification(method, params)
w := msgp.NewWriter(rpc.conn) w := msgp.NewWriter(rpc.rwc)
req.EncodeMsg(w) req.EncodeMsg(w)
} }
@ -115,30 +136,46 @@ func (rpc *RPCConn) RegisterHandler(name string, fn ServiceFunc) error {
// TODO: mutex lock for sync (or use sync.map? // TODO: mutex lock for sync (or use sync.map?
rpc.handlers[name] = fn rpc.handlers[name] = fn
rpc.Logger.Info("registered a new handler", "name", name, "fn", fn) rpc.logger.Info("registered a new handler", "name", name, "fn", fn)
return nil return nil
} }
// Serve runs the server. It will dispatch goroutines to handle each
// method call. This can (and should in most cases) be run in the background to allow for // Removes a handler, if it exists. Never errors. No-op if the name
// sending and receving on the same connection. // is not a registered handler.
func (rpc *RPCConn) RemoveHandler(name string) error {
delete(rpc.handlers, name)
return nil
}
// Serve runs the server. It will dispatch goroutines to handle each method
// call. This can (and should in most cases) be run in the background to allow
// for sending and receving on the same connection.
func (rpc *RPCConn) Serve() { func (rpc *RPCConn) Serve() {
// construct a stream reader. // construct a stream reader.
msgReader := msgp.NewReader(rpc.conn) msgReader := msgp.NewReader(rpc.rwc)
// read a request/notification from the connection. // read a request/notification from the connection.
var rawmsg msgp.Raw = make(msgp.Raw, 0, 4) var rawmsg msgp.Raw = make(msgp.Raw, 0, 4)
for { for {
rawmsg.DecodeMsg(msgReader) err := rawmsg.DecodeMsg(msgReader)
if err != nil {
if errors.Is(err, io.EOF) {
rpc.logger.Info("reached EOF, stopping server")
return
}
rpc.logger.Warn("error decoding message", "err", err)
continue
}
rpcIntf, err := parseRPC(rawmsg) rpcIntf, err := parseRPC(rawmsg)
if err != nil { if err != nil {
rpc.Logger.Warn("Could not parse RPC message", "err", err) rpc.logger.Warn("Could not parse RPC message", "err", err)
continue continue
} }
@ -152,26 +189,33 @@ func (rpc *RPCConn) Serve() {
case Response: case Response:
cbCh, err := rpc.ct.Clear(rpcObject.MsgId) cbCh, err := rpc.ct.Clear(rpcObject.MsgId)
if err != nil { if err != nil {
rpc.Logger.Warn("could not get rpc callback", "msgid", rpcObject.MsgId, "err", err) rpc.logger.Warn("could not get rpc callback", "msgid", rpcObject.MsgId, "err", err)
continue continue
} }
cbCh <- rpcObject cbCh <- rpcObject
default:
panic("invalid rpcObject!")
} }
} }
} }
// INTERNAL functions for rpcConn
// dispatch is an internal method used to execute a Request sent by the remote:w // dispatch is an internal method used to execute a Request sent by the remote:w
func (rpc *RPCConn) dispatch(req Request) { func (rpc *RPCConn) dispatch(req Request) {
result, err := rpc.handlers[req.Method](req.Params) result, err := rpc.handlers[req.Method](req.Params)
if err != nil { if err != nil {
rpc.Logger.Warn("error dispatching rpc function", "method", req.Method, "err", err) rpc.logger.Warn("error dispatching rpc function", "method", req.Method, "err", err)
} }
// construct the response frame. // construct the response frame.
var rpcE *RPCError = MakeRPCError(err) var rpcE *RPCError = MakeRPCError(err)
w := msgp.NewWriter(rpc.conn) w := msgp.NewWriter(rpc.rwc)
response := NewResponse(req.MsgId, *rpcE, result) response := NewResponse(req.MsgId, *rpcE, result)
@ -187,7 +231,7 @@ func (rpc *RPCConn) dispatchNotif(req Notification) {
if err != nil { if err != nil {
// log the error, but don't do anything about it. // log the error, but don't do anything about it.
rpc.Logger.Warn("error dispatching rpc function", "method", req.Method, "err", err) rpc.logger.Warn("error dispatching rpc function", "method", req.Method, "err", err)
} }
} }

View file

@ -160,3 +160,11 @@ func MakeRPCError(err error) *RPCError {
func (r *RPCError) Error() string { func (r *RPCError) Error() string {
return r.Desc return r.Desc
} }
// we need to describe an empty data that will be excluded in the msgp
// for functions without an argument or return value.
type RPCEmpty struct {
}

View file

@ -97,6 +97,89 @@ func (z *Notification) Msgsize() (s int) {
return return
} }
// DecodeMsg implements msgp.Decodable
func (z *RPCEmpty) DecodeMsg(dc *msgp.Reader) (err error) {
var field []byte
_ = field
var zb0001 uint32
zb0001, err = dc.ReadMapHeader()
if err != nil {
err = msgp.WrapError(err)
return
}
for zb0001 > 0 {
zb0001--
field, err = dc.ReadMapKeyPtr()
if err != nil {
err = msgp.WrapError(err)
return
}
switch msgp.UnsafeString(field) {
default:
err = dc.Skip()
if err != nil {
err = msgp.WrapError(err)
return
}
}
}
return
}
// EncodeMsg implements msgp.Encodable
func (z RPCEmpty) EncodeMsg(en *msgp.Writer) (err error) {
// map header, size 0
err = en.Append(0x80)
if err != nil {
return
}
return
}
// MarshalMsg implements msgp.Marshaler
func (z RPCEmpty) MarshalMsg(b []byte) (o []byte, err error) {
o = msgp.Require(b, z.Msgsize())
// map header, size 0
o = append(o, 0x80)
return
}
// UnmarshalMsg implements msgp.Unmarshaler
func (z *RPCEmpty) UnmarshalMsg(bts []byte) (o []byte, err error) {
var field []byte
_ = field
var zb0001 uint32
zb0001, bts, err = msgp.ReadMapHeaderBytes(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
for zb0001 > 0 {
zb0001--
field, bts, err = msgp.ReadMapKeyZC(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
switch msgp.UnsafeString(field) {
default:
bts, err = msgp.Skip(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
}
}
o = bts
return
}
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
func (z RPCEmpty) Msgsize() (s int) {
s = 1
return
}
// DecodeMsg implements msgp.Decodable // DecodeMsg implements msgp.Decodable
func (z *RPCError) DecodeMsg(dc *msgp.Reader) (err error) { func (z *RPCError) DecodeMsg(dc *msgp.Reader) (err error) {
var zb0001 uint32 var zb0001 uint32

View file

@ -122,6 +122,119 @@ func BenchmarkDecodeNotification(b *testing.B) {
} }
} }
func TestMarshalUnmarshalRPCEmpty(t *testing.T) {
v := RPCEmpty{}
bts, err := v.MarshalMsg(nil)
if err != nil {
t.Fatal(err)
}
left, err := v.UnmarshalMsg(bts)
if err != nil {
t.Fatal(err)
}
if len(left) > 0 {
t.Errorf("%d bytes left over after UnmarshalMsg(): %q", len(left), left)
}
left, err = msgp.Skip(bts)
if err != nil {
t.Fatal(err)
}
if len(left) > 0 {
t.Errorf("%d bytes left over after Skip(): %q", len(left), left)
}
}
func BenchmarkMarshalMsgRPCEmpty(b *testing.B) {
v := RPCEmpty{}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
v.MarshalMsg(nil)
}
}
func BenchmarkAppendMsgRPCEmpty(b *testing.B) {
v := RPCEmpty{}
bts := make([]byte, 0, v.Msgsize())
bts, _ = v.MarshalMsg(bts[0:0])
b.SetBytes(int64(len(bts)))
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
bts, _ = v.MarshalMsg(bts[0:0])
}
}
func BenchmarkUnmarshalRPCEmpty(b *testing.B) {
v := RPCEmpty{}
bts, _ := v.MarshalMsg(nil)
b.ReportAllocs()
b.SetBytes(int64(len(bts)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := v.UnmarshalMsg(bts)
if err != nil {
b.Fatal(err)
}
}
}
func TestEncodeDecodeRPCEmpty(t *testing.T) {
v := RPCEmpty{}
var buf bytes.Buffer
msgp.Encode(&buf, &v)
m := v.Msgsize()
if buf.Len() > m {
t.Log("WARNING: TestEncodeDecodeRPCEmpty Msgsize() is inaccurate")
}
vn := RPCEmpty{}
err := msgp.Decode(&buf, &vn)
if err != nil {
t.Error(err)
}
buf.Reset()
msgp.Encode(&buf, &v)
err = msgp.NewReader(&buf).Skip()
if err != nil {
t.Error(err)
}
}
func BenchmarkEncodeRPCEmpty(b *testing.B) {
v := RPCEmpty{}
var buf bytes.Buffer
msgp.Encode(&buf, &v)
b.SetBytes(int64(buf.Len()))
en := msgp.NewWriter(msgp.Nowhere)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
v.EncodeMsg(en)
}
en.Flush()
}
func BenchmarkDecodeRPCEmpty(b *testing.B) {
v := RPCEmpty{}
var buf bytes.Buffer
msgp.Encode(&buf, &v)
b.SetBytes(int64(buf.Len()))
rd := msgp.NewEndlessReader(buf.Bytes(), b)
dc := msgp.NewReader(rd)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
err := v.DecodeMsg(dc)
if err != nil {
b.Fatal(err)
}
}
}
func TestMarshalUnmarshalRPCError(t *testing.T) { func TestMarshalUnmarshalRPCError(t *testing.T) {
v := RPCError{} v := RPCError{}
bts, err := v.MarshalMsg(nil) bts, err := v.MarshalMsg(nil)

1
mprpc/rpc_test.go Normal file
View file

@ -0,0 +1 @@
package mprpc_test

View file

@ -13,6 +13,13 @@ type rpcConnTrack struct {
mu sync.RWMutex mu sync.RWMutex
} }
func NewRPCConnTrack() rpcConnTrack {
return rpcConnTrack{
ct: make(map[uint32]chan Response),
}
}
// Get attempts to get a random mark from the mutex. // Get attempts to get a random mark from the mutex.
func (c *rpcConnTrack) Claim() (uint32, chan Response) { func (c *rpcConnTrack) Claim() (uint32, chan Response) {
var val uint32 var val uint32