Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 97 additions & 15 deletions tools/clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5286,8 +5286,8 @@ class HLSLExternalSource : public ExternalSemaSource {
/// use for the signature, with the first being the return type.</remarks>
bool MatchArguments(const IntrinsicDefIter &cursor, QualType objectType,
QualType objectElement, QualType functionTemplateTypeArg,
ArrayRef<Expr *> Args, std::vector<QualType> *,
size_t &badArgIdx);
unsigned functionTemplateIntArg, ArrayRef<Expr *> Args,
std::vector<QualType> *, size_t &badArgIdx);

/// <summary>Validate object element on intrinsic to catch case like integer
/// on Sample.</summary> <param name="tableName">Intrinsic function to
Expand Down Expand Up @@ -5336,6 +5336,19 @@ class HLSLExternalSource : public ExternalSemaSource {
nameIdentifier, argumentCount));
}

static unsigned GetIntegralTemplateArg(ASTContext &context,
const TemplateArgument &arg) {
if (arg.getKind() == TemplateArgument::Integral)
return arg.getAsIntegral().getZExtValue();
if (arg.getKind() == TemplateArgument::Expression) {
llvm::APSInt result;
Expr *expr = arg.getAsExpr();
if (expr != nullptr && expr->isIntegerConstantExpr(result, context))
return result.getZExtValue();
}
return 0;
}

