Skip to content

Commit dcff890

Browse files
committed
rpc: abstract client and server encodings
R=r CC=golang-dev, rog https://golang.org/cl/811046
1 parent 72f9b2e commit dcff890

File tree

2 files changed

+131
-41
lines changed

2 files changed

+131
-41
lines changed

src/pkg/rpc/client.go

Lines changed: 56 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,25 @@ type Client struct {
3333
shutdown os.Error // non-nil if the client is shut down
3434
sending sync.Mutex
3535
seq uint64
36-
conn io.ReadWriteCloser
37-
enc *gob.Encoder
38-
dec *gob.Decoder
36+
codec ClientCodec
3937
pending map[uint64]*Call
4038
closing bool
4139
}
4240

41+
// A ClientCodec implements writing of RPC requests and
42+
// reading of RPC responses for the client side of an RPC session.
43+
// The client calls WriteRequest to write a request to the connection
44+
// and calls ReadResponseHeader and ReadResponseBody in pairs
45+
// to read responses. The client calls Close when finished with the
46+
// connection.
47+
type ClientCodec interface {
48+
WriteRequest(*Request, interface{}) os.Error
49+
ReadResponseHeader(*Response) os.Error
50+
ReadResponseBody(interface{}) os.Error
51+
52+
Close() os.Error
53+
}
54+
4355
func (client *Client) send(c *Call) {
4456
// Register this call.
4557
client.mutex.Lock()
@@ -59,9 +71,7 @@ func (client *Client) send(c *Call) {
5971
client.sending.Lock()
6072
request.Seq = c.seq
6173
request.ServiceMethod = c.ServiceMethod
62-
client.enc.Encode(request)
63-
err := client.enc.Encode(c.Args)
64-
if err != nil {
74+
if err := client.codec.WriteRequest(request, c.Args); err != nil {
6575
panic("rpc: client encode error: " + err.String())
6676
}
6777
client.sending.Unlock()
@@ -71,7 +81,7 @@ func (client *Client) input() {
7181
var err os.Error
7282
for err == nil {
7383
response := new(Response)
74-
err = client.dec.Decode(response)
84+
err = client.codec.ReadResponseHeader(response)
7585
if err != nil {
7686
if err == os.EOF && !client.closing {
7787
err = io.ErrUnexpectedEOF
@@ -83,7 +93,7 @@ func (client *Client) input() {
8393
c := client.pending[seq]
8494
client.pending[seq] = c, false
8595
client.mutex.Unlock()
86-
err = client.dec.Decode(c.Reply)
96+
err = client.codec.ReadResponseBody(c.Reply)
8797
// Empty strings should turn into nil os.Errors
8898
if response.Error != "" {
8999
c.Error = os.ErrorString(response.Error)
@@ -110,17 +120,49 @@ func (client *Client) input() {
110120
// NewClient returns a new Client to handle requests to the
111121
// set of services at the other end of the connection.
112122
func NewClient(conn io.ReadWriteCloser) *Client {
113-
client := new(Client)
114-
client.conn = conn
115-
client.enc = gob.NewEncoder(conn)
116-
client.dec = gob.NewDecoder(conn)
117-
client.pending = make(map[uint64]*Call)
123+
return NewClientWithCodec(&gobClientCodec{conn, gob.NewDecoder(conn), gob.NewEncoder(conn)})
124+
}
125+
126+
// NewClientWithCodec is like NewClient but uses the specified
127+
// codec to encode requests and decode responses.
128+
func NewClientWithCodec(codec ClientCodec) *Client {
129+
client := &Client{
130+
codec: codec,
131+
pending: make(map[uint64]*Call),
132+
}
118133
go client.input()
119134
return client
120135
}
121136

137+
type gobClientCodec struct {
138+
rwc io.ReadWriteCloser
139+
dec *gob.Decoder
140+
enc *gob.Encoder
141+
}
142+
143+
func (c *gobClientCodec) WriteRequest(r *Request, body interface{}) os.Error {
144+
if err := c.enc.Encode(r); err != nil {
145+
return err
146+
}
147+
return c.enc.Encode(body)
148+
}
149+
150+
func (c *gobClientCodec) ReadResponseHeader(r *Response) os.Error {
151+
return c.dec.Decode(r)
152+
}
153+
154+
func (c *gobClientCodec) ReadResponseBody(body interface{}) os.Error {
155+
return c.dec.Decode(body)
156+
}
157+
158+
func (c *gobClientCodec) Close() os.Error {
159+
return c.rwc.Close()
160+
}
161+
162+
122163
// DialHTTP connects to an HTTP RPC server at the specified network address.
123164
func DialHTTP(network, address string) (*Client, os.Error) {
165+
var err os.Error
124166
conn, err := net.Dial(network, "", address)
125167
if err != nil {
126168
return nil, err
@@ -156,7 +198,7 @@ func (client *Client) Close() os.Error {
156198
client.mutex.Lock()
157199
client.closing = true
158200
client.mutex.Unlock()
159-
return client.conn.Close()
201+
return client.codec.Close()
160202
}
161203

162204
// Go invokes the function asynchronously. It returns the Call structure representing

src/pkg/rpc/server.go

Lines changed: 75 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ func _new(t *reflect.PtrType) *reflect.PtrValue {
272272
return v
273273
}
274274

275-
func sendResponse(sending *sync.Mutex, req *Request, reply interface{}, enc *gob.Encoder, errmsg string) {
275+
func sendResponse(sending *sync.Mutex, req *Request, reply interface{}, codec ServerCodec, errmsg string) {
276276
resp := new(Response)
277277
// Encode the response header
278278
resp.ServiceMethod = req.ServiceMethod
@@ -281,13 +281,14 @@ func sendResponse(sending *sync.Mutex, req *Request, reply interface{}, enc *gob
281281
}
282282
resp.Seq = req.Seq
283283
sending.Lock()
284-
enc.Encode(resp)
285-
// Encode the reply value.
286-
enc.Encode(reply)
284+
err := codec.WriteResponse(resp, reply)
285+
if err != nil {
286+
log.Stderr("rpc: writing response: ", err)
287+
}
287288
sending.Unlock()
288289
}
289290

290-
func (s *service) call(sending *sync.Mutex, mtype *methodType, req *Request, argv, replyv reflect.Value, enc *gob.Encoder) {
291+
func (s *service) call(sending *sync.Mutex, mtype *methodType, req *Request, argv, replyv reflect.Value, codec ServerCodec) {
291292
mtype.Lock()
292293
mtype.numCalls++
293294
mtype.Unlock()
@@ -300,17 +301,40 @@ func (s *service) call(sending *sync.Mutex, mtype *methodType, req *Request, arg
300301
if errInter != nil {
301302
errmsg = errInter.(os.Error).String()
302303
}
303-
sendResponse(sending, req, replyv.Interface(), enc, errmsg)
304+
sendResponse(sending, req, replyv.Interface(), codec, errmsg)
305+
}
306+
307+
type gobServerCodec struct {
308+
rwc io.ReadWriteCloser
309+
dec *gob.Decoder
310+
enc *gob.Encoder
311+
}
312+
313+
func (c *gobServerCodec) ReadRequestHeader(r *Request) os.Error {
314+
return c.dec.Decode(r)
315+
}
316+
317+
func (c *gobServerCodec) ReadRequestBody(body interface{}) os.Error {
318+
return c.dec.Decode(body)
319+
}
320+
321+
func (c *gobServerCodec) WriteResponse(r *Response, body interface{}) os.Error {
322+
if err := c.enc.Encode(r); err != nil {
323+
return err
324+
}
325+
return c.enc.Encode(body)
304326
}
305327

306-
func (server *serverType) input(conn io.ReadWriteCloser) {
307-
dec := gob.NewDecoder(conn)
308-
enc := gob.NewEncoder(conn)
328+
func (c *gobServerCodec) Close() os.Error {
329+
return c.rwc.Close()
330+
}
331+
332+
func (server *serverType) input(codec ServerCodec) {
309333
sending := new(sync.Mutex)
310334
for {
311335
// Grab the request header.
312336
req := new(Request)
313-
err := dec.Decode(req)
337+
err := codec.ReadRequestHeader(req)
314338
if err != nil {
315339
if err == os.EOF || err == io.ErrUnexpectedEOF {
316340
if err == io.ErrUnexpectedEOF {
@@ -319,13 +343,13 @@ func (server *serverType) input(conn io.ReadWriteCloser) {
319343
break
320344
}
321345
s := "rpc: server cannot decode request: " + err.String()
322-
sendResponse(sending, req, invalidRequest, enc, s)
323-
continue
346+
sendResponse(sending, req, invalidRequest, codec, s)
347+
break
324348
}
325349
serviceMethod := strings.Split(req.ServiceMethod, ".", 0)
326350
if len(serviceMethod) != 2 {
327-
s := "rpc: service/method request ill:formed: " + req.ServiceMethod
328-
sendResponse(sending, req, invalidRequest, enc, s)
351+
s := "rpc: service/method request ill-formed: " + req.ServiceMethod
352+
sendResponse(sending, req, invalidRequest, codec, s)
329353
continue
330354
}
331355
// Look up the request.
@@ -334,27 +358,27 @@ func (server *serverType) input(conn io.ReadWriteCloser) {
334358
server.Unlock()
335359
if !ok {
336360
s := "rpc: can't find service " + req.ServiceMethod
337-
sendResponse(sending, req, invalidRequest, enc, s)
361+
sendResponse(sending, req, invalidRequest, codec, s)
338362
continue
339363
}
340364
mtype, ok := service.method[serviceMethod[1]]
341365
if !ok {
342366
s := "rpc: can't find method " + req.ServiceMethod
343-
sendResponse(sending, req, invalidRequest, enc, s)
367+
sendResponse(sending, req, invalidRequest, codec, s)
344368
continue
345369
}
346370
// Decode the argument value.
347371
argv := _new(mtype.argType)
348372
replyv := _new(mtype.replyType)
349-
err = dec.Decode(argv.Interface())
373+
err = codec.ReadRequestBody(argv.Interface())
350374
if err != nil {
351375
log.Stderr("rpc: tearing down", serviceMethod[0], "connection:", err)
352-
sendResponse(sending, req, replyv.Interface(), enc, err.String())
353-
continue
376+
sendResponse(sending, req, replyv.Interface(), codec, err.String())
377+
break
354378
}
355-
go service.call(sending, mtype, req, argv, replyv, enc)
379+
go service.call(sending, mtype, req, argv, replyv, codec)
356380
}
357-
conn.Close()
381+
codec.Close()
358382
}
359383

360384
func (server *serverType) accept(lis net.Listener) {
@@ -363,7 +387,7 @@ func (server *serverType) accept(lis net.Listener) {
363387
if err != nil {
364388
log.Exit("rpc.Serve: accept:", err.String()) // TODO(r): exit?
365389
}
366-
go server.input(conn)
390+
go ServeConn(conn)
367391
}
368392
}
369393

@@ -376,10 +400,34 @@ func (server *serverType) accept(lis net.Listener) {
376400
// suitable methods.
377401
func Register(rcvr interface{}) os.Error { return server.register(rcvr) }
378402

379-
// ServeConn runs the server on a single connection. When the connection
380-
// completes, service terminates. ServeConn blocks; the caller typically
381-
// invokes it in a go statement.
382-
func ServeConn(conn io.ReadWriteCloser) { server.input(conn) }
403+
// A ServerCodec implements reading of RPC requests and writing of
404+
// RPC responses for the server side of an RPC session.
405+
// The server calls ReadRequestHeader and ReadRequestBody in pairs
406+
// to read requests from the connection, and it calls WriteResponse to
407+
// write a response back. The server calls Close when finished with the
408+
// connection.
409+
type ServerCodec interface {
410+
ReadRequestHeader(*Request) os.Error
411+
ReadRequestBody(interface{}) os.Error
412+
WriteResponse(*Response, interface{}) os.Error
413+
414+
Close() os.Error
415+
}
416+
417+
// ServeConn runs the server on a single connection.
418+
// ServeConn blocks, serving the connection until the client hangs up.
419+
// The caller typically invokes ServeConn in a go statement.
420+
// ServeConn uses the gob wire format (see package gob) on the
421+
// connection. To use an alternate codec, use ServeCodec.
422+
func ServeConn(conn io.ReadWriteCloser) {
423+
ServeCodec(&gobServerCodec{conn, gob.NewDecoder(conn), gob.NewEncoder(conn)})
424+
}
425+
426+
// ServeCodec is like ServeConn but uses the specified codec to
427+
// decode requests and encode responses.
428+
func ServeCodec(codec ServerCodec) {
429+
server.input(codec)
430+
}
383431

384432
// Accept accepts connections on the listener and serves requests
385433
// for each incoming connection. Accept blocks; the caller typically
@@ -404,7 +452,7 @@ func serveHTTP(c *http.Conn, req *http.Request) {
404452
return
405453
}
406454
io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n")
407-
server.input(conn)
455+
ServeConn(conn)
408456
}
409457

410458
// HandleHTTP registers an HTTP handler for RPC messages.

0 commit comments

Comments
 (0)