rpc stuff
This commit is contained in:
parent
992c0cac13
commit
c7473a8d02
80
mprpc/rpc.go
80
mprpc/rpc.go
|
@ -47,7 +47,8 @@ encoding the arguments and decoding the response for a remote procedure.
|
||||||
package mprpc
|
package mprpc
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"errors"
|
||||||
|
"io"
|
||||||
|
|
||||||
"github.com/tinylib/msgp/msgp"
|
"github.com/tinylib/msgp/msgp"
|
||||||
"golang.org/x/exp/slog"
|
"golang.org/x/exp/slog"
|
||||||
|
@ -66,12 +67,33 @@ type ServiceFunc func(params msgp.Raw) (res msgp.Raw, err error)
|
||||||
// "server" aka listener, and client.
|
// "server" aka listener, and client.
|
||||||
type RPCConn struct {
|
type RPCConn struct {
|
||||||
// TODO: use io.readwritecloser?
|
// TODO: use io.readwritecloser?
|
||||||
conn net.Conn
|
rwc io.ReadWriteCloser
|
||||||
handlers map[string]ServiceFunc
|
handlers map[string]ServiceFunc
|
||||||
|
|
||||||
ct rpcConnTrack
|
ct rpcConnTrack
|
||||||
|
|
||||||
slog.Logger
|
logger slog.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// creates a new RPC connection on top of an io.ReadWriteCloser. Can be
|
||||||
|
// pre-seeded with handlers.
|
||||||
|
func NewRPC(rwc io.ReadWriteCloser, logger *slog.Logger, initialHandlers map[string]ServiceFunc) (rpc *RPCConn, err error) {
|
||||||
|
|
||||||
|
rpc = &RPCConn{
|
||||||
|
rwc: rwc,
|
||||||
|
handlers: make(map[string]ServiceFunc),
|
||||||
|
ct: NewRPCConnTrack(),
|
||||||
|
}
|
||||||
|
if initialHandlers != nil {
|
||||||
|
for k,v := range initialHandlers {
|
||||||
|
rpc.handlers[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Call intiates an RPC call to a remote method and returns the
|
// Call intiates an RPC call to a remote method and returns the
|
||||||
|
@ -85,7 +107,7 @@ func (rpc *RPCConn) Call(method string, params msgp.Raw) (msgp.Raw, error) {
|
||||||
|
|
||||||
req := NewRequest(id, method, params)
|
req := NewRequest(id, method, params)
|
||||||
|
|
||||||
w := msgp.NewWriter(rpc.conn)
|
w := msgp.NewWriter(rpc.rwc)
|
||||||
req.EncodeMsg(w)
|
req.EncodeMsg(w)
|
||||||
|
|
||||||
// block and wait for response.
|
// block and wait for response.
|
||||||
|
@ -96,14 +118,13 @@ func (rpc *RPCConn) Call(method string, params msgp.Raw) (msgp.Raw, error) {
|
||||||
|
|
||||||
// Notify initiates a notification to a remote method. It does not
|
// Notify initiates a notification to a remote method. It does not
|
||||||
// return any information. There is no response from the server.
|
// return any information. There is no response from the server.
|
||||||
// This method will not block. An error is returned if there is a local
|
// This method will not block nor will it inform the caller if any errors occur.
|
||||||
// problem.
|
|
||||||
func (rpc *RPCConn) Notify(method string, params msgp.Raw) {
|
func (rpc *RPCConn) Notify(method string, params msgp.Raw) {
|
||||||
// TODO: return an error if there's a local problem?
|
// TODO: return an error if there's a local problem?
|
||||||
|
|
||||||
req := NewNotification(method, params)
|
req := NewNotification(method, params)
|
||||||
|
|
||||||
w := msgp.NewWriter(rpc.conn)
|
w := msgp.NewWriter(rpc.rwc)
|
||||||
req.EncodeMsg(w)
|
req.EncodeMsg(w)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -115,30 +136,46 @@ func (rpc *RPCConn) RegisterHandler(name string, fn ServiceFunc) error {
|
||||||
// TODO: mutex lock for sync (or use sync.map?
|
// TODO: mutex lock for sync (or use sync.map?
|
||||||
rpc.handlers[name] = fn
|
rpc.handlers[name] = fn
|
||||||
|
|
||||||
rpc.Logger.Info("registered a new handler", "name", name, "fn", fn)
|
rpc.logger.Info("registered a new handler", "name", name, "fn", fn)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Serve runs the server. It will dispatch goroutines to handle each
|
|
||||||
// method call. This can (and should in most cases) be run in the background to allow for
|
// Removes a handler, if it exists. Never errors. No-op if the name
|
||||||
// sending and receving on the same connection.
|
// is not a registered handler.
|
||||||
|
func (rpc *RPCConn) RemoveHandler(name string) error {
|
||||||
|
delete(rpc.handlers, name)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Serve runs the server. It will dispatch goroutines to handle each method
|
||||||
|
// call. This can (and should in most cases) be run in the background to allow
|
||||||
|
// for sending and receving on the same connection.
|
||||||
func (rpc *RPCConn) Serve() {
|
func (rpc *RPCConn) Serve() {
|
||||||
|
|
||||||
// construct a stream reader.
|
// construct a stream reader.
|
||||||
msgReader := msgp.NewReader(rpc.conn)
|
msgReader := msgp.NewReader(rpc.rwc)
|
||||||
|
|
||||||
// read a request/notification from the connection.
|
// read a request/notification from the connection.
|
||||||
|
|
||||||
var rawmsg msgp.Raw = make(msgp.Raw, 0, 4)
|
var rawmsg msgp.Raw = make(msgp.Raw, 0, 4)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
rawmsg.DecodeMsg(msgReader)
|
err := rawmsg.DecodeMsg(msgReader)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, io.EOF) {
|
||||||
|
rpc.logger.Info("reached EOF, stopping server")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
rpc.logger.Warn("error decoding message", "err", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
rpcIntf, err := parseRPC(rawmsg)
|
rpcIntf, err := parseRPC(rawmsg)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
rpc.Logger.Warn("Could not parse RPC message", "err", err)
|
rpc.logger.Warn("Could not parse RPC message", "err", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -152,26 +189,33 @@ func (rpc *RPCConn) Serve() {
|
||||||
case Response:
|
case Response:
|
||||||
cbCh, err := rpc.ct.Clear(rpcObject.MsgId)
|
cbCh, err := rpc.ct.Clear(rpcObject.MsgId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
rpc.Logger.Warn("could not get rpc callback", "msgid", rpcObject.MsgId, "err", err)
|
rpc.logger.Warn("could not get rpc callback", "msgid", rpcObject.MsgId, "err", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
cbCh <- rpcObject
|
cbCh <- rpcObject
|
||||||
|
default:
|
||||||
|
panic("invalid rpcObject!")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
// INTERNAL functions for rpcConn
|
||||||
|
|
||||||
// dispatch is an internal method used to execute a Request sent by the remote:w
|
// dispatch is an internal method used to execute a Request sent by the remote:w
|
||||||
func (rpc *RPCConn) dispatch(req Request) {
|
func (rpc *RPCConn) dispatch(req Request) {
|
||||||
|
|
||||||
result, err := rpc.handlers[req.Method](req.Params)
|
result, err := rpc.handlers[req.Method](req.Params)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
rpc.Logger.Warn("error dispatching rpc function", "method", req.Method, "err", err)
|
rpc.logger.Warn("error dispatching rpc function", "method", req.Method, "err", err)
|
||||||
}
|
}
|
||||||
// construct the response frame.
|
// construct the response frame.
|
||||||
var rpcE *RPCError = MakeRPCError(err)
|
var rpcE *RPCError = MakeRPCError(err)
|
||||||
|
|
||||||
w := msgp.NewWriter(rpc.conn)
|
w := msgp.NewWriter(rpc.rwc)
|
||||||
|
|
||||||
response := NewResponse(req.MsgId, *rpcE, result)
|
response := NewResponse(req.MsgId, *rpcE, result)
|
||||||
|
|
||||||
|
@ -187,7 +231,7 @@ func (rpc *RPCConn) dispatchNotif(req Notification) {
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// log the error, but don't do anything about it.
|
// log the error, but don't do anything about it.
|
||||||
rpc.Logger.Warn("error dispatching rpc function", "method", req.Method, "err", err)
|
rpc.logger.Warn("error dispatching rpc function", "method", req.Method, "err", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -160,3 +160,11 @@ func MakeRPCError(err error) *RPCError {
|
||||||
func (r *RPCError) Error() string {
|
func (r *RPCError) Error() string {
|
||||||
return r.Desc
|
return r.Desc
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
// we need to describe an empty data that will be excluded in the msgp
|
||||||
|
// for functions without an argument or return value.
|
||||||
|
type RPCEmpty struct {
|
||||||
|
}
|
||||||
|
|
|
@ -97,6 +97,89 @@ func (z *Notification) Msgsize() (s int) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DecodeMsg implements msgp.Decodable
|
||||||
|
func (z *RPCEmpty) DecodeMsg(dc *msgp.Reader) (err error) {
|
||||||
|
var field []byte
|
||||||
|
_ = field
|
||||||
|
var zb0001 uint32
|
||||||
|
zb0001, err = dc.ReadMapHeader()
|
||||||
|
if err != nil {
|
||||||
|
err = msgp.WrapError(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for zb0001 > 0 {
|
||||||
|
zb0001--
|
||||||
|
field, err = dc.ReadMapKeyPtr()
|
||||||
|
if err != nil {
|
||||||
|
err = msgp.WrapError(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
switch msgp.UnsafeString(field) {
|
||||||
|
default:
|
||||||
|
err = dc.Skip()
|
||||||
|
if err != nil {
|
||||||
|
err = msgp.WrapError(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncodeMsg implements msgp.Encodable
|
||||||
|
func (z RPCEmpty) EncodeMsg(en *msgp.Writer) (err error) {
|
||||||
|
// map header, size 0
|
||||||
|
err = en.Append(0x80)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalMsg implements msgp.Marshaler
|
||||||
|
func (z RPCEmpty) MarshalMsg(b []byte) (o []byte, err error) {
|
||||||
|
o = msgp.Require(b, z.Msgsize())
|
||||||
|
// map header, size 0
|
||||||
|
o = append(o, 0x80)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalMsg implements msgp.Unmarshaler
|
||||||
|
func (z *RPCEmpty) UnmarshalMsg(bts []byte) (o []byte, err error) {
|
||||||
|
var field []byte
|
||||||
|
_ = field
|
||||||
|
var zb0001 uint32
|
||||||
|
zb0001, bts, err = msgp.ReadMapHeaderBytes(bts)
|
||||||
|
if err != nil {
|
||||||
|
err = msgp.WrapError(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for zb0001 > 0 {
|
||||||
|
zb0001--
|
||||||
|
field, bts, err = msgp.ReadMapKeyZC(bts)
|
||||||
|
if err != nil {
|
||||||
|
err = msgp.WrapError(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
switch msgp.UnsafeString(field) {
|
||||||
|
default:
|
||||||
|
bts, err = msgp.Skip(bts)
|
||||||
|
if err != nil {
|
||||||
|
err = msgp.WrapError(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
o = bts
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
|
||||||
|
func (z RPCEmpty) Msgsize() (s int) {
|
||||||
|
s = 1
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// DecodeMsg implements msgp.Decodable
|
// DecodeMsg implements msgp.Decodable
|
||||||
func (z *RPCError) DecodeMsg(dc *msgp.Reader) (err error) {
|
func (z *RPCError) DecodeMsg(dc *msgp.Reader) (err error) {
|
||||||
var zb0001 uint32
|
var zb0001 uint32
|
||||||
|
|
|
@ -122,6 +122,119 @@ func BenchmarkDecodeNotification(b *testing.B) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMarshalUnmarshalRPCEmpty(t *testing.T) {
|
||||||
|
v := RPCEmpty{}
|
||||||
|
bts, err := v.MarshalMsg(nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
left, err := v.UnmarshalMsg(bts)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(left) > 0 {
|
||||||
|
t.Errorf("%d bytes left over after UnmarshalMsg(): %q", len(left), left)
|
||||||
|
}
|
||||||
|
|
||||||
|
left, err = msgp.Skip(bts)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(left) > 0 {
|
||||||
|
t.Errorf("%d bytes left over after Skip(): %q", len(left), left)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkMarshalMsgRPCEmpty(b *testing.B) {
|
||||||
|
v := RPCEmpty{}
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
v.MarshalMsg(nil)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkAppendMsgRPCEmpty(b *testing.B) {
|
||||||
|
v := RPCEmpty{}
|
||||||
|
bts := make([]byte, 0, v.Msgsize())
|
||||||
|
bts, _ = v.MarshalMsg(bts[0:0])
|
||||||
|
b.SetBytes(int64(len(bts)))
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
bts, _ = v.MarshalMsg(bts[0:0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkUnmarshalRPCEmpty(b *testing.B) {
|
||||||
|
v := RPCEmpty{}
|
||||||
|
bts, _ := v.MarshalMsg(nil)
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.SetBytes(int64(len(bts)))
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, err := v.UnmarshalMsg(bts)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeDecodeRPCEmpty(t *testing.T) {
|
||||||
|
v := RPCEmpty{}
|
||||||
|
var buf bytes.Buffer
|
||||||
|
msgp.Encode(&buf, &v)
|
||||||
|
|
||||||
|
m := v.Msgsize()
|
||||||
|
if buf.Len() > m {
|
||||||
|
t.Log("WARNING: TestEncodeDecodeRPCEmpty Msgsize() is inaccurate")
|
||||||
|
}
|
||||||
|
|
||||||
|
vn := RPCEmpty{}
|
||||||
|
err := msgp.Decode(&buf, &vn)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
buf.Reset()
|
||||||
|
msgp.Encode(&buf, &v)
|
||||||
|
err = msgp.NewReader(&buf).Skip()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkEncodeRPCEmpty(b *testing.B) {
|
||||||
|
v := RPCEmpty{}
|
||||||
|
var buf bytes.Buffer
|
||||||
|
msgp.Encode(&buf, &v)
|
||||||
|
b.SetBytes(int64(buf.Len()))
|
||||||
|
en := msgp.NewWriter(msgp.Nowhere)
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
v.EncodeMsg(en)
|
||||||
|
}
|
||||||
|
en.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkDecodeRPCEmpty(b *testing.B) {
|
||||||
|
v := RPCEmpty{}
|
||||||
|
var buf bytes.Buffer
|
||||||
|
msgp.Encode(&buf, &v)
|
||||||
|
b.SetBytes(int64(buf.Len()))
|
||||||
|
rd := msgp.NewEndlessReader(buf.Bytes(), b)
|
||||||
|
dc := msgp.NewReader(rd)
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
err := v.DecodeMsg(dc)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestMarshalUnmarshalRPCError(t *testing.T) {
|
func TestMarshalUnmarshalRPCError(t *testing.T) {
|
||||||
v := RPCError{}
|
v := RPCError{}
|
||||||
bts, err := v.MarshalMsg(nil)
|
bts, err := v.MarshalMsg(nil)
|
||||||
|
|
1
mprpc/rpc_test.go
Normal file
1
mprpc/rpc_test.go
Normal file
|
@ -0,0 +1 @@
|
||||||
|
package mprpc_test
|
|
@ -13,6 +13,13 @@ type rpcConnTrack struct {
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
func NewRPCConnTrack() rpcConnTrack {
|
||||||
|
return rpcConnTrack{
|
||||||
|
ct: make(map[uint32]chan Response),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Get attempts to get a random mark from the mutex.
|
// Get attempts to get a random mark from the mutex.
|
||||||
func (c *rpcConnTrack) Claim() (uint32, chan Response) {
|
func (c *rpcConnTrack) Claim() (uint32, chan Response) {
|
||||||
var val uint32
|
var val uint32
|
||||||
|
|
Loading…
Reference in a new issue