diff --git a/tools/clang/unittests/HLSLExec/LinAlgTests.cpp b/tools/clang/unittests/HLSLExec/LinAlgTests.cpp index da32f553c4..1d420d1ca6 100644 --- a/tools/clang/unittests/HLSLExec/LinAlgTests.cpp +++ b/tools/clang/unittests/HLSLExec/LinAlgTests.cpp @@ -307,6 +307,14 @@ class DxilConf_SM610_LinAlg { // Element access TEST_METHOD(ElementAccess_Wave_16x16_F16); + TEST_METHOD(ElementSet_Wave_16x16_F16); + + // Cast/Convert + TEST_METHOD(CopyConvert_Wave_16x16_F16); + TEST_METHOD(CopyConvert_Wave_16x16_F16_Transpose); + + // Matrix Arithmetic + TEST_METHOD(MatMatMul_Wave_16x16x16_F16); private: CComPtr D3DDevice; @@ -537,14 +545,9 @@ static void runElementAccess(ID3D12Device *Device, const MatrixParams &Params, bool Verbose) { const size_t NumElements = Params.totalElements(); const size_t NumThreads = Params.NumThreads; - const size_t InputBufSize = Params.totalBytes(); - const size_t ElementSize = elementSize(Params.CompType); - - // Output: ElementSize bytes per element - // 1 element for each mat idx - // 1 uint for each thread's length - const size_t OutputBufSize = - NumElements * ElementSize + NumThreads * sizeof(uint32_t); + const size_t MatrixSize = Params.totalBytes(); + // OutputBuf needs to fit the Matrix plus one uint per thread + const size_t OutputBufSize = MatrixSize + NumThreads * sizeof(uint32_t); std::stringstream ExtraDefs; std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str()); @@ -555,7 +558,7 @@ static void runElementAccess(ID3D12Device *Device, auto Op = createComputeOp(ElementAccessShader, "cs_6_10", "UAV(u0), UAV(u1)", Args.c_str()); - addUAVBuffer(Op.get(), "Input", InputBufSize, false, "byname"); + addUAVBuffer(Op.get(), "Input", MatrixSize, false, "byname"); addUAVBuffer(Op.get(), "Output", OutputBufSize, true); addRootUAV(Op.get(), 0, "Input"); addRootUAV(Op.get(), 1, "Output"); @@ -579,9 +582,8 @@ static void runElementAccess(ID3D12Device *Device, // Verify the end of the buffer is NumThreads number of lengths, whose // sum is greater than or equal to NumElements const BYTE *Out = static_cast(OutData.data()); - size_t MatrixEndOffset = NumElements * ElementSize; const uint32_t *Lengths = - reinterpret_cast(Out + MatrixEndOffset); + reinterpret_cast(Out + MatrixSize); uint32_t TotalLength = 0; for (size_t I = 0; I < NumThreads; ++I) TotalLength += Lengths[I]; @@ -602,4 +604,257 @@ void DxilConf_SM610_LinAlg::ElementAccess_Wave_16x16_F16() { runElementAccess(D3DDevice, DxcSupport, Params, VerboseLogging); } +static const char ElementSetShader[] = R"( + RWByteAddressBuffer Input : register(u0); + RWByteAddressBuffer Output : register(u1); + + [WaveSize(4, 64)] + [numthreads(NUMTHREADS, 1, 1)] + void main(uint threadID : SV_GroupIndex) { + if (WaveReadLaneFirst(threadID) != 0) + return; + + __builtin_LinAlgMatrix + [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE, SCOPE)]] + Mat; + __builtin_LinAlg_MatrixLoadFromDescriptor( + Mat, Input, 0, STRIDE, LAYOUT, 128); + + // Increment every element by 5 + for (uint I = 0; I < __builtin_LinAlg_MatrixLength(Mat); ++I) { + ELEM_TYPE Elem; + __builtin_LinAlg_MatrixGetElement(Elem, Mat, I); + Elem = Elem + 5; + __builtin_LinAlg_MatrixSetElement(Mat, Mat, I, Elem); + } + + __builtin_LinAlg_MatrixStoreToDescriptor( + Mat, Output, 0, STRIDE, LAYOUT, 128); + } +)"; + +static void runElementSet(ID3D12Device *Device, + dxc::SpecificDllLoader &DxcSupport, + const MatrixParams &Params, bool Verbose) { + const size_t NumElements = Params.totalElements(); + const size_t MatrixSize = Params.totalBytes(); + + std::stringstream ExtraDefs; + std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str()); + + compileShader(DxcSupport, ElementSetShader, "cs_6_10", Args, Verbose); + + // Start counting from 6 since each element was increased by 5 + auto Expected = makeExpected(Params.CompType, Params.M, Params.N, 6); + + auto Op = createComputeOp(ElementSetShader, "cs_6_10", "UAV(u0), UAV(u1)", + Args.c_str()); + addUAVBuffer(Op.get(), "Input", MatrixSize, false, "byname"); + addUAVBuffer(Op.get(), "Output", MatrixSize, true); + addRootUAV(Op.get(), 0, "Input"); + addRootUAV(Op.get(), 1, "Output"); + + auto Result = + runShaderOp(Device, DxcSupport, std::move(Op), + [NumElements, Params](LPCSTR Name, std::vector &Data, + st::ShaderOp *) { + VERIFY_IS_TRUE(fillInputBuffer(Name, Data, Params.CompType, + NumElements), + "Saw unsupported component type"); + }); + + MappedData OutData; + Result->Test->GetReadBackData("Output", &OutData); + + // Verify the front of the buffer is a list of elements of the expected type + VERIFY_IS_TRUE(verifyComponentBuffer(Params.CompType, OutData.data(), + Expected, NumElements, Verbose)); +} + +void DxilConf_SM610_LinAlg::ElementSet_Wave_16x16_F16() { + MatrixParams Params = {}; + Params.CompType = ComponentType::F16; + Params.M = 16; + Params.N = 16; + Params.Use = MatrixUse::Accumulator; + Params.Scope = MatrixScope::Wave; + Params.Layout = LinalgMatrixLayout::RowMajor; + Params.NumThreads = 64; + Params.Enable16Bit = true; + runElementSet(D3DDevice, DxcSupport, Params, VerboseLogging); +} + +static const char CopyConvertShader[] = R"( + RWByteAddressBuffer Input : register(u0); + RWByteAddressBuffer Output : register(u1); + + [WaveSize(4, 64)] + [numthreads(NUMTHREADS, 1, 1)] + void main(uint threadID : SV_GroupIndex) { + if (WaveReadLaneFirst(threadID) != 0) + return; + + __builtin_LinAlgMatrix + [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE, SCOPE)]] + Src; + __builtin_LinAlgMatrix + [[__LinAlgMatrix_Attributes(COMP_TYPE, N_DIM, M_DIM, USE, SCOPE)]] + Dst; + + __builtin_LinAlg_MatrixLoadFromDescriptor( + Src, Input, 0, STRIDE, LAYOUT, 128); + __builtin_LinAlg_CopyConvertMatrix(Dst, Src, TRANSPOSE); + __builtin_LinAlg_MatrixStoreToDescriptor( + Dst, Output, 0, STRIDE, LAYOUT, 128); + } +)"; + +static void runCopyConvert(ID3D12Device *Device, + dxc::SpecificDllLoader &DxcSupport, + const MatrixParams &Params, bool Verbose, + bool Transpose) { + const size_t NumElements = Params.totalElements(); + const size_t BufferSize = Params.totalBytes(); + + std::stringstream ExtraDefs; + ExtraDefs << " -DTRANSPOSE=" << Transpose; + + std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str()); + + compileShader(DxcSupport, CopyConvertShader, "cs_6_10", Args, Verbose); + + auto Expected = makeExpected(Params.CompType, Params.M, Params.N, 1, + /*Increment=*/true, Transpose); + + // Construct the ShaderOp: two UAV buffers, load from one, store to other. + auto Op = createComputeOp(CopyConvertShader, "cs_6_10", "UAV(u0), UAV(u1)", + Args.c_str()); + addUAVBuffer(Op.get(), "Input", BufferSize, false, "byname"); + addUAVBuffer(Op.get(), "Output", BufferSize, true); + addRootUAV(Op.get(), 0, "Input"); + addRootUAV(Op.get(), 1, "Output"); + + auto Result = + runShaderOp(Device, DxcSupport, std::move(Op), + [NumElements, Params](LPCSTR Name, std::vector &Data, + st::ShaderOp *) { + VERIFY_IS_TRUE(fillInputBuffer(Name, Data, Params.CompType, + NumElements), + "Saw unsupported component type"); + }); + + MappedData OutData; + Result->Test->GetReadBackData("Output", &OutData); + + VERIFY_IS_TRUE(verifyComponentBuffer(Params.CompType, OutData.data(), + Expected, NumElements, Verbose)); +} + +void DxilConf_SM610_LinAlg::CopyConvert_Wave_16x16_F16() { + MatrixParams Params = {}; + Params.CompType = ComponentType::F16; + Params.M = 16; + Params.N = 16; + Params.Use = MatrixUse::A; + Params.Scope = MatrixScope::Wave; + Params.Layout = LinalgMatrixLayout::RowMajor; + Params.NumThreads = 64; + Params.Enable16Bit = true; + runCopyConvert(D3DDevice, DxcSupport, Params, VerboseLogging, + /*Transpose=*/false); +} + +void DxilConf_SM610_LinAlg::CopyConvert_Wave_16x16_F16_Transpose() { + MatrixParams Params = {}; + Params.CompType = ComponentType::F16; + Params.M = 16; + Params.N = 16; + Params.Use = MatrixUse::A; + Params.Scope = MatrixScope::Wave; + Params.Layout = LinalgMatrixLayout::RowMajor; + Params.NumThreads = 64; + Params.Enable16Bit = true; + runCopyConvert(D3DDevice, DxcSupport, Params, VerboseLogging, + /*Transpose=*/true); +} + +static const char MatMatMulShader[] = R"( + #define USE_A 0 + #define USE_B 1 + #define USE_ACC 2 + + RWByteAddressBuffer Output : register(u0); + + [WaveSize(4, 64)] + [numthreads(NUMTHREADS, 1, 1)] + void main(uint threadID : SV_GroupIndex) { + if (WaveReadLaneFirst(threadID) != 0) + return; + + __builtin_LinAlgMatrix + [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, K_DIM, USE_A, SCOPE)]] + MatA; + __builtin_LinAlg_FillMatrix(MatA, A_FILL); + + __builtin_LinAlgMatrix + [[__LinAlgMatrix_Attributes(COMP_TYPE, K_DIM, N_DIM, USE_B, SCOPE)]] + MatB; + __builtin_LinAlg_FillMatrix(MatB, B_FILL); + + __builtin_LinAlgMatrix + [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE_ACC, SCOPE)]] + MatC; + __builtin_LinAlg_MatrixMatrixMultiply(MatC, MatA, MatB); + + __builtin_LinAlg_MatrixStoreToDescriptor( + MatC, Output, 0, STRIDE, LAYOUT, 128); + } +)"; + +static void runMatMatMul(ID3D12Device *Device, + dxc::SpecificDllLoader &DxcSupport, + const MatrixParams &Params, bool Verbose, MatrixDim K, + float AFill, float BFill) { + const size_t NumElements = Params.totalElements(); + const size_t BufferSize = Params.totalBytes(); + + std::stringstream ExtraDefs; + ExtraDefs << " -DK_DIM=" << K; + ExtraDefs << " -DA_FILL=" << AFill; + ExtraDefs << " -DB_FILL=" << BFill; + + std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str()); + + compileShader(DxcSupport, MatMatMulShader, "cs_6_10", Args, Verbose); + + auto Expected = makeExpected(Params.CompType, Params.M, Params.N, + AFill * BFill * K, /*Increment=*/false); + + auto Op = + createComputeOp(MatMatMulShader, "cs_6_10", "UAV(u0)", Args.c_str()); + addUAVBuffer(Op.get(), "Output", BufferSize, true); + addRootUAV(Op.get(), 0, "Output"); + + auto Result = runShaderOp(Device, DxcSupport, std::move(Op)); + + MappedData OutData; + Result->Test->GetReadBackData("Output", &OutData); + + VERIFY_IS_TRUE(verifyComponentBuffer(Params.CompType, OutData.data(), + Expected, NumElements, Verbose)); +} + +void DxilConf_SM610_LinAlg::MatMatMul_Wave_16x16x16_F16() { + MatrixParams Params = {}; + Params.CompType = ComponentType::F16; + Params.M = 16; + Params.N = 16; + Params.Scope = MatrixScope::Wave; + Params.Layout = LinalgMatrixLayout::RowMajor; + Params.NumThreads = 64; + Params.Enable16Bit = true; + runMatMatMul(D3DDevice, DxcSupport, Params, VerboseLogging, /*K=*/16, + /*AFill=*/2.0f, /*BFill=*/3.0f); +} + } // namespace LinAlg