diff --git a/Sources/Containerization/LinuxContainer.swift b/Sources/Containerization/LinuxContainer.swift index 9033a206..b6b03d12 100644 --- a/Sources/Containerization/LinuxContainer.swift +++ b/Sources/Containerization/LinuxContainer.swift @@ -969,7 +969,7 @@ extension LinuxContainer { } /// Dial a vsock port in the container. - public func dialVsock(port: UInt32) async throws -> FileHandle { + public func dialVsock(port: UInt32) async throws -> VsockConnection { try await self.state.withLock { let state = try $0.startedState("dialVsock") return try await state.vm.dial(port) @@ -1098,7 +1098,7 @@ extension LinuxContainer { try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in self.copyQueue.async { do { - defer { conn.closeFile() } + defer { try? conn.close() } if isArchive { let writer = try ArchiveWriter(configuration: .init(format: .pax, filter: .gzip)) @@ -1209,7 +1209,7 @@ extension LinuxContainer { try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in self.copyQueue.async { do { - defer { conn.closeFile() } + defer { try? conn.close() } if metadata.isArchive { try FileManager.default.createDirectory(at: destination, withIntermediateDirectories: true) diff --git a/Sources/Containerization/LinuxPod.swift b/Sources/Containerization/LinuxPod.swift index 7574d697..32a4ee1b 100644 --- a/Sources/Containerization/LinuxPod.swift +++ b/Sources/Containerization/LinuxPod.swift @@ -850,7 +850,7 @@ extension LinuxPod { } /// Dial a vsock port in the pod's VM. - public func dialVsock(port: UInt32) async throws -> FileHandle { + public func dialVsock(port: UInt32) async throws -> VsockConnection { try await self.state.withLock { state in let createdState = try state.phase.createdState("dialVsock") return try await createdState.vm.dial(port) diff --git a/Sources/Containerization/LinuxProcess.swift b/Sources/Containerization/LinuxProcess.swift index 4ef676d6..9a52aba0 100644 --- a/Sources/Containerization/LinuxProcess.swift +++ b/Sources/Containerization/LinuxProcess.swift @@ -48,26 +48,44 @@ public final class LinuxProcess: Sendable { } private struct StdioHandles: Sendable { - var stdin: FileHandle? - var stdout: FileHandle? - var stderr: FileHandle? + var stdin: VsockConnection? + var stdout: VsockConnection? + var stderr: VsockConnection? mutating func close() throws { + var firstError: Error? + if let stdin { - try stdin.close() stdin.readabilityHandler = nil + do { + try stdin.close() + } catch { + firstError = firstError ?? error + } self.stdin = nil } if let stdout { - try stdout.close() stdout.readabilityHandler = nil + do { + try stdout.close() + } catch { + firstError = firstError ?? error + } self.stdout = nil } if let stderr { - try stderr.close() stderr.readabilityHandler = nil + do { + try stderr.close() + } catch { + firstError = firstError ?? error + } self.stderr = nil } + + if let firstError { + throw firstError + } } } @@ -124,10 +142,10 @@ public final class LinuxProcess: Sendable { } extension LinuxProcess { - func setupIO(listeners: [VsockListener?]) async throws -> [FileHandle?] { + func setupIO(listeners: [VsockListener?]) async throws -> [VsockConnection?] { let handles = try await Timeout.run(seconds: 3) { - try await withThrowingTaskGroup(of: (Int, FileHandle?).self) { group in - var results = [FileHandle?](repeating: nil, count: 3) + try await withThrowingTaskGroup(of: (Int, VsockConnection?).self) { group in + var results = [VsockConnection?](repeating: nil, count: 3) for (index, listener) in listeners.enumerated() { guard let listener else { continue } @@ -196,7 +214,7 @@ extension LinuxProcess { return handles } - func startStdinRelay(handle: FileHandle) { + func startStdinRelay(handle: VsockConnection) { guard let stdin = self.ioSetup.stdin else { return } self.state.withLock { diff --git a/Sources/Containerization/UnixSocketRelay.swift b/Sources/Containerization/UnixSocketRelay.swift index e7a3304a..0c3872a1 100644 --- a/Sources/Containerization/UnixSocketRelay.swift +++ b/Sources/Containerization/UnixSocketRelay.swift @@ -29,8 +29,13 @@ package final class UnixSocketRelay: Sendable { private let log: Logger? private let state: Mutex + private struct ActiveRelay: Sendable { + let relay: BidirectionalRelay + let guestConnection: VsockConnection + } + private struct State { - var activeRelays: [String: BidirectionalRelay] = [:] + var activeRelays: [String: ActiveRelay] = [:] var t: Task<(), Never>? = nil var listener: VsockListener? = nil } @@ -75,10 +80,9 @@ extension UnixSocketRelay { } t.cancel() $0.t = nil - for (_, relay) in $0.activeRelays { - relay.stop() + for (_, activeRelay) in $0.activeRelays { + activeRelay.relay.stop() } - $0.activeRelays.removeAll() switch configuration.direction { case .outOf: @@ -170,12 +174,12 @@ extension UnixSocketRelay { "initiating connection from host to guest", metadata: [ "vport": "\(port)", - "hostFd": "\(guestConn.fileDescriptor)", - "guestFd": "\(hostConn.fileDescriptor)", + "hostFd": "\(hostConn.fileDescriptor)", + "guestFd": "\(guestConn.fileDescriptor)", ]) try await self.relay( hostConn: hostConn, - guestFd: guestConn.fileDescriptor + guestConn: guestConn ) } catch { log?.error("failed to relay between vsock \(port) and \(hostConn)") @@ -184,7 +188,7 @@ extension UnixSocketRelay { } private func handleGuestVsockConn( - vsockConn: FileHandle, + vsockConn: VsockConnection, hostConnectionPath: URL, port: UInt32, log: Logger? @@ -207,7 +211,7 @@ extension UnixSocketRelay { do { try await self.relay( hostConn: hostSocket, - guestFd: vsockConn.fileDescriptor + guestConn: vsockConn ) } catch { log?.error("failed to relay between vsock \(port) and \(hostPath)") @@ -216,9 +220,13 @@ extension UnixSocketRelay { private func relay( hostConn: Socket, - guestFd: Int32 + guestConn: VsockConnection ) async throws { let hostFd = hostConn.fileDescriptor + let guestFd = dup(guestConn.fileDescriptor) + if guestFd == -1 { + throw POSIXError.fromErrno() + } let relayID = UUID().uuidString let relay = BidirectionalRelay( @@ -229,9 +237,21 @@ extension UnixSocketRelay { ) state.withLock { - $0.activeRelays[relayID] = relay + // Retain the original connection until the relay has fully completed. + // The relay owns its duplicated fd and will close it itself. + $0.activeRelays[relayID] = ActiveRelay( + relay: relay, + guestConnection: guestConn + ) } relay.start() + + Task { + await relay.waitForCompletion() + let _ = self.state.withLock { + $0.activeRelays.removeValue(forKey: relayID) + } + } } } diff --git a/Sources/Containerization/VZVirtualMachine+Helpers.swift b/Sources/Containerization/VZVirtualMachine+Helpers.swift index 2cbadb1c..d8ab1adc 100644 --- a/Sources/Containerization/VZVirtualMachine+Helpers.swift +++ b/Sources/Containerization/VZVirtualMachine+Helpers.swift @@ -122,13 +122,15 @@ extension VZVirtualMachine { } extension VZVirtualMachine { - func waitForAgent(queue: DispatchQueue) async throws -> FileHandle { + func waitForAgent(queue: DispatchQueue) async throws -> (FileHandle, VsockTransport) { let agentConnectionRetryCount: Int = 200 let agentConnectionSleepDuration: Duration = .milliseconds(20) for _ in 0...agentConnectionRetryCount { do { - return try await self.connect(queue: queue, port: Vminitd.port).dupHandle() + let conn = try await self.connect(queue: queue, port: Vminitd.port) + let handle = try conn.dupFileDescriptor() + return (handle, VsockTransport(conn)) } catch { try await Task.sleep(for: agentConnectionSleepDuration) continue @@ -139,12 +141,27 @@ extension VZVirtualMachine { } extension VZVirtioSocketConnection { - func dupHandle() throws -> FileHandle { + /// Duplicates the file descriptor and retains the originating vsock connection + /// until the returned connection is closed or deallocated. + /// + /// Use this for file descriptors which cross an async boundary or may not be + /// consumed immediately by the caller. + func retainedConnection() throws -> VsockConnection { + try VsockConnection(connection: self) + } + + /// Duplicates the connection's file descriptor without closing the connection. + /// + /// The caller must keep the `VZVirtioSocketConnection` alive until the dup'd + /// descriptor is no longer needed. The Virtualization framework tears down the + /// vsock endpoint when the connection is closed, which invalidates dup'd + /// descriptors. This is intended for callers which manage lifetime separately, + /// such as gRPC transports stored on `Vminitd`. + func dupFileDescriptor() throws -> FileHandle { let fd = dup(self.fileDescriptor) if fd == -1 { throw POSIXError.fromErrno() } - self.close() return FileHandle(fileDescriptor: fd, closeOnDealloc: false) } } diff --git a/Sources/Containerization/VZVirtualMachineInstance.swift b/Sources/Containerization/VZVirtualMachineInstance.swift index 7e580aa2..cacc5fa5 100644 --- a/Sources/Containerization/VZVirtualMachineInstance.swift +++ b/Sources/Containerization/VZVirtualMachineInstance.swift @@ -125,10 +125,8 @@ extension VZVirtualMachineInstance: VirtualMachineInstance { try await self.vm.start(queue: self.queue) - let agent = try Vminitd( - connection: try await self.vm.waitForAgent(queue: self.queue), - group: self.group - ) + let (handle, transport) = try await self.vm.waitForAgent(queue: self.queue) + let agent = try Vminitd(connection: handle, transport: transport, group: self.group) do { if self.config.rosetta { @@ -189,8 +187,8 @@ extension VZVirtualMachineInstance: VirtualMachineInstance { queue: queue, port: Vminitd.port ) - let handle = try conn.dupHandle() - return try Vminitd(connection: handle, group: self.group) + let handle = try conn.dupFileDescriptor() + return try Vminitd(connection: handle, transport: VsockTransport(conn), group: self.group) } catch { if let err = error as? ContainerizationError { throw err @@ -204,14 +202,14 @@ extension VZVirtualMachineInstance: VirtualMachineInstance { } } - func dial(_ port: UInt32) async throws -> FileHandle { + func dial(_ port: UInt32) async throws -> VsockConnection { try await lock.withLock { _ in do { let conn = try await vm.connect( queue: queue, port: port ) - return try conn.dupHandle() + return try conn.retainedConnection() } catch { if let err = error as? ContainerizationError { throw err diff --git a/Sources/Containerization/VirtualMachineInstance.swift b/Sources/Containerization/VirtualMachineInstance.swift index 070cee87..01ddf7e2 100644 --- a/Sources/Containerization/VirtualMachineInstance.swift +++ b/Sources/Containerization/VirtualMachineInstance.swift @@ -38,7 +38,7 @@ public protocol VirtualMachineInstance: Sendable { /// what port the agent is listening on. func dialAgent() async throws -> Agent /// Dial a vsock port in the guest. - func dial(_ port: UInt32) async throws -> FileHandle + func dial(_ port: UInt32) async throws -> VsockConnection /// Listen on a host vsock port. func listen(_ port: UInt32) throws -> VsockListener /// Start the virtual machine. diff --git a/Sources/Containerization/Vminitd.swift b/Sources/Containerization/Vminitd.swift index cadce2c0..904c9d62 100644 --- a/Sources/Containerization/Vminitd.swift +++ b/Sources/Containerization/Vminitd.swift @@ -98,7 +98,18 @@ public struct Vminitd: Sendable { private let grpcClient: GRPCClient private let connectionTask: Task + /// Retains the underlying vsock connection to keep the file descriptor + /// valid for the gRPC client's lifetime. The Virtualization framework + /// tears down the vsock endpoint when the connection is closed, which + /// invalidates dup'd descriptors. Must remain open until the gRPC + /// channel is shut down. + private let transport: VsockTransport? + public init(connection: FileHandle, group: any EventLoopGroup) throws { + try self.init(connection: connection, transport: nil, group: group) + } + + init(connection: FileHandle, transport: VsockTransport?, group: any EventLoopGroup) throws { let channel = try ClientBootstrap(group: group) .channelInitializer { channel in channel.eventLoop.makeCompletedFuture(withResultOf: { @@ -106,12 +117,13 @@ public struct Vminitd: Sendable { }) } .withConnectedSocket(connection.fileDescriptor).wait() - let transport = HTTP2ClientTransport.WrappedChannel.wrapping( + let channelTransport = HTTP2ClientTransport.WrappedChannel.wrapping( channel: channel, ) - let grpcClient = GRPCClient(transport: transport) + let grpcClient = GRPCClient(transport: channelTransport) self.grpcClient = grpcClient self.client = Com_Apple_Containerization_Sandbox_V3_SandboxContext.Client(wrapping: self.grpcClient) + self.transport = transport // Not very structured concurrency friendly, but we'd need to expose a way on the protocol to "run" the // agent otherwise, which some agents might not even need. self.connectionTask = Task { @@ -122,6 +134,7 @@ public struct Vminitd: Sendable { /// Close the connection to the guest agent. public func close() async throws { self.grpcClient.beginGracefulShutdown() + defer { transport?.close() } try await self.connectionTask.value } } diff --git a/Sources/Containerization/VsockConnection.swift b/Sources/Containerization/VsockConnection.swift new file mode 100644 index 00000000..f4da0c3e --- /dev/null +++ b/Sources/Containerization/VsockConnection.swift @@ -0,0 +1,92 @@ +//===----------------------------------------------------------------------===// +// Copyright © 2025-2026 Apple Inc. and the Containerization project authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//===----------------------------------------------------------------------===// + +import Foundation + +#if os(macOS) +import Virtualization +#endif + +/// A vsock connection whose duplicated file descriptor keeps the originating +/// transport alive until the connection is closed. +/// +/// Uses `@unchecked Sendable` because the mutable close state is protected by +/// `NSLock`, while the underlying `FileHandle` and `VsockTransport` are shared +/// across tasks. +public final class VsockConnection: @unchecked Sendable { + private let fileHandle: FileHandle + private let transport: VsockTransport + private let lock = NSLock() + private var isClosed = false + +#if os(macOS) + init(connection: VZVirtioSocketConnection) throws { + let fd = dup(connection.fileDescriptor) + if fd == -1 { + throw POSIXError.fromErrno() + } + self.fileHandle = FileHandle(fileDescriptor: fd, closeOnDealloc: false) + self.transport = VsockTransport(connection) + } +#endif + + init(fileDescriptor: Int32, transport: VsockTransport) { + self.fileHandle = FileHandle(fileDescriptor: fileDescriptor, closeOnDealloc: false) + self.transport = transport + } + + public var fileDescriptor: Int32 { + fileHandle.fileDescriptor + } + + public var readabilityHandler: (@Sendable (FileHandle) -> Void)? { + get { fileHandle.readabilityHandler } + set { fileHandle.readabilityHandler = newValue } + } + + public var availableData: Data { + fileHandle.availableData + } + + public func write(contentsOf data: some DataProtocol) throws { + try fileHandle.write(contentsOf: data) + } + + public func close() throws { + try closeIfNeeded { + try fileHandle.close() + } + } + + private func closeIfNeeded(_ closeUnderlying: () throws -> Void) throws { + lock.lock() + guard !isClosed else { + lock.unlock() + return + } + isClosed = true + lock.unlock() + + defer { transport.close() } + try closeUnderlying() + } + + deinit { + try? closeIfNeeded { + try fileHandle.close() + } + } +} diff --git a/Sources/Containerization/VsockListener.swift b/Sources/Containerization/VsockListener.swift index 7a7b36fa..a2a7d193 100644 --- a/Sources/Containerization/VsockListener.swift +++ b/Sources/Containerization/VsockListener.swift @@ -22,18 +22,18 @@ import Virtualization /// A stream of vsock connections. public final class VsockListener: NSObject, Sendable, AsyncSequence { - public typealias Element = FileHandle + public typealias Element = VsockConnection /// The port the connections are for. public let port: UInt32 - private let connections: AsyncStream - private let cont: AsyncStream.Continuation + private let connections: AsyncStream + private let cont: AsyncStream.Continuation private let stopListening: @Sendable (_ port: UInt32) throws -> Void package init(port: UInt32, stopListen: @Sendable @escaping (_ port: UInt32) throws -> Void) { self.port = port - let (stream, continuation) = AsyncStream.makeStream(of: FileHandle.self) + let (stream, continuation) = AsyncStream.makeStream(of: VsockConnection.self) self.connections = stream self.cont = continuation self.stopListening = stopListen @@ -44,7 +44,7 @@ public final class VsockListener: NSObject, Sendable, AsyncSequence { try self.stopListening(self.port) } - public func makeAsyncIterator() -> AsyncStream.AsyncIterator { + public func makeAsyncIterator() -> AsyncStream.AsyncIterator { connections.makeAsyncIterator() } } @@ -52,20 +52,20 @@ public final class VsockListener: NSObject, Sendable, AsyncSequence { #if os(macOS) extension VsockListener: VZVirtioSocketListenerDelegate { + /// Accepts a new vsock connection and yields a retained `VsockConnection`. public func listener( _: VZVirtioSocketListener, shouldAcceptNewConnection conn: VZVirtioSocketConnection, from _: VZVirtioSocketDevice ) -> Bool { - let fd = dup(conn.fileDescriptor) - guard fd != -1 else { + let connection: VsockConnection + do { + connection = try conn.retainedConnection() + } catch { return false } - conn.close() - - let fh = FileHandle(fileDescriptor: fd, closeOnDealloc: false) - let result = cont.yield(fh) + let result = cont.yield(connection) if case .terminated = result { - try? fh.close() + try? connection.close() return false } diff --git a/Sources/Containerization/VsockTransport.swift b/Sources/Containerization/VsockTransport.swift new file mode 100644 index 00000000..d1c52c11 --- /dev/null +++ b/Sources/Containerization/VsockTransport.swift @@ -0,0 +1,62 @@ +//===----------------------------------------------------------------------===// +// Copyright © 2025-2026 Apple Inc. and the Containerization project authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//===----------------------------------------------------------------------===// + +#if os(macOS) +import Foundation +import Virtualization + +/// Manages the lifecycle of a VZVirtioSocketConnection for use as a gRPC transport. +/// +/// When a vsock connection's file descriptor is dup'd and handed to gRPC/NIO, +/// the original VZVirtioSocketConnection must remain open. The Virtualization +/// framework tears down the host-to-guest vsock mapping when the connection is +/// closed, which invalidates dup'd descriptors. This wrapper captures the +/// connection's close operation and provides explicit, idempotent close semantics. +/// +/// Uses `@unchecked Sendable` because the close state is protected by `NSLock`, +/// but the stored close closure may capture a non-Sendable +/// `VZVirtioSocketConnection`. +final class VsockTransport: @unchecked Sendable { + private let onClose: () -> Void + private let lock = NSLock() + private var isClosed = false + + init(_ connection: VZVirtioSocketConnection) { + self.onClose = { connection.close() } + } + + init(onClose: @escaping () -> Void) { + self.onClose = onClose + } + + /// Closes the underlying vsock connection, tearing down the host-side endpoint. + func close() { + lock.lock() + guard !isClosed else { + lock.unlock() + return + } + isClosed = true + lock.unlock() + onClose() + } + + deinit { + close() + } +} + +#endif diff --git a/Sources/Integration/ContainerTests.swift b/Sources/Integration/ContainerTests.swift index 2dd978b1..3c20d20a 100644 --- a/Sources/Integration/ContainerTests.swift +++ b/Sources/Integration/ContainerTests.swift @@ -3969,6 +3969,76 @@ extension IntegrationSuite { } } + /// Exercises the dialAgent() → gRPC RPC path that previously crashed with + /// EBADF when the VZVirtioSocketConnection was closed before the gRPC + /// client made its first call. + /// + /// Each exec() call creates a new vsock connection via dialAgent(). The + /// gRPC ClientConnection defers NIO channel creation until the first RPC + /// (createProcess). A delay between exec() and start() widens the window + /// where the fd must remain valid — if the VZVirtioSocketConnection is + /// closed prematurely, the fd may be invalidated by the time NIO tries + /// fcntl(F_SETNOSIGPIPE), causing a precondition failure. + /// + /// The same VsockTransport fix also applies to the waitForAgent() startup + /// path (where the first RPC is setTime via TimeSyncer). That path is + /// implicitly exercised by every integration test that boots a container, + /// but isn't stress-tested with an artificial delay here because the timing + /// depends on VM boot and Rosetta setup, which aren't controllable. + func testExecDeferredConnectionStability() async throws { + let id = "test-exec-deferred-connection-stability" + + let bs = try await bootstrap(id) + let container = try LinuxContainer(id, rootfs: bs.rootfs, vmm: bs.vmm) { config in + config.process.arguments = ["/bin/sleep", "1000"] + config.bootLog = bs.bootLog + } + + do { + try await container.create() + try await container.start() + + // Run multiple sequential exec calls with delays between creating the + // gRPC connection (exec) and making the first RPC (start). This is the + // pattern that triggered the EBADF crash: the fd was dup'd, the + // VZVirtioSocketConnection was closed, and by the time NIO tried to + // create the channel the fd was invalid. + for i in 0..<10 { + let buffer = BufferWriter() + let exec = try await container.exec("deferred-\(i)") { config in + config.arguments = ["/bin/echo", "exec-\(i)"] + config.stdout = buffer + } + + // Delay between exec() (which calls dialAgent/creates gRPC connection) + // and start() (which triggers the first RPC/NIO channel creation). + try await Task.sleep(for: .milliseconds(100)) + + try await exec.start() + let status = try await exec.wait() + try await exec.delete() + + guard status.exitCode == 0 else { + throw IntegrationError.assert(msg: "exec deferred-\(i) status \(status) != 0") + } + + guard let output = String(data: buffer.data, encoding: .utf8) else { + throw IntegrationError.assert(msg: "failed to read output from deferred-\(i)") + } + guard output.trimmingCharacters(in: .whitespacesAndNewlines) == "exec-\(i)" else { + throw IntegrationError.assert(msg: "deferred-\(i) output mismatch: \(output)") + } + } + + try await container.kill(SIGKILL) + try await container.wait() + try await container.stop() + } catch { + try? await container.stop() + throw error + } + } + @available(macOS 26.0, *) func testNetworkingDisabled() async throws { let id = "test-networking-disabled" diff --git a/Sources/Integration/Suite.swift b/Sources/Integration/Suite.swift index 9e18c8ae..a0551c1a 100644 --- a/Sources/Integration/Suite.swift +++ b/Sources/Integration/Suite.swift @@ -369,6 +369,7 @@ struct IntegrationSuite: AsyncParsableCommand { Test("container useInit zombie reaping", testUseInitZombieReaping), Test("container useInit with terminal", testUseInitWithTerminal), Test("container useInit with stdin", testUseInitWithStdin), + Test("exec deferred connection stability", testExecDeferredConnectionStability), Test("container sysctl", testSysctl), Test("container sysctl multiple", testSysctlMultiple), Test("container noNewPrivileges", testNoNewPrivileges), diff --git a/Tests/ContainerizationTests/VsockTransportTests.swift b/Tests/ContainerizationTests/VsockTransportTests.swift new file mode 100644 index 00000000..08ef8eeb --- /dev/null +++ b/Tests/ContainerizationTests/VsockTransportTests.swift @@ -0,0 +1,414 @@ +//===----------------------------------------------------------------------===// +// Copyright © 2025-2026 Apple Inc. and the Containerization project authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//===----------------------------------------------------------------------===// + +#if os(macOS) + +import ContainerizationOS +import Darwin +import Foundation +import Testing + +@testable import Containerization + +/// Tests for the VsockTransport fd lifecycle fix. +/// +/// The Virtualization framework tears down the vsock endpoint when a +/// VZVirtioSocketConnection is closed, invalidating dup'd descriptors. +/// The fix keeps the connection alive via VsockTransport until the gRPC +/// channel is shut down. +/// +/// These tests use Unix socket pairs to verify: +/// 1. A dup'd fd is fully functional when the original is kept alive. +/// 2. The specific fcntl call that triggers the NIO crash (F_SETNOSIGPIPE) +/// works on the dup'd fd. +/// 3. The correct teardown order (close dup'd fd first, then original) +/// preserves the connection for the peer until the original is closed. +@Suite("VsockTransport tests") +struct VsockTransportTests { + private final class CloseCounter: @unchecked Sendable { + private let lock = NSLock() + private var value = 0 + + func increment() { + lock.lock() + value += 1 + lock.unlock() + } + + func count() -> Int { + lock.lock() + defer { lock.unlock() } + return value + } + } + + private struct FakeVM: VirtualMachineInstance { + typealias Agent = Vminitd + + let dialImpl: @Sendable (UInt32) async throws -> VsockConnection + + var state: VirtualMachineInstanceState { .running } + var mounts: [String: [AttachedFilesystem]] { [:] } + + func dialAgent() async throws -> Vminitd { + fatalError("unused in test") + } + + func dial(_ port: UInt32) async throws -> VsockConnection { + try await dialImpl(port) + } + + func listen(_ port: UInt32) throws -> VsockListener { + fatalError("unused in test") + } + + func start() async throws { + fatalError("unused in test") + } + + func stop() async throws { + fatalError("unused in test") + } + } + + /// Creates a connected Unix socket pair. Returns (fd0, fd1). + private func makeSocketPair() throws -> (Int32, Int32) { + var fds: [Int32] = [0, 0] + let result = socketpair(AF_UNIX, SOCK_STREAM, 0, &fds) + try #require(result == 0, "socketpair should succeed") + return (fds[0], fds[1]) + } + + private func setSocketTimeout(fd: Int32, seconds: Int) throws { + var timer = timeval() + timer.tv_sec = seconds + timer.tv_usec = 0 + + let rc = setsockopt( + fd, + SOL_SOCKET, + SO_RCVTIMEO, + &timer, + socklen_t(MemoryLayout.size) + ) + try #require(rc == 0, "setting socket timeout should succeed") + } + + private func uniqueSocketPath() -> String { + let dir = FileManager.default.temporaryDirectory + .appendingPathComponent(UUID().uuidString, isDirectory: true) + try? FileManager.default.createDirectory(at: dir, withIntermediateDirectories: true) + return dir.appendingPathComponent("relay.sock").path + } + + private func connectUnixSocket(path: String) throws -> Socket { + var lastError: Error? + for _ in 0..<50 { + do { + let socket = try Socket(type: UnixType(path: path)) + try socket.connect() + try socket.setTimeout(option: .receive, seconds: 1) + try socket.setTimeout(option: .send, seconds: 1) + return socket + } catch { + lastError = error + usleep(20_000) + } + } + + throw lastError ?? POSIXError(.ETIMEDOUT) + } + + // MARK: - fd lifecycle tests + + /// Verifies that F_SETNOSIGPIPE (the exact fcntl call where NIO crashes) + /// succeeds on a dup'd fd when the original is kept alive. + @Test func dupdDescriptorSupportsFcntlWhenOriginalAlive() throws { + let (fd0, fd1) = try makeSocketPair() + defer { + close(fd0) + close(fd1) + } + + let dupdFd = dup(fd0) + try #require(dupdFd != -1) + defer { close(dupdFd) } + + // This is the exact operation that triggers the NIO EBADF crash + // when the underlying vsock endpoint has been torn down. + let result = fcntl(dupdFd, F_SETNOSIGPIPE, 1) + #expect(result == 0, "F_SETNOSIGPIPE should succeed on dup'd fd when original is alive") + } + + /// Verifies that a dup'd fd can read data written by the peer when the + /// original fd is kept alive. + @Test func dupdDescriptorCanReadWhenOriginalAlive() throws { + let (fd0, fd1) = try makeSocketPair() + defer { + close(fd0) + close(fd1) + } + + let dupdFd = dup(fd0) + try #require(dupdFd != -1) + defer { close(dupdFd) } + + // Peer writes data. + let message: [UInt8] = [1, 2, 3] + let writeResult = message.withUnsafeBufferPointer { buf in + write(fd1, buf.baseAddress, buf.count) + } + try #require(writeResult == 3) + + // Dup'd fd can read because the original keeps the connection alive. + var readBuf = [UInt8](repeating: 0, count: 3) + let readResult = readBuf.withUnsafeMutableBufferPointer { buf in + read(dupdFd, buf.baseAddress, buf.count) + } + #expect(readResult == 3) + #expect(readBuf == [1, 2, 3]) + } + + /// Verifies the correct teardown order: closing the dup'd fd first (gRPC + /// channel shutdown) does not break the connection for the peer, because + /// the original fd (transport) is still alive. + @Test func peerCanWriteAfterDupdFdClosedWhileOriginalAlive() throws { + let (fd0, fd1) = try makeSocketPair() + defer { + close(fd0) + close(fd1) + } + + let dupdFd = dup(fd0) + try #require(dupdFd != -1) + + // Close the dup'd fd (simulates gRPC channel shutdown). + close(dupdFd) + + // The peer can still write because the original fd keeps the + // connection alive. This matters for orderly shutdown: the guest + // doesn't see an unexpected EOF while the host is still tearing + // down the gRPC channel. + let message: [UInt8] = [42] + let writeResult = message.withUnsafeBufferPointer { buf in + write(fd1, buf.baseAddress, buf.count) + } + #expect(writeResult == 1, "Peer can still write after dup'd fd is closed") + + // Read from the original to confirm data arrived. + var readBuf = [UInt8](repeating: 0, count: 1) + let readResult = readBuf.withUnsafeMutableBufferPointer { buf in + read(fd0, buf.baseAddress, buf.count) + } + #expect(readResult == 1) + #expect(readBuf == [42]) + } + + /// Verifies that after both the dup'd fd and the original are closed, + /// the peer sees EOF (read returns 0). + @Test func peerSeesEOFAfterBothDescriptorsClosed() throws { + let (fd0, fd1) = try makeSocketPair() + defer { close(fd1) } + + let dupdFd = dup(fd0) + try #require(dupdFd != -1) + + // Close dup'd fd first (gRPC shutdown), then original (transport.close()). + close(dupdFd) + close(fd0) + + // Peer should see EOF. + var readBuf = [UInt8](repeating: 0, count: 1) + let readResult = readBuf.withUnsafeMutableBufferPointer { buf in + read(fd1, buf.baseAddress, buf.count) + } + #expect(readResult == 0, "Peer should see EOF after both descriptors are closed") + } + + @Test func transportCloseIsIdempotent() { + let counter = CloseCounter() + let transport = VsockTransport(onClose: { + counter.increment() + }) + + transport.close() + transport.close() + + #expect(counter.count() == 1) + } + + @Test func retainedConnectionCloseClosesTransportOnce() throws { + let (fd0, fd1) = try makeSocketPair() + defer { + close(fd0) + close(fd1) + } + + let dupdFd = dup(fd0) + try #require(dupdFd != -1) + + let counter = CloseCounter() + let transport = VsockTransport(onClose: { + counter.increment() + }) + let connection = VsockConnection(fileDescriptor: dupdFd, transport: transport) + + try connection.close() + try connection.close() + + #expect(counter.count() == 1) + } + + @Test func retainedConnectionDeinitClosesUnderlyingTransport() throws { + let (fd0, fd1) = try makeSocketPair() + defer { close(fd1) } + + let dupdFd = dup(fd0) + try #require(dupdFd != -1) + + let counter = CloseCounter() + do { + let connection = VsockConnection( + fileDescriptor: dupdFd, + transport: VsockTransport(onClose: { + counter.increment() + close(fd0) + }) + ) + _ = connection + } + + var readBuf = [UInt8](repeating: 0, count: 1) + let readResult = readBuf.withUnsafeMutableBufferPointer { buf in + read(fd1, buf.baseAddress, buf.count) + } + #expect(readResult == 0, "peer should see EOF after retained handle deallocation") + #expect(counter.count() == 1) + } + + @Test func unixSocketRelayRetainsDialedHandleForActiveRelay() async throws { + let (relayFd, peerFd) = try makeSocketPair() + defer { close(peerFd) } + + try setSocketTimeout(fd: peerFd, seconds: 1) + + let socketPath = uniqueSocketPath() + defer { + try? FileManager.default.removeItem(atPath: (socketPath as NSString).deletingLastPathComponent) + } + + let relay = try UnixSocketRelay( + port: 4242, + socket: UnixSocketConfiguration( + source: URL(filePath: "/guest/test.sock"), + destination: URL(filePath: socketPath), + direction: .outOf + ), + vm: FakeVM(dialImpl: { _ in + VsockConnection( + fileDescriptor: relayFd, + transport: VsockTransport(onClose: {}) + ) + }), + queue: DispatchQueue(label: "com.apple.containerization.tests.unix-socket-relay") + ) + + try await relay.start() + let hostSocket = try connectUnixSocket(path: socketPath) + defer { try? hostSocket.close() } + try? await Task.sleep(for: .milliseconds(100)) + + let guestToHost = Data("guest-to-host".utf8) + let guestWriteResult = guestToHost.withUnsafeBytes { ptr in + write(peerFd, ptr.baseAddress, ptr.count) + } + try #require(guestWriteResult == guestToHost.count) + + var hostBuffer = Data(repeating: 0, count: guestToHost.count) + let hostReadCount = try hostSocket.read(buffer: &hostBuffer) + #expect(hostReadCount == guestToHost.count) + #expect(Data(hostBuffer.prefix(hostReadCount)) == guestToHost) + + let hostToGuest = Data("host-to-guest".utf8) + let hostWriteCount = try hostSocket.write(data: hostToGuest) + #expect(hostWriteCount == hostToGuest.count) + + var guestBuffer = [UInt8](repeating: 0, count: hostToGuest.count) + let guestReadCount = guestBuffer.withUnsafeMutableBufferPointer { buf in + read(peerFd, buf.baseAddress, buf.count) + } + #expect(guestReadCount == hostToGuest.count) + #expect(Data(guestBuffer.prefix(guestReadCount)) == hostToGuest) + + try relay.stop() + } + + @Test func unixSocketRelayStopKeepsGuestConnectionAliveUntilRelayFinishes() async throws { + let (relayFd, peerFd) = try makeSocketPair() + defer { close(peerFd) } + + let counter = CloseCounter() + let queue = DispatchQueue(label: "com.apple.containerization.tests.unix-socket-relay.stop") + let socketPath = uniqueSocketPath() + defer { + try? FileManager.default.removeItem(atPath: (socketPath as NSString).deletingLastPathComponent) + } + + let relay = try UnixSocketRelay( + port: 4243, + socket: UnixSocketConfiguration( + source: URL(filePath: "/guest/test.sock"), + destination: URL(filePath: socketPath), + direction: .outOf + ), + vm: FakeVM(dialImpl: { _ in + VsockConnection( + fileDescriptor: relayFd, + transport: VsockTransport(onClose: { + counter.increment() + }) + ) + }), + queue: queue + ) + + try await relay.start() + let hostSocket = try connectUnixSocket(path: socketPath) + defer { try? hostSocket.close() } + try? await Task.sleep(for: .milliseconds(100)) + + queue.suspend() + var queueResumed = false + defer { + if !queueResumed { + queue.resume() + } + } + try relay.stop() + + #expect(counter.count() == 0) + #expect(fcntl(relayFd, F_GETFD) != -1) + + queue.resume() + queueResumed = true + try? await Task.sleep(for: .milliseconds(100)) + + #expect(counter.count() == 1) + #expect(fcntl(relayFd, F_GETFD) == -1) + } +} + +#endif