diff --git a/Sources/ContainerClient/Core/ClientImage.swift b/Sources/ContainerClient/Core/ClientImage.swift index fa6da0731..7cac67005 100644 --- a/Sources/ContainerClient/Core/ClientImage.swift +++ b/Sources/ContainerClient/Core/ClientImage.swift @@ -220,7 +220,13 @@ extension ClientImage { }) } - public static func pull(reference: String, platform: Platform? = nil, scheme: RequestScheme = .auto, progressUpdate: ProgressUpdateHandler? = nil) async throws -> ClientImage { + public static func pull( + reference: String, platform: Platform? = nil, scheme: RequestScheme = .auto, progressUpdate: ProgressUpdateHandler? = nil, maxConcurrentDownloads: Int = 3 + ) async throws -> ClientImage { + guard maxConcurrentDownloads > 0 else { + throw ContainerizationError(.invalidArgument, message: "maximum number of concurrent downloads must be greater than 0, got \(maxConcurrentDownloads)") + } + let client = newXPCClient() let request = newRequest(.imagePull) @@ -234,6 +240,7 @@ extension ClientImage { let insecure = try scheme.schemeFor(host: host) == .http request.set(key: .insecureFlag, value: insecure) + request.set(key: .maxConcurrentDownloads, value: Int64(maxConcurrentDownloads)) var progressUpdateClient: ProgressUpdateClient? if let progressUpdate { @@ -313,8 +320,9 @@ extension ClientImage { return (totalCount: total, activeCount: active, totalSize: size, reclaimableSize: reclaimable) } - public static func fetch(reference: String, platform: Platform? = nil, scheme: RequestScheme = .auto, progressUpdate: ProgressUpdateHandler? = nil) async throws -> ClientImage - { + public static func fetch( + reference: String, platform: Platform? = nil, scheme: RequestScheme = .auto, progressUpdate: ProgressUpdateHandler? = nil, maxConcurrentDownloads: Int = 3 + ) async throws -> ClientImage { do { let match = try await self.get(reference: reference) if let platform { @@ -327,7 +335,7 @@ extension ClientImage { guard err.isCode(.notFound) else { throw err } - return try await Self.pull(reference: reference, platform: platform, scheme: scheme, progressUpdate: progressUpdate) + return try await Self.pull(reference: reference, platform: platform, scheme: scheme, progressUpdate: progressUpdate, maxConcurrentDownloads: maxConcurrentDownloads) } } } diff --git a/Sources/ContainerClient/Flags.swift b/Sources/ContainerClient/Flags.swift index 939a55049..4a5cfd27b 100644 --- a/Sources/ContainerClient/Flags.swift +++ b/Sources/ContainerClient/Flags.swift @@ -215,4 +215,11 @@ public struct Flags { @Option(name: .long, help: ArgumentHelp("Progress type (format: none|ansi)", valueName: "type")) public var progress: ProgressType = .ansi } + + public struct ImageFetch: ParsableArguments { + public init() {} + + @Option(name: .long, help: "Maximum number of concurrent downloads (default: 3)") + public var maxConcurrentDownloads: Int = 3 + } } diff --git a/Sources/ContainerClient/Utility.swift b/Sources/ContainerClient/Utility.swift index 16f43398f..85c6c7c3a 100644 --- a/Sources/ContainerClient/Utility.swift +++ b/Sources/ContainerClient/Utility.swift @@ -93,6 +93,7 @@ public struct Utility { management: Flags.Management, resource: Flags.Resource, registry: Flags.Registry, + imageFetch: Flags.ImageFetch, progressUpdate: @escaping ProgressUpdateHandler ) async throws -> (ContainerConfiguration, Kernel) { var requestedPlatform = Parser.platform(os: management.os, arch: management.arch) @@ -112,7 +113,8 @@ public struct Utility { reference: image, platform: requestedPlatform, scheme: scheme, - progressUpdate: ProgressTaskCoordinator.handler(for: fetchTask, from: progressUpdate) + progressUpdate: ProgressTaskCoordinator.handler(for: fetchTask, from: progressUpdate), + maxConcurrentDownloads: imageFetch.maxConcurrentDownloads ) // Unpack a fetched image before use @@ -140,7 +142,8 @@ public struct Utility { let fetchInitTask = await taskManager.startTask() let initImage = try await ClientImage.fetch( reference: ClientImage.initImageRef, platform: .current, scheme: scheme, - progressUpdate: ProgressTaskCoordinator.handler(for: fetchInitTask, from: progressUpdate)) + progressUpdate: ProgressTaskCoordinator.handler(for: fetchInitTask, from: progressUpdate), + maxConcurrentDownloads: imageFetch.maxConcurrentDownloads) await progressUpdate([ .setDescription("Unpacking init image"), diff --git a/Sources/ContainerCommands/Container/ContainerCreate.swift b/Sources/ContainerCommands/Container/ContainerCreate.swift index 5e4ecd7c2..4f2ac09a6 100644 --- a/Sources/ContainerCommands/Container/ContainerCreate.swift +++ b/Sources/ContainerCommands/Container/ContainerCreate.swift @@ -40,6 +40,9 @@ extension Application { @OptionGroup(title: "Registry options") var registryFlags: Flags.Registry + @OptionGroup(title: "Image fetch options") + var imageFetchFlags: Flags.ImageFetch + @OptionGroup var global: Flags.Global @@ -73,6 +76,7 @@ extension Application { management: managementFlags, resource: resourceFlags, registry: registryFlags, + imageFetch: imageFetchFlags, progressUpdate: progress.handler ) diff --git a/Sources/ContainerCommands/Container/ContainerRun.swift b/Sources/ContainerCommands/Container/ContainerRun.swift index ec65329c3..790f80642 100644 --- a/Sources/ContainerCommands/Container/ContainerRun.swift +++ b/Sources/ContainerCommands/Container/ContainerRun.swift @@ -47,6 +47,9 @@ extension Application { @OptionGroup(title: "Progress options") var progressFlags: Flags.Progress + @OptionGroup(title: "Image fetch options") + var imageFetchFlags: Flags.ImageFetch + @OptionGroup var global: Flags.Global @@ -97,6 +100,7 @@ extension Application { management: managementFlags, resource: resourceFlags, registry: registryFlags, + imageFetch: imageFetchFlags, progressUpdate: progress.handler ) diff --git a/Sources/ContainerCommands/Image/ImagePull.swift b/Sources/ContainerCommands/Image/ImagePull.swift index 3f14cef81..55ca722da 100644 --- a/Sources/ContainerCommands/Image/ImagePull.swift +++ b/Sources/ContainerCommands/Image/ImagePull.swift @@ -36,6 +36,9 @@ extension Application { @OptionGroup var progressFlags: Flags.Progress + @OptionGroup + var imageFetchFlags: Flags.ImageFetch + @Option( name: .shortAndLong, help: "Limit the pull to the specified architecture" @@ -100,7 +103,8 @@ extension Application { let taskManager = ProgressTaskCoordinator() let fetchTask = await taskManager.startTask() let image = try await ClientImage.pull( - reference: processedReference, platform: p, scheme: scheme, progressUpdate: ProgressTaskCoordinator.handler(for: fetchTask, from: progress.handler) + reference: processedReference, platform: p, scheme: scheme, progressUpdate: ProgressTaskCoordinator.handler(for: fetchTask, from: progress.handler), + maxConcurrentDownloads: self.imageFetchFlags.maxConcurrentDownloads ) progress.set(description: "Unpacking image") diff --git a/Sources/Services/ContainerImagesService/Client/ImageServiceXPCKeys.swift b/Sources/Services/ContainerImagesService/Client/ImageServiceXPCKeys.swift index c81c76dc3..fcbed505b 100644 --- a/Sources/Services/ContainerImagesService/Client/ImageServiceXPCKeys.swift +++ b/Sources/Services/ContainerImagesService/Client/ImageServiceXPCKeys.swift @@ -35,6 +35,7 @@ public enum ImagesServiceXPCKeys: String { case ociPlatform case insecureFlag case garbageCollect + case maxConcurrentDownloads /// ContentStore case digest diff --git a/Sources/Services/ContainerImagesService/Server/ImageService.swift b/Sources/Services/ContainerImagesService/Server/ImageService.swift index e7cc1baf7..896421870 100644 --- a/Sources/Services/ContainerImagesService/Server/ImageService.swift +++ b/Sources/Services/ContainerImagesService/Server/ImageService.swift @@ -59,11 +59,15 @@ public actor ImagesService { return try await imageStore.list().map { $0.description.fromCZ } } - public func pull(reference: String, platform: Platform?, insecure: Bool, progressUpdate: ProgressUpdateHandler?) async throws -> ImageDescription { - self.log.info("ImagesService: \(#function) - ref: \(reference), platform: \(String(describing: platform)), insecure: \(insecure)") + public func pull(reference: String, platform: Platform?, insecure: Bool, progressUpdate: ProgressUpdateHandler?, maxConcurrentDownloads: Int = 3) async throws + -> ImageDescription + { + self.log.info( + "ImagesService: \(#function) - ref: \(reference), platform: \(String(describing: platform)), insecure: \(insecure), maxConcurrentDownloads: \(maxConcurrentDownloads)") let img = try await Self.withAuthentication(ref: reference) { auth in try await self.imageStore.pull( - reference: reference, platform: platform, insecure: insecure, auth: auth, progress: ContainerizationProgressAdapter.handler(from: progressUpdate)) + reference: reference, platform: platform, insecure: insecure, auth: auth, progress: ContainerizationProgressAdapter.handler(from: progressUpdate), + maxConcurrentDownloads: maxConcurrentDownloads) } guard let img else { throw ContainerizationError(.internalError, message: "failed to pull image \(reference)") diff --git a/Sources/Services/ContainerImagesService/Server/ImagesServiceHarness.swift b/Sources/Services/ContainerImagesService/Server/ImagesServiceHarness.swift index cc1e70455..966191b8b 100644 --- a/Sources/Services/ContainerImagesService/Server/ImagesServiceHarness.swift +++ b/Sources/Services/ContainerImagesService/Server/ImagesServiceHarness.swift @@ -47,9 +47,11 @@ public struct ImagesServiceHarness: Sendable { platform = try JSONDecoder().decode(ContainerizationOCI.Platform.self, from: platformData) } let insecure = message.bool(key: .insecureFlag) + let maxConcurrentDownloads = message.int64(key: .maxConcurrentDownloads) let progressUpdateService = ProgressUpdateService(message: message) - let imageDescription = try await service.pull(reference: ref, platform: platform, insecure: insecure, progressUpdate: progressUpdateService?.handler) + let imageDescription = try await service.pull( + reference: ref, platform: platform, insecure: insecure, progressUpdate: progressUpdateService?.handler, maxConcurrentDownloads: Int(maxConcurrentDownloads)) let imageData = try JSONEncoder().encode(imageDescription) let reply = message.reply() diff --git a/Sources/Services/ContainerSandboxService/SandboxService.swift b/Sources/Services/ContainerSandboxService/SandboxService.swift index 867747364..546fcdb98 100644 --- a/Sources/Services/ContainerSandboxService/SandboxService.swift +++ b/Sources/Services/ContainerSandboxService/SandboxService.swift @@ -119,6 +119,7 @@ public actor SandboxService { try bundle.createLogFile() var config = try bundle.configuration + let vmm = VZVirtualMachineManager( kernel: try bundle.kernel, initialFilesystem: bundle.initialFilesystem.asMount, diff --git a/Tests/CLITests/Subcommands/Images/TestCLIImages.swift b/Tests/CLITests/Subcommands/Images/TestCLIImages.swift index 5f7800e88..20d921c59 100644 --- a/Tests/CLITests/Subcommands/Images/TestCLIImages.swift +++ b/Tests/CLITests/Subcommands/Images/TestCLIImages.swift @@ -360,6 +360,41 @@ extension TestCLIImagesCommand { } } + @Test func testMaxConcurrentDownloadsValidation() throws { + // Test that invalid maxConcurrentDownloads value is rejected + let (_, _, error, status) = try run(arguments: [ + "image", + "pull", + "--max-concurrent-downloads", "0", + "alpine:latest", + ]) + + #expect(status != 0, "Expected command to fail with maxConcurrentDownloads=0") + #expect( + error.contains("maximum number of concurrent downloads must be greater than 0"), + "Expected validation error message in output") + } + + @Test func testMaxConcurrentDownloadsFlag() throws { + // Test that the flag is accepted with valid values + do { + try doPull(imageName: alpine, args: ["--max-concurrent-downloads", "1"]) + let imagePresent = try isImagePresent(targetImage: alpine) + #expect(imagePresent, "Expected image to be pulled with maxConcurrentDownloads=1") + + // Clean up + try? doRemoveImages(images: [alpine]) + + // Test with higher concurrency + try doPull(imageName: alpine, args: ["--max-concurrent-downloads", "6"]) + let imagePresent2 = try isImagePresent(targetImage: alpine) + #expect(imagePresent2, "Expected image to be pulled with maxConcurrentDownloads=6") + } catch { + Issue.record("failed to pull image with maxConcurrentDownloads flag: \(error)") + return + } + } + @Test func testImageSaveAndLoadStdinStdout() throws { do { // 1. pull image