Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Sources/Containerization/LinuxContainer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -969,7 +969,7 @@ extension LinuxContainer {
}

/// Dial a vsock port in the container.
public func dialVsock(port: UInt32) async throws -> FileHandle {
public func dialVsock(port: UInt32) async throws -> VsockConnection {
try await self.state.withLock {
let state = try $0.startedState("dialVsock")
return try await state.vm.dial(port)
Expand Down Expand Up @@ -1098,7 +1098,7 @@ extension LinuxContainer {
try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<Void, any Error>) in
self.copyQueue.async {
do {
defer { conn.closeFile() }
defer { try? conn.close() }

if isArchive {
let writer = try ArchiveWriter(configuration: .init(format: .pax, filter: .gzip))
Expand Down Expand Up @@ -1209,7 +1209,7 @@ extension LinuxContainer {
try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<Void, any Error>) in
self.copyQueue.async {
do {
defer { conn.closeFile() }
defer { try? conn.close() }

if metadata.isArchive {
try FileManager.default.createDirectory(at: destination, withIntermediateDirectories: true)
Expand Down
2 changes: 1 addition & 1 deletion Sources/Containerization/LinuxPod.swift
Original file line number Diff line number Diff line change
Expand Up @@ -850,7 +850,7 @@ extension LinuxPod {
}

/// Dial a vsock port in the pod's VM.
public func dialVsock(port: UInt32) async throws -> FileHandle {
public func dialVsock(port: UInt32) async throws -> VsockConnection {
try await self.state.withLock { state in
let createdState = try state.phase.createdState("dialVsock")
return try await createdState.vm.dial(port)
Expand Down
38 changes: 28 additions & 10 deletions Sources/Containerization/LinuxProcess.swift
Original file line number Diff line number Diff line change
Expand Up @@ -48,26 +48,44 @@ public final class LinuxProcess: Sendable {
}

private struct StdioHandles: Sendable {
var stdin: FileHandle?
var stdout: FileHandle?
var stderr: FileHandle?
var stdin: VsockConnection?
var stdout: VsockConnection?
var stderr: VsockConnection?

mutating func close() throws {
var firstError: Error?

if let stdin {
try stdin.close()
stdin.readabilityHandler = nil
do {
try stdin.close()
} catch {
firstError = firstError ?? error
}
self.stdin = nil
}
if let stdout {
try stdout.close()
stdout.readabilityHandler = nil
do {
try stdout.close()
} catch {
firstError = firstError ?? error
}
self.stdout = nil
}
if let stderr {
try stderr.close()
stderr.readabilityHandler = nil
do {
try stderr.close()
} catch {
firstError = firstError ?? error
}
self.stderr = nil
}

if let firstError {
throw firstError
}
}
}

Expand Down Expand Up @@ -124,10 +142,10 @@ public final class LinuxProcess: Sendable {
}

extension LinuxProcess {
func setupIO(listeners: [VsockListener?]) async throws -> [FileHandle?] {
func setupIO(listeners: [VsockListener?]) async throws -> [VsockConnection?] {
let handles = try await Timeout.run(seconds: 3) {
try await withThrowingTaskGroup(of: (Int, FileHandle?).self) { group in
var results = [FileHandle?](repeating: nil, count: 3)
try await withThrowingTaskGroup(of: (Int, VsockConnection?).self) { group in
var results = [VsockConnection?](repeating: nil, count: 3)

for (index, listener) in listeners.enumerated() {
guard let listener else { continue }
Expand Down Expand Up @@ -196,7 +214,7 @@ extension LinuxProcess {
return handles
}

func startStdinRelay(handle: FileHandle) {
func startStdinRelay(handle: VsockConnection) {
guard let stdin = self.ioSetup.stdin else { return }

self.state.withLock {
Expand Down
42 changes: 31 additions & 11 deletions Sources/Containerization/UnixSocketRelay.swift
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,13 @@ package final class UnixSocketRelay: Sendable {
private let log: Logger?
private let state: Mutex<State>

private struct ActiveRelay: Sendable {
let relay: BidirectionalRelay
let guestConnection: VsockConnection
}

private struct State {
var activeRelays: [String: BidirectionalRelay] = [:]
var activeRelays: [String: ActiveRelay] = [:]
var t: Task<(), Never>? = nil
var listener: VsockListener? = nil
}
Expand Down Expand Up @@ -75,10 +80,9 @@ extension UnixSocketRelay {
}
t.cancel()
$0.t = nil
for (_, relay) in $0.activeRelays {
relay.stop()
for (_, activeRelay) in $0.activeRelays {
activeRelay.relay.stop()
}
$0.activeRelays.removeAll()

switch configuration.direction {
case .outOf:
Expand Down Expand Up @@ -170,12 +174,12 @@ extension UnixSocketRelay {
"initiating connection from host to guest",
metadata: [
"vport": "\(port)",
"hostFd": "\(guestConn.fileDescriptor)",
"guestFd": "\(hostConn.fileDescriptor)",
"hostFd": "\(hostConn.fileDescriptor)",
"guestFd": "\(guestConn.fileDescriptor)",
])
try await self.relay(
hostConn: hostConn,
guestFd: guestConn.fileDescriptor
guestConn: guestConn
)
} catch {
log?.error("failed to relay between vsock \(port) and \(hostConn)")
Expand All @@ -184,7 +188,7 @@ extension UnixSocketRelay {
}

private func handleGuestVsockConn(
vsockConn: FileHandle,
vsockConn: VsockConnection,
hostConnectionPath: URL,
port: UInt32,
log: Logger?
Expand All @@ -207,7 +211,7 @@ extension UnixSocketRelay {
do {
try await self.relay(
hostConn: hostSocket,
guestFd: vsockConn.fileDescriptor
guestConn: vsockConn
)
} catch {
log?.error("failed to relay between vsock \(port) and \(hostPath)")
Expand All @@ -216,9 +220,13 @@ extension UnixSocketRelay {

private func relay(
hostConn: Socket,
guestFd: Int32
guestConn: VsockConnection
) async throws {
let hostFd = hostConn.fileDescriptor
let guestFd = dup(guestConn.fileDescriptor)
if guestFd == -1 {
throw POSIXError.fromErrno()
}

let relayID = UUID().uuidString
let relay = BidirectionalRelay(
Expand All @@ -229,9 +237,21 @@ extension UnixSocketRelay {
)

state.withLock {
$0.activeRelays[relayID] = relay
// Retain the original connection until the relay has fully completed.
// The relay owns its duplicated fd and will close it itself.
$0.activeRelays[relayID] = ActiveRelay(
relay: relay,
guestConnection: guestConn
)
}

relay.start()

Task {
await relay.waitForCompletion()
let _ = self.state.withLock {
$0.activeRelays.removeValue(forKey: relayID)
}
}
}
}
25 changes: 21 additions & 4 deletions Sources/Containerization/VZVirtualMachine+Helpers.swift
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,15 @@ extension VZVirtualMachine {
}

extension VZVirtualMachine {
func waitForAgent(queue: DispatchQueue) async throws -> FileHandle {
func waitForAgent(queue: DispatchQueue) async throws -> (FileHandle, VsockTransport) {
let agentConnectionRetryCount: Int = 200
let agentConnectionSleepDuration: Duration = .milliseconds(20)

for _ in 0...agentConnectionRetryCount {
do {
return try await self.connect(queue: queue, port: Vminitd.port).dupHandle()
let conn = try await self.connect(queue: queue, port: Vminitd.port)
let handle = try conn.dupFileDescriptor()
return (handle, VsockTransport(conn))
} catch {
try await Task.sleep(for: agentConnectionSleepDuration)
continue
Expand All @@ -139,12 +141,27 @@ extension VZVirtualMachine {
}

extension VZVirtioSocketConnection {
func dupHandle() throws -> FileHandle {
/// Duplicates the file descriptor and retains the originating vsock connection
/// until the returned connection is closed or deallocated.
///
/// Use this for file descriptors which cross an async boundary or may not be
/// consumed immediately by the caller.
func retainedConnection() throws -> VsockConnection {
try VsockConnection(connection: self)
}

/// Duplicates the connection's file descriptor without closing the connection.
///
/// The caller must keep the `VZVirtioSocketConnection` alive until the dup'd
/// descriptor is no longer needed. The Virtualization framework tears down the
/// vsock endpoint when the connection is closed, which invalidates dup'd
/// descriptors. This is intended for callers which manage lifetime separately,
/// such as gRPC transports stored on `Vminitd`.
func dupFileDescriptor() throws -> FileHandle {
let fd = dup(self.fileDescriptor)
if fd == -1 {
throw POSIXError.fromErrno()
}
self.close()
return FileHandle(fileDescriptor: fd, closeOnDealloc: false)
}
}
Expand Down
14 changes: 6 additions & 8 deletions Sources/Containerization/VZVirtualMachineInstance.swift
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,8 @@ extension VZVirtualMachineInstance: VirtualMachineInstance {

try await self.vm.start(queue: self.queue)

let agent = try Vminitd(
connection: try await self.vm.waitForAgent(queue: self.queue),
group: self.group
)
let (handle, transport) = try await self.vm.waitForAgent(queue: self.queue)
let agent = try Vminitd(connection: handle, transport: transport, group: self.group)

do {
if self.config.rosetta {
Expand Down Expand Up @@ -189,8 +187,8 @@ extension VZVirtualMachineInstance: VirtualMachineInstance {
queue: queue,
port: Vminitd.port
)
let handle = try conn.dupHandle()
return try Vminitd(connection: handle, group: self.group)
let handle = try conn.dupFileDescriptor()
return try Vminitd(connection: handle, transport: VsockTransport(conn), group: self.group)
} catch {
if let err = error as? ContainerizationError {
throw err
Expand All @@ -204,14 +202,14 @@ extension VZVirtualMachineInstance: VirtualMachineInstance {
}
}

func dial(_ port: UInt32) async throws -> FileHandle {
func dial(_ port: UInt32) async throws -> VsockConnection {
try await lock.withLock { _ in
do {
let conn = try await vm.connect(
queue: queue,
port: port
)
return try conn.dupHandle()
return try conn.retainedConnection()
} catch {
if let err = error as? ContainerizationError {
throw err
Expand Down
2 changes: 1 addition & 1 deletion Sources/Containerization/VirtualMachineInstance.swift
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public protocol VirtualMachineInstance: Sendable {
/// what port the agent is listening on.
func dialAgent() async throws -> Agent
/// Dial a vsock port in the guest.
func dial(_ port: UInt32) async throws -> FileHandle
func dial(_ port: UInt32) async throws -> VsockConnection
/// Listen on a host vsock port.
func listen(_ port: UInt32) throws -> VsockListener
/// Start the virtual machine.
Expand Down
17 changes: 15 additions & 2 deletions Sources/Containerization/Vminitd.swift
Original file line number Diff line number Diff line change
Expand Up @@ -98,20 +98,32 @@ public struct Vminitd: Sendable {
private let grpcClient: GRPCClient<HTTP2ClientTransport.WrappedChannel>
private let connectionTask: Task<Void, Error>

/// Retains the underlying vsock connection to keep the file descriptor
/// valid for the gRPC client's lifetime. The Virtualization framework
/// tears down the vsock endpoint when the connection is closed, which
/// invalidates dup'd descriptors. Must remain open until the gRPC
/// channel is shut down.
private let transport: VsockTransport?

public init(connection: FileHandle, group: any EventLoopGroup) throws {
try self.init(connection: connection, transport: nil, group: group)
}

init(connection: FileHandle, transport: VsockTransport?, group: any EventLoopGroup) throws {
let channel = try ClientBootstrap(group: group)
.channelInitializer { channel in
channel.eventLoop.makeCompletedFuture(withResultOf: {
try channel.pipeline.syncOperations.addHandler(HTTP2ConnectBufferingHandler())
})
}
.withConnectedSocket(connection.fileDescriptor).wait()
let transport = HTTP2ClientTransport.WrappedChannel.wrapping(
let channelTransport = HTTP2ClientTransport.WrappedChannel.wrapping(
channel: channel,
)
let grpcClient = GRPCClient(transport: transport)
let grpcClient = GRPCClient(transport: channelTransport)
self.grpcClient = grpcClient
self.client = Com_Apple_Containerization_Sandbox_V3_SandboxContext.Client(wrapping: self.grpcClient)
self.transport = transport
// Not very structured concurrency friendly, but we'd need to expose a way on the protocol to "run" the
// agent otherwise, which some agents might not even need.
self.connectionTask = Task {
Expand All @@ -122,6 +134,7 @@ public struct Vminitd: Sendable {
/// Close the connection to the guest agent.
public func close() async throws {
self.grpcClient.beginGracefulShutdown()
defer { transport?.close() }
try await self.connectionTask.value
}
}
Expand Down
Loading