diff --git a/src/Microsoft.Data.SqlClient/tests/Common/Fixtures/AzureKeyVaultKeyFixtureBase.cs b/src/Microsoft.Data.SqlClient/tests/Common/Fixtures/AzureKeyVaultKeyFixtureBase.cs index e6e3bf4996..b2232b3f69 100644 --- a/src/Microsoft.Data.SqlClient/tests/Common/Fixtures/AzureKeyVaultKeyFixtureBase.cs +++ b/src/Microsoft.Data.SqlClient/tests/Common/Fixtures/AzureKeyVaultKeyFixtureBase.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System.Security.Cryptography; using Azure.Core; using Azure.Security.KeyVault.Keys; @@ -18,20 +19,41 @@ namespace Microsoft.Data.SqlClient.Tests.Common.Fixtures; public abstract class AzureKeyVaultKeyFixtureBase : IDisposable { private readonly KeyClient _keyClient; - private readonly Random _randomGenerator; + private readonly RandomNumberGenerator _randomGenerator; private readonly List _createdKeys = new List(); protected AzureKeyVaultKeyFixtureBase(Uri keyVaultUri, TokenCredential keyVaultToken) { _keyClient = new KeyClient(keyVaultUri, keyVaultToken); - _randomGenerator = new Random(); + _randomGenerator = RandomNumberGenerator.Create(); } protected Uri CreateKey(string name, int keySize) { - CreateRsaKeyOptions createOptions = new CreateRsaKeyOptions(GenerateUniqueName(name)) { KeySize = keySize }; - KeyVaultKey created = _keyClient.CreateRsaKey(createOptions); + const int MaxConflictResolutions = 5; + KeyVaultKey created; + int i = 0; + + while (true) + { + CreateRsaKeyOptions createOptions = new CreateRsaKeyOptions(GenerateUniqueName(name)) { KeySize = keySize }; + + try + { + created = _keyClient.CreateRsaKey(createOptions); + break; + } + // It's possible for a key to already exist with the same name, even in a deleted state. If so, CreateRsaKey + // will throw an exception with HTTP status code 409 (Conflict.) + // We can't assume we possess permissions to purge or to recover the key, so regenerate the name and try again. + // Only make MaxConflictResolutions attempts, to avoid possible infinite loops. + catch (Azure.RequestFailedException conflictException) + when (conflictException.Status == 409 && i < MaxConflictResolutions) + { + i++; + } + } _createdKeys.Add(created); return created.Id; @@ -41,7 +63,7 @@ private string GenerateUniqueName(string name) { byte[] rndBytes = new byte[16]; - _randomGenerator.NextBytes(rndBytes); + _randomGenerator.GetBytes(rndBytes); return name + "-" + BitConverter.ToString(rndBytes); } @@ -64,5 +86,7 @@ protected virtual void Dispose(bool disposing) continue; } } + + _randomGenerator.Dispose(); } } diff --git a/src/Microsoft.Data.SqlClient/tests/Common/LocalAppContextSwitchesHelper.cs b/src/Microsoft.Data.SqlClient/tests/Common/LocalAppContextSwitchesHelper.cs index 430d6645a8..8a70df630a 100644 --- a/src/Microsoft.Data.SqlClient/tests/Common/LocalAppContextSwitchesHelper.cs +++ b/src/Microsoft.Data.SqlClient/tests/Common/LocalAppContextSwitchesHelper.cs @@ -208,7 +208,8 @@ public void Dispose() #region Switch Value Getters and Setters - // These properties get or set the like-named underlying switch field value. + // These properties get the like-named underlying switch *property* value and set the underlying + // switch *field* value. This allows tests to verify the default switch values. // // They all throw if the value cannot be retrieved or set. @@ -218,7 +219,7 @@ public void Dispose() /// public bool? DisableTnirByDefault { - get => GetSwitchValue("s_disableTnirByDefault"); + get => GetSwitchPropertyValue(nameof(DisableTnirByDefault)); set => SetSwitchValue("s_disableTnirByDefault", value); } #endif @@ -228,7 +229,7 @@ public bool? DisableTnirByDefault /// public bool? EnableMultiSubnetFailoverByDefault { - get => GetSwitchValue("s_enableMultiSubnetFailoverByDefault"); + get => GetSwitchPropertyValue(nameof(EnableMultiSubnetFailoverByDefault)); set => SetSwitchValue("s_enableMultiSubnetFailoverByDefault", value); } @@ -238,7 +239,7 @@ public bool? EnableMultiSubnetFailoverByDefault /// public bool? GlobalizationInvariantMode { - get => GetSwitchValue("s_globalizationInvariantMode"); + get => GetSwitchPropertyValue(nameof(GlobalizationInvariantMode)); set => SetSwitchValue("s_globalizationInvariantMode", value); } #endif @@ -248,7 +249,7 @@ public bool? GlobalizationInvariantMode /// public bool? IgnoreServerProvidedFailoverPartner { - get => GetSwitchValue("s_ignoreServerProvidedFailoverPartner"); + get => GetSwitchPropertyValue(nameof(IgnoreServerProvidedFailoverPartner)); set => SetSwitchValue("s_ignoreServerProvidedFailoverPartner", value); } @@ -257,7 +258,7 @@ public bool? IgnoreServerProvidedFailoverPartner /// public bool? UseLegacyFailoverAlternationOnLoginSqlErrors { - get => GetSwitchValue("s_useLegacyFailoverAlternationOnLoginSqlErrors"); + get => GetSwitchPropertyValue(nameof(UseLegacyFailoverAlternationOnLoginSqlErrors)); set => SetSwitchValue("s_useLegacyFailoverAlternationOnLoginSqlErrors", value); } @@ -266,7 +267,7 @@ public bool? UseLegacyFailoverAlternationOnLoginSqlErrors /// public bool? LegacyRowVersionNullBehavior { - get => GetSwitchValue("s_legacyRowVersionNullBehavior"); + get => GetSwitchPropertyValue(nameof(LegacyRowVersionNullBehavior)); set => SetSwitchValue("s_legacyRowVersionNullBehavior", value); } @@ -275,7 +276,7 @@ public bool? LegacyRowVersionNullBehavior /// public bool? LegacyVarTimeZeroScaleBehaviour { - get => GetSwitchValue("s_legacyVarTimeZeroScaleBehaviour"); + get => GetSwitchPropertyValue(nameof(LegacyVarTimeZeroScaleBehaviour)); set => SetSwitchValue("s_legacyVarTimeZeroScaleBehaviour", value); } @@ -284,7 +285,7 @@ public bool? LegacyVarTimeZeroScaleBehaviour /// public bool? MakeReadAsyncBlocking { - get => GetSwitchValue("s_makeReadAsyncBlocking"); + get => GetSwitchPropertyValue(nameof(MakeReadAsyncBlocking)); set => SetSwitchValue("s_makeReadAsyncBlocking", value); } @@ -293,7 +294,7 @@ public bool? MakeReadAsyncBlocking /// public bool? SuppressInsecureTlsWarning { - get => GetSwitchValue("s_suppressInsecureTlsWarning"); + get => GetSwitchPropertyValue(nameof(SuppressInsecureTlsWarning)); set => SetSwitchValue("s_suppressInsecureTlsWarning", value); } @@ -302,7 +303,7 @@ public bool? SuppressInsecureTlsWarning /// public bool? TruncateScaledDecimal { - get => GetSwitchValue("s_truncateScaledDecimal"); + get => GetSwitchPropertyValue(nameof(TruncateScaledDecimal)); set => SetSwitchValue("s_truncateScaledDecimal", value); } @@ -311,7 +312,7 @@ public bool? TruncateScaledDecimal /// public bool? UseCompatibilityAsyncBehaviour { - get => GetSwitchValue("s_useCompatibilityAsyncBehaviour"); + get => GetSwitchPropertyValue(nameof(UseCompatibilityAsyncBehaviour)); set => SetSwitchValue("s_useCompatibilityAsyncBehaviour", value); } @@ -320,7 +321,7 @@ public bool? UseCompatibilityAsyncBehaviour /// public bool? UseCompatibilityProcessSni { - get => GetSwitchValue("s_useCompatibilityProcessSni"); + get => GetSwitchPropertyValue(nameof(UseCompatibilityProcessSni)); set => SetSwitchValue("s_useCompatibilityProcessSni", value); } @@ -329,7 +330,7 @@ public bool? UseCompatibilityProcessSni /// public bool? UseConnectionPoolV2 { - get => GetSwitchValue("s_useConnectionPoolV2"); + get => GetSwitchPropertyValue(nameof(UseConnectionPoolV2)); set => SetSwitchValue("s_useConnectionPoolV2", value); } @@ -338,7 +339,7 @@ public bool? UseConnectionPoolV2 /// public bool? UseOverallConnectTimeoutForPoolWait { - get => GetSwitchValue("s_useOverallConnectTimeoutForPoolWait"); + get => GetSwitchPropertyValue(nameof(UseOverallConnectTimeoutForPoolWait)); set => SetSwitchValue("s_useOverallConnectTimeoutForPoolWait", value); } @@ -348,7 +349,7 @@ public bool? UseOverallConnectTimeoutForPoolWait /// public bool? UseManagedNetworking { - get => GetSwitchValue("s_useManagedNetworking"); + get => GetSwitchPropertyValue(nameof(UseManagedNetworking)); set => SetSwitchValue("s_useManagedNetworking", value); } #endif @@ -358,7 +359,7 @@ public bool? UseManagedNetworking /// public bool? UseMinimumLoginTimeout { - get => GetSwitchValue("s_useMinimumLoginTimeout"); + get => GetSwitchPropertyValue(nameof(UseMinimumLoginTimeout)); set => SetSwitchValue("s_useMinimumLoginTimeout", value); } @@ -371,19 +372,7 @@ public bool? UseMinimumLoginTimeout /// private static bool? GetSwitchValue(string fieldName) { - var assembly = Assembly.GetAssembly(typeof(SqlConnection)); - if (assembly is null) - { - throw new InvalidOperationException( - "Could not get assembly for Microsoft.Data.SqlClient"); - } - - var type = assembly.GetType("Microsoft.Data.SqlClient.LocalAppContextSwitches"); - if (type is null) - { - throw new InvalidOperationException( - "Could not get type LocalAppContextSwitches"); - } + var type = GetLocalAppContextSwitchesType(); var field = type.GetField( fieldName, @@ -418,19 +407,7 @@ public bool? UseMinimumLoginTimeout /// private static void SetSwitchValue(string fieldName, bool? value) { - var assembly = Assembly.GetAssembly(typeof(SqlConnection)); - if (assembly is null) - { - throw new InvalidOperationException( - "Could not get assembly for Microsoft.Data.SqlClient"); - } - - var type = assembly.GetType("Microsoft.Data.SqlClient.LocalAppContextSwitches"); - if (type is null) - { - throw new InvalidOperationException( - "Could not get type LocalAppContextSwitches"); - } + var type = GetLocalAppContextSwitchesType(); var field = type.GetField( fieldName, @@ -455,5 +432,49 @@ private static void SetSwitchValue(string fieldName, bool? value) field.SetValue(null, Enum.ToObject(field.FieldType, byteValue)); } + /// + /// Use reflection to get a switch property value from LocalAppContextSwitches. + /// + /// + /// Each property in LocalAppContextSwitchHelper corresponds to a like-named property in + /// LocalAppContextSwitches, which may return a different value when the AppContext switch + /// has not been set. + /// + private static bool GetSwitchPropertyValue(string propertyName) + { + var type = GetLocalAppContextSwitchesType(); + var property = type.GetProperty( + propertyName, + BindingFlags.Static | BindingFlags.Public); + + if (property == null) + { + throw new InvalidOperationException( + $"Property '{propertyName}' not found in LocalAppContextSwitches"); + } + + object? value = property.GetValue(null); + + return value is bool boolValue + ? boolValue + : throw new InvalidOperationException($"Property '{propertyName}' is not of type bool."); + } + + private static Type GetLocalAppContextSwitchesType() + { + var assembly = Assembly.GetAssembly(typeof(SqlConnection)); + if (assembly is null) + { + throw new InvalidOperationException("Could not get assembly for Microsoft.Data.SqlClient"); + } + + var type = assembly.GetType("Microsoft.Data.SqlClient.LocalAppContextSwitches"); + if (type is null) + { + throw new InvalidOperationException("Could not get type LocalAppContextSwitches"); + } + return type; + } + #endregion } diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/Batch/BatchTests.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/Batch/BatchTests.cs index 01d54f2462..69c4b22316 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/Batch/BatchTests.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/Batch/BatchTests.cs @@ -7,6 +7,7 @@ using System.Data; using System.Data.Common; using System.Threading.Tasks; +using Microsoft.Data.SqlClient.Tests.Common.Fixtures.DatabaseObjects; using Xunit; namespace Microsoft.Data.SqlClient.ManualTesting.Tests @@ -378,10 +379,14 @@ public static void ExceptionWithoutBatchContainsNoBatch() [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))] public static void ParameterInOutAndReturn() { - string create = - @" -CREATE PROCEDURE TestInAndOutParams - @Input int, + SqlParameter input = CreateParameter("@Input", SqlDbType.Int, 2); + SqlParameter inputOutput = CreateParameter("@InOut", SqlDbType.Int, 4, ParameterDirection.InputOutput); + SqlParameter output = CreateParameter("@Output", SqlDbType.Int, DBNull.Value, ParameterDirection.Output); + SqlParameter returned = CreateParameter("@RETURN_VALUE", SqlDbType.Int, DBNull.Value, ParameterDirection.ReturnValue); + + using (SqlConnection conn = new(DataTestUtility.TCPConnectionString)) + using (StoredProcedure spTestInAndOutParams = new(conn, "TestInAndOutParams", @" + @Input int, @InOut int OUTPUT, @Output int = default OUTPUT AS @@ -389,26 +394,14 @@ CREATE PROCEDURE TestInAndOutParams SET NOCOUNT ON; SELECT @InOut = 2 * @InOut, @Output = 2 * @Input RETURN @Input -END"; - string drop = "DROP PROCEDURE TestInAndOutParams"; - - SqlParameter input = CreateParameter("@Input", SqlDbType.Int, 2); - SqlParameter inputOutput = CreateParameter("@InOut", SqlDbType.Int, 4, ParameterDirection.InputOutput); - SqlParameter output = CreateParameter("@Output", SqlDbType.Int, DBNull.Value, ParameterDirection.Output); - SqlParameter returned = CreateParameter("@RETURN_VALUE", SqlDbType.Int, DBNull.Value, ParameterDirection.ReturnValue); - try +END")) { - TryExecuteNonQueryCommand(drop); - ExecuteNonQueryCommand(create); - - using (SqlConnection conn = new SqlConnection(DataTestUtility.TCPConnectionString)) using (SqlBatch batch = new SqlBatch(conn)) { - conn.Open(); batch.Commands.Add(new SqlBatchCommand("SELECT @@VERSION")); batch.Commands.Add( new SqlBatchCommand( - "TestInAndOutParams", + spTestInAndOutParams.Name, CommandType.StoredProcedure, new[] { input, inputOutput, output, returned } ) @@ -417,10 +410,6 @@ RETURN @Input batch.ExecuteNonQuery(); } } - finally - { - TryExecuteNonQueryCommand(drop); - } Assert.Equal(8, Convert.ToInt32(inputOutput.Value)); Assert.Equal(4, Convert.ToInt32(output.Value)); @@ -657,28 +646,5 @@ private static SqlParameter CreateParameter(string name, SqlDbType type, T va parameter.Value = value; return parameter; } - - private static void ExecuteNonQueryCommand(string command) - { - using (SqlConnection conn = new SqlConnection(DataTestUtility.TCPConnectionString)) - using (SqlCommand cmd = conn.CreateCommand()) - { - conn.Open(); - cmd.CommandText = command; - cmd.ExecuteNonQuery(); - } - } - private static bool TryExecuteNonQueryCommand(string command) - { - try - { - ExecuteNonQueryCommand(command); - return true; - } - catch - { - } - return false; - } } } diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlCommand/SqlCommandCancelTest.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlCommand/SqlCommandCancelTest.cs index dc97d98052..830f1bb733 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlCommand/SqlCommandCancelTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlCommand/SqlCommandCancelTest.cs @@ -375,7 +375,9 @@ private static void ExecuteCommandCancelExpected(object state) string errorMessage = SystemDataResourceManager.Instance.SQL_OperationCancelled; string errorMessageSevereFailure = SystemDataResourceManager.Instance.SQL_SevereError; - DataTestUtility.ExpectFailure(() => + // This could fail with either a SqlException or an InvalidOperationException depending on timing, + // so we will accept either but require the message to match expected cancellation messages + DataTestUtility.ExpectFailure(() => { threadsReady.SignalAndWait(); using (SqlDataReader r = command.ExecuteReader()) @@ -387,7 +389,9 @@ private static void ExecuteCommandCancelExpected(object state) } } while (r.NextResult()); } - }, new string[] { errorMessage, errorMessageSevereFailure }); + }, + new string[] { errorMessage, errorMessageSevereFailure }, + customExceptionVerifier: (ex) => ex is SqlException or InvalidOperationException); } diff --git a/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/LocalAppContextSwitchesTest.cs b/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/LocalAppContextSwitchesTest.cs index c5c6f7ec73..0d2fe8d83d 100644 --- a/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/LocalAppContextSwitchesTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/LocalAppContextSwitchesTest.cs @@ -1,4 +1,4 @@ -// Licensed to the .NET Foundation under one or more agreements. +// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. @@ -49,27 +49,27 @@ public void TestDefaultAppContextSwitchValues() switchesHelper.DisableTnirByDefault = null; #endif - Assert.False(LocalAppContextSwitches.LegacyRowVersionNullBehavior); - Assert.False(LocalAppContextSwitches.SuppressInsecureTlsWarning); - Assert.False(LocalAppContextSwitches.MakeReadAsyncBlocking); - Assert.True(LocalAppContextSwitches.UseMinimumLoginTimeout); - Assert.True(LocalAppContextSwitches.LegacyVarTimeZeroScaleBehaviour); - Assert.True(LocalAppContextSwitches.UseCompatibilityProcessSni); - Assert.True(LocalAppContextSwitches.UseCompatibilityAsyncBehaviour); - Assert.False(LocalAppContextSwitches.UseConnectionPoolV2); - Assert.False(LocalAppContextSwitches.UseOverallConnectTimeoutForPoolWait); - Assert.False(LocalAppContextSwitches.TruncateScaledDecimal); - Assert.False(LocalAppContextSwitches.IgnoreServerProvidedFailoverPartner); - Assert.False(LocalAppContextSwitches.UseLegacyFailoverAlternationOnLoginSqlErrors); - Assert.False(LocalAppContextSwitches.EnableMultiSubnetFailoverByDefault); + Assert.False(switchesHelper.LegacyRowVersionNullBehavior); + Assert.False(switchesHelper.SuppressInsecureTlsWarning); + Assert.False(switchesHelper.MakeReadAsyncBlocking); + Assert.True(switchesHelper.UseMinimumLoginTimeout); + Assert.True(switchesHelper.LegacyVarTimeZeroScaleBehaviour); + Assert.True(switchesHelper.UseCompatibilityProcessSni); + Assert.True(switchesHelper.UseCompatibilityAsyncBehaviour); + Assert.False(switchesHelper.UseConnectionPoolV2); + Assert.False(switchesHelper.UseOverallConnectTimeoutForPoolWait); + Assert.False(switchesHelper.TruncateScaledDecimal); + Assert.False(switchesHelper.IgnoreServerProvidedFailoverPartner); + Assert.False(switchesHelper.UseLegacyFailoverAlternationOnLoginSqlErrors); + Assert.False(switchesHelper.EnableMultiSubnetFailoverByDefault); #if NET - Assert.False(LocalAppContextSwitches.GlobalizationInvariantMode); + Assert.False(switchesHelper.GlobalizationInvariantMode); #endif #if NET && _WINDOWS - Assert.False(LocalAppContextSwitches.UseManagedNetworking); + Assert.False(switchesHelper.UseManagedNetworking); #endif #if NETFRAMEWORK - Assert.False(LocalAppContextSwitches.DisableTnirByDefault); + Assert.False(switchesHelper.DisableTnirByDefault); #endif } }