Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
277 changes: 266 additions & 11 deletions tools/clang/unittests/HLSLExec/LinAlgTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,14 @@ class DxilConf_SM610_LinAlg {

// Element access
TEST_METHOD(ElementAccess_Wave_16x16_F16);
TEST_METHOD(ElementSet_Wave_16x16_F16);

// Cast/Convert
TEST_METHOD(CopyConvert_Wave_16x16_F16);
TEST_METHOD(CopyConvert_Wave_16x16_F16_Transpose);

// Matrix Arithmetic
TEST_METHOD(MatMatMul_Wave_16x16x16_F16);

private:
CComPtr<ID3D12Device> D3DDevice;
Expand Down Expand Up @@ -537,14 +545,9 @@ static void runElementAccess(ID3D12Device *Device,
const MatrixParams &Params, bool Verbose) {
const size_t NumElements = Params.totalElements();
const size_t NumThreads = Params.NumThreads;
const size_t InputBufSize = Params.totalBytes();
const size_t ElementSize = elementSize(Params.CompType);

// Output: ElementSize bytes per element
// 1 element for each mat idx
// 1 uint for each thread's length
const size_t OutputBufSize =
NumElements * ElementSize + NumThreads * sizeof(uint32_t);
const size_t MatrixSize = Params.totalBytes();
// OutputBuf needs to fit the Matrix plus one uint per thread
const size_t OutputBufSize = MatrixSize + NumThreads * sizeof(uint32_t);

std::stringstream ExtraDefs;
std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str());
Expand All @@ -555,7 +558,7 @@ static void runElementAccess(ID3D12Device *Device,

auto Op = createComputeOp(ElementAccessShader, "cs_6_10", "UAV(u0), UAV(u1)",
Args.c_str());
addUAVBuffer(Op.get(), "Input", InputBufSize, false, "byname");
addUAVBuffer(Op.get(), "Input", MatrixSize, false, "byname");
addUAVBuffer(Op.get(), "Output", OutputBufSize, true);
addRootUAV(Op.get(), 0, "Input");
addRootUAV(Op.get(), 1, "Output");
Expand All @@ -579,9 +582,8 @@ static void runElementAccess(ID3D12Device *Device,
// Verify the end of the buffer is NumThreads number of lengths, whose
// sum is greater than or equal to NumElements
const BYTE *Out = static_cast<const BYTE *>(OutData.data());
size_t MatrixEndOffset = NumElements * ElementSize;
const uint32_t *Lengths =
reinterpret_cast<const uint32_t *>(Out + MatrixEndOffset);
reinterpret_cast<const uint32_t *>(Out + MatrixSize);
uint32_t TotalLength = 0;
for (size_t I = 0; I < NumThreads; ++I)
TotalLength += Lengths[I];
Expand All @@ -602,4 +604,257 @@ void DxilConf_SM610_LinAlg::ElementAccess_Wave_16x16_F16() {
runElementAccess(D3DDevice, DxcSupport, Params, VerboseLogging);
}

static const char ElementSetShader[] = R"(
RWByteAddressBuffer Input : register(u0);
RWByteAddressBuffer Output : register(u1);

[WaveSize(4, 64)]
[numthreads(NUMTHREADS, 1, 1)]
void main(uint threadID : SV_GroupIndex) {
if (WaveReadLaneFirst(threadID) != 0)
return;

__builtin_LinAlgMatrix
[[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE, SCOPE)]]
Mat;
__builtin_LinAlg_MatrixLoadFromDescriptor(
Mat, Input, 0, STRIDE, LAYOUT, 128);

// Increment every element by 5
for (uint I = 0; I < __builtin_LinAlg_MatrixLength(Mat); ++I) {
ELEM_TYPE Elem;
__builtin_LinAlg_MatrixGetElement(Elem, Mat, I);
Elem = Elem + 5;
__builtin_LinAlg_MatrixSetElement(Mat, Mat, I, Elem);
}

__builtin_LinAlg_MatrixStoreToDescriptor(
Mat, Output, 0, STRIDE, LAYOUT, 128);
}
)";

static void runElementSet(ID3D12Device *Device,
dxc::SpecificDllLoader &DxcSupport,
const MatrixParams &Params, bool Verbose) {
const size_t NumElements = Params.totalElements();
const size_t MatrixSize = Params.totalBytes();

std::stringstream ExtraDefs;
std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str());

compileShader(DxcSupport, ElementSetShader, "cs_6_10", Args, Verbose);

// Start counting from 6 since each element was increased by 5
auto Expected = makeExpected(Params.CompType, Params.M, Params.N, 6);

auto Op = createComputeOp(ElementSetShader, "cs_6_10", "UAV(u0), UAV(u1)",
Args.c_str());
addUAVBuffer(Op.get(), "Input", MatrixSize, false, "byname");
addUAVBuffer(Op.get(), "Output", MatrixSize, true);
addRootUAV(Op.get(), 0, "Input");
addRootUAV(Op.get(), 1, "Output");

auto Result =
runShaderOp(Device, DxcSupport, std::move(Op),
[NumElements, Params](LPCSTR Name, std::vector<BYTE> &Data,
st::ShaderOp *) {
VERIFY_IS_TRUE(fillInputBuffer(Name, Data, Params.CompType,
NumElements),
"Saw unsupported component type");
});

MappedData OutData;
Result->Test->GetReadBackData("Output", &OutData);

// Verify the front of the buffer is a list of elements of the expected type
VERIFY_IS_TRUE(verifyComponentBuffer(Params.CompType, OutData.data(),
Expected, NumElements, Verbose));
}

