diff --git a/engine/cmd/database-lab/main.go b/engine/cmd/database-lab/main.go index e6a687747ac6285d5ec1842b6f1cf9ad225f8281..9e2919f39f8dbf3560eab61dc56a5b067fbb2b19 100644 --- a/engine/cmd/database-lab/main.go +++ b/engine/cmd/database-lab/main.go @@ -18,6 +18,7 @@ import ( "syscall" "time" + "github.com/docker/docker/api/types" "github.com/docker/docker/client" "github.com/pkg/errors" @@ -124,7 +125,9 @@ func main() { } // Create a cloning service to provision new clones. - provisioner, err := provision.New(ctx, &cfg.Provision, dbCfg, docker, pm, engProps.InstanceID, internalNetworkID) + networkGateway := getNetworkGateway(docker, internalNetworkID) + + provisioner, err := provision.New(ctx, &cfg.Provision, dbCfg, docker, pm, engProps.InstanceID, internalNetworkID, networkGateway) if err != nil { log.Errf(errors.WithMessage(err, `error in the "provision" section of the config`).Error()) } @@ -253,6 +256,22 @@ func main() { tm.SendEvent(ctxBackground, telemetry.EngineStoppedEvent, telemetry.EngineStopped{Uptime: server.Uptime()}) } +func getNetworkGateway(docker *client.Client, internalNetworkID string) string { + gateway := "" + + networkResource, err := docker.NetworkInspect(context.Background(), internalNetworkID, types.NetworkInspectOptions{}) + if err != nil { + log.Err(err.Error()) + return gateway + } + + if len(networkResource.IPAM.Config) > 0 { + gateway = networkResource.IPAM.Config[0].Gateway + } + + return gateway +} + func getEngineProperties(ctx context.Context, docker *client.Client, cfg *config.Config) (global.EngineProps, error) { hostname := os.Getenv("HOSTNAME") if hostname == "" { diff --git a/engine/internal/cloning/storage_test.go b/engine/internal/cloning/storage_test.go index e2a458d87f1de6b229763be8de8b386f9eacc351..4df70a2292749e3440c57f62b602f4622c942bc9 100644 --- a/engine/internal/cloning/storage_test.go +++ b/engine/internal/cloning/storage_test.go @@ -83,7 +83,7 @@ func newProvisioner() (*provision.Provisioner, error) { From: 1, To: 5, }, - }, nil, nil, nil, "instID", "nwID") + }, nil, nil, nil, "instID", "nwID", "") } func TestLoadingSessionState(t *testing.T) { diff --git a/engine/internal/provision/mode_local.go b/engine/internal/provision/mode_local.go index 82c680afbda6f51fe1f9025a5677e89ba0de5a99..922318fc011048b2a5e073bd95c51470a7cb9d42 100644 --- a/engine/internal/provision/mode_local.go +++ b/engine/internal/provision/mode_local.go @@ -15,6 +15,7 @@ import ( "regexp" "sort" "strconv" + "strings" "sync" "sync/atomic" "time" @@ -41,6 +42,7 @@ const ( maxNumberOfPortsToCheck = 5 portCheckingTimeout = 3 * time.Second unknownVersion = "unknown" + wildcardIP = "0.0.0.0" ) // PortPool describes an available port range for clones. @@ -73,11 +75,12 @@ type Provisioner struct { pm *pool.Manager networkID string instanceID string + gateway string } // New creates a new Provisioner instance. func New(ctx context.Context, cfg *Config, dbCfg *resources.DB, docker *client.Client, pm *pool.Manager, - instanceID, networkID string) (*Provisioner, error) { + instanceID, networkID, gateway string) (*Provisioner, error) { if err := IsValidConfig(*cfg); err != nil { return nil, errors.Wrap(err, "configuration is not valid") } @@ -93,6 +96,7 @@ func New(ctx context.Context, cfg *Config, dbCfg *resources.DB, docker *client.C pm: pm, networkID: networkID, instanceID: instanceID, + gateway: gateway, ports: make([]bool, cfg.PortPool.To-cfg.PortPool.From+1), } @@ -435,7 +439,7 @@ func getLatestSnapshot(snapshots []resources.Snapshot) (*resources.Snapshot, err func (p *Provisioner) RevisePortPool() error { log.Msg(fmt.Sprintf("Revising availability of the port range [%d - %d]", p.config.PortPool.From, p.config.PortPool.To)) - host, err := externalIP() + host, err := hostIP(p.gateway) if err != nil { return err } @@ -468,13 +472,21 @@ func (p *Provisioner) RevisePortPool() error { return nil } +func hostIP(gateway string) (string, error) { + if gateway != "" { + return gateway, nil + } + + return externalIP() +} + // allocatePort tries to find a free port and occupy it. func (p *Provisioner) allocatePort() (uint, error) { portOpts := p.config.PortPool attempts := 0 - host, err := externalIP() + host, err := hostIP(p.gateway) if err != nil { return 0, err } @@ -598,6 +610,8 @@ func (p *Provisioner) stopPoolSessions(fsm pool.FSManager, exceptClones map[stri } func (p *Provisioner) getAppConfig(pool *resources.Pool, name string, port uint) *resources.AppConfig { + provisionHosts := p.getProvisionHosts() + appConfig := &resources.AppConfig{ CloneName: name, DockerImage: p.config.DockerImage, @@ -607,12 +621,33 @@ func (p *Provisioner) getAppConfig(pool *resources.Pool, name string, port uint) Pool: pool, ContainerConf: p.config.ContainerConfig, NetworkID: p.networkID, - ProvisionHosts: p.config.CloneAccessAddresses, + ProvisionHosts: provisionHosts, } return appConfig } +// getProvisionHosts adds an internal Docker gateway to the hosts rule if the user restricts access to IP addresses. +func (p *Provisioner) getProvisionHosts() string { + provisionHosts := p.config.CloneAccessAddresses + + if provisionHosts == "" || provisionHosts == wildcardIP { + return provisionHosts + } + + hostSet := []string{p.gateway} + + for _, hostIP := range strings.Split(provisionHosts, ",") { + if hostIP != p.gateway { + hostSet = append(hostSet, hostIP) + } + } + + provisionHosts = strings.Join(hostSet, ",") + + return provisionHosts +} + // LastSessionActivity returns the time of the last session activity. func (p *Provisioner) LastSessionActivity(session *resources.Session, minimumTime time.Time) (*time.Time, error) { fsm, err := p.pm.GetFSManager(session.Pool) diff --git a/engine/internal/provision/mode_local_test.go b/engine/internal/provision/mode_local_test.go index cb01e63ce24d2e4739c0bd6bd419570442b50a9c..02fe78a34f324364b526c77dd08ee4fe5a7daba6 100644 --- a/engine/internal/provision/mode_local_test.go +++ b/engine/internal/provision/mode_local_test.go @@ -26,7 +26,7 @@ func TestPortAllocation(t *testing.T) { }, } - p, err := New(context.Background(), cfg, &resources.DB{}, &client.Client{}, &pool.Manager{}, "instanceID", "networkID") + p, err := New(context.Background(), cfg, &resources.DB{}, &client.Client{}, &pool.Manager{}, "instanceID", "networkID", "") require.NoError(t, err) // Allocate a new port. @@ -330,3 +330,57 @@ func createTempConfigFile(testCaseDir, fileName string, content string) error { return os.WriteFile(fn, []byte(content), 0666) } + +func TestProvisionHosts(t *testing.T) { + tests := []struct { + name string + udAddresses string + gateway string + expectedHosts string + }{ + { + name: "Empty fields", + udAddresses: "", + gateway: "", + expectedHosts: "", + }, + { + name: "Empty user-defined address", + udAddresses: "", + gateway: "172.20.0.1", + expectedHosts: "", + }, + { + name: "Wildcard IP", + udAddresses: "0.0.0.0", + gateway: "172.20.0.1", + expectedHosts: "0.0.0.0", + }, + { + name: "User-defined address", + udAddresses: "192.168.1.1", + gateway: "172.20.0.1", + expectedHosts: "172.20.0.1,192.168.1.1", + }, + { + name: "Multiple user-defined addresses", + udAddresses: "192.168.1.1,10.0.58.1", + gateway: "172.20.0.1", + expectedHosts: "172.20.0.1,192.168.1.1,10.0.58.1", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + p := Provisioner{ + config: &Config{ + CloneAccessAddresses: tt.udAddresses, + }, + gateway: tt.gateway, + } + + assert.Equal(t, tt.expectedHosts, p.getProvisionHosts()) + }) + } +}