diff --git a/tools/clang/lib/Sema/SemaHLSL.cpp b/tools/clang/lib/Sema/SemaHLSL.cpp index e7e55d82a7..880369afaf 100644 --- a/tools/clang/lib/Sema/SemaHLSL.cpp +++ b/tools/clang/lib/Sema/SemaHLSL.cpp @@ -5286,8 +5286,8 @@ class HLSLExternalSource : public ExternalSemaSource { /// use for the signature, with the first being the return type. bool MatchArguments(const IntrinsicDefIter &cursor, QualType objectType, QualType objectElement, QualType functionTemplateTypeArg, - ArrayRef Args, std::vector *, - size_t &badArgIdx); + unsigned functionTemplateIntArg, ArrayRef Args, + std::vector *, size_t &badArgIdx); /// Validate object element on intrinsic to catch case like integer /// on Sample. Intrinsic function to @@ -5336,6 +5336,19 @@ class HLSLExternalSource : public ExternalSemaSource { nameIdentifier, argumentCount)); } + static unsigned GetIntegralTemplateArg(ASTContext &context, + const TemplateArgument &arg) { + if (arg.getKind() == TemplateArgument::Integral) + return arg.getAsIntegral().getZExtValue(); + if (arg.getKind() == TemplateArgument::Expression) { + llvm::APSInt result; + Expr *expr = arg.getAsExpr(); + if (expr != nullptr && expr->isIntegerConstantExpr(result, context)) + return result.getZExtValue(); + } + return 0; + } + bool AddOverloadedCallCandidates(UnresolvedLookupExpr *ULE, ArrayRef Args, OverloadCandidateSet &CandidateSet, Scope *S, @@ -5430,11 +5443,22 @@ class HLSLExternalSource : public ExternalSemaSource { "otherwise g_MaxIntrinsicParamCount needs to be updated for " "wider signatures"); + QualType templateTypeArg; + unsigned templateIntArg = 0; std::vector functionArgTypes; size_t badArgIdx; + if (ULE->hasExplicitTemplateArgs() && ULE->getNumTemplateArgs() >= 1) { + const TemplateArgumentLoc &TypeArgLoc = ULE->getTemplateArgs()[0]; + if (TypeArgLoc.getArgument().getKind() == TemplateArgument::Type) + templateTypeArg = TypeArgLoc.getArgument().getAsType(); + if (ULE->getNumTemplateArgs() >= 2) + templateIntArg = GetIntegralTemplateArg( + *m_context, ULE->getTemplateArgs()[1].getArgument()); + } + bool argsMatch = - MatchArguments(cursor, QualType(), QualType(), QualType(), Args, - &functionArgTypes, badArgIdx); + MatchArguments(cursor, QualType(), QualType(), templateTypeArg, + templateIntArg, Args, &functionArgTypes, badArgIdx); if (!functionArgTypes.size()) return false; @@ -6843,8 +6867,9 @@ bool HLSLExternalSource::IsValidObjectElement(LPCSTR tableName, bool HLSLExternalSource::MatchArguments( const IntrinsicDefIter &cursor, QualType objectType, QualType objectElement, - QualType functionTemplateTypeArg, ArrayRef Args, - std::vector *argTypesVector, size_t &badArgIdx) { + QualType functionTemplateTypeArg, unsigned functionTemplateIntArg, + ArrayRef Args, std::vector *argTypesVector, + size_t &badArgIdx) { const HLSL_INTRINSIC *pIntrinsic = *cursor; LPCSTR tableName = cursor.GetTableName(); IntrinsicOp builtinOp = IntrinsicOp::Num_Intrinsics; @@ -7336,7 +7361,45 @@ bool HLSLExternalSource::MatchArguments( if (i == 0 && (builtinOp == hlsl::IntrinsicOp::IOP_Vkreinterpret_pointer_cast || builtinOp == hlsl::IntrinsicOp::IOP_Vkstatic_pointer_cast)) { - pNewType = Args[0]->getType(); + if (functionTemplateTypeArg.isNull()) { + badArgIdx = std::min(badArgIdx, i); + continue; + } + + // Build BufferPointer where T is the template type argument and + // A is the template alignment argument (or the alignment of the + // source pointer if none is given). + unsigned srcAlignment = + functionTemplateIntArg + ? functionTemplateIntArg + : hlsl::GetVKBufferPointerAlignment(Args[0]->getType()); + TemplateArgument TemplateArgs[] = { + TemplateArgument(functionTemplateTypeArg), + TemplateArgument(*m_context, + llvm::APSInt(llvm::APInt(32, srcAlignment)), + m_context->UnsignedIntTy)}; + void *InsertPos = nullptr; + ClassTemplateSpecializationDecl *Spec = + m_vkBufferPointerTemplateDecl->findSpecialization( + llvm::ArrayRef(TemplateArgs, 2), InsertPos); + if (!Spec) { + Spec = ClassTemplateSpecializationDecl::Create( + *m_context, TagDecl::TagKind::TTK_Struct, + m_vkBufferPointerTemplateDecl->getDeclContext(), SourceLocation(), + SourceLocation(), m_vkBufferPointerTemplateDecl, TemplateArgs, 2, + nullptr); + m_vkBufferPointerTemplateDecl->AddSpecialization(Spec, InsertPos); + Spec->setImplicit(true); + DXVERIFY_NOMSG( + false == + getSema()->InstantiateClassTemplateSpecialization( + SourceLocation(), Spec, + TemplateSpecializationKind::TSK_ImplicitInstantiation, true)); + } + + pNewType = m_context->getTemplateSpecializationType( + TemplateName(m_vkBufferPointerTemplateDecl), TemplateArgs, 2, + m_context->getTypeDeclType(Spec)); } else { badArgIdx = std::min(badArgIdx, i); } @@ -11072,11 +11135,18 @@ HLSLExternalSource::DeduceTemplateArgumentsForHLSL( QualType objectType = m_context->getTagDeclType(functionParentRecord); QualType functionTemplateTypeArg{}; - if (ExplicitTemplateArgs != nullptr && ExplicitTemplateArgs->size() == 1) { + unsigned functionTemplateIntArg = 0; + if (ExplicitTemplateArgs != nullptr && ExplicitTemplateArgs->size() >= 1) { const TemplateArgument &firstTemplateArg = (*ExplicitTemplateArgs)[0].getArgument(); if (firstTemplateArg.getKind() == TemplateArgument::ArgKind::Type) functionTemplateTypeArg = firstTemplateArg.getAsType(); + if (ExplicitTemplateArgs->size() > 1) { + const TemplateArgument &secondTemplateArg = + (*ExplicitTemplateArgs)[1].getArgument(); + functionTemplateIntArg = + GetIntegralTemplateArg(*m_context, secondTemplateArg); + } } // Handle subscript overloads. @@ -11150,7 +11220,8 @@ HLSLExternalSource::DeduceTemplateArgumentsForHLSL( while (cursor != end) { size_t badArgIdx; if (!MatchArguments(cursor, objectType, objectElement, - functionTemplateTypeArg, Args, &argTypes, badArgIdx)) { + functionTemplateTypeArg, functionTemplateIntArg, Args, + &argTypes, badArgIdx)) { ++cursor; continue; } @@ -11193,8 +11264,9 @@ HLSLExternalSource::DeduceTemplateArgumentsForHLSL( if (!IsNull && getSema()->RequireCompleteType(Loc, functionTemplateTypeArg, 0)) return Sema::TemplateDeductionResult::TDK_Invalid; - if (IsNull || !hlsl::IsHLSLNumericOrAggregateOfNumericType( - functionTemplateTypeArg)) { + if (IsNull || ExplicitTemplateArgs->size() > 1 || + !hlsl::IsHLSLNumericOrAggregateOfNumericType( + functionTemplateTypeArg)) { getSema()->Diag(Loc, diag::err_hlsl_intrinsic_template_arg_numeric) << intrinsicName; DiagnoseTypeElements( @@ -11924,8 +11996,18 @@ static bool CheckBarrierCall(Sema &S, FunctionDecl *FD, CallExpr *CE, } #ifdef ENABLE_SPIRV_CODEGEN -static bool CheckVKBufferPointerCast(Sema &S, FunctionDecl *FD, CallExpr *CE, - bool isStatic) { +static bool CheckVKBufferPointerCast(Sema &S, CallExpr *CE, bool isStatic) { + const auto *callee = dyn_cast(CE->getCallee()->IgnoreImpCasts()); + if (callee && callee->hasExplicitTemplateArgs() && + callee->getNumTemplateArgs() > 2) { + StringRef castName = + isStatic ? "static_pointer_cast" : "reinterpret_pointer_cast"; + S.Diags.Report(CE->getExprLoc(), + diag::err_template_arg_list_different_arity) + << /*too many*/ 1 << /*function template*/ 1 << castName; + return true; + } + const Expr *argExpr = CE->getArg(0); QualType srcType = argExpr->getType(); QualType destType = CE->getType(); @@ -12063,10 +12145,10 @@ void Sema::CheckHLSLFunctionCall(FunctionDecl *FDecl, CallExpr *TheCall) { break; #ifdef ENABLE_SPIRV_CODEGEN case hlsl::IntrinsicOp::IOP_Vkreinterpret_pointer_cast: - CheckVKBufferPointerCast(*this, FDecl, TheCall, false); + CheckVKBufferPointerCast(*this, TheCall, false); break; case hlsl::IntrinsicOp::IOP_Vkstatic_pointer_cast: - CheckVKBufferPointerCast(*this, FDecl, TheCall, true); + CheckVKBufferPointerCast(*this, TheCall, true); break; #endif default: diff --git a/tools/clang/test/CodeGenSPIRV/vk.buffer-pointer.casts.hlsl b/tools/clang/test/CodeGenSPIRV/vk.buffer-pointer.casts.hlsl new file mode 100644 index 0000000000..9481aa2b68 --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/vk.buffer-pointer.casts.hlsl @@ -0,0 +1,45 @@ +// RUN: %dxc -T cs_6_6 -spirv -fspv-target-env=vulkan1.3 %s | FileCheck %s --check-prefix=GOOD +// RUN: not %dxc -T cs_6_6 -spirv -fspv-target-env=vulkan1.3 -DBAD %s 2>&1 | FileCheck %s --check-prefix=BAD + +struct Base { +}; + +struct Derived : Base { + int val; +}; + +cbuffer Test { + vk::BufferPointer derivedBuf; + vk::BufferPointer intBuf; +}; + +[shader("compute")] +[numthreads(256, 1, 1)] +void main(in uint3 threadId : SV_DispatchThreadID) { +#ifdef BAD + vk::BufferPointer derivedBufAsBase = vk::static_pointer_cast(derivedBuf); + vk::BufferPointer intBufAsFloat = vk::static_pointer_cast(intBuf); +#else + vk::BufferPointer derivedBufAsBase = vk::static_pointer_cast(derivedBuf); + vk::BufferPointer intBufAsFloat = vk::reinterpret_pointer_cast(intBuf); +#endif + + intBuf.Get() = (int)intBufAsFloat.Get(); +} + +// GOOD: [[INT:%[^ ]*]] = OpTypeInt 32 1 +// GOOD: [[I1:%[^ ]*]] = OpConstant [[INT]] 1 +// GOOD: [[PINT:%[^ ]*]] = OpTypePointer PhysicalStorageBuffer [[INT]] +// GOOD: [[FLOAT:%[^ ]*]] = OpTypeFloat 32 +// GOOD: [[PFLOAT:%[^ ]*]] = OpTypePointer PhysicalStorageBuffer [[FLOAT]] +// GOOD: %Test = OpVariable %{{[^ ]*}} Uniform +// GOOD: [[V0:%[^ ]*]] = OpAccessChain %{{[^ ]*}} %Test [[I1]] +// GOOD: [[V1:%[^ ]*]] = OpLoad [[PINT]] [[V0]] +// GOOD: [[V2:%[^ ]*]] = OpBitcast [[PFLOAT]] [[V1]] +// GOOD: [[V3:%[^ ]*]] = OpLoad [[FLOAT]] [[V2]] Aligned 4 +// GOOD: [[V4:%[^ ]*]] = OpConvertFToS [[INT]] [[V3]] +// GOOD: OpStore [[V1]] [[V4]] Aligned 4 + +// BAD: error: Vulkan buffer pointer cannot be cast to greater alignment +// BAD: error: vk::static_pointer_cast() content type must be base class of argument's content type + diff --git a/tools/clang/test/CodeGenSPIRV/vk.buffer-pointer.linked-list.hlsl b/tools/clang/test/CodeGenSPIRV/vk.buffer-pointer.linked-list.hlsl index 75380d3f4e..4f12218783 100644 --- a/tools/clang/test/CodeGenSPIRV/vk.buffer-pointer.linked-list.hlsl +++ b/tools/clang/test/CodeGenSPIRV/vk.buffer-pointer.linked-list.hlsl @@ -53,8 +53,7 @@ struct TestPushConstant_t float4 MainPs(void) : SV_Target0 { if (__has_feature(hlsl_vk_buffer_pointer)) { - [[vk::aliased_pointer]] block_p g_p = - vk::static_pointer_cast(g_PushConstants.root); + [[vk::aliased_pointer]] block_p g_p = g_PushConstants.root; g_p = g_p.Get().next; uint64_t addr = (uint64_t)g_p; block_p copy1 = block_p(addr);