void DxilConf_SM610_LinAlg::ElementSet_Wave_16x16_F16() {
MatrixParams Params = {};
Params.CompType = ComponentType::F16;
Params.M = 16;
Params.N = 16;
Params.Use = MatrixUse::Accumulator;
Params.Scope = MatrixScope::Wave;
Params.Layout = LinalgMatrixLayout::RowMajor;
Params.NumThreads = 64;
Params.Enable16Bit = true;
runElementSet(D3DDevice, DxcSupport, Params, VerboseLogging);
}

static const char CopyConvertShader[] = R"(
RWByteAddressBuffer Input : register(u0);
RWByteAddressBuffer Output : register(u1);

[WaveSize(4, 64)]
[numthreads(NUMTHREADS, 1, 1)]
void main(uint threadID : SV_GroupIndex) {
if (WaveReadLaneFirst(threadID) != 0)
return;

__builtin_LinAlgMatrix
[[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE, SCOPE)]]
Src;
__builtin_LinAlgMatrix
[[__LinAlgMatrix_Attributes(COMP_TYPE, N_DIM, M_DIM, USE, SCOPE)]]
Dst;

__builtin_LinAlg_MatrixLoadFromDescriptor(
Src, Input, 0, STRIDE, LAYOUT, 128);
__builtin_LinAlg_CopyConvertMatrix(Dst, Src, TRANSPOSE);
__builtin_LinAlg_MatrixStoreToDescriptor(
Dst, Output, 0, STRIDE, LAYOUT, 128);
}
)";

static void runCopyConvert(ID3D12Device *Device,
dxc::SpecificDllLoader &DxcSupport,
const MatrixParams &Params, bool Verbose,
bool Transpose) {
const size_t NumElements = Params.totalElements();
const size_t BufferSize = Params.totalBytes();

std::stringstream ExtraDefs;
ExtraDefs << " -DTRANSPOSE=" << Transpose;

std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str());

compileShader(DxcSupport, CopyConvertShader, "cs_6_10", Args, Verbose);

auto Expected = makeExpected(Params.CompType, Params.M, Params.N, 1,
/*Increment=*/true, Transpose);

// Construct the ShaderOp: two UAV buffers, load from one, store to other.
auto Op = createComputeOp(CopyConvertShader, "cs_6_10", "UAV(u0), UAV(u1)",
Args.c_str());
addUAVBuffer(Op.get(), "Input", BufferSize, false, "byname");
addUAVBuffer(Op.get(), "Output", BufferSize, true);
addRootUAV(Op.get(), 0, "Input");
addRootUAV(Op.get(), 1, "Output");

auto Result =
runShaderOp(Device, DxcSupport, std::move(Op),
[NumElements, Params](LPCSTR Name, std::vector<BYTE> &Data,
st::ShaderOp *) {
VERIFY_IS_TRUE(fillInputBuffer(Name, Data, Params.CompType,
NumElements),
"Saw unsupported component type");
});

MappedData OutData;
Result->Test->GetReadBackData("Output", &OutData);

