Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 119 additions & 33 deletions src/CodeGen_D3D12Compute_Dev.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
#include "CodeGen_Internal.h"
#include "Debug.h"
#include "DeviceArgument.h"
#include "EliminateBoolVectors.h"
#include "IRMutator.h"
#include "IROperator.h"
#include "Simplify.h"
#include "StrictifyFloat.h"

#define DEBUG_TYPES (0)

Expand All @@ -25,8 +27,6 @@ using std::vector;

namespace {

ostringstream nil;

class CodeGen_D3D12Compute_Dev : public CodeGen_GPU_Dev {
public:
CodeGen_D3D12Compute_Dev(const Target &target);
Expand Down Expand Up @@ -81,7 +81,6 @@ class CodeGen_D3D12Compute_Dev : public CodeGen_GPU_Dev {
alias("floor", "floor");
alias("ceil", "ceil");
alias("trunc", "trunc");
alias("pow", "pow");
alias("asin", "asin");
alias("acos", "acos");
alias("tan", "tan");
Expand Down Expand Up @@ -130,6 +129,7 @@ class CodeGen_D3D12Compute_Dev : public CodeGen_GPU_Dev {
void visit(const For *) override;
void visit(const Ramp *op) override;
void visit(const Broadcast *op) override;
void visit(const Shuffle *op) override;
void visit(const Call *op) override;
void visit(const Load *op) override;
void visit(const Store *op) override;
Expand All @@ -141,6 +141,7 @@ class CodeGen_D3D12Compute_Dev : public CodeGen_GPU_Dev {
void visit(const FloatImm *op) override;

Scope<> groupshared_allocations;
bool emit_precise = false;
};

std::ostringstream src_stream;
Expand Down Expand Up @@ -364,8 +365,40 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Broadcast *op
print_assignment(op->type.with_lanes(op->lanes), rhs.str());
}

void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Shuffle *op) {
// HLSL uses WGSL-like syntax for vector shuffles, not the CLike syntax which is the default.
ScopedValue vector_style{vector_declaration_style, VectorDeclarationStyle::WGSLSyntax};
CodeGen_GPU_C::visit(op);
}

void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Call *op) {
if (op->is_intrinsic(Call::gpu_thread_barrier)) {
if (op->is_intrinsic(Call::bool_to_mask)) {
if (op->args[0].type().is_vector()) {
// The argument is already a mask of the right width. Just
// sign-extend to the expected type.
op->args[0].accept(this);
} else {
// The argument is a scalar bool. Casting it to an int
// produces zero or one. Convert it to -1 of the requested
// type.
Expr equiv = -Cast::make(op->type, op->args[0]);
equiv.accept(this);
}
} else if (op->is_intrinsic(Call::cast_mask)) {
internal_assert(op->args.size() == 1);
// Sign-extension is fine
Expr equiv = Cast::make(op->type, op->args[0]);
equiv.accept(this);
} else if (op->is_intrinsic(Call::select_mask)) {
internal_assert(op->args.size() == 3);
string cond = print_expr(op->args[0]);
string true_val = print_expr(op->args[1]);
string false_val = print_expr(op->args[2]);

ostringstream rhs;
rhs << "(" << cond << " ? " << true_val << " : " << false_val << ")";
print_assignment(op->type, rhs.str());
} else if (op->is_intrinsic(Call::gpu_thread_barrier)) {
internal_assert(op->args.size() == 1) << "gpu_thread_barrier() intrinsic must specify memory fence type.\n";

auto fence_type_ptr = as_const_int(op->args[0]);
Expand All @@ -385,11 +418,31 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Call *op) {
}
stream << get_indent() << "GroupMemoryBarrierWithGroupSync();\n";
print_assignment(op->type, "0");
} else if (op->name == "pow_f32" && can_prove(op->args[0] > 0)) {
} else if ((op->name == "pow_f16" || op->name == "pow_f32" || op->name == "pow_f64") && op->type.is_vector()) {
internal_assert(op->args.size() == 2);
const string x = print_expr(op->args[0]);
const string y = print_expr(op->args[1]);
ostringstream rhs;
rhs << print_type(op->type) << "(";
for (int i = 0; i < op->type.lanes(); ++i) {
rhs << op->name << "(" << x << "[" << i << "], " << y << "[" << i << "])";
if (i < op->type.lanes() - 1) {
rhs << ", ";
}
}
rhs << ")";
print_assignment(op->type, rhs.str());
} else if ((op->name == "pow_f16" || op->name == "pow_f32") && can_prove(op->args[0] > 0)) {
// If we know pow(x, y) is called with x > 0, we can use HLSL's pow
// directly.
Expr equiv = Call::make(op->type, "pow", op->args, Call::PureExtern);
equiv.accept(this);
} else if (op->is_strict_float_intrinsic()) {
ScopedValue old_emit_precise(emit_precise, true);
Expr equiv = op->is_intrinsic(Call::strict_fma) ?
Call::make(op->type, op->type.bits() == 64 ? "fma" : "mad", op->args, Call::PureExtern) :
unstrictify_float(op);
equiv.accept(this);
} else if (op->is_intrinsic(Call::round)) {
// HLSL's round intrinsic has the correct semantics for our rounding.
Expr equiv = Call::make(op->type, "round", op->args, Call::PureExtern);
Expand Down Expand Up @@ -763,6 +816,14 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Store *op) {
}

void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Select *op) {
if (op->type.is_vector()) {
// A vector of bool was recursively introduced while
// performing codegen. Eliminate it.
Expr equiv = eliminate_bool_vectors(op);
equiv.accept(this);
return;
}

ostringstream rhs;
string true_val = print_expr(op->true_value);
string false_val = print_expr(op->false_value);
Expand Down Expand Up @@ -832,7 +893,25 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Free *op) {

string CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::print_assignment(Type type, const string &rhs) {
string rhs_modified = print_reinforced_cast(type, rhs);
return CodeGen_GPU_C::print_assignment(type, rhs_modified);
const char *precise_flag = (type.is_float() && emit_precise) ? "precise " : "";
const string cache_key = precise_flag + rhs_modified;

auto cached = cache.find(cache_key);
if (cached == cache.end()) {
id = unique_name('_');
const char *const_flag = output_kind == CPlusPlusImplementation ? " const " : "";
if (type.is_handle()) {
// Don't print void *, which might lose useful type information. just use auto.
stream << get_indent() << "auto *";
} else {
stream << get_indent() << precise_flag << print_type(type, AppendSpace);
}
stream << const_flag << id << " = " << rhs_modified << ";\n";
cache[cache_key] = id;
} else {
id = cached->second;
}
return id;
}

string CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::print_vanilla_cast(Type type, const string &value_expr) {
Expand Down Expand Up @@ -992,7 +1071,7 @@ void CodeGen_D3D12Compute_Dev::add_kernel(Stmt s,
// support predication.
s = scalarize_predicated_loads_stores(s);

debug(2) << "CodeGen_D3D12Compute_Dev: after removing predication: \n"
debug(2) << "CodeGen_D3D12Compute_Dev: after removing predication:\n"
<< s;

// TODO: do we have to uniquify these names, or can we trust that they are safe?
Expand Down Expand Up @@ -1022,6 +1101,11 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::add_kernel(Stmt s,

debug(2) << "Adding D3D12Compute kernel " << name << "\n";

debug(2) << "Eliminating bool vectors\n";
s = eliminate_bool_vectors(s);
debug(2) << "After eliminating bool vectors:\n"
<< s << "\n";

// Figure out which arguments should be passed in constant.
// Such arguments should be:
// - not written to,
Expand Down Expand Up @@ -1260,10 +1344,6 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::add_kernel(Stmt s,
void CodeGen_D3D12Compute_Dev::init_module() {
debug(2) << "D3D12Compute device codegen init_module\n";

// TODO: we could support strict float intrinsics with the precise qualifier
internal_assert(!any_strict_float)
<< "strict float intrinsics not yet supported in d3d12compute backend";

// wipe the internal kernel source
src_stream.str("");
src_stream.clear();
Expand Down Expand Up @@ -1316,32 +1396,38 @@ void CodeGen_D3D12Compute_Dev::init_module() {
<< "#define asuint32 asuint\n"
<< "\n"
#endif
<< "float nan_f32() { return 1.#IND; } \n" // Quiet NaN with minimum fractional value.
<< "float neg_inf_f32() { return -1.#INF; } \n"
<< "float inf_f32() { return +1.#INF; } \n"
<< "#define float_from_bits asfloat \n"
<< "float nan_f32() { return 1.#IND; }\n" // Quiet NaN with minimum fractional value.
<< "float neg_inf_f32() { return -1.#INF; }\n"
<< "float inf_f32() { return +1.#INF; }\n"
<< "#define float_from_bits asfloat\n"
// pow() in HLSL has the same semantics as C if
// x > 0. Otherwise, we need to emulate C
// behavior.
// TODO(shoaibkamil): Can we simplify this?
<< "float pow_f32(float x, float y) { \n"
<< " if (x > 0.0) { \n"
<< " return pow(x, y); \n"
<< " } else if (y == 0.0) { \n"
<< " return 1.0f; \n"
<< " } else if (trunc(y) == y) { \n"
<< " if (fmod(y, 2) == 0) { \n"
<< " return pow(abs(x), y); \n"
<< " } else { \n"
<< " return -pow(abs(x), y); \n"
<< " } \n"
<< " } else { \n"
<< " return nan_f32(); \n"
<< " } \n"
<< "} \n"
<< "#define asinh(x) (log(x + sqrt(x*x + 1))) \n"
<< "#define acosh(x) (log(x + sqrt(x*x - 1))) \n"
<< "#define atanh(x) (log((1+x)/(1-x))/2) \n"
<< "float pow_f32(float x, float y) {\n"
<< " if (x > 0.0) {\n"
<< " return pow(x, y);\n"
<< " } else if (y == 0.0) {\n"
<< " return 1.0f;\n"
<< " } else if (trunc(y) == y) {\n"
<< " if (fmod(y, 2) == 0) {\n"
<< " return pow(abs(x), y);\n"
<< " } else {\n"
<< " return -pow(abs(x), y);\n"
<< " }\n"
<< " } else {\n"
<< " return nan_f32();\n"
<< " }\n"
<< "}\n"
<< "half pow_f16(half x, half y) {\n"
<< " return half(pow_f32(float(x), float(y)));\n"
<< "}\n"
<< "double pow_f64(double x, double y) {\n"
<< " return double(pow_f32(float(x), float(y)));\n"
<< "}\n"
<< "#define asinh(x) (log(x + sqrt(x*x + 1)))\n"
<< "#define acosh(x) (log(x + sqrt(x*x - 1)))\n"
<< "#define atanh(x) (log((1+x)/(1-x))/2)\n"
<< "\n";
//<< "}\n"; // close namespace

Expand Down
Loading