diff --git a/example_ssh_test.go b/example_ssh_test.go index 544f717..6a707cf 100644 --- a/example_ssh_test.go +++ b/example_ssh_test.go @@ -7,6 +7,7 @@ import ( "golang.org/x/crypto/ssh" "nemith.io/netconf" + "nemith.io/netconf/rpc" ncssh "nemith.io/netconf/transport/ssh" ) @@ -37,7 +38,7 @@ func Example_ssh() { // timeout for the call itself. ctx, cancel = context.WithTimeout(ctx, 5*time.Second) defer cancel() - deviceConfig, err := session.GetConfig(ctx, "running") + deviceConfig, err := rpc.GetConfig{Source: rpc.Running}.Exec(ctx, session) if err != nil { log.Fatalf("failed to get config: %v", err) } diff --git a/example_tls_test.go b/example_tls_test.go index c7c9ab7..b68f185 100644 --- a/example_tls_test.go +++ b/example_tls_test.go @@ -11,6 +11,7 @@ import ( "time" "nemith.io/netconf" + "nemith.io/netconf/rpc" nctls "nemith.io/netconf/transport/tls" ) @@ -67,7 +68,7 @@ func Example_tls() { ctx, cancel = context.WithTimeout(ctx, 5*time.Second) defer cancel() - cfg, err := session.GetConfig(ctx, "running") + cfg, err := rpc.GetConfig{Source: rpc.Running}.Exec(ctx, session) if err != nil { panic(err) } diff --git a/inttest/ssh_test.go b/inttest/ssh_test.go index 0115705..93a8fbc 100644 --- a/inttest/ssh_test.go +++ b/inttest/ssh_test.go @@ -15,6 +15,7 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/crypto/ssh" "nemith.io/netconf" + "nemith.io/netconf/rpc" ncssh "nemith.io/netconf/transport/ssh" ) @@ -79,8 +80,8 @@ func setupSSH(t *testing.T) *netconf.Session { require.NoErrorf(t, err, "failed to connect to dut %q", addr) // capture the framed communication - inCap := newLogWriter("<<<", t) - outCap := newLogWriter(">>>", t) + inCap := newLogWriter("S: ", t) + outCap := newLogWriter("C: ", t) tr.DebugCapture(inCap, outCap) @@ -101,7 +102,7 @@ func TestSSHGetConfig(t *testing.T) { session := setupSSH(t) ctx := context.Background() - config, err := session.GetConfig(ctx, "running") + config, err := rpc.GetConfig{Source: rpc.Running}.Exec(ctx, session) assert.NoError(t, err) t.Logf("configuration: %s", config) @@ -114,7 +115,7 @@ func TestBadGetConfig(t *testing.T) { session := setupSSH(t) ctx := context.Background() - cfg, err := session.GetConfig(ctx, "non-exist") + cfg, err := rpc.GetConfig{Source: "non-exist"}.Exec(ctx, session) assert.Nil(t, cfg) var rpcErr netconf.RPCError assert.ErrorAs(t, err, &rpcErr) @@ -131,8 +132,13 @@ func TestJunosCommand(t *testing.T) { Command: "show version", } + var reply struct { + netconf.RPCReply + Result string `xml:"command-output>result"` + } + ctx := context.Background() - reply, err := session.Do(ctx, &cmd) + err := session.Exec(ctx, &cmd, &reply) assert.NoError(t, err) - assert.NoError(t, reply.Err()) + assert.Empty(t, reply.RPCErrors) } diff --git a/msg.go b/msg.go index ba0856f..6a8d071 100644 --- a/msg.go +++ b/msg.go @@ -2,131 +2,39 @@ package netconf import ( "encoding/xml" + "errors" "fmt" + "io" "slices" "strings" "time" ) -// RawXML captures the raw xml for the given element. Used to process certain -// elements later. -type RawXML []byte - -func (x *RawXML) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error { - var inner struct { - Data []byte `xml:",innerxml"` - } - - if err := d.DecodeElement(&inner, &start); err != nil { - return err - } - - *x = inner.Data - return nil -} +// RPC maps the xml value of in RFC6241 +type RPC struct { + XMLName xml.Name `xml:"urn:ietf:params:xml:ns:netconf:base:1.0 rpc"` -// MarshalXML implements xml.Marshaller. Raw XML is passed verbatim, errors and -// all. -func (x *RawXML) MarshalXML(e *xml.Encoder, start xml.StartElement) error { - inner := struct { - Data []byte `xml:",innerxml"` - }{ - Data: []byte(*x), - } - return e.EncodeElement(&inner, start) -} + // Managed by the session. Will be overwritten when sent on the wire. + MessageID string `xml:"message-id,attr"` -// helloMsg maps the xml value of the message in RFC6241 -type helloMsg struct { - XMLName xml.Name `xml:"urn:ietf:params:xml:ns:netconf:base:1.0 hello"` - SessionID uint64 `xml:"session-id,omitempty"` - Capabilities []string `xml:"capabilities>capability"` -} + // User-defined attributes (e.g. xmlns:ex="..."). Per RFC6241 sec 7.3, these + // must be preserved and reflected in the associated . + Attributes []xml.Attr `xml:",attr"` -// request maps the xml value of in RFC6241 -type request struct { - XMLName xml.Name `xml:"urn:ietf:params:xml:ns:netconf:base:1.0 rpc"` - MessageID uint64 `xml:"message-id,attr"` - Operation any `xml:",innerxml"` + // The inner XML of the RPC message (e.g. , ) + Operation any `xml:",innerxml"` // The operation payload (e.g. ) } -func (msg *request) MarshalXML(e *xml.Encoder, start xml.StartElement) error { - if msg.Operation == nil { - return fmt.Errorf("operation cannot be nil") - } - - // TODO: validate operation is named? - - // alias the type to not cause recursion calling e.Encode - type rpcMsg request - inner := rpcMsg(*msg) - return e.Encode(&inner) -} - -// Reply maps the xml value of in RFC6241 -type Reply struct { - XMLName xml.Name `xml:"urn:ietf:params:xml:ns:netconf:base:1.0 rpc-reply"` - MessageID uint64 `xml:"message-id,attr"` - Errors RPCErrors `xml:"rpc-error,omitempty"` - Body []byte `xml:",innerxml"` -} - -// Decode will decode the body of a reply into a value pointed to by v. This is -// a simple wrapper around xml.Unmarshal. -func (r Reply) Decode(v interface{}) error { - return xml.Unmarshal(r.Body, v) -} - -// Err will return go error(s) from a Reply that are of the given severities. If -// no severity is given then it defaults to `ErrSevError`. -// -// If one error is present then the underlyign type is `RPCError`. If more than -// one error exists than the underlying type is `[]RPCError` -// -// Example - -// get all errors with severity of error -// -// if err := reply.Err(ErrSevError); err != nil { /* ... */ } -// -// or -// -// if err := reply.Err(); err != nil { /* ... */ } -// -// get all errors with severity of only warning -// -// if err := reply.Err(ErrSevWarning); err != nil { /* ... */ } -// -// get all errors -// -// if err := reply.Err(ErrSevWarning, ErrSevError); err != nil { /* ... */ } -func (r Reply) Err(severity ...ErrSeverity) error { - // fast escape for no errors - if len(r.Errors) == 0 { - return nil - } +type RPCReply struct { + XMLName xml.Name `xml:"urn:ietf:params:xml:ns:netconf:base:1.0 rpc-reply"` - errs := r.Errors.Filter(severity...) - switch len(errs) { - case 0: - return nil - case 1: - return errs[0] - default: - return errs - } -} + // The message-id must match that of the associated + MessageID string `xml:"message-id,attr"` -type Notification struct { - XMLName xml.Name `xml:"urn:ietf:params:xml:ns:netconf:notification:1.0 notification"` - EventTime time.Time `xml:"eventTime"` - Body []byte `xml:",innerxml"` -} + // Additional attributes on the . + Attributes []xml.Attr `xml:",attr"` -// Decode will decode the body of a noticiation into a value pointed to by v. -// This is a simple wrapper around xml.Unmarshal. -func (r Notification) Decode(v interface{}) error { - return xml.Unmarshal(r.Body, v) + RPCErrors RPCErrors `xml:"rpc-error,omitempty"` } type ErrSeverity string @@ -206,7 +114,16 @@ func (errs RPCErrors) Filter(severity ...ErrSeverity) RPCErrors { } func (errs RPCErrors) Error() string { + if len(errs) == 0 { + return "" + } + + if len(errs) == 1 { + return errs[0].Error() + } + var sb strings.Builder + sb.WriteString("multiple netconf errors:\n") for i, err := range errs { if i > 0 { sb.WriteRune('\n') @@ -216,10 +133,94 @@ func (errs RPCErrors) Error() string { return sb.String() } -func (errs RPCErrors) Unwrap() []error { - boxedErrs := make([]error, len(errs)) +func (errs RPCErrors) Unwrap() error { + if len(errs) == 0 { + return nil + } + if len(errs) == 1 { + return errs[0] + } + + unboxedErrs := make([]error, len(errs)) for i, err := range errs { - boxedErrs[i] = err + unboxedErrs[i] = err + } + return errors.Join(unboxedErrs...) +} + +type Notification struct { + XMLName xml.Name `xml:"urn:ietf:params:xml:ns:netconf:notification:1.0 notification"` + EventTime time.Time `xml:"eventTime"` +} + +// HelloMsg maps the xml value of the message in RFC6241 +type HelloMsg struct { + XMLName xml.Name `xml:"urn:ietf:params:xml:ns:netconf:base:1.0 hello"` + SessionID uint64 `xml:"session-id,omitempty"` + Capabilities []string `xml:"capabilities>capability"` +} + +type Request struct { + RPC RPC +} + +func NewRequest(op any) *Request { + return &Request{ + RPC: RPC{ + Operation: op, + }, + } +} + +type Response struct { + io.ReadCloser + + MessageID string // Captured from the message-id attribute + Attributes []xml.Attr // Any other attributes on the envelope +} + +// Decode will decode the response XML into the provided value v and then close +// the message releasing the session to process new messages. +func (d *Response) Decode(v any) (err error) { + defer func() { + err = errors.Join(err, d.Close()) + }() + + if err := xml.NewDecoder(d).Decode(v); err != nil { + return fmt.Errorf("failed to decode response: %w", err) + } + + return err +} + +func (d *Response) Close() error { + if d.ReadCloser == nil { + return nil + } + return d.ReadCloser.Close() +} + +// RawXML is a helper type for getting innerxml content as a byte slice. +type RawXML []byte + +func (x *RawXML) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error { + var inner struct { + Data []byte `xml:",innerxml"` + } + + if err := d.DecodeElement(&inner, &start); err != nil { + return err + } + + *x = inner.Data + return nil +} + +func (x RawXML) MarshalXML(e *xml.Encoder, start xml.StartElement) error { + inner := struct { + Data []byte `xml:",innerxml"` + }{ + Data: []byte(x), } - return boxedErrs + return e.EncodeElement(&inner, start) } diff --git a/msg_test.go b/msg_test.go index 279dfb4..99ed4e6 100644 --- a/msg_test.go +++ b/msg_test.go @@ -65,12 +65,12 @@ func TestRawXMLMarshal(t *testing.T) { var helloMsgTestTable = []struct { name string raw []byte - msg helloMsg + msg HelloMsg }{ { name: "basic", raw: []byte(`urn:ietf:params:netconf:base:1.0urn:ietf:params:netconf:base:1.1`), - msg: helloMsg{ + msg: HelloMsg{ XMLName: xml.Name{ Local: "hello", Space: "urn:ietf:params:xml:ns:netconf:base:1.0", @@ -100,7 +100,7 @@ var helloMsgTestTable = []struct { 410 `), - msg: helloMsg{ + msg: HelloMsg{ XMLName: xml.Name{ Local: "hello", Space: "urn:ietf:params:xml:ns:netconf:base:1.0", @@ -127,7 +127,7 @@ var helloMsgTestTable = []struct { func TestUnmarshalHelloMsg(t *testing.T) { for _, tc := range helloMsgTestTable { t.Run(tc.name, func(t *testing.T) { - var got helloMsg + var got HelloMsg err := xml.Unmarshal(tc.raw, &got) assert.NoError(t, err) assert.Equal(t, got, tc.msg) @@ -151,11 +151,6 @@ func TestMarshalRPCMsg(t *testing.T) { err bool want []byte }{ - { - name: "nil", - operation: nil, - err: true, - }, { name: "string", operation: "", @@ -166,11 +161,6 @@ func TestMarshalRPCMsg(t *testing.T) { operation: []byte(""), want: []byte(``), }, - { - name: "validate", - operation: ValidateReq{Source: Running}, - want: []byte(``), - }, { name: "namedStruct", operation: struct { @@ -194,8 +184,8 @@ func TestMarshalRPCMsg(t *testing.T) { for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { - out, err := xml.Marshal(&request{ - MessageID: 1, + out, err := xml.Marshal(&RPC{ + MessageID: "1", Operation: tc.operation, }) t.Logf("out: %s", out) @@ -229,46 +219,33 @@ func TestUnmarshalRPCReply(t *testing.T) { tt := []struct { name string reply []byte - want Reply + want any }{ { name: "error", reply: replyJunosGetConfigError, - want: Reply{ + want: RPCReply{ XMLName: xml.Name{ - Space: "urn:ietf:params:xml:ns:netconf:base:1.0", Local: "rpc-reply", + Space: "urn:ietf:params:xml:ns:netconf:base:1.0", }, - MessageID: 1, - Errors: []RPCError{ + MessageID: "1", + RPCErrors: []RPCError{ { Type: ErrTypeProtocol, Tag: ErrOperationFailed, Severity: SevError, Message: "syntax error, expecting or ", - Info: []byte(` -non-exist -`), + Info: []byte("\nnon-exist\n"), }, }, - Body: []byte(` - -protocol -operation-failed -error -syntax error, expecting <candidate/> or <running/> - -non-exist - - -`), }, }, } for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { - var got Reply + var got RPCReply err := xml.Unmarshal(tc.reply, &got) assert.NoError(t, err) assert.Equal(t, tc.want, got) diff --git a/ops.go b/ops.go deleted file mode 100644 index b7ed6cf..0000000 --- a/ops.go +++ /dev/null @@ -1,514 +0,0 @@ -package netconf - -import ( - "context" - "encoding/xml" - "fmt" - "strings" - "time" -) - -type ExtantBool bool - -func (b ExtantBool) MarshalXML(e *xml.Encoder, start xml.StartElement) error { - if !b { - return nil - } - // This produces a empty start/end tag (i.e ) vs a self-closing - // tag (() which should be the same in XML, however I know certain - // vendors may have issues with this format. We may have to process this - // after xml encoding. - // - // See https://github.com/golang/go/issues/21399 - // or https://github.com/golang/go/issues/26756 for a different hack. - return e.EncodeElement(struct{}{}, start) -} - -func (b *ExtantBool) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error { - v := &struct{}{} - if err := d.DecodeElement(v, &start); err != nil { - return err - } - *b = v != nil - return nil -} - -type OKResp struct { - OK ExtantBool `xml:"ok"` -} - -type Datastore string - -func (s Datastore) MarshalXML(e *xml.Encoder, start xml.StartElement) error { - if s == "" { - return fmt.Errorf("datastores cannot be empty") - } - - // XXX: it would be nice to actually just block names with crap in them - // instead of escaping them, but we need to find a list of what is allowed - // in an xml tag. - escaped, err := escapeXML(string(s)) - if err != nil { - return fmt.Errorf("invalid string element: %w", err) - } - - v := struct { - Elem string `xml:",innerxml"` - }{Elem: "<" + escaped + "/>"} - return e.EncodeElement(&v, start) -} - -func escapeXML(input string) (string, error) { - buf := &strings.Builder{} - if err := xml.EscapeText(buf, []byte(input)); err != nil { - return "", err - } - return buf.String(), nil -} - -type URL string - -func (u URL) MarshalXML(e *xml.Encoder, start xml.StartElement) error { - v := struct { - URL string `xml:"url"` - }{string(u)} - return e.EncodeElement(&v, start) -} - -const ( - // Running configuration datastore. Required by RFC6241 - Running Datastore = "running" - - // Candidate configuration configuration datastore. Supported with the - // `:candidate` capability defined in RFC6241 section 8.3 - Candidate Datastore = "candidate" - - // Startup configuration configuration datastore. Supported with the - // `:startup` capability defined in RFC6241 section 8.7 - Startup Datastore = "startup" // -) - -type GetConfigReq struct { - XMLName xml.Name `xml:"get-config"` - Source Datastore `xml:"source"` - // Filter -} - -type GetConfigReply struct { - XMLName xml.Name `xml:"data"` - Config []byte `xml:",innerxml"` -} - -// GetConfig implements the rpc operation defined in [RFC6241 7.1]. -// `source` is the datastore to query. -// -// [RFC6241 7.1]: https://www.rfc-editor.org/rfc/rfc6241.html#section-7.1 -func (s *Session) GetConfig(ctx context.Context, source Datastore) ([]byte, error) { - req := GetConfigReq{ - Source: source, - } - - var resp GetConfigReply - if err := s.Call(ctx, &req, &resp); err != nil { - return nil, err - } - - return resp.Config, nil -} - -// MergeStrategy defines the strategies for merging configuration in a -// ` operation`. -// -// *Note*: in RFC6241 7.2 this is called the `operation` attribute and -// `default-operation` parameter. Since the `operation` term is already -// overloaded this was changed to `MergeStrategy` for a cleaner API. -type MergeStrategy string - -const ( - // MergeConfig configuration elements are merged together at the level at - // which this specified. Can be used for config elements as well as default - // defined with [WithDefaultMergeStrategy] option. - MergeConfig MergeStrategy = "merge" - - // ReplaceConfig defines that the incoming config change should replace the - // existing config at the level which it is specified. This can be - // specified on individual config elements or set as the default strategy set - // with [WithDefaultMergeStrategy] option. - ReplaceConfig MergeStrategy = "replace" - - // NoMergeStrategy is only used as a default strategy defined in - // [WithDefaultMergeStrategy]. Elements must specific one of the other - // strategies with the `operation` Attribute on elements in the `` - // subtree. Elements without the `operation` attribute are ignored. - NoMergeStrategy MergeStrategy = "none" - - // CreateConfig allows a subtree element to be created only if it doesn't - // already exist. - // This strategy is only used as the `operation` attribute of - // a `` element and cannot be used as the default strategy. - CreateConfig MergeStrategy = "create" - - // DeleteConfig will completely delete subtree from the config only if it - // already exists. This strategy is only used as the `operation` attribute - // of a `` element and cannot be used as the default strategy. - DeleteConfig MergeStrategy = "delete" - - // RemoveConfig will remove subtree from the config. If the subtree doesn't - // exist in the datastore then it is silently skipped. This strategy is - // only used as the `operation` attribute of a `` element and cannot - // be used as the default strategy. - RemoveConfig MergeStrategy = "remove" -) - -// TestStrategy defines the beahvior for testing configuration before applying it in a `` operation. -// -// *Note*: in RFC6241 7.2 this is called the `test-option` parameter. Since the `option` term is already -// overloaded this was changed to `TestStrategy` for a cleaner API. -type TestStrategy string - -const ( - // TestThenSet will validate the configuration and only if is is valid then - // apply the configuration to the datastore. - TestThenSet TestStrategy = "test-then-set" - - // SetOnly will not do any testing before applying it. - SetOnly TestStrategy = "set" - - // Test only will validation the incoming configuration and return the - // results without modifying the underlying store. - TestOnly TestStrategy = "test-only" -) - -// ErrorStrategy defines the behavior when an error is encountered during a `` operation. -// -// *Note*: in RFC6241 7.2 this is called the `error-option` parameter. Since the `option` term is already -// overloaded this was changed to `ErrorStrategy` for a cleaner API. -type ErrorStrategy string - -const ( - // StopOnError will about the `` operation on the first error. - StopOnError ErrorStrategy = "stop-on-error" - - // ContinueOnError will continue to parse the configuration data even if an - // error is encountered. Errors are still recorded and reported in the - // reply. - ContinueOnError ErrorStrategy = "continue-on-error" - - // RollbackOnError will restore the configuration back to before the - // `` operation took place. This requires the device to - // support the `:rollback-on-error` capabilitiy. - RollbackOnError ErrorStrategy = "rollback-on-error" -) - -type ( - defaultMergeStrategy MergeStrategy - testStrategy TestStrategy - errorStrategy ErrorStrategy -) - -func (o defaultMergeStrategy) apply(req *EditConfigReq) { req.DefaultMergeStrategy = MergeStrategy(o) } -func (o testStrategy) apply(req *EditConfigReq) { req.TestStrategy = TestStrategy(o) } -func (o errorStrategy) apply(req *EditConfigReq) { req.ErrorStrategy = ErrorStrategy(o) } - -// WithDefaultMergeStrategy sets the default config merging strategy for the -// operation. Only [Merge], [Replace], and [None] are supported -// (the rest of the strategies are for defining as attributed in individual -// elements inside the `` subtree). -func WithDefaultMergeStrategy(op MergeStrategy) EditConfigOption { return defaultMergeStrategy(op) } - -// WithTestStrategy sets the `test-option` in the `“ operation. -// This defines what testing should be done the supplied configuration. See the -// documentation on [TestStrategy] for details on each strategy. -func WithTestStrategy(op TestStrategy) EditConfigOption { return testStrategy(op) } - -// WithErrorStrategy sets the `error-option` in the `` operation. -// This defines the behavior when errors are encountered applying the supplied -// config. See [ErrorStrategy] for the available options. -func WithErrorStrategy(opt ErrorStrategy) EditConfigOption { return errorStrategy(opt) } - -type EditConfigReq struct { - XMLName xml.Name `xml:"edit-config"` - Target Datastore `xml:"target"` - DefaultMergeStrategy MergeStrategy `xml:"default-operation,omitempty"` - TestStrategy TestStrategy `xml:"test-option,omitempty"` - ErrorStrategy ErrorStrategy `xml:"error-option,omitempty"` - - // either of these two values - Config any `xml:"config,omitempty"` - URL string `xml:"url,omitempty"` -} - -// EditOption is a optional arguments to [Session.EditConfig] method -type EditConfigOption interface { - apply(*EditConfigReq) -} - -// EditConfig issues the `` operation defined in [RFC6241 7.2] for -// updating an existing target config datastore. -// -// [RFC6241 7.2]: https://www.rfc-editor.org/rfc/rfc6241.html#section-7.2 -func (s *Session) EditConfig(ctx context.Context, target Datastore, config any, opts ...EditConfigOption) error { - req := EditConfigReq{ - Target: target, - } - - // XXX: Should we use reflect here? - switch v := config.(type) { - case string: - req.Config = struct { - Inner []byte `xml:",innerxml"` - }{Inner: []byte(v)} - case []byte: - req.Config = struct { - Inner []byte `xml:",innerxml"` - }{Inner: v} - case URL: - req.URL = string(v) - default: - req.Config = config - } - - for _, opt := range opts { - opt.apply(&req) - } - - var resp OKResp - return s.Call(ctx, &req, &resp) -} - -type CopyConfigReq struct { - XMLName xml.Name `xml:"copy-config"` - Source any `xml:"source"` - Target any `xml:"target"` -} - -// CopyConfig issues the `` operation as defined in [RFC6241 7.3] -// for copying an entire config to/from a source and target datastore. -// -// A `` element defining a full config can be used as the source. -// -// If a device supports the `:url` capability than a [URL] object can be used -// for the source or target datastore. -// -// [RFC6241 7.3] https://www.rfc-editor.org/rfc/rfc6241.html#section-7.3 -func (s *Session) CopyConfig(ctx context.Context, source, target any) error { - req := CopyConfigReq{ - Source: source, - Target: target, - } - - var resp OKResp - return s.Call(ctx, &req, &resp) -} - -type DeleteConfigReq struct { - XMLName xml.Name `xml:"delete-config"` - Target Datastore `xml:"target"` -} - -func (s *Session) DeleteConfig(ctx context.Context, target Datastore) error { - req := DeleteConfigReq{ - Target: target, - } - - var resp OKResp - return s.Call(ctx, &req, &resp) -} - -type LockReq struct { - XMLName xml.Name - Target Datastore `xml:"target"` -} - -func (s *Session) Lock(ctx context.Context, target Datastore) error { - req := LockReq{ - XMLName: xml.Name{Space: "urn:ietf:params:xml:ns:netconf:base:1.0", Local: "lock"}, - Target: target, - } - - var resp OKResp - return s.Call(ctx, &req, &resp) -} - -func (s *Session) Unlock(ctx context.Context, target Datastore) error { - req := LockReq{ - XMLName: xml.Name{Space: "urn:ietf:params:xml:ns:netconf:base:1.0", Local: "unlock"}, - Target: target, - } - - var resp OKResp - return s.Call(ctx, &req, &resp) -} - -/* -func (s *Session) Get(ctx context.Context, filter Filter) error { - panic("unimplemented") -} -*/ - -type KillSessionReq struct { - XMLName xml.Name `xml:"kill-session"` - SessionID uint32 `xml:"session-id"` -} - -func (s *Session) KillSession(ctx context.Context, sessionID uint32) error { - req := KillSessionReq{ - SessionID: sessionID, - } - - var resp OKResp - return s.Call(ctx, &req, &resp) -} - -type ValidateReq struct { - XMLName xml.Name `xml:"validate"` - Source any `xml:"source"` -} - -func (s *Session) Validate(ctx context.Context, source any) error { - req := ValidateReq{ - Source: source, - } - - var resp OKResp - return s.Call(ctx, &req, &resp) -} - -type CommitReq struct { - XMLName xml.Name `xml:"commit"` - Confirmed ExtantBool `xml:"confirmed,omitempty"` - ConfirmTimeout int64 `xml:"confirm-timeout,omitempty"` - Persist string `xml:"persist,omitempty"` - PersistID string `xml:"persist-id,omitempty"` -} - -// CommitOption is a optional arguments to [Session.Commit] method -type CommitOption interface { - apply(*CommitReq) -} - -type confirmed bool -type confirmedTimeout struct { - time.Duration -} -type persist string -type persistID string - -func (o confirmed) apply(req *CommitReq) { req.Confirmed = true } -func (o confirmedTimeout) apply(req *CommitReq) { - req.Confirmed = true - req.ConfirmTimeout = int64(o.Seconds()) -} -func (o persist) apply(req *CommitReq) { - req.Confirmed = true - req.Persist = string(o) -} -func (o persistID) apply(req *CommitReq) { req.PersistID = string(o) } - -// RollbackOnError will restore the configuration back to before the -// `` operation took place. This requires the device to -// support the `:rollback-on-error` capability. - -// WithConfirmed will mark the commits as requiring confirmation or will rollback -// after the default timeout on the device (default should be 600s). The commit -// can be confirmed with another `` call without the confirmed option, -// extended by calling with `Commit` With `WithConfirmed` or -// `WithConfirmedTimeout` or canceling the commit with a `CommitCancel` call. -// This requires the device to support the `:confirmed-commit:1.1` capability. -func WithConfirmed() CommitOption { return confirmed(true) } - -// WithConfirmedTimeout is like `WithConfirmed` but using the given timeout -// duration instead of the device's default. -func WithConfirmedTimeout(timeout time.Duration) CommitOption { return confirmedTimeout{timeout} } - -// WithPersist allows you to set a identifier to confirm a commit in another -// sessions. Confirming the commit requires setting the `WithPersistID` in the -// following `Commit` call matching the id set on the confirmed commit. Will -// mark the commit as confirmed if not already set. -func WithPersist(id string) CommitOption { return persist(id) } - -// WithPersistID is used to confirm a previous commit set with a given -// identifier. This allows you to confirm a commit from (potentially) another -// sesssion. -func WithPersistID(id string) persistID { return persistID(id) } - -// Commit will commit a canidate config to the running comming. This requires -// the device to support the `:canidate` capability. -func (s *Session) Commit(ctx context.Context, opts ...CommitOption) error { - var req CommitReq - for _, opt := range opts { - opt.apply(&req) - } - - if req.PersistID != "" && req.Confirmed { - return fmt.Errorf("PersistID cannot be used with Confirmed/ConfirmedTimeout or Persist options") - } - - var resp OKResp - return s.Call(ctx, &req, &resp) -} - -// CancelCommitOption is a optional arguments to [Session.CancelCommit] method -type CancelCommitOption interface { - applyCancelCommit(*CancelCommitReq) -} - -func (o persistID) applyCancelCommit(req *CancelCommitReq) { req.PersistID = string(o) } - -type CancelCommitReq struct { - XMLName xml.Name `xml:"cancel-commit"` - PersistID string `xml:"persist-id,omitempty"` -} - -func (s *Session) CancelCommit(ctx context.Context, opts ...CancelCommitOption) error { - var req CancelCommitReq - for _, opt := range opts { - opt.applyCancelCommit(&req) - } - - var resp OKResp - return s.Call(ctx, &req, &resp) -} - -// CreateSubscriptionOption is a optional arguments to [Session.CreateSubscription] method -type CreateSubscriptionOption interface { - apply(req *CreateSubscriptionReq) -} - -type CreateSubscriptionReq struct { - XMLName xml.Name `xml:"urn:ietf:params:xml:ns:netconf:notification:1.0 create-subscription"` - Stream string `xml:"stream,omitempty"` - // TODO: Implement filter - //Filter int64 `xml:"filter,omitempty"` - StartTime string `xml:"startTime,omitempty"` - EndTime string `xml:"endTime,omitempty"` -} - -type stream string -type startTime time.Time -type endTime time.Time - -func (o stream) apply(req *CreateSubscriptionReq) { - req.Stream = string(o) -} -func (o startTime) apply(req *CreateSubscriptionReq) { - req.StartTime = time.Time(o).Format(time.RFC3339) -} -func (o endTime) apply(req *CreateSubscriptionReq) { - req.EndTime = time.Time(o).Format(time.RFC3339) -} - -func WithStreamOption(s string) CreateSubscriptionOption { return stream(s) } -func WithStartTimeOption(st time.Time) CreateSubscriptionOption { return startTime(st) } -func WithEndTimeOption(et time.Time) CreateSubscriptionOption { return endTime(et) } - -func (s *Session) CreateSubscription(ctx context.Context, opts ...CreateSubscriptionOption) error { - var req CreateSubscriptionReq - for _, opt := range opts { - opt.apply(&req) - } - // TODO: eventual custom notifications rpc logic, e.g. create subscription only if notification capability is present - - var resp OKResp - return s.Call(ctx, &req, &resp) -} diff --git a/ops_test.go b/ops_test.go deleted file mode 100644 index 79097a4..0000000 --- a/ops_test.go +++ /dev/null @@ -1,598 +0,0 @@ -package netconf - -import ( - "context" - "encoding/xml" - "regexp" - "strconv" - "testing" - "time" - - "github.com/stretchr/testify/assert" -) - -func TestUnmarshalOk(t *testing.T) { - tt := []struct { - name string - input string - want bool - }{ - {"selfclosing", ">", true}, - {"missing", "", false}, - {"closetag", "", true}, - } - for _, tc := range tt { - t.Run(tc.name, func(t *testing.T) { - var v struct { - XMLName xml.Name `xml:"foo"` - Ok ExtantBool `xml:"ok"` - } - - err := xml.Unmarshal([]byte(tc.input), &v) - assert.NoError(t, err) - assert.Equal(t, tc.want, bool(v.Ok)) - }) - } -} - -func TestMarshalDatastore(t *testing.T) { - tt := []struct { - input Datastore - want string - shouldErr bool - }{ - {Running, "", false}, - {Startup, "", false}, - {Candidate, "", false}, - {Datastore("custom-store"), "", false}, - {Datastore(""), "", true}, - {Datastore(""), "<<xml-elements>/>", true}, - } - - for _, tc := range tt { - t.Run(string(tc.input), func(t *testing.T) { - v := struct { - XMLName xml.Name `xml:"rpc"` - Target Datastore `xml:"target"` - }{Target: tc.input} - - got, err := xml.Marshal(&v) - if !tc.shouldErr { - assert.NoError(t, err) - } - assert.Equal(t, tc.want, string(got)) - }) - } -} - -func TestGetConfig(t *testing.T) { - ts := newTestServer(t) - sess := newSession(ts.transport()) - go sess.recv() - - ts.queueRespString("foo") - - got, err := sess.GetConfig(context.Background(), Running) - assert.NoError(t, err) - - _, err = ts.popReqString() - assert.NoError(t, err) - - want := []byte("foo") - assert.Equal(t, want, got) -} - -type structuredCfg struct { - System structuredCfgSystem `xml:"system"` -} - -type structuredCfgSystem struct { - Hostname string `xml:"host-name"` -} - -const intfaceConfig = ` - - - ge-0/0/2 - - 0 - - -
- 2.2.2.1/32 -
-
-
-
-
-
-` - -func TestEditConfig(t *testing.T) { - tt := []struct { - name string - target Datastore - config any - options []EditConfigOption - mustMatch []*regexp.Regexp - noMatch []*regexp.Regexp - }{ - { - name: "running structured no options", - target: Running, - config: structuredCfg{ - System: structuredCfgSystem{ - Hostname: "darkstar", - }, - }, - mustMatch: []*regexp.Regexp{ - regexp.MustCompile(`\S*\S*`), - regexp.MustCompile( - `\S*\S*darkstar\S*\S*`, - ), - }, - noMatch: []*regexp.Regexp{ - regexp.MustCompile(``), - }, - }, - { - name: "canidate string all options", - target: Candidate, - config: intfaceConfig, - options: []EditConfigOption{ - WithDefaultMergeStrategy(ReplaceConfig), - WithErrorStrategy(ContinueOnError), - WithTestStrategy(TestOnly), - }, - mustMatch: []*regexp.Regexp{ - regexp.MustCompile(`\S*\S*`), - regexp.MustCompile(`ge-0/0/2`), - regexp.MustCompile(`replace`), - regexp.MustCompile(`test-only`), - regexp.MustCompile(`continue-on-error`), - }, - noMatch: []*regexp.Regexp{ - regexp.MustCompile(``), - }, - }, - { - name: "byteslice config", - target: Running, - config: []byte(""), - mustMatch: []*regexp.Regexp{ - regexp.MustCompile(``), - }, - }, - { - name: "startup url no options", - target: Startup, - config: URL("ftp://myftpesrver/foo/config.xml"), - mustMatch: []*regexp.Regexp{ - regexp.MustCompile(`\S*\S*`), - regexp.MustCompile(`ftp://myftpesrver/foo/config.xml`), - }, - noMatch: []*regexp.Regexp{ - regexp.MustCompile(``), - }, - }, - } - for _, tc := range tt { - t.Run(tc.name, func(t *testing.T) { - ts := newTestServer(t) - sess := newSession(ts.transport()) - go sess.recv() - - ts.queueRespString(``) - - err := sess.EditConfig(context.Background(), tc.target, tc.config, tc.options...) - assert.NoError(t, err) - - sentMsg, err := ts.popReq() - assert.NoError(t, err) - - for _, match := range tc.mustMatch { - assert.Regexp(t, match, string(sentMsg)) - } - - for _, match := range tc.noMatch { - assert.NotRegexp(t, match, string(sentMsg)) - } - }) - } -} - -// TODO: TestEditConfigError() - -func TestCopyConfig(t *testing.T) { - tt := []struct { - name string - source, target any - matches []*regexp.Regexp - }{ - { - name: "running->startup", - source: Running, - target: Startup, - matches: []*regexp.Regexp{ - regexp.MustCompile(`\S*\S*`), - regexp.MustCompile(`\S*\S*`), - }, - }, - { - name: "running->url", - source: Running, - target: URL("ftp://myserver.example.com/router.cfg"), - matches: []*regexp.Regexp{ - regexp.MustCompile(`\S*\S*`), - regexp.MustCompile(`\S*ftp://myserver.example.com/router.cfg\S*`), - }, - }, - { - name: "url->candidate", - source: URL("http://myserver.example.com/router.cfg"), - target: Candidate, - matches: []*regexp.Regexp{ - regexp.MustCompile(`\S*http://myserver.example.com/router.cfg\S*`), - regexp.MustCompile(`\S*\S*`), - }, - }, - } - - for _, tc := range tt { - t.Run(tc.name, func(t *testing.T) { - ts := newTestServer(t) - sess := newSession(ts.transport()) - go sess.recv() - - ts.queueRespString(``) - - err := sess.CopyConfig(context.Background(), tc.source, tc.target) - assert.NoError(t, err) - - sentMsg, err := ts.popReq() - assert.NoError(t, err) - - for _, match := range tc.matches { - assert.Regexp(t, match, string(sentMsg)) - } - }) - } -} - -func TestDeleteConfig(t *testing.T) { - tt := []struct { - target Datastore - matches []*regexp.Regexp - }{ - { - target: Startup, - matches: []*regexp.Regexp{ - regexp.MustCompile(`\S*\S*\S*\S*`), - }, - }, - } - - for _, tc := range tt { - t.Run(string(tc.target), func(t *testing.T) { - ts := newTestServer(t) - sess := newSession(ts.transport()) - go sess.recv() - - ts.queueRespString(``) - - err := sess.DeleteConfig(context.Background(), tc.target) - assert.NoError(t, err) - - sentMsg, err := ts.popReq() - assert.NoError(t, err) - - for _, match := range tc.matches { - assert.Regexp(t, match, string(sentMsg)) - } - }) - } -} - -func TestValidateConfig(t *testing.T) { - tt := []struct { - name string - source any - matches []*regexp.Regexp - }{ - { - name: "candidate", - source: Candidate, - matches: []*regexp.Regexp{ - regexp.MustCompile(`\S*\S*\S*\S*`), - }, - }, - // XXX: test []byte,string - // XXX: test xml object - } - - for _, tc := range tt { - t.Run(tc.name, func(t *testing.T) { - ts := newTestServer(t) - sess := newSession(ts.transport()) - go sess.recv() - - ts.queueRespString(``) - - err := sess.Validate(context.Background(), tc.source) - assert.NoError(t, err) - - sentMsg, err := ts.popReq() - assert.NoError(t, err) - - for _, match := range tc.matches { - assert.Regexp(t, match, string(sentMsg)) - } - }) - } -} - -func TestLock(t *testing.T) { - tt := []struct { - target Datastore - matches []*regexp.Regexp - }{ - { - target: Candidate, - matches: []*regexp.Regexp{ - regexp.MustCompile(`\S*\S*\S*\S*`), - }, - }, - } - - for _, tc := range tt { - t.Run(string(tc.target), func(t *testing.T) { - ts := newTestServer(t) - sess := newSession(ts.transport()) - go sess.recv() - - ts.queueRespString(``) - - err := sess.Lock(context.Background(), tc.target) - assert.NoError(t, err) - - sentMsg, err := ts.popReq() - assert.NoError(t, err) - - for _, match := range tc.matches { - assert.Regexp(t, match, string(sentMsg)) - } - }) - } -} - -func TestUnlock(t *testing.T) { - tt := []struct { - target Datastore - matches []*regexp.Regexp - }{ - { - target: Candidate, - matches: []*regexp.Regexp{ - regexp.MustCompile(`\S*\S*\S*\S*`), - }, - }, - } - - for _, tc := range tt { - t.Run(string(tc.target), func(t *testing.T) { - ts := newTestServer(t) - sess := newSession(ts.transport()) - go sess.recv() - - ts.queueRespString(``) - - err := sess.Unlock(context.Background(), tc.target) - assert.NoError(t, err) - - sentMsg, err := ts.popReq() - assert.NoError(t, err) - - for _, match := range tc.matches { - assert.Regexp(t, match, string(sentMsg)) - } - }) - } -} - -func TestKillSession(t *testing.T) { - tt := []struct { - id uint32 - matches []*regexp.Regexp - }{ - { - id: 42, - matches: []*regexp.Regexp{ - regexp.MustCompile(`\S*42\S*`), - }, - }, - } - - for _, tc := range tt { - t.Run(strconv.Itoa(int(tc.id)), func(t *testing.T) { - ts := newTestServer(t) - sess := newSession(ts.transport()) - go sess.recv() - - ts.queueRespString(``) - - err := sess.KillSession(context.Background(), tc.id) - assert.NoError(t, err) - - sentMsg, err := ts.popReq() - assert.NoError(t, err) - - for _, match := range tc.matches { - assert.Regexp(t, match, string(sentMsg)) - } - }) - } -} - -func TestCommit(t *testing.T) { - tt := []struct { - name string - options []CommitOption - matches []*regexp.Regexp - }{ - { - name: "noOptions", - matches: []*regexp.Regexp{ - regexp.MustCompile(``), - }, - }, - { - name: "confirmed", - options: []CommitOption{WithConfirmed()}, - matches: []*regexp.Regexp{ - regexp.MustCompile(``), - }, - }, - { - name: "confirmed", - options: []CommitOption{WithConfirmedTimeout(1 * time.Minute)}, - matches: []*regexp.Regexp{ - regexp.MustCompile(`60`), - }, - }, - { - name: "persist", - options: []CommitOption{WithPersist("myid")}, - matches: []*regexp.Regexp{ - regexp.MustCompile(`myid`), - }, - }, - { - name: "persist_id", - options: []CommitOption{WithPersistID("myid")}, - matches: []*regexp.Regexp{ - regexp.MustCompile(`myid`), - }, - }, - } - - for _, tc := range tt { - t.Run(tc.name, func(t *testing.T) { - ts := newTestServer(t) - sess := newSession(ts.transport()) - go sess.recv() - - ts.queueRespString(``) - - err := sess.Commit(context.Background(), tc.options...) - assert.NoError(t, err) - - sentMsg, err := ts.popReq() - assert.NoError(t, err) - - for _, match := range tc.matches { - assert.Regexp(t, match, string(sentMsg)) - } - }) - } -} - -func TestCancelCommit(t *testing.T) { - tt := []struct { - name string - options []CancelCommitOption - matches []*regexp.Regexp - }{ - { - name: "noOptions", - matches: []*regexp.Regexp{ - regexp.MustCompile(``), - }, - }, - { - name: "persist_id", - options: []CancelCommitOption{WithPersistID("myid")}, - matches: []*regexp.Regexp{ - regexp.MustCompile(`myid`), - }, - }, - } - - for _, tc := range tt { - t.Run(tc.name, func(t *testing.T) { - ts := newTestServer(t) - sess := newSession(ts.transport()) - go sess.recv() - - ts.queueRespString(``) - - err := sess.CancelCommit(context.Background(), tc.options...) - assert.NoError(t, err) - - sentMsg, err := ts.popReq() - assert.NoError(t, err) - - for _, match := range tc.matches { - assert.Regexp(t, match, string(sentMsg)) - } - }) - } -} - -func TestCreateSubscription(t *testing.T) { - start := time.Date(2023, time.June, 07, 18, 31, 48, 00, time.UTC) - end := time.Date(2023, time.June, 07, 18, 33, 48, 00, time.UTC) - - tt := []struct { - name string - options []CreateSubscriptionOption - matches []*regexp.Regexp - }{ - { - name: "noOptions", - matches: []*regexp.Regexp{ - regexp.MustCompile(``), - }, - }, - { - name: "startTime option", - options: []CreateSubscriptionOption{WithStartTimeOption(start)}, - matches: []*regexp.Regexp{ - regexp.MustCompile(`` + regexp.QuoteMeta(start.Format(time.RFC3339)) + ``), - }, - }, - { - name: "endTime option", - options: []CreateSubscriptionOption{WithEndTimeOption(end)}, - matches: []*regexp.Regexp{ - regexp.MustCompile(`` + regexp.QuoteMeta(end.Format(time.RFC3339)) + ``), - }, - }, - { - name: "stream option", - options: []CreateSubscriptionOption{WithStreamOption("thestream")}, - matches: []*regexp.Regexp{ - regexp.MustCompile(`thestream`), - }, - }, - } - - for _, tc := range tt { - t.Run(tc.name, func(t *testing.T) { - ts := newTestServer(t) - sess := newSession(ts.transport()) - go sess.recv() - - ts.queueRespString(``) - - err := sess.CreateSubscription(context.Background(), tc.options...) - assert.NoError(t, err) - - sentMsg, err := ts.popReq() - assert.NoError(t, err) - - for _, match := range tc.matches { - assert.Regexp(t, match, string(sentMsg)) - } - }) - } -} diff --git a/rpc/config.go b/rpc/config.go new file mode 100644 index 0000000..ea24ad0 --- /dev/null +++ b/rpc/config.go @@ -0,0 +1,472 @@ +package rpc + +import ( + "context" + "encoding/xml" + "fmt" + + "nemith.io/netconf" +) + +// Datastore represents a NETCONF configuration datastore as defined in +// RFC6241 section 7.1 +type Datastore string + +func (d Datastore) MarshalXML(e *xml.Encoder, start xml.StartElement) error { + if d == "" { + return fmt.Errorf("datastore name cannot be empty") + } + + for i := range len(d) { + c := d[i] + if (c < 'a' || c > 'z') && + (c < 'A' || c > 'Z') && + (c < '0' || c > '9') && + c != '_' && c != '-' && c != '.' { + return fmt.Errorf("invalid datastore name: %q", d) + } + } + + inner := struct { + Elem string `xml:",innerxml"` + }{Elem: "<" + string(d) + "/>"} + + // EncodeElement(nil, ...) creates a self-closing tag + return e.EncodeElement(&inner, start) +} + +const ( + // Running configuration datastore. Required by RFC6241 + Running Datastore = "running" + + // Candidate configuration configuration datastore. Supported with the + // `:candidate` capability defined in RFC6241 section 8.3 + Candidate Datastore = "candidate" + + // Startup configuration configuration datastore. Supported with the + // `:startup` capability defined in RFC6241 section 8.7 + Startup Datastore = "startup" +) + +type URL string + +func (u URL) MarshalXML(e *xml.Encoder, start xml.StartElement) error { + v := struct { + URL string `xml:"url"` + }{string(u)} + return e.EncodeElement(&v, start) +} + +// GetConfig implements the rpc operation defined in [RFC6241 7.1]. +// `source` is the datastore to query. +// +// [RFC6241 7.1]: https://www.rfc-editor.org/rfc/rfc6241.html#section-7.1 +type GetConfig struct { + Source Datastore + Filter Filter +} + +func (op GetConfig) MarshalXML(e *xml.Encoder, start xml.StartElement) error { + req := struct { + XMLName xml.Name `xml:"urn:ietf:params:xml:ns:netconf:base:1.0 get-config"` + Source Datastore `xml:"source"` + Filter Filter `xml:"filter,omitempty"` + }{ + Source: op.Source, + Filter: op.Filter, + } + + return e.Encode(&req) +} + +func (rpc GetConfig) Exec(ctx context.Context, session *netconf.Session) ([]byte, error) { + var reply GetConfigReply + if err := session.Exec(ctx, rpc, &reply); err != nil { + return nil, err + } + + return reply.Config, nil +} + +type GetConfigReply struct { + Config []byte +} + +func (r *GetConfigReply) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error { + type dataWrapper struct { + Data struct { + Inner []byte `xml:",innerxml"` + } `xml:"data"` + } + + var wrapper dataWrapper + if err := d.DecodeElement(&wrapper, &start); err != nil { + return err + } + r.Config = wrapper.Data.Inner + return nil +} + +// DefaultOperation defines the strategies for merging configuration in a +// ` operation`. +type DefaultOperation string + +const ( + // MergeConfig configuration elements are merged together at the level at + // which this specified. Can be used for config elements as well as default. + MergeConfig DefaultOperation = "merge" + + // ReplaceConfig defines that the incoming config change should replace the + // existing config at the level which it is specified. This can be + // specified on individual config elements or set as the default strategy. + ReplaceConfig DefaultOperation = "replace" + + // NoneOperation indicates that no default operation should be applied and + // nothing is applied to the target configuration unless there are + // operations defined on the configs subelements. + NoneOperation DefaultOperation = "none" +) + +// TestOption defines the behavior for testing configuration before applying it +// in a `` operation. +type TestOption string + +const ( + // TestThenSet will validate the configuration and only if is is valid then + // apply the configuration to the datastore. + TestThenSet TestOption = "test-then-set" + + // SetOnly will not do any testing before applying it. + SetOnly TestOption = "set" + + // Test only will validatate the incoming configuration and return the + // results without modifying the underlying store. + TestOnly TestOption = "test-only" +) + +// ErrorOption defines the behavior when an error is encountered during a +// `` operation. +type ErrorOption string + +const ( + // StopOnError will abort the `` operation on the first error. + StopOnError ErrorOption = "stop-on-error" + + // ContinueOnError will continue to parse the configuration data even if an + // error is encountered. Errors are still recorded and reported in the + // reply. + ContinueOnError ErrorOption = "continue-on-error" + + // RollbackOnError will restore the configuration back to before the + // `` operation took place. This requires the device to + // support the `:rollback-on-error` capability. + RollbackOnError ErrorOption = "rollback-on-error" +) + +// EditConfig issues the `` operation defined in [RFC6241 7.2] for +// updating an existing target config datastore. +// +// [RFC6241 7.2]: https://www.rfc-editor.org/rfc/rfc6241.html#section-7.2 +type EditConfig struct { + Target Datastore + DefaultOperation DefaultOperation + TestOption TestOption + ErrorOption ErrorOption + Config any +} + +func (rpc EditConfig) MarshalXML(e *xml.Encoder, start xml.StartElement) error { + req := struct { + XMLName xml.Name `xml:"urn:ietf:params:xml:ns:netconf:base:1.0 edit-config"` + Target Datastore `xml:"target"` + DefaultOperation DefaultOperation `xml:"default-operation,omitempty"` + TestOption TestOption `xml:"test-option,omitempty"` + ErrorOption ErrorOption `xml:"error-option,omitempty"` + + Config any `xml:"config,omitempty"` + URL string `xml:"url,omitempty"` + }{ + Target: rpc.Target, + DefaultOperation: rpc.DefaultOperation, + TestOption: rpc.TestOption, + ErrorOption: rpc.ErrorOption, + } + + switch v := rpc.Config.(type) { + case URL: + req.URL = string(v) + case string: + req.Config = struct { + Inner string `xml:",innerxml"` + }{Inner: v} + case []byte: + req.Config = struct { + Inner []byte `xml:",innerxml"` + }{Inner: v} + default: + req.Config = rpc.Config + } + + return e.Encode(&req) +} + +func (rpc EditConfig) Exec(ctx context.Context, session *netconf.Session) error { + var resp OkReply + if err := session.Exec(ctx, rpc, &resp); err != nil { + return err + } + + if !resp.OK { + return fmt.Errorf("edit-config: operation failed, not received") + } + return nil +} + +// CopyConfig issues the `` operation as defined in [RFC6241 7.3] +// for copying an entire config to/from a source and target datastore. +// +// A `` element defining a full config can be used as the source. +// +// If a device supports the `:url` capability than a [URL] object can be used +// for the source or target datastore. +// +// [RFC6241 7.3] https://www.rfc-editor.org/rfc/rfc6241.html#section-7.3 +type CopyConfig struct { + Source any + Target any +} + +func (rpc CopyConfig) MarshalXML(e *xml.Encoder, start xml.StartElement) error { + req := struct { + XMLName xml.Name `xml:"urn:ietf:params:xml:ns:netconf:base:1.0 copy-config"` + Source any `xml:"source"` + Target any `xml:"target"` + }{ + Source: rpc.Source, + Target: rpc.Target, + } + + return e.Encode(&req) +} + +func (rpc CopyConfig) Exec(ctx context.Context, session *netconf.Session) error { + var resp OkReply + if err := session.Exec(ctx, rpc, &resp); err != nil { + return err + } + + if !resp.OK { + return fmt.Errorf("copy-config: operation failed, not received") + } + + return nil +} + +// DeleteConfigReq represents the `` operation defined in +// [RFC6241 7.4] for deleting a configuration datastore. +// +// [RFC6241 7.4]: https://www.rfc-editor.org/rfc/rfc6241.html#section-7.4 +type DeleteConfig struct { + Target Datastore +} + +func (rpc DeleteConfig) MarshalXML(e *xml.Encoder, start xml.StartElement) error { + req := struct { + XMLName xml.Name `xml:"urn:ietf:params:xml:ns:netconf:base:1.0 delete-config"` + Target any `xml:"target"` + }{ + Target: rpc.Target, + } + + return e.Encode(&req) +} + +func (rpc DeleteConfig) Exec(ctx context.Context, session *netconf.Session) error { + var resp OkReply + if err := session.Exec(ctx, rpc, &resp); err != nil { + return err + } + + if !resp.OK { + return fmt.Errorf("delete-config: operation failed, not received") + } + + return nil +} + +// LockReq represents the `` operation defined in [RFC6241 7.5] for +// locking a configuration datastore. +// +// [RFC6241 7.5]: https://www.rfc-editor.org/rfc/rfc6241.html#section-7.5 +type Lock struct { + Target Datastore +} + +func (rpc Lock) MarshalXML(e *xml.Encoder, start xml.StartElement) error { + req := struct { + XMLName xml.Name `xml:"urn:ietf:params:xml:ns:netconf:base:1.0 lock"` + Target Datastore `xml:"target"` + }{ + Target: rpc.Target, + } + + return e.Encode(&req) +} + +func (rpc Lock) Exec(ctx context.Context, session *netconf.Session) error { + var resp OkReply + if err := session.Exec(ctx, rpc, &resp); err != nil { + return err + } + + if !resp.OK { + return fmt.Errorf("lock: operation failed, not received") + } + + return nil +} + +type Unlock struct { + Target Datastore +} + +func (rpc Unlock) MarshalXML(e *xml.Encoder, start xml.StartElement) error { + req := struct { + XMLName xml.Name `xml:"urn:ietf:params:xml:ns:netconf:base:1.0 unlock"` + Target Datastore `xml:"target"` + }{ + Target: rpc.Target, + } + + return e.Encode(&req) +} + +func (rpc Unlock) Exec(ctx context.Context, session *netconf.Session) error { + var resp OkReply + if err := session.Exec(ctx, rpc, &resp); err != nil { + return err + } + + if !resp.OK { + return fmt.Errorf("unlock: operation failed, not received") + } + + return nil +} + +type Validate struct { + Source any +} + +func (rpc Validate) MarshalXML(e *xml.Encoder, start xml.StartElement) error { + req := struct { + XMLName xml.Name `xml:"urn:ietf:params:xml:ns:netconf:base:1.0 validate"` + Source any `xml:"source"` + }{ + Source: rpc.Source, + } + + return e.Encode(&req) +} + +func (rpc Validate) Exec(ctx context.Context, session *netconf.Session) error { + var resp OkReply + if err := session.Exec(ctx, rpc, &resp); err != nil { + return err + } + + if !resp.OK { + return fmt.Errorf("validate: operation failed, not received") + } + + return nil +} + +// Commit represents the `` operation defined in [RFC6241 8.5] for +// committing candidate configuration to the running datastore. +// +// [RFC6241 8.5]: https://www.rfc-editor.org/rfc/rfc6241.html#section-8.5 +type Commit struct { + // Confirmed indicates that the commit must be confirmed with a follow-up + // commit within the confirm-timeout period (default 600 seconds). If not + // confirmed, the commit will be reverted. + // + // Device must support :confirmed-commit:1.1 capability. + Confirmed bool + + // ConfirmTimeout is the time in seconds to wait before reverting a + // confirmed commit. + // + // Device must support :confirmed-commit:1.1 capability. + ConfirmTimeout int64 + + // Persist indicates that the confirmed commit can be persisted across + // sessions and confirmed in a different session. + // + // If Confirmed is set this expands to the element. + // + // If Confirmed is not set this expands to the element to + // confirm a previous commit with the same id. + PersistID string +} + +func (rpc Commit) MarshalXML(e *xml.Encoder, start xml.StartElement) error { + req := struct { + XMLName xml.Name `xml:"urn:ietf:params:xml:ns:netconf:base:1.0 commit"` + Confirmed ExtantBool `xml:"confirmed,omitempty"` + ConfirmTimeout int64 `xml:"confirm-timeout,omitempty"` + Persist string `xml:"persist,omitempty"` + PersistID string `xml:"persist-id,omitempty"` + }{ + Confirmed: ExtantBool(rpc.Confirmed), + ConfirmTimeout: rpc.ConfirmTimeout, + } + + if rpc.PersistID != "" { + if rpc.Confirmed { + req.Persist = rpc.PersistID + } else { + req.PersistID = rpc.PersistID + } + } + + return e.Encode(&req) +} + +func (rpc Commit) Exec(ctx context.Context, session *netconf.Session) error { + var resp OkReply + if err := session.Exec(ctx, rpc, &resp); err != nil { + return err + } + + if !resp.OK { + return fmt.Errorf("commit: operation failed, not received") + } + return nil +} + +type CancelCommit struct { + PersistID string +} + +func (rpc CancelCommit) MarshalXML(e *xml.Encoder, start xml.StartElement) error { + req := struct { + XMLName xml.Name `xml:"urn:ietf:params:xml:ns:netconf:base:1.0 cancel-commit"` + PersistID string `xml:"persist-id,omitempty"` + }{ + PersistID: rpc.PersistID, + } + return e.Encode(&req) +} + +func (rpc CancelCommit) Exec(ctx context.Context, session *netconf.Session) error { + var resp OkReply + if err := session.Exec(ctx, rpc, &resp); err != nil { + return err + } + + if !resp.OK { + return fmt.Errorf("cancel-commit: operation failed, not received") + } + return nil +} diff --git a/rpc/config_test.go b/rpc/config_test.go new file mode 100644 index 0000000..1fc6709 --- /dev/null +++ b/rpc/config_test.go @@ -0,0 +1,600 @@ +package rpc + +import ( + "encoding/xml" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestMarshalDatastore(t *testing.T) { + tt := []struct { + input Datastore + want string + shouldErr bool + }{ + {Running, "", false}, + {Startup, "", false}, + {Candidate, "", false}, + {Datastore("custom-store"), "", false}, + {Datastore(""), "", true}, + {Datastore(""), "", true}, + } + + for _, tc := range tt { + t.Run(string(tc.input), func(t *testing.T) { + v := struct { + XMLName xml.Name `xml:"rpc"` + Target Datastore `xml:"target"` + }{Target: tc.input} + + got, err := xml.Marshal(&v) + if !tc.shouldErr { + assert.NoError(t, err) + } + assert.Equal(t, tc.want, string(got)) + }) + } +} + +func TestGetConfig_MarshalXML(t *testing.T) { + tests := []struct { + name string + op GetConfig + expected string + }{ + { + name: "basic", + op: GetConfig{ + Source: Running, + }, + expected: ``, + }, + { + name: "subtreeFilter", + op: GetConfig{ + Source: Running, + Filter: SubtreeFilter(``), + }, + expected: ``, + }, + { + name: "xpathFilter", + op: GetConfig{ + Source: Running, + Filter: XPathFilter("/interfaces/interface/name", nil), + }, + expected: ``, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := xml.Marshal(tt.op) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + assert.Equal(t, tt.expected, string(got)) + }) + } +} +func TestGetConfig_Exec(t *testing.T) { + tests := []struct { + name string + op GetConfig + serverReply string + shouldError bool + expected string + }{ + { + name: "good reply", + op: GetConfig{ + Source: Datastore("my-datastore"), + }, + serverReply: `root`, + shouldError: false, + expected: `root`, + }, + } + + for _, tc := range tests { + session, _ := mockSession(t, tc.serverReply) + got, err := tc.op.Exec(t.Context(), session) + assert.NoError(t, err) + expected := `root` + assert.Equal(t, expected, string(got)) + + } + +} + +func TestEditConfig_MarshalXML(t *testing.T) { + tests := []struct { + name string + op EditConfig + expected string + }{ + { + name: "stringConfig", + op: EditConfig{ + Target: Running, + Config: `eth0`, + }, + // Expect: ...content... + expected: `eth0`, + }, + { + name: "byteSliceConfig", + op: EditConfig{ + Target: Running, + Config: []byte(`eth0`), + }, + // Expect: ...content... + expected: `eth0`, + }, + { + name: "urlConfig", + op: EditConfig{ + Target: Candidate, + Config: URL("https://example.com/config.xml"), + }, + // Expect: ... NOT wrapped in + expected: `https://example.com/config.xml`, + }, + { + name: "optionsSet", + op: EditConfig{ + Target: Running, + DefaultOperation: ReplaceConfig, + TestOption: TestThenSet, + ErrorOption: RollbackOnError, + Config: "foo", + }, + expected: `replacetest-then-setrollback-on-errorfoo`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := xml.Marshal(tt.op) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + assert.Equal(t, tt.expected, string(got)) + }) + } +} + +func TestEditConfig_Exec(t *testing.T) { + tests := []struct { + name string + op EditConfig + serverReply string + shouldError bool + }{ + { + name: "okReply", + op: EditConfig{ + Target: Running, + Config: `eth0`, + }, + serverReply: ``, + shouldError: false, + }, + } + + for _, tc := range tests { + session, _ := mockSession(t, tc.serverReply) + err := tc.op.Exec(t.Context(), session) + if tc.shouldError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + } +} + +func TestCopyConfig_MarshalXML(t *testing.T) { + tests := []struct { + name string + op CopyConfig + expected string + }{ + { + name: "basic", + op: CopyConfig{ + Source: URL("ftp://example.com/config.xml"), + Target: Running, + }, + expected: `ftp://example.com/config.xml`, + }, + { + name: "withDefault", + op: CopyConfig{ + Source: Startup, + Target: Candidate, + }, + expected: ``, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := xml.Marshal(tt.op) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + assert.Equal(t, tt.expected, string(got)) + }) + } +} + +func TestCopyConfig_Exec(t *testing.T) { + tests := []struct { + name string + op CopyConfig + serverReply string + shouldError bool + }{ + { + name: "okReply", + op: CopyConfig{ + Source: URL("ftp://example.com/config.xml"), + Target: Running, + }, + serverReply: ``, + shouldError: false, + }, + } + + for _, tc := range tests { + session, _ := mockSession(t, tc.serverReply) + err := tc.op.Exec(t.Context(), session) + if tc.shouldError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + } +} + +func TestDeleteConfig_MarshalXML(t *testing.T) { + tests := []struct { + name string + op DeleteConfig + expected string + }{ + { + name: "basic", + op: DeleteConfig{ + Target: Datastore("my-custom-datastore"), + }, + expected: ``, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := xml.Marshal(tt.op) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + assert.Equal(t, tt.expected, string(got)) + }) + } +} + +func TestDeleteConfig_Exec(t *testing.T) { + tests := []struct { + name string + op DeleteConfig + serverReply string + shouldError bool + }{ + { + name: "okReply", + op: DeleteConfig{ + Target: Datastore("my-custom-datastore"), + }, + serverReply: ``, + shouldError: false, + }, + } + + for _, tc := range tests { + session, _ := mockSession(t, tc.serverReply) + err := tc.op.Exec(t.Context(), session) + if tc.shouldError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + } +} + +func TestLock_MarshalXML(t *testing.T) { + tests := []struct { + name string + op Lock + expected string + }{ + { + name: "basic", + op: Lock{ + Target: Running, + }, + expected: ``, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := xml.Marshal(tt.op) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + assert.Equal(t, tt.expected, string(got)) + }) + } +} + +func TestLock_Exec(t *testing.T) { + tests := []struct { + name string + op Lock + serverReply string + shouldError bool + }{ + { + name: "okReply", + op: Lock{ + Target: Running, + }, + serverReply: ``, + shouldError: false, + }, + } + + for _, tc := range tests { + session, _ := mockSession(t, tc.serverReply) + err := tc.op.Exec(t.Context(), session) + if tc.shouldError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + } +} + +func TestUnlock_MarshalXML(t *testing.T) { + tests := []struct { + name string + op Unlock + expected string + }{ + { + name: "basic", + op: Unlock{ + Target: Running, + }, + expected: ``, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := xml.Marshal(tt.op) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + assert.Equal(t, tt.expected, string(got)) + }) + } +} + +func TestUnlock_Exec(t *testing.T) { + tests := []struct { + name string + op Unlock + serverReply string + shouldError bool + }{ + { + name: "okReply", + op: Unlock{ + Target: Running, + }, + serverReply: ``, + shouldError: false, + }, + } + + for _, tc := range tests { + session, _ := mockSession(t, tc.serverReply) + err := tc.op.Exec(t.Context(), session) + if tc.shouldError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + } +} + +func TestValidate_MarshalXML(t *testing.T) { + tests := []struct { + name string + op Validate + expected string + }{ + { + name: "basic", + op: Validate{ + Source: Running, + }, + expected: ``, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := xml.Marshal(tt.op) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + assert.Equal(t, tt.expected, string(got)) + }) + } +} + +func TestValidate_Exec(t *testing.T) { + tests := []struct { + name string + op Validate + serverReply string + shouldError bool + }{ + { + name: "okReply", + op: Validate{ + Source: Running, + }, + serverReply: ``, + shouldError: false, + }, + } + + for _, tc := range tests { + session, _ := mockSession(t, tc.serverReply) + err := tc.op.Exec(t.Context(), session) + if tc.shouldError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + } +} + +func TestCommit_MarshalXML(t *testing.T) { + tests := []struct { + name string + op Commit + expected string + }{ + { + name: "basic", + op: Commit{}, + expected: ``, + }, + { + name: "confirmed", + op: Commit{ + Confirmed: true, + ConfirmTimeout: 300, + }, + expected: `300`, + }, + { + name: "confirmedPersist", + op: Commit{ + Confirmed: true, + PersistID: "foobar", + }, + expected: `foobar`, + }, + { + name: "confirmPersistID", + op: Commit{ + PersistID: "foobar2", + }, + expected: `foobar2`, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := xml.Marshal(tt.op) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + assert.Equal(t, tt.expected, string(got)) + }) + } +} + +func TestCommit_Exec(t *testing.T) { + tests := []struct { + name string + op Commit + serverReply string + shouldError bool + }{ + { + name: "okReply", + op: Commit{ + Confirmed: true, + ConfirmTimeout: 200, + }, + serverReply: ``, + shouldError: false, + }, + } + + for _, tc := range tests { + session, _ := mockSession(t, tc.serverReply) + err := tc.op.Exec(t.Context(), session) + if tc.shouldError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + } +} + +func TestCancelCommit_MarshalXML(t *testing.T) { + tests := []struct { + name string + op CancelCommit + expected string + }{ + { + name: "basic", + op: CancelCommit{}, + expected: ``, + }, + { + name: "persistID", + op: CancelCommit{ + PersistID: "persist-123", + }, + expected: `persist-123`, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := xml.Marshal(tt.op) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + assert.Equal(t, tt.expected, string(got)) + }) + } +} + +func TestCancelCommit_Exec(t *testing.T) { + tests := []struct { + name string + op CancelCommit + serverReply string + shouldError bool + }{ + { + name: "okReply", + op: CancelCommit{}, + serverReply: ``, + shouldError: false, + }, + } + + for _, tc := range tests { + session, _ := mockSession(t, tc.serverReply) + err := tc.op.Exec(t.Context(), session) + if tc.shouldError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + } +} diff --git a/rpc/filter.go b/rpc/filter.go new file mode 100644 index 0000000..4b0aa4c --- /dev/null +++ b/rpc/filter.go @@ -0,0 +1,72 @@ +package rpc + +import ( + "encoding/xml" + "maps" + "slices" +) + +type Filter interface { + xml.Marshaler + filter() +} + +type subtreeFilter struct { + f any +} + +func (f subtreeFilter) filter() {} + +func (f subtreeFilter) MarshalXML(e *xml.Encoder, start xml.StartElement) error { + start.Attr = append(start.Attr, xml.Attr{Name: xml.Name{Local: "type"}, Value: "subtree"}) + + switch v := f.f.(type) { + case string: + inner := struct { + Data string `xml:",innerxml"` + }{Data: v} + return e.EncodeElement(&inner, start) + case []byte: + inner := struct { + Data []byte `xml:",innerxml"` + }{Data: v} + return e.EncodeElement(&inner, start) + default: + return e.EncodeElement(f.f, start) + + } +} + +// SubtreeFilter creates a filter matching the provided XML structure(s). +// Multiple arguments are merged into a single filter element as siblings. +func SubtreeFilter(filter any) Filter { + return subtreeFilter{f: filter} +} + +type xpathFilter struct { + Select string + Namespaces map[string]string +} + +func (f xpathFilter) filter() {} + +func (f xpathFilter) MarshalXML(e *xml.Encoder, start xml.StartElement) error { + start.Attr = append(start.Attr, + xml.Attr{Name: xml.Name{Local: "type"}, Value: "xpath"}, + xml.Attr{Name: xml.Name{Local: "select"}, Value: f.Select}, + ) + + for _, prefix := range slices.Sorted(maps.Keys(f.Namespaces)) { + uri := f.Namespaces[prefix] + attrName := "xmlns:" + prefix + start.Attr = append(start.Attr, xml.Attr{Name: xml.Name{Local: attrName}, Value: uri}) + } + + return e.EncodeElement(struct{}{}, start) +} + +// XPathFilter creates a filter using XPath 1.0 expression. +// namespaces map prefixes used in the path to their URIs. +func XPathFilter(path string, namespaces map[string]string) Filter { + return xpathFilter{Select: path, Namespaces: namespaces} +} diff --git a/rpc/filter_test.go b/rpc/filter_test.go new file mode 100644 index 0000000..8cbcaf8 --- /dev/null +++ b/rpc/filter_test.go @@ -0,0 +1,94 @@ +package rpc + +import ( + "encoding/xml" + "testing" + + "github.com/stretchr/testify/assert" +) + +// Helper struct for testing struct marshaling +type InterfaceFilter struct { + XMLName xml.Name `xml:"interfaces"` + Name string `xml:"interfaces>interface>name,omitempty"` +} + +func TestSubtreeFilter_MarshalXML(t *testing.T) { + tests := []struct { + name string + input Filter + expected string + }{ + { + name: "string", + input: SubtreeFilter(``), + expected: ``, + }, + { + name: "bytes", + input: SubtreeFilter([]byte(``)), + expected: ``, + }, + { + name: "struct", + input: SubtreeFilter(InterfaceFilter{Name: "eth0"}), + expected: `eth0`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + wrapper := struct { + XMLName xml.Name `xml:"root"` + F Filter `xml:"filter"` + }{F: tt.input} + + out, err := xml.Marshal(&wrapper) + assert.NoError(t, err) + assert.Equal(t, tt.expected, string(out)) + }) + } +} + +func TestXPathFilter_MarshalXML(t *testing.T) { + tests := []struct { + name string + input Filter + expected string + }{ + { + name: "xpath", + input: XPathFilter("/interfaces/interface/name", nil), + // Note: Attributes order in map iteration is random, but here we have none. + // Go's XML encoder usually alphabetizes attributes. + expected: ``, + }, + { + name: "xpathNamespaces", + input: XPathFilter( + "/if:interfaces/if:interface", + map[string]string{ + "if": "urn:ietf:params:xml:ns:yang:ietf-interfaces", + }, + ), + // Expected outcome needs to check for the xmlns attribute. + // Since map iteration order is random, exact string match might be flaky if we had multiple NS. + // But with one NS, it's deterministic. + expected: ``, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + wrapper := struct { + XMLName xml.Name `xml:"root"` + F Filter `xml:"filter"` + }{F: tt.input} + + out, err := xml.Marshal(&wrapper) + assert.NoError(t, err) + assert.Equal(t, tt.expected, string(out)) + }) + } +} diff --git a/rpc/rpc.go b/rpc/rpc.go new file mode 100644 index 0000000..3ad8d82 --- /dev/null +++ b/rpc/rpc.go @@ -0,0 +1,64 @@ +package rpc + +import ( + "context" + "encoding/xml" + + "nemith.io/netconf" +) + +type ExtantBool bool + +func (b ExtantBool) MarshalXML(e *xml.Encoder, start xml.StartElement) error { + if !b { + return nil + } + // This produces a empty start/end tag (i.e ) vs a self-closing + // tag () which should be the same in XML, however I know certain + // vendors may have issues with this format. We may have to process this + // after xml encoding. + // + // See https://github.com/golang/go/issues/21399 + // or https://github.com/golang/go/issues/26756 for a different hack. + return e.EncodeElement(struct{}{}, start) +} + +func (b *ExtantBool) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error { + *b = true + return d.Skip() +} + +type OkReply struct { + netconf.RPCReply + OK ExtantBool `xml:"ok"` +} + +type Get struct { + Filter Filter `xml:"filter,omitempty"` +} + +func (rpc *Get) MarshalXML(e *xml.Encoder, start xml.StartElement) error { + req := struct { + XMLName xml.Name `xml:"urn:ietf:params:xml:ns:netconf:base:1.0 get"` + Filter Filter `xml:"filter,omitempty"` + }{ + Filter: rpc.Filter, + } + return e.Encode(&req) +} + +type GetReply struct { + netconf.RPCReply + Data struct { + XML []byte `xml:",innerxml"` + } `xml:"data"` +} + +func (rpc *Get) Exec(ctx context.Context, session *netconf.Session) ([]byte, error) { + var resp GetReply + if err := session.Exec(ctx, rpc, &resp); err != nil { + return nil, err + } + + return resp.Data.XML, nil +} diff --git a/rpc/rpc_test.go b/rpc/rpc_test.go new file mode 100644 index 0000000..e9f9dcc --- /dev/null +++ b/rpc/rpc_test.go @@ -0,0 +1,103 @@ +package rpc + +import ( + "context" + "encoding/xml" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "nemith.io/netconf" + "nemith.io/netconf/transport" +) + +func mockSession(t *testing.T, rpcReplyInnerXML string) (*netconf.Session, *transport.TestTransport) { + tr := &transport.TestTransport{} + tr.AddResponse(` + + + urn:ietf:params:netconf:base:1.0 + + 42 + `) + + tr.AddResponse(fmt.Sprintf(` + + %s + `, rpcReplyInnerXML)) + + // 3. Create Session + // This will immediately consume the first message (Server Hello) + // and write the Client Hello to tr.outputs[0]. + s, err := netconf.Open(tr) + require.NoError(t, err, "Session handshake failed") + + return s, tr +} + +func TestUnmarshalOk(t *testing.T) { + tt := []struct { + name string + input string + want bool + }{ + {"selfclosing", ">", true}, + {"missing", "", false}, + {"closetag", "", true}, + } + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + var v struct { + XMLName xml.Name `xml:"foo"` + Ok ExtantBool `xml:"ok"` + } + + err := xml.Unmarshal([]byte(tc.input), &v) + assert.NoError(t, err) + assert.Equal(t, tc.want, bool(v.Ok)) + }) + } +} + +func TestGet_MarshalXML(t *testing.T) { + tests := []struct { + name string + op Get + expected string + }{ + { + name: "noFilter", + op: Get{}, + expected: ``, + }, + { + name: "withFilter", + op: Get{ + Filter: SubtreeFilter(``), + }, + expected: ``, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + out, err := xml.Marshal(&tt.op) + require.NoError(t, err) + assert.Equal(t, tt.expected, string(out)) + }) + } +} + +func TestGet_Exec(t *testing.T) { + const rpcReplyData = `eth0` + + sess, _ := mockSession(t, rpcReplyData) + + getOp := &Get{} + data, err := getOp.Exec(context.Background(), sess) + require.NoError(t, err) + + expectedData := `eth0` + assert.Equal(t, expectedData, string(data)) +} diff --git a/rpc/session.go b/rpc/session.go new file mode 100644 index 0000000..e09cd77 --- /dev/null +++ b/rpc/session.go @@ -0,0 +1,39 @@ +package rpc + +import ( + "context" + "encoding/xml" + "fmt" + + "nemith.io/netconf" +) + +// KillSessionReq represents the `` operation defined in +// [RFC6241 7.6] for terminating a NETCONF session. +// +// [RFC6241 7.6]: https://www.rfc-editor.org/rfc/rfc6241.html#section-7.6 +type KillSession struct { + SessionID uint +} + +func (rpc *KillSession) MarshalXML(e *xml.Encoder, start xml.StartElement) error { + req := struct { + XMLName xml.Name `xml:"urn:ietf:params:xml:ns:netconf:base:1.0 kill-session"` + SessionID uint `xml:"session-id"` + }{ + SessionID: rpc.SessionID, + } + return e.EncodeElement(&req, start) +} + +func (rpc *KillSession) Exec(ctx context.Context, session *netconf.Session) error { + var resp OkReply + if err := session.Exec(ctx, rpc, &resp); err != nil { + return err + } + + if !resp.OK { + return fmt.Errorf("kill-session: operation failed, not received") + } + return nil +} diff --git a/session.go b/session.go index 6c2aaf2..b085362 100644 --- a/session.go +++ b/session.go @@ -1,6 +1,7 @@ package netconf import ( + "bytes" "context" "encoding/xml" "errors" @@ -9,6 +10,7 @@ import ( "log" "net" "slices" + "strconv" "sync" "sync/atomic" "syscall" @@ -16,11 +18,15 @@ import ( "nemith.io/netconf/transport" ) +const ( + NetconfNamespace = "urn:ietf:params:xml:ns:netconf:base:1.0" + NotificationNamespace = "urn:ietf:params:xml:ns:netconf:notification:1.0" +) + var ErrClosed = errors.New("closed connection") type sessionConfig struct { - clientCaps []string - notificationHandler NotificationHandler + clientCaps []string } type SessionOption interface { @@ -37,37 +43,20 @@ func WithCapability(capabilities ...string) SessionOption { return capabilityOpt(capabilities) } -type notificationHandlerOpt NotificationHandler - -func (o notificationHandlerOpt) apply(cfg *sessionConfig) { - cfg.notificationHandler = NotificationHandler(o) -} - -func WithNotificationHandler(nh NotificationHandler) SessionOption { - return notificationHandlerOpt(nh) -} - // Session is represents a netconf session to a one given device. type Session struct { tr transport.Transport sessionID uint64 seq atomic.Uint64 - clientCaps CapabilitySet - serverCaps CapabilitySet - notificationHandler NotificationHandler + clientCaps CapabilitySet + serverCaps CapabilitySet mu sync.Mutex - reqs map[uint64]*req + reqs map[string]*pendingReq closing bool } -// NotificationHandler function allows to work with received notifications. -// A NotificationHandler function can be passed in as an option when calling Open method of Session object -// A typical use of the NofificationHandler function is to retrieve notifications once they are received so -// that they can be parsed and/or stored somewhere. -type NotificationHandler func(msg Notification) - func newSession(transport transport.Transport, opts ...SessionOption) *Session { cfg := sessionConfig{ clientCaps: DefaultCapabilities, @@ -78,10 +67,9 @@ func newSession(transport transport.Transport, opts ...SessionOption) *Session { } s := &Session{ - tr: transport, - clientCaps: NewCapabilitySet(cfg.clientCaps...), - reqs: make(map[uint64]*req), - notificationHandler: cfg.notificationHandler, + tr: transport, + clientCaps: NewCapabilitySet(cfg.clientCaps...), + reqs: make(map[string]*pendingReq), } return s } @@ -97,26 +85,42 @@ func Open(transport transport.Transport, opts ...SessionOption) (*Session, error return nil, err } - go s.recv() + go s.recvLoop() return s, nil } // handshake exchanges handshake messages and reports if there are any errors. func (s *Session) handshake() error { - clientMsg := helloMsg{ + clientMsg := HelloMsg{ Capabilities: slices.Collect(s.clientCaps.All()), } - if err := s.writeMsg(&clientMsg); err != nil { + + w, err := s.tr.MsgWriter() + if err != nil { + return fmt.Errorf("failed to get hello message writer: %w", err) + } + defer func() { + // TODO: expose this error + _ = w.Close() + }() + + if err := xml.NewEncoder(w).Encode(&clientMsg); err != nil { return fmt.Errorf("failed to write hello message: %w", err) } + if err := w.Close(); err != nil { + return fmt.Errorf("failed to close hello message writer: %w", err) + } r, err := s.tr.MsgReader() if err != nil { - return err + return fmt.Errorf("failed to get hello message reader: %w", err) } - defer r.Close() // nolint:errcheck // TODO: catch and log err + defer func() { + // TODO: expose this error + _ = r.Close() + }() - var serverMsg helloMsg + var serverMsg HelloMsg if err := xml.NewDecoder(r).Decode(&serverMsg); err != nil { return fmt.Errorf("failed to read server hello message: %w", err) } @@ -134,8 +138,7 @@ func (s *Session) handshake() error { // upgrade the transport if we are on a larger version and the transport // supports it. - const baseCap11 = baseCap + ":1.1" - if s.serverCaps.Has(baseCap11) && s.clientCaps.Has(baseCap11) { + if s.serverCaps.Has(CapNetConf11) && s.clientCaps.Has(CapNetConf11) { if upgrader, ok := s.tr.(interface{ Upgrade() }); ok { upgrader.Upgrade() } @@ -176,178 +179,213 @@ func startElement(d *xml.Decoder) (*xml.StartElement, error) { } } -type req struct { - reply chan Reply +type pendingReq struct { + reply chan *Response ctx context.Context } -func (s *Session) recvMsg() error { - r, err := s.tr.MsgReader() - if err != nil { - return err - } - defer r.Close() // nolint:errcheck // TODO: catch error and log - dec := xml.NewDecoder(r) - - root, err := startElement(dec) - if err != nil { - return err - } - - const ( - ncNamespace = "urn:ietf:params:xml:ns:netconf:base:1.0" - notifNamespace = "urn:ietf:params:xml:ns:netconf:notification:1.0" - ) - - switch root.Name { - case xml.Name{Space: notifNamespace, Local: "notification"}: - if s.notificationHandler == nil { - return nil - } - var notif Notification - if err := dec.DecodeElement(¬if, root); err != nil { - return fmt.Errorf("failed to decode notification message: %w", err) - } - s.notificationHandler(notif) - case xml.Name{Space: ncNamespace, Local: "rpc-reply"}: - var reply Reply - if err := dec.DecodeElement(&reply, root); err != nil { - // What should we do here? Kill the connection? - return fmt.Errorf("failed to decode rpc-reply message: %w", err) - } - ok, req := s.req(reply.MessageID) - if !ok { - return fmt.Errorf("cannot find reply channel for message-id: %d", reply.MessageID) - } +type replyReader struct { + io.Reader + closer io.Closer - select { - case req.reply <- reply: - return nil - case <-req.ctx.Done(): - return fmt.Errorf("message %d context canceled: %s", reply.MessageID, req.ctx.Err().Error()) - } - default: - return fmt.Errorf("unknown message type: %q", root.Name.Local) - } - return nil + done chan struct{} + once sync.Once } -// recv is the main receive loop. It runs concurrently to be able to handle -// interleaved messages (like notifications). -func (s *Session) recv() { +func (r *replyReader) Close() error { var err error - var opErr *net.OpError + r.once.Do(func() { + err = r.closer.Close() + close(r.done) + }) + return err +} +// recvLoop is the main receive loop. It runs concurrently to be able to handle +// interleaved messages (like notifications). +func (s *Session) recvLoop() { + // buffer used to "peel" into the message enough to read the first element + // (i.e or ) + buf := make([]byte, 4096) for { - err = s.recvMsg() - if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) || errors.As(err, &opErr) { + if err := s.recvMsg(buf); err != nil { + log.Printf("netconf: failed to receive message: %v", err) break } - if err != nil { - log.Printf("netconf: failed to read incoming message: %v", err) - } } - s.mu.Lock() - defer s.mu.Unlock() - // Close all outstanding requests + // Final cleanup when the loop exits + s.mu.Lock() for _, req := range s.reqs { close(req.reply) } + s.mu.Unlock() + // TODO: expose this error + _ = s.tr.Close() if !s.closing { log.Printf("netconf: connection closed unexpectedly") } } -func (s *Session) req(msgID uint64) (bool, *req) { - s.mu.Lock() - defer s.mu.Unlock() - - req, ok := s.reqs[msgID] - if !ok { - return false, nil +func getMessageID(attrs []xml.Attr) string { + for _, attr := range attrs { + if attr.Name.Local == "message-id" { + return attr.Value + } } - delete(s.reqs, msgID) - return true, req + return "" } -func (s *Session) writeMsg(v any) error { - w, err := s.tr.MsgWriter() +func (s *Session) recvMsg(buf []byte) error { + r, err := s.tr.MsgReader() if err != nil { return err } + defer func() { + // TODO: expose this error + _ = r.Close() + }() - if err := xml.NewEncoder(w).Encode(v); err != nil { + // 3. Peek/Read the start of the message + n, err := r.Read(buf) + if err != nil && !errors.Is(err, io.EOF) { + // It is okay to return EOF here; recv() handles the check. return err } - return w.Close() -} -func (s *Session) send(ctx context.Context, msg *request) (chan Reply, error) { - s.mu.Lock() - defer s.mu.Unlock() + chunk := buf[:n] + decoder := xml.NewDecoder(bytes.NewReader(chunk)) - if err := s.writeMsg(msg); err != nil { - return nil, err + startElem, err := startElement(decoder) + if err != nil { + return fmt.Errorf("failed to parse message start: %w", err) } - // cap of 1 makes sure we don't block on send - ch := make(chan Reply, 1) - s.reqs[msg.MessageID] = &req{ - reply: ch, - ctx: ctx, - } + msgReader := io.MultiReader(bytes.NewReader(chunk), r) + + switch startElem.Name { + case xml.Name{Space: NetconfNamespace, Local: "rpc-reply"}: + msgID := getMessageID(startElem.Attr) + if msgID == "" { + log.Printf("netconf: rpc-reply missing message-id") + return nil // Continue loop + } + + s.mu.Lock() + req, ok := s.reqs[msgID] + delete(s.reqs, msgID) + s.mu.Unlock() + + if !ok { + log.Printf("netconf: unexpected rpc-reply with message-id %s (possible timeout?)", msgID) + return nil // Continue loop + } + + readDone := make(chan struct{}) + reader := &replyReader{ + Reader: msgReader, + closer: r, // The raw transport reader + done: readDone, + } - return ch, nil + select { + case req.reply <- &Response{ + ReadCloser: reader, + MessageID: msgID, + Attributes: startElem.Attr, + }: + // We wait for the user to call Close() on the replyReader. + <-readDone + return nil + + case <-req.ctx.Done(): + return nil + } + + default: + return fmt.Errorf("netconf: unknown message type: %s", startElem.Name.Local) + } } -// Do issues a rpc call for the given NETCONF operation returning a Reply. RPC -// errors (i.e erros in the `` section of the ``) are -// converted into go errors automatically. Instead use `reply.Err()` or -// `reply.RPCErrors` to access the errors and/or warnings. -func (s *Session) Do(ctx context.Context, req any) (*Reply, error) { - msg := &request{ - MessageID: s.seq.Add(1), - Operation: req, +// Do issues a rpc message for the given Request. This is a low-level method +// that doesn't try to decode the response including any rpc-errors. +func (s *Session) Do(ctx context.Context, req *Request) (*Response, error) { + msgID := strconv.FormatUint(s.seq.Add(1), 10) + req.RPC.MessageID = msgID + + // Setup channel + ch := make(chan *Response, 1) + s.mu.Lock() + s.reqs[msgID] = &pendingReq{ + reply: ch, + ctx: ctx, } + s.mu.Unlock() + + // Cleanup if context triggers before send/recv + defer func() { + s.mu.Lock() + delete(s.reqs, msgID) + s.mu.Unlock() + }() - ch, err := s.send(ctx, msg) + w, err := s.tr.MsgWriter() if err != nil { - return nil, err + return nil, fmt.Errorf("failed to get message writer: %w", err) + } + if err := xml.NewEncoder(w).Encode(req.RPC); err != nil { + _ = w.Close() // try to close anyway + return nil, fmt.Errorf("failed to encode request: %w", err) + } + if err := w.Close(); err != nil { + return nil, fmt.Errorf("failed to flush request: %w", err) } - // wait for reply or context to be cancelled. + // Wait for the response select { - case reply, ok := <-ch: + case resp, ok := <-ch: if !ok { return nil, ErrClosed } - return &reply, nil + return resp, nil case <-ctx.Done(): - // remove any existing request - s.mu.Lock() - delete(s.reqs, msg.MessageID) - s.mu.Unlock() - return nil, ctx.Err() } } -// Call issues a rpc message with `req` as the body and decodes the reponse into -// a pointer at `resp`. Any Call errors are presented as a go error. -func (s *Session) Call(ctx context.Context, req any, resp any) error { - reply, err := s.Do(ctx, &req) +// Exec issues a rpc message with `req` as the body and decodes the response into +// a pointer at `resp`. Resp must include the full structure. +func (s *Session) Exec(ctx context.Context, operation any, reply any) error { + req := Request{RPC: RPC{Operation: operation}} + + resp, err := s.Do(ctx, &req) if err != nil { return err } + defer func() { + _ = resp.Close() + }() - if err := reply.Err(); err != nil { - return err + raw, err := io.ReadAll(resp) + if err != nil { + return fmt.Errorf("failed to read reply: %w", err) } - if err := reply.Decode(&resp); err != nil { - return err + var rpcReply RPCReply + if err := xml.Unmarshal(raw, &rpcReply); err != nil { + return fmt.Errorf("failed to parse rpc-reply: %w", err) + } + // filter out warnings + rpcErrors := rpcReply.RPCErrors.Filter(SevError) + if len(rpcErrors) > 0 { + return rpcErrors + } + + if reply != nil { + if err := xml.Unmarshal(raw, reply); err != nil { + return fmt.Errorf("failed to decode response: %w", err) + } } return nil @@ -365,7 +403,11 @@ func (s *Session) Close(ctx context.Context) error { } // This may fail so save the error but still close the underlying transport. - _, callErr := s.Do(ctx, &closeSession{}) + req := NewRequest(&closeSession{}) + resp, _ := s.Do(ctx, req) + if resp != nil { + _ = resp.Close() + } // Close the connection and ignore errors if the remote side hung up first. if err := s.tr.Close(); err != nil && @@ -377,9 +419,5 @@ func (s *Session) Close(ctx context.Context) error { } } - if !errors.Is(callErr, io.EOF) { - return callErr - } - return nil } diff --git a/session_test.go b/session_test.go index fb145e4..4015725 100644 --- a/session_test.go +++ b/session_test.go @@ -66,7 +66,6 @@ func (s *testServer) transport() *testTransport { return newTestTransport(s.hand type testTransport struct { handler func(r io.ReadCloser, w io.WriteCloser) out chan io.ReadCloser - // msgReceived, msgSent int } func newTestTransport(handler func(r io.ReadCloser, w io.WriteCloser)) *testTransport { diff --git a/transport/transport.go b/transport/transport.go index e9ad32e..973b7f4 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -1,6 +1,7 @@ package transport import ( + "bytes" "errors" "io" ) @@ -26,3 +27,54 @@ type Transport interface { Close() error } + +// TestTransport mocks the underlying NETCONF transport layer. +// It allows us to queue up "Server Responses" and inspect "Client Requests". +type TestTransport struct { + // inputs is a queue of messages the Server "sends" to the Client. + // The Session calls ReadMsg() to pop from this queue. + inputs [][]byte + + // outputs captures messages the Client "sends" to the Server. + // The Session calls WriteMsg() to append to this list. + outputs [][]byte +} + +type readNoopCloser struct{ io.Reader } + +func (r readNoopCloser) Close() error { return nil } + +type testWriter struct { + tt *TestTransport + buf *bytes.Buffer +} + +func (w *testWriter) Write(p []byte) (int, error) { + return w.buf.Write(p) +} + +func (w *testWriter) Close() error { + w.tt.outputs = append(w.tt.outputs, w.buf.Bytes()) + return nil +} + +func (t *TestTransport) MsgReader() (io.ReadCloser, error) { + if len(t.inputs) == 0 { + return nil, io.EOF + } + + msg := t.inputs[0] + t.inputs = t.inputs[1:] + return readNoopCloser{bytes.NewReader(msg)}, nil +} + +func (t *TestTransport) MsgWriter() (io.WriteCloser, error) { + return &testWriter{tt: t, buf: &bytes.Buffer{}}, nil +} + +func (t *TestTransport) Close() error { return nil } + +// Helper to push a server response into the read queue +func (t *TestTransport) AddResponse(body string) { + t.inputs = append(t.inputs, []byte(body)) +}