VERIFY_IS_TRUE(verifyComponentBuffer(Params.CompType, OutData.data(),
Expected, NumElements, Verbose));
}

void DxilConf_SM610_LinAlg::CopyConvert_Wave_16x16_F16() {
MatrixParams Params = {};
Params.CompType = ComponentType::F16;
Params.M = 16;
Params.N = 16;
Params.Use = MatrixUse::A;
Params.Scope = MatrixScope::Wave;
Params.Layout = LinalgMatrixLayout::RowMajor;
Params.NumThreads = 64;
Params.Enable16Bit = true;
runCopyConvert(D3DDevice, DxcSupport, Params, VerboseLogging,
/*Transpose=*/false);
}

void DxilConf_SM610_LinAlg::CopyConvert_Wave_16x16_F16_Transpose() {
MatrixParams Params = {};
Params.CompType = ComponentType::F16;
Params.M = 16;
Params.N = 16;
Params.Use = MatrixUse::A;
Params.Scope = MatrixScope::Wave;
Params.Layout = LinalgMatrixLayout::RowMajor;
Params.NumThreads = 64;
Params.Enable16Bit = true;
runCopyConvert(D3DDevice, DxcSupport, Params, VerboseLogging,
/*Transpose=*/true);
}

static const char MatMatMulShader[] = R"(
#define USE_A 0
#define USE_B 1
#define USE_ACC 2

RWByteAddressBuffer Output : register(u0);

[WaveSize(4, 64)]
[numthreads(NUMTHREADS, 1, 1)]
void main(uint threadID : SV_GroupIndex) {
if (WaveReadLaneFirst(threadID) != 0)
return;

__builtin_LinAlgMatrix
[[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, K_DIM, USE_A, SCOPE)]]
MatA;
__builtin_LinAlg_FillMatrix(MatA, A_FILL);

__builtin_LinAlgMatrix
[[__LinAlgMatrix_Attributes(COMP_TYPE, K_DIM, N_DIM, USE_B, SCOPE)]]
MatB;
__builtin_LinAlg_FillMatrix(MatB, B_FILL);

__builtin_LinAlgMatrix
[[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE_ACC, SCOPE)]]
MatC;
__builtin_LinAlg_MatrixMatrixMultiply(MatC, MatA, MatB);

__builtin_LinAlg_MatrixStoreToDescriptor(
MatC, Output, 0, STRIDE, LAYOUT, 128);
}
)";

static void runMatMatMul(ID3D12Device *Device,
dxc::SpecificDllLoader &DxcSupport,
const MatrixParams &Params, bool Verbose, MatrixDim K,
float AFill, float BFill) {
const size_t NumElements = Params.totalElements();
const size_t BufferSize = Params.totalBytes();

std::stringstream ExtraDefs;
ExtraDefs << " -DK_DIM=" << K;
ExtraDefs << " -DA_FILL=" << AFill;
ExtraDefs << " -DB_FILL=" << BFill;

std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str());

compileShader(DxcSupport, MatMatMulShader, "cs_6_10", Args, Verbose);

auto Expected = makeExpected(Params.CompType, Params.M, Params.N,
AFill * BFill * K, /*Increment=*/false);

auto Op =
createComputeOp(MatMatMulShader, "cs_6_10", "UAV(u0)", Args.c_str());
addUAVBuffer(Op.get(), "Output", BufferSize, true);
addRootUAV(Op.get(), 0, "Output");

auto Result = runShaderOp(Device, DxcSupport, std::move(Op));

MappedData OutData;
Result->Test->GetReadBackData("Output", &OutData);

VERIFY_IS_TRUE(verifyComponentBuffer(Params.CompType, OutData.data(),
Expected, NumElements, Verbose));
}

void DxilConf_SM610_LinAlg::MatMatMul_Wave_16x16x16_F16() {
MatrixParams Params = {};
Params.CompType = ComponentType::F16;
Params.M = 16;
Params.N = 16;
Params.Scope = MatrixScope::Wave;
Params.Layout = LinalgMatrixLayout::RowMajor;
Params.NumThreads = 64;
Params.Enable16Bit = true;
runMatMatMul(D3DDevice, DxcSupport, Params, VerboseLogging, /*K=*/16,
/*AFill=*/2.0f, /*BFill=*/3.0f);
}

} // namespace LinAlg