bool AddOverloadedCallCandidates(UnresolvedLookupExpr *ULE,
ArrayRef<Expr *> Args,
OverloadCandidateSet &CandidateSet, Scope *S,
Expand Down Expand Up @@ -5430,11 +5443,22 @@ class HLSLExternalSource : public ExternalSemaSource {
"otherwise g_MaxIntrinsicParamCount needs to be updated for "
"wider signatures");

QualType templateTypeArg;
unsigned templateIntArg = 0;
std::vector<QualType> functionArgTypes;
size_t badArgIdx;
if (ULE->hasExplicitTemplateArgs() && ULE->getNumTemplateArgs() >= 1) {
const TemplateArgumentLoc &TypeArgLoc = ULE->getTemplateArgs()[0];
if (TypeArgLoc.getArgument().getKind() == TemplateArgument::Type)
templateTypeArg = TypeArgLoc.getArgument().getAsType();
if (ULE->getNumTemplateArgs() >= 2)
templateIntArg = GetIntegralTemplateArg(
*m_context, ULE->getTemplateArgs()[1].getArgument());
}

bool argsMatch =
MatchArguments(cursor, QualType(), QualType(), QualType(), Args,
&functionArgTypes, badArgIdx);
MatchArguments(cursor, QualType(), QualType(), templateTypeArg,
templateIntArg, Args, &functionArgTypes, badArgIdx);
if (!functionArgTypes.size())
return false;

Expand Down Expand Up @@ -6843,8 +6867,9 @@ bool HLSLExternalSource::IsValidObjectElement(LPCSTR tableName,

bool HLSLExternalSource::MatchArguments(
const IntrinsicDefIter &cursor, QualType objectType, QualType objectElement,
QualType functionTemplateTypeArg, ArrayRef<Expr *> Args,
std::vector<QualType> *argTypesVector, size_t &badArgIdx) {
QualType functionTemplateTypeArg, unsigned functionTemplateIntArg,
ArrayRef<Expr *> Args, std::vector<QualType> *argTypesVector,
size_t &badArgIdx) {
const HLSL_INTRINSIC *pIntrinsic = *cursor;
LPCSTR tableName = cursor.GetTableName();
IntrinsicOp builtinOp = IntrinsicOp::Num_Intrinsics;
Expand Down Expand Up @@ -7336,7 +7361,45 @@ bool HLSLExternalSource::MatchArguments(
if (i == 0 &&
(builtinOp == hlsl::IntrinsicOp::IOP_Vkreinterpret_pointer_cast ||
builtinOp == hlsl::IntrinsicOp::IOP_Vkstatic_pointer_cast)) {
pNewType = Args[0]->getType();
if (functionTemplateTypeArg.isNull()) {
badArgIdx = std::min(badArgIdx, i);
continue;
}

// Build BufferPointer<T, A> where T is the template type argument and
// A is the template alignment argument (or the alignment of the
// source pointer if none is given).
unsigned srcAlignment =
functionTemplateIntArg
? functionTemplateIntArg
: hlsl::GetVKBufferPointerAlignment(Args[0]->getType());
TemplateArgument TemplateArgs[] = {
TemplateArgument(functionTemplateTypeArg),
TemplateArgument(*m_context,
llvm::APSInt(llvm::APInt(32, srcAlignment)),
m_context->UnsignedIntTy)};
void *InsertPos = nullptr;
ClassTemplateSpecializationDecl *Spec =
m_vkBufferPointerTemplateDecl->findSpecialization(
llvm::ArrayRef<TemplateArgument>(TemplateArgs, 2), InsertPos);
if (!Spec) {
Spec = ClassTemplateSpecializationDecl::Create(
*m_context, TagDecl::TagKind::TTK_Struct,
m_vkBufferPointerTemplateDecl->getDeclContext(), SourceLocation(),
SourceLocation(), m_vkBufferPointerTemplateDecl, TemplateArgs, 2,
nullptr);
m_vkBufferPointerTemplateDecl->AddSpecialization(Spec, InsertPos);
Spec->setImplicit(true);
DXVERIFY_NOMSG(
false ==
getSema()->InstantiateClassTemplateSpecialization(
SourceLocation(), Spec,
TemplateSpecializationKind::TSK_ImplicitInstantiation, true));
}

pNewType = m_context->getTemplateSpecializationType(
TemplateName(m_vkBufferPointerTemplateDecl), TemplateArgs, 2,
m_context->getTypeDeclType(Spec));
} else {
badArgIdx = std::min(badArgIdx, i);
}
Expand Down Expand Up @@ -11072,11 +11135,18 @@ HLSLExternalSource::DeduceTemplateArgumentsForHLSL(
QualType objectType = m_context->getTagDeclType(functionParentRecord);

QualType functionTemplateTypeArg{};
if (ExplicitTemplateArgs != nullptr && ExplicitTemplateArgs->size() == 1) {
unsigned functionTemplateIntArg = 0;
if (ExplicitTemplateArgs != nullptr && ExplicitTemplateArgs->size() >= 1) {
const TemplateArgument &firstTemplateArg =
(*ExplicitTemplateArgs)[0].getArgument();
if (firstTemplateArg.getKind() == TemplateArgument::ArgKind::Type)
functionTemplateTypeArg = firstTemplateArg.getAsType();
if (ExplicitTemplateArgs->size() > 1) {
const TemplateArgument &secondTemplateArg =
(*ExplicitTemplateArgs)[1].getArgument();
functionTemplateIntArg =
GetIntegralTemplateArg(*m_context, secondTemplateArg);
}
}

// Handle subscript overloads.
Expand Down Expand Up @@ -11150,7 +11220,8 @@ HLSLExternalSource::DeduceTemplateArgumentsForHLSL(
while (cursor != end) {
size_t badArgIdx;
if (!MatchArguments(cursor, objectType, objectElement,
functionTemplateTypeArg, Args, &argTypes, badArgIdx)) {
functionTemplateTypeArg, functionTemplateIntArg, Args,
&argTypes, badArgIdx)) {
++cursor;
continue;
}
Expand Down Expand Up @@ -11193,8 +11264,9 @@ HLSLExternalSource::DeduceTemplateArgumentsForHLSL(
if (!IsNull &&
getSema()->RequireCompleteType(Loc, functionTemplateTypeArg, 0))
return Sema::TemplateDeductionResult::TDK_Invalid;
if (IsNull || !hlsl::IsHLSLNumericOrAggregateOfNumericType(
functionTemplateTypeArg)) {
if (IsNull || ExplicitTemplateArgs->size() > 1 ||
!hlsl::IsHLSLNumericOrAggregateOfNumericType(
functionTemplateTypeArg)) {
getSema()->Diag(Loc, diag::err_hlsl_intrinsic_template_arg_numeric)
<< intrinsicName;
DiagnoseTypeElements(
Expand Down Expand Up @@ -11924,8 +11996,18 @@ static bool CheckBarrierCall(Sema &S, FunctionDecl *FD, CallExpr *CE,
}

#ifdef ENABLE_SPIRV_CODEGEN
static bool CheckVKBufferPointerCast(Sema &S, FunctionDecl *FD, CallExpr *CE,
bool isStatic) {
static bool CheckVKBufferPointerCast(Sema &S, CallExpr *CE, bool isStatic) {
const auto *callee = dyn_cast<DeclRefExpr>(CE->getCallee()->IgnoreImpCasts());
if (callee && callee->hasExplicitTemplateArgs() &&
callee->getNumTemplateArgs() > 2) {
StringRef castName =
isStatic ? "static_pointer_cast" : "reinterpret_pointer_cast";
S.Diags.Report(CE->getExprLoc(),
diag::err_template_arg_list_different_arity)
<< /*too many*/ 1 << /*function template*/ 1 << castName;
return true;
}

const Expr *argExpr = CE->getArg(0);
QualType srcType = argExpr->getType();
QualType destType = CE->getType();
Expand Down Expand Up @@ -12063,10 +12145,10 @@ void Sema::CheckHLSLFunctionCall(FunctionDecl *FDecl, CallExpr *TheCall) {
break;
#ifdef ENABLE_SPIRV_CODEGEN
case hlsl::IntrinsicOp::IOP_Vkreinterpret_pointer_cast:
CheckVKBufferPointerCast(*this, FDecl, TheCall, false);
CheckVKBufferPointerCast(*this, TheCall, false);
break;
case hlsl::IntrinsicOp::IOP_Vkstatic_pointer_cast:
CheckVKBufferPointerCast(*this, FDecl, TheCall, true);
CheckVKBufferPointerCast(*this, TheCall, true);
break;
#endif
default:
Expand Down
45 changes: 45 additions & 0 deletions tools/clang/test/CodeGenSPIRV/vk.buffer-pointer.casts.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// RUN: %dxc -T cs_6_6 -spirv -fspv-target-env=vulkan1.3 %s | FileCheck %s --check-prefix=GOOD
// RUN: not %dxc -T cs_6_6 -spirv -fspv-target-env=vulkan1.3 -DBAD %s 2>&1 | FileCheck %s --check-prefix=BAD

struct Base {
};

struct Derived : Base {
int val;
};

cbuffer Test {
vk::BufferPointer<Derived, 32> derivedBuf;
vk::BufferPointer<int> intBuf;
};

[shader("compute")]
[numthreads(256, 1, 1)]
void main(in uint3 threadId : SV_DispatchThreadID) {
#ifdef BAD
vk::BufferPointer<Base, 64> derivedBufAsBase = vk::static_pointer_cast<Base, 64>(derivedBuf);
vk::BufferPointer<float> intBufAsFloat = vk::static_pointer_cast<float>(intBuf);
#else
vk::BufferPointer<Base, 16> derivedBufAsBase = vk::static_pointer_cast<Base, 16>(derivedBuf);
vk::BufferPointer<float> intBufAsFloat = vk::reinterpret_pointer_cast<float>(intBuf);
#endif

intBuf.Get() = (int)intBufAsFloat.Get();
}

// GOOD: [[INT:%[^ ]*]] = OpTypeInt 32 1
// GOOD: [[I1:%[^ ]*]] = OpConstant [[INT]] 1
// GOOD: [[PINT:%[^ ]*]] = OpTypePointer PhysicalStorageBuffer [[INT]]
// GOOD: [[FLOAT:%[^ ]*]] = OpTypeFloat 32
// GOOD: [[PFLOAT:%[^ ]*]] = OpTypePointer PhysicalStorageBuffer [[FLOAT]]
// GOOD: %Test = OpVariable %{{[^ ]*}} Uniform
// GOOD: [[V0:%[^ ]*]] = OpAccessChain %{{[^ ]*}} %Test [[I1]]
// GOOD: [[V1:%[^ ]*]] = OpLoad [[PINT]] [[V0]]
// GOOD: [[V2:%[^ ]*]] = OpBitcast [[PFLOAT]] [[V1]]
// GOOD: [[V3:%[^ ]*]] = OpLoad [[FLOAT]] [[V2]] Aligned 4
// GOOD: [[V4:%[^ ]*]] = OpConvertFToS [[INT]] [[V3]]
// GOOD: OpStore [[V1]] [[V4]] Aligned 4

// BAD: error: Vulkan buffer pointer cannot be cast to greater alignment
// BAD: error: vk::static_pointer_cast() content type must be base class of argument's content type

Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,7 @@ struct TestPushConstant_t
float4 MainPs(void) : SV_Target0
{
if (__has_feature(hlsl_vk_buffer_pointer)) {
[[vk::aliased_pointer]] block_p g_p =
vk::static_pointer_cast<block_t, 16>(g_PushConstants.root);
[[vk::aliased_pointer]] block_p g_p = g_PushConstants.root;
g_p = g_p.Get().next;
uint64_t addr = (uint64_t)g_p;
block_p copy1 = block_p(addr);
Expand Down
Loading