diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/VectorTest/NativeVectorFloat32Tests.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/VectorTest/NativeVectorFloat32Tests.cs index da55ceb1b3..5dd71c0f3f 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/VectorTest/NativeVectorFloat32Tests.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/VectorTest/NativeVectorFloat32Tests.cs @@ -1,627 +1,45 @@ -// Licensed to the .NET Foundation under one or more agreements. +// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; -using System.Collections.Generic; -using System.Data; -using System.Data.SqlTypes; -using System.Text.Json; -using System.Threading.Tasks; -using Microsoft.Data.SqlTypes; using Xunit; -using Xunit.Abstractions; -namespace Microsoft.Data.SqlClient.ManualTesting.Tests.SQL.VectorTest -{ - public static class VectorFloat32TestData - { - public const int VectorHeaderSize = 8; - public static float[] testData = new float[] { 1.1f, 2.2f, 3.3f, 1.01f, float.MinValue, -0.0f }; - public static int vectorColumnLength = testData.Length; - // Incorrect size for SqlParameter.Size - public static int IncorrectParamSize = 3234; - public static IEnumerable GetVectorFloat32TestData() - { - // Pattern 1-4 with SqlVector(values: testData) - yield return new object[] { 1, new SqlVector(testData), testData, vectorColumnLength }; - yield return new object[] { 2, new SqlVector(testData), testData, vectorColumnLength }; - yield return new object[] { 3, new SqlVector(testData), testData, vectorColumnLength }; - yield return new object[] { 4, new SqlVector(testData), testData, vectorColumnLength }; - - // Pattern 1-4 with SqlVector(n) - yield return new object[] { 1, SqlVector.CreateNull(vectorColumnLength), Array.Empty(), vectorColumnLength }; - yield return new object[] { 2, SqlVector.CreateNull(vectorColumnLength), Array.Empty(), vectorColumnLength }; - yield return new object[] { 3, SqlVector.CreateNull(vectorColumnLength), Array.Empty(), vectorColumnLength }; - yield return new object[] { 4, SqlVector.CreateNull(vectorColumnLength), Array.Empty(), vectorColumnLength }; - - // Pattern 1-4 with DBNull - yield return new object[] { 1, DBNull.Value, Array.Empty(), vectorColumnLength }; - yield return new object[] { 2, DBNull.Value, Array.Empty(), vectorColumnLength }; - yield return new object[] { 3, DBNull.Value, Array.Empty(), vectorColumnLength }; - yield return new object[] { 4, DBNull.Value, Array.Empty(), vectorColumnLength }; +namespace Microsoft.Data.SqlClient.ManualTesting.Tests.SQL.VectorTest; - // Pattern 1-4 with SqlVector.Null - yield return new object[] { 1, SqlVector.Null, Array.Empty(), vectorColumnLength }; +#nullable enable - // Following scenario is not supported in SqlClient. - // This can only be fixed with a behavior change that SqlParameter.Value is internally set to DBNull.Value if it is set to null. - //yield return new object[] { 2, SqlVector.Null, Array.Empty(), vectorColumnLength }; - - yield return new object[] { 3, SqlVector.Null, Array.Empty(), vectorColumnLength }; - yield return new object[] { 4, SqlVector.Null, Array.Empty(), vectorColumnLength }; - } - } +public sealed class VectorFloat32TestData : NativeVectorTestDataBase +{ + public override float[] SampleScalarData => [1.1f, 2.2f, 3.3f, 1.01f, float.MinValue, -0.0f]; - [Trait("Set", "3")] - public sealed class NativeVectorFloat32Tests : IDisposable + public override float[,] SampleDataSet { - private readonly ITestOutputHelper _output; - private static readonly string s_connectionString = ManualTesting.Tests.DataTestUtility.TCPConnectionString; - private static readonly string s_tableName = DataTestUtility.GetShortName("VectorTestTable"); - private static readonly string s_bulkCopySrcTableName = DataTestUtility.GetShortName("VectorBulkCopyTestTable"); - private static readonly int s_vectorDimensions = VectorFloat32TestData.vectorColumnLength; - private static readonly string s_bulkCopySrcTableDef = $@"(Id INT PRIMARY KEY IDENTITY, VectorData vector({s_vectorDimensions}) NULL)"; - private static readonly string s_tableDefinition = $@"(Id INT PRIMARY KEY IDENTITY, VectorData vector({s_vectorDimensions}) NULL)"; - private static readonly string s_selectCmdString = $"SELECT VectorData FROM {s_tableName} ORDER BY Id DESC"; - private static readonly string s_insertCmdString = $"INSERT INTO {s_tableName} (VectorData) VALUES (@VectorData)"; - private static readonly string s_vectorParamName = $"@VectorData"; - private static readonly string s_outputVectorParamName = $"@OutputVectorData"; - private static readonly string s_storedProcName = DataTestUtility.GetShortName("VectorsAsVarcharSp"); - private static readonly string s_storedProcBody = $@" - {s_vectorParamName} vector({s_vectorDimensions}), -- Input: Serialized float[] as JSON string - {s_outputVectorParamName} vector({s_vectorDimensions}) OUTPUT -- Output: Echoed back from latest inserted row - AS - BEGIN - SET NOCOUNT ON; - - -- Insert into vector table - INSERT INTO {s_tableName} (VectorData) - VALUES ({s_vectorParamName}); - - -- Retrieve latest entry (assumes auto-incrementing ID) - SELECT TOP 1 {s_outputVectorParamName} = VectorData - FROM {s_tableName} - ORDER BY Id DESC; - END;"; - - public NativeVectorFloat32Tests(ITestOutputHelper output) - { - _output = output; - using var connection = new SqlConnection(s_connectionString); - connection.Open(); - DataTestUtility.CreateTable(connection, s_tableName, s_tableDefinition); - DataTestUtility.CreateTable(connection, s_bulkCopySrcTableName, s_bulkCopySrcTableDef); - DataTestUtility.CreateSP(connection, s_storedProcName, s_storedProcBody); - } - - public void Dispose() - { - using var connection = new SqlConnection(s_connectionString); - connection.Open(); - DataTestUtility.DropTable(connection, s_tableName); - DataTestUtility.DropTable(connection, s_bulkCopySrcTableName); - DataTestUtility.DropStoredProcedure(connection, s_storedProcName); - } - - private void ValidateSqlVectorFloat32Object(bool isNull, SqlVector sqlVectorFloat32, float[] expectedData, int expectedLength) - { - Assert.Equal(expectedData, sqlVectorFloat32.Memory.ToArray()); - Assert.Equal(expectedLength, sqlVectorFloat32.Length); - if (!isNull) - { - Assert.False(sqlVectorFloat32.IsNull, "IsNull set to true for a non-null value"); - } - else - { - Assert.True(sqlVectorFloat32.IsNull, "IsNull set to false for a null value"); - } - } - - private void ValidateInsertedData(SqlConnection connection, float[] expectedData, int expectedLength) - { - using var selectCmd = new SqlCommand(s_selectCmdString, connection); - using var reader = selectCmd.ExecuteReader(); - Assert.True(reader.Read(), "No data found in the table."); - - //For both null and non-null cases, validate the SqlVector object - ValidateSqlVectorFloat32Object(reader.IsDBNull(0), (SqlVector)reader.GetSqlVector(0), expectedData, expectedLength); - ValidateSqlVectorFloat32Object(reader.IsDBNull(0), reader.GetFieldValue>(0), expectedData, expectedLength); - ValidateSqlVectorFloat32Object(reader.IsDBNull(0), (SqlVector)reader.GetSqlValue(0), expectedData, expectedLength); - - if (!reader.IsDBNull(0)) - { - ValidateSqlVectorFloat32Object(reader.IsDBNull(0), (SqlVector)reader.GetValue(0), expectedData, expectedLength); - ValidateSqlVectorFloat32Object(reader.IsDBNull(0), (SqlVector)reader[0], expectedData, expectedLength); - ValidateSqlVectorFloat32Object(reader.IsDBNull(0), (SqlVector)reader["VectorData"], expectedData, expectedLength); - Assert.Equal(expectedData, JsonSerializer.Deserialize(reader.GetString(0))); - Assert.Equal(expectedData, JsonSerializer.Deserialize(reader.GetSqlString(0).Value)); - Assert.Equal(expectedData, JsonSerializer.Deserialize(reader.GetFieldValue(0))); - } - else - { - Assert.Equal(DBNull.Value, reader.GetValue(0)); - Assert.Equal(DBNull.Value, reader[0]); - Assert.Equal(DBNull.Value, reader["VectorData"]); - Assert.Throws(() => reader.GetString(0)); - Assert.Throws(() => reader.GetSqlString(0).Value); - Assert.Throws(() => reader.GetFieldValue(0)); - } - } - - [ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.IsSqlVectorSupported))] - [MemberData(nameof(VectorFloat32TestData.GetVectorFloat32TestData), MemberType = typeof(VectorFloat32TestData), DisableDiscoveryEnumeration = true)] - public void TestSqlVectorFloat32ParameterInsertionAndReads( - int pattern, - object value, - float[] expectedValues, - int expectedLength) - { - using var conn = new SqlConnection(s_connectionString); - conn.Open(); - - using var insertCmd = new SqlCommand(s_insertCmdString, conn); - - SqlParameter param = pattern switch - { - 1 => new SqlParameter - { - ParameterName = s_vectorParamName, - SqlDbType = SqlDbTypeExtensions.Vector, - Value = value - }, - 2 => new SqlParameter(s_vectorParamName, value), - 3 => new SqlParameter(s_vectorParamName, SqlDbTypeExtensions.Vector) { Value = value }, - // Even if size is specified, the actual size is determined by the value passed and specified size is ignored. - 4 => new SqlParameter(s_vectorParamName, SqlDbTypeExtensions.Vector, VectorFloat32TestData.IncorrectParamSize) { Value = value }, - _ => throw new ArgumentOutOfRangeException(nameof(pattern), $"Unsupported pattern: {pattern}") - }; - - insertCmd.Parameters.Add(param); - Assert.Equal(1, insertCmd.ExecuteNonQuery()); - insertCmd.Parameters.Clear(); - - ValidateInsertedData(conn, expectedValues, expectedLength); - } - - private async Task ValidateInsertedDataAsync(SqlConnection connection, float[] expectedData, int expectedLength) - { - using var selectCmd = new SqlCommand(s_selectCmdString, connection); - using var reader = await selectCmd.ExecuteReaderAsync(); - Assert.True(await reader.ReadAsync(), "No data found in the table."); - - //For both null and non-null cases, validate the SqlVector object - ValidateSqlVectorFloat32Object(await reader.IsDBNullAsync(0), (SqlVector)reader.GetSqlVector(0), expectedData, expectedLength); - ValidateSqlVectorFloat32Object(await reader.IsDBNullAsync(0), await reader.GetFieldValueAsync>(0), expectedData, expectedLength); - ValidateSqlVectorFloat32Object(await reader.IsDBNullAsync(0), (SqlVector)reader.GetSqlValue(0), expectedData, expectedLength); - - if (!await reader.IsDBNullAsync(0)) - { - ValidateSqlVectorFloat32Object(await reader.IsDBNullAsync(0), (SqlVector)reader.GetValue(0), expectedData, expectedLength); - ValidateSqlVectorFloat32Object(await reader.IsDBNullAsync(0), (SqlVector)reader[0], expectedData, expectedLength); - ValidateSqlVectorFloat32Object(await reader.IsDBNullAsync(0), (SqlVector)reader["VectorData"], expectedData, expectedLength); - Assert.Equal(expectedData, JsonSerializer.Deserialize(reader.GetString(0))); - Assert.Equal(expectedData, JsonSerializer.Deserialize(reader.GetSqlString(0).Value)); - Assert.Equal(expectedData, JsonSerializer.Deserialize(await reader.GetFieldValueAsync(0))); - } - else - { - Assert.Equal(DBNull.Value, reader.GetValue(0)); - Assert.Equal(DBNull.Value, reader[0]); - Assert.Equal(DBNull.Value, reader["VectorData"]); - Assert.Throws(() => reader.GetString(0)); - Assert.Throws(() => reader.GetSqlString(0).Value); - await Assert.ThrowsAsync(async () => await reader.GetFieldValueAsync(0)); - } - } - - [ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.IsSqlVectorSupported))] - [MemberData(nameof(VectorFloat32TestData.GetVectorFloat32TestData), MemberType = typeof(VectorFloat32TestData), DisableDiscoveryEnumeration = true)] - public async Task TestSqlVectorFloat32ParameterInsertionAndReadsAsync( - int pattern, - object value, - float[] expectedValues, - int expectedLength) + get { - using var conn = new SqlConnection(s_connectionString); - await conn.OpenAsync(); - - using var insertCmd = new SqlCommand(s_insertCmdString, conn); - - SqlParameter param = pattern switch - { - 1 => new SqlParameter - { - ParameterName = s_vectorParamName, - SqlDbType = (SqlDbType)36, // SqlDbTypeExtension.Vector - Value = value - }, - 2 => new SqlParameter(s_vectorParamName, value), - 3 => new SqlParameter(s_vectorParamName, SqlDbTypeExtensions.Vector) { Value = value }, - 4 => new SqlParameter(s_vectorParamName, SqlDbTypeExtensions.Vector, VectorFloat32TestData.IncorrectParamSize) { Value = value }, - _ => throw new ArgumentOutOfRangeException(nameof(pattern), $"Unsupported pattern: {pattern}") - }; - - insertCmd.Parameters.Add(param); - Assert.Equal(1, await insertCmd.ExecuteNonQueryAsync()); - insertCmd.Parameters.Clear(); - - await ValidateInsertedDataAsync(conn, expectedValues, expectedLength); - } + float[,] sampleData = new float[10, ValidSampleScalarDataLength]; - [ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.IsSqlVectorSupported))] - [MemberData(nameof(VectorFloat32TestData.GetVectorFloat32TestData), MemberType = typeof(VectorFloat32TestData), DisableDiscoveryEnumeration = true)] - public void TestStoredProcParamsForVectorFloat32( - int pattern, - object value, - float[] expectedValues, - int expectedLength) - { - //Create SP for test - using var conn = new SqlConnection(s_connectionString); - conn.Open(); - DataTestUtility.CreateSP(conn, s_storedProcName, s_storedProcBody); - using var command = new SqlCommand(s_storedProcName, conn) + for (int i = 0; i < sampleData.GetLength(0); i++) { - CommandType = CommandType.StoredProcedure - }; + float baseValue = i * 10; - // Set input and output parameters - SqlParameter inputParam = pattern switch - { - 1 => new SqlParameter + for (int j = 0; j < sampleData.GetLength(1); j++) { - ParameterName = s_vectorParamName, - SqlDbType = SqlDbTypeExtensions.Vector, // SqlDbTypeExtension.Vector - Value = value - }, - 2 => new SqlParameter(s_vectorParamName, value), - 3 => new SqlParameter(s_vectorParamName, SqlDbTypeExtensions.Vector) { Value = value }, - 4 => new SqlParameter(s_vectorParamName, SqlDbTypeExtensions.Vector, VectorFloat32TestData.IncorrectParamSize) { Value = value }, - _ => throw new ArgumentOutOfRangeException(nameof(pattern), $"Unsupported pattern: {pattern}") - }; - command.Parameters.Add(inputParam); - - var outputParam = new SqlParameter - { - ParameterName = s_outputVectorParamName, - SqlDbType = SqlDbTypeExtensions.Vector, - Direction = ParameterDirection.Output, - Value = SqlVector.CreateNull(VectorFloat32TestData.vectorColumnLength) - }; - command.Parameters.Add(outputParam); - - // Execute the stored procedure - command.ExecuteNonQuery(); - - // Validate the output parameter - var vector = (SqlVector)outputParam.Value; - ValidateSqlVectorFloat32Object(vector.IsNull, vector, expectedValues, expectedLength); - - // Validate error for conventional way of setting output parameters - command.Parameters.Clear(); - command.Parameters.Add(inputParam); - var outputParamWithoutVal = new SqlParameter(s_outputVectorParamName, SqlDbTypeExtensions.Vector, VectorFloat32TestData.IncorrectParamSize) { Direction = ParameterDirection.Output }; - command.Parameters.Add(outputParamWithoutVal); - Assert.Throws(() => command.ExecuteNonQuery()); - } - - [ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.IsSqlVectorSupported))] - [MemberData(nameof(VectorFloat32TestData.GetVectorFloat32TestData), MemberType = typeof(VectorFloat32TestData), DisableDiscoveryEnumeration = true)] - public async Task TestStoredProcParamsForVectorFloat32Async( - int pattern, - object value, - float[] expectedValues, - int expectedLength) - { - //Create SP for test - using var conn = new SqlConnection(s_connectionString); - await conn.OpenAsync(); - DataTestUtility.CreateSP(conn, s_storedProcName, s_storedProcBody); - using var command = new SqlCommand(s_storedProcName, conn) - { - CommandType = CommandType.StoredProcedure - }; - - // Set input and output parameters - SqlParameter inputParam = pattern switch - { - 1 => new SqlParameter - { - ParameterName = s_vectorParamName, - SqlDbType = SqlDbTypeExtensions.Vector, // SqlDbTypeExtension.Vector - Value = value - }, - 2 => new SqlParameter(s_vectorParamName, value), - 3 => new SqlParameter(s_vectorParamName, SqlDbTypeExtensions.Vector) { Value = value }, - 4 => new SqlParameter(s_vectorParamName, SqlDbTypeExtensions.Vector, VectorFloat32TestData.IncorrectParamSize) { Value = value }, - _ => throw new ArgumentOutOfRangeException(nameof(pattern), $"Unsupported pattern: {pattern}") - }; - command.Parameters.Add(inputParam); - - var outputParam = new SqlParameter - { - ParameterName = s_outputVectorParamName, - SqlDbType = SqlDbTypeExtensions.Vector, - Direction = ParameterDirection.Output, - Value = SqlVector.CreateNull(VectorFloat32TestData.vectorColumnLength) - }; - command.Parameters.Add(outputParam); - - // Execute the stored procedure - await command.ExecuteNonQueryAsync(); - - // Validate the output parameter - var vector = (SqlVector)outputParam.Value; - ValidateSqlVectorFloat32Object(vector.IsNull, vector, expectedValues, expectedLength); - - // Validate error for conventional way of setting output parameters - command.Parameters.Clear(); - command.Parameters.Add(inputParam); - var outputParamWithoutVal = new SqlParameter(s_outputVectorParamName, SqlDbTypeExtensions.Vector, VectorFloat32TestData.IncorrectParamSize) { Direction = ParameterDirection.Output }; - command.Parameters.Add(outputParamWithoutVal); - await Assert.ThrowsAsync(async () => await command.ExecuteNonQueryAsync()); - } - - [ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.IsSqlVectorSupported))] - [InlineData(1)] - [InlineData(2)] - public void TestBulkCopyFromSqlTable(int bulkCopySourceMode) - { - //Setup source with test data and create destination table for bulkcopy. - SqlConnection sourceConnection = new SqlConnection(s_connectionString); - sourceConnection.Open(); - SqlConnection destinationConnection = new SqlConnection(s_connectionString); - destinationConnection.Open(); - DataTable table = null; - switch (bulkCopySourceMode) - { - - case 1: - // Use SqlServer table as source - var insertCmd = new SqlCommand($"insert into {s_bulkCopySrcTableName} values (@VectorData)", sourceConnection); - var vectorParam = new SqlParameter(s_vectorParamName, new SqlVector(VectorFloat32TestData.testData)); - - // Insert 2 rows with one non-null and null value - insertCmd.Parameters.Add(vectorParam); - Assert.Equal(1, insertCmd.ExecuteNonQuery()); - insertCmd.Parameters.Clear(); - vectorParam.Value = DBNull.Value; - insertCmd.Parameters.Add(vectorParam); - Assert.Equal(1, insertCmd.ExecuteNonQuery()); - insertCmd.Parameters.Clear(); - break; - case 2: - table = new DataTable(s_bulkCopySrcTableName); - table.Columns.Add("Id", typeof(int)); - table.Columns.Add("VectorData", typeof(SqlVector)); - table.Rows.Add(1, new SqlVector(VectorFloat32TestData.testData)); - table.Rows.Add(2, DBNull.Value); - break; - default: - throw new ArgumentOutOfRangeException(nameof(bulkCopySourceMode), $"Unsupported bulk copy source mode: {bulkCopySourceMode}"); - } - - - - //Bulkcopy from sql server table to destination table - using SqlCommand sourceDataCommand = new SqlCommand($"SELECT Id, VectorData FROM {s_bulkCopySrcTableName}", sourceConnection); - using SqlDataReader reader = sourceDataCommand.ExecuteReader(); - - // Verify that the destination table is empty before bulk copy - using SqlCommand countCommand = new SqlCommand($"SELECT COUNT(*) FROM {s_tableName}", destinationConnection); - Assert.Equal(0, Convert.ToInt16(countCommand.ExecuteScalar())); - - // Initialize bulk copy configuration - using SqlBulkCopy bulkCopy = new SqlBulkCopy(destinationConnection) - { - DestinationTableName = s_tableName, - }; - - try - { - switch (bulkCopySourceMode) - { - case 1: - bulkCopy.WriteToServer(reader); - break; - case 2: - bulkCopy.WriteToServer(table); - break; - default: - throw new ArgumentOutOfRangeException(nameof(bulkCopySourceMode), $"Unsupported bulk copy source mode: {bulkCopySourceMode}"); + sampleData[i, j] = baseValue + (j * 0.1f); } } - catch (Exception ex) - { - // If bulk copy fails, fail the test with the exception message - Assert.Fail($"Bulk copy failed: {ex.Message}"); - } - - // Verify that the 2 rows from the source table have been copied into the destination table. - Assert.Equal(2, Convert.ToInt16(countCommand.ExecuteScalar())); - - // Read the data from destination table as varbinary to verify the UTF-8 byte sequence - using SqlCommand verifyCommand = new SqlCommand($"SELECT VectorData from {s_tableName}", destinationConnection); - using SqlDataReader verifyReader = verifyCommand.ExecuteReader(); - - // Verify that we have data in the destination table - Assert.True(verifyReader.Read(), "No data found in destination table after bulk copy."); - // Validate first non-null value. - Assert.True(!verifyReader.IsDBNull(0), "First row in the table is null."); - Assert.Equal(VectorFloat32TestData.testData, ((SqlVector)verifyReader.GetSqlVector(0)).Memory.ToArray()); - Assert.Equal(VectorFloat32TestData.testData.Length, ((SqlVector)verifyReader.GetSqlVector(0)).Length); - - // Verify that we have another row - Assert.True(verifyReader.Read(), "Second row not found in the table"); - - // Verify that we have encountered null. - Assert.True(verifyReader.IsDBNull(0)); - Assert.Equal(Array.Empty(), ((SqlVector)verifyReader.GetSqlVector(0)).Memory.ToArray()); - Assert.Equal(VectorFloat32TestData.testData.Length, ((SqlVector)verifyReader.GetSqlVector(0)).Length); + return sampleData; } + } - [ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.IsSqlVectorSupported))] - [InlineData(1)] - [InlineData(2)] - public async Task TestBulkCopyFromSqlTableAsync(int bulkCopySourceMode) - { - //Setup source with test data and create destination table for bulkcopy. - SqlConnection sourceConnection = new SqlConnection(s_connectionString); - await sourceConnection.OpenAsync(); - SqlConnection destinationConnection = new SqlConnection(s_connectionString); - await destinationConnection.OpenAsync(); - - DataTable table = null; - switch (bulkCopySourceMode) - { - - case 1: - // Use SqlServer table as source - var insertCmd = new SqlCommand($"insert into {s_bulkCopySrcTableName} values (@VectorData)", sourceConnection); - var vectorParam = new SqlParameter(s_vectorParamName, new SqlVector(VectorFloat32TestData.testData)); - - // Insert 2 rows with one non-null and null value - insertCmd.Parameters.Add(vectorParam); - Assert.Equal(1, await insertCmd.ExecuteNonQueryAsync()); - insertCmd.Parameters.Clear(); - vectorParam.Value = DBNull.Value; - insertCmd.Parameters.Add(vectorParam); - Assert.Equal(1, await insertCmd.ExecuteNonQueryAsync()); - insertCmd.Parameters.Clear(); - break; - case 2: - table = new DataTable(s_bulkCopySrcTableName); - table.Columns.Add("Id", typeof(int)); - table.Columns.Add("VectorData", typeof(SqlVector)); - table.Rows.Add(1, new SqlVector(VectorFloat32TestData.testData)); - table.Rows.Add(2, DBNull.Value); - break; - default: - throw new ArgumentOutOfRangeException(nameof(bulkCopySourceMode), $"Unsupported bulk copy source mode: {bulkCopySourceMode}"); - } - - //Bulkcopy from sql server table to destination table - using SqlCommand sourceDataCommand = new SqlCommand($"SELECT Id, VectorData FROM {s_bulkCopySrcTableName}", sourceConnection); - using SqlDataReader reader = await sourceDataCommand.ExecuteReaderAsync(); - - // Verify that the destination table is empty before bulk copy - using SqlCommand countCommand = new SqlCommand($"SELECT COUNT(*) FROM {s_tableName}", destinationConnection); - Assert.Equal(0, Convert.ToInt16(await countCommand.ExecuteScalarAsync())); - - // Initialize bulk copy configuration - using SqlBulkCopy bulkCopy = new SqlBulkCopy(destinationConnection) - { - DestinationTableName = s_tableName, - }; - - try - { // Perform bulkcopy - switch (bulkCopySourceMode) - { - case 1: - await bulkCopy.WriteToServerAsync(reader); - break; - case 2: - await bulkCopy.WriteToServerAsync(table); - break; - default: - throw new ArgumentOutOfRangeException(nameof(bulkCopySourceMode), $"Unsupported bulk copy source mode: {bulkCopySourceMode}"); - } - } - catch (Exception ex) - { - // If bulk copy fails, fail the test with the exception message - Assert.Fail($"Bulk copy failed: {ex.Message}"); - } - - // Verify that the 2 rows from the source table have been copied into the destination table. - Assert.Equal(2, Convert.ToInt16(await countCommand.ExecuteScalarAsync())); - - // Read the data from destination table as varbinary to verify the UTF-8 byte sequence - using SqlCommand verifyCommand = new SqlCommand($"SELECT VectorData from {s_tableName}", destinationConnection); - using SqlDataReader verifyReader = await verifyCommand.ExecuteReaderAsync(); - - // Verify that we have data in the destination table - Assert.True(await verifyReader.ReadAsync(), "No data found in destination table after bulk copy."); - - // Validate first non-null value. - Assert.True(!await verifyReader.IsDBNullAsync(0), "First row in the table is null."); - var vector = await verifyReader.GetFieldValueAsync>(0); - Assert.Equal(VectorFloat32TestData.testData, vector.Memory.ToArray()); - Assert.Equal(VectorFloat32TestData.testData.Length, vector.Length); - - // Verify that we have another row - Assert.True(await verifyReader.ReadAsync(), "Second row not found in the table"); - - // Verify that we have encountered null. - Assert.True(await verifyReader.IsDBNullAsync(0)); - vector = await verifyReader.GetFieldValueAsync>(0); - Assert.Equal(Array.Empty(), vector.Memory.ToArray()); - Assert.Equal(VectorFloat32TestData.testData.Length, vector.Length); - } - - [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.IsSqlVectorSupported))] - public void TestGetFieldTypeReturnsSqlVectorForVectorColumn() - { - using var connection = new SqlConnection(s_connectionString); - connection.Open(); - - // Insert a row so we can query it - using (var insertCmd = new SqlCommand(s_insertCmdString, connection)) - { - var param = insertCmd.Parameters.Add(s_vectorParamName, SqlDbTypeExtensions.Vector); - param.Value = new SqlVector(VectorFloat32TestData.testData); - insertCmd.ExecuteNonQuery(); - } - - using var selectCmd = new SqlCommand(s_selectCmdString, connection); - using var reader = selectCmd.ExecuteReader(); - - // Verify GetFieldType returns SqlVector for the vector column - Assert.Equal(typeof(SqlVector), reader.GetFieldType(0)); - - // Verify GetProviderSpecificFieldType also returns SqlVector - Assert.Equal(typeof(SqlVector), reader.GetProviderSpecificFieldType(0)); + public override int IncorrectScalarDataParameterSize => 3234; - // Verify that GetValue returns an instance consistent with GetFieldType - Assert.True(reader.Read(), "No data found in the table."); - object value = reader.GetValue(0); - Assert.IsType>(value); - Assert.Equal(VectorFloat32TestData.testData, ((SqlVector)value).Memory.ToArray()); + public override bool IsSupported => DataTestUtility.IsSqlVectorSupported; - // Verify GetFieldValue> returns the correct typed value - SqlVector typedValue = reader.GetFieldValue>(0); - Assert.IsType>(typedValue); - Assert.Equal(VectorFloat32TestData.testData, typedValue.Memory.ToArray()); - } + public override string SqlServerTypeName => "float32"; +} - [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.IsSqlVectorSupported))] - public void TestInsertVectorsFloat32WithPrepare() - { - SqlConnection conn = new SqlConnection(s_connectionString); - conn.Open(); - SqlCommand command = new SqlCommand(s_insertCmdString, conn); - SqlParameter vectorParam = new SqlParameter("@VectorData", SqlDbTypeExtensions.Vector); - command.Parameters.Add(vectorParam); - command.Prepare(); - for (int i = 0; i < 10; i++) - { - vectorParam.Value = new SqlVector(new float[] { i + 0.1f, i + 0.2f, i + 0.3f, i + 0.4f, i + 0.5f, i + 0.6f }); - command.ExecuteNonQuery(); - } - SqlCommand validateCommand = new SqlCommand($"SELECT VectorData FROM {s_tableName}", conn); - using SqlDataReader reader = validateCommand.ExecuteReader(); - int rowcnt = 0; - while (reader.Read()) - { - float[] expectedData = new float[] { rowcnt + 0.1f, rowcnt + 0.2f, rowcnt + 0.3f, rowcnt + 0.4f, rowcnt + 0.5f, rowcnt + 0.6f }; - float[] dbData = reader.GetSqlVector(0).Memory.ToArray(); - Assert.Equal(expectedData, dbData); - rowcnt++; - } - Assert.Equal(10, rowcnt); - } - } +[Trait("Set", "3")] +public sealed class NativeVectorFloat32Tests : NativeVectorTestsBase +{ } diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/VectorTest/NativeVectorTestsBase.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/VectorTest/NativeVectorTestsBase.cs new file mode 100644 index 0000000000..4a3251deed --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/VectorTest/NativeVectorTestsBase.cs @@ -0,0 +1,669 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Data; +using System.Data.SqlTypes; +using System.Text.Json; +using System.Threading.Tasks; +using Microsoft.Data.SqlClient.Tests.Common.Fixtures.DatabaseObjects; +using Microsoft.Data.SqlTypes; +using Xunit; + +#nullable enable + +namespace Microsoft.Data.SqlClient.ManualTesting.Tests.SQL.VectorTest +{ + /// + /// Base class for all data passed to a type derived from . + /// + /// The element type of the . + public abstract class NativeVectorTestDataBase + where TElement : unmanaged + { + public abstract TElement[] SampleScalarData { get; } + + public abstract TElement[,] SampleDataSet { get; } + + // Incorrect size for SqlParameter.Size + public abstract int IncorrectScalarDataParameterSize { get; } + + public abstract bool IsSupported { get; } + + public abstract string SqlServerTypeName { get; } + + public int ValidSampleScalarDataLength => SampleScalarData.Length; + + public IEnumerable TestData => + [ + // Pattern 1-4 with SqlVector(values: SampleScalarData) + [ 1, new SqlVector(SampleScalarData), SampleScalarData, ValidSampleScalarDataLength ], + [ 2, new SqlVector(SampleScalarData), SampleScalarData, ValidSampleScalarDataLength ], + [ 3, new SqlVector(SampleScalarData), SampleScalarData, ValidSampleScalarDataLength ], + [ 4, new SqlVector(SampleScalarData), SampleScalarData, ValidSampleScalarDataLength ], + + // Pattern 1-4 with SqlVector(n) + [ 1, SqlVector.CreateNull(ValidSampleScalarDataLength), Array.Empty(), ValidSampleScalarDataLength ], + [ 2, SqlVector.CreateNull(ValidSampleScalarDataLength), Array.Empty(), ValidSampleScalarDataLength ], + [ 3, SqlVector.CreateNull(ValidSampleScalarDataLength), Array.Empty(), ValidSampleScalarDataLength ], + [ 4, SqlVector.CreateNull(ValidSampleScalarDataLength), Array.Empty(), ValidSampleScalarDataLength ], + + // Pattern 1-4 with DBNull + [ 1, DBNull.Value, Array.Empty(), ValidSampleScalarDataLength ], + [ 2, DBNull.Value, Array.Empty(), ValidSampleScalarDataLength ], + [ 3, DBNull.Value, Array.Empty(), ValidSampleScalarDataLength ], + [ 4, DBNull.Value, Array.Empty(), ValidSampleScalarDataLength ], + + // Pattern 1-4 with SqlVector.Null + [ 1, SqlVector.Null, Array.Empty(), ValidSampleScalarDataLength ], + + // Following scenario is not supported in SqlClient. + // This can only be fixed with a behavior change that SqlParameter.Value is internally set to DBNull.Value if it is set to null. + // [ 2, SqlVector.Null, Array.Empty(), vectorColumnLength ], + + [ 3, SqlVector.Null, Array.Empty(), ValidSampleScalarDataLength ], + [ 4, SqlVector.Null, Array.Empty(), ValidSampleScalarDataLength ] + ]; + } + + /// + /// Base class for all strongly-typed manual tests for . + /// + /// The element type of the . + /// The type containing the sample data. + public abstract class NativeVectorTestsBase : IDisposable + where TElement : unmanaged + where TTestData : NativeVectorTestDataBase, new() + { + private const string VectorColumnName = "VectorData"; + private const string VectorParameterName = "@VectorData"; + private const string VectorOutputParameterName = "@OutputVectorData"; + + private readonly string _connectionString; + private readonly SqlConnection _managementConnection; + private readonly Table _vectorTable; + private readonly Table _bulkCopySourceTable; + private readonly StoredProcedure _vectorProcedure; + + private readonly string _selectCommand; + private readonly string _insertCommand; + + private bool _disposed; + + // xUnit only allows MemberData for a test to point to static methods, properties and variables. + // This presents a problem when the sample data needs to change based upon the element type of + // the SqlVector, so this compromises: it instantiates a class derived from NativeVectorTestDataBase, + // then projects the relevant fields from it as static properties in this base class. + private static TTestData TestDataInstance => + field ??= new(); + + public static bool IsSupported => TestDataInstance.IsSupported; + + public static IEnumerable TestData => TestDataInstance.TestData; + + public NativeVectorTestsBase() + { + int vectorDimensions = TestDataInstance.ValidSampleScalarDataLength; + string tableDefinition = $@"(Id INT PRIMARY KEY IDENTITY, {VectorColumnName} vector({vectorDimensions}, {TestDataInstance.SqlServerTypeName}) NULL)"; + + _connectionString = DataTestUtility.TCPConnectionString; + _managementConnection = new SqlConnection(_connectionString); + _vectorTable = new Table(_managementConnection, "VectorTestTable", tableDefinition); + _bulkCopySourceTable = new Table(_managementConnection, "VectorBulkCopyTestTable", tableDefinition); + _vectorProcedure = new StoredProcedure(_managementConnection, + prefix: "VectorsAsVarcharSp", + definition: $@" + {VectorParameterName} vector({vectorDimensions}, {TestDataInstance.SqlServerTypeName}), -- Input: Serialized TElement[] as JSON string + {VectorOutputParameterName} vector({vectorDimensions}, {TestDataInstance.SqlServerTypeName}) OUTPUT -- Output: Echoed back from latest inserted row + AS + BEGIN + SET NOCOUNT ON; + + -- Insert into vector table + INSERT INTO {_vectorTable.Name} ({VectorColumnName}) + VALUES ({VectorParameterName}); + + -- Retrieve latest entry (assumes auto-incrementing ID) + SELECT TOP 1 {VectorOutputParameterName} = {VectorColumnName} + FROM {_vectorTable.Name} + ORDER BY Id DESC; + END;"); + + _selectCommand = $"SELECT {VectorColumnName} FROM {_vectorTable.Name} ORDER BY Id DESC"; + _insertCommand = $"INSERT INTO {_vectorTable.Name} ({VectorColumnName}) VALUES ({VectorParameterName})"; + } + + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + protected virtual void Dispose(bool disposing) + { + if (_disposed) + { + return; + } + + if (disposing) + { + _vectorProcedure?.Dispose(); + _bulkCopySourceTable?.Dispose(); + _vectorTable?.Dispose(); + _managementConnection?.Dispose(); + } + + _disposed = true; + } + + ~NativeVectorTestsBase() => + Dispose(false); + + /// + /// Wraps an inbound in a according to + /// the specified pattern. + /// + /// Pattern number. + /// to wrap. + /// A instance wrapping the . + /// is not a valid pattern. + /// + /// can be a number from 1 to 4, inclusive, with the following meaning: + /// + /// Parameterless constructor, manually setting ParameterName, SqlDbType and Value. + /// Specify the parameter name and value directly, relying upon type inference. + /// Specify the parameter name and SqlDbType, manually setting the Value property. + /// Identical to pattern 3, but with a known-invalid parameter size. + /// + /// + /// + /// + /// + /// + private static SqlParameter GetParameterByPattern(int pattern, object? value) => + pattern switch + { + 1 => new SqlParameter + { + ParameterName = VectorParameterName, + SqlDbType = SqlDbTypeExtensions.Vector, + Value = value + }, + 2 => new SqlParameter(VectorParameterName, value), + 3 => new SqlParameter(VectorParameterName, SqlDbTypeExtensions.Vector) { Value = value }, + // Even if size is specified, the actual size is determined by the value passed and specified size is ignored. + 4 => new SqlParameter(VectorParameterName, SqlDbTypeExtensions.Vector, TestDataInstance.IncorrectScalarDataParameterSize) { Value = value }, + _ => throw new ArgumentOutOfRangeException(nameof(pattern), $"Unsupported pattern: {pattern}") + }; + + private static void ValidateSqlVectorObject(bool isNull, SqlVector sqlVector, TElement[] expectedData, int expectedLength) + { + Assert.Equal(expectedData, sqlVector.Memory.ToArray()); + Assert.Equal(expectedLength, sqlVector.Length); + if (!isNull) + { + Assert.False(sqlVector.IsNull, "IsNull set to true for a non-null value"); + } + else + { + Assert.True(sqlVector.IsNull, "IsNull set to false for a null value"); + } + } + + private void ValidateInsertedData(SqlConnection connection, TElement[] expectedData, int expectedLength) + { + using SqlCommand selectCmd = new(_selectCommand, connection); + using SqlDataReader reader = selectCmd.ExecuteReader(); + Assert.True(reader.Read(), "No data found in the table."); + + // For both null and non-null cases, validate the SqlVector object + ValidateSqlVectorObject(reader.IsDBNull(0), (SqlVector)reader.GetSqlVector(0), expectedData, expectedLength); + ValidateSqlVectorObject(reader.IsDBNull(0), reader.GetFieldValue>(0), expectedData, expectedLength); + ValidateSqlVectorObject(reader.IsDBNull(0), (SqlVector)reader.GetSqlValue(0), expectedData, expectedLength); + + if (!reader.IsDBNull(0)) + { + ValidateSqlVectorObject(reader.IsDBNull(0), (SqlVector)reader.GetValue(0), expectedData, expectedLength); + ValidateSqlVectorObject(reader.IsDBNull(0), (SqlVector)reader[0], expectedData, expectedLength); + ValidateSqlVectorObject(reader.IsDBNull(0), (SqlVector)reader[VectorColumnName], expectedData, expectedLength); + Assert.Equal(expectedData, JsonSerializer.Deserialize(reader.GetString(0))); + Assert.Equal(expectedData, JsonSerializer.Deserialize(reader.GetSqlString(0).Value)); + Assert.Equal(expectedData, JsonSerializer.Deserialize(reader.GetFieldValue(0))); + } + else + { + Assert.Equal(DBNull.Value, reader.GetValue(0)); + Assert.Equal(DBNull.Value, reader[0]); + Assert.Equal(DBNull.Value, reader[VectorColumnName]); + Assert.Throws(() => reader.GetString(0)); + Assert.Throws(() => reader.GetSqlString(0).Value); + Assert.Throws(() => reader.GetFieldValue(0)); + } + } + + [ConditionalTheory(nameof(IsSupported))] + [MemberData(nameof(TestData), DisableDiscoveryEnumeration = true)] + public void TestSqlVectorParameterInsertionAndReads( + int pattern, + object? value, + TElement[] expectedValues, + int expectedLength) + { + using SqlConnection conn = new(_connectionString); + conn.Open(); + + using SqlCommand insertCmd = new(_insertCommand, conn); + SqlParameter param = GetParameterByPattern(pattern, value); + + insertCmd.Parameters.Add(param); + Assert.Equal(1, insertCmd.ExecuteNonQuery()); + insertCmd.Parameters.Clear(); + + ValidateInsertedData(conn, expectedValues, expectedLength); + } + + private async Task ValidateInsertedDataAsync(SqlConnection connection, TElement[] expectedData, int expectedLength) + { + using SqlCommand selectCmd = new(_selectCommand, connection); + using SqlDataReader reader = await selectCmd.ExecuteReaderAsync(); + Assert.True(await reader.ReadAsync(), "No data found in the table."); + + // For both null and non-null cases, validate the SqlVector object + ValidateSqlVectorObject(await reader.IsDBNullAsync(0), (SqlVector)reader.GetSqlVector(0), expectedData, expectedLength); + ValidateSqlVectorObject(await reader.IsDBNullAsync(0), await reader.GetFieldValueAsync>(0), expectedData, expectedLength); + ValidateSqlVectorObject(await reader.IsDBNullAsync(0), (SqlVector)reader.GetSqlValue(0), expectedData, expectedLength); + + if (!await reader.IsDBNullAsync(0)) + { + ValidateSqlVectorObject(await reader.IsDBNullAsync(0), (SqlVector)reader.GetValue(0), expectedData, expectedLength); + ValidateSqlVectorObject(await reader.IsDBNullAsync(0), (SqlVector)reader[0], expectedData, expectedLength); + ValidateSqlVectorObject(await reader.IsDBNullAsync(0), (SqlVector)reader[VectorColumnName], expectedData, expectedLength); + Assert.Equal(expectedData, JsonSerializer.Deserialize(reader.GetString(0))); + Assert.Equal(expectedData, JsonSerializer.Deserialize(reader.GetSqlString(0).Value)); + Assert.Equal(expectedData, JsonSerializer.Deserialize(await reader.GetFieldValueAsync(0))); + } + else + { + Assert.Equal(DBNull.Value, reader.GetValue(0)); + Assert.Equal(DBNull.Value, reader[0]); + Assert.Equal(DBNull.Value, reader[VectorColumnName]); + Assert.Throws(() => reader.GetString(0)); + Assert.Throws(() => reader.GetSqlString(0).Value); + await Assert.ThrowsAsync(async () => await reader.GetFieldValueAsync(0)); + } + } + + [ConditionalTheory(nameof(IsSupported))] + [MemberData(nameof(TestData), DisableDiscoveryEnumeration = true)] + public async Task TestSqlVectorParameterInsertionAndReadsAsync( + int pattern, + object? value, + TElement[] expectedValues, + int expectedLength) + { + using SqlConnection conn = new(_connectionString); + await conn.OpenAsync(); + + using SqlCommand insertCmd = new(_insertCommand, conn); + SqlParameter param = GetParameterByPattern(pattern, value); + + insertCmd.Parameters.Add(param); + Assert.Equal(1, await insertCmd.ExecuteNonQueryAsync()); + insertCmd.Parameters.Clear(); + + await ValidateInsertedDataAsync(conn, expectedValues, expectedLength); + } + + [ConditionalTheory(nameof(IsSupported))] + [MemberData(nameof(TestData), DisableDiscoveryEnumeration = true)] + public void TestStoredProcParamsForVector( + int pattern, + object? value, + TElement[] expectedValues, + int expectedLength) + { + using SqlConnection conn = new(_connectionString); + conn.Open(); + using SqlCommand command = new(_vectorProcedure.Name, conn) + { + CommandType = CommandType.StoredProcedure + }; + + // Set input and output parameters + SqlParameter inputParam = GetParameterByPattern(pattern, value); + command.Parameters.Add(inputParam); + + SqlParameter outputParam = new() + { + ParameterName = VectorOutputParameterName, + SqlDbType = SqlDbTypeExtensions.Vector, + Direction = ParameterDirection.Output, + Value = SqlVector.CreateNull(TestDataInstance.ValidSampleScalarDataLength) + }; + command.Parameters.Add(outputParam); + + // Execute the stored procedure + command.ExecuteNonQuery(); + + // Validate the output parameter + SqlVector vector = (SqlVector)outputParam.Value; + ValidateSqlVectorObject(vector.IsNull, vector, expectedValues, expectedLength); + + // Validate error for conventional way of setting output parameters + command.Parameters.Clear(); + command.Parameters.Add(inputParam); + SqlParameter outputParamWithoutVal = new(VectorOutputParameterName, SqlDbTypeExtensions.Vector, TestDataInstance.IncorrectScalarDataParameterSize) { Direction = ParameterDirection.Output }; + command.Parameters.Add(outputParamWithoutVal); + Assert.Throws(() => command.ExecuteNonQuery()); + } + + [ConditionalTheory(nameof(IsSupported))] + [MemberData(nameof(TestData), DisableDiscoveryEnumeration = true)] + public async Task TestStoredProcParamsForVectorAsync( + int pattern, + object? value, + TElement[] expectedValues, + int expectedLength) + { + using SqlConnection conn = new(_connectionString); + await conn.OpenAsync(); + using SqlCommand command = new(_vectorProcedure.Name, conn) + { + CommandType = CommandType.StoredProcedure + }; + + // Set input and output parameters + SqlParameter inputParam = GetParameterByPattern(pattern, value); + command.Parameters.Add(inputParam); + + SqlParameter outputParam = new() + { + ParameterName = VectorOutputParameterName, + SqlDbType = SqlDbTypeExtensions.Vector, + Direction = ParameterDirection.Output, + Value = SqlVector.CreateNull(TestDataInstance.ValidSampleScalarDataLength) + }; + command.Parameters.Add(outputParam); + + // Execute the stored procedure + await command.ExecuteNonQueryAsync(); + + // Validate the output parameter + SqlVector vector = (SqlVector)outputParam.Value; + ValidateSqlVectorObject(vector.IsNull, vector, expectedValues, expectedLength); + + // Validate error for conventional way of setting output parameters + command.Parameters.Clear(); + command.Parameters.Add(inputParam); + SqlParameter outputParamWithoutVal = new(VectorOutputParameterName, SqlDbTypeExtensions.Vector, TestDataInstance.IncorrectScalarDataParameterSize) { Direction = ParameterDirection.Output }; + command.Parameters.Add(outputParamWithoutVal); + await Assert.ThrowsAsync(async () => await command.ExecuteNonQueryAsync()); + } + + [ConditionalTheory(nameof(IsSupported))] + [InlineData(1)] + [InlineData(2)] + public void TestBulkCopyFromSqlTable(int bulkCopySourceMode) + { + // Setup source with test data and create destination table for bulkcopy. + using SqlConnection sourceConnection = new(_connectionString); + sourceConnection.Open(); + using SqlConnection destinationConnection = new(_connectionString); + destinationConnection.Open(); + DataTable? table = null; + switch (bulkCopySourceMode) + { + + case 1: + { + // Use SQL Server table as source + using SqlCommand insertCmd = new($"insert into {_bulkCopySourceTable.Name} values ({VectorParameterName})", sourceConnection); + SqlParameter vectorParam = new(VectorParameterName, new SqlVector(TestDataInstance.SampleScalarData)); + + // Insert 2 rows with one non-null and null value + insertCmd.Parameters.Add(vectorParam); + Assert.Equal(1, insertCmd.ExecuteNonQuery()); + insertCmd.Parameters.Clear(); + vectorParam.Value = DBNull.Value; + insertCmd.Parameters.Add(vectorParam); + Assert.Equal(1, insertCmd.ExecuteNonQuery()); + insertCmd.Parameters.Clear(); + break; + } + case 2: + table = new DataTable(_bulkCopySourceTable.Name); + table.Columns.Add("Id", typeof(int)); + table.Columns.Add(VectorColumnName, typeof(SqlVector)); + table.Rows.Add(1, new SqlVector(TestDataInstance.SampleScalarData)); + table.Rows.Add(2, DBNull.Value); + break; + default: + throw new ArgumentOutOfRangeException(nameof(bulkCopySourceMode), $"Unsupported bulk copy source mode: {bulkCopySourceMode}"); + } + + // Bulk copy from SQL Server table to destination table + using SqlCommand sourceDataCommand = new($"SELECT Id, {VectorColumnName} FROM {_bulkCopySourceTable.Name}", sourceConnection); + using SqlDataReader reader = sourceDataCommand.ExecuteReader(); + + // Verify that the destination table is empty before bulk copy + using SqlCommand countCommand = new($"SELECT COUNT(*) FROM {_vectorTable.Name}", destinationConnection); + Assert.Equal(0, Convert.ToInt16(countCommand.ExecuteScalar())); + + // Initialize bulk copy configuration + using SqlBulkCopy bulkCopy = new(destinationConnection) + { + DestinationTableName = _vectorTable.Name, + }; + + switch (bulkCopySourceMode) + { + case 1: + bulkCopy.WriteToServer(reader); + break; + case 2: + bulkCopy.WriteToServer(table); + break; + default: + throw new ArgumentOutOfRangeException(nameof(bulkCopySourceMode), $"Unsupported bulk copy source mode: {bulkCopySourceMode}"); + } + + // Verify that the 2 rows from the source table have been copied into the destination table. + Assert.Equal(2, Convert.ToInt16(countCommand.ExecuteScalar())); + + // Read the data from destination table as varbinary to verify the UTF-8 byte sequence + using SqlCommand verifyCommand = new($"SELECT {VectorColumnName} from {_vectorTable.Name}", destinationConnection); + using SqlDataReader verifyReader = verifyCommand.ExecuteReader(); + + // Verify that we have data in the destination table + Assert.True(verifyReader.Read(), "No data found in destination table after bulk copy."); + + // Validate first non-null value. + Assert.False(verifyReader.IsDBNull(0), "First row in the table is null."); + Assert.Equal(TestDataInstance.SampleScalarData, ((SqlVector)verifyReader.GetSqlVector(0)).Memory.ToArray()); + Assert.Equal(TestDataInstance.SampleScalarData.Length, ((SqlVector)verifyReader.GetSqlVector(0)).Length); + + // Verify that we have another row + Assert.True(verifyReader.Read(), "Second row not found in the table"); + + // Verify that we have encountered null. + Assert.True(verifyReader.IsDBNull(0)); + Assert.Equal([], ((SqlVector)verifyReader.GetSqlVector(0)).Memory.ToArray()); + Assert.Equal(TestDataInstance.SampleScalarData.Length, ((SqlVector)verifyReader.GetSqlVector(0)).Length); + } + + [ConditionalTheory(nameof(IsSupported))] + [InlineData(1)] + [InlineData(2)] + public async Task TestBulkCopyFromSqlTableAsync(int bulkCopySourceMode) + { + // Setup source with test data and create destination table for bulk copy. + using SqlConnection sourceConnection = new(_connectionString); + await sourceConnection.OpenAsync(); + using SqlConnection destinationConnection = new(_connectionString); + await destinationConnection.OpenAsync(); + + DataTable? table = null; + switch (bulkCopySourceMode) + { + + case 1: + { + // Use SQL Server table as source + using SqlCommand insertCmd = new($"insert into {_bulkCopySourceTable.Name} values ({VectorParameterName})", sourceConnection); + SqlParameter vectorParam = new(VectorParameterName, new SqlVector(TestDataInstance.SampleScalarData)); + + // Insert 2 rows with one non-null and null value + insertCmd.Parameters.Add(vectorParam); + Assert.Equal(1, await insertCmd.ExecuteNonQueryAsync()); + insertCmd.Parameters.Clear(); + vectorParam.Value = DBNull.Value; + insertCmd.Parameters.Add(vectorParam); + Assert.Equal(1, await insertCmd.ExecuteNonQueryAsync()); + insertCmd.Parameters.Clear(); + break; + } + case 2: + table = new DataTable(_bulkCopySourceTable.Name); + table.Columns.Add("Id", typeof(int)); + table.Columns.Add(VectorColumnName, typeof(SqlVector)); + table.Rows.Add(1, new SqlVector(TestDataInstance.SampleScalarData)); + table.Rows.Add(2, DBNull.Value); + break; + default: + throw new ArgumentOutOfRangeException(nameof(bulkCopySourceMode), $"Unsupported bulk copy source mode: {bulkCopySourceMode}"); + } + + // Bulk copy from SQL Server table to destination table + using SqlCommand sourceDataCommand = new($"SELECT Id, {VectorColumnName} FROM {_bulkCopySourceTable.Name}", sourceConnection); + using SqlDataReader reader = await sourceDataCommand.ExecuteReaderAsync(); + + // Verify that the destination table is empty before bulk copy + using SqlCommand countCommand = new($"SELECT COUNT(*) FROM {_vectorTable.Name}", destinationConnection); + Assert.Equal(0, Convert.ToInt16(await countCommand.ExecuteScalarAsync())); + + // Initialize bulk copy configuration + using SqlBulkCopy bulkCopy = new(destinationConnection) + { + DestinationTableName = _vectorTable.Name, + }; + + // Perform bulk copy + switch (bulkCopySourceMode) + { + case 1: + await bulkCopy.WriteToServerAsync(reader); + break; + case 2: + await bulkCopy.WriteToServerAsync(table); + break; + default: + throw new ArgumentOutOfRangeException(nameof(bulkCopySourceMode), $"Unsupported bulk copy source mode: {bulkCopySourceMode}"); + } + + // Verify that the 2 rows from the source table have been copied into the destination table. + Assert.Equal(2, Convert.ToInt16(await countCommand.ExecuteScalarAsync())); + + // Read the data from destination table as varbinary to verify the UTF-8 byte sequence + using SqlCommand verifyCommand = new($"SELECT {VectorColumnName} from {_vectorTable.Name}", destinationConnection); + using SqlDataReader verifyReader = await verifyCommand.ExecuteReaderAsync(); + + // Verify that we have data in the destination table + Assert.True(await verifyReader.ReadAsync(), "No data found in destination table after bulk copy."); + + // Validate first non-null value. + Assert.False(await verifyReader.IsDBNullAsync(0), "First row in the table is null."); + SqlVector vector = await verifyReader.GetFieldValueAsync>(0); + Assert.Equal(TestDataInstance.SampleScalarData, vector.Memory.ToArray()); + Assert.Equal(TestDataInstance.SampleScalarData.Length, vector.Length); + + // Verify that we have another row + Assert.True(await verifyReader.ReadAsync(), "Second row not found in the table"); + + // Verify that we have encountered null. + Assert.True(await verifyReader.IsDBNullAsync(0)); + vector = await verifyReader.GetFieldValueAsync>(0); + Assert.Equal([], vector.Memory.ToArray()); + Assert.Equal(TestDataInstance.SampleScalarData.Length, vector.Length); + } + + [ConditionalFact(nameof(IsSupported))] + public void TestGetFieldTypeReturnsSqlVectorForVectorColumn() + { + using SqlConnection connection = new(_connectionString); + connection.Open(); + + // Insert a row so we can query it + using (SqlCommand insertCmd = new(_insertCommand, connection)) + { + SqlParameter param = insertCmd.Parameters.Add(VectorParameterName, SqlDbTypeExtensions.Vector); + param.Value = new SqlVector(TestDataInstance.SampleScalarData); + insertCmd.ExecuteNonQuery(); + } + + using SqlCommand selectCmd = new(_selectCommand, connection); + using SqlDataReader reader = selectCmd.ExecuteReader(); + + // Verify GetFieldType returns SqlVector for the vector column + Assert.Equal(typeof(SqlVector), reader.GetFieldType(0)); + + // Verify GetProviderSpecificFieldType also returns SqlVector + Assert.Equal(typeof(SqlVector), reader.GetProviderSpecificFieldType(0)); + + // Verify that GetValue returns an instance consistent with GetFieldType + Assert.True(reader.Read(), "No data found in the table."); + object value = reader.GetValue(0); + Assert.IsType>(value); + Assert.Equal(TestDataInstance.SampleScalarData, ((SqlVector)value).Memory.ToArray()); + + // Verify GetFieldValue> returns the correct typed value + SqlVector typedValue = reader.GetFieldValue>(0); + Assert.IsType>(typedValue); + Assert.Equal(TestDataInstance.SampleScalarData, typedValue.Memory.ToArray()); + } + + [ConditionalFact(nameof(IsSupported))] + public void TestInsertVectorsWithPrepare() + { + using SqlConnection conn = new(_connectionString); + conn.Open(); + using SqlCommand command = new(_insertCommand, conn); + SqlParameter vectorParam = new(VectorParameterName, SqlDbTypeExtensions.Vector); + command.Parameters.Add(vectorParam); + command.Prepare(); + + TElement[,] sampleDataSet = TestDataInstance.SampleDataSet; + for (int i = 0; i < sampleDataSet.GetLength(0); i++) + { + TElement[] rowData = GetMultidimensionalArraySlice(sampleDataSet, i); + vectorParam.Value = new SqlVector(rowData); + command.ExecuteNonQuery(); + } + + using SqlCommand validateCommand = new($"SELECT {VectorColumnName} FROM {_vectorTable.Name}", conn); + using SqlDataReader reader = validateCommand.ExecuteReader(); + int rowcnt = 0; + while (reader.Read()) + { + TElement[] expectedData = GetMultidimensionalArraySlice(sampleDataSet, rowcnt); + TElement[] dbData = reader.GetSqlVector(0).Memory.ToArray(); + Assert.Equal(expectedData, dbData); + rowcnt++; + } + Assert.Equal(10, rowcnt); + + static TElement[] GetMultidimensionalArraySlice(TElement[,] sourceArray, int dimension) + { + TElement[] dst = new TElement[sourceArray.GetLength(1)]; + + for (int i = 0; i < dst.Length; i++) + { + dst[i] = sourceArray[dimension, i]; + } + return dst; + } + } + } +}