@@ -58,14 +58,17 @@ type Server struct {
58
58
RequestHandlers map [string ]RequestHandler
59
59
60
60
listenerWg sync.WaitGroup
61
- mu sync.Mutex
61
+ mu sync.RWMutex
62
62
listeners map [net.Listener ]struct {}
63
63
conns map [* gossh.ServerConn ]struct {}
64
64
connWg sync.WaitGroup
65
65
doneChan chan struct {}
66
66
}
67
67
68
68
func (srv * Server ) ensureHostSigner () error {
69
+ srv .mu .Lock ()
70
+ defer srv .mu .Unlock ()
71
+
69
72
if len (srv .HostSigners ) == 0 {
70
73
signer , err := generateSigner ()
71
74
if err != nil {
@@ -79,6 +82,7 @@ func (srv *Server) ensureHostSigner() error {
79
82
func (srv * Server ) ensureHandlers () {
80
83
srv .mu .Lock ()
81
84
defer srv .mu .Unlock ()
85
+
82
86
if srv .RequestHandlers == nil {
83
87
srv .RequestHandlers = map [string ]RequestHandler {}
84
88
for k , v := range DefaultRequestHandlers {
@@ -94,6 +98,9 @@ func (srv *Server) ensureHandlers() {
94
98
}
95
99
96
100
func (srv * Server ) config (ctx Context ) * gossh.ServerConfig {
101
+ srv .mu .RLock ()
102
+ defer srv .mu .RUnlock ()
103
+
97
104
var config * gossh.ServerConfig
98
105
if srv .ServerConfigCallback == nil {
99
106
config = & gossh.ServerConfig {}
@@ -142,6 +149,9 @@ func (srv *Server) config(ctx Context) *gossh.ServerConfig {
142
149
143
150
// Handle sets the Handler for the server.
144
151
func (srv * Server ) Handle (fn Handler ) {
152
+ srv .mu .Lock ()
153
+ defer srv .mu .Unlock ()
154
+
145
155
srv .Handler = fn
146
156
}
147
157
@@ -153,6 +163,7 @@ func (srv *Server) Handle(fn Handler) {
153
163
func (srv * Server ) Close () error {
154
164
srv .mu .Lock ()
155
165
defer srv .mu .Unlock ()
166
+
156
167
srv .closeDoneChanLocked ()
157
168
err := srv .closeListenersLocked ()
158
169
for c := range srv .conns {
@@ -313,19 +324,42 @@ func (srv *Server) ListenAndServe() error {
313
324
// with the same algorithm, it is overwritten. Each server config must have at
314
325
// least one host key.
315
326
func (srv * Server ) AddHostKey (key Signer ) {
327
+ srv .mu .Lock ()
328
+ defer srv .mu .Unlock ()
329
+
316
330
// these are later added via AddHostKey on ServerConfig, which performs the
317
331
// check for one of every algorithm.
332
+
333
+ // This check is based on the AddHostKey method from the x/crypto/ssh
334
+ // library. This allows us to only keep one active key for each type on a
335
+ // server at once. So, if you're dynamically updating keys at runtime, this
336
+ // list will not keep growing.
337
+ for i , k := range srv .HostSigners {
338
+ if k .PublicKey ().Type () == key .PublicKey ().Type () {
339
+ srv .HostSigners [i ] = key
340
+ return
341
+ }
342
+ }
343
+
318
344
srv .HostSigners = append (srv .HostSigners , key )
319
345
}
320
346
321
347
// SetOption runs a functional option against the server.
322
348
func (srv * Server ) SetOption (option Option ) error {
349
+ // NOTE: there is a potential race here for any option that doesn't call an
350
+ // internal method. We can't actually lock here because if something calls
351
+ // (as an example) AddHostKey, it will deadlock.
352
+
353
+ //srv.mu.Lock()
354
+ //defer srv.mu.Unlock()
355
+
323
356
return option (srv )
324
357
}
325
358
326
359
func (srv * Server ) getDoneChan () <- chan struct {} {
327
360
srv .mu .Lock ()
328
361
defer srv .mu .Unlock ()
362
+
329
363
return srv .getDoneChanLocked ()
330
364
}
331
365
@@ -362,6 +396,7 @@ func (srv *Server) closeListenersLocked() error {
362
396
func (srv * Server ) trackListener (ln net.Listener , add bool ) {
363
397
srv .mu .Lock ()
364
398
defer srv .mu .Unlock ()
399
+
365
400
if srv .listeners == nil {
366
401
srv .listeners = make (map [net.Listener ]struct {})
367
402
}
@@ -382,6 +417,7 @@ func (srv *Server) trackListener(ln net.Listener, add bool) {
382
417
func (srv * Server ) trackConn (c * gossh.ServerConn , add bool ) {
383
418
srv .mu .Lock ()
384
419
defer srv .mu .Unlock ()
420
+
385
421
if srv .conns == nil {
386
422
srv .conns = make (map [* gossh.ServerConn ]struct {})
387
423
}
0 commit comments