Skip to content

Commit 59d6e45

Browse files
authored
Merge pull request #123 from gliderlabs/optimize-add-host-key
Update AddHostKey to avoid always appending
2 parents 63518b5 + 1db07d8 commit 59d6e45

File tree

3 files changed

+93
-11
lines changed

3 files changed

+93
-11
lines changed

server.go

+37-1
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,17 @@ type Server struct {
5858
RequestHandlers map[string]RequestHandler
5959

6060
listenerWg sync.WaitGroup
61-
mu sync.Mutex
61+
mu sync.RWMutex
6262
listeners map[net.Listener]struct{}
6363
conns map[*gossh.ServerConn]struct{}
6464
connWg sync.WaitGroup
6565
doneChan chan struct{}
6666
}
6767

6868
func (srv *Server) ensureHostSigner() error {
69+
srv.mu.Lock()
70+
defer srv.mu.Unlock()
71+
6972
if len(srv.HostSigners) == 0 {
7073
signer, err := generateSigner()
7174
if err != nil {
@@ -79,6 +82,7 @@ func (srv *Server) ensureHostSigner() error {
7982
func (srv *Server) ensureHandlers() {
8083
srv.mu.Lock()
8184
defer srv.mu.Unlock()
85+
8286
if srv.RequestHandlers == nil {
8387
srv.RequestHandlers = map[string]RequestHandler{}
8488
for k, v := range DefaultRequestHandlers {
@@ -94,6 +98,9 @@ func (srv *Server) ensureHandlers() {
9498
}
9599

96100
func (srv *Server) config(ctx Context) *gossh.ServerConfig {
101+
srv.mu.RLock()
102+
defer srv.mu.RUnlock()
103+
97104
var config *gossh.ServerConfig
98105
if srv.ServerConfigCallback == nil {
99106
config = &gossh.ServerConfig{}
@@ -142,6 +149,9 @@ func (srv *Server) config(ctx Context) *gossh.ServerConfig {
142149

143150
// Handle sets the Handler for the server.
144151
func (srv *Server) Handle(fn Handler) {
152+
srv.mu.Lock()
153+
defer srv.mu.Unlock()
154+
145155
srv.Handler = fn
146156
}
147157

@@ -153,6 +163,7 @@ func (srv *Server) Handle(fn Handler) {
153163
func (srv *Server) Close() error {
154164
srv.mu.Lock()
155165
defer srv.mu.Unlock()
166+
156167
srv.closeDoneChanLocked()
157168
err := srv.closeListenersLocked()
158169
for c := range srv.conns {
@@ -313,19 +324,42 @@ func (srv *Server) ListenAndServe() error {
313324
// with the same algorithm, it is overwritten. Each server config must have at
314325
// least one host key.
315326
func (srv *Server) AddHostKey(key Signer) {
327+
srv.mu.Lock()
328+
defer srv.mu.Unlock()
329+
316330
// these are later added via AddHostKey on ServerConfig, which performs the
317331
// check for one of every algorithm.
332+
333+
// This check is based on the AddHostKey method from the x/crypto/ssh
334+
// library. This allows us to only keep one active key for each type on a
335+
// server at once. So, if you're dynamically updating keys at runtime, this
336+
// list will not keep growing.
337+
for i, k := range srv.HostSigners {
338+
if k.PublicKey().Type() == key.PublicKey().Type() {
339+
srv.HostSigners[i] = key
340+
return
341+
}
342+
}
343+
318344
srv.HostSigners = append(srv.HostSigners, key)
319345
}
320346

321347
// SetOption runs a functional option against the server.
322348
func (srv *Server) SetOption(option Option) error {
349+
// NOTE: there is a potential race here for any option that doesn't call an
350+
// internal method. We can't actually lock here because if something calls
351+
// (as an example) AddHostKey, it will deadlock.
352+
353+
//srv.mu.Lock()
354+
//defer srv.mu.Unlock()
355+
323356
return option(srv)
324357
}
325358

326359
func (srv *Server) getDoneChan() <-chan struct{} {
327360
srv.mu.Lock()
328361
defer srv.mu.Unlock()
362+
329363
return srv.getDoneChanLocked()
330364
}
331365

@@ -362,6 +396,7 @@ func (srv *Server) closeListenersLocked() error {
362396
func (srv *Server) trackListener(ln net.Listener, add bool) {
363397
srv.mu.Lock()
364398
defer srv.mu.Unlock()
399+
365400
if srv.listeners == nil {
366401
srv.listeners = make(map[net.Listener]struct{})
367402
}
@@ -382,6 +417,7 @@ func (srv *Server) trackListener(ln net.Listener, add bool) {
382417
func (srv *Server) trackConn(c *gossh.ServerConn, add bool) {
383418
srv.mu.Lock()
384419
defer srv.mu.Unlock()
420+
385421
if srv.conns == nil {
386422
srv.conns = make(map[*gossh.ServerConn]struct{})
387423
}

server_test.go

+20
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,26 @@ import (
88
"time"
99
)
1010

11+
func TestAddHostKey(t *testing.T) {
12+
s := Server{}
13+
signer, err := generateSigner()
14+
if err != nil {
15+
t.Fatal(err)
16+
}
17+
s.AddHostKey(signer)
18+
if len(s.HostSigners) != 1 {
19+
t.Fatal("Key was not properly added")
20+
}
21+
signer, err = generateSigner()
22+
if err != nil {
23+
t.Fatal(err)
24+
}
25+
s.AddHostKey(signer)
26+
if len(s.HostSigners) != 1 {
27+
t.Fatal("Key was not properly replaced")
28+
}
29+
}
30+
1131
func TestServerShutdown(t *testing.T) {
1232
l := newLocalListener()
1333
testBytes := []byte("Hello world\n")

session_test.go

+36-10
Original file line numberDiff line numberDiff line change
@@ -289,20 +289,40 @@ func TestPtyResize(t *testing.T) {
289289
func TestSignals(t *testing.T) {
290290
t.Parallel()
291291

292+
// errChan lets us get errors back from the session
293+
errChan := make(chan error, 5)
294+
295+
// doneChan lets us specify that we should exit.
296+
doneChan := make(chan interface{})
297+
292298
session, _, cleanup := newTestSession(t, &Server{
293299
Handler: func(s Session) {
294-
signals := make(chan Signal)
300+
// We need to use a buffered channel here, otherwise it's possible for the
301+
// second call to Signal to get discarded.
302+
signals := make(chan Signal, 2)
295303
s.Signals(signals)
296-
if sig := <-signals; sig != SIGINT {
297-
t.Fatalf("expected signal %v but got %v", SIGINT, sig)
304+
305+
select {
306+
case sig := <-signals:
307+
if sig != SIGINT {
308+
errChan <- fmt.Errorf("expected signal %v but got %v", SIGINT, sig)
309+
return
310+
}
311+
case <-doneChan:
312+
errChan <- fmt.Errorf("Unexpected done")
313+
return
298314
}
299-
exiter := make(chan bool)
300-
go func() {
301-
if sig := <-signals; sig == SIGKILL {
302-
close(exiter)
315+
316+
select {
317+
case sig := <-signals:
318+
if sig != SIGKILL {
319+
errChan <- fmt.Errorf("expected signal %v but got %v", SIGKILL, sig)
320+
return
303321
}
304-
}()
305-
<-exiter
322+
case <-doneChan:
323+
errChan <- fmt.Errorf("Unexpected done")
324+
return
325+
}
306326
},
307327
}, nil)
308328
defer cleanup()
@@ -312,7 +332,13 @@ func TestSignals(t *testing.T) {
312332
session.Signal(gossh.SIGKILL)
313333
}()
314334

315-
err := session.Run("")
335+
go func() {
336+
errChan <- session.Run("")
337+
}()
338+
339+
err := <-errChan
340+
close(doneChan)
341+
316342
if err != nil {
317343
t.Fatalf("expected nil but got %v", err)
318344
}

0 commit comments

Comments
 (0)