rpc stuff
This commit is contained in:
parent
992c0cac13
commit
c7473a8d02
80
mprpc/rpc.go
80
mprpc/rpc.go
|
@ -47,7 +47,8 @@ encoding the arguments and decoding the response for a remote procedure.
|
|||
package mprpc
|
||||
|
||||
import (
|
||||
"net"
|
||||
"errors"
|
||||
"io"
|
||||
|
||||
"github.com/tinylib/msgp/msgp"
|
||||
"golang.org/x/exp/slog"
|
||||
|
@ -66,12 +67,33 @@ type ServiceFunc func(params msgp.Raw) (res msgp.Raw, err error)
|
|||
// "server" aka listener, and client.
|
||||
type RPCConn struct {
|
||||
// TODO: use io.readwritecloser?
|
||||
conn net.Conn
|
||||
rwc io.ReadWriteCloser
|
||||
handlers map[string]ServiceFunc
|
||||
|
||||
ct rpcConnTrack
|
||||
|
||||
slog.Logger
|
||||
logger slog.Logger
|
||||
}
|
||||
|
||||
|
||||
// creates a new RPC connection on top of an io.ReadWriteCloser. Can be
|
||||
// pre-seeded with handlers.
|
||||
func NewRPC(rwc io.ReadWriteCloser, logger *slog.Logger, initialHandlers map[string]ServiceFunc) (rpc *RPCConn, err error) {
|
||||
|
||||
rpc = &RPCConn{
|
||||
rwc: rwc,
|
||||
handlers: make(map[string]ServiceFunc),
|
||||
ct: NewRPCConnTrack(),
|
||||
}
|
||||
if initialHandlers != nil {
|
||||
for k,v := range initialHandlers {
|
||||
rpc.handlers[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return
|
||||
|
||||
}
|
||||
|
||||
// Call intiates an RPC call to a remote method and returns the
|
||||
|
@ -85,7 +107,7 @@ func (rpc *RPCConn) Call(method string, params msgp.Raw) (msgp.Raw, error) {
|
|||
|
||||
req := NewRequest(id, method, params)
|
||||
|
||||
w := msgp.NewWriter(rpc.conn)
|
||||
w := msgp.NewWriter(rpc.rwc)
|
||||
req.EncodeMsg(w)
|
||||
|
||||
// block and wait for response.
|
||||
|
@ -96,14 +118,13 @@ func (rpc *RPCConn) Call(method string, params msgp.Raw) (msgp.Raw, error) {
|
|||
|
||||
// Notify initiates a notification to a remote method. It does not
|
||||
// return any information. There is no response from the server.
|
||||
// This method will not block. An error is returned if there is a local
|
||||
// problem.
|
||||
// This method will not block nor will it inform the caller if any errors occur.
|
||||
func (rpc *RPCConn) Notify(method string, params msgp.Raw) {
|
||||
// TODO: return an error if there's a local problem?
|
||||
|
||||
req := NewNotification(method, params)
|
||||
|
||||
w := msgp.NewWriter(rpc.conn)
|
||||
w := msgp.NewWriter(rpc.rwc)
|
||||
req.EncodeMsg(w)
|
||||
|
||||
}
|
||||
|
@ -115,30 +136,46 @@ func (rpc *RPCConn) RegisterHandler(name string, fn ServiceFunc) error {
|
|||
// TODO: mutex lock for sync (or use sync.map?
|
||||
rpc.handlers[name] = fn
|
||||
|
||||
rpc.Logger.Info("registered a new handler", "name", name, "fn", fn)
|
||||
rpc.logger.Info("registered a new handler", "name", name, "fn", fn)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Serve runs the server. It will dispatch goroutines to handle each
|
||||
// method call. This can (and should in most cases) be run in the background to allow for
|
||||
// sending and receving on the same connection.
|
||||
|
||||
// Removes a handler, if it exists. Never errors. No-op if the name
|
||||
// is not a registered handler.
|
||||
func (rpc *RPCConn) RemoveHandler(name string) error {
|
||||
delete(rpc.handlers, name)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Serve runs the server. It will dispatch goroutines to handle each method
|
||||
// call. This can (and should in most cases) be run in the background to allow
|
||||
// for sending and receving on the same connection.
|
||||
func (rpc *RPCConn) Serve() {
|
||||
|
||||
// construct a stream reader.
|
||||
msgReader := msgp.NewReader(rpc.conn)
|
||||
msgReader := msgp.NewReader(rpc.rwc)
|
||||
|
||||
// read a request/notification from the connection.
|
||||
|
||||
var rawmsg msgp.Raw = make(msgp.Raw, 0, 4)
|
||||
|
||||
for {
|
||||
rawmsg.DecodeMsg(msgReader)
|
||||
err := rawmsg.DecodeMsg(msgReader)
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
rpc.logger.Info("reached EOF, stopping server")
|
||||
return
|
||||
}
|
||||
rpc.logger.Warn("error decoding message", "err", err)
|
||||
continue
|
||||
}
|
||||
|
||||
rpcIntf, err := parseRPC(rawmsg)
|
||||
|
||||
if err != nil {
|
||||
rpc.Logger.Warn("Could not parse RPC message", "err", err)
|
||||
rpc.logger.Warn("Could not parse RPC message", "err", err)
|
||||
continue
|
||||
}
|
||||
|
||||
|
@ -152,26 +189,33 @@ func (rpc *RPCConn) Serve() {
|
|||
case Response:
|
||||
cbCh, err := rpc.ct.Clear(rpcObject.MsgId)
|
||||
if err != nil {
|
||||
rpc.Logger.Warn("could not get rpc callback", "msgid", rpcObject.MsgId, "err", err)
|
||||
rpc.logger.Warn("could not get rpc callback", "msgid", rpcObject.MsgId, "err", err)
|
||||
continue
|
||||
}
|
||||
cbCh <- rpcObject
|
||||
default:
|
||||
panic("invalid rpcObject!")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
// INTERNAL functions for rpcConn
|
||||
|
||||
// dispatch is an internal method used to execute a Request sent by the remote:w
|
||||
func (rpc *RPCConn) dispatch(req Request) {
|
||||
|
||||
result, err := rpc.handlers[req.Method](req.Params)
|
||||
|
||||
if err != nil {
|
||||
rpc.Logger.Warn("error dispatching rpc function", "method", req.Method, "err", err)
|
||||
rpc.logger.Warn("error dispatching rpc function", "method", req.Method, "err", err)
|
||||
}
|
||||
// construct the response frame.
|
||||
var rpcE *RPCError = MakeRPCError(err)
|
||||
|
||||
w := msgp.NewWriter(rpc.conn)
|
||||
w := msgp.NewWriter(rpc.rwc)
|
||||
|
||||
response := NewResponse(req.MsgId, *rpcE, result)
|
||||
|
||||
|
@ -187,7 +231,7 @@ func (rpc *RPCConn) dispatchNotif(req Notification) {
|
|||
|
||||
if err != nil {
|
||||
// log the error, but don't do anything about it.
|
||||
rpc.Logger.Warn("error dispatching rpc function", "method", req.Method, "err", err)
|
||||
rpc.logger.Warn("error dispatching rpc function", "method", req.Method, "err", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -160,3 +160,11 @@ func MakeRPCError(err error) *RPCError {
|
|||
func (r *RPCError) Error() string {
|
||||
return r.Desc
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
// we need to describe an empty data that will be excluded in the msgp
|
||||
// for functions without an argument or return value.
|
||||
type RPCEmpty struct {
|
||||
}
|
||||
|
|
|
@ -97,6 +97,89 @@ func (z *Notification) Msgsize() (s int) {
|
|||
return
|
||||
}
|
||||
|
||||
// DecodeMsg implements msgp.Decodable
|
||||
func (z *RPCEmpty) DecodeMsg(dc *msgp.Reader) (err error) {
|
||||
var field []byte
|
||||
_ = field
|
||||
var zb0001 uint32
|
||||
zb0001, err = dc.ReadMapHeader()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
for zb0001 > 0 {
|
||||
zb0001--
|
||||
field, err = dc.ReadMapKeyPtr()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
switch msgp.UnsafeString(field) {
|
||||
default:
|
||||
err = dc.Skip()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// EncodeMsg implements msgp.Encodable
|
||||
func (z RPCEmpty) EncodeMsg(en *msgp.Writer) (err error) {
|
||||
// map header, size 0
|
||||
err = en.Append(0x80)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// MarshalMsg implements msgp.Marshaler
|
||||
func (z RPCEmpty) MarshalMsg(b []byte) (o []byte, err error) {
|
||||
o = msgp.Require(b, z.Msgsize())
|
||||
// map header, size 0
|
||||
o = append(o, 0x80)
|
||||
return
|
||||
}
|
||||
|
||||
// UnmarshalMsg implements msgp.Unmarshaler
|
||||
func (z *RPCEmpty) UnmarshalMsg(bts []byte) (o []byte, err error) {
|
||||
var field []byte
|
||||
_ = field
|
||||
var zb0001 uint32
|
||||
zb0001, bts, err = msgp.ReadMapHeaderBytes(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
for zb0001 > 0 {
|
||||
zb0001--
|
||||
field, bts, err = msgp.ReadMapKeyZC(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
switch msgp.UnsafeString(field) {
|
||||
default:
|
||||
bts, err = msgp.Skip(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
o = bts
|
||||
return
|
||||
}
|
||||
|
||||
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
|
||||
func (z RPCEmpty) Msgsize() (s int) {
|
||||
s = 1
|
||||
return
|
||||
}
|
||||
|
||||
// DecodeMsg implements msgp.Decodable
|
||||
func (z *RPCError) DecodeMsg(dc *msgp.Reader) (err error) {
|
||||
var zb0001 uint32
|
||||
|
|
|
@ -122,6 +122,119 @@ func BenchmarkDecodeNotification(b *testing.B) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestMarshalUnmarshalRPCEmpty(t *testing.T) {
|
||||
v := RPCEmpty{}
|
||||
bts, err := v.MarshalMsg(nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
left, err := v.UnmarshalMsg(bts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(left) > 0 {
|
||||
t.Errorf("%d bytes left over after UnmarshalMsg(): %q", len(left), left)
|
||||
}
|
||||
|
||||
left, err = msgp.Skip(bts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(left) > 0 {
|
||||
t.Errorf("%d bytes left over after Skip(): %q", len(left), left)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMarshalMsgRPCEmpty(b *testing.B) {
|
||||
v := RPCEmpty{}
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
v.MarshalMsg(nil)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkAppendMsgRPCEmpty(b *testing.B) {
|
||||
v := RPCEmpty{}
|
||||
bts := make([]byte, 0, v.Msgsize())
|
||||
bts, _ = v.MarshalMsg(bts[0:0])
|
||||
b.SetBytes(int64(len(bts)))
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
bts, _ = v.MarshalMsg(bts[0:0])
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkUnmarshalRPCEmpty(b *testing.B) {
|
||||
v := RPCEmpty{}
|
||||
bts, _ := v.MarshalMsg(nil)
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(int64(len(bts)))
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := v.UnmarshalMsg(bts)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeDecodeRPCEmpty(t *testing.T) {
|
||||
v := RPCEmpty{}
|
||||
var buf bytes.Buffer
|
||||
msgp.Encode(&buf, &v)
|
||||
|
||||
m := v.Msgsize()
|
||||
if buf.Len() > m {
|
||||
t.Log("WARNING: TestEncodeDecodeRPCEmpty Msgsize() is inaccurate")
|
||||
}
|
||||
|
||||
vn := RPCEmpty{}
|
||||
err := msgp.Decode(&buf, &vn)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
buf.Reset()
|
||||
msgp.Encode(&buf, &v)
|
||||
err = msgp.NewReader(&buf).Skip()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkEncodeRPCEmpty(b *testing.B) {
|
||||
v := RPCEmpty{}
|
||||
var buf bytes.Buffer
|
||||
msgp.Encode(&buf, &v)
|
||||
b.SetBytes(int64(buf.Len()))
|
||||
en := msgp.NewWriter(msgp.Nowhere)
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
v.EncodeMsg(en)
|
||||
}
|
||||
en.Flush()
|
||||
}
|
||||
|
||||
func BenchmarkDecodeRPCEmpty(b *testing.B) {
|
||||
v := RPCEmpty{}
|
||||
var buf bytes.Buffer
|
||||
msgp.Encode(&buf, &v)
|
||||
b.SetBytes(int64(buf.Len()))
|
||||
rd := msgp.NewEndlessReader(buf.Bytes(), b)
|
||||
dc := msgp.NewReader(rd)
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
err := v.DecodeMsg(dc)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalUnmarshalRPCError(t *testing.T) {
|
||||
v := RPCError{}
|
||||
bts, err := v.MarshalMsg(nil)
|
||||
|
|
1
mprpc/rpc_test.go
Normal file
1
mprpc/rpc_test.go
Normal file
|
@ -0,0 +1 @@
|
|||
package mprpc_test
|
|
@ -13,6 +13,13 @@ type rpcConnTrack struct {
|
|||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
|
||||
func NewRPCConnTrack() rpcConnTrack {
|
||||
return rpcConnTrack{
|
||||
ct: make(map[uint32]chan Response),
|
||||
}
|
||||
}
|
||||
|
||||
// Get attempts to get a random mark from the mutex.
|
||||
func (c *rpcConnTrack) Claim() (uint32, chan Response) {
|
||||
var val uint32
|
||||
|
|
Loading…
Reference in a new issue