diff --git a/src/CodeGen_D3D12Compute_Dev.cpp b/src/CodeGen_D3D12Compute_Dev.cpp index 8385fb8b5448..b2f1a8e094f9 100644 --- a/src/CodeGen_D3D12Compute_Dev.cpp +++ b/src/CodeGen_D3D12Compute_Dev.cpp @@ -9,9 +9,11 @@ #include "CodeGen_Internal.h" #include "Debug.h" #include "DeviceArgument.h" +#include "EliminateBoolVectors.h" #include "IRMutator.h" #include "IROperator.h" #include "Simplify.h" +#include "StrictifyFloat.h" #define DEBUG_TYPES (0) @@ -25,8 +27,6 @@ using std::vector; namespace { -ostringstream nil; - class CodeGen_D3D12Compute_Dev : public CodeGen_GPU_Dev { public: CodeGen_D3D12Compute_Dev(const Target &target); @@ -81,7 +81,6 @@ class CodeGen_D3D12Compute_Dev : public CodeGen_GPU_Dev { alias("floor", "floor"); alias("ceil", "ceil"); alias("trunc", "trunc"); - alias("pow", "pow"); alias("asin", "asin"); alias("acos", "acos"); alias("tan", "tan"); @@ -130,6 +129,7 @@ class CodeGen_D3D12Compute_Dev : public CodeGen_GPU_Dev { void visit(const For *) override; void visit(const Ramp *op) override; void visit(const Broadcast *op) override; + void visit(const Shuffle *op) override; void visit(const Call *op) override; void visit(const Load *op) override; void visit(const Store *op) override; @@ -141,6 +141,7 @@ class CodeGen_D3D12Compute_Dev : public CodeGen_GPU_Dev { void visit(const FloatImm *op) override; Scope<> groupshared_allocations; + bool emit_precise = false; }; std::ostringstream src_stream; @@ -364,8 +365,40 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Broadcast *op print_assignment(op->type.with_lanes(op->lanes), rhs.str()); } +void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Shuffle *op) { + // HLSL uses WGSL-like syntax for vector shuffles, not the CLike syntax which is the default. + ScopedValue vector_style{vector_declaration_style, VectorDeclarationStyle::WGSLSyntax}; + CodeGen_GPU_C::visit(op); +} + void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Call *op) { - if (op->is_intrinsic(Call::gpu_thread_barrier)) { + if (op->is_intrinsic(Call::bool_to_mask)) { + if (op->args[0].type().is_vector()) { + // The argument is already a mask of the right width. Just + // sign-extend to the expected type. + op->args[0].accept(this); + } else { + // The argument is a scalar bool. Casting it to an int + // produces zero or one. Convert it to -1 of the requested + // type. + Expr equiv = -Cast::make(op->type, op->args[0]); + equiv.accept(this); + } + } else if (op->is_intrinsic(Call::cast_mask)) { + internal_assert(op->args.size() == 1); + // Sign-extension is fine + Expr equiv = Cast::make(op->type, op->args[0]); + equiv.accept(this); + } else if (op->is_intrinsic(Call::select_mask)) { + internal_assert(op->args.size() == 3); + string cond = print_expr(op->args[0]); + string true_val = print_expr(op->args[1]); + string false_val = print_expr(op->args[2]); + + ostringstream rhs; + rhs << "(" << cond << " ? " << true_val << " : " << false_val << ")"; + print_assignment(op->type, rhs.str()); + } else if (op->is_intrinsic(Call::gpu_thread_barrier)) { internal_assert(op->args.size() == 1) << "gpu_thread_barrier() intrinsic must specify memory fence type.\n"; auto fence_type_ptr = as_const_int(op->args[0]); @@ -385,11 +418,31 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Call *op) { } stream << get_indent() << "GroupMemoryBarrierWithGroupSync();\n"; print_assignment(op->type, "0"); - } else if (op->name == "pow_f32" && can_prove(op->args[0] > 0)) { + } else if ((op->name == "pow_f16" || op->name == "pow_f32" || op->name == "pow_f64") && op->type.is_vector()) { + internal_assert(op->args.size() == 2); + const string x = print_expr(op->args[0]); + const string y = print_expr(op->args[1]); + ostringstream rhs; + rhs << print_type(op->type) << "("; + for (int i = 0; i < op->type.lanes(); ++i) { + rhs << op->name << "(" << x << "[" << i << "], " << y << "[" << i << "])"; + if (i < op->type.lanes() - 1) { + rhs << ", "; + } + } + rhs << ")"; + print_assignment(op->type, rhs.str()); + } else if ((op->name == "pow_f16" || op->name == "pow_f32") && can_prove(op->args[0] > 0)) { // If we know pow(x, y) is called with x > 0, we can use HLSL's pow // directly. Expr equiv = Call::make(op->type, "pow", op->args, Call::PureExtern); equiv.accept(this); + } else if (op->is_strict_float_intrinsic()) { + ScopedValue old_emit_precise(emit_precise, true); + Expr equiv = op->is_intrinsic(Call::strict_fma) ? + Call::make(op->type, op->type.bits() == 64 ? "fma" : "mad", op->args, Call::PureExtern) : + unstrictify_float(op); + equiv.accept(this); } else if (op->is_intrinsic(Call::round)) { // HLSL's round intrinsic has the correct semantics for our rounding. Expr equiv = Call::make(op->type, "round", op->args, Call::PureExtern); @@ -763,6 +816,14 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Store *op) { } void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Select *op) { + if (op->type.is_vector()) { + // A vector of bool was recursively introduced while + // performing codegen. Eliminate it. + Expr equiv = eliminate_bool_vectors(op); + equiv.accept(this); + return; + } + ostringstream rhs; string true_val = print_expr(op->true_value); string false_val = print_expr(op->false_value); @@ -832,7 +893,25 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Free *op) { string CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::print_assignment(Type type, const string &rhs) { string rhs_modified = print_reinforced_cast(type, rhs); - return CodeGen_GPU_C::print_assignment(type, rhs_modified); + const char *precise_flag = (type.is_float() && emit_precise) ? "precise " : ""; + const string cache_key = precise_flag + rhs_modified; + + auto cached = cache.find(cache_key); + if (cached == cache.end()) { + id = unique_name('_'); + const char *const_flag = output_kind == CPlusPlusImplementation ? " const " : ""; + if (type.is_handle()) { + // Don't print void *, which might lose useful type information. just use auto. + stream << get_indent() << "auto *"; + } else { + stream << get_indent() << precise_flag << print_type(type, AppendSpace); + } + stream << const_flag << id << " = " << rhs_modified << ";\n"; + cache[cache_key] = id; + } else { + id = cached->second; + } + return id; } string CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::print_vanilla_cast(Type type, const string &value_expr) { @@ -992,7 +1071,7 @@ void CodeGen_D3D12Compute_Dev::add_kernel(Stmt s, // support predication. s = scalarize_predicated_loads_stores(s); - debug(2) << "CodeGen_D3D12Compute_Dev: after removing predication: \n" + debug(2) << "CodeGen_D3D12Compute_Dev: after removing predication:\n" << s; // TODO: do we have to uniquify these names, or can we trust that they are safe? @@ -1022,6 +1101,11 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::add_kernel(Stmt s, debug(2) << "Adding D3D12Compute kernel " << name << "\n"; + debug(2) << "Eliminating bool vectors\n"; + s = eliminate_bool_vectors(s); + debug(2) << "After eliminating bool vectors:\n" + << s << "\n"; + // Figure out which arguments should be passed in constant. // Such arguments should be: // - not written to, @@ -1260,10 +1344,6 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::add_kernel(Stmt s, void CodeGen_D3D12Compute_Dev::init_module() { debug(2) << "D3D12Compute device codegen init_module\n"; - // TODO: we could support strict float intrinsics with the precise qualifier - internal_assert(!any_strict_float) - << "strict float intrinsics not yet supported in d3d12compute backend"; - // wipe the internal kernel source src_stream.str(""); src_stream.clear(); @@ -1316,32 +1396,38 @@ void CodeGen_D3D12Compute_Dev::init_module() { << "#define asuint32 asuint\n" << "\n" #endif - << "float nan_f32() { return 1.#IND; } \n" // Quiet NaN with minimum fractional value. - << "float neg_inf_f32() { return -1.#INF; } \n" - << "float inf_f32() { return +1.#INF; } \n" - << "#define float_from_bits asfloat \n" + << "float nan_f32() { return 1.#IND; }\n" // Quiet NaN with minimum fractional value. + << "float neg_inf_f32() { return -1.#INF; }\n" + << "float inf_f32() { return +1.#INF; }\n" + << "#define float_from_bits asfloat\n" // pow() in HLSL has the same semantics as C if // x > 0. Otherwise, we need to emulate C // behavior. // TODO(shoaibkamil): Can we simplify this? - << "float pow_f32(float x, float y) { \n" - << " if (x > 0.0) { \n" - << " return pow(x, y); \n" - << " } else if (y == 0.0) { \n" - << " return 1.0f; \n" - << " } else if (trunc(y) == y) { \n" - << " if (fmod(y, 2) == 0) { \n" - << " return pow(abs(x), y); \n" - << " } else { \n" - << " return -pow(abs(x), y); \n" - << " } \n" - << " } else { \n" - << " return nan_f32(); \n" - << " } \n" - << "} \n" - << "#define asinh(x) (log(x + sqrt(x*x + 1))) \n" - << "#define acosh(x) (log(x + sqrt(x*x - 1))) \n" - << "#define atanh(x) (log((1+x)/(1-x))/2) \n" + << "float pow_f32(float x, float y) {\n" + << " if (x > 0.0) {\n" + << " return pow(x, y);\n" + << " } else if (y == 0.0) {\n" + << " return 1.0f;\n" + << " } else if (trunc(y) == y) {\n" + << " if (fmod(y, 2) == 0) {\n" + << " return pow(abs(x), y);\n" + << " } else {\n" + << " return -pow(abs(x), y);\n" + << " }\n" + << " } else {\n" + << " return nan_f32();\n" + << " }\n" + << "}\n" + << "half pow_f16(half x, half y) {\n" + << " return half(pow_f32(float(x), float(y)));\n" + << "}\n" + << "double pow_f64(double x, double y) {\n" + << " return double(pow_f32(float(x), float(y)));\n" + << "}\n" + << "#define asinh(x) (log(x + sqrt(x*x + 1)))\n" + << "#define acosh(x) (log(x + sqrt(x*x - 1)))\n" + << "#define atanh(x) (log((1+x)/(1-x))/2)\n" << "\n"; //<< "}\n"; // close namespace