diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..5e7505bf --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "third_party/flashinfer"] + path = third_party/flashinfer + url = https://github.com/flashinfer-ai/flashinfer.git diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 00000000..9876f488 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,176 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Build + +InfiniOps uses CMake + scikit-build-core. The library is compiled into a shared `libinfiniops` and an optional Python extension `ops`. + +### C++ only + +```bash +mkdir build && cd build +cmake .. -DWITH_CPU=ON # or -DWITH_NVIDIA=ON, -DWITH_METAX=ON, etc. +make -j$(nproc) +``` + +### Python package (pip / editable install) + +```bash +pip install .[dev] # installs infiniops + dev tools +# or for an editable build: +pip install -e .[dev] +``` + +`pyproject.toml` sets `AUTO_DETECT_DEVICES=ON` and `GENERATE_PYTHON_BINDINGS=ON` automatically during `pip install`. + +### Backend CMake flags + +| Flag | Backend | +|------|---------| +| `-DWITH_CPU=ON` | CPU (OpenMP) | +| `-DWITH_NVIDIA=ON` | NVIDIA CUDA (requires CUDAToolkit) | +| `-DWITH_ILUVATAR=ON` | Iluvatar (clang++ with `-x ivcore`) | +| `-DWITH_METAX=ON` | MetaX (requires `$MACA_PATH`) | +| `-DWITH_CAMBRICON=ON` | Cambricon (requires `$NEUWARE_HOME`) | + +`WITH_NVIDIA` and `WITH_ILUVATAR` cannot both be ON at the same time. + +## Testing + +```bash +pytest tests/ # run all tests +pytest tests/test_add.py # run one test file +pytest tests/test_add.py::test_add # run a single test +pytest tests/ --benchmark # run with performance benchmarks +pytest tests/ -v --tb=short # verbose output +``` + +Tests auto-parametrize on `dtype` (float32/float16/bfloat16) and `device` (cpu, and cuda/mlu if available). Tests import `infini.ops`, so the package must be installed (or built and on `PYTHONPATH`). + +## Linting + +```bash +ruff check . +ruff format . +``` + +## Code Style + +Follow PEP 8 as the primary style guide. For areas PEP 8 does not cover in detail, refer to the GDScript style guide for non-syntax conventions. Always run `ruff format && ruff check` before committing. + +### Comments + +- Comments must be complete English sentences: capitalize the first word, end with punctuation. +- Use Markdown backtick syntax for code references within comments (e.g. `` `variable_name` ``). +- Error messages and framework-conventional strings (e.g. `pytest.skip` reasons) follow their own conventions — typically lowercase, no trailing period. + +### Docstrings + +- Follow PEP 257. One-line docstrings stay on a single line. Multi-line docstrings have a summary line, a blank line, then the description. + +### Blank lines + +- No blank line between a function signature and its body when there is no docstring or comment. +- Add a blank line before and after `if`, `for`, `while`, and similar compound statements. +- Add a blank line before a `return` statement unless it is directly inside an `if`/`for`/`while` block body. + +## CI + +The `.ci/` directory implements a multi-platform, resource-aware CI system with Docker-based execution, GitHub integration, and cross-machine job dispatch. + +### Configuration + +`config.yaml` uses a **platform-centric** structure that normalizes to flat `{platform}_{job}` names at load time (e.g. `nvidia_gpu`). Each platform defines its Docker image, setup commands, volumes, env vars, and jobs. Jobs inherit platform-level defaults. + +Supported platforms: **nvidia**, **iluvatar**, **ascend** (ascend not ready yet). + +### Building images + +```bash +python .ci/build.py --platform nvidia # build one platform +python .ci/build.py --platform all # build all platforms +python .ci/build.py --platform nvidia --force # skip Dockerfile change detection +python .ci/build.py --push --dry-run # push to registry (preview) +``` + +Dockerfiles live in `.ci/images/{platform}/Dockerfile`. Proxy variables from the host are forwarded automatically. + +### Running the pipeline locally + +```bash +python .ci/run.py # auto-detect platform, run all jobs +python .ci/run.py --job gpu --stage test # run specific job/stage +python .ci/run.py --job gpu --gpu-id 0,2 # override GPU allocation +python .ci/run.py --image-tag stable # use a specific image tag +python .ci/run.py --dry-run # preview docker commands +``` + +Platform is auto-detected by checking for `nvidia-smi` or `ixsmi` on PATH. + +### Agent (scheduler + webhook server) + +`agent.py` provides a resource-aware scheduler with GitHub webhook support and REST API: + +```bash +# Start the agent (webhook server + scheduler) +python .ci/agent.py serve --port 8080 --webhook-secret + +# Dispatch jobs to remote agents via HTTP +python .ci/agent.py run --branch feat/xxx --platform nvidia +python .ci/agent.py run --job nvidia_gpu --dry-run +``` + +**Key capabilities:** + +- **Resource-aware scheduling** — dynamically allocates GPUs based on utilization threshold; queues jobs when resources are busy. +- **GitHub webhooks** — triggers jobs on push/PR events (`/webhook` endpoint, HMAC-SHA256 verified). +- **REST API** — `/api/run` (trigger jobs, Bearer token auth), `/api/job/{id}` (query status), `/status` (queue + resources), `/health`. +- **GitHub commit status** — reports pending/success/failure per job via `github_status.py`. +- **Cross-machine dispatch** — sends jobs to remote platform agents and polls for results. + +### Module overview + +| File | Purpose | +|------|---------| +| `config.yaml` | Platform-centric CI configuration | +| `build.py` | Docker image builder with change detection | +| `run.py` | Standalone Docker CI runner (clone, setup, stages) | +| `agent.py` | Scheduler, webhook server, remote dispatch CLI | +| `utils.py` | Config normalization (`normalize_config`), git helpers | +| `ci_resource.py` | GPU/memory detection and thread-safe allocation (`ResourcePool`) | +| `github_status.py` | GitHub Commit Status API wrapper (zero external deps) | + +### Tests + +```bash +pytest .ci/tests/ # run all CI tests +pytest .ci/tests/test_agent.py # test scheduler and webhooks +``` + +## Architecture + +### C++ layer (`src/`) + +- **`src/base/.h`** — Abstract base class for each operator (e.g. `Add`, `Gemm`, `RmsNorm`). Declares the constructor (capturing tensor metadata) and a pure-virtual `operator()`. +- **`src//.*`** — Backend-specific specializations: `src/cpu/`, `src/cuda/`, `src/nvidia/`, `src/metax/`, `src/cambricon/`, `src/iluvatar/`. Each provides `template<> class Operator`. +- **`src/operator.h`** — `Operator` template that dispatches to the correct device specialization at `make()` time via `DispatchFunc`. Also caches constructed operator descriptors keyed on tensor shape/dtype/strides. +- **`src/tensor.h` / `src/device.h` / `src/data_type.h`** — Core data model: `Tensor` (pointer + shape + strides + dtype + device), `Device`, `DataType`. +- **`src/dispatcher.h`** — `DispatchFunc` selects the right device at runtime based on `Device::Type` and the compile-time `ActiveDevices` set. + +### Python bindings + +Python bindings are **auto-generated** by `scripts/generate_wrappers.py` using libclang to parse `src/base/.h`. The generated output lands in `generated/bindings/ops.cc` and `generated/include/`. Bindings expose each operator both as a callable class (stateful, with constructor) and as a free function (`infini.ops.add(input, other, out)`). + +### Test framework (`tests/`) + +- `conftest.py` implements the `@pytest.mark.auto_act_and_assert` marker: the test function returns a `Payload(func, ref, args, kwargs, rtol, atol)` and the framework calls both, clones tensors for the reference, and asserts `torch.allclose`. +- `device` and `dtype` fixtures are auto-parametrized in `conftest.py`; individual tests can override with explicit `@pytest.mark.parametrize`. +- `tests/utils.py` provides `randn_strided`, `randint_strided`, `empty_strided`, `clone_strided` to create tensors with arbitrary strides. + +### Adding a new operator + +1. Create `src/base/.h` with an abstract class inheriting `Operator`. +2. Implement backend specializations in `src//`. +3. Re-run `scripts/generate_wrappers.py` (or rebuild with `GENERATE_PYTHON_BINDINGS=ON`) to regenerate Python bindings. +4. Add a `tests/test_.py` using the `Payload` / `auto_act_and_assert` pattern. diff --git a/CMakeLists.txt b/CMakeLists.txt index b9e2deb5..da76ca3d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -14,6 +14,7 @@ option(WITH_ILUVATAR "Enable Iluvatar GPU backend" OFF) option(WITH_METAX "Enable MetaX backend" OFF) option(WITH_CAMBRICON "Enable Cambricon backend" OFF) option(WITH_MOORE "Enable Moore backend" OFF) +option(WITH_ASCEND "Enable Ascend backend" OFF) option(AUTO_DETECT_DEVICES "Automatically detect available devices" OFF) option(GENERATE_PYTHON_BINDINGS "Generate Python bindings" OFF) @@ -28,6 +29,31 @@ if(AUTO_DETECT_DEVICES) if(NVIDIA_DEV_FILES) set(WITH_NVIDIA ON) message(STATUS "Auto-detected NVIDIA environment.") + + # Detect the GPU's compute capability so we compile for the right + # architecture. Without this, CMake may pick a lower default (e.g. + # SM75) and kernels that require newer features (bf16 on SM80+) will + # fail at runtime. + if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) + execute_process( + COMMAND nvidia-smi --query-gpu=compute_cap --format=csv,noheader + OUTPUT_VARIABLE _gpu_caps + OUTPUT_STRIP_TRAILING_WHITESPACE + ERROR_QUIET + ) + + if(_gpu_caps) + # Take the first GPU's capability (e.g. "8.0" -> "80"). + string(REGEX MATCH "([0-9]+)\\.([0-9]+)" _cap_match "${_gpu_caps}") + string(REPLACE "." "" _arch "${_cap_match}") + + if(_arch) + set(CMAKE_CUDA_ARCHITECTURES "${_arch}" CACHE STRING + "CUDA architectures (auto-detected from GPU)") + message(STATUS "Auto-detected CUDA architecture: SM${_arch}") + endif() + endif() + endif() endif() file(GLOB ILUVATAR_DEV_FILES "/dev/iluvatar*") @@ -71,20 +97,25 @@ if(AUTO_DETECT_DEVICES) set(WITH_MOORE OFF) set(WITH_MOORE OFF CACHE BOOL "Enable Moore backend" FORCE) endif() + + if(DEFINED ENV{ASCEND_HOME_PATH} OR EXISTS "/dev/davinci0") + set(WITH_ASCEND ON) + message(STATUS "Auto-detected Ascend environment.") + endif() endif() include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src) # Only one CUDA-like GPU backend can be enabled at a time. set(_gpu_backend_count 0) -foreach(_gpu_backend WITH_NVIDIA WITH_ILUVATAR WITH_METAX WITH_MOORE) +foreach(_gpu_backend WITH_NVIDIA WITH_ILUVATAR WITH_METAX WITH_MOORE WITH_ASCEND) if(${_gpu_backend}) math(EXPR _gpu_backend_count "${_gpu_backend_count} + 1") endif() endforeach() if(_gpu_backend_count GREATER 1) - message(FATAL_ERROR "`WITH_NVIDIA`, `WITH_ILUVATAR`, `WITH_METAX`, and `WITH_MOORE` are mutually exclusive. Build one GPU backend at a time.") + message(FATAL_ERROR "`WITH_NVIDIA`, `WITH_ILUVATAR`, `WITH_METAX`, `WITH_MOORE`, and `WITH_ASCEND` are mutually exclusive. Build one GPU backend at a time.") endif() if(WITH_NVIDIA) @@ -178,8 +209,23 @@ if(WITH_CAMBRICON) find_library(CAMBRICON_PAPI_LIB NAMES cnpapi HINTS "${NEUWARE_HOME}/lib64" REQUIRED) endif() +if(WITH_ASCEND) + add_compile_definitions(WITH_ASCEND=1) + if(NOT DEFINED ASCEND_HOME) + if(DEFINED ENV{ASCEND_HOME_PATH} AND NOT "$ENV{ASCEND_HOME_PATH}" STREQUAL "") + set(ASCEND_HOME "$ENV{ASCEND_HOME_PATH}" CACHE PATH "Ascend toolkit root") + else() + set(ASCEND_HOME "/usr/local/Ascend/ascend-toolkit/latest" CACHE PATH "Ascend toolkit root") + endif() + endif() + if(NOT EXISTS "${ASCEND_HOME}") + message(FATAL_ERROR "`WITH_ASCEND` is ON but `${ASCEND_HOME}` was not found. Set ASCEND_HOME_PATH.") + endif() + message(STATUS "Using Ascend from `${ASCEND_HOME}`.") +endif() + # If all other platforms are not enabled, CPU is enabled by default. -if(NOT WITH_NVIDIA AND NOT WITH_ILUVATAR AND NOT WITH_METAX AND NOT WITH_MOORE AND NOT WITH_CAMBRICON) +if(NOT WITH_NVIDIA AND NOT WITH_ILUVATAR AND NOT WITH_METAX AND NOT WITH_MOORE AND NOT WITH_CAMBRICON AND NOT WITH_ASCEND) add_compile_definitions(WITH_CPU=1) endif() diff --git a/docs/superpowers/plans/2026-04-11-dsl-cmake-integration.md b/docs/superpowers/plans/2026-04-11-dsl-cmake-integration.md new file mode 100644 index 00000000..b2d98e10 --- /dev/null +++ b/docs/superpowers/plans/2026-04-11-dsl-cmake-integration.md @@ -0,0 +1,223 @@ +# DSL Compiler CMake Integration + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Unify code generation so `python -m dsl` replaces `generate_wrappers.py` as the single CMake entry point for all generated code (DSL kernels, pybind11 bindings, C API). + +**Architecture:** Move libclang-based binding generation from `scripts/generate_wrappers.py` into `dsl/compiler/bindings.py`. The DSL `__main__.py` calls it after DSL generation. CMake invokes `python -m dsl` instead of `generate_wrappers.py`. The old script is retained as fallback. + +**Tech Stack:** Python (DSL compiler), libclang (C++ parsing), pybind11 (bindings), CMake. + +**Spec:** `docs/superpowers/specs/2026-04-11-dsl-cmake-integration-design.md` + +--- + +## Task 1: Extract binding generation into `dsl/compiler/bindings.py` + +**Files:** +- Create: `dsl/compiler/bindings.py` + +- [ ] **Step 1: Create `dsl/compiler/bindings.py`** + +Move the following from `scripts/generate_wrappers.py` into this new module: + +1. **`_OperatorExtractor` class** (lines 27-90) — libclang AST parsing of `src/base/*.h`. Keep it as-is. + +2. **`_find_optional_tensor_params()`** and **`_find_vector_tensor_params()`** (lines 95-112) — regex-based parameter detection. + +3. **`_generate_pybind11()`** (lines 115-250) — pybind11 binding code generation, including per-op `impl_names` string overloads. + +4. **`_generate_legacy_c()`** (lines 253-464) — C API source/header generation. + +5. **`_snake_to_pascal()`** and **`_get_all_ops()`** (lines 467-489) — utility functions. + +Wrap everything in a single entry point: + +```python +def generate_all_bindings( + devices: list[str], + output_dir: pathlib.Path, + impl_names: dict[str, dict[str, int]], +) -> None: + """Generate pybind11 bindings and C API for all operators. + + This replaces the standalone `scripts/generate_wrappers.py` script. + The libclang parsing, pybind11 generation, and C API generation + logic is moved here verbatim. + """ +``` + +This function should: +1. Discover all ops via `_get_all_ops(devices)` (or `ops.json` if it exists). +2. For each op: parse with `_OperatorExtractor`, generate pybind11 binding header, generate C API files. +3. Assemble `ops.cc` with all includes and `PYBIND11_MODULE`. + +Keep the same output paths: `generated/bindings/`, `generated/include/`, `generated/src/`. + +Constants to define at module level: +```python +_SRC_DIR = pathlib.Path("src") +_BASE_DIR = _SRC_DIR / "base" +_INDENTATION = " " +``` + +**Important:** This is a move, not a rewrite. Copy the functions verbatim from `generate_wrappers.py`, only adjusting imports and making them module-level instead of `if __name__ == "__main__"` scoped. + +- [ ] **Step 2: Verify the module imports cleanly** + +Run: `python -c "from dsl.compiler.bindings import generate_all_bindings; print('OK')"` +Expected: "OK" + +- [ ] **Step 3: Commit** + +``` +git add dsl/compiler/bindings.py +git commit -m "refactor(dsl): extract binding generation into dsl/compiler/bindings.py" +``` + +--- + +## Task 2: Wire bindings into `dsl/__main__.py` + +**Files:** +- Modify: `dsl/__main__.py` + +- [ ] **Step 1: Add binding generation call** + +At the end of `main()`, after the `impl_names.json` write and before the verify/summary print, add: + +```python +if not args.verify: + from dsl.compiler.bindings import generate_all_bindings + generate_all_bindings(args.devices, args.output, all_impl_names) +``` + +Note: `all_impl_names` is already computed by `REGISTRY.all_impl_names()` earlier in `main()`. But the binding generator needs the full set (all ops, not just `--ops` filtered). The current `all_impl_names` call already covers all registered ops. + +**Important detail:** The `generate_all_bindings` function discovers ops by scanning `src/base/*.h` (via `_get_all_ops`), independently of the DSL registry. This is correct — it needs to generate bindings for ALL operators, including `@manual_op` ones that have no DSL variant. + +The `devices` list passed to binding generation must include `"cpu"` if `WITH_CPU` is enabled. Check that `args.devices` includes CPU. The existing `generate_wrappers.py` receives `${DEVICE_LIST}` from CMake which includes `cpu` when `WITH_CPU=ON`. + +- [ ] **Step 2: Test the unified pipeline** + +```bash +python -m dsl --devices cpu nvidia --output generated +``` + +Expected: generates all DSL kernel files + bindings + C API + impl_names.json. + +Verify output matches `generate_wrappers.py`: +```bash +# Save current generated output. +cp -r generated /tmp/dsl_generated + +# Run old script. +python scripts/generate_wrappers.py --devices cpu nvidia + +# Compare bindings (the part that matters). +diff generated/bindings/ops.cc /tmp/dsl_generated/bindings/ops.cc +``` + +The outputs should be identical (or differ only in include ordering, which is harmless). + +- [ ] **Step 3: Commit** + +``` +git add dsl/__main__.py +git commit -m "feat(dsl): integrate binding generation into python -m dsl" +``` + +--- + +## Task 3: Update CMakeLists.txt + +**Files:** +- Modify: `src/CMakeLists.txt` + +- [ ] **Step 1: Replace `generate_wrappers.py` with `python -m dsl`** + +Change the `execute_process` call (around line 229-233): + +From: +```cmake +execute_process( + COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/scripts/generate_wrappers.py --devices ${DEVICE_LIST} + WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} + RESULT_VARIABLE script_result +) +``` + +To: +```cmake +execute_process( + COMMAND ${Python_EXECUTABLE} -m dsl --devices ${DEVICE_LIST} + WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} + RESULT_VARIABLE script_result +) +``` + +Also update the status message: +```cmake +if(NOT script_result EQUAL 0) + message(FATAL_ERROR "DSL compilation and binding generation - failed") +else() + message(STATUS "DSL compilation and binding generation - done") +endif() +``` + +- [ ] **Step 2: Build and verify** + +```bash +pip install -e .[dev] +``` + +Expected: builds successfully using `python -m dsl` instead of `generate_wrappers.py`. + +- [ ] **Step 3: Smoke test** + +```bash +python -c " +import torch, infini.ops +a = torch.randn(4, 4, device='cuda') +b = torch.randn(4, 4, device='cuda') +out = torch.empty(4, 4, device='cuda') +infini.ops.add(a, b, out, implementation='dsl') +print('OK') +" +``` + +- [ ] **Step 4: Commit** + +``` +git add src/CMakeLists.txt +git commit -m "build: replace generate_wrappers.py with python -m dsl in CMake" +``` + +--- + +## Task 4: Full regression test + +- [ ] **Step 1: Run full test suite** + +```bash +pytest tests/ dsl/tests/ --tb=short -q \ + --ignore=tests/test_add_rms_norm.py \ + --ignore=tests/test_cat.py \ + --ignore=tests/test_linear.py \ + --ignore=tests/test_matmul.py +``` + +Expected: 4372+ passed, 0 failed. + +- [ ] **Step 2: Run linter** + +```bash +ruff check dsl/compiler/bindings.py dsl/__main__.py +ruff format dsl/compiler/bindings.py dsl/__main__.py +``` + +- [ ] **Step 3: Commit any lint fixes** + +``` +git add -u && git commit -m "style: fix lint issues" +``` diff --git a/docs/superpowers/plans/2026-04-11-unary-brick-cast-benchmark.md b/docs/superpowers/plans/2026-04-11-unary-brick-cast-benchmark.md new file mode 100644 index 00000000..02c8efa7 --- /dev/null +++ b/docs/superpowers/plans/2026-04-11-unary-brick-cast-benchmark.md @@ -0,0 +1,275 @@ +# Unary Elementwise Brick, Cast Migration, and Performance Benchmark + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Add a `UnaryElementwiseBrick` C++ template, migrate Cast to DSL, and benchmark all DSL operators against hand-written versions. + +**Architecture:** New unary brick templates (CUDA + CPU) with dual-dtype dispatch handle single-input operators. The DSL compiler learns to match unary DAGs and emit code using these bricks. A benchmark script compares DSL vs hand-written kernel performance. + +**Tech Stack:** C++17/CUDA (brick templates), Python (DSL compiler, benchmarks), pybind11 (bindings), pytest + `torch.utils.benchmark` (benchmarks). + +**Spec:** `docs/superpowers/specs/2026-04-11-unary-brick-cast-benchmark-design.md` + +--- + +## Task 1: CUDA unary elementwise brick + +**Files:** +- Create: `src/cuda/templates/unary_elementwise.cuh` + +- [ ] **Step 1: Create the CUDA unary kernel and brick class** + +Model on `src/cuda/templates/binary_elementwise.cuh`. Key differences: +- One input tensor instead of two. +- Dual-dtype dispatch: `Run` takes `InputTypeList` and `OutputTypeList` and dispatches on `(input_dtype, output_dtype)`. +- Op functor signature: `TOut operator()(const TIn& x) const`. +- `UnaryElementwiseBrick` manages device metadata for 2 tensors (input + output) instead of 3. + +Use `DispatchFunc` with `{static_cast(input_dtype), static_cast(output_dtype)}` for mixed multi-type dispatch (see `CONTRIBUTING.md` "Mixed Multi-Type Dispatch" section). Inside the lambda, use `ListGet<0>(list_tag)` and `ListGet<1>(list_tag)` to extract both types. + +- [ ] **Step 2: Verify it compiles** + +Run: `pip install -e .[dev] 2>&1 | tail -3` +Expected: "Successfully installed InfiniOps-0.1.0" + +- [ ] **Step 3: Commit** + +``` +git add src/cuda/templates/unary_elementwise.cuh +git commit -m "feat(dsl): add CUDA unary elementwise brick template" +``` + +--- + +## Task 2: CPU unary elementwise brick + +**Files:** +- Create: `src/cpu/templates/unary_elementwise.h` + +- [ ] **Step 1: Create the CPU unary elementwise function** + +Model on `src/cpu/templates/binary_elementwise.h`. Key differences: +- Single input tensor. +- Dual-dtype dispatch: nested `DispatchFunc` calls — outer dispatches `input_dtype`, inner dispatches `output_dtype` (same pattern as existing `src/cpu/cast/cast.h`). +- Op functor signature: `TOut operator()(const TIn& x) const`. +- OpenMP parallel for loop with `IndexToOffset` for non-contiguous tensors. + +- [ ] **Step 2: Verify it compiles** + +Run: `pip install -e .[dev] 2>&1 | tail -3` +Expected: "Successfully installed InfiniOps-0.1.0" + +- [ ] **Step 3: Commit** + +``` +git add src/cpu/templates/unary_elementwise.h +git commit -m "feat(dsl): add CPU unary elementwise brick template" +``` + +--- + +## Task 3: DSL compiler — unary codegen + +**Files:** +- Modify: `dsl/compiler/infini_codegen.py` — add `_gen_unary_elementwise_cuda()`, `_gen_unary_elementwise_cpu()`, `_generate_unary_functor_cuda()`, `_generate_unary_functor_cpu()` +- Modify: `dsl/__main__.py` — route `BrickKind.UNARY_ELEMENTWISE` to new generators + +Note: `dsl/compiler/patterns.py` already has `BrickKind.UNARY_ELEMENTWISE` and matching logic. + +- [ ] **Step 1: Add unary functor generators to `infini_codegen.py`** + +Add `_generate_unary_functor_cuda(op, dag, match)` and `_generate_unary_functor_cpu(op, dag, match)`. These follow the same pattern as `_generate_binary_functor_cuda/cpu` but with: +- Single input `va` instead of `va, vb`. +- Return type may differ from input type (for Cast). + +For Cast specifically, the functor body is just `return Caster::template Cast(va);` (CUDA) or `return static_cast(va);` (CPU). + +- [ ] **Step 2: Add unary file generators to `infini_codegen.py`** + +Add `_gen_unary_elementwise_cuda(op, dag, match, guard, op_snake)` and `_gen_unary_elementwise_cpu(...)`. These generate complete header files that: +- Include `cuda/templates/unary_elementwise.cuh` or `cpu/templates/unary_elementwise.h`. +- Include the base class header (`base/cast.h`). +- Define the functor struct and `DslCudaCast` / `Operator` classes. +- Use `AllTypes` for both input and output type lists. +- The CUDA class constructor takes `(input, out)` matching Cast's base class. + +- [ ] **Step 3: Wire `generate_cuda_kernel` and `generate_cpu_kernel` to handle `UNARY_ELEMENTWISE`** + +Add `if match.brick == BrickKind.UNARY_ELEMENTWISE` branches in both functions. + +- [ ] **Step 4: Update `__main__.py` to route unary brick** + +In `_generate_infini_op`, the code already calls `generate_cuda_kernel` and `generate_cpu_kernel` which will now handle `UNARY_ELEMENTWISE`. No changes needed in `__main__.py` unless the output path logic differs. Verify by running: + +``` +python -m dsl --ops Cast --output /tmp/dsl_test --devices nvidia +``` + +Expected: generates `cuda/cast/dsl.h`, `cpu/cast/dsl.h`, `nvidia/cast/dsl.h`, registries. + +- [ ] **Step 5: Commit** + +``` +git add dsl/compiler/infini_codegen.py dsl/__main__.py +git commit -m "feat(dsl): add unary elementwise codegen for @infini_op" +``` + +--- + +## Task 4: Cast DSL migration + +**Files:** +- Create: `dsl/ops/cast_dsl.py` +- Create: `src/cuda/cast/dsl.h` (generated) +- Create: `src/nvidia/cast/dsl.h` (generated) +- Create: `src/cpu/cast/dsl.h` (generated) +- Create: `src/nvidia/cast/registry.h` (generated) +- Create: `src/cpu/cast/registry.h` (generated) +- Modify: `src/cpu/cast/cast.h` — add `#include "cpu/cast/registry.h"` +- Create: `tests/test_cast_dsl.py` + +- [ ] **Step 1: Create DSL definition** + +Create `dsl/ops/cast_dsl.py`: +```python +from dsl.decorators import infini_op +from dsl.primitives import Tensor, cast + +@infini_op( + name="Cast", + impl_index=1, + shapes={"N": "output_size"}, + manual_backends={ + "ascend": "ascend/cast/kernel.h", + }, +) +def cast_dsl(input: Tensor["N"]) -> Tensor["N"]: + return cast(input) +``` + +- [ ] **Step 2: Generate and place files** + +``` +python -m dsl --ops Cast --output /tmp/dsl_cast --devices nvidia +``` + +Copy generated files to `src/`: +- `src/cuda/cast/dsl.h` +- `src/nvidia/cast/dsl.h` +- `src/cpu/cast/dsl.h` +- `src/cpu/cast/registry.h` + +For nvidia, manually create `src/nvidia/cast/registry.h` with `List` only (no hand-written NVIDIA impl exists; dispatcher fallback handles default index). + +- [ ] **Step 3: Update existing CPU cast to include registry** + +Add `#include "cpu/cast/registry.h"` to `src/cpu/cast/cast.h`. + +- [ ] **Step 4: Create test** + +Create `tests/test_cast_dsl.py` following `tests/test_cast.py` pattern. Use `implementation="dsl"`. Test fp32→fp16, fp16→fp32, bf16→fp32, fp32→bf16 conversions. + +- [ ] **Step 5: Regenerate `impl_names.json` and rebuild** + +``` +python -m dsl --output generated --devices nvidia +pip install -e .[dev] +``` + +- [ ] **Step 6: Run tests** + +``` +pytest tests/test_cast_dsl.py -v +pytest tests/test_cast.py --devices cpu -v # existing tests (CPU only, no CUDA hand-written) +``` + +Expected: all pass. + +- [ ] **Step 7: Commit** + +``` +git add dsl/ops/cast_dsl.py src/cuda/cast/dsl.h src/nvidia/cast/ src/cpu/cast/ tests/test_cast_dsl.py +git commit -m "feat(dsl): migrate Cast to @infini_op with unary elementwise brick" +``` + +--- + +## Task 5: Performance benchmark + +**Files:** +- Create: `tests/benchmark_dsl.py` + +- [ ] **Step 1: Create benchmark script** + +Create `tests/benchmark_dsl.py` using `torch.utils.benchmark.Timer` and `@pytest.mark.benchmark`. Structure: + +```python +import pytest +import torch +import torch.utils.benchmark as benchmark +import infini.ops + +@pytest.mark.benchmark +@pytest.mark.parametrize("op_name, shape, dtype, setup_fn", [ + # Add + ("add", (4, 4, 5632), torch.float32, _setup_binary), + ("add", (1024, 1024), torch.float16, _setup_binary), + # RmsNorm + ("rms_norm", (2, 4, 2048), torch.float32, _setup_rms_norm), + # Swiglu + ("swiglu", (4, 4, 5632), torch.float32, _setup_binary), + # Cast + ("cast", (4, 4, 5632), torch.float32, _setup_cast), # fp32→fp16 +]) +def test_benchmark_dsl_vs_default(op_name, shape, dtype, setup_fn): + ... +``` + +Each test: +1. Creates tensors on CUDA. +2. Runs the operator with `implementation="default"` (hand-written) — times it. +3. Runs with `implementation="dsl"` — times it. +4. Computes ratio. Prints comparison table. +5. Asserts ratio is within 0.8-1.2 (configurable via marker). + +Skip operators that lack a hand-written CUDA implementation (Mul, Cast on NVIDIA) — they only have DSL, so no comparison is possible. + +- [ ] **Step 2: Run benchmark** + +``` +pytest tests/benchmark_dsl.py --benchmark -v --devices cuda +``` + +Expected: table of results showing DSL vs hand-written timing. + +- [ ] **Step 3: Commit** + +``` +git add tests/benchmark_dsl.py +git commit -m "test(dsl): add performance benchmark comparing DSL vs hand-written kernels" +``` + +--- + +## Task 6: Full regression and final commit + +- [ ] **Step 1: Run full test suite** + +``` +pytest tests/ dsl/tests/ --tb=short -q \ + --ignore=tests/test_add_rms_norm.py \ + --ignore=tests/test_cat.py \ + --ignore=tests/test_linear.py \ + --ignore=tests/test_matmul.py +``` + +Expected: 4300+ passed, 0 failed. + +(The ignored tests are pre-existing CUDA crashes for operators without NVIDIA implementations — unrelated to this work.) + +- [ ] **Step 2: Run linter** + +``` +ruff check dsl/ scripts/generate_wrappers.py tests/test_cast_dsl.py tests/benchmark_dsl.py +ruff format dsl/ tests/test_cast_dsl.py tests/benchmark_dsl.py +``` diff --git a/docs/superpowers/plans/2026-04-12-operator-dispatch-maintenance.md b/docs/superpowers/plans/2026-04-12-operator-dispatch-maintenance.md new file mode 100644 index 00000000..f22cadbd --- /dev/null +++ b/docs/superpowers/plans/2026-04-12-operator-dispatch-maintenance.md @@ -0,0 +1,251 @@ +# Operator Dispatch and Maintenance Optimization + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Eliminate per-operator wrapper files from `src//` so that new platforms require only 4 adapter files and new operators require zero platform-specific boilerplate. + +**Architecture:** Move all auto-generated wrapper/registry/DSL files from `src/` to `generated/`. Update the DSL compiler to generate wrappers for ALL operators (not just DSL ones). Update `ops.cc` includes to reference `generated/` paths. Keep hand-written kernels and platform adapters in `src/`. + +**Tech Stack:** Python (DSL compiler), CMake, C++17. + +**Spec:** `docs/superpowers/specs/2026-04-12-operator-dispatch-maintenance-design.md` + +--- + +## Task 1: Update DSL compiler to generate ALL platform wrappers + +**Files:** +- Modify: `dsl/compiler/codegen.py` +- Modify: `dsl/compiler/bindings.py` +- Modify: `dsl/__main__.py` + +Currently `generate_wrappers_for_op` only generates wrappers for `@infini_op` operators and `@manual_op` operators that have a `cuda` backend entry. It skips operators with explicit per-platform backend entries (e.g., Gemm's `"nvidia": "nvidia/gemm/cublas.h"`). + +- [ ] **Step 1: Update `generate_wrappers_for_op` to generate wrappers for ALL `@manual_op` operators** + +In `dsl/compiler/codegen.py`, the current logic at line ~235 skips backends with explicit string entries: +```python +explicit = backends.get(backend) +if explicit is not None and isinstance(explicit, str): + continue # Skip hand-written +``` + +For operators like Cat, AddRmsNorm, etc., the `backends` dict has no `nvidia` entry — it only has `cuda` (shared kernel) and possibly `ascend`/`cambricon`. The DSL compiler correctly generates nvidia wrappers from the `cuda` entry for these. + +For operators like Gemm that have explicit `"nvidia": "nvidia/gemm/cublas.h"`, these are hand-written multi-file implementations (cublas.h + cublaslt.h + registry.h) that should NOT be auto-generated. They stay in `src/nvidia/gemm/`. + +**No code change needed here** — the existing logic is correct. Gemm/Matmul's nvidia-specific files stay in `src/`. All other operators already get wrappers generated. + +- [ ] **Step 2: Update `bindings.py` to include wrappers from `generated/` instead of `src/`** + +In `dsl/compiler/bindings.py`, the `_get_all_ops` function scans `src/` for `Operator<>` specializations: +```python +for file_path in _SRC_DIR.rglob("*"): + if f"class Operator<{_snake_to_pascal(op_name)}" in file_path.read_text(): + ops[op_name].append(file_path) +``` + +Update to also scan `generated/`: +```python +for search_dir in [_SRC_DIR, output_dir]: + for file_path in search_dir.rglob("*"): + ... +``` + +And update `generate_all_bindings` to accept the output directory and pass it to `_get_all_ops`. + +- [ ] **Step 3: Verify the compiler generates correctly** + +```bash +python -m dsl --devices cpu nvidia --output generated +ls generated/nvidia/add/kernel.h generated/nvidia/cat/kernel.h +``` + +- [ ] **Step 4: Commit** + +``` +git add dsl/compiler/bindings.py dsl/__main__.py +git commit -m "refactor(dsl): scan generated/ for operator specializations in binding generation" +``` + +--- + +## Task 2: Move nvidia wrapper files from `src/` to `generated/` + +**Files to move (delete from `src/`, DSL compiler regenerates in `generated/`):** + +CUDA-like platform wrappers (simple 21-line template files): +- `src/nvidia/add/kernel.h` → generated by DSL +- `src/nvidia/add/dsl.h` → generated by DSL +- `src/nvidia/add/registry.h` → generated by DSL +- `src/nvidia/add_rms_norm/kernel.h` → generated by DSL +- `src/nvidia/cast/dsl.h` → generated by DSL +- `src/nvidia/cast/registry.h` → generated by DSL +- `src/nvidia/cat/kernel.h` → generated by DSL +- `src/nvidia/causal_softmax/kernel.h` → generated by DSL +- `src/nvidia/flash_attention/kernel.h` → generated by DSL +- `src/nvidia/linear/kernel.h` → generated by DSL +- `src/nvidia/mul/dsl.h` → generated by DSL +- `src/nvidia/mul/registry.h` → generated by DSL +- `src/nvidia/reshape_and_cache/kernel.h` → generated by DSL +- `src/nvidia/rms_norm/kernel.h` → generated by DSL +- `src/nvidia/rms_norm/dsl.h` → generated by DSL +- `src/nvidia/rms_norm/registry.h` → generated by DSL +- `src/nvidia/rotary_embedding/kernel.h` → generated by DSL +- `src/nvidia/swiglu/kernel.h` → generated by DSL +- `src/nvidia/swiglu/dsl.h` → generated by DSL +- `src/nvidia/swiglu/registry.h` → generated by DSL + +**Files that stay in `src/nvidia/` (hand-written, NOT auto-generated):** +- `src/nvidia/gemm/cublas.h` — cuBLAS implementation (not a simple wrapper) +- `src/nvidia/gemm/cublaslt.h` — cuBLASLt implementation +- `src/nvidia/gemm/registry.h` — GemmImpl struct + ActiveImplementationsImpl +- `src/nvidia/matmul/cublaslt.h` — cuBLASLt implementation +- `src/nvidia/matmul/cublas.h` — cuBLAS wrapper +- `src/nvidia/matmul/registry.h` — MatmulImpl struct +- Adapter files: `blas.h`, `blas_utils.h`, `caster.cuh`, `data_type_.h`, `device_.h`, `device_property.h`, `runtime_.h`, `runtime_utils.h` + +- [ ] **Step 1: Delete wrapper files from `src/nvidia/`** + +Delete the 20 files listed above. Keep Gemm, Matmul, and adapter files. + +```bash +# Delete per-operator directories that are purely auto-generated. +# Keep gemm/ and matmul/ (hand-written multi-impl). +``` + +- [ ] **Step 2: Regenerate all wrappers in `generated/`** + +```bash +python -m dsl --devices cpu nvidia --output generated +``` + +Verify the generated files exist: +```bash +ls generated/nvidia/add/kernel.h +ls generated/nvidia/cat/kernel.h +ls generated/nvidia/flash_attention/kernel.h +``` + +- [ ] **Step 3: Commit** + +``` +git add -A +git commit -m "refactor: move nvidia wrapper files from src/ to generated/" +``` + +--- + +## Task 3: Move CPU DSL/registry files from `src/` to `generated/` + +**Files to move:** +- `src/cpu/add/dsl.h`, `src/cpu/add/registry.h` +- `src/cpu/cast/dsl.h`, `src/cpu/cast/registry.h` +- `src/cpu/mul/dsl.h`, `src/cpu/mul/registry.h` +- `src/cpu/rms_norm/dsl.h`, `src/cpu/rms_norm/registry.h` +- `src/cpu/swiglu/dsl.h`, `src/cpu/swiglu/registry.h` + +**Files that stay in `src/cpu/`:** +- Hand-written CPU implementations: `add/add.h`, `cast/cast.h`, `mul/mul.h`, etc. + +- [ ] **Step 1: Remove registry includes from hand-written CPU files** + +The hand-written CPU files (e.g., `src/cpu/add/add.h`) currently `#include "cpu/add/registry.h"`. Since registry.h moves to `generated/`, update the include path or have the registry included via `ops.cc` instead. + +Best approach: remove the `#include "cpu//registry.h"` from hand-written CPU files. The registry is only needed by the DSL file (which includes it) and by `ops.cc` (which includes both). + +- [ ] **Step 2: Delete CPU DSL/registry files from `src/`** + +- [ ] **Step 3: Regenerate and verify** + +```bash +python -m dsl --devices cpu nvidia --output generated +ls generated/cpu/add/dsl.h generated/cpu/add/registry.h +``` + +- [ ] **Step 4: Commit** + +``` +git add -A +git commit -m "refactor: move CPU DSL and registry files from src/ to generated/" +``` + +--- + +## Task 4: Update CMake to include `generated/` in include paths + +**Files:** +- Modify: `src/CMakeLists.txt` + +- [ ] **Step 1: Add `generated/` to include directories** + +The `generated/` directory contains header files that need to be found by the compiler. Add: + +```cmake +target_include_directories(infiniops PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR} + ${PROJECT_SOURCE_DIR}/generated +) +``` + +This allows `#include "nvidia/add/kernel.h"` to resolve from `generated/nvidia/add/kernel.h`. + +Note: `src/` is already an include directory. Adding `generated/` means both `src/nvidia/blas.h` (adapter) and `generated/nvidia/add/kernel.h` (wrapper) are findable. + +- [ ] **Step 2: Build and verify** + +```bash +pip install -e .[dev] +``` + +- [ ] **Step 3: Commit** + +``` +git add src/CMakeLists.txt +git commit -m "build: add generated/ to include paths for auto-generated wrappers" +``` + +--- + +## Task 5: Build and full regression test + +- [ ] **Step 1: Clean rebuild** + +```bash +pip install -e .[dev] +``` + +- [ ] **Step 2: Run full test suite** + +```bash +pytest tests/ dsl/tests/ --tb=short -q \ + --ignore=tests/test_add_rms_norm.py \ + --ignore=tests/test_cast.py \ + --ignore=tests/test_cat.py \ + --ignore=tests/test_linear.py \ + --ignore=tests/test_matmul.py +``` + +Expected: 4300+ passed, 0 failed. + +- [ ] **Step 3: Verify `src/nvidia/` is clean** + +```bash +# Should only show adapter files, not per-operator wrappers. +find src/nvidia/ -name "*.h" -o -name "*.cuh" | sort +``` + +Expected output: only `blas.h`, `blas_utils.h`, `caster.cuh`, `data_type_.h`, `device_.h`, `device_property.h`, `runtime_.h`, `runtime_utils.h`, plus `gemm/` and `matmul/` directories. + +- [ ] **Step 4: Run linter** + +```bash +ruff check dsl/ --fix +``` + +- [ ] **Step 5: Final commit** + +``` +git add -A +git commit -m "refactor: complete separation of hand-written and generated code" +``` diff --git a/docs/superpowers/specs/2026-04-11-benchmark-baseline.md b/docs/superpowers/specs/2026-04-11-benchmark-baseline.md new file mode 100644 index 00000000..78f18cbe --- /dev/null +++ b/docs/superpowers/specs/2026-04-11-benchmark-baseline.md @@ -0,0 +1,165 @@ +# InfiniOps CUDA Operator Benchmark Baseline + +**Date**: 2026-04-11 +**Hardware**: NVIDIA A100-SXM4-80GB (SM80) +**CUDA**: 13.0 +**Tool**: `torch.utils.benchmark.Timer.blocked_autorange(min_run_time=2)` + +--- + +## Elementwise Operators + +| Operator | Shape | dtype | Time (ms) | +|----------|-------|-------|-----------| +| **Add** | (4,4,5632) | fp32 | 0.010 | +| **Add** | (1,32,4096) | fp32 | 0.010 | +| **Add** | (64,32,128) | fp32 | 0.010 | +| **Add** | (4,4,5632) | fp16 | 0.010 | +| **Add** | (1,32,4096) | fp16 | 0.010 | +| **Add** | (64,32,128) | fp16 | 0.010 | +| **Add** | (4,4,5632) | bf16 | 0.010 | +| **Add** | (1,32,4096) | bf16 | 0.010 | +| **Add** | (64,32,128) | bf16 | 0.010 | +| **Mul** | (4,4,5632) | fp32 | 0.010 | +| **Mul** | (1,32,4096) | fp32 | 0.010 | +| **Mul** | (64,32,128) | fp32 | 0.010 | +| **Mul** | (4,4,5632) | fp16 | 0.010 | +| **Mul** | (1,32,4096) | fp16 | 0.010 | +| **Mul** | (64,32,128) | fp16 | 0.010 | +| **Mul** | (4,4,5632) | bf16 | 0.010 | +| **Mul** | (1,32,4096) | bf16 | 0.010 | +| **Mul** | (64,32,128) | bf16 | 0.010 | +| **Cast** | (4,4,5632) | fp32→fp16 | 0.008 | +| **Cast** | (4,4,5632) | fp16→fp32 | 0.008 | +| **Cast** | (1,32,4096) | fp32→bf16 | 0.008 | +| **Cast** | (1,32,4096) | bf16→fp32 | 0.008 | +| **Swiglu** | (4,4,5632) | fp32 | 0.010 | +| **Swiglu** | (1,32,4096) | fp32 | 0.010 | +| **Swiglu** | (4,4,5632) | fp16 | 0.010 | +| **Swiglu** | (1,32,4096) | fp16 | 0.010 | +| **Swiglu** | (4,4,5632) | bf16 | 0.010 | +| **Swiglu** | (1,32,4096) | bf16 | 0.010 | + +**Note**: Elementwise ops at these sizes are launch-overhead dominated +(~10 us). Differences become meaningful at larger tensor sizes (>1M +elements). + +--- + +## Normalization Operators + +| Operator | Shape | dtype | Time (ms) | +|----------|-------|-------|-----------| +| **RmsNorm** | (2,4,2048) | fp32 | 0.010 | +| **RmsNorm** | (1,32,4096) | fp32 | 0.010 | +| **RmsNorm** | (4,48,64) | fp32 | 0.010 | +| **RmsNorm** | (2,4,2048) | fp16 | 0.010 | +| **RmsNorm** | (1,32,4096) | fp16 | 0.010 | +| **RmsNorm** | (4,48,64) | fp16 | 0.010 | +| **RmsNorm** | (2,4,2048) | bf16 | 0.010 | +| **RmsNorm** | (1,32,4096) | bf16 | 0.010 | +| **RmsNorm** | (4,48,64) | bf16 | 0.010 | +| **AddRmsNorm** | (2,4,2048) | fp32 | 0.014 | +| **AddRmsNorm** | (1,32,4096) | fp32 | 0.014 | +| **AddRmsNorm** | (2,4,2048) | fp16 | 0.014 | +| **AddRmsNorm** | (1,32,4096) | fp16 | 0.014 | +| **AddRmsNorm** | (2,4,2048) | bf16 | 0.014 | +| **AddRmsNorm** | (1,32,4096) | bf16 | 0.014 | +| **CausalSoftmax** | (2,4,64,64) | fp32 | 0.008 | +| **CausalSoftmax** | (1,32,128,128) | fp32 | 0.054 | +| **CausalSoftmax** | (2,4,64,64) | fp16 | 0.008 | +| **CausalSoftmax** | (1,32,128,128) | fp16 | 0.057 | +| **CausalSoftmax** | (2,4,64,64) | bf16 | 0.008 | +| **CausalSoftmax** | (1,32,128,128) | bf16 | 0.061 | + +--- + +## GEMM / Linear + +| Operator | Shape (M,N,K) | dtype | Time (ms) | TFLOPS | +|----------|---------------|-------|-----------|--------| +| **Gemm** | (1024,1024,1024) | fp16 | 0.040 | 53.8 | +| **Gemm** | (4096,4096,4096) | fp16 | 0.584 | 235.4 | +| **Gemm** | (1,4096,4096) | fp16 | 0.021 | 1.6 | +| **Gemm** | (1024,1024,1024) | bf16 | 0.038 | 56.0 | +| **Gemm** | (4096,4096,4096) | bf16 | 0.571 | 240.6 | +| **Gemm** | (1,4096,4096) | bf16 | 0.021 | 1.6 | +| **Matmul** | (1024,1024,1024) | fp16 | 0.017 | 124.6 | +| **Matmul** | (4096,4096,4096) | fp16 | 0.590 | 232.9 | +| **Matmul** | (1,4096,4096) | fp16 | 0.023 | 1.5 | +| **Matmul** | (1024,1024,1024) | bf16 | 0.019 | 112.9 | +| **Matmul** | (4096,4096,4096) | bf16 | 0.552 | 248.8 | +| **Matmul** | (1,4096,4096) | bf16 | 0.023 | 1.5 | +| **Linear** | (1024,4096,4096) no bias | fp16 | 0.210 | — | +| **Linear** | (1024,4096,4096) + bias | fp16 | 0.229 | — | +| **Linear** | (1,4096,4096) no bias | fp16 | 0.021 | — | + +**Note**: A100 theoretical peak: 312 TFLOPS (fp16 tensor core). Gemm/Matmul +at 4096³ achieve ~235-249 TFLOPS (75-80% utilization). The Matmul 1024³ +result (124.6 TFLOPS) is better than Gemm (53.8 TFLOPS) because Matmul +uses cuBLASLt with heuristic algorithm selection. + +--- + +## Position / Cache Operators + +| Operator | Config | dtype | Time (ms) | +|----------|--------|-------|-----------| +| **RotaryEmbed** | T=128 H=32 D=128 | fp16 | 0.016 | +| **RotaryEmbed** | T=1 H=32 D=128 | fp16 | 0.016 | +| **RotaryEmbed** | T=512 H=32 D=64 | fp16 | 0.016 | +| **RotaryEmbed** | T=128 H=32 D=128 | bf16 | 0.016 | +| **RotaryEmbed** | T=1 H=32 D=128 | bf16 | 0.016 | +| **RotaryEmbed** | T=512 H=32 D=64 | bf16 | 0.016 | +| **ReshapeAndCache** | T=128 Nkv=8 D=128 BS=16 | fp16 | 0.014 | +| **ReshapeAndCache** | T=32 Nkv=32 D=128 BS=16 | fp16 | 0.014 | + +--- + +## Attention + +| Operator | SeqLen | Heads (Q/KV) | HeadDim | dtype | Time (ms) | TFLOPS | +|----------|--------|-------------|---------|-------|-----------|--------| +| **FlashAttn** | 128 | 32/32 | 128 | fp16 | 0.014 | 19.6 | +| **FlashAttn** | 512 | 32/32 | 128 | fp16 | 0.041 | 105.0 | +| **FlashAttn** | 2048 | 32/32 | 128 | fp16 | 0.240 | 286.3 | +| **FlashAttn** | 128 | 32/8 | 128 | fp16 | 0.014 | 19.5 | +| **FlashAttn** | 512 | 32/8 | 128 | fp16 | 0.036 | 119.6 | +| **FlashAttn** | 128 | 32/32 | 128 | bf16 | 0.014 | 19.5 | +| **FlashAttn** | 512 | 32/32 | 128 | bf16 | 0.041 | 105.0 | +| **FlashAttn** | 2048 | 32/32 | 128 | bf16 | 0.240 | 286.6 | +| **FlashAttn** | 128 | 32/8 | 128 | bf16 | 0.014 | 19.7 | +| **FlashAttn** | 512 | 32/8 | 128 | bf16 | 0.036 | 119.7 | + +**Note**: FlashAttention via FlashInfer. At S=2048, achieves 286 TFLOPS +(92% of A100 peak). GQA (32/8 heads) is faster than MHA at same seq_len +due to fewer KV heads. + +--- + +## Cat + +| Config | dtype | Time (ms) | +|--------|-------|-----------| +| 3×(4,128) dim=0 | fp16 | 0.012 | +| (4,1024)+(4,2048)+(4,512) dim=1 | fp16 | 0.012 | +| 2×(2,32,4096) dim=0 | fp16 | 0.010 | + +--- + +## Optimization Priorities + +Based on this baseline, areas with the most optimization potential: + +1. **Gemm 1024³**: 53.8 TFLOPS vs Matmul's 124.6 TFLOPS — Gemm uses + cuBLAS default algorithm while Matmul uses cuBLASLt with heuristic + search. Consider switching Gemm's default to cuBLASLt. + +2. **Linear**: 0.210 ms for (1024,4096,4096) — could benefit from + cuBLASLt like Matmul. + +3. **CausalSoftmax (1,32,128,128)**: 0.054-0.061 ms — may benefit from + warp-level online softmax or shared memory tiling optimization. + +4. **Elementwise ops**: All at ~0.010 ms (launch overhead). For larger + tensors, consider vectorized loads (float4) and grid-stride loops. diff --git a/docs/superpowers/specs/2026-04-11-cross-platform-dsl-design.md b/docs/superpowers/specs/2026-04-11-cross-platform-dsl-design.md new file mode 100644 index 00000000..aca898a6 --- /dev/null +++ b/docs/superpowers/specs/2026-04-11-cross-platform-dsl-design.md @@ -0,0 +1,398 @@ +# InfiniOps Cross-Platform DSL Design + +## Problem + +Adding a new operator to InfiniOps requires 10+ files and ~670 lines of code, +roughly 50% of which is boilerplate. CUDA-like backends (NVIDIA, MetaX, +Iluvatar, Moore) share ~99% of kernel code via `src/cuda/` templates, yet each +still needs a hand-written 21-line wrapper file per operator. CPU +implementations duplicate the same mathematical logic in a separate +OpenMP-based form. The core algorithmic intent is expressed repeatedly across +backends rather than once. + +Ascend uses aclnn vendor APIs exclusively and cannot share kernel code with +CUDA backends. Its implementations will remain hand-written. + +## Solution + +A Python DSL for defining operator semantics, paired with a C++ template +building-block library ("bricks"). The DSL compiler translates operator +definitions into C++ code that composes these bricks. Hand-written kernels +remain available for performance-critical or complex operators via an escape +hatch. + +### Scope + +- **Automated by DSL**: CUDA-like backends (NVIDIA, MetaX, Iluvatar, Moore) + + CPU. +- **Hand-written, DSL-managed boilerplate**: Ascend (aclnn), Cambricon + (cnnl/BANG), and any future vendor-API platform. +- **Performance target**: generated kernel code within 10-20% of hand-written. + Performance-critical operators (GEMM, FlashAttention) use the escape hatch. + +--- + +## 1. Python DSL + +### Operator definition + +Operators are Python functions decorated with `@infini_op`. The function body +uses a restricted set of tensor primitives to describe mathematical semantics +declaratively (no control flow, no side effects). + +```python +# dsl/ops/rms_norm.py + +from infini_dsl import infini_op, Tensor, Scalar, reduce_mean, rsqrt + +@infini_op( + name="RmsNorm", + shapes={"B": "batch", "H": "heads", "D": "dim"}, +) +def rms_norm( + input: Tensor["B", "H", "D"], + weight: Tensor["D"], + eps: Scalar[float] = 1e-6, +) -> Tensor["B", "H", "D"]: + ss = reduce_mean(input * input, dim="D") + rms = rsqrt(ss + eps) + return input * rms * weight +``` + +Shape variables (`B`, `H`, `D`) let the compiler infer grid/block mapping and +derive base-class member fields. + +### Primitive set + +| Category | Primitives | +|----------|------------| +| Elementwise | `+`, `-`, `*`, `/`, `sqrt`, `rsqrt`, `exp`, `log`, `abs`, `neg`, `pow`, `clamp` | +| Activation | `relu`, `gelu`, `silu`, `sigmoid`, `tanh` | +| Reduction | `reduce_sum`, `reduce_mean`, `reduce_max`, `reduce_min` | +| Softmax | `softmax`, `log_softmax` | +| Comparison | `where(cond, a, b)`, `>`, `<`, `>=`, `<=`, `eq` | +| Type | `cast(x, dtype)` | +| Shape | `reshape`, `transpose`, `unsqueeze`, `expand`, `cat`, `slice` | +| Index | `gather`, `scatter`, `index_select` | +| Scalar | `Scalar[float]`, `Scalar[int]` | + +Operators that cannot be expressed with these primitives use `@manual_op`. + +### Escape hatch + +```python +@manual_op( + name="Gemm", + base="src/base/gemm.h", + backends={ + "cuda": "src/cuda/gemm/blas.h", + "ascend": "src/ascend/gemm/kernel.h", + "cpu": "src/cpu/gemm/gemm.h", + }, +) +def gemm(): ... +``` + +`@manual_op` tells the compiler to generate only boilerplate (backend wrapper +files, Python bindings, test scaffolding) while leaving kernel logic to the +hand-written files specified in `backends`. + +### Mixed mode + +An `@infini_op` can specify `manual_backends` for platforms that need +hand-written implementations while still auto-generating for CUDA-like +backends and CPU: + +```python +@infini_op( + name="RmsNorm", + manual_backends={ + "ascend": "src/ascend/rms_norm/kernel.h", + "cambricon": "src/cambricon/rms_norm/rms_norm.h", + }, +) +def rms_norm(...): + ... +``` + +CUDA-like backends and CPU get auto-generated code; Ascend and Cambricon use +the specified hand-written files. One decorator manages all backends. + +--- + +## 2. DSL compiler + +### Pipeline + +``` +Python DSL source → AST parse → Compute DAG → Pattern match → C++ codegen +``` + +**AST parse**: extracts the function signature (tensor shapes, dtypes, scalar +attributes) and body (primitive operations). + +**Compute DAG**: a directed acyclic graph where nodes are primitive operations +and edges are tensor data flows. Shape variables propagate through the graph +for dimension inference. + +**Pattern match**: the compiler maintains a set of pattern rules that map +subgraph shapes to template bricks: + +```python +PATTERNS = [ + Pattern(match=all_elementwise, emit="ElementwiseKernel"), + Pattern(match=reduce_then_transform, emit="ReduceThenTransform"), + Pattern(match=softmax_pattern, emit="SoftmaxKernel"), + Pattern(match=has_gather_scatter, emit="IndexKernel"), + Pattern(match=pure_reduction, emit="ReductionKernel"), +] +``` + +If a subgraph cannot be matched, the compiler emits an error directing the +user to either decompose the operator or use `@manual_op`. + +**C++ codegen**: emits C++ source files using Jinja2 templates. Generated code +calls template bricks with operator-specific functors. + +### Directory structure + +``` +dsl/ + ops/ # Operator definitions (@infini_op, @manual_op) + compiler/ + __init__.py + parser.py # AST → compute DAG + patterns.py # Pattern matching rules + codegen.py # C++ code generation (CUDA-like + CPU) + templates/ # Jinja2 templates for generated C++ files + base_class.h.j2 + cuda_kernel.h.j2 + backend_wrapper.h.j2 + cpu_kernel.h.j2 + test.py.j2 +``` + +### Invocation + +```bash +python -m dsl.compiler --devices nvidia metax iluvatar moore cpu \ + --output generated/ +``` + +Integrated into CMake, runs before compilation. Replaces the current +`generate_wrappers.py` call (bindings generation is subsumed). + +--- + +## 3. C++ template brick library + +Hand-written, optimized C++ templates that serve as the code-generation +targets. Each brick is parameterized on `Device::Type kDev` and user-provided +functors, so the same brick serves all CUDA-like backends. + +### Brick inventory + +| Brick | Location | Covers | +|-------|----------|--------| +| `ElementwiseKernel` | `src/cuda/templates/elementwise.cuh` | Add, Mul, ReLU, GELU, SiLU, Sigmoid, Tanh, Cast, Abs, Neg | +| `BroadcastKernel` | `src/cuda/templates/broadcast.cuh` | Elementwise ops on different-shaped tensors | +| `ReductionKernel` | `src/cuda/templates/reduction.cuh` | ReduceSum, ReduceMean, ReduceMax, ReduceMin | +| `ReduceThenTransform` | `src/cuda/templates/reduce_transform.cuh` | RmsNorm, LayerNorm, L2Norm | +| `SoftmaxKernel` | `src/cuda/templates/softmax.cuh` | Softmax, LogSoftmax, CausalSoftmax | +| `IndexKernel` | `src/cuda/templates/index.cuh` | Gather, Scatter, IndexSelect, Embedding | +| `ShapeKernel` | `src/cuda/templates/shape.cuh` | Reshape, Transpose, Cat, Slice | + +### Interface pattern + +```cpp +// src/cuda/templates/elementwise.cuh + +template +struct ElementwiseKernel { + static void Run( + typename Runtime::Stream stream, + const Tensor input, + Tensor output, + F op); +}; +``` + +Bricks use `Caster` for type conversions and `Runtime` for memory +operations. This defers all platform-specific details to the existing +per-backend specializations. + +### CPU counterparts + +Each CUDA brick has a CPU counterpart in `src/cpu/templates/` using OpenMP: + +```cpp +// src/cpu/templates/elementwise.h + +template +struct CpuElementwise { + static void Run(const Tensor input, Tensor output, F op); +}; +``` + +### Generated code example + +For `rms_norm`, the compiler generates: + +```cpp +// generated/cuda/rms_norm/kernel.h + +template +class CudaRmsNorm : public RmsNorm { + void operator()(const Tensor input, const Tensor weight, + Tensor out) const override { + ReduceThenTransform::Run( + stream_, input, out, + ReduceMeanSquare{}, + RsqrtEpsMulWeight{weight, eps_}, + dim_, batch_size_, nhead_); + } +}; +``` + +--- + +## 4. Generated output + +### For `@infini_op` operators + +``` +generated/ + base/.h # Abstract base class + cuda//kernel.h # CudaOp template (brick calls) + nvidia//kernel.h # Operator wrapper + metax//kernel.h # Operator wrapper + iluvatar//kernel.h # Operator wrapper + moore//kernel.h # Operator wrapper + cpu//.h # CPU implementation (OpenMP bricks) + bindings/.h # pybind11 bindings + src//operator.cc # C API (legacy) + tests/test_.py # Parametrized tests +``` + +### For `@manual_op` operators + +``` +generated/ + nvidia//kernel.h # Wrapper pointing to hand-written cuda impl + metax//kernel.h # Wrapper + iluvatar//kernel.h # Wrapper + moore//kernel.h # Wrapper + bindings/.h # pybind11 bindings + tests/test_.py # Test scaffolding +``` + +Base class, kernel logic, and Ascend/Cambricon implementations remain in +`src/` under manual control. + +### Unchanged files + +- `src/cuda/templates/` — hand-written brick library. +- `src/ascend/` — all Ascend implementations. +- `src/operator.h`, `src/dispatcher.h`, `src/device.h` — core framework. +- `src//runtime_.h`, `data_type_.h`, `caster.cuh` — platform + adaptation layers. + +--- + +## 5. New platform onboarding + +### CUDA-compatible platforms + +Provide four adaptation files: + +``` +src//device_.h # DeviceEnabled = true +src//runtime_.h # Runtime: Stream, Malloc, Free, Memcpy +src//data_type_.h # TypeMap specializations for fp16/bf16 +src//caster.cuh # Type conversion specializations +``` + +Add `--devices ` to the compiler invocation. All `@infini_op` +operators automatically get generated wrappers for the new platform. No +operator definitions need to change. + +### Vendor-API platforms + +Add the platform to `manual_backends` in each operator's `@infini_op` or +`@manual_op` definition: + +```python +@infini_op( + name="RmsNorm", + manual_backends={ + "ascend": "src/ascend/rms_norm/kernel.h", + "new_vendor": "src/new_vendor/rms_norm/kernel.h", + }, +) +``` + +Hand-write each operator implementation using the vendor's SDK. The compiler +generates wrappers and bindings. + +--- + +## 6. Migration strategy + +### Phase 1: `@manual_op` for all existing operators + +Register every existing operator as `@manual_op`. This immediately eliminates +hand-written wrapper files (the ~21-line `Operator` files) and +centralizes binding generation. No kernel code changes. + +### Phase 2: Extract template bricks from existing kernels + +Refactor existing hand-written CUDA kernels in `src/cuda/` into the template +brick library. The existing `CudaAdd`, `CudaRmsNorm`, etc. provide the +implementations. + +### Phase 3: Migrate simple operators to `@infini_op` + +Convert elementwise operators (Add, ReLU, Cast, SiLU, etc.) to DSL +definitions. Verify generated code matches existing behavior via tests. + +### Phase 4: Migrate medium-complexity operators + +Convert reduction-based operators (RmsNorm, LayerNorm, Softmax) to DSL +definitions using the `ReduceThenTransform` and `SoftmaxKernel` bricks. + +### Non-migrated operators + +GEMM, FlashAttention, RotaryEmbedding, and other complex/performance-critical +operators remain as `@manual_op` indefinitely. The DSL still manages their +boilerplate. + +--- + +## 7. Verification + +### Auto-generated tests + +The compiler derives a PyTorch reference implementation directly from the DSL +function body and generates parametrized tests using the existing +`Payload`/`auto_act_and_assert` framework. + +### Brick-level tests + +``` +tests/test_templates/ + test_elementwise.py + test_reduction.py + test_reduce_transform.py + test_softmax.py + test_index.py +``` + +### End-to-end + +```bash +python -m dsl.compiler --devices nvidia metax iluvatar moore cpu \ + --output generated/ +pip install -e .[dev] +pytest tests/ -v --tb=short +pytest tests/ --devices ascend -v # Ascend ops unaffected +``` diff --git a/docs/superpowers/specs/2026-04-11-cross-platform-dsl-roadmap.md b/docs/superpowers/specs/2026-04-11-cross-platform-dsl-roadmap.md new file mode 100644 index 00000000..25592f5a --- /dev/null +++ b/docs/superpowers/specs/2026-04-11-cross-platform-dsl-roadmap.md @@ -0,0 +1,147 @@ +# Cross-Platform DSL Implementation Roadmap + +**Related design spec**: `2026-04-11-cross-platform-dsl-design.md` + +--- + +## Phase 1: `@manual_op` codegen foundation + +**Goal**: Replace `generate_wrappers.py` with a new compiler that also generates +backend wrappers, reducing per-operator boilerplate immediately. + +**Status**: Completed + +### Steps + +1. **DSL framework scaffolding** — Create the `dsl/` directory structure: + `dsl/compiler/`, `dsl/ops/`, `dsl/templates/`. +2. **Port binding generation from `generate_wrappers.py`** — Move the + libclang-based AST parsing and pybind11 codegen into `dsl/compiler/`. +3. **Add backend wrapper generation** — Extend the compiler to generate + `Operator` wrapper files for CUDA-like backends. +4. **Register all existing operators as `@manual_op`** — Create `dsl/ops/*.py` + for every existing operator (14 operators registered). +5. **Integrate into CMake** — Replace the `generate_wrappers.py` call with + `python -m dsl.compiler`. + +**Verification**: Build with all backends, run full test suite, diff generated +bindings against previous output to confirm identical behavior. + +--- + +## Phase 2: C++ template brick library + +**Goal**: Extract reusable kernel templates from existing CUDA implementations. + +**Status**: Completed + +### Steps + +1. **`BinaryElementwiseBrick`** — Extract from `src/cuda/add/kernel.cuh`. + Created `src/cuda/templates/binary_elementwise.cuh` (GPU) and + `src/cpu/templates/binary_elementwise.h` (OpenMP CPU). +2. **`ReduceThenTransform`** — Extract from `src/cuda/rms_norm/kernel.cuh`. + Created `src/cuda/templates/reduce_transform.cuh` (GPU) and + `src/cpu/templates/reduce_transform.h` (CPU). +3. **Future bricks** (Phase 4+): `SoftmaxKernel`, `ReductionKernel`, + `IndexKernel`, `BroadcastKernel`, `ShapeKernel` — to be added as needed. + +**Verification**: Existing tests pass after each refactor. Built-in ops +(`MeanSquareReduce`, `RmsNormTransform`) bundled with brick headers for +backward compatibility. + +--- + +## Phase 3: `@infini_op` compiler + +**Goal**: Build the DSL compiler that translates Python operator definitions +into C++ code using template bricks. + +**Status**: Completed (10 unit tests passing) + +### Steps + +1. **DSL AST parser** — `dsl/compiler/parser.py`: parse `@infini_op` function + bodies into a compute DAG (`dsl/compiler/dag.py`). +2. **Pattern matcher** — `dsl/compiler/patterns.py`: match DAG subgraphs to + bricks. Supports `BINARY_ELEMENTWISE` and `REDUCE_THEN_TRANSFORM` patterns. +3. **C++ code generator** — `dsl/compiler/infini_codegen.py`: generate CUDA + kernel headers, CPU implementation headers, and backend wrappers. +4. **Mixed mode support** — `manual_backends` parameter allows hand-written + implementations for specific platforms (Ascend, Cambricon) alongside + auto-generated CUDA/CPU code. + +**Verification**: `AddDsl` and `RmsNormDsl` defined as `@infini_op`, compiler +generates correct C++ code, all 10 unit tests pass. + +--- + +## Phase 4: NV GPU compilation verification + +**Goal**: Verify that DSL-generated code compiles and produces correct results +on NVIDIA GPU hardware. + +**Status**: Completed + +### Steps + +1. **Create base classes** — `src/base/add_dsl.h` and `src/base/rms_norm_dsl.h` + mirroring existing `Add` and `RmsNorm` interfaces. +2. **Place generated kernel files** — DSL compiler output placed into `src/cuda/`, + `src/nvidia/`, `src/cpu/` where CMake GLOB picks them up. +3. **Python bindings** — Auto-generated by `generate_wrappers.py` which + auto-discovers new base classes in `src/base/`. +4. **Build** — `pip install -e .[dev]` succeeds with CUDA compilation. +5. **Tests** — Created `tests/test_add_dsl.py` and `tests/test_rms_norm_dsl.py`. + +### Results + +| Test suite | Tests | Result | +|---|---|---| +| AddDsl (CPU + CUDA, fp32/fp16/bf16) | 36 | All passed | +| RmsNormDsl (CPU + CUDA, fp32/fp16/bf16, two eps values) | 72 | All passed | +| Existing operators (regression check) | 288 | All passed | +| DSL compiler unit tests | 10 | All passed | + +--- + +## Phase 5: Operator migration (planned) + +**Goal**: Migrate existing operators from hand-written to DSL-defined. + +**Status**: Not started + +### Step 5.1: Elementwise operators + +Migrate: Add, Mul, ReLU, GELU, SiLU, Sigmoid, Tanh, Cast, Abs, Neg. +Each migration: write DSL definition, generate code, run tests, remove +hand-written files from `src/`. + +### Step 5.2: Reduction-based operators + +Migrate: RmsNorm, LayerNorm, Softmax. +Requires `ReduceThenTransform` and `SoftmaxKernel` bricks. + +### Step 5.3: Remaining pattern coverage + +Add `IndexKernel`, `BroadcastKernel`, `ShapeKernel` bricks and migrate +Gather, Scatter, Cat, Transpose, etc. + +**Verification**: Full test suite passes after each migration. CI on all +platforms. + +--- + +## Key files + +| File / Directory | Purpose | +|---|---| +| `dsl/ops/*.py` | Operator definitions (`@manual_op`, `@infini_op`) | +| `dsl/compiler/parser.py` | AST parser: `@infini_op` body to compute DAG | +| `dsl/compiler/patterns.py` | Pattern matcher: DAG subgraphs to bricks | +| `dsl/compiler/infini_codegen.py` | C++ code generation for `@infini_op` | +| `dsl/compiler/codegen.py` | Backend wrapper generation | +| `src/cuda/templates/` | Hand-written CUDA brick library | +| `src/cpu/templates/` | Hand-written CPU brick library | +| `src/base/add_dsl.h` | Base class for DSL-generated Add | +| `src/base/rms_norm_dsl.h` | Base class for DSL-generated RmsNorm | diff --git a/docs/superpowers/specs/2026-04-11-dsl-cmake-integration-design.md b/docs/superpowers/specs/2026-04-11-dsl-cmake-integration-design.md new file mode 100644 index 00000000..ad352216 --- /dev/null +++ b/docs/superpowers/specs/2026-04-11-dsl-cmake-integration-design.md @@ -0,0 +1,175 @@ +# DSL Compiler CMake Integration + +## Problem + +The build system runs `generate_wrappers.py` for pybind11 bindings and C +API generation, while `python -m dsl` is a separate manual step for DSL +kernel generation. This dual-system setup means: + +- DSL-generated files must be pre-generated before `pip install`. +- `impl_names.json` must exist before `generate_wrappers.py` runs. +- New operators require touching both systems. + +## Solution + +Unify code generation into `python -m dsl`, which absorbs all functionality +from `generate_wrappers.py`. CMake calls one command. +`generate_wrappers.py` is retained as a fallback but not called by CMake. + +--- + +## Architecture + +### Before + +``` +CMakeLists.txt + └─ execute_process(generate_wrappers.py --devices ...) + ├─ libclang parse src/base/*.h + ├─ scan src/ for Operator<> specializations + └─ emit: generated/bindings/*.h, ops.cc, include/*.h, src/*/operator.cc + +Manual step: + └─ python -m dsl --output generated --devices ... + ├─ emit: DSL kernel files (cuda/*/dsl.h, etc.) + ├─ emit: registry.h files + └─ emit: impl_names.json +``` + +### After + +``` +CMakeLists.txt + └─ execute_process(python -m dsl --devices ...) + ├─ DSL kernel generation (unchanged) + ├─ registry.h generation (unchanged) + ├─ impl_names.json generation (unchanged) + ├─ libclang parse src/base/*.h (moved from generate_wrappers.py) + ├─ scan src/ for Operator<> specializations (moved) + └─ emit: generated/bindings/*.h, ops.cc, include/*.h, src/*/operator.cc +``` + +`generate_wrappers.py` remains in `scripts/` as a fallback. It is not +called by CMake. It can be used to verify output consistency during the +transition period. + +--- + +## Implementation + +### 1. Create `dsl/compiler/bindings.py` + +Move from `generate_wrappers.py`: +- `_OperatorExtractor` class (libclang AST parsing) +- `_generate_pybind11()` function (pybind11 binding generation) +- `_generate_legacy_c()` function (C API generation) +- Helper functions: `_find_optional_tensor_params()`, + `_find_vector_tensor_params()`, `_snake_to_pascal()` + +The module exposes one entry point: + +```python +def generate_all_bindings( + devices: list[str], + output_dir: pathlib.Path, + impl_names: dict[str, dict[str, int]], +) -> None: +``` + +This function: +1. Discovers all operators via `src/base/*.h` (same logic as + `_get_all_ops()` in `generate_wrappers.py`). +2. For each operator, parses the base class with libclang, generates + pybind11 bindings (with per-op `impl_names` string overloads) and + C API files. +3. Assembles `ops.cc` with all includes and `PYBIND11_MODULE`. + +### 2. Update `dsl/__main__.py` + +After the existing DSL generation loop, call: + +```python +from dsl.compiler.bindings import generate_all_bindings +generate_all_bindings(args.devices, args.output, all_impl_names) +``` + +This replaces the separate `generate_wrappers.py` invocation. + +### 3. Update `src/CMakeLists.txt` + +Replace: +```cmake +execute_process( + COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/scripts/generate_wrappers.py + --devices ${DEVICE_LIST} + ... +) +``` + +With: +```cmake +execute_process( + COMMAND ${Python_EXECUTABLE} -m dsl --devices ${DEVICE_LIST} + ... +) +``` + +### 4. Keep `generate_wrappers.py` as fallback + +No changes to `scripts/generate_wrappers.py`. It can be run manually to +verify output consistency: + +```bash +# Compare outputs. +python -m dsl --devices nvidia --output /tmp/dsl_out +python scripts/generate_wrappers.py --devices nvidia +diff -r generated/ /tmp/dsl_out/ +``` + +--- + +## Files to create/modify + +| File | Action | +|------|--------| +| `dsl/compiler/bindings.py` | New: libclang parsing + binding generation (moved from generate_wrappers.py) | +| `dsl/__main__.py` | Modify: call `generate_all_bindings()` after DSL generation | +| `src/CMakeLists.txt` | Modify: replace `generate_wrappers.py` with `python -m dsl` | + +## What stays unchanged + +- `scripts/generate_wrappers.py` — retained as fallback, not called by CMake +- All existing DSL generation logic in `dsl/compiler/` +- libclang parsing logic (moved, not rewritten) +- Generated output format (bindings, C API, ops.cc) + +## Verification + +```bash +# Build with unified pipeline. +pip install -e .[dev] + +# Verify bindings work. +python -c "import infini.ops; print(dir(infini.ops))" + +# Verify string implementation param works. +python -c " +import torch, infini.ops +a = torch.randn(4, 4, device='cuda') +b = torch.randn(4, 4, device='cuda') +out = torch.empty(4, 4, device='cuda') +infini.ops.add(a, b, out, implementation='dsl') +print('OK') +" + +# Full test suite. +pytest tests/ dsl/tests/ --tb=short -q \ + --ignore=tests/test_add_rms_norm.py \ + --ignore=tests/test_cat.py \ + --ignore=tests/test_linear.py \ + --ignore=tests/test_matmul.py + +# Compare with legacy script output (optional). +python scripts/generate_wrappers.py --devices cpu nvidia +diff generated/bindings/ops.cc /tmp/legacy_ops.cc +``` diff --git a/docs/superpowers/specs/2026-04-11-flashinfer-integration-design.md b/docs/superpowers/specs/2026-04-11-flashinfer-integration-design.md new file mode 100644 index 00000000..14142dcb --- /dev/null +++ b/docs/superpowers/specs/2026-04-11-flashinfer-integration-design.md @@ -0,0 +1,74 @@ +# FlashAttention via FlashInfer Integration + +## Problem + +FlashAttention is the only operator in InfiniOps without an NVIDIA +implementation. FlashInfer provides a header-only C++ API with +state-of-the-art attention kernels for both prefill and decode. + +## Solution + +Integrate FlashInfer as a header-only dependency. Wrap its C++ API in +InfiniOps's `CudaFlashAttention` operator class, mapping InfiniOps's +`FlashAttention` base class parameters to FlashInfer's param structs. + +--- + +## Integration approach + +1. Add FlashInfer headers to `third_party/flashinfer/include/`. +2. Add FlashInfer's CUTLASS dependency to `third_party/flashinfer/3rdparty/cutlass/`. +3. Update `src/CMakeLists.txt` to add include paths when `WITH_NVIDIA=ON`. +4. Create `src/cuda/flash_attention/kernel.h` wrapping FlashInfer's + `SinglePrefillWithKVCacheDispatched`. +5. Create `src/nvidia/flash_attention/kernel.h` as the nvidia wrapper. + +## Parameter mapping + +| InfiniOps | FlashInfer | +|-----------|-----------| +| `query [T, N, D]` | `params.q`, `params.qo_len=T` | +| `key [S, Nkv, D]` | `params.k`, `params.kv_len=S` | +| `value [S, Nkv, D]` | `params.v` | +| `num_heads` | `params.num_qo_heads` | +| `num_kv_heads` | `params.num_kv_heads` | +| `head_size` | template `HEAD_DIM` + `params.head_dim` | +| `scale` | `params.sm_scale` | +| `causal` | `MaskMode::kCausal` vs `MaskMode::kNone` | +| `window_left` | `params.window_left` | +| `output [T, N, D]` | `params.o` | + +## Scope + +Initial implementation covers **single-request prefill** (non-paged, +contiguous KV). This handles the standard attention pattern. Paged KV +cache and batch decode can be added later. + +## Head dimension dispatch + +FlashInfer requires HEAD_DIM as a compile-time template parameter. +Dispatch at runtime: + +```cpp +switch (head_size) { + case 64: return launch<64>(...); + case 128: return launch<128>(...); + case 256: return launch<256>(...); + default: assert(false && "unsupported head_size"); +} +``` + +## Data type dispatch + +Use InfiniOps's existing `DispatchFunc` for dtype → (half, nv_bfloat16, +float) mapping. + +## Files + +| File | Action | +|------|--------| +| `third_party/flashinfer/` | New: FlashInfer headers (git submodule) | +| `src/CMakeLists.txt` | Modify: add FlashInfer include path | +| `src/cuda/flash_attention/kernel.h` | New: CudaFlashAttention wrapper | +| `src/nvidia/flash_attention/kernel.h` | New: nvidia specialization | +| `tests/test_flash_attention.py` | Modify: enable CUDA tests | diff --git a/docs/superpowers/specs/2026-04-11-unary-brick-cast-benchmark-design.md b/docs/superpowers/specs/2026-04-11-unary-brick-cast-benchmark-design.md new file mode 100644 index 00000000..1035c1b8 --- /dev/null +++ b/docs/superpowers/specs/2026-04-11-unary-brick-cast-benchmark-design.md @@ -0,0 +1,169 @@ +# Unary Elementwise Brick, Cast Migration, and DSL Performance Benchmark + +## Problem + +The DSL currently has two brick templates (`binary_elementwise` and +`reduce_transform`) covering two-input elementwise and reduction-based +operators. Single-input operators like Cast cannot be expressed. +Additionally, there is no systematic performance comparison between +DSL-generated and hand-written kernel code. + +## Solution + +1. Add `UnaryElementwiseBrick` template (CUDA + CPU). +2. Migrate Cast to `@infini_op` using the new brick. +3. Benchmark all DSL-migrated operators against hand-written versions. + +--- + +## 1. `UnaryElementwiseBrick` + +### CUDA template (`src/cuda/templates/unary_elementwise.cuh`) + +A single-input elementwise kernel with dual-dtype dispatch. + +```cpp +template +__global__ void UnaryElementwiseKernel( + TOut* __restrict__ out, const TIn* __restrict__ in, + const size_t* __restrict__ out_shape, + const size_t* __restrict__ in_shape, + const ptrdiff_t* __restrict__ out_strides, + const ptrdiff_t* __restrict__ in_strides, + size_t output_size, size_t ndim, + bool out_contig, bool in_contig); +``` + +**Key differences from `BinaryElementwiseBrick`:** +- Single input tensor (no `other`). +- Dual-dtype dispatch: `DispatchFunc` resolves + `(TIn, TOut)` at runtime from `(input_dtype, output_dtype)`. +- Op functor signature: `TOut operator()(const TIn& x) const`. + +**`UnaryElementwiseBrick` class:** +- Constructor takes `(input, out, ndim)` — allocates device metadata for + two tensors (not three). +- `Run()` does the dual dispatch and + kernel launch. + +### CPU template (`src/cpu/templates/unary_elementwise.h`) + +```cpp +template +void CpuUnaryElementwise( + const Tensor in, Tensor out, Tensor::Size output_size, + Tensor::Size ndim, bool in_contig, bool out_contig, + const Tensor::Shape& in_shape, const Tensor::Shape& out_shape, + const Tensor::Strides& in_strides, const Tensor::Strides& out_strides, + DataType input_dtype, DataType output_dtype, Op op); +``` + +Uses `DispatchFunc` with two `DataType` lists for dual dispatch, OpenMP +parallel for loop, and `Caster` for type conversion. + +### Future reuse + +Although Cast is the immediate use case, the unary brick also serves future +single-input operators (ReLU, GELU, Sigmoid, Abs, Neg). Those have +`input_dtype == output_dtype`, which works naturally — dual dispatch +resolves both to the same type. + +--- + +## 2. Cast DSL migration + +### DSL definition + +```python +# dsl/ops/cast_dsl.py +@infini_op(name="Cast", impl_index=1, shapes={"N": "output_size"}) +def cast_dsl(input: Tensor["N"]) -> Tensor["N"]: + return cast(input) +``` + +### Compiler changes + +**`dsl/compiler/patterns.py`:** +- Add `BrickKind.UNARY_ELEMENTWISE`. +- Match rule: single input, no reduction, single output → unary. + +**`dsl/compiler/infini_codegen.py`:** +- Add `_gen_unary_elementwise_cuda()` and `_gen_unary_elementwise_cpu()`. +- Cast functor body: `Caster::Cast(x)` (pure type conversion, + no math). +- Generated class `DslCudaCast` inherits from `Cast` base class. + +**`dsl/__main__.py`:** +- Route `UNARY_ELEMENTWISE` brick to the new generators. +- Output paths: `cuda/cast/dsl.h`, `nvidia/cast/dsl.h`, `cpu/cast/dsl.h`, + plus `registry.h` files. + +### Registration + +- `Operator` via generated nvidia wrapper. +- `Operator` via generated CPU file. +- `registry.h` files for nvidia and CPU. +- Cast currently has no NVIDIA hand-written implementation, so the nvidia + registry declares `List` only (dispatcher fallback handles + default index). + +--- + +## 3. Performance benchmark + +### Test file + +`tests/benchmark_dsl.py`, using `@pytest.mark.benchmark` (only runs with +`pytest --benchmark`). + +### Test matrix + +| Operator | Shapes | Dtypes | Compare | +|----------|--------|--------|---------| +| Add | (4,4,5632), (16,5632), (1024,1024) | fp32, fp16, bf16 | default vs dsl | +| RmsNorm | (2,4,2048), (4,48,64) | fp32, fp16, bf16 | default vs dsl | +| Swiglu | (4,4,5632), (16,5632) | fp32, fp16, bf16 | default vs dsl | +| Cast | (4,4,5632), (1024,1024) | fp32→fp16, fp16→fp32 | default vs dsl | + +Mul is excluded (NVIDIA has DSL-only, no hand-written to compare). + +### Measurement + +- CUDA event timing (`torch.cuda.Event`) for GPU kernel time. +- Warmup runs + multiple iterations, report median. +- Output: table with `hand-written ms`, `dsl ms`, `ratio`. + +### Success criterion + +DSL-generated code within 80-120% of hand-written performance (per the +design spec's 10-20% tolerance target). + +--- + +## Files to create/modify + +| File | Action | +|------|--------| +| `src/cuda/templates/unary_elementwise.cuh` | New: CUDA unary brick | +| `src/cpu/templates/unary_elementwise.h` | New: CPU unary brick | +| `dsl/compiler/patterns.py` | Modify: add `UNARY_ELEMENTWISE` | +| `dsl/compiler/infini_codegen.py` | Modify: add unary codegen | +| `dsl/__main__.py` | Modify: route unary brick | +| `dsl/ops/cast_dsl.py` | New: Cast DSL definition | +| `src/cuda/cast/dsl.h` | New: generated CUDA kernel | +| `src/nvidia/cast/dsl.h` | New: generated nvidia wrapper | +| `src/cpu/cast/dsl.h` | New: generated CPU impl | +| `src/{nvidia,cpu}/cast/registry.h` | New: impl registry | +| `src/cpu/cast/cast.h` | Modify: add registry include | +| `tests/benchmark_dsl.py` | New: performance benchmark | + +## Verification + +```bash +pip install -e .[dev] +pytest tests/test_cast.py -v # existing Cast tests +pytest tests/test_cast_dsl.py -v # new DSL Cast tests +pytest tests/ --ignore=... --tb=short # full regression +pytest tests/benchmark_dsl.py --benchmark -v # performance comparison +``` diff --git a/docs/superpowers/specs/2026-04-12-operator-dispatch-maintenance-design.md b/docs/superpowers/specs/2026-04-12-operator-dispatch-maintenance-design.md new file mode 100644 index 00000000..df9a9e93 --- /dev/null +++ b/docs/superpowers/specs/2026-04-12-operator-dispatch-maintenance-design.md @@ -0,0 +1,175 @@ +# Operator Dispatch and Maintenance Optimization + +## Problem + +As operators and platforms grow, maintenance cost scales as `O(ops × platforms)`. +Each new platform requires a wrapper file per operator; each new operator +requires a wrapper per platform. Currently, DSL-generated wrappers are +copied manually into `src/`, and `src//` mixes adapter files +(4 per platform) with per-operator wrappers (1 per operator). + +## Goal + +Reduce the per-operator-per-platform cost to zero for CUDA-like platforms. +New platform onboarding: provide 4 adapter files, add a CMake flag, build. +New operator onboarding: write base class + CUDA kernel + DSL registration, +build. All wrappers generated automatically. + +## Design + +### Directory responsibility separation + +**`src/` — hand-written code only** + +``` +src/ + base/.h # Abstract base class + cuda//kernel.cuh # Shared CUDA kernel + cuda//kernel.h # Shared CUDA launcher (CudaOp) + cuda/templates/ # Reusable brick templates + cpu//.h # CPU implementation + nvidia/ # Platform adapter files ONLY: + device_.h + runtime_.h + data_type_.h + caster.cuh + blas.h + blas_utils.h + metax/ # Same 4-6 adapter files + device_.h, runtime_.h, ... + iluvatar/ # Same + moore/ # Same + ascend/ # Ascend-specific impls (aclnn, not CUDA-like) + /kernel.h + cambricon/ # Cambricon-specific impls + /.h +``` + +No per-operator wrapper files in `src/nvidia/`, `src/metax/`, etc. + +**`generated/` — all auto-generated code** + +``` +generated/ + nvidia//kernel.h # Operator wrapper + metax//kernel.h # Operator wrapper + iluvatar//kernel.h # ... + moore//kernel.h + cpu//dsl.h # DSL CPU impl (if @infini_op) + nvidia//dsl.h # DSL CUDA impl (if @infini_op) + nvidia//registry.h # ActiveImplementationsImpl (if multi-impl) + cpu//registry.h # ... + bindings/*.h # pybind11 bindings + bindings/ops.cc # PYBIND11_MODULE + include/*.h # C API headers + src/*/operator.cc # C API sources + impl_names.json # Per-op implementation name mapping +``` + +### CMake changes + +Add `generated//` to the source GLOB for each CUDA-like backend: + +```cmake +if(WITH_NVIDIA) + set(NVIDIA_PATTERNS + "cuda/*.cc" "cuda/*.cpp" "cuda/*.cu" + "nvidia/*.cc" "nvidia/*.cpp" "nvidia/*.cu" + ) + file(GLOB_RECURSE NVIDIA_SOURCES CONFIGURE_DEPENDS ${NVIDIA_PATTERNS}) + + # Add DSL-generated wrappers. + file(GLOB_RECURSE NVIDIA_GENERATED CONFIGURE_DEPENDS + "${PROJECT_SOURCE_DIR}/generated/nvidia/*.h" + ) + + # ... (wrapper .h files are header-only, included by ops.cc) +endif() +``` + +Since wrapper files are headers (not `.cc`), they are pulled in via +`#include` from the generated `ops.cc`. The CMake change is mainly about +ensuring the include path covers `generated/`. + +### DSL compiler changes + +`python -m dsl --devices ${DEVICE_LIST}` already generates: +- `@infini_op` kernel files (cuda/cpu DSL code) +- Backend wrappers for CUDA-like platforms +- Bindings, C API, impl_names.json + +**Changes needed:** +1. Generate `@manual_op` wrappers to `generated/` instead of relying on + `generate_wrappers.py` scanning `src/`. +2. Remove the `_get_all_ops(devices)` scan-based discovery. All ops are + already registered in `dsl/ops/*.py` — use the registry directly. +3. The generated `ops.cc` includes should reference `generated//` + paths instead of `src//`. + +### New platform onboarding flow + +``` +1. mkdir src// +2. Create: device_.h, runtime_.h, data_type_.h, caster.cuh +3. CMakeLists.txt: add WITH_ option, GLOB patterns, link libs +4. pip install -e .[dev] ← DSL auto-generates all wrappers +``` + +No operator-specific files needed. The DSL compiler reads the `--devices` +list and generates `Operator` wrappers for every registered +operator. + +### New operator onboarding flow + +``` +1. Create src/base/.h (base class) +2. Create src/cuda//kernel.cuh (CUDA kernel) +3. Create src/cuda//kernel.h (CUDA launcher: CudaOp) +4. Create dsl/ops/.py (@manual_op or @infini_op) +5. Create tests/test_.py (tests) +6. pip install -e .[dev] ← wrappers + bindings auto-generated +``` + +For Ascend/Cambricon (non-CUDA-like): also add `src/ascend//kernel.h` +and reference it in `manual_backends` of the DSL definition. + +### Migration plan + +1. Move existing `src/nvidia//kernel.h` wrappers to `generated/`. +2. Move existing `src/nvidia//dsl.h` to `generated/`. +3. Move existing `src/nvidia//registry.h` to `generated/`. +4. Same for cpu DSL files and registries. +5. Keep `src/nvidia/` with only adapter files. +6. Update `ops.cc` includes from `src/nvidia//` to + `generated/nvidia//`. +7. Verify full test suite passes. + +### What stays unchanged + +- `src/base/` — base classes (hand-written) +- `src/cuda/` — shared CUDA kernels and templates (hand-written) +- `src/cpu/` — hand-written CPU implementations +- `src/ascend/`, `src/cambricon/` — vendor-API implementations (hand-written) +- `src/operator.h`, `src/dispatcher.h` — core framework +- DSL decorator format (`@manual_op` / `@infini_op`) +- Python test framework + +## Verification + +```bash +pip install -e .[dev] +pytest tests/ dsl/tests/ --tb=short -q \ + --ignore=tests/test_add_rms_norm.py \ + --ignore=tests/test_cast.py \ + --ignore=tests/test_cat.py \ + --ignore=tests/test_linear.py \ + --ignore=tests/test_matmul.py +``` + +All tests must pass with zero wrapper files in `src/nvidia/*/`. + +## References + +- [PyTorch native_functions.yaml](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml) +- [PyTorch Operator Registration](https://docs.pytorch.org/docs/stable/accelerator/operators.html) +- [ATen native README](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md) diff --git a/docs/superpowers/specs/2026-04-12-optimization-log.md b/docs/superpowers/specs/2026-04-12-optimization-log.md new file mode 100644 index 00000000..4527ebc3 --- /dev/null +++ b/docs/superpowers/specs/2026-04-12-optimization-log.md @@ -0,0 +1,108 @@ +# Optimization Log — A100-SXM4-80GB + +## Round 1: Vectorized Binary Elementwise Brick + +**Problem**: Add (4096²) fp16 at 612 GB/s (31% of A100 HBM 2 TB/s). +Each thread processes 1 element, no vectorized load. + +**Fix**: Add `BinaryElementwiseVecKernel` with 128-bit coalesced +load/store and grid-stride loop for contiguous tensors. + +**Result (DSL Add)**: 612 GB/s → **1646 GB/s** (2.7x, matches PyTorch). + +## Round 2: Refactor CudaAdd/CudaSwiglu to Use Vectorized Brick + +**Problem**: Hand-written CudaAdd and CudaSwiglu still use old scalar +kernels, not the improved brick. + +**Fix**: Replace per-element kernels with `BinaryElementwiseBrick`. + +| Operator | Before | After | Speedup | +|----------|--------|-------|---------| +| Add (4096²) fp16 | 0.164 ms (612 GB/s) | 0.077 ms (1315 GB/s) | **2.1x** | +| Swiglu (4096²) fp16 | ~0.164 ms | 0.062 ms (1612 GB/s) | **~2.6x** | + +## Round 3: Grid-Stride Loop for Unary Elementwise + +**Problem**: Cast fp32→fp16 (4096²) at 626 GB/s. + +**Fix**: Add `UnaryElementwiseVecKernel` with grid-stride loop. + +**Result**: 0.161 ms (626 GB/s) → **0.092 ms (1094 GB/s)** (1.75x). + +## Round 4: RmsNorm Analysis (No Change) + +RmsNorm (32,32,4096) is 3.3x slower than PyTorch. Root cause: +PyTorch likely uses a more optimized reduce kernel. Requires deeper +kernel rewrite — deferred. + +## Round 5: Post-Optimization Full Benchmark (4096² fp16 on A100) + +| Operator | Time (ms) | Bandwidth / TFLOPS | vs PyTorch | +|----------|-----------|-------------------|------------| +| **Add** | 0.076 | 1318 GB/s | 0.80x | +| **Mul** | 0.061 | 1647 GB/s | ≈1.0x | +| **Swiglu** | 0.062 | 1611 GB/s | 1.15x faster | +| **Cast fp32→fp16** | 0.079 | 1279 GB/s | 0.78x | +| **Gemm 4096³** | 0.587 | 234 TFLOPS | ≈1.0x | +| **Matmul 1024³** | 0.017 | 126 TFLOPS | 2.0x faster | +| **Linear 1024×4096²** | 0.171 | — | 1.2x faster | +| **FlashAttn S=2048** | 0.241 | 286 TFLOPS | 1.12x faster | + +## Round 6 (new series): Full Baseline with CUDA Profiler + +Used `torch.profiler` to measure actual kernel time (not Python overhead): + +| Operator | InfiniOps kernel | PyTorch kernel | Real ratio | +|----------|-----------------|----------------|------------| +| **Add (4096²)** | 60.1 us | 59.3 us | **1.0x ✓** | +| **CausalSoftmax** | 73.3 us | 16.5 us (2 kernels) | **4.4x ✗** | +| **Cast fp32→fp16** | 103.6 us | 61.5 us | **1.7x ✗** | +| **RmsNorm** | 21 us (bench) | 11 us (bench) | **1.9x ✗** | +| **AddRmsNorm** | 42.6 us | 28.9 us (2 kernels) | **1.5x ✗** | + +Key insight: Add's 20% benchmark gap is entirely Python binding +overhead — CUDA kernel is matching PyTorch. + +## Round 7: Cast Vectorized Load (new series Round 3) + +Added 128-bit vectorized input load + output store. + +Cast fp32→fp16 (4096²): 0.092 ms → **0.078 ms** (+17%, 1285 GB/s). +Gap vs PyTorch (1645 GB/s): 22% — limited by mixed-type vectorization +(input vec size ≠ output vec size). + +## Round 8: RmsNorm Vectorized Attempts (new series Rounds 4-5) + +Attempted two approaches: +1. Register caching (store x in registers during reduce, reuse in + transform) — **failed**: register pressure reduced occupancy, slower. +2. Warp shuffle reduction (replace CUB BlockReduce with manual + `__shfl_xor_sync`) — **failed**: no improvement, CUB is already + well-optimized. +3. Vectorized 128-bit struct loads — **failed**: anonymous struct + alignment issues, compiler couldn't optimize. + +Root cause: PyTorch's `vectorized_layer_norm` uses a fundamentally +different approach — needs deeper study with nsight compute. + +## Current Status (Post All Optimization) + +| Operator | InfiniOps (ms) | PyTorch (ms) | Ratio | Status | +|----------|---------------|-------------|-------|--------| +| Add (4096²) | 0.076 | 0.061 | 0.80x | ✓ kernel matched (binding overhead) | +| Mul (4096²) | 0.061 | 0.061 | 1.00x | ✓ | +| Swiglu (4096²) | 0.062 | 0.167 | 2.68x | ✓ faster | +| Cast (4096²) | 0.078 | 0.061 | 0.78x | ✗ 22% gap | +| RmsNorm | 0.021 | 0.011 | 0.49x | ✗ 2x gap | +| AddRmsNorm | 0.036 | 0.028 | 0.78x | ✗ | +| CausalSoftmax | 0.056 | 0.034 | 0.61x | ✗ | +| Gemm 4096³ | 0.594 | 0.571 | 0.96x | ✓ | +| Matmul 4096³ | 0.590 | 0.574 | 0.97x | ✓ | +| Linear 1024×4096² | 0.173 | 0.211 | 1.22x | ✓ faster | +| RotaryEmbed | 0.020 | 0.099 | 4.93x | ✓ faster | +| FlashAttn S=2048 | 0.240 | 0.269 | 1.12x | ✓ faster | + +**7/12 operators match or beat PyTorch.** Remaining gaps in +RmsNorm/AddRmsNorm (vectorized reduce), CausalSoftmax (warp-level +softmax), and Cast (mixed-type vectorization). diff --git a/dsl/__init__.py b/dsl/__init__.py new file mode 100644 index 00000000..573e3807 --- /dev/null +++ b/dsl/__init__.py @@ -0,0 +1,3 @@ +"""InfiniOps cross-platform DSL for operator definition and code generation.""" + +from dsl.decorators import infini_op, manual_op diff --git a/dsl/__main__.py b/dsl/__main__.py new file mode 100644 index 00000000..9fee9557 --- /dev/null +++ b/dsl/__main__.py @@ -0,0 +1,278 @@ +"""CLI entry point: ``python -m dsl``.""" + +from __future__ import annotations + +import argparse +import difflib +import json +import pathlib +import sys + +from dsl.compiler.codegen import CUDA_LIKE_BACKENDS, generate_wrappers_for_op +from dsl.compiler.infini_codegen import generate_cpu_kernel, generate_cuda_kernel +from dsl.compiler.parser import parse_infini_op +from dsl.compiler.patterns import match_dag +from dsl.compiler.registry import REGISTRY +from dsl.decorators import InfiniOpDef +from dsl.ops import discover + + +def _to_snake(pascal: str) -> str: + """Convert PascalCase to snake_case.""" + import re + + return re.sub(r"(?<=[a-z0-9])(?=[A-Z])", "_", pascal).lower() + + +def _generate_infini_op( + op: InfiniOpDef, + output_dir: pathlib.Path, +) -> list[pathlib.Path]: + """Generate CUDA + CPU files for an `@infini_op` operator.""" + dag = parse_infini_op(op) + match = match_dag(dag) + op_snake = _to_snake(op.name) + generated: list[pathlib.Path] = [] + + # Determine output filenames based on impl_index. + cuda_filename = "dsl.h" if op.impl_index > 0 else "kernel.h" + cpu_filename = "dsl.h" if op.impl_index > 0 else f"{op_snake}.h" + + # Generate shared CUDA kernel. + cuda_content = generate_cuda_kernel(op, dag, match) + cuda_path = output_dir / "cuda" / op_snake / cuda_filename + cuda_path.parent.mkdir(parents=True, exist_ok=True) + cuda_path.write_text(cuda_content) + generated.append(cuda_path) + + # Generate CPU implementation. + cpu_content = generate_cpu_kernel(op, dag, match) + cpu_path = output_dir / "cpu" / op_snake / cpu_filename + cpu_path.parent.mkdir(parents=True, exist_ok=True) + cpu_path.write_text(cpu_content) + generated.append(cpu_path) + + return generated + + +def _generate_registry( + op_name: str, + impl_indices: list[int], + devices: list[str], + output_dir: pathlib.Path, + primary_op: ManualOpDef | InfiniOpDef | None = None, +) -> list[pathlib.Path]: + """Generate ``registry.h`` files declaring active implementation indices.""" + op_snake = _to_snake(op_name) + generated: list[pathlib.Path] = [] + + # Determine which devices have a hand-written default implementation + # (index 0). If the primary @manual_op has a `cuda` or device-specific + # backend entry, it has a default impl on CUDA-like platforms. If not, + # only the DSL variant (index 1+) exists. + from dsl.decorators import ManualOpDef + + def _has_default_impl(device: str) -> bool: + if primary_op is None: + return True + + if not isinstance(primary_op, ManualOpDef): + return True + + backends = primary_op.backends + + if device == "cpu": + return "cpu" in backends + + # For CUDA-like devices, a default impl exists if either the + # specific device or the shared "cuda" key is in backends. + return device in backends or "cuda" in backends + + for device in ["cpu"] + [d for d in devices if d in CUDA_LIKE_BACKENDS]: + if device == "cpu": + device_enum = "Device::Type::kCpu" + else: + from dsl.compiler.codegen import BACKEND_ENUM + + device_enum = f"Device::Type::k{BACKEND_ENUM[device]}" + + guard = f"INFINI_OPS_{device.upper()}_{op_snake.upper()}_REGISTRY_H_" + + # Filter impl_indices: only include kDefault (0) if a hand-written + # implementation exists for this device. + device_indices = [ + i for i in impl_indices + if i > 0 or _has_default_impl(device) + ] + + if not device_indices: + continue + + # Use named constants from Impl for readability. + named_indices = ", ".join( + "Impl::kDsl" if i > 0 else "Impl::kDefault" + for i in sorted(device_indices) + ) + + content = ( + f"#ifndef {guard}\n" + f"#define {guard}\n" + f"\n" + f'#include "base/{op_snake}.h"\n' + f'#include "impl.h"\n' + f"\n" + f"namespace infini::ops {{\n" + f"\n" + f"template <>\n" + f"struct ActiveImplementationsImpl<{op_name}, {device_enum}> {{\n" + f" using type = List<{named_indices}>;\n" + f"}};\n" + f"\n" + f"}} // namespace infini::ops\n" + f"\n" + f"#endif\n" + ) + + reg_path = output_dir / device / op_snake / "registry.h" + reg_path.parent.mkdir(parents=True, exist_ok=True) + reg_path.write_text(content) + generated.append(reg_path) + + return generated + + +def _diff_file(expected: str, actual: str, label: str) -> list[str]: + return list( + difflib.unified_diff( + actual.splitlines(keepends=True), + expected.splitlines(keepends=True), + fromfile=f"existing/{label}", + tofile=f"generated/{label}", + ) + ) + + +def main() -> None: + parser = argparse.ArgumentParser( + description="InfiniOps DSL compiler — generate backend wrappers.", + ) + parser.add_argument( + "--devices", + nargs="+", + default=list(CUDA_LIKE_BACKENDS), + help="CUDA-like backends to generate wrappers for.", + ) + parser.add_argument( + "--output", + type=pathlib.Path, + default=pathlib.Path("generated"), + help="Output directory for generated files.", + ) + parser.add_argument( + "--verify", + action="store_true", + help="Compare generated wrappers against existing hand-written files " + "in src/ and report differences.", + ) + parser.add_argument( + "--ops", + nargs="*", + default=None, + help="Generate only the specified operators (default: all).", + ) + + args = parser.parse_args() + + # Discover and register all operator definitions. + discover() + + ops = REGISTRY.all_ops() + + if args.ops: + ops = {k: v for k, v in ops.items() if k in args.ops} + + if not ops: + print("No operators found.", file=sys.stderr) + sys.exit(1) + + src_dir = pathlib.Path("src") + total_generated = 0 + total_diffs = 0 + + for name, op in sorted(ops.items()): + + if isinstance(op, InfiniOpDef): + generated = _generate_infini_op(op, args.output) + # Also generate CUDA-like backend wrappers for @infini_op. + generated += generate_wrappers_for_op(op, args.devices, args.output) + else: + generated = generate_wrappers_for_op(op, args.devices, args.output) + + # Process DSL variants (impl_index > 0). + variants = REGISTRY.variants(name) + + for variant in variants: + generated += _generate_infini_op(variant, args.output) + generated += generate_wrappers_for_op( + variant, args.devices, args.output + ) + + if variants: + impl_indices = [0] + [v.impl_index for v in variants] + generated += _generate_registry( + name, impl_indices, args.devices, args.output, op + ) + + total_generated += len(generated) + + if args.verify: + + for gen_path in generated: + # Map generated path to the existing hand-written path in src/. + rel = gen_path.relative_to(args.output) + existing_path = src_dir / rel + + if not existing_path.exists(): + print(f"NEW {rel}") + total_diffs += 1 + + continue + + expected = gen_path.read_text() + actual = existing_path.read_text() + + if expected != actual: + diff = _diff_file(expected, actual, str(rel)) + print(f"DIFF {rel}") + + for line in diff: + print(line, end="") + + print() + total_diffs += 1 + else: + print(f"OK {rel}") + + # Write per-operator implementation name mappings. + all_impl_names = REGISTRY.all_impl_names() + impl_names_path = args.output / "impl_names.json" + impl_names_path.parent.mkdir(parents=True, exist_ok=True) + impl_names_path.write_text(json.dumps(all_impl_names, indent=2) + "\n") + + # Generate pybind11 bindings and C API (replaces generate_wrappers.py). + if not args.verify: + from dsl.compiler.bindings import generate_all_bindings + + generate_all_bindings(args.devices, args.output, all_impl_names) + + if args.verify: + print(f"\n{total_generated} files checked, {total_diffs} differences.") + + if total_diffs: + sys.exit(1) + else: + print(f"Generated {total_generated} DSL files + bindings in {args.output}/") + + +if __name__ == "__main__": + main() diff --git a/dsl/compiler/__init__.py b/dsl/compiler/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/dsl/compiler/bindings.py b/dsl/compiler/bindings.py new file mode 100644 index 00000000..fa0c930c --- /dev/null +++ b/dsl/compiler/bindings.py @@ -0,0 +1,559 @@ +"""Generate pybind11 and C API bindings for InfiniOps operators.""" + +import json +import pathlib +import re +import shutil +import subprocess +import textwrap + +import clang.cindex +from clang.cindex import CursorKind + +_SRC_DIR = pathlib.Path("src") +_BASE_DIR = _SRC_DIR / "base" +_INDENTATION = " " + + +class _Operator: + def __init__(self, name, constructors, calls): + self.name = name + self.constructors = constructors + self.calls = calls + + +class _OperatorExtractor: + def __call__(self, op_name): + def _get_system_include_flags(): + def _get_compilers(): + compilers = [] + + for compiler in ("clang++", "g++"): + if shutil.which(compiler) is not None: + compilers.append(compiler) + + return compilers + + system_include_flags = [] + + for compiler in _get_compilers(): + for line in subprocess.getoutput( + f"{compiler} -E -x c++ -v /dev/null" + ).splitlines(): + if not line.startswith(" "): + continue + + system_include_flags.append("-isystem") + system_include_flags.append(line.strip()) + + return system_include_flags + + system_include_flags = _get_system_include_flags() + + index = clang.cindex.Index.create() + args = ("-std=c++17", "-x", "c++", "-I", "src") + tuple(system_include_flags) + translation_unit = index.parse(f"src/base/{op_name}.h", args=args) + + nodes = tuple(type(self)._find(translation_unit.cursor, op_name)) + + constructors = [] + calls = [] + + for node in nodes: + if node.kind == CursorKind.CONSTRUCTOR: + constructors.append(node) + elif node.kind == CursorKind.CXX_METHOD and node.spelling == "operator()": + calls.append(node) + + return _Operator(op_name, constructors, calls) + + @staticmethod + def _find(node, op_name): + pascal_case_op_name = _snake_to_pascal(op_name) + + if ( + node.semantic_parent + and node.semantic_parent.spelling == pascal_case_op_name + ): + yield node + + for child in node.get_children(): + yield from _OperatorExtractor._find(child, op_name) + + +def _find_optional_tensor_params(op_name): + """Return a set of parameter names declared as `std::optional` in + the base header. libclang resolves the type to ``int`` when the STL + headers are not fully available, so we fall back to a regex scan of the + source text. + """ + source = (_BASE_DIR / f"{op_name}.h").read_text() + + return set(re.findall(r"std::optional\s+(\w+)", source)) + + +def _find_vector_tensor_params(op_name): + """Return a set of parameter names declared as `std::vector` in + the base header. + """ + source = (_BASE_DIR / f"{op_name}.h").read_text() + + return set(re.findall(r"std::vector\s+(\w+)", source)) + + +def _generate_pybind11(operator, impl_names=None): + optional_tensor_params = _find_optional_tensor_params(operator.name) + vector_tensor_params = _find_vector_tensor_params(operator.name) + + if impl_names is None: + impl_names = {} + + def _is_optional_tensor(arg): + if arg.spelling in optional_tensor_params: + return True + return "std::optional" in arg.type.spelling and "Tensor" in arg.type.spelling + + def _is_vector_tensor(arg): + if arg.spelling in vector_tensor_params: + return True + return "std::vector" in arg.type.spelling and "Tensor" in arg.type.spelling + + def _generate_params(node): + parts = [] + + for arg in node.get_arguments(): + if arg.spelling == "stream": + continue + + if _is_optional_tensor(arg): + parts.append(f"std::optional {arg.spelling}") + elif _is_vector_tensor(arg): + parts.append(f"std::vector {arg.spelling}") + else: + param = arg.type.spelling.replace("const Tensor", "py::object").replace( + "Tensor", "py::object" + ) + parts.append(f"{param} {arg.spelling}") + + return ", ".join(parts) + + def _generate_arguments(node): + args = [] + + for arg in node.get_arguments(): + if arg.spelling == "stream": + continue + + if _is_optional_tensor(arg): + args.append(f"OptionalTensorFromPybind11Handle({arg.spelling})") + elif _is_vector_tensor(arg): + args.append(f"VectorTensorFromPybind11Handle({arg.spelling})") + elif "Tensor" in arg.type.spelling: + args.append(f"TensorFromPybind11Handle({arg.spelling})") + else: + args.append(arg.spelling) + + return ", ".join(args) + + op_name = operator.name + + def _generate_init(constructor): + constructor_params = _generate_params(constructor) + + return f""" .def(py::init([]({constructor_params}) {{ + return std::unique_ptr{{static_cast(Self::make({_generate_arguments(constructor)}).release())}}; + }}))""" + + def _generate_py_args(node): + return ", ".join( + f'py::arg("{arg.spelling}")' + for arg in node.get_arguments() + if arg.spelling != "stream" + ) + + def _generate_call(op_name, call, method=True): + call_params = _generate_params(call) + call_args = _generate_arguments(call) + + if not method: + # Overload 1: implementation_index (numeric, backward compatible). + params_idx = ( + f"{call_params}, std::size_t implementation_index, std::uintptr_t stream" + if call_params + else "std::size_t implementation_index, std::uintptr_t stream" + ) + py_args = _generate_py_args(call) + py_args_str = f"{py_args}, " if py_args else "" + + overload_idx = ( + f' m.def("{op_name}", []({params_idx}) {{\n' + f" Config config;\n" + f" config.set_implementation_index(implementation_index);\n" + f" Handle handle;\n" + f" if (stream) {{\n" + f" handle.set_stream(reinterpret_cast(stream));\n" + f" }}\n" + f" return Self::call(handle, config, {call_args});\n" + f' }}, {py_args_str}py::kw_only(), py::arg("implementation_index") = 0, py::arg("stream") = 0);' + ) + + # Overload 2: implementation (string name, e.g. "dsl"). + # Only generate if there are named implementations. + if not impl_names: + return overload_idx + + # Build C++ initializer list for the per-operator map. + map_entries = ", ".join( + f'{{"{name}", {idx}}}' for name, idx in impl_names.items() + ) + valid_names = ", ".join(f"'{n}'" for n in impl_names) + + params_str = ( + f"{call_params}, const std::string& implementation, std::uintptr_t stream" + if call_params + else "const std::string& implementation, std::uintptr_t stream" + ) + + overload_str = ( + f' m.def("{op_name}", []({params_str}) {{\n' + f" static const std::unordered_map kImplNames{{{{{map_entries}}}}};\n" + f" auto it = kImplNames.find(implementation);\n" + f" if (it == kImplNames.end()) throw py::value_error(\n" + f' "unknown implementation: \'" + implementation + "\' (valid: {valid_names})");\n' + f" Config config;\n" + f" config.set_implementation_index(it->second);\n" + f" Handle handle;\n" + f" if (stream) {{\n" + f" handle.set_stream(reinterpret_cast(stream));\n" + f" }}\n" + f" return Self::call(handle, config, {call_args});\n" + f' }}, {py_args_str}py::kw_only(), py::arg("implementation"), py::arg("stream") = 0);' + ) + + return f"{overload_idx}\n{overload_str}" + + return f""" .def("__call__", [](const Self& self, {call_params}) {{ + return static_cast&>(self)({call_args}); + }})""" + + inits = "\n".join( + _generate_init(constructor) for constructor in operator.constructors + ) + calls = "\n".join(_generate_call(operator.name, call) for call in operator.calls) + callers = "\n".join( + _generate_call(operator.name, call, method=False) for call in operator.calls + ) + + pascal_case_op_name = _snake_to_pascal(op_name) + + return f"""#ifndef INFINI_OPS_BINDINGS_{op_name.upper()}_H_ +#define INFINI_OPS_BINDINGS_{op_name.upper()}_H_ + +#include +#include + +#include "base/{op_name}.h" +#include "config.h" +#include "handle.h" +#include "operator.h" +#include "pybind11_utils.h" + +namespace py = pybind11; + +namespace infini::ops {{ + +void Bind{pascal_case_op_name}(py::module& m) {{ + using Self = {pascal_case_op_name}; + + py::class_(m, "{pascal_case_op_name}") +{inits} +{calls} + .def_static("active_implementation_indices", [](const std::string& device) {{ + return Self::active_implementation_indices(DeviceTypeFromString(device)); + }}); + +{callers} +}} + +}} // namespace infini::ops + +#endif +""" + + +def _generate_legacy_c(operator, paths): + def _generate_source(operator): + impl_includes = "\n".join( + f'#include "{str(path).removeprefix("src/")}"' for path in paths + ) + + return f"""#include "../../handle.h" +#include "../../tensor.h" +#include "infiniop/ops/{operator.name.lower()}.h" +{impl_includes} + +static infini::ops::DataType DataTypeFromInfiniDType( + const infiniDtype_t& dtype) {{ + static constexpr infini::ops::ConstexprMap + kInfiniDTypeToDataType{{ + {{{{{{INFINI_DTYPE_I8, infini::ops::DataType::kInt8}}, + {{INFINI_DTYPE_I16, infini::ops::DataType::kInt16}}, + {{INFINI_DTYPE_I32, infini::ops::DataType::kInt32}}, + {{INFINI_DTYPE_I64, infini::ops::DataType::kInt64}}, + {{INFINI_DTYPE_U8, infini::ops::DataType::kUInt8}}, + {{INFINI_DTYPE_U16, infini::ops::DataType::kUInt16}}, + {{INFINI_DTYPE_U32, infini::ops::DataType::kUInt32}}, + {{INFINI_DTYPE_U64, infini::ops::DataType::kUInt64}}, + {{INFINI_DTYPE_F16, infini::ops::DataType::kFloat16}}, + {{INFINI_DTYPE_BF16, infini::ops::DataType::kBFloat16}}, + {{INFINI_DTYPE_F32, infini::ops::DataType::kFloat32}}, + {{INFINI_DTYPE_F64, infini::ops::DataType::kFloat64}}}}}}}}; + + return kInfiniDTypeToDataType.at(dtype); +}} + +static infini::ops::Device::Type DeviceTypeFromInfiniDevice( + const infiniDevice_t& device) {{ + static constexpr infini::ops::ConstexprMap< + infiniDevice_t, infini::ops::Device::Type, + static_cast(INFINI_DEVICE_TYPE_COUNT)> + kInfiniDeviceToDeviceType{{ + {{{{{{INFINI_DEVICE_CPU, infini::ops::Device::Type::kCpu}}, + {{INFINI_DEVICE_NVIDIA, infini::ops::Device::Type::kNvidia}}, + {{INFINI_DEVICE_CAMBRICON, infini::ops::Device::Type::kCambricon}}, + {{INFINI_DEVICE_ASCEND, infini::ops::Device::Type::kAscend}}, + {{INFINI_DEVICE_METAX, infini::ops::Device::Type::kMetax}}, + {{INFINI_DEVICE_MOORE, infini::ops::Device::Type::kMoore}}, + {{INFINI_DEVICE_ILUVATAR, infini::ops::Device::Type::kIluvatar}}, + {{INFINI_DEVICE_KUNLUN, infini::ops::Device::Type::kKunlun}}, + {{INFINI_DEVICE_HYGON, infini::ops::Device::Type::kHygon}}, + {{INFINI_DEVICE_QY, infini::ops::Device::Type::kQy}}}}}}}}; + + return kInfiniDeviceToDeviceType.at(device); +}} + +__C {_generate_create_func_def(operator)} + +__C {_generate_get_workspace_size_func_def(operator)} + +__C {_generate_call_func_def(operator)} + +__C {_generate_destroy_func_def(operator)} +""" + + def _generate_header(operator): + return f"""#ifndef __INFINIOP_{operator.name.upper()}_API_H__ +#define __INFINIOP_{operator.name.upper()}_API_H__ + +#include "base/{operator.name.lower()}.h" + +typedef struct infini::ops::Operator *infiniop{operator.name}Descriptor_t; + +__C __export {_generate_create_func_decl(operator)}; + +__C __export {_generate_get_workspace_size_func_decl(operator)}; + +__C __export {_generate_call_func_decl(operator)}; + +__C __export {_generate_destroy_func_decl(operator)}; + +#endif +""" + + def _generate_create_func_def(operator): + name = operator.name + constructor = operator.constructors[-1] + + return f"""{_generate_create_func_decl(operator)} {{ + *desc_ptr = infini::ops::Operator::make({_generate_arguments(constructor)}).release(); + + return INFINI_STATUS_SUCCESS; +}}""" + + def _generate_get_workspace_size_func_def(operator): + return f"""{_generate_get_workspace_size_func_decl(operator)} {{ + *size = 0; // desc->workspace_size(); + + return INFINI_STATUS_SUCCESS; +}}""" + + def _generate_call_func_def(operator): + call = operator.calls[-1] + + return f"""{_generate_call_func_decl(operator)} {{ + (*desc)(stream, {_generate_arguments(call, is_data=True)}); + + return INFINI_STATUS_SUCCESS; +}}""" + + def _generate_destroy_func_def(operator): + return f"""{_generate_destroy_func_decl(operator)} {{ + delete desc; + + return INFINI_STATUS_SUCCESS; +}}""" + + def _generate_create_func_decl(operator): + name = operator.name + constructor = operator.constructors[-1] + params = _generate_params(constructor) + + return f"infiniStatus_t infiniopCreate{name}Descriptor(infiniopHandle_t handle, infiniop{name}Descriptor_t *desc_ptr, {params})" + + def _generate_get_workspace_size_func_decl(operator): + name = operator.name + + return f"infiniStatus_t infiniopGet{name}WorkspaceSize(infiniop{name}Descriptor_t desc, size_t *size)" + + def _generate_call_func_decl(operator): + name = operator.name + call = operator.calls[-1] + params = _generate_params(call, call=True) + params = params.replace("void * stream, ", "") + + return f"infiniStatus_t infiniop{name}(infiniop{name}Descriptor_t desc, void *workspace, size_t workspace_size, {params}, void *stream)" + + def _generate_destroy_func_decl(operator): + name = operator.name + + return f"infiniStatus_t infiniopDestroy{name}Descriptor(infiniop{name}Descriptor_t desc)" + + def _generate_params(node, call=False): + arguments = tuple(node.get_arguments()) + + arguments = (arguments[-1], *arguments[:-1]) + + def _handle_tensor(spelling): + if call: + return spelling.replace("Tensor", "void *") + return spelling.replace("Tensor", "infiniopTensorDescriptor_t") + + def _handle_std_optional(spelling): + return spelling.replace("std::optional<", "").replace(">", "") + + return ", ".join( + f"{_handle_std_optional(_handle_tensor(arg.type.spelling))} {arg.spelling}" + for arg in arguments + ) + + def _generate_arguments(node, is_data=False): + return ", ".join( + _generate_tensor_caster(arg.spelling, is_data=is_data) + if "Tensor" in arg.type.spelling + else arg.spelling + for arg in node.get_arguments() + if arg.spelling != "handle" and arg.spelling != "stream" + ) + + def _generate_tensor_caster(name, is_data=False): + if is_data: + return f"infini::ops::Tensor(const_cast({name}), infini::ops::Tensor::Shape{{}})" + + return f"infini::ops::Tensor{{nullptr, {name}->shape(), DataTypeFromInfiniDType({name}->dtype()), infini::ops::Device{{DeviceTypeFromInfiniDevice(handle->device), handle->device_id}}, {name}->strides()}}" + + return _generate_source(operator), _generate_header(operator) + + +def _snake_to_pascal(snake_str): + return "".join(word.capitalize() for word in snake_str.split("_")) + + +def _get_all_ops(devices, output_dir=None): + ops = {} + + for file_path in _BASE_DIR.iterdir(): + if not file_path.is_file(): + continue + + op_name = file_path.stem + ops[op_name] = [] + + search_dirs = [_SRC_DIR] + + if output_dir is not None: + search_dirs.append(output_dir) + + for search_dir in search_dirs: + for file_path in search_dir.rglob("*"): + if ( + not file_path.is_file() + or file_path.parent.parent.name not in devices + ): + continue + + if ( + f"class Operator<{_snake_to_pascal(op_name)}" + in file_path.read_text() + ): + ops[op_name].append(file_path) + + return ops + + +def generate_all_bindings( + devices: list[str], + output_dir: pathlib.Path, + impl_names: dict[str, dict[str, int]], +) -> None: + """Generate pybind11 bindings and C API for all operators.""" + bindings_dir = output_dir / "bindings" + generated_src_dir = output_dir / "src" + include_dir = output_dir / "include" + + bindings_dir.mkdir(parents=True, exist_ok=True) + generated_src_dir.mkdir(parents=True, exist_ok=True) + include_dir.mkdir(parents=True, exist_ok=True) + + ops_json = pathlib.Path("ops.json") + + if ops_json.exists(): + ops = json.loads(ops_json.read_text()) + else: + ops = _get_all_ops(devices, output_dir) + + header_paths = [] + bind_func_names = [] + + for op_name, impl_paths in ops.items(): + extractor = _OperatorExtractor() + operator = extractor(op_name) + + pascal_name = _snake_to_pascal(op_name) + op_impl_names = impl_names.get(pascal_name, {}) + + source_path = generated_src_dir / op_name + header_name = f"{op_name}.h" + bind_func_name = f"Bind{pascal_name}" + + (bindings_dir / header_name).write_text( + _generate_pybind11(operator, op_impl_names) + ) + + legacy_c_source, legacy_c_header = _generate_legacy_c(operator, impl_paths) + source_path.mkdir(exist_ok=True) + (generated_src_dir / op_name / "operator.cc").write_text(legacy_c_source) + (include_dir / header_name).write_text(legacy_c_header) + + header_paths.append(header_name) + bind_func_names.append(bind_func_name) + + # Assemble ops.cc. + impl_includes = "\n".join( + f'#include "{impl_path}"' + for impl_paths in ops.values() + for impl_path in impl_paths + ) + op_includes = "\n".join(f'#include "{h}"' for h in header_paths) + bind_func_calls = "\n".join(f"{f}(m);" for f in bind_func_names) + + (bindings_dir / "ops.cc").write_text( + f"#include \n\n" + f"// clang-format off\n{impl_includes}\n// clang-format on\n\n" + f"{op_includes}\n\n" + f"namespace infini::ops {{\n\n" + f"PYBIND11_MODULE(ops, m) {{\n" + f"{textwrap.indent(bind_func_calls, _INDENTATION)}\n" + f"}}\n\n" + f"}} // namespace infini::ops\n" + ) diff --git a/dsl/compiler/codegen.py b/dsl/compiler/codegen.py new file mode 100644 index 00000000..b858aa05 --- /dev/null +++ b/dsl/compiler/codegen.py @@ -0,0 +1,278 @@ +"""C++ code generation for backend wrapper files.""" + +from __future__ import annotations + +import pathlib +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from dsl.decorators import InfiniOpDef, ManualOpDef + +# Backend identifiers used in Device::Type enum. +CUDA_LIKE_BACKENDS = ("nvidia", "metax", "iluvatar", "moore") + +# Maps backend name → Device::Type enum suffix (PascalCase). +BACKEND_ENUM = { + "nvidia": "Nvidia", + "metax": "Metax", + "iluvatar": "Iluvatar", + "moore": "Moore", + "ascend": "Ascend", + "cambricon": "Cambricon", + "cpu": "Cpu", +} + + +def _pascal_case(snake: str) -> str: + return "".join(w.capitalize() for w in snake.split("_")) + + +def _to_snake(pascal: str) -> str: + """Convert PascalCase to snake_case.""" + import re + + return re.sub(r"(?<=[a-z0-9])(?=[A-Z])", "_", pascal).lower() + + +def _include_guard(backend: str, op_snake: str, filename: str) -> str: + """Build an include guard matching the project convention.""" + stem = pathlib.Path(filename).stem + suffix = pathlib.Path(filename).suffix.lstrip(".") + + # Example: INFINI_OPS_NVIDIA_ADD_KERNEL_H_ + parts = ["INFINI_OPS", backend.upper(), op_snake.upper(), stem.upper()] + parts.append(f"{suffix.upper()}_" if suffix else "H_") + + return "_".join(parts) + + +# ---- CUDA-like wrapper generation ---------------------------------------- + + +def _resolve_cuda_template_info( + op: ManualOpDef | InfiniOpDef, +) -> tuple[str, str] | None: + """Derive the shared CUDA template class name and include path. + + Returns ``(CudaClassName, include_path)`` or ``None`` if the operator + does not use a shared CUDA template. + """ + from dsl.decorators import InfiniOpDef + + if isinstance(op, InfiniOpDef): + op_snake = _to_snake(op.name) + prefix = "Dsl" if op.impl_index > 0 else "" + filename = "dsl.h" if op.impl_index > 0 else "kernel.h" + + return f"{prefix}Cuda{op.name}", f"cuda/{op_snake}/{filename}" + + cuda_entry = op.backends.get("cuda") + + if cuda_entry is None: + return None + + if isinstance(cuda_entry, dict): + # Complex BLAS-style entry: {"include": ..., "class": ..., "blas": True} + return cuda_entry.get("class"), cuda_entry.get("include") + + # Simple string: "cuda/add/kernel.h" → CudaAdd (convention: Cuda + OpName). + return f"Cuda{op.name}", cuda_entry + + +def generate_cuda_wrapper( + op: ManualOpDef | InfiniOpDef, + backend: str, + impl_index: int | None = None, +) -> str: + """Generate a CUDA-like backend wrapper header. + + For operators backed by a shared ``Cuda*>`` template. + """ + from dsl.decorators import InfiniOpDef + + op_snake = _to_snake(op.name) + enum_name = BACKEND_ENUM[backend] + filename = "dsl.h" if isinstance(op, InfiniOpDef) and op.impl_index > 0 else "kernel.h" + guard = _include_guard(backend, op_snake, filename) + + info = _resolve_cuda_template_info(op) + + if info is None: + raise ValueError( + f"Operator `{op.name}` has no `cuda` entry in backends; " + f"cannot generate a CUDA-like wrapper for `{backend}`." + ) + + cuda_class, cuda_include = info + + # Build the template specialization. + device_type = f"Device::Type::k{enum_name}" + need_impl_h = False + + if impl_index is not None and impl_index > 0: + device_type += ", Impl::kDsl" + need_impl_h = True + + # Collect includes — no blank lines between them (matches existing style). + lines: list[str] = ["#include ", ""] + + if need_impl_h: + lines.append('#include "impl.h"') + lines.append(f'#include "{backend}/{op_snake}/registry.h"') + lines.append("") + + if backend == "moore": + lines.append("// clang-format off") + lines.append('#include "moore/polyfills.cuh"') + lines.append("// clang-format on") + lines.append("") + + lines.append(f'#include "{cuda_include}"') + lines.append(f'#include "{backend}/caster.cuh"') + + if backend == "moore": + lines.append('#include "moore/polyfills.cuh"') + + lines.append(f'#include "{backend}/runtime_.h"') + + includes_str = "\n".join(lines) + + return "\n".join([ + f"#ifndef {guard}", + f"#define {guard}", + "", + includes_str, + "", + "namespace infini::ops {", + "", + "template <>", + f"class Operator<{op.name}, {device_type}>", + f" : public {cuda_class}> {{", + " public:", + f" using {cuda_class}>::{cuda_class};", + "};", + "", + "} // namespace infini::ops", + "", + "#endif", + "", + ]) + + +def generate_blas_wrapper( + op: ManualOpDef, + backend: str, + blas_class: str, + blas_include: str, + impl_index: int | None = None, +) -> str: + """Generate a BLAS-based backend wrapper (e.g. GEMM via cuBLAS).""" + op_snake = _to_snake(op.name) + enum_name = BACKEND_ENUM[backend] + guard = _include_guard(backend, op_snake, "kernel.h") + + device_type = f"Device::Type::k{enum_name}" + + if impl_index is not None: + device_type += f", {impl_index}" + + # Include the platform's registry if the operator has one in src/. + registry_path = pathlib.Path(f"src/{backend}/{op_snake}/registry.h") + registry_include = ( + f'#include "{backend}/{op_snake}/registry.h"\n' + if registry_path.exists() + else "" + ) + + return ( + f"#ifndef {guard}\n" + f"#define {guard}\n" + f"\n" + f'#include "{blas_include}"\n' + f'#include "{backend}/blas.h"\n' + f"{registry_include}" + f"\n" + f"namespace infini::ops {{\n" + f"\n" + f"template <>\n" + f"class Operator<{op.name}, {device_type}>\n" + f" : public {blas_class}> {{\n" + f" public:\n" + f" using {blas_class}>::{blas_class};\n" + f"}};\n" + f"\n" + f"}} // namespace infini::ops\n" + f"\n" + f"#endif\n" + ) + + +# ---- High-level generation entry point ----------------------------------- + + +def generate_wrappers_for_op( + op: ManualOpDef | InfiniOpDef, + devices: list[str], + output_dir: pathlib.Path, +) -> list[pathlib.Path]: + """Generate backend wrapper files for an operator. + + Works for both ``@manual_op`` and ``@infini_op`` operators. + For ``@infini_op``, the shared CUDA template is the generated + ``cuda//kernel.h`` file. + + Returns a list of generated file paths. + """ + from dsl.decorators import ManualOpDef + + op_snake = _to_snake(op.name) + generated: list[pathlib.Path] = [] + + # Build an effective backends dict. + if isinstance(op, ManualOpDef): + backends = op.backends + else: + # For @infini_op, the CUDA kernel is auto-generated. + backends = dict(op.manual_backends) + backends["cuda"] = f"cuda/{op_snake}/kernel.h" + + # Determine impl_index and output filename. + impl_index = getattr(op, "impl_index", None) + out_filename = "dsl.h" if impl_index and impl_index > 0 else "kernel.h" + + # Check if the cuda entry is a BLAS-style operator. + cuda_entry = backends.get("cuda") + is_blas = isinstance(cuda_entry, dict) and cuda_entry.get("blas", False) + + for backend in devices: + + if backend not in CUDA_LIKE_BACKENDS: + continue + + if backend not in backends and "cuda" not in backends: + continue + + # Check for an explicit backend entry (overrides shared CUDA path). + explicit = backends.get(backend) + + if explicit is not None and isinstance(explicit, str): + # Explicit hand-written file — do not generate a wrapper. + continue + + if is_blas: + # Generate BLAS-based wrapper (e.g., BlasGemm>). + blas_class = cuda_entry["class"] + blas_include = cuda_entry["include"] + content = generate_blas_wrapper( + op, backend, blas_class, blas_include, impl_index=impl_index + ) + else: + # Generate standard CUDA wrapper (e.g., CudaOp>). + content = generate_cuda_wrapper(op, backend, impl_index=impl_index) + + out_path = output_dir / backend / op_snake / out_filename + out_path.parent.mkdir(parents=True, exist_ok=True) + out_path.write_text(content) + generated.append(out_path) + + return generated diff --git a/dsl/compiler/dag.py b/dsl/compiler/dag.py new file mode 100644 index 00000000..9556fd8d --- /dev/null +++ b/dsl/compiler/dag.py @@ -0,0 +1,203 @@ +"""Compute DAG representation for `@infini_op` operators.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import Enum, auto +from typing import Any + + +class NodeKind(Enum): + """Primitive operation types in the compute DAG.""" + + # Inputs. + INPUT = auto() + SCALAR = auto() + + # Elementwise unary. + NEG = auto() + ABS = auto() + SQRT = auto() + RSQRT = auto() + EXP = auto() + LOG = auto() + + # Elementwise binary. + ADD = auto() + SUB = auto() + MUL = auto() + DIV = auto() + POW = auto() + + # Activations. + RELU = auto() + GELU = auto() + SILU = auto() + SIGMOID = auto() + TANH = auto() + + # Reductions. + REDUCE_SUM = auto() + REDUCE_MEAN = auto() + REDUCE_MAX = auto() + REDUCE_MIN = auto() + + # Comparison / conditional. + WHERE = auto() + GT = auto() + LT = auto() + GE = auto() + LE = auto() + EQ = auto() + + # Type. + CAST = auto() + + # Clamp. + CLAMP = auto() + + +# Classify node kinds into categories for pattern matching. +ELEMENTWISE_UNARY = { + NodeKind.NEG, + NodeKind.ABS, + NodeKind.SQRT, + NodeKind.RSQRT, + NodeKind.EXP, + NodeKind.LOG, + NodeKind.RELU, + NodeKind.GELU, + NodeKind.SILU, + NodeKind.SIGMOID, + NodeKind.TANH, +} + +ELEMENTWISE_BINARY = { + NodeKind.ADD, + NodeKind.SUB, + NodeKind.MUL, + NodeKind.DIV, + NodeKind.POW, + NodeKind.GT, + NodeKind.LT, + NodeKind.GE, + NodeKind.LE, + NodeKind.EQ, +} + +ELEMENTWISE = ELEMENTWISE_UNARY | ELEMENTWISE_BINARY | { + NodeKind.WHERE, + NodeKind.CAST, + NodeKind.CLAMP, +} + +REDUCTIONS = { + NodeKind.REDUCE_SUM, + NodeKind.REDUCE_MEAN, + NodeKind.REDUCE_MAX, + NodeKind.REDUCE_MIN, +} + + +@dataclass +class DagNode: + """A single node in the compute DAG.""" + + id: int + kind: NodeKind + inputs: list[int] = field(default_factory=list) + + # Shape variable name (e.g. "B", "H", "D") for inputs. + shape: list[str] | None = None + + # For INPUT: parameter name; for SCALAR: value/name. + name: str | None = None + + # For reductions: the shape variable being reduced over. + reduce_dim: str | None = None + + # For CAST: target dtype string. + cast_dtype: str | None = None + + # For CLAMP: min/max bounds. + clamp_min: float | None = None + clamp_max: float | None = None + + # Arbitrary extra attributes. + attrs: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class ComputeDAG: + """A directed acyclic graph of primitive operations. + + Built by the parser from an `@infini_op` function body. + """ + + nodes: dict[int, DagNode] = field(default_factory=dict) + output_id: int | None = None + + # Shape variables declared in the operator definition. + shape_vars: dict[str, str] = field(default_factory=dict) + + _next_id: int = field(default=0, repr=False) + + def add_node(self, kind: NodeKind, **kwargs: Any) -> int: + """Create a new node and return its id.""" + nid = self._next_id + self._next_id += 1 + self.nodes[nid] = DagNode(id=nid, kind=kind, **kwargs) + + return nid + + def get(self, nid: int) -> DagNode: + return self.nodes[nid] + + def consumers(self, nid: int) -> list[int]: + """Return ids of nodes that consume ``nid`` as an input.""" + + return [ + n.id for n in self.nodes.values() if nid in n.inputs + ] + + def is_elementwise_only(self) -> bool: + """True if the DAG contains only elementwise ops (no reductions).""" + + for node in self.nodes.values(): + + if node.kind in REDUCTIONS: + return False + + return True + + def has_reduction(self) -> bool: + """True if any node is a reduction.""" + + return any(n.kind in REDUCTIONS for n in self.nodes.values()) + + def reduction_nodes(self) -> list[DagNode]: + """Return all reduction nodes.""" + + return [n for n in self.nodes.values() if n.kind in REDUCTIONS] + + def topo_sort(self) -> list[int]: + """Return node ids in topological order.""" + visited: set[int] = set() + order: list[int] = [] + + def dfs(nid: int) -> None: + + if nid in visited: + return + + visited.add(nid) + + for inp in self.nodes[nid].inputs: + dfs(inp) + + order.append(nid) + + for nid in self.nodes: + dfs(nid) + + return order diff --git a/dsl/compiler/infini_codegen.py b/dsl/compiler/infini_codegen.py new file mode 100644 index 00000000..07a990ac --- /dev/null +++ b/dsl/compiler/infini_codegen.py @@ -0,0 +1,1124 @@ +"""C++ code generation for `@infini_op` operators. + +Translates a matched compute DAG into C++ source files that compose +template bricks from `src/cuda/templates/` and `src/cpu/templates/`. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from dsl.compiler.dag import ComputeDAG, DagNode, NodeKind +from dsl.compiler.patterns import BrickKind, MatchResult + +if TYPE_CHECKING: + from dsl.decorators import InfiniOpDef + + +def _to_snake(pascal: str) -> str: + """Convert PascalCase to snake_case.""" + import re + + return re.sub(r"(?<=[a-z0-9])(?=[A-Z])", "_", pascal).lower() + + +# ---- Functor C++ expression generation ------------------------------------- + + +# Map DAG node kinds to C++ operator/function expressions. +_CUDA_BINOP: dict[NodeKind, str] = { + NodeKind.ADD: "+", + NodeKind.SUB: "-", + NodeKind.MUL: "*", + NodeKind.DIV: "/", +} + +_CUDA_UNARY_FUNC: dict[NodeKind, str] = { + NodeKind.SQRT: "sqrtf", + NodeKind.RSQRT: "rsqrtf", + NodeKind.EXP: "expf", + NodeKind.LOG: "logf", + NodeKind.ABS: "fabsf", + NodeKind.TANH: "tanhf", +} + +_CPU_UNARY_FUNC: dict[NodeKind, str] = { + NodeKind.SQRT: "std::sqrt", + NodeKind.RSQRT: "1.f / std::sqrt", + NodeKind.EXP: "std::exp", + NodeKind.LOG: "std::log", + NodeKind.ABS: "std::abs", + NodeKind.TANH: "std::tanh", +} + +_ACTIVATION_CUDA: dict[NodeKind, str] = { + NodeKind.RELU: "v > 0 ? v : static_cast(0)", + NodeKind.SIGMOID: "static_cast(1) / (static_cast(1) + expf(-v))", + NodeKind.SILU: "v / (static_cast(1) + expf(-v))", +} + +_ACTIVATION_CPU: dict[NodeKind, str] = { + NodeKind.RELU: "v > 0 ? v : static_cast(0)", + NodeKind.SIGMOID: "static_cast(1) / (static_cast(1) + std::exp(-v))", + NodeKind.SILU: "v / (static_cast(1) + std::exp(-v))", +} + + +def _expr_for_node( + dag: ComputeDAG, + node: DagNode, + var_map: dict[int, str], + is_cuda: bool, +) -> str: + """Generate a C++ expression string for a single DAG node. + + ``var_map`` maps node id → C++ variable name for already-emitted nodes. + """ + + def _ref(nid: int) -> str: + return var_map[nid] + + if node.kind in _CUDA_BINOP: + op = _CUDA_BINOP[node.kind] + + return f"({_ref(node.inputs[0])} {op} {_ref(node.inputs[1])})" + + unary_map = _CUDA_UNARY_FUNC if is_cuda else _CPU_UNARY_FUNC + + if node.kind in unary_map: + func = unary_map[node.kind] + + if node.kind == NodeKind.RSQRT and not is_cuda: + return f"(1.f / std::sqrt({_ref(node.inputs[0])}))" + + return f"{func}({_ref(node.inputs[0])})" + + if node.kind == NodeKind.NEG: + return f"(-{_ref(node.inputs[0])})" + + act_map = _ACTIVATION_CUDA if is_cuda else _ACTIVATION_CPU + + if node.kind in act_map: + # Activation functions expect the variable to be named `v`. + return act_map[node.kind].replace("v", _ref(node.inputs[0])) + + if node.kind == NodeKind.WHERE: + return ( + f"({_ref(node.inputs[0])} ? " + f"{_ref(node.inputs[1])} : {_ref(node.inputs[2])})" + ) + + if node.kind == NodeKind.POW: + func = "powf" if is_cuda else "std::pow" + + return f"{func}({_ref(node.inputs[0])}, {_ref(node.inputs[1])})" + + if node.kind == NodeKind.CAST: + # Type conversion — the actual cast is handled by the functor's + # return-type conversion, so just pass through the input expression. + return _ref(node.inputs[0]) + + if node.kind == NodeKind.SCALAR: + # Literal scalar. + val = node.attrs.get("value") + + if val is not None: + return repr(val) + + return node.name or "0" + + raise ValueError(f"Cannot generate expression for node kind: {node.kind}.") + + +# ---- Binary elementwise code generation ------------------------------------ + + +def _dsl_prefix(op: InfiniOpDef) -> str: + """Return the prefix for DSL-generated class names. + + When ``impl_index > 0``, class names are prefixed with ``Dsl`` to + avoid collisions with the hand-written implementation. + """ + + return "Dsl" if op.impl_index > 0 else "" + + +def _generate_binary_functor_cuda( + op: InfiniOpDef, + dag: ComputeDAG, + match: MatchResult, +) -> str: + """Generate the device-side binary functor for CUDA.""" + prefix = _dsl_prefix(op) + + # Build the functor body by walking the DAG in topological order. + topo = dag.topo_sort() + var_map: dict[int, str] = {} + body_lines: list[str] = [] + + for nid in topo: + node = dag.get(nid) + + if node.kind == NodeKind.INPUT: + if node.name == match.input_names[0]: + var_map[nid] = "va" + elif node.name == match.input_names[1]: + var_map[nid] = "vb" + else: + var_map[nid] = node.name + + continue + + if node.kind == NodeKind.SCALAR: + val = node.attrs.get("value") + + if val is not None: + var_map[nid] = repr(val) + else: + var_map[nid] = node.name + + continue + + expr = _expr_for_node(dag, node, var_map, is_cuda=True) + + if nid == dag.output_id: + body_lines.append(f" return Caster::template Cast({expr});") + else: + vname = f"t{nid}" + body_lines.append(f" auto {vname} = {expr};") + var_map[nid] = vname + + body = "\n".join(body_lines) + functor_name = f"{prefix}{op.name}Op" + + return f"""\ +// Device-side binary functor for `{op.name}` (DSL). +template +struct {functor_name} {{ + template + __device__ __forceinline__ T operator()(const T& a, const T& b) const {{ + using ComputeType = float; + auto va = Caster::template Cast(a); + auto vb = Caster::template Cast(b); +{body} + }} +}};""" + + +def _generate_binary_functor_cpu( + op: InfiniOpDef, + dag: ComputeDAG, + match: MatchResult, +) -> str: + """Generate the host-side binary functor for CPU.""" + prefix = _dsl_prefix(op) + topo = dag.topo_sort() + var_map: dict[int, str] = {} + body_lines: list[str] = [] + + for nid in topo: + node = dag.get(nid) + + if node.kind == NodeKind.INPUT: + if node.name == match.input_names[0]: + var_map[nid] = "va" + elif node.name == match.input_names[1]: + var_map[nid] = "vb" + else: + var_map[nid] = node.name + + continue + + if node.kind == NodeKind.SCALAR: + val = node.attrs.get("value") + + if val is not None: + var_map[nid] = repr(val) + else: + var_map[nid] = node.name + + continue + + expr = _expr_for_node(dag, node, var_map, is_cuda=False) + + if nid == dag.output_id: + body_lines.append(f" return static_cast({expr});") + else: + vname = f"t{nid}" + body_lines.append(f" auto {vname} = {expr};") + var_map[nid] = vname + + body = "\n".join(body_lines) + functor_name = f"{prefix}Cpu{op.name}Op" + + return f"""\ +// Host-side binary functor for `{op.name}` (CPU, DSL). +struct {functor_name} {{ + template + T operator()(const T& a, const T& b) const {{ + using ComputeType = float; + auto va = static_cast(a); + auto vb = static_cast(b); +{body} + }} +}};""" + + +# ---- Unary elementwise code generation --------------------------------------- + + +def _generate_unary_functor_cuda( + op: InfiniOpDef, + dag: ComputeDAG, + match: MatchResult, +) -> str: + """Generate the device-side unary functor for CUDA.""" + prefix = _dsl_prefix(op) + + # Build the functor body by walking the DAG in topological order. + topo = dag.topo_sort() + var_map: dict[int, str] = {} + body_lines: list[str] = [] + + for nid in topo: + node = dag.get(nid) + + if node.kind == NodeKind.INPUT: + var_map[nid] = "va" + + continue + + if node.kind == NodeKind.SCALAR: + val = node.attrs.get("value") + + if val is not None: + var_map[nid] = repr(val) + else: + var_map[nid] = node.name + + continue + + expr = _expr_for_node(dag, node, var_map, is_cuda=True) + + if nid == dag.output_id: + body_lines.append(f" return Caster::template Cast({expr});") + else: + vname = f"t{nid}" + body_lines.append(f" auto {vname} = {expr};") + var_map[nid] = vname + + body = "\n".join(body_lines) + functor_name = f"{prefix}{op.name}Op" + + return f"""\ +// Device-side unary functor for `{op.name}` (DSL). +template +struct {functor_name} {{ + template + __device__ __forceinline__ TOut operator()(const TIn& x) const {{ + auto va = Caster::template Cast(x); +{body} + }} +}};""" + + +def _generate_unary_functor_cpu( + op: InfiniOpDef, + dag: ComputeDAG, + match: MatchResult, +) -> str: + """Generate the host-side unary functor for CPU.""" + prefix = _dsl_prefix(op) + topo = dag.topo_sort() + var_map: dict[int, str] = {} + body_lines: list[str] = [] + + for nid in topo: + node = dag.get(nid) + + if node.kind == NodeKind.INPUT: + var_map[nid] = "va" + + continue + + if node.kind == NodeKind.SCALAR: + val = node.attrs.get("value") + + if val is not None: + var_map[nid] = repr(val) + else: + var_map[nid] = node.name + + continue + + expr = _expr_for_node(dag, node, var_map, is_cuda=False) + + if nid == dag.output_id: + body_lines.append( + f" return Caster::Cast({expr});" + ) + else: + vname = f"t{nid}" + body_lines.append(f" auto {vname} = {expr};") + var_map[nid] = vname + + body = "\n".join(body_lines) + functor_name = f"{prefix}Cpu{op.name}Op" + + return f"""\ +// Host-side unary functor for `{op.name}` (CPU, DSL). +struct {functor_name} {{ + template + TOut operator()(const TIn& x) const {{ + auto va = Caster::Cast(x); +{body} + }} +}};""" + + +# ---- Reduce-then-transform code generation --------------------------------- + + +def _generate_reduce_op_cuda( + op: InfiniOpDef, + dag: ComputeDAG, + match: MatchResult, +) -> str: + """Generate the CUDA reduce op struct.""" + assert match.reduce_nodes is not None + + # Analyze the reduction pattern to determine the accumulation. + reduce_node = None + + for nid in match.reduce_nodes: + node = dag.get(nid) + + if node.kind in ( + NodeKind.REDUCE_SUM, + NodeKind.REDUCE_MEAN, + NodeKind.REDUCE_MAX, + NodeKind.REDUCE_MIN, + ): + reduce_node = node + + break + + assert reduce_node is not None + + # Determine the pre-reduce expression (what is accumulated). + pre_reduce_expr = _build_pre_reduce_expr(dag, reduce_node, is_cuda=True) + finalize_expr = _build_finalize_expr(dag, reduce_node, match, is_cuda=True) + + prefix = _dsl_prefix(op) + + return f"""\ +// Reduce op for `{op.name}` (DSL). +struct {prefix}{op.name}Reduce {{ + template + __device__ __forceinline__ float Accumulate(const TData* ptr, + size_t count) const {{ + float ss = 0; + + for (size_t i = threadIdx.x; i < count; i += block_size) {{ + float v = Caster::template Cast(ptr[i]); +{pre_reduce_expr} + }} + + return ss; + }} + + __device__ __forceinline__ float Finalize(float total, + size_t count) const {{ +{finalize_expr} + }} + +{_generate_reduce_members(op, dag, match)} +}};""" + + +def _generate_reduce_op_cpu( + op: InfiniOpDef, + dag: ComputeDAG, + match: MatchResult, +) -> str: + """Generate the CPU reduce op struct.""" + assert match.reduce_nodes is not None + + reduce_node = None + + for nid in match.reduce_nodes: + node = dag.get(nid) + + if node.kind in ( + NodeKind.REDUCE_SUM, + NodeKind.REDUCE_MEAN, + NodeKind.REDUCE_MAX, + NodeKind.REDUCE_MIN, + ): + reduce_node = node + + break + + assert reduce_node is not None + + init_val = _reduce_init_value(reduce_node.kind) + accum_expr = _build_accum_expr_scalar(dag, reduce_node, is_cuda=False) + finalize_expr = _build_finalize_expr(dag, reduce_node, match, is_cuda=False) + + prefix = _dsl_prefix(op) + + return f"""\ +// CPU reduce op for `{op.name}` (DSL). +struct {prefix}Cpu{op.name}Reduce {{ + float Init() const {{ return {init_val}; }} + + float Accumulate(float acc, float v) const {{ return {accum_expr}; }} + + float Finalize(float acc, size_t count) const {{ +{finalize_expr} + }} + +{_generate_reduce_members(op, dag, match)} +}};""" + + +def _generate_transform_op_cuda( + op: InfiniOpDef, + dag: ComputeDAG, + match: MatchResult, +) -> str: + """Generate the CUDA transform op struct.""" + transform_body = _build_transform_body(dag, match, is_cuda=True) + + prefix = _dsl_prefix(op) + + return f"""\ +// Transform op for `{op.name}` (DSL). +struct {prefix}{op.name}Transform {{ + template + __device__ __forceinline__ TData Apply(TData x, float reduced, + size_t i) const {{ +{transform_body} + }} + +{_generate_transform_members(op, dag, match)} +}};""" + + +def _generate_transform_op_cpu( + op: InfiniOpDef, + dag: ComputeDAG, + match: MatchResult, +) -> str: + """Generate the CPU transform op struct.""" + transform_body = _build_transform_body(dag, match, is_cuda=False) + + prefix = _dsl_prefix(op) + + return f"""\ +// CPU transform op for `{op.name}` (DSL). +struct {prefix}Cpu{op.name}Transform {{ + template + T Apply(T x, float reduced, size_t i) const {{ +{transform_body} + }} + +{_generate_transform_members(op, dag, match)} +}};""" + + +# ---- Helper functions for reduce/transform expression building ------------- + + +def _build_pre_reduce_expr( + dag: ComputeDAG, + reduce_node: DagNode, + is_cuda: bool, +) -> str: + """Build the inner-loop accumulation expression for the reduce phase.""" + + # Walk the inputs to the reduction to find what is being accumulated. + input_node_id = reduce_node.inputs[0] + input_node = dag.get(input_node_id) + + # Common pattern: reduce_mean(x * x) → sum of squares. + if ( + input_node.kind == NodeKind.MUL + and len(input_node.inputs) == 2 + and input_node.inputs[0] == input_node.inputs[1] + ): + return " ss += v * v;" + + # reduce_sum(x) or reduce_mean(x). + if input_node.kind == NodeKind.INPUT: + return " ss += v;" + + # Generic: just accumulate the expression. + var_map = {input_node.inputs[0]: "v"} if input_node.inputs else {"v": "v"} + + return " ss += v;" + + +def _build_accum_expr_scalar( + dag: ComputeDAG, + reduce_node: DagNode, + is_cuda: bool, +) -> str: + """Build the scalar accumulation expression for CPU reduce.""" + input_node_id = reduce_node.inputs[0] + input_node = dag.get(input_node_id) + + # reduce_mean(x * x) → acc + v * v. + if ( + input_node.kind == NodeKind.MUL + and len(input_node.inputs) == 2 + and input_node.inputs[0] == input_node.inputs[1] + ): + return "acc + v * v" + + return "acc + v" + + +def _reduce_init_value(kind: NodeKind) -> str: + """Return the identity element for a reduction.""" + + if kind in (NodeKind.REDUCE_SUM, NodeKind.REDUCE_MEAN): + return "0.f" + + if kind == NodeKind.REDUCE_MAX: + return "-INFINITY" + + if kind == NodeKind.REDUCE_MIN: + return "INFINITY" + + return "0.f" + + +def _build_finalize_expr( + dag: ComputeDAG, + reduce_node: DagNode, + match: MatchResult, + is_cuda: bool, +) -> str: + """Build the finalize expression after block reduction.""" + + # Check what happens after the reduction before the transform phase. + # Walk from the reduction output to find post-reduce ops. + consumers = dag.consumers(reduce_node.id) + topo = dag.topo_sort() + + # Find nodes between reduce and the first transform node. + reduce_idx = topo.index(reduce_node.id) + transform_start = ( + match.transform_nodes[0] if match.transform_nodes else dag.output_id + ) + + # Collect post-reduce nodes that are not transform nodes. + post_reduce: list[int] = [] + + for nid in topo[reduce_idx + 1 :]: + if match.transform_nodes and nid in match.transform_nodes: + break + + node = dag.get(nid) + + if node.kind not in (NodeKind.INPUT, NodeKind.SCALAR): + post_reduce.append(nid) + + # Common pattern: rsqrt(total / count + eps). + if reduce_node.kind == NodeKind.REDUCE_MEAN: + # Check for rsqrt(mean + eps) pattern in post_reduce or transform. + all_post = post_reduce + (match.transform_nodes or []) + + for nid in all_post: + node = dag.get(nid) + + if node.kind == NodeKind.RSQRT: + rsqrt_func = "rsqrtf" if is_cuda else "1.f / std::sqrt" + + if is_cuda: + return ( + " return rsqrtf(total / " + "static_cast(count) + epsilon);" + ) + + return ( + " return 1.f / std::sqrt(acc / " + "static_cast(count) + epsilon);" + ) + + # Plain mean. + if is_cuda: + return " return total / static_cast(count);" + + return " return acc / static_cast(count);" + + if reduce_node.kind == NodeKind.REDUCE_SUM: + if is_cuda: + return " return total;" + + return " return acc;" + + if reduce_node.kind == NodeKind.REDUCE_MAX: + if is_cuda: + return " return total;" + + return " return acc;" + + if is_cuda: + return " return total;" + + return " return acc;" + + +def _build_transform_body( + dag: ComputeDAG, + match: MatchResult, + is_cuda: bool, +) -> str: + """Build the transform phase body.""" + + # The transform applies: out[i] = f(in[i], reduced, i). + # Walk the DAG from the output backwards to understand the transform. + output_node = dag.get(dag.output_id) + + # Common pattern: input * reduced * weight[i]. + # For RmsNorm: return x * rms * weight[i]. + if _is_rms_norm_transform(dag, match): + if is_cuda: + return ( + " return Caster::template Cast(\n" + " Caster::template Cast(x) *\n" + " Caster::template Cast(" + "static_cast(weight)[i]) * reduced);" + ) + + return ( + " const auto* w = static_cast(weight);\n\n" + " return Caster::Cast(\n" + " Caster::Cast(x) *\n" + " Caster::Cast(w[i]) " + "* reduced);" + ) + + # Generic: input * reduced. + if is_cuda: + return ( + " return Caster::template Cast(\n" + " Caster::template Cast(x) * reduced);" + ) + + return ( + " return Caster::Cast(\n" + " Caster::Cast(x) * reduced);" + ) + + +def _is_rms_norm_transform(dag: ComputeDAG, match: MatchResult) -> bool: + """Check if the transform is ``x * reduced * weight[i]``.""" + + # Look for a weight tensor input. + for node in dag.nodes.values(): + if node.kind == NodeKind.INPUT and node.name == "weight": + return True + + return False + + +def _generate_reduce_members( + op: InfiniOpDef, + dag: ComputeDAG, + match: MatchResult, +) -> str: + """Generate member variables for the reduce op struct.""" + members = [] + + # Check if epsilon is used. + for node in dag.nodes.values(): + if node.kind == NodeKind.SCALAR and node.name == "eps": + members.append(" float epsilon;") + + return "\n".join(members) + + +def _generate_transform_members( + op: InfiniOpDef, + dag: ComputeDAG, + match: MatchResult, +) -> str: + """Generate member variables for the transform op struct.""" + members = [] + + for node in dag.nodes.values(): + if node.kind == NodeKind.INPUT and node.name == "weight": + members.append(" const void* weight;") + + return "\n".join(members) + + +# ---- Top-level file generators --------------------------------------------- + + +def generate_cuda_kernel( + op: InfiniOpDef, + dag: ComputeDAG, + match: MatchResult, +) -> str: + """Generate the shared CUDA kernel header for an `@infini_op`.""" + op_snake = _to_snake(op.name) + + if op.impl_index > 0: + guard = f"INFINI_OPS_CUDA_{op_snake.upper()}_DSL_H_" + else: + guard = f"INFINI_OPS_CUDA_{op_snake.upper()}_KERNEL_H_" + + if match.brick == BrickKind.BINARY_ELEMENTWISE: + return _gen_binary_elementwise_cuda(op, dag, match, guard, op_snake) + + if match.brick == BrickKind.UNARY_ELEMENTWISE: + return _gen_unary_elementwise_cuda(op, dag, match, guard, op_snake) + + if match.brick == BrickKind.REDUCE_THEN_TRANSFORM: + return _gen_reduce_transform_cuda(op, dag, match, guard, op_snake) + + raise ValueError(f"Unsupported brick kind for CUDA codegen: {match.brick}.") + + +def generate_cpu_kernel( + op: InfiniOpDef, + dag: ComputeDAG, + match: MatchResult, +) -> str: + """Generate the CPU implementation header for an `@infini_op`.""" + op_snake = _to_snake(op.name) + + if op.impl_index > 0: + guard = f"INFINI_OPS_CPU_{op_snake.upper()}_DSL_H_" + else: + guard = f"INFINI_OPS_CPU_{op_snake.upper()}_{op_snake.upper()}_H_" + + if match.brick == BrickKind.BINARY_ELEMENTWISE: + return _gen_binary_elementwise_cpu(op, dag, match, guard, op_snake) + + if match.brick == BrickKind.UNARY_ELEMENTWISE: + return _gen_unary_elementwise_cpu(op, dag, match, guard, op_snake) + + if match.brick == BrickKind.REDUCE_THEN_TRANSFORM: + return _gen_reduce_transform_cpu(op, dag, match, guard, op_snake) + + raise ValueError(f"Unsupported brick kind for CPU codegen: {match.brick}.") + + +# ---- Binary elementwise file generators ------------------------------------ + + +def _gen_binary_elementwise_cuda( + op: InfiniOpDef, + dag: ComputeDAG, + match: MatchResult, + guard: str, + op_snake: str, +) -> str: + prefix = _dsl_prefix(op) + functor = _generate_binary_functor_cuda(op, dag, match) + base_header = f"base/{op_snake}.h" + class_name = f"{prefix}Cuda{op.name}" + functor_name = f"{prefix}{op.name}Op" + + return f"""\ +#ifndef {guard} +#define {guard} + +#include "cuda/templates/binary_elementwise.cuh" +#include "{base_header}" + +namespace infini::ops {{ + +{functor} + +template +class {class_name} : public {op.name} {{ + public: + {class_name}(const Tensor input, const Tensor other, Tensor out) + : {op.name}{{input, other, out}}, + brick_{{input, other, out, ndim_}} {{}} + + void operator()(const Tensor input, const Tensor other, + Tensor out) const override {{ + brick_.template Run( + stream_, input, other, out, output_size_, ndim_, + is_input_contiguous_, is_other_contiguous_, is_out_contiguous_, + out_type_); + }} + + private: + BinaryElementwiseBrick brick_; +}}; + +}} // namespace infini::ops + +#endif +""" + + +def _gen_binary_elementwise_cpu( + op: InfiniOpDef, + dag: ComputeDAG, + match: MatchResult, + guard: str, + op_snake: str, +) -> str: + prefix = _dsl_prefix(op) + functor = _generate_binary_functor_cpu(op, dag, match) + base_header = f"base/{op_snake}.h" + functor_name = f"{prefix}Cpu{op.name}Op" + impl_suffix = ", Impl::kDsl" if op.impl_index > 0 else "" + impl_include = ( + f'#include "impl.h"\n#include "cpu/{op_snake}/registry.h"\n' + if op.impl_index > 0 + else "" + ) + + return f"""\ +#ifndef {guard} +#define {guard} + +#include "cpu/templates/binary_elementwise.h" +#include "{base_header}" +{impl_include} +namespace infini::ops {{ + +{functor} + +template <> +class Operator<{op.name}, Device::Type::kCpu{impl_suffix}> : public {op.name} {{ + public: + using {op.name}::{op.name}; + + void operator()(const Tensor input, const Tensor other, + Tensor out) const override {{ + CpuBinaryElementwise( + input, other, out, output_size_, ndim_, + is_input_contiguous_, is_other_contiguous_, is_out_contiguous_, + input_shape_, other_shape_, out_shape_, + input_strides_, other_strides_, out_strides_, + out_type_, {functor_name}{{}}); + }} +}}; + +}} // namespace infini::ops + +#endif +""" + + +# ---- Unary elementwise file generators --------------------------------------- + + +def _gen_unary_elementwise_cuda( + op: InfiniOpDef, + dag: ComputeDAG, + match: MatchResult, + guard: str, + op_snake: str, +) -> str: + prefix = _dsl_prefix(op) + functor = _generate_unary_functor_cuda(op, dag, match) + base_header = f"base/{op_snake}.h" + class_name = f"{prefix}Cuda{op.name}" + functor_name = f"{prefix}{op.name}Op" + + return f"""\ +#ifndef {guard} +#define {guard} + +#include "cuda/templates/unary_elementwise.cuh" +#include "{base_header}" + +namespace infini::ops {{ + +{functor} + +template +class {class_name} : public {op.name} {{ + public: + {class_name}(const Tensor input, Tensor out) + : {op.name}{{input, out}}, + brick_{{input, out, ndim_}} {{}} + + void operator()(const Tensor input, Tensor out) const override {{ + brick_.template Run( + stream_, input, out, output_size_, ndim_, + is_input_contiguous_, is_out_contiguous_, + input_dtype_, out_dtype_); + }} + + private: + UnaryElementwiseBrick brick_; +}}; + +}} // namespace infini::ops + +#endif +""" + + +def _gen_unary_elementwise_cpu( + op: InfiniOpDef, + dag: ComputeDAG, + match: MatchResult, + guard: str, + op_snake: str, +) -> str: + prefix = _dsl_prefix(op) + functor = _generate_unary_functor_cpu(op, dag, match) + base_header = f"base/{op_snake}.h" + functor_name = f"{prefix}Cpu{op.name}Op" + impl_suffix = ", Impl::kDsl" if op.impl_index > 0 else "" + impl_include = ( + f'#include "impl.h"\n#include "cpu/{op_snake}/registry.h"\n' + if op.impl_index > 0 + else "" + ) + + return f"""\ +#ifndef {guard} +#define {guard} + +#include "cpu/templates/unary_elementwise.h" +#include "{base_header}" +{impl_include} +namespace infini::ops {{ + +{functor} + +template <> +class Operator<{op.name}, Device::Type::kCpu{impl_suffix}> : public {op.name} {{ + public: + using {op.name}::{op.name}; + + void operator()(const Tensor input, Tensor out) const override {{ + CpuUnaryElementwise( + input, out, output_size_, ndim_, + is_input_contiguous_, is_out_contiguous_, + input_shape_, out_shape_, + input_strides_, out_strides_, + input_dtype_, out_dtype_, {functor_name}{{}}); + }} +}}; + +}} // namespace infini::ops + +#endif +""" + + +# ---- Reduce-then-transform file generators --------------------------------- + + +def _gen_reduce_transform_cuda( + op: InfiniOpDef, + dag: ComputeDAG, + match: MatchResult, + guard: str, + op_snake: str, +) -> str: + prefix = _dsl_prefix(op) + reduce_op = _generate_reduce_op_cuda(op, dag, match) + transform_op = _generate_transform_op_cuda(op, dag, match) + base_header = f"base/{op_snake}.h" + class_name = f"{prefix}Cuda{op.name}" + reduce_name = f"{prefix}{op.name}Reduce" + transform_name = f"{prefix}{op.name}Transform" + + # Determine the type list based on the operator. + type_list = "ConcatType, ReducedFloatTypes>" + + return f"""\ +#ifndef {guard} +#define {guard} + +#include "cuda/templates/reduce_transform.cuh" +#include "{base_header}" + +namespace infini::ops {{ + +{reduce_op} + +{transform_op} + +template +class {class_name} : public {op.name} {{ + public: + using {op.name}::{op.name}; + + void operator()(const Tensor input, const Tensor weight, float eps, + Tensor out) const override {{ + LaunchReduceThenTransform( + stream_, input, out, batch_size_, nhead_, dim_, + out.dtype(), input_strides_, out_strides_, + {reduce_name}{{eps}}, + {transform_name}{{weight.data()}}); + }} +}}; + +}} // namespace infini::ops + +#endif +""" + + +def _gen_reduce_transform_cpu( + op: InfiniOpDef, + dag: ComputeDAG, + match: MatchResult, + guard: str, + op_snake: str, +) -> str: + prefix = _dsl_prefix(op) + reduce_op = _generate_reduce_op_cpu(op, dag, match) + transform_op = _generate_transform_op_cpu(op, dag, match) + base_header = f"base/{op_snake}.h" + reduce_name = f"{prefix}Cpu{op.name}Reduce" + transform_name = f"{prefix}Cpu{op.name}Transform" + impl_suffix = ", Impl::kDsl" if op.impl_index > 0 else "" + impl_include = ( + f'#include "impl.h"\n#include "cpu/{op_snake}/registry.h"\n' + if op.impl_index > 0 + else "" + ) + + type_list = "ConcatType, ReducedFloatTypes>" + + return f"""\ +#ifndef {guard} +#define {guard} + +#include "cpu/templates/reduce_transform.h" +#include "{base_header}" +{impl_include} +namespace infini::ops {{ + +{reduce_op} + +{transform_op} + +template <> +class Operator<{op.name}, Device::Type::kCpu{impl_suffix}> : public {op.name} {{ + public: + using {op.name}::{op.name}; + + void operator()(const Tensor input, const Tensor weight, float eps, + Tensor out) const override {{ + CpuReduceThenTransform<{type_list}>( + input, out, batch_size_, nhead_, dim_, + out.dtype(), input_strides_, out_strides_, + {reduce_name}{{eps}}, + {transform_name}{{weight.data()}}); + }} +}}; + +}} // namespace infini::ops + +#endif +""" diff --git a/dsl/compiler/parser.py b/dsl/compiler/parser.py new file mode 100644 index 00000000..cb797b69 --- /dev/null +++ b/dsl/compiler/parser.py @@ -0,0 +1,325 @@ +"""Parse `@infini_op` function bodies into a compute DAG.""" + +from __future__ import annotations + +import ast +import inspect +import textwrap +from typing import TYPE_CHECKING, Any + +from dsl.compiler.dag import ComputeDAG, NodeKind + +if TYPE_CHECKING: + from dsl.decorators import InfiniOpDef + +# Map Python AST binary operators to DAG node kinds. +_BINOP_MAP: dict[type, NodeKind] = { + ast.Add: NodeKind.ADD, + ast.Sub: NodeKind.SUB, + ast.Mult: NodeKind.MUL, + ast.Div: NodeKind.DIV, + ast.Pow: NodeKind.POW, +} + +# Map Python AST comparison operators to DAG node kinds. +_CMPOP_MAP: dict[type, NodeKind] = { + ast.Gt: NodeKind.GT, + ast.Lt: NodeKind.LT, + ast.GtE: NodeKind.GE, + ast.LtE: NodeKind.LE, + ast.Eq: NodeKind.EQ, +} + +# Map DSL function names to DAG node kinds. +_FUNC_MAP: dict[str, NodeKind] = { + "sqrt": NodeKind.SQRT, + "rsqrt": NodeKind.RSQRT, + "exp": NodeKind.EXP, + "log": NodeKind.LOG, + "abs": NodeKind.ABS, + "neg": NodeKind.NEG, + "relu": NodeKind.RELU, + "gelu": NodeKind.GELU, + "silu": NodeKind.SILU, + "sigmoid": NodeKind.SIGMOID, + "tanh": NodeKind.TANH, + "reduce_sum": NodeKind.REDUCE_SUM, + "reduce_mean": NodeKind.REDUCE_MEAN, + "reduce_max": NodeKind.REDUCE_MAX, + "reduce_min": NodeKind.REDUCE_MIN, + "cast": NodeKind.CAST, + "where": NodeKind.WHERE, + "clamp": NodeKind.CLAMP, +} + + +class _DAGBuilder(ast.NodeVisitor): + """Walk a function AST and build a ``ComputeDAG``.""" + + def __init__(self, dag: ComputeDAG, params: dict[str, dict[str, Any]]) -> None: + self.dag = dag + self.params = params + + # Maps local variable names to DAG node ids. + self.env: dict[str, int] = {} + + # Register function parameters as INPUT / SCALAR nodes. + for pname, pinfo in params.items(): + + if pinfo["kind"] == "tensor": + nid = dag.add_node( + NodeKind.INPUT, + name=pname, + shape=pinfo.get("shape"), + ) + else: + nid = dag.add_node(NodeKind.SCALAR, name=pname) + + self.env[pname] = nid + + def visit_Assign(self, node: ast.Assign) -> None: + assert len(node.targets) == 1, "Only single assignment supported." + target = node.targets[0] + assert isinstance(target, ast.Name) + + nid = self._visit_expr(node.value) + self.env[target.id] = nid + + def visit_Return(self, node: ast.Return) -> None: + assert node.value is not None + nid = self._visit_expr(node.value) + self.dag.output_id = nid + + def _visit_expr(self, node: ast.expr) -> int: + """Recursively translate an expression AST node into DAG nodes.""" + + if isinstance(node, ast.Name): + assert node.id in self.env, f"Undefined variable: `{node.id}`." + + return self.env[node.id] + + if isinstance(node, ast.Constant): + + return self.dag.add_node( + NodeKind.SCALAR, + name=repr(node.value), + attrs={"value": node.value}, + ) + + if isinstance(node, ast.BinOp): + + return self._visit_binop(node) + + if isinstance(node, ast.UnaryOp): + + return self._visit_unaryop(node) + + if isinstance(node, ast.Call): + + return self._visit_call(node) + + if isinstance(node, ast.Compare): + + return self._visit_compare(node) + + raise ValueError(f"Unsupported expression type: {type(node).__name__}.") + + def _visit_binop(self, node: ast.BinOp) -> int: + left = self._visit_expr(node.left) + right = self._visit_expr(node.right) + kind = _BINOP_MAP.get(type(node.op)) + + if kind is None: + raise ValueError( + f"Unsupported binary operator: {type(node.op).__name__}." + ) + + return self.dag.add_node(kind, inputs=[left, right]) + + def _visit_unaryop(self, node: ast.UnaryOp) -> int: + operand = self._visit_expr(node.operand) + + if isinstance(node.op, ast.USub): + + return self.dag.add_node(NodeKind.NEG, inputs=[operand]) + + raise ValueError( + f"Unsupported unary operator: {type(node.op).__name__}." + ) + + def _visit_call(self, node: ast.Call) -> int: + func_name = self._get_func_name(node) + kind = _FUNC_MAP.get(func_name) + + if kind is None: + raise ValueError(f"Unknown DSL primitive: `{func_name}`.") + + # Build input list from positional args. + inputs = [self._visit_expr(arg) for arg in node.args] + + # Extract keyword arguments. + kwargs: dict[str, Any] = {} + + for kw in node.keywords: + assert kw.arg is not None + + if isinstance(kw.value, ast.Constant): + kwargs[kw.arg] = kw.value.value + elif isinstance(kw.value, ast.Constant): + kwargs[kw.arg] = kw.value.value + elif isinstance(kw.value, ast.Name): + kwargs[kw.arg] = kw.value.id + + # Handle reduction ops. + if kind in ( + NodeKind.REDUCE_SUM, + NodeKind.REDUCE_MEAN, + NodeKind.REDUCE_MAX, + NodeKind.REDUCE_MIN, + ): + + return self.dag.add_node( + kind, + inputs=inputs, + reduce_dim=kwargs.get("dim"), + ) + + # Handle cast. + if kind == NodeKind.CAST: + + return self.dag.add_node( + kind, + inputs=inputs, + cast_dtype=kwargs.get("dtype"), + ) + + # Handle where(cond, a, b). + if kind == NodeKind.WHERE: + assert len(inputs) == 3, "`where` requires 3 arguments." + + return self.dag.add_node(kind, inputs=inputs) + + # Handle clamp. + if kind == NodeKind.CLAMP: + + return self.dag.add_node( + kind, + inputs=inputs, + clamp_min=kwargs.get("min"), + clamp_max=kwargs.get("max"), + ) + + # Unary / activation functions. + return self.dag.add_node(kind, inputs=inputs) + + def _visit_compare(self, node: ast.Compare) -> int: + assert len(node.ops) == 1, "Only single comparisons supported." + assert len(node.comparators) == 1 + + left = self._visit_expr(node.left) + right = self._visit_expr(node.comparators[0]) + kind = _CMPOP_MAP.get(type(node.ops[0])) + + if kind is None: + raise ValueError( + f"Unsupported comparison: {type(node.ops[0]).__name__}." + ) + + return self.dag.add_node(kind, inputs=[left, right]) + + @staticmethod + def _get_func_name(node: ast.Call) -> str: + + if isinstance(node.func, ast.Name): + return node.func.id + + if isinstance(node.func, ast.Attribute): + return node.func.attr + + raise ValueError(f"Unsupported call target: {type(node.func).__name__}.") + + +def _extract_params(func_def: ast.FunctionDef) -> dict[str, dict[str, Any]]: + """Extract parameter metadata from the function signature AST.""" + params: dict[str, dict[str, Any]] = {} + + for arg in func_def.args.args: + pname = arg.arg + annotation = arg.annotation + pinfo: dict[str, Any] = {"kind": "tensor"} + + if annotation is not None: + + # Tensor["B", "H", "D"] → subscript with shape vars. + if isinstance(annotation, ast.Subscript): + + if isinstance(annotation.value, ast.Name): + + if annotation.value.id == "Scalar": + pinfo["kind"] = "scalar" + elif annotation.value.id == "Tensor": + # Extract shape variable names. + shape = _extract_shape_vars(annotation.slice) + pinfo["shape"] = shape + + elif isinstance(annotation, ast.Name): + + if annotation.id == "float": + pinfo["kind"] = "scalar" + elif annotation.id == "int": + pinfo["kind"] = "scalar" + + params[pname] = pinfo + + return params + + +def _extract_shape_vars(node: ast.expr) -> list[str]: + """Extract shape variable names from a Tensor subscript.""" + + if isinstance(node, ast.Tuple): + return [_const_str(elt) for elt in node.elts] + + return [_const_str(node)] + + +def _const_str(node: ast.expr) -> str: + + if isinstance(node, ast.Constant) and isinstance(node.value, str): + return node.value + + raise ValueError(f"Expected string constant, got {type(node).__name__}.") + + +def parse_infini_op(op: InfiniOpDef) -> ComputeDAG: + """Parse an `@infini_op` function into a ``ComputeDAG``.""" + assert op.func is not None, f"Operator `{op.name}` has no function body." + + source = inspect.getsource(op.func) + source = textwrap.dedent(source) + tree = ast.parse(source) + + # Find the function definition (skip the decorator). + func_def: ast.FunctionDef | None = None + + for node in ast.walk(tree): + + if isinstance(node, ast.FunctionDef): + func_def = node + + break + + assert func_def is not None, "No function definition found." + + params = _extract_params(func_def) + dag = ComputeDAG(shape_vars=dict(op.shapes)) + builder = _DAGBuilder(dag, params) + + for stmt in func_def.body: + builder.visit(stmt) + + assert dag.output_id is not None, ( + f"Operator `{op.name}` function body has no return statement." + ) + + return dag diff --git a/dsl/compiler/patterns.py b/dsl/compiler/patterns.py new file mode 100644 index 00000000..2dc9a5a4 --- /dev/null +++ b/dsl/compiler/patterns.py @@ -0,0 +1,182 @@ +"""Pattern matching: map compute DAG subgraphs to C++ template bricks.""" + +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum, auto +from typing import TYPE_CHECKING + +from dsl.compiler.dag import ( + ELEMENTWISE_BINARY, + ELEMENTWISE_UNARY, + ComputeDAG, + DagNode, + NodeKind, +) + +if TYPE_CHECKING: + pass + + +class BrickKind(Enum): + """Available C++ template bricks.""" + + BINARY_ELEMENTWISE = auto() + UNARY_ELEMENTWISE = auto() + REDUCE_THEN_TRANSFORM = auto() + PURE_REDUCTION = auto() + + +@dataclass +class MatchResult: + """Result of matching a compute DAG to a brick pattern.""" + + brick: BrickKind + + # For REDUCE_THEN_TRANSFORM: the reduce and transform sub-DAGs. + reduce_nodes: list[int] | None = None + transform_nodes: list[int] | None = None + reduce_dim: str | None = None + + # For elementwise: the functor body description. + elementwise_kind: str | None = None + + # The input parameter names involved. + input_names: list[str] | None = None + + +def match_dag(dag: ComputeDAG) -> MatchResult: + """Match a compute DAG to the best-fitting brick pattern. + + Raises ``ValueError`` if no pattern matches. + """ + + if dag.is_elementwise_only(): + return _match_elementwise(dag) + + if dag.has_reduction(): + return _match_reduce_then_transform(dag) + + raise ValueError( + "Cannot match DAG to any known brick pattern. " + "Consider using `@manual_op` instead." + ) + + +def _match_elementwise(dag: ComputeDAG) -> MatchResult: + """Match a pure-elementwise DAG.""" + + # Collect input tensor names. + inputs = [ + n.name + for n in dag.nodes.values() + if n.kind == NodeKind.INPUT and n.name is not None + ] + + # Determine if it is a binary or unary elementwise op. + compute_nodes = [ + n + for n in dag.nodes.values() + if n.kind not in (NodeKind.INPUT, NodeKind.SCALAR) + ] + + # Count tensor inputs (not scalar). + tensor_inputs = [ + n for n in dag.nodes.values() if n.kind == NodeKind.INPUT + ] + + if len(tensor_inputs) >= 2: + # Determine the core operation kind for simple binary ops. + kind = _identify_core_op(dag, compute_nodes) + + return MatchResult( + brick=BrickKind.BINARY_ELEMENTWISE, + elementwise_kind=kind, + input_names=inputs, + ) + + return MatchResult( + brick=BrickKind.UNARY_ELEMENTWISE, + elementwise_kind=_identify_core_op(dag, compute_nodes), + input_names=inputs, + ) + + +def _match_reduce_then_transform(dag: ComputeDAG) -> MatchResult: + """Match a reduce-then-transform pattern. + + The DAG must have exactly one reduction, followed by elementwise ops + that use the reduction result. + """ + reductions = dag.reduction_nodes() + + if not reductions: + raise ValueError("Expected at least one reduction node.") + + # Use the first reduction as the primary one. + reduce_node = reductions[0] + + # Identify all nodes that contribute to the reduction (pre-reduce). + reduce_ancestors = _ancestors(dag, reduce_node.id) + reduce_ancestors.add(reduce_node.id) + + # Everything after the reduction is the transform. + topo = dag.topo_sort() + reduce_idx = topo.index(reduce_node.id) + transform_ids = [ + nid + for nid in topo[reduce_idx + 1 :] + if dag.get(nid).kind not in (NodeKind.INPUT, NodeKind.SCALAR) + ] + + # Collect input names. + inputs = [ + n.name + for n in dag.nodes.values() + if n.kind == NodeKind.INPUT and n.name is not None + ] + + return MatchResult( + brick=BrickKind.REDUCE_THEN_TRANSFORM, + reduce_nodes=sorted(reduce_ancestors), + transform_nodes=transform_ids, + reduce_dim=reduce_node.reduce_dim, + input_names=inputs, + ) + + +def _ancestors(dag: ComputeDAG, nid: int) -> set[int]: + """Return all ancestor node ids (transitive inputs), excluding leaf nodes.""" + result: set[int] = set() + stack = list(dag.get(nid).inputs) + + while stack: + cur = stack.pop() + node = dag.get(cur) + + if node.kind in (NodeKind.INPUT, NodeKind.SCALAR): + continue + + if cur not in result: + result.add(cur) + stack.extend(node.inputs) + + return result + + +def _identify_core_op(dag: ComputeDAG, compute_nodes: list[DagNode]) -> str: + """Identify the dominant operation kind for simple elementwise DAGs.""" + + if len(compute_nodes) == 1: + return compute_nodes[0].kind.name.lower() + + # For compound expressions, return a description. + kinds = {n.kind for n in compute_nodes} + + if kinds <= ELEMENTWISE_BINARY: + return "compound_binary" + + if kinds <= ELEMENTWISE_UNARY: + return "compound_unary" + + return "compound_mixed" diff --git a/dsl/compiler/registry.py b/dsl/compiler/registry.py new file mode 100644 index 00000000..bb9fea5b --- /dev/null +++ b/dsl/compiler/registry.py @@ -0,0 +1,81 @@ +"""Global registry collecting all operator definitions.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from dsl.decorators import InfiniOpDef, ManualOpDef + + +class _Registry: + def __init__(self) -> None: + self._ops: dict[str, ManualOpDef | InfiniOpDef] = {} + self._variants: dict[str, list[InfiniOpDef]] = {} + + def register(self, op: ManualOpDef | InfiniOpDef) -> None: + from dsl.decorators import InfiniOpDef + + if isinstance(op, InfiniOpDef) and op.impl_index > 0: + self._variants.setdefault(op.name, []).append(op) + + return + + if op.name in self._ops: + raise ValueError(f"Operator `{op.name}` is already registered.") + + self._ops[op.name] = op + + def get(self, name: str) -> ManualOpDef | InfiniOpDef: + return self._ops[name] + + def all_ops(self) -> dict[str, ManualOpDef | InfiniOpDef]: + return dict(self._ops) + + def variants(self, name: str) -> list[InfiniOpDef]: + """Return DSL alternative implementations for a given operator.""" + + return list(self._variants.get(name, [])) + + def all_variants(self) -> dict[str, list[InfiniOpDef]]: + """Return all DSL variant implementations.""" + + return dict(self._variants) + + def impl_names_for(self, name: str) -> dict[str, int]: + """Return the merged name→index mapping for an operator. + + Rules: + - ``@manual_op`` with explicit ``impl_names`` → use as-is. + - ``@manual_op`` without ``impl_names`` → ``{"default": 0}``. + - Each ``@infini_op`` variant adds ``{"dsl": impl_index}``. + """ + from dsl.decorators import ManualOpDef + + primary = self._ops.get(name) + result: dict[str, int] = {} + + if primary is not None: + + if isinstance(primary, ManualOpDef) and primary.impl_names: + result = {v: k for k, v in primary.impl_names.items()} + else: + result = {"default": 0} + + for variant in self._variants.get(name, []): + result["dsl"] = variant.impl_index + + return result + + def all_impl_names(self) -> dict[str, dict[str, int]]: + """Return name→index mappings for all operators.""" + all_names = set(self._ops.keys()) | set(self._variants.keys()) + + return {name: self.impl_names_for(name) for name in sorted(all_names)} + + def clear(self) -> None: + self._ops.clear() + self._variants.clear() + + +REGISTRY = _Registry() diff --git a/dsl/decorators.py b/dsl/decorators.py new file mode 100644 index 00000000..14c28835 --- /dev/null +++ b/dsl/decorators.py @@ -0,0 +1,94 @@ +"""Decorators for registering InfiniOps operators.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Callable + +from dsl.compiler.registry import REGISTRY + + +@dataclass +class ManualOpDef: + """An operator whose kernel logic is hand-written in C++.""" + + name: str + base: str + backends: dict[str, str | dict[str, str]] = field(default_factory=dict) + impl_names: dict[int, str] = field(default_factory=dict) + + +@dataclass +class InfiniOpDef: + """An operator whose CUDA/CPU kernels are auto-generated from DSL.""" + + name: str + shapes: dict[str, str] = field(default_factory=dict) + manual_backends: dict[str, str] = field(default_factory=dict) + func: Callable[..., Any] | None = None + impl_index: int = 0 + + +def manual_op( + *, + name: str, + base: str, + backends: dict[str, str | dict[str, str]] | None = None, + impl_names: dict[int, str] | None = None, +) -> Callable: + """Register a hand-written operator. + + The compiler generates only boilerplate (backend wrappers, bindings) + while kernel logic stays in the files specified by ``backends``. + + ``impl_names`` maps implementation indices to human-readable names + (e.g. ``{0: "cublas", 1: "cublaslt"}``). When omitted, the default + mapping ``{0: "default"}`` is used. + """ + + def decorator(func: Callable) -> ManualOpDef: + op = ManualOpDef( + name=name, + base=base, + backends=backends or {}, + impl_names=impl_names or {}, + ) + REGISTRY.register(op) + + return op + + return decorator + + +def infini_op( + *, + name: str, + shapes: dict[str, str] | None = None, + manual_backends: dict[str, str] | None = None, + impl_index: int = 0, +) -> Callable: + """Register an operator defined in the DSL. + + CUDA-like backends and CPU get auto-generated kernel code. + Backends listed in ``manual_backends`` use the specified hand-written + implementations instead. + + When ``impl_index > 0``, the operator is registered as an alternative + implementation of an existing operator (like cuBLAS vs cuBLASLt for + GEMM). The compiler generates ``Operator`` + specializations and a ``registry.h`` declaring ``List<0, ..., N>``. + """ + + def decorator(func: Callable) -> InfiniOpDef: + op = InfiniOpDef( + name=name, + shapes=shapes or {}, + manual_backends=manual_backends or {}, + func=func, + impl_index=impl_index, + ) + REGISTRY.register(op) + + return op + + return decorator diff --git a/dsl/ops/__init__.py b/dsl/ops/__init__.py new file mode 100644 index 00000000..9b68ee32 --- /dev/null +++ b/dsl/ops/__init__.py @@ -0,0 +1,22 @@ +"""Operator definitions for InfiniOps. + +Importing this package auto-discovers and registers all operator definitions +in this directory. +""" + +import importlib +import pathlib + +_OPS_DIR = pathlib.Path(__file__).parent + + +def discover() -> None: + """Import every Python module in this package to trigger registration.""" + + for path in sorted(_OPS_DIR.glob("*.py")): + + if path.name.startswith("_"): + continue + + module_name = f"dsl.ops.{path.stem}" + importlib.import_module(module_name) diff --git a/dsl/ops/add.py b/dsl/ops/add.py new file mode 100644 index 00000000..71663d4d --- /dev/null +++ b/dsl/ops/add.py @@ -0,0 +1,14 @@ +from dsl.decorators import manual_op + + +@manual_op( + name="Add", + base="src/base/add.h", + backends={ + "cuda": "cuda/add/kernel.h", + "ascend": "ascend/add/kernel.h", + "cpu": "cpu/add/add.h", + }, +) +def add(): + ... diff --git a/dsl/ops/add_dsl.py b/dsl/ops/add_dsl.py new file mode 100644 index 00000000..f882c244 --- /dev/null +++ b/dsl/ops/add_dsl.py @@ -0,0 +1,23 @@ +"""DSL alternative implementation for Add (impl_index=1). + +Registers as ``Operator`` alongside the existing +hand-written ``Operator``. +""" + +from dsl.decorators import infini_op +from dsl.primitives import Tensor + + +@infini_op( + name="Add", + impl_index=1, + shapes={"N": "output_size"}, + manual_backends={ + "ascend": "ascend/add/kernel.h", + }, +) +def add_dsl( + input: Tensor["N"], + other: Tensor["N"], +) -> Tensor["N"]: + return input + other diff --git a/dsl/ops/add_rms_norm.py b/dsl/ops/add_rms_norm.py new file mode 100644 index 00000000..dbd61392 --- /dev/null +++ b/dsl/ops/add_rms_norm.py @@ -0,0 +1,13 @@ +from dsl.decorators import manual_op + + +@manual_op( + name="AddRmsNorm", + base="src/base/add_rms_norm.h", + backends={ + "cuda": "cuda/add_rms_norm/kernel.h", + "ascend": "ascend/add_rms_norm/kernel.h", + }, +) +def add_rms_norm(): + ... diff --git a/dsl/ops/cast.py b/dsl/ops/cast.py new file mode 100644 index 00000000..feb6a8ef --- /dev/null +++ b/dsl/ops/cast.py @@ -0,0 +1,13 @@ +from dsl.decorators import manual_op + + +@manual_op( + name="Cast", + base="src/base/cast.h", + backends={ + "ascend": "ascend/cast/kernel.h", + "cpu": "cpu/cast/cast.h", + }, +) +def cast(): + ... diff --git a/dsl/ops/cast_dsl.py b/dsl/ops/cast_dsl.py new file mode 100644 index 00000000..dd5827f6 --- /dev/null +++ b/dsl/ops/cast_dsl.py @@ -0,0 +1,22 @@ +"""DSL alternative implementation for Cast (impl_index=1). + +Registers as ``Operator`` alongside the existing +hand-written ``Operator``. +""" + +from dsl.decorators import infini_op +from dsl.primitives import Tensor, cast + + +@infini_op( + name="Cast", + impl_index=1, + shapes={"N": "output_size"}, + manual_backends={ + "ascend": "ascend/cast/kernel.h", + }, +) +def cast_dsl( + input: Tensor["N"], +) -> Tensor["N"]: + return cast(input) diff --git a/dsl/ops/cat.py b/dsl/ops/cat.py new file mode 100644 index 00000000..3edbd02f --- /dev/null +++ b/dsl/ops/cat.py @@ -0,0 +1,14 @@ +from dsl.decorators import manual_op + + +@manual_op( + name="Cat", + base="src/base/cat.h", + backends={ + "cuda": "cuda/cat/kernel.h", + "ascend": "ascend/cat/kernel.h", + "cpu": "cpu/cat/cat.h", + }, +) +def cat(): + ... diff --git a/dsl/ops/causal_softmax.py b/dsl/ops/causal_softmax.py new file mode 100644 index 00000000..89cbd2bd --- /dev/null +++ b/dsl/ops/causal_softmax.py @@ -0,0 +1,14 @@ +from dsl.decorators import manual_op + + +@manual_op( + name="CausalSoftmax", + base="src/base/causal_softmax.h", + backends={ + "cuda": "cuda/causal_softmax/kernel.h", + "ascend": "ascend/causal_softmax/kernel.h", + "cpu": "cpu/causal_softmax/causal_softmax.h", + }, +) +def causal_softmax(): + ... diff --git a/dsl/ops/flash_attention.py b/dsl/ops/flash_attention.py new file mode 100644 index 00000000..0250fdec --- /dev/null +++ b/dsl/ops/flash_attention.py @@ -0,0 +1,13 @@ +from dsl.decorators import manual_op + + +@manual_op( + name="FlashAttention", + base="src/base/flash_attention.h", + backends={ + "cuda": "cuda/flash_attention/kernel.h", + "ascend": "ascend/flash_attention/kernel.h", + }, +) +def flash_attention(): + ... diff --git a/dsl/ops/gemm.py b/dsl/ops/gemm.py new file mode 100644 index 00000000..69a2a818 --- /dev/null +++ b/dsl/ops/gemm.py @@ -0,0 +1,16 @@ +from dsl.decorators import manual_op + + +@manual_op( + name="Gemm", + base="src/base/gemm.h", + impl_names={0: "cublas", 1: "cublaslt"}, + backends={ + "cuda": {"include": "cuda/gemm/blas.h", "class": "BlasGemm", "blas": True}, + "ascend": "ascend/gemm/kernel.h", + "cambricon": "cambricon/gemm/cnblas.h", + "cpu": "cpu/gemm/gemm.h", + }, +) +def gemm(): + ... diff --git a/dsl/ops/linear.py b/dsl/ops/linear.py new file mode 100644 index 00000000..4c8fb93b --- /dev/null +++ b/dsl/ops/linear.py @@ -0,0 +1,14 @@ +from dsl.decorators import manual_op + + +@manual_op( + name="Linear", + base="src/base/linear.h", + backends={ + "cuda": "cuda/linear/kernel.h", + "ascend": "ascend/linear/kernel.h", + "cpu": "cpu/linear/linear.h", + }, +) +def linear(): + ... diff --git a/dsl/ops/matmul.py b/dsl/ops/matmul.py new file mode 100644 index 00000000..9d0e7363 --- /dev/null +++ b/dsl/ops/matmul.py @@ -0,0 +1,14 @@ +from dsl.decorators import manual_op + + +@manual_op( + name="Matmul", + base="src/base/matmul.h", + backends={ + "nvidia": "nvidia/matmul/cublaslt.h", + "ascend": "ascend/matmul/kernel.h", + "cpu": "cpu/matmul/matmul.h", + }, +) +def matmul(): + ... diff --git a/dsl/ops/mul.py b/dsl/ops/mul.py new file mode 100644 index 00000000..c66adf83 --- /dev/null +++ b/dsl/ops/mul.py @@ -0,0 +1,13 @@ +from dsl.decorators import manual_op + + +@manual_op( + name="Mul", + base="src/base/mul.h", + backends={ + "ascend": "ascend/mul/kernel.h", + "cpu": "cpu/mul/mul.h", + }, +) +def mul(): + ... diff --git a/dsl/ops/mul_dsl.py b/dsl/ops/mul_dsl.py new file mode 100644 index 00000000..64975428 --- /dev/null +++ b/dsl/ops/mul_dsl.py @@ -0,0 +1,23 @@ +"""DSL alternative implementation for Mul (impl_index=1). + +Registers as ``Operator`` alongside the existing +hand-written ``Operator``. +""" + +from dsl.decorators import infini_op +from dsl.primitives import Tensor + + +@infini_op( + name="Mul", + impl_index=1, + shapes={"N": "output_size"}, + manual_backends={ + "ascend": "ascend/mul/kernel.h", + }, +) +def mul_dsl( + input: Tensor["N"], + other: Tensor["N"], +) -> Tensor["N"]: + return input * other diff --git a/dsl/ops/reshape_and_cache.py b/dsl/ops/reshape_and_cache.py new file mode 100644 index 00000000..967093f6 --- /dev/null +++ b/dsl/ops/reshape_and_cache.py @@ -0,0 +1,13 @@ +from dsl.decorators import manual_op + + +@manual_op( + name="ReshapeAndCache", + base="src/base/reshape_and_cache.h", + backends={ + "cuda": "cuda/reshape_and_cache/kernel.h", + "ascend": "ascend/reshape_and_cache/kernel.h", + }, +) +def reshape_and_cache(): + ... diff --git a/dsl/ops/rms_norm.py b/dsl/ops/rms_norm.py new file mode 100644 index 00000000..e1a1ead4 --- /dev/null +++ b/dsl/ops/rms_norm.py @@ -0,0 +1,15 @@ +from dsl.decorators import manual_op + + +@manual_op( + name="RmsNorm", + base="src/base/rms_norm.h", + backends={ + "cuda": "cuda/rms_norm/kernel.h", + "ascend": "ascend/rms_norm/kernel.h", + "cambricon": "cambricon/rms_norm/rms_norm.h", + "cpu": "cpu/rms_norm/rms_norm.h", + }, +) +def rms_norm(): + ... diff --git a/dsl/ops/rms_norm_dsl.py b/dsl/ops/rms_norm_dsl.py new file mode 100644 index 00000000..1326a824 --- /dev/null +++ b/dsl/ops/rms_norm_dsl.py @@ -0,0 +1,28 @@ +"""DSL alternative implementation for RmsNorm (impl_index=1). + +Registers as ``Operator`` alongside the existing +hand-written ``Operator``. +""" + +from dsl.decorators import infini_op +from dsl.primitives import Scalar, Tensor, reduce_mean, rsqrt + + +@infini_op( + name="RmsNorm", + impl_index=1, + shapes={"B": "batch_size", "H": "nhead", "D": "dim"}, + manual_backends={ + "ascend": "ascend/rms_norm/kernel.h", + "cambricon": "cambricon/rms_norm/rms_norm.h", + }, +) +def rms_norm_dsl( + input: Tensor["B", "H", "D"], + weight: Tensor["D"], + eps: Scalar[float] = 1e-6, +) -> Tensor["B", "H", "D"]: + ss = reduce_mean(input * input, dim="D") + rms = rsqrt(ss + eps) + + return input * rms * weight diff --git a/dsl/ops/rotary_embedding.py b/dsl/ops/rotary_embedding.py new file mode 100644 index 00000000..409fafd1 --- /dev/null +++ b/dsl/ops/rotary_embedding.py @@ -0,0 +1,13 @@ +from dsl.decorators import manual_op + + +@manual_op( + name="RotaryEmbedding", + base="src/base/rotary_embedding.h", + backends={ + "cuda": "cuda/rotary_embedding/kernel.h", + "ascend": "ascend/rotary_embedding/kernel.h", + }, +) +def rotary_embedding(): + ... diff --git a/dsl/ops/swiglu.py b/dsl/ops/swiglu.py new file mode 100644 index 00000000..730a6d9f --- /dev/null +++ b/dsl/ops/swiglu.py @@ -0,0 +1,14 @@ +from dsl.decorators import manual_op + + +@manual_op( + name="Swiglu", + base="src/base/swiglu.h", + backends={ + "cuda": "cuda/swiglu/kernel.h", + "ascend": "ascend/swiglu/kernel.h", + "cpu": "cpu/swiglu/swiglu.h", + }, +) +def swiglu(): + ... diff --git a/dsl/ops/swiglu_dsl.py b/dsl/ops/swiglu_dsl.py new file mode 100644 index 00000000..d931cf55 --- /dev/null +++ b/dsl/ops/swiglu_dsl.py @@ -0,0 +1,25 @@ +"""DSL alternative implementation for Swiglu (impl_index=1). + +SwiGLU(input, gate) = input * silu(gate). + +Registers as ``Operator`` alongside the existing +hand-written ``Operator``. +""" + +from dsl.decorators import infini_op +from dsl.primitives import Tensor, silu + + +@infini_op( + name="Swiglu", + impl_index=1, + shapes={"N": "output_size"}, + manual_backends={ + "ascend": "ascend/swiglu/kernel.h", + }, +) +def swiglu_dsl( + input: Tensor["N"], + other: Tensor["N"], +) -> Tensor["N"]: + return input * silu(other) diff --git a/dsl/primitives.py b/dsl/primitives.py new file mode 100644 index 00000000..c9e79fdd --- /dev/null +++ b/dsl/primitives.py @@ -0,0 +1,144 @@ +"""DSL primitive types and functions for `@infini_op` definitions. + +These are used purely for type annotation and AST parsing — they have +no runtime behavior. The function bodies serve as PyTorch-compatible +reference implementations for testing. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Generic, TypeVar + +if TYPE_CHECKING: + import torch + +T = TypeVar("T") + + +# ---- Type annotations ----------------------------------------------------- + + +class Tensor: + """Annotates a tensor parameter with shape variables. + + Usage: ``input: Tensor["B", "H", "D"]`` + """ + + def __class_getitem__(cls, item: Any) -> Any: + return cls + + +class Scalar(Generic[T]): + """Annotates a scalar parameter. + + Usage: ``eps: Scalar[float] = 1e-6`` + """ + + pass + + +# ---- Elementwise functions ------------------------------------------------- + + +def sqrt(x: torch.Tensor) -> torch.Tensor: + return torch.sqrt(x) + + +def rsqrt(x: torch.Tensor) -> torch.Tensor: + return torch.rsqrt(x) + + +def exp(x: torch.Tensor) -> torch.Tensor: + return torch.exp(x) + + +def log(x: torch.Tensor) -> torch.Tensor: + return torch.log(x) + + +def abs(x: torch.Tensor) -> torch.Tensor: + return torch.abs(x) + + +def neg(x: torch.Tensor) -> torch.Tensor: + return -x + + +def relu(x: torch.Tensor) -> torch.Tensor: + return torch.relu(x) + + +def gelu(x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.gelu(x) + + +def silu(x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.silu(x) + + +def sigmoid(x: torch.Tensor) -> torch.Tensor: + return torch.sigmoid(x) + + +def tanh(x: torch.Tensor) -> torch.Tensor: + return torch.tanh(x) + + +# ---- Reduction functions --------------------------------------------------- + + +def reduce_sum( + x: torch.Tensor, + dim: str | int = -1, +) -> torch.Tensor: + return torch.sum(x, dim=-1, keepdim=True) + + +def reduce_mean( + x: torch.Tensor, + dim: str | int = -1, +) -> torch.Tensor: + return torch.mean(x, dim=-1, keepdim=True) + + +def reduce_max( + x: torch.Tensor, + dim: str | int = -1, +) -> torch.Tensor: + return torch.max(x, dim=-1, keepdim=True).values + + +def reduce_min( + x: torch.Tensor, + dim: str | int = -1, +) -> torch.Tensor: + return torch.min(x, dim=-1, keepdim=True).values + + +# ---- Conditional ----------------------------------------------------------- + + +def where( + cond: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, +) -> torch.Tensor: + return torch.where(cond, a, b) + + +# ---- Type ------------------------------------------------------------------- + + +def cast(x: torch.Tensor, dtype: Any) -> torch.Tensor: + return x.to(dtype) + + +# ---- Clamp ------------------------------------------------------------------ + + +def clamp( + x: torch.Tensor, + min: float | None = None, + max: float | None = None, +) -> torch.Tensor: + return torch.clamp(x, min=min, max=max) diff --git a/dsl/tests/__init__.py b/dsl/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/dsl/tests/test_compiler.py b/dsl/tests/test_compiler.py new file mode 100644 index 00000000..b1760563 --- /dev/null +++ b/dsl/tests/test_compiler.py @@ -0,0 +1,156 @@ +"""Tests for the DSL compiler pipeline.""" + +from __future__ import annotations + + +from dsl.compiler.dag import NodeKind +from dsl.compiler.parser import parse_infini_op +from dsl.compiler.patterns import BrickKind, match_dag +from dsl.compiler.infini_codegen import generate_cuda_kernel, generate_cpu_kernel +from dsl.decorators import InfiniOpDef + + +# ---- Helpers --------------------------------------------------------------- + + +def _make_add_op() -> InfiniOpDef: + """Create a simple binary add @infini_op.""" + + def add_fn(input, other): + return input + other + + return InfiniOpDef( + name="TestAdd", + shapes={"N": "output_size"}, + func=add_fn, + ) + + +def _make_rms_norm_op() -> InfiniOpDef: + """Create an RmsNorm-like @infini_op.""" + + def rms_norm_fn(input, weight, eps=1e-6): + from dsl.primitives import reduce_mean, rsqrt + + ss = reduce_mean(input * input, dim="D") + rms = rsqrt(ss + eps) + + return input * rms * weight + + return InfiniOpDef( + name="TestRmsNorm", + shapes={"B": "batch_size", "H": "nhead", "D": "dim"}, + func=rms_norm_fn, + ) + + +# ---- Parser tests ---------------------------------------------------------- + + +class TestParser: + def test_parse_add(self) -> None: + op = _make_add_op() + dag = parse_infini_op(op) + + assert dag.output_id is not None + assert len(dag.nodes) > 0 + + # Should have 2 inputs and 1 add. + inputs = [n for n in dag.nodes.values() if n.kind == NodeKind.INPUT] + adds = [n for n in dag.nodes.values() if n.kind == NodeKind.ADD] + assert len(inputs) == 2 + assert len(adds) == 1 + + def test_parse_rms_norm(self) -> None: + op = _make_rms_norm_op() + dag = parse_infini_op(op) + + assert dag.output_id is not None + assert dag.has_reduction() + + reductions = dag.reduction_nodes() + assert len(reductions) == 1 + assert reductions[0].kind == NodeKind.REDUCE_MEAN + + def test_elementwise_only(self) -> None: + op = _make_add_op() + dag = parse_infini_op(op) + assert dag.is_elementwise_only() + + def test_topo_sort(self) -> None: + op = _make_add_op() + dag = parse_infini_op(op) + topo = dag.topo_sort() + + # Output should be last. + assert topo[-1] == dag.output_id + + +# ---- Pattern matching tests ------------------------------------------------ + + +class TestPatterns: + def test_match_add(self) -> None: + op = _make_add_op() + dag = parse_infini_op(op) + result = match_dag(dag) + + assert result.brick == BrickKind.BINARY_ELEMENTWISE + + def test_match_rms_norm(self) -> None: + op = _make_rms_norm_op() + dag = parse_infini_op(op) + result = match_dag(dag) + + assert result.brick == BrickKind.REDUCE_THEN_TRANSFORM + assert result.reduce_nodes is not None + assert result.transform_nodes is not None + + +# ---- Code generation tests ------------------------------------------------ + + +class TestCodegen: + def test_cuda_add(self) -> None: + op = _make_add_op() + dag = parse_infini_op(op) + match = match_dag(dag) + code = generate_cuda_kernel(op, dag, match) + + assert "#ifndef" in code + assert "TestAddOp" in code + assert "BinaryElementwiseBrick" in code + assert "va + vb" in code + + def test_cpu_add(self) -> None: + op = _make_add_op() + dag = parse_infini_op(op) + match = match_dag(dag) + code = generate_cpu_kernel(op, dag, match) + + assert "#ifndef" in code + assert "CpuTestAddOp" in code + assert "CpuBinaryElementwise" in code + + def test_cuda_rms_norm(self) -> None: + op = _make_rms_norm_op() + dag = parse_infini_op(op) + match = match_dag(dag) + code = generate_cuda_kernel(op, dag, match) + + assert "TestRmsNormReduce" in code + assert "TestRmsNormTransform" in code + assert "LaunchReduceThenTransform" in code + assert "rsqrtf" in code + assert "epsilon" in code + + def test_cpu_rms_norm(self) -> None: + op = _make_rms_norm_op() + dag = parse_infini_op(op) + match = match_dag(dag) + code = generate_cpu_kernel(op, dag, match) + + assert "CpuTestRmsNormReduce" in code + assert "CpuTestRmsNormTransform" in code + assert "CpuReduceThenTransform" in code + assert "std::sqrt" in code diff --git a/examples/runtime_api.h b/examples/runtime_api.h index 4c7469fe..d8bcb7fc 100644 --- a/examples/runtime_api.h +++ b/examples/runtime_api.h @@ -4,7 +4,7 @@ #include "device.h" #ifdef WITH_NVIDIA -#include "nvidia/gemm/cublas.h" +#include "nvidia/gemm/kernel.h" #include "nvidia/gemm/cublaslt.h" #include "nvidia/runtime_.h" #elif WITH_ILUVATAR @@ -19,6 +19,9 @@ #elif WITH_MOORE #include "moore/gemm/mublas.h" #include "moore/runtime_.h" +#elif WITH_ASCEND +#include "ascend/gemm/kernel.h" +#include "ascend/runtime_.h" #elif WITH_CPU #include "cpu/gemm/gemm.h" #include "cpu/runtime_.h" @@ -38,6 +41,8 @@ using DefaultRuntimeUtils = Runtime; using DefaultRuntimeUtils = Runtime; #elif WITH_MOORE using DefaultRuntimeUtils = Runtime; +#elif WITH_ASCEND +using DefaultRuntimeUtils = Runtime; #elif WITH_CPU using DefaultRuntimeUtils = Runtime; #endif diff --git a/scripts/bindings_overrides/cat.h b/scripts/bindings_overrides/cat.h new file mode 100644 index 00000000..cd34a02d --- /dev/null +++ b/scripts/bindings_overrides/cat.h @@ -0,0 +1,68 @@ +#ifndef INFINI_OPS_BINDINGS_CAT_H_ +#define INFINI_OPS_BINDINGS_CAT_H_ + +#include +#include + +#include "base/cat.h" +#include "config.h" +#include "pybind11_utils.h" + +namespace py = pybind11; + +namespace infini::ops { + +inline std::vector TensorListFromPybind11(py::list list) { + std::vector result; + result.reserve(py::len(list)); + + for (auto& item : list) { + result.push_back(TensorFromPybind11Handle(item)); + } + + return result; +} + +void BindCat(py::module& m) { + using Self = Cat; + + py::class_(m, "Cat") + .def(py::init([](py::object first_input, py::list rest_inputs, + int64_t dim, py::object out) { + return std::unique_ptr{static_cast( + Self::make(TensorFromPybind11Handle(first_input), + TensorListFromPybind11(rest_inputs), dim, + TensorFromPybind11Handle(out)) + .release())}; + })) + .def("__call__", + [](const Self& self, py::object first_input, py::list rest_inputs, + int64_t dim, py::object out) { + return static_cast&>(self)( + TensorFromPybind11Handle(first_input), + TensorListFromPybind11(rest_inputs), dim, + TensorFromPybind11Handle(out)); + }) + .def_static("active_implementation_indices", + [](const std::string& device) { + return Self::active_implementation_indices( + DeviceTypeFromString(device)); + }); + + m.def( + "cat", + [](py::object first_input, py::list rest_inputs, int64_t dim, + py::object out, std::size_t implementation_index) { + Config config; + config.set_implementation_index(implementation_index); + return Self::call({}, config, TensorFromPybind11Handle(first_input), + TensorListFromPybind11(rest_inputs), dim, + TensorFromPybind11Handle(out)); + }, + py::arg("first_input"), py::arg("rest_inputs"), py::arg("dim"), + py::arg("out"), py::kw_only(), py::arg("implementation_index") = 0); +} + +} // namespace infini::ops + +#endif diff --git a/scripts/generate_wrappers.py b/scripts/generate_wrappers.py index 5aa8896e..24c7ae90 100644 --- a/scripts/generate_wrappers.py +++ b/scripts/generate_wrappers.py @@ -1,6 +1,7 @@ import argparse import json import pathlib +import re import shutil import subprocess import textwrap @@ -20,6 +21,8 @@ _INCLUDE_DIR = _GENERATION_DIR / "include" +_BINDINGS_OVERRIDES_DIR = pathlib.Path("scripts") / "bindings_overrides" + _INDENTATION = " " @@ -91,26 +94,77 @@ def __init__(self, name, constructors, calls): self.calls = calls -def _generate_pybind11(operator): +def _find_optional_tensor_params(op_name): + """Return a set of parameter names declared as `std::optional` in + the base header. libclang resolves the type to ``int`` when the STL + headers are not fully available, so we fall back to a regex scan of the + source text. + """ + source = (_BASE_DIR / f"{op_name}.h").read_text() + return set(re.findall(r"std::optional\s+(\w+)", source)) + + +def _find_vector_tensor_params(op_name): + """Return a set of parameter names declared as `std::vector` in + the base header. + """ + import re + + source = (_BASE_DIR / f"{op_name}.h").read_text() + return set(re.findall(r"std::vector\s+(\w+)", source)) + + +def _generate_pybind11(operator, impl_names=None): + optional_tensor_params = _find_optional_tensor_params(operator.name) + vector_tensor_params = _find_vector_tensor_params(operator.name) + + if impl_names is None: + impl_names = {} + + def _is_optional_tensor(arg): + if arg.spelling in optional_tensor_params: + return True + return "std::optional" in arg.type.spelling and "Tensor" in arg.type.spelling + + def _is_vector_tensor(arg): + if arg.spelling in vector_tensor_params: + return True + return "std::vector" in arg.type.spelling and "Tensor" in arg.type.spelling + def _generate_params(node): - return ( - ", ".join( - f"{arg.type.spelling} {arg.spelling}" - for arg in node.get_arguments() - if arg.spelling != "stream" - ) - .replace("const Tensor", "py::object") - .replace("Tensor", "py::object") - ) + parts = [] + + for arg in node.get_arguments(): + if arg.spelling == "stream": + continue + if _is_optional_tensor(arg): + parts.append(f"std::optional {arg.spelling}") + elif _is_vector_tensor(arg): + parts.append(f"std::vector {arg.spelling}") + else: + param = arg.type.spelling.replace("const Tensor", "py::object").replace( + "Tensor", "py::object" + ) + parts.append(f"{param} {arg.spelling}") + + return ", ".join(parts) def _generate_arguments(node): - return ", ".join( - f"TensorFromPybind11Handle({arg.spelling})" - if "Tensor" in arg.type.spelling - else arg.spelling - for arg in node.get_arguments() - if arg.spelling != "stream" - ) + args = [] + + for arg in node.get_arguments(): + if arg.spelling == "stream": + continue + if _is_optional_tensor(arg): + args.append(f"OptionalTensorFromPybind11Handle({arg.spelling})") + elif _is_vector_tensor(arg): + args.append(f"VectorTensorFromPybind11Handle({arg.spelling})") + elif "Tensor" in arg.type.spelling: + args.append(f"TensorFromPybind11Handle({arg.spelling})") + else: + args.append(arg.spelling) + + return ", ".join(args) op_name = operator.name @@ -133,19 +187,61 @@ def _generate_call(op_name, call, method=True): call_args = _generate_arguments(call) if not method: - params = ( - f"{call_params}, std::size_t implementation_index" + # Overload 1: implementation_index (numeric, backward compatible). + params_idx = ( + f"{call_params}, std::size_t implementation_index, std::uintptr_t stream" if call_params - else "std::size_t implementation_index" + else "std::size_t implementation_index, std::uintptr_t stream" ) py_args = _generate_py_args(call) py_args_str = f"{py_args}, " if py_args else "" - return f""" m.def("{op_name}", []({params}) {{ - Config config; - config.set_implementation_index(implementation_index); - return Self::call({{}}, config, {call_args}); - }}, {py_args_str}py::kw_only(), py::arg("implementation_index") = 0);""" + overload_idx = ( + f' m.def("{op_name}", []({params_idx}) {{\n' + f" Config config;\n" + f" config.set_implementation_index(implementation_index);\n" + f" Handle handle;\n" + f" if (stream) {{\n" + f" handle.set_stream(reinterpret_cast(stream));\n" + f" }}\n" + f" return Self::call(handle, config, {call_args});\n" + f' }}, {py_args_str}py::kw_only(), py::arg("implementation_index") = 0, py::arg("stream") = 0);' + ) + + # Overload 2: implementation (string name, e.g. "dsl"). + # Only generate if there are named implementations. + if not impl_names: + return overload_idx + + # Build C++ initializer list for the per-operator map. + map_entries = ", ".join( + f'{{"{name}", {idx}}}' for name, idx in impl_names.items() + ) + valid_names = ", ".join(f"'{n}'" for n in impl_names) + + params_str = ( + f"{call_params}, const std::string& implementation, std::uintptr_t stream" + if call_params + else "const std::string& implementation, std::uintptr_t stream" + ) + + overload_str = ( + f' m.def("{op_name}", []({params_str}) {{\n' + f" static const std::unordered_map kImplNames{{{{{map_entries}}}}};\n" + f" auto it = kImplNames.find(implementation);\n" + f' if (it == kImplNames.end()) throw py::value_error(\n' + f' "unknown implementation: \'" + implementation + "\' (valid: {valid_names})");\n' + f" Config config;\n" + f" config.set_implementation_index(it->second);\n" + f" Handle handle;\n" + f" if (stream) {{\n" + f" handle.set_stream(reinterpret_cast(stream));\n" + f" }}\n" + f" return Self::call(handle, config, {call_args});\n" + f' }}, {py_args_str}py::kw_only(), py::arg("implementation"), py::arg("stream") = 0);' + ) + + return f"{overload_idx}\n{overload_str}" return f""" .def("__call__", [](const Self& self, {call_params}) {{ return static_cast&>(self)({call_args}); @@ -169,6 +265,8 @@ def _generate_call(op_name, call, method=True): #include "base/{op_name}.h" #include "config.h" +#include "handle.h" +#include "operator.h" #include "pybind11_utils.h" namespace py = pybind11; @@ -417,6 +515,14 @@ def _get_all_ops(devices): else: ops = _get_all_ops(args.devices) + # Load per-operator implementation name mappings (generated by DSL compiler). + impl_names_path = _GENERATION_DIR / "impl_names.json" + + if impl_names_path.exists(): + all_impl_names = json.loads(impl_names_path.read_text()) + else: + all_impl_names = {} + header_paths = [] bind_func_names = [] @@ -424,11 +530,24 @@ def _get_all_ops(devices): extractor = _OperatorExtractor() operator = extractor(op_name) + pascal_name = _snake_to_pascal(op_name) + op_impl_names = all_impl_names.get(pascal_name, {}) + source_path = _GENERATED_SRC_DIR / op_name header_name = f"{op_name}.h" - bind_func_name = f"Bind{_snake_to_pascal(op_name)}" - - (_BINDINGS_DIR / header_name).write_text(_generate_pybind11(operator)) + bind_func_name = f"Bind{pascal_name}" + + binding_path = _BINDINGS_DIR / header_name + override_path = _BINDINGS_OVERRIDES_DIR / header_name + + # Use a hand-written binding if one exists in the overrides directory; + # otherwise auto-generate. + if override_path.exists(): + binding_path.write_text(override_path.read_text()) + else: + binding_path.write_text( + _generate_pybind11(operator, op_impl_names) + ) legacy_c_source, legacy_c_header = _generate_legacy_c(operator, impl_paths) source_path.mkdir(exist_ok=True) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 0b56341b..3827cdd3 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -42,6 +42,11 @@ if(WITH_NVIDIA) find_package(CUDAToolkit REQUIRED) target_link_libraries(infiniops PUBLIC CUDA::cudart CUDA::cublas CUDA::cublasLt CUDA::cuda_driver) + target_include_directories(infiniops PUBLIC + "${PROJECT_SOURCE_DIR}/third_party/flashinfer/include" + "${PROJECT_SOURCE_DIR}/third_party/flashinfer/3rdparty/cutlass/include" + ) + list(APPEND DEVICE_LIST "nvidia") set_target_properties(infiniops PROPERTIES CUDA_STANDARD 17 @@ -172,20 +177,73 @@ if(WITH_CAMBRICON) list(APPEND DEVICE_LIST "cambricon") endif() -target_include_directories(infiniops PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) +if(WITH_ASCEND) + # ASCEND_HOME is set by the top-level CMakeLists.txt. + file(GLOB_RECURSE ASCEND_SOURCES CONFIGURE_DEPENDS + "ascend/*.cc" + "ascend/*.cpp" + ) + # Exclude kernel_impl.cpp — AscendC device code, not compiled by the host C++ compiler. + list(FILTER ASCEND_SOURCES EXCLUDE REGEX ".*kernel_impl\\.cpp$") + + target_compile_definitions(infiniops PUBLIC WITH_ASCEND=1) + target_sources(infiniops PRIVATE ${ASCEND_SOURCES}) + + # Resolve the driver lib dir two levels above the toolkit root. + get_filename_component(ASCEND_ROOT "${ASCEND_HOME}/../.." ABSOLUTE) + + # Prefer the real driver HAL; fall back to the toolkit stub for build-only + # environments (e.g., Docker CI images without hardware drivers installed). + # CANN <= 8.0: stub at runtime/lib64/stub/; CANN >= 8.5: devlib/-linux/devlib/. + set(ASCEND_HAL_REAL "${ASCEND_ROOT}/driver/lib64/driver/libascend_hal.so") + set(ASCEND_HAL_STUB "${ASCEND_HOME}/runtime/lib64/stub/libascend_hal.so") + set(ASCEND_HAL_DEVLIB "${ASCEND_HOME}/${CMAKE_SYSTEM_PROCESSOR}-linux/devlib/libascend_hal.so") + if(EXISTS "${ASCEND_HAL_REAL}") + set(ASCEND_HAL_LIB "${ASCEND_HAL_REAL}") + elseif(EXISTS "${ASCEND_HAL_STUB}") + set(ASCEND_HAL_LIB "${ASCEND_HAL_STUB}") + message(STATUS "ascend_hal: driver not found, using stub for linking") + elseif(EXISTS "${ASCEND_HAL_DEVLIB}") + set(ASCEND_HAL_LIB "${ASCEND_HAL_DEVLIB}") + message(STATUS "ascend_hal: driver not found, using devlib for linking") + else() + message(FATAL_ERROR "libascend_hal.so not found (tried ${ASCEND_HAL_REAL}, ${ASCEND_HAL_STUB}, and ${ASCEND_HAL_DEVLIB})") + endif() + + target_include_directories(infiniops PUBLIC + "${ASCEND_HOME}/include" + "${ASCEND_HOME}/include/aclnn" + "${ASCEND_HOME}/include/aclnnop") + target_link_libraries(infiniops PUBLIC + "${ASCEND_HOME}/lib64/libascendcl.so" + "${ASCEND_HOME}/lib64/libnnopbase.so" + "${ASCEND_HOME}/lib64/libopapi.so" + "${ASCEND_HAL_LIB}") + + list(APPEND DEVICE_LIST "ascend") +endif() + +target_include_directories(infiniops PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR} + ${PROJECT_SOURCE_DIR}/generated +) if(GENERATE_PYTHON_BINDINGS) find_package(Python COMPONENTS Interpreter REQUIRED) + # Always regenerate bindings so the included kernel headers match the + # active device list. Stale generated files (e.g., committed for one + # platform) would omit specializations for other enabled backends, + # causing link-time or runtime failures. execute_process( - COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/scripts/generate_wrappers.py --devices ${DEVICE_LIST} + COMMAND ${Python_EXECUTABLE} -m dsl --devices ${DEVICE_LIST} WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} RESULT_VARIABLE script_result ) if(NOT script_result EQUAL 0) - message(FATAL_ERROR "Generating wrappers - failed") + message(FATAL_ERROR "DSL compilation and binding generation - failed") else() - message(STATUS "Generating wrappers - done") + message(STATUS "DSL compilation and binding generation - done") endif() set(PYBIND11_SOURCES "${PROJECT_SOURCE_DIR}/generated/bindings/ops.cc") @@ -204,7 +262,7 @@ if(GENERATE_PYTHON_BINDINGS) pybind11_add_module(ops NO_EXTRAS ${PYBIND11_SOURCES}) endif() - target_include_directories(ops PRIVATE ${PROJECT_SOURCE_DIR}) + target_include_directories(ops PRIVATE ${PROJECT_SOURCE_DIR} ${PROJECT_SOURCE_DIR}/generated) target_link_libraries(ops PRIVATE infiniops) set_target_properties(infiniops PROPERTIES INSTALL_RPATH "$ORIGIN") diff --git a/src/ascend/add/kernel.h b/src/ascend/add/kernel.h new file mode 100644 index 00000000..650edebb --- /dev/null +++ b/src/ascend/add/kernel.h @@ -0,0 +1,81 @@ +#ifndef INFINI_OPS_ASCEND_ADD_KERNEL_H_ +#define INFINI_OPS_ASCEND_ADD_KERNEL_H_ + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnn_add.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/add.h" +#include "data_type.h" +#include "operator.h" + +namespace infini::ops { + +template <> +class Operator : public Add { + public: + Operator(const Tensor input, const Tensor other, Tensor out) + : Add(input, other, out), + in_cache_(input), + oth_cache_(other), + out_cache_(out) { + // aclCreateScalar stores the pointer rather than copying the value, so + // alpha_storage_* must remain alive for the lifetime of alpha_. + // The alpha scalar type must match the tensor dtype: use int64 for integer + // dtypes and float for floating-point dtypes. + if (ascend::isIntegerDtype(input.dtype())) { + alpha_ = aclCreateScalar(&alpha_int_storage_, ACL_INT64); + } else { + alpha_ = aclCreateScalar(&alpha_float_storage_, ACL_FLOAT); + } + } + + ~Operator() { + if (executor_) aclDestroyAclOpExecutor(executor_); + aclDestroyScalar(alpha_); + } + + void operator()(const Tensor input, const Tensor other, + Tensor out) const override { + auto stream = static_cast(stream_); + auto t_in = in_cache_.get(const_cast(input.data())); + auto t_oth = oth_cache_.get(const_cast(other.data())); + auto t_out = out_cache_.get(out.data()); + + if (!executor_) { + aclnnAddGetWorkspaceSize(t_in, t_oth, alpha_, t_out, &ws_size_, + &executor_); + aclSetAclOpExecutorRepeatable(executor_); + } else { + aclSetInputTensorAddr(executor_, 0, t_in, + const_cast(input.data())); + aclSetInputTensorAddr(executor_, 1, t_oth, + const_cast(other.data())); + aclSetOutputTensorAddr(executor_, 0, t_out, out.data()); + } + + auto& arena = ascend::workspacePool().ensure(stream, ws_size_); + aclnnAdd(arena.buf, ws_size_, executor_, stream); + } + + private: + mutable ascend::AclTensorCache in_cache_; + + mutable ascend::AclTensorCache oth_cache_; + + mutable ascend::AclTensorCache out_cache_; + + mutable aclOpExecutor* executor_ = nullptr; + + mutable uint64_t ws_size_ = 0; + + float alpha_float_storage_ = + 1.0f; // stable address for aclCreateScalar (float) + int64_t alpha_int_storage_ = 1; // stable address for aclCreateScalar (int) + aclScalar* alpha_ = nullptr; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/add_rms_norm/kernel.h b/src/ascend/add_rms_norm/kernel.h new file mode 100644 index 00000000..4f9670a2 --- /dev/null +++ b/src/ascend/add_rms_norm/kernel.h @@ -0,0 +1,127 @@ +#ifndef INFINI_OPS_ASCEND_ADD_RMS_NORM_KERNEL_H_ +#define INFINI_OPS_ASCEND_ADD_RMS_NORM_KERNEL_H_ + +#include + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnn_add.h" +#include "aclnn_rms_norm.h" +#include "ascend/common.h" +#include "ascend/add_rms_norm/registry.h" +#include "ascend/workspace_pool_.h" +#include "operator.h" + +namespace infini::ops { + +// Decomposed implementation: aclnnAdd + aclnnRmsNorm. +// +// The fused aclnnAddRmsNorm API has ~200 us host-side launch overhead that +// dominates small-tensor dispatch. Decomposing into two fast ACLNN calls +// reduces host dispatch from ~224 us to ~56 us (4x faster) with negligible +// NPU-side impact for inference tensor sizes. +template <> +class Operator : public AddRmsNorm { + public: + Operator(const Tensor x1, const Tensor x2, const Tensor gamma, float eps, + Tensor y_out, Tensor x_out) + : AddRmsNorm(x1, x2, gamma, eps, y_out, x_out), + x1_cache_(x1), + x2_cache_(x2), + gamma_cache_(gamma), + y_out_cache_(y_out), + x_out_cache_(x_out) { + // Alpha scalar for aclnnAdd (x_out = x1 + 1.0 * x2). + alpha_ = aclCreateScalar(&alpha_storage_, ACL_FLOAT); + + // aclnnRmsNorm writes rstd as a required side output. + rstd_shape_ = {static_cast(batch_size_), + static_cast(nhead_)}; + size_t rstd_bytes = batch_size_ * nhead_ * sizeof(float); + aclrtMalloc(&rstd_data_, rstd_bytes, ACL_MEM_MALLOC_NORMAL_ONLY); + + rstd_tensor_ = aclCreateTensor(rstd_shape_.data(), 2, ACL_FLOAT, + /*strides=*/nullptr, 0, ACL_FORMAT_ND, + rstd_shape_.data(), 2, rstd_data_); + } + + ~Operator() { + if (add_exec_) aclDestroyAclOpExecutor(add_exec_); + if (norm_exec_) aclDestroyAclOpExecutor(norm_exec_); + aclDestroyScalar(alpha_); + if (rstd_tensor_) aclDestroyTensor(rstd_tensor_); + if (rstd_data_) aclrtFree(rstd_data_); + } + + void operator()(const Tensor x1, const Tensor x2, const Tensor gamma, + float eps, Tensor y_out, Tensor x_out) const override { + auto t_x1 = x1_cache_.get(const_cast(x1.data())); + auto t_x2 = x2_cache_.get(const_cast(x2.data())); + auto t_gamma = gamma_cache_.get(const_cast(gamma.data())); + auto t_y_out = y_out_cache_.get(y_out.data()); + auto t_x_out = x_out_cache_.get(x_out.data()); + auto stream = static_cast(stream_); + + // Step 1: x_out = x1 + x2. + if (!add_exec_) { + aclnnAddGetWorkspaceSize(t_x1, t_x2, alpha_, t_x_out, &add_ws_, + &add_exec_); + aclSetAclOpExecutorRepeatable(add_exec_); + } else { + aclSetInputTensorAddr(add_exec_, 0, t_x1, + const_cast(x1.data())); + aclSetInputTensorAddr(add_exec_, 1, t_x2, + const_cast(x2.data())); + aclSetOutputTensorAddr(add_exec_, 0, t_x_out, x_out.data()); + } + auto& add_arena = ascend::workspacePool().ensure(stream, add_ws_); + aclnnAdd(add_arena.buf, add_ws_, add_exec_, stream); + + // Step 2: y_out = rms_norm(x_out, gamma, eps). + if (!norm_exec_) { + aclnnRmsNormGetWorkspaceSize(t_x_out, t_gamma, eps, t_y_out, + rstd_tensor_, &norm_ws_, &norm_exec_); + aclSetAclOpExecutorRepeatable(norm_exec_); + } else { + aclSetInputTensorAddr(norm_exec_, 0, t_x_out, x_out.data()); + aclSetInputTensorAddr(norm_exec_, 1, t_gamma, + const_cast(gamma.data())); + aclSetOutputTensorAddr(norm_exec_, 0, t_y_out, y_out.data()); + } + auto& norm_arena = ascend::workspacePool().ensure(stream, norm_ws_); + aclnnRmsNorm(norm_arena.buf, norm_ws_, norm_exec_, stream); + } + + private: + mutable ascend::AclTensorCache x1_cache_; + + mutable ascend::AclTensorCache x2_cache_; + + mutable ascend::AclTensorCache gamma_cache_; + + mutable ascend::AclTensorCache y_out_cache_; + + mutable ascend::AclTensorCache x_out_cache_; + + float alpha_storage_ = 1.0f; + + aclScalar* alpha_ = nullptr; + + std::vector rstd_shape_; + + void* rstd_data_ = nullptr; + + aclTensor* rstd_tensor_ = nullptr; + + mutable aclOpExecutor* add_exec_ = nullptr; + + mutable uint64_t add_ws_ = 0; + + mutable aclOpExecutor* norm_exec_ = nullptr; + + mutable uint64_t norm_ws_ = 0; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/add_rms_norm/kernel_fused.h b/src/ascend/add_rms_norm/kernel_fused.h new file mode 100644 index 00000000..2959a73f --- /dev/null +++ b/src/ascend/add_rms_norm/kernel_fused.h @@ -0,0 +1,124 @@ +#ifndef INFINI_OPS_ASCEND_ADD_RMS_NORM_KERNEL_FUSED_H_ +#define INFINI_OPS_ASCEND_ADD_RMS_NORM_KERNEL_FUSED_H_ + +#include + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnnop/aclnn_add_rms_norm.h" +#include "ascend/add_rms_norm/registry.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "operator.h" + +namespace infini::ops { + +// Fused implementation via aclnnAddRmsNorm (implementation index 1). +// +// Computes x_out = x1 + x2 and y_out = rms_norm(x_out, gamma, eps) in a +// single CANN launch. The fused API has higher host-side launch overhead +// (~200 us) compared to the decomposed aclnnAdd + aclnnRmsNorm path (~39 us), +// but may offer better NPU-side efficiency for large tensors where kernel +// fusion reduces memory traffic. +// +// Select via `implementation_index=1` in Python: +// infini.ops.add_rms_norm(..., implementation_index=1, stream=s) +template <> +class Operator : public AddRmsNorm { + public: + Operator(const Tensor x1, const Tensor x2, const Tensor gamma, float eps, + Tensor y_out, Tensor x_out) + : AddRmsNorm(x1, x2, gamma, eps, y_out, x_out), + x1_cache_(x1), + x2_cache_(x2), + gamma_cache_(gamma), + y_out_cache_(y_out), + x_out_cache_(x_out) { + // aclnnAddRmsNorm requires rstdOut to have the same ndim as x1, with + // the last gamma.ndim() dimensions set to 1. For example: + // x1 shape(2, 32, 128), gamma shape(128) -> rstdOut shape(2, 32, 1) + // x1 shape(64, 128), gamma shape(128) -> rstdOut shape(64, 1) + fused_rstd_shape_.reserve(ndim_); + for (size_t i = 0; i < ndim_ - gamma.ndim(); ++i) { + fused_rstd_shape_.push_back(static_cast(x1.size(i))); + } + for (size_t i = 0; i < gamma.ndim(); ++i) { + fused_rstd_shape_.push_back(1); + } + + size_t rstd_elems = 1; + for (auto d : fused_rstd_shape_) { + rstd_elems *= static_cast(d); + } + size_t rstd_bytes = rstd_elems * sizeof(float); + aclrtMalloc(&rstd_data_, rstd_bytes, ACL_MEM_MALLOC_NORMAL_ONLY); + + rstd_tensor_ = aclCreateTensor( + fused_rstd_shape_.data(), + static_cast(fused_rstd_shape_.size()), ACL_FLOAT, + /*strides=*/nullptr, 0, ACL_FORMAT_ND, fused_rstd_shape_.data(), + static_cast(fused_rstd_shape_.size()), rstd_data_); + } + + ~Operator() { + if (executor_) aclDestroyAclOpExecutor(executor_); + if (rstd_tensor_) aclDestroyTensor(rstd_tensor_); + if (rstd_data_) aclrtFree(rstd_data_); + } + + void operator()(const Tensor x1, const Tensor x2, const Tensor gamma, + float eps, Tensor y_out, Tensor x_out) const override { + auto t_x1 = x1_cache_.get(const_cast(x1.data())); + auto t_x2 = x2_cache_.get(const_cast(x2.data())); + auto t_gamma = gamma_cache_.get(const_cast(gamma.data())); + auto t_y_out = y_out_cache_.get(y_out.data()); + auto t_x_out = x_out_cache_.get(x_out.data()); + auto stream = static_cast(stream_); + + if (!executor_) { + aclnnAddRmsNormGetWorkspaceSize(t_x1, t_x2, t_gamma, + static_cast(eps), t_y_out, + rstd_tensor_, t_x_out, &ws_size_, + &executor_); + aclSetAclOpExecutorRepeatable(executor_); + } else { + aclSetInputTensorAddr(executor_, 0, t_x1, + const_cast(x1.data())); + aclSetInputTensorAddr(executor_, 1, t_x2, + const_cast(x2.data())); + aclSetInputTensorAddr(executor_, 2, t_gamma, + const_cast(gamma.data())); + aclSetOutputTensorAddr(executor_, 0, t_y_out, y_out.data()); + // rstd at output index 1 has a stable address — no update needed. + aclSetOutputTensorAddr(executor_, 2, t_x_out, x_out.data()); + } + + auto& arena = ascend::workspacePool().ensure(stream, ws_size_); + aclnnAddRmsNorm(arena.buf, ws_size_, executor_, stream); + } + + private: + mutable ascend::AclTensorCache x1_cache_; + + mutable ascend::AclTensorCache x2_cache_; + + mutable ascend::AclTensorCache gamma_cache_; + + mutable ascend::AclTensorCache y_out_cache_; + + mutable ascend::AclTensorCache x_out_cache_; + + std::vector fused_rstd_shape_; + + void* rstd_data_ = nullptr; + + aclTensor* rstd_tensor_ = nullptr; + + mutable aclOpExecutor* executor_ = nullptr; + + mutable uint64_t ws_size_ = 0; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/add_rms_norm/registry.h b/src/ascend/add_rms_norm/registry.h new file mode 100644 index 00000000..d48de306 --- /dev/null +++ b/src/ascend/add_rms_norm/registry.h @@ -0,0 +1,15 @@ +#ifndef INFINI_OPS_ASCEND_ADD_RMS_NORM_REGISTRY_H_ +#define INFINI_OPS_ASCEND_ADD_RMS_NORM_REGISTRY_H_ + +#include "base/add_rms_norm.h" + +namespace infini::ops { + +template <> +struct ActiveImplementationsImpl { + using type = List<0, 1>; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/cast/kernel.h b/src/ascend/cast/kernel.h new file mode 100644 index 00000000..645f05af --- /dev/null +++ b/src/ascend/cast/kernel.h @@ -0,0 +1,60 @@ +#ifndef INFINI_OPS_ASCEND_CAST_KERNEL_H_ +#define INFINI_OPS_ASCEND_CAST_KERNEL_H_ + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnnop/aclnn_cast.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/cast.h" +#include "operator.h" + +namespace infini::ops { + +template <> +class Operator : public Cast { + public: + Operator(const Tensor input, Tensor out) + : Cast(input, out), + in_cache_(input), + out_cache_(out), + acl_out_dtype_(ascend::toAclDtype(out.dtype())) {} + + ~Operator() { + if (executor_) aclDestroyAclOpExecutor(executor_); + } + + void operator()(const Tensor input, Tensor out) const override { + auto stream = static_cast(stream_); + auto t_in = in_cache_.get(const_cast(input.data())); + auto t_out = out_cache_.get(out.data()); + + if (!executor_) { + aclnnCastGetWorkspaceSize(t_in, acl_out_dtype_, t_out, &ws_size_, + &executor_); + aclSetAclOpExecutorRepeatable(executor_); + } else { + aclSetInputTensorAddr(executor_, 0, t_in, + const_cast(input.data())); + aclSetOutputTensorAddr(executor_, 0, t_out, out.data()); + } + + auto& arena = ascend::workspacePool().ensure(stream, ws_size_); + aclnnCast(arena.buf, ws_size_, executor_, stream); + } + + private: + mutable ascend::AclTensorCache in_cache_; + + mutable ascend::AclTensorCache out_cache_; + + aclDataType acl_out_dtype_; + + mutable aclOpExecutor* executor_ = nullptr; + + mutable uint64_t ws_size_ = 0; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/cat/kernel.h b/src/ascend/cat/kernel.h new file mode 100644 index 00000000..aae90e08 --- /dev/null +++ b/src/ascend/cat/kernel.h @@ -0,0 +1,94 @@ +#ifndef INFINI_OPS_ASCEND_CAT_KERNEL_H_ +#define INFINI_OPS_ASCEND_CAT_KERNEL_H_ + +#include + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnn/acl_meta.h" +#include "aclnnop/aclnn_cat.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/cat.h" +#include "operator.h" + +namespace infini::ops { + +template <> +class Operator : public Cat { + public: + Operator(const Tensor first_input, std::vector rest_inputs, + int64_t dim, Tensor out) + : Cat(first_input, rest_inputs, dim, out), out_cache_(out) { + // Build AclTensorCache for each input tensor. + in_caches_.reserve(input_count_); + in_caches_.emplace_back(first_input); + for (const auto& t : rest_inputs) { + in_caches_.emplace_back(t); + } + } + + ~Operator() { + if (executor_) aclDestroyAclOpExecutor(executor_); + if (tensor_list_) aclDestroyTensorList(tensor_list_); + } + + void operator()(const Tensor first_input, std::vector rest_inputs, + int64_t /*dim*/, Tensor out) const override { + auto stream = static_cast(stream_); + + // Collect all input tensors in order. + std::vector inputs; + inputs.reserve(input_count_); + inputs.push_back(&first_input); + for (const auto& t : rest_inputs) { + inputs.push_back(&t); + } + + auto t_out = out_cache_.get(out.data()); + + if (!executor_) { + // First call: create descriptors, tensor list, and executor. + std::vector acl_tensors(input_count_); + for (size_t i = 0; i < input_count_; ++i) { + acl_tensors[i] = + in_caches_[i].get(const_cast(inputs[i]->data())); + } + + tensor_list_ = aclCreateTensorList( + const_cast(acl_tensors.data()), + static_cast(input_count_)); + + aclnnCatGetWorkspaceSize(tensor_list_, dim_, t_out, &ws_size_, + &executor_); + aclSetAclOpExecutorRepeatable(executor_); + } else { + // Subsequent calls: update data pointers on cached descriptors via + // `aclSetRawTensorAddr`. The executor holds references to the same + // `aclTensor*` objects inside `tensor_list_`, so updating their data + // pointers is sufficient — no `aclSetInputTensorAddr` needed. + for (size_t i = 0; i < input_count_; ++i) { + in_caches_[i].get(const_cast(inputs[i]->data())); + } + aclSetOutputTensorAddr(executor_, 0, t_out, out.data()); + } + + auto& arena = ascend::workspacePool().ensure(stream, ws_size_); + aclnnCat(arena.buf, ws_size_, executor_, stream); + } + + private: + mutable std::vector in_caches_; + + mutable ascend::AclTensorCache out_cache_; + + mutable aclTensorList* tensor_list_ = nullptr; + + mutable aclOpExecutor* executor_ = nullptr; + + mutable uint64_t ws_size_ = 0; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/causal_softmax/kernel.h b/src/ascend/causal_softmax/kernel.h new file mode 100644 index 00000000..a27cb5dc --- /dev/null +++ b/src/ascend/causal_softmax/kernel.h @@ -0,0 +1,158 @@ +#ifndef INFINI_OPS_ASCEND_CAUSAL_SOFTMAX_KERNEL_H_ +#define INFINI_OPS_ASCEND_CAUSAL_SOFTMAX_KERNEL_H_ + +#include +#include + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnn_copy.h" +#include "aclnn_masked_fill_scalar.h" +#include "aclnn_softmax.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/causal_softmax.h" +#include "data_type.h" +#include "operator.h" + +namespace infini::ops { + +// Implements causal softmax via three ACLNN calls: +// 1. InplaceCopy(temp, input) — stride-aware copy to contiguous temp +// buffer. +// 2. InplaceMaskedFillScalar(temp, mask, -inf) — apply upper-triangle mask. +// 3. Softmax(temp, dim=-1, out) — softmax over the last dimension. +// +// The boolean causal mask is pre-computed and uploaded to device once in the +// constructor. Its shape (seq_len, total_seq_len) broadcasts over the batch. +template <> +class Operator : public CausalSoftmax { + public: + Operator(const Tensor input, Tensor out) + : CausalSoftmax(input, out), + in_cache_(input), + out_cache_(out) { + // Contiguous temp buffer with the same element count as input. + size_t n_elems = input.numel(); + size_t elem_bytes = kDataTypeToSize.at(dtype_); + aclrtMalloc(&temp_buf_, n_elems * elem_bytes, ACL_MEM_MALLOC_NORMAL_ONLY); + + // Build a contiguous Tensor descriptor pointing to temp_buf_. + Tensor temp_t{temp_buf_, input.shape(), input.dtype(), input.device()}; + temp_cache_ = ascend::AclTensorCache(temp_t); + + // Causal mask: mask[i][j] = 1 when position j must be masked for query i. + // Shape (seq_len, total_seq_len) – broadcasts over the batch dimension. + size_t mask_elems = seq_len_ * total_seq_len_; + std::vector mask_host(mask_elems, 0); + + for (size_t i = 0; i < seq_len_; ++i) { + auto vis_end = static_cast(total_seq_len_ - seq_len_ + i); + + for (auto j = vis_end + 1; j < static_cast(total_seq_len_); + ++j) { + mask_host[i * total_seq_len_ + j] = 1; + } + } + + aclrtMalloc(&mask_buf_, mask_elems, ACL_MEM_MALLOC_NORMAL_ONLY); + aclrtMemcpy(mask_buf_, mask_elems, mask_host.data(), mask_elems, + ACL_MEMCPY_HOST_TO_DEVICE); + + std::vector mshape = {static_cast(seq_len_), + static_cast(total_seq_len_)}; + std::vector mstrides = {static_cast(total_seq_len_), 1}; + mask_tensor_ = aclCreateTensor(mshape.data(), mshape.size(), ACL_BOOL, + mstrides.data(), 0, ACL_FORMAT_ND, + mshape.data(), mshape.size(), mask_buf_); + + // Scalar -inf for the masked-fill step. aclCreateScalar stores the pointer + // rather than copying, so neg_inf_storage_ must stay alive with the object. + neg_inf_ = aclCreateScalar(&neg_inf_storage_, ACL_FLOAT); + // Workspaces are allocated lazily on first operator() call. + } + + ~Operator() { + if (copy_exec_) aclDestroyAclOpExecutor(copy_exec_); + if (fill_exec_) aclDestroyAclOpExecutor(fill_exec_); + if (softmax_exec_) aclDestroyAclOpExecutor(softmax_exec_); + aclrtFree(temp_buf_); + aclrtFree(mask_buf_); + aclDestroyTensor(mask_tensor_); + aclDestroyScalar(neg_inf_); + } + + void operator()(const Tensor input, Tensor out) const override { + auto t_in = in_cache_.get(const_cast(input.data())); + auto t_temp = temp_cache_.get(temp_buf_); + auto t_out = out_cache_.get(out.data()); + auto stream = static_cast(stream_); + + // Step 1: copy input (possibly non-contiguous) into contiguous temp. + if (!copy_exec_) { + aclnnInplaceCopyGetWorkspaceSize(t_temp, t_in, ©_ws_, ©_exec_); + aclSetAclOpExecutorRepeatable(copy_exec_); + } else { + aclSetInputTensorAddr(copy_exec_, 0, t_temp, temp_buf_); + aclSetInputTensorAddr(copy_exec_, 1, t_in, + const_cast(input.data())); + } + auto& copy_arena = ascend::workspacePool().ensure(stream, copy_ws_); + aclnnInplaceCopy(copy_arena.buf, copy_ws_, copy_exec_, stream); + + // Step 2: mask upper-triangle positions with -inf in-place. + // mask_tensor_ and neg_inf_ have stable addresses — first-call only. + if (!fill_exec_) { + aclnnInplaceMaskedFillScalarGetWorkspaceSize( + t_temp, mask_tensor_, neg_inf_, &fill_ws_, &fill_exec_); + aclSetAclOpExecutorRepeatable(fill_exec_); + } + auto& fill_arena = ascend::workspacePool().ensure(stream, fill_ws_); + aclnnInplaceMaskedFillScalar(fill_arena.buf, fill_ws_, fill_exec_, stream); + + // Step 3: softmax over the last dimension → out. + if (!softmax_exec_) { + constexpr int64_t kLastDim = -1; + aclnnSoftmaxGetWorkspaceSize(t_temp, kLastDim, t_out, &softmax_ws_, + &softmax_exec_); + aclSetAclOpExecutorRepeatable(softmax_exec_); + } else { + aclSetOutputTensorAddr(softmax_exec_, 0, t_out, out.data()); + } + auto& softmax_arena = ascend::workspacePool().ensure(stream, softmax_ws_); + aclnnSoftmax(softmax_arena.buf, softmax_ws_, softmax_exec_, stream); + } + + private: + mutable ascend::AclTensorCache in_cache_; + + mutable ascend::AclTensorCache out_cache_; + + mutable ascend::AclTensorCache temp_cache_; + + float neg_inf_storage_ = -std::numeric_limits::infinity(); + + void* temp_buf_ = nullptr; + + void* mask_buf_ = nullptr; + + aclTensor* mask_tensor_ = nullptr; + + aclScalar* neg_inf_ = nullptr; + + mutable aclOpExecutor* copy_exec_ = nullptr; + + mutable uint64_t copy_ws_ = 0; + + mutable aclOpExecutor* fill_exec_ = nullptr; + + mutable uint64_t fill_ws_ = 0; + + mutable aclOpExecutor* softmax_exec_ = nullptr; + + mutable uint64_t softmax_ws_ = 0; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/common.h b/src/ascend/common.h new file mode 100644 index 00000000..8b1a5624 --- /dev/null +++ b/src/ascend/common.h @@ -0,0 +1,175 @@ +#ifndef INFINI_OPS_ASCEND_COMMON_H_ +#define INFINI_OPS_ASCEND_COMMON_H_ + +#include +#include + +#include "acl/acl.h" +#include "aclnn/acl_meta.h" +#include "ascend/data_type_.h" +#include "tensor.h" + +namespace infini::ops::ascend { + +// Build an aclTensor descriptor from an InfiniOps Tensor. +// +// When `transpose_last2` is true the last two dimensions are swapped in the +// descriptor (shape and strides) without copying data. This is used by GEMM +// and Matmul to express a transpose via the view. +inline aclTensor* buildAclTensor(const Tensor& t, + bool transpose_last2 = false) { + std::vector shape(t.shape().begin(), t.shape().end()); + std::vector strides(t.strides().begin(), t.strides().end()); + + if (transpose_last2 && shape.size() >= 2) { + auto n = shape.size(); + std::swap(shape[n - 2], shape[n - 1]); + std::swap(strides[n - 2], strides[n - 1]); + } + + // Compute the minimum physical storage needed for this strided view. + // For contiguous tensors this equals `numel()`; for non-contiguous (gapped) + // tensors it may be larger; for broadcast (stride-0) tensors it may be + // smaller. Passing the view shape as the storage shape causes + // "ViewShape overlap" errors in ACLNN for non-contiguous inputs. + int64_t storage_elems = 1; + for (size_t i = 0; i < shape.size(); ++i) { + if (shape[i] == 0) { + storage_elems = 0; + break; + } + if (strides[i] > 0 && shape[i] > 1) { + storage_elems += static_cast(shape[i] - 1) * strides[i]; + } + } + std::vector storage_shape = {storage_elems}; + + return aclCreateTensor( + shape.data(), static_cast(shape.size()), toAclDtype(t.dtype()), + strides.data(), + /*storageOffset=*/0, ACL_FORMAT_ND, storage_shape.data(), + static_cast(storage_shape.size()), const_cast(t.data())); +} + +// Pre-computed tensor metadata for descriptor reuse. +// +// Stores shape, strides, storage_shape, and dtype once (avoiding per-call heap +// allocations). The aclTensor descriptor is created on the first `get()` call +// and its data pointer is updated in-place via `aclSetRawTensorAddr` on +// subsequent calls. +class AclTensorCache { + public: + AclTensorCache() = default; + + // Construct from explicit metadata (for device buffers not wrapped in Tensor). + // Computes contiguous strides from shape. + AclTensorCache(std::vector shape, aclDataType dtype, void* data) + : shape_(std::move(shape)), dtype_(dtype) { + strides_.resize(shape_.size()); + int64_t stride = 1; + for (int i = static_cast(shape_.size()) - 1; i >= 0; --i) { + strides_[i] = stride; + stride *= shape_[i]; + } + storage_shape_ = {stride}; + + if (data) { + tensor_ = aclCreateTensor( + shape_.data(), static_cast(shape_.size()), dtype_, + strides_.data(), + /*storageOffset=*/0, ACL_FORMAT_ND, storage_shape_.data(), + static_cast(storage_shape_.size()), data); + } + } + + explicit AclTensorCache(const Tensor& t, bool transpose_last2 = false) + : dtype_{toAclDtype(t.dtype())} { + shape_.assign(t.shape().begin(), t.shape().end()); + strides_.assign(t.strides().begin(), t.strides().end()); + + if (transpose_last2 && shape_.size() >= 2) { + auto n = shape_.size(); + std::swap(shape_[n - 2], shape_[n - 1]); + std::swap(strides_[n - 2], strides_[n - 1]); + } + + int64_t storage_elems = 1; + for (size_t i = 0; i < shape_.size(); ++i) { + if (shape_[i] == 0) { + storage_elems = 0; + break; + } + if (strides_[i] > 0 && shape_[i] > 1) { + storage_elems += static_cast(shape_[i] - 1) * strides_[i]; + } + } + storage_shape_ = {storage_elems}; + } + + ~AclTensorCache() { + if (tensor_) { + aclDestroyTensor(tensor_); + } + } + + AclTensorCache(const AclTensorCache&) = delete; + + AclTensorCache& operator=(const AclTensorCache&) = delete; + + AclTensorCache(AclTensorCache&& o) noexcept + : shape_(std::move(o.shape_)), + strides_(std::move(o.strides_)), + storage_shape_(std::move(o.storage_shape_)), + dtype_(o.dtype_), + tensor_(o.tensor_) { + o.tensor_ = nullptr; + } + + AclTensorCache& operator=(AclTensorCache&& o) noexcept { + if (this != &o) { + if (tensor_) { + aclDestroyTensor(tensor_); + } + shape_ = std::move(o.shape_); + strides_ = std::move(o.strides_); + storage_shape_ = std::move(o.storage_shape_); + dtype_ = o.dtype_; + tensor_ = o.tensor_; + o.tensor_ = nullptr; + } + + return *this; + } + + // Update the data pointer and return the cached descriptor. + aclTensor* get(void* data) const { + if (tensor_) { + aclSetRawTensorAddr(tensor_, data); + + return tensor_; + } + + tensor_ = aclCreateTensor( + shape_.data(), static_cast(shape_.size()), dtype_, + strides_.data(), + /*storageOffset=*/0, ACL_FORMAT_ND, storage_shape_.data(), + static_cast(storage_shape_.size()), data); + + return tensor_; + } + + private: + std::vector shape_; + + std::vector strides_; + + std::vector storage_shape_; + + aclDataType dtype_{ACL_DT_UNDEFINED}; + + mutable aclTensor* tensor_ = nullptr; +}; + +} // namespace infini::ops::ascend + +#endif diff --git a/src/ascend/data_type_.h b/src/ascend/data_type_.h new file mode 100644 index 00000000..08b1541b --- /dev/null +++ b/src/ascend/data_type_.h @@ -0,0 +1,61 @@ +#ifndef INFINI_OPS_ASCEND_DATA_TYPE__H_ +#define INFINI_OPS_ASCEND_DATA_TYPE__H_ + +#include + +#include "acl/acl.h" +#include "ascend/device_.h" +#include "data_type.h" + +namespace infini::ops::ascend { + +inline aclDataType toAclDtype(DataType dt) { + switch (dt) { + case DataType::kFloat16: + return ACL_FLOAT16; + case DataType::kBFloat16: + return ACL_BF16; + case DataType::kFloat32: + return ACL_FLOAT; + case DataType::kInt8: + return ACL_INT8; + case DataType::kInt16: + return ACL_INT16; + case DataType::kInt32: + return ACL_INT32; + case DataType::kInt64: + return ACL_INT64; + case DataType::kUInt8: + return ACL_UINT8; + case DataType::kUInt16: + return ACL_UINT16; + case DataType::kUInt32: + return ACL_UINT32; + case DataType::kUInt64: + return ACL_UINT64; + default: + assert(false && "unsupported dtype for Ascend backend"); + return ACL_DT_UNDEFINED; + } +} + +// Returns true for integer (signed or unsigned) DataType values. +inline bool isIntegerDtype(DataType dt) { + switch (dt) { + case DataType::kInt8: + case DataType::kInt16: + case DataType::kInt32: + case DataType::kInt64: + case DataType::kUInt8: + case DataType::kUInt16: + case DataType::kUInt32: + case DataType::kUInt64: + return true; + default: + return false; + } +} + +} // namespace infini::ops::ascend + +#endif diff --git a/src/ascend/device_.h b/src/ascend/device_.h new file mode 100644 index 00000000..b4ec934d --- /dev/null +++ b/src/ascend/device_.h @@ -0,0 +1,16 @@ +#ifndef INFINI_OPS_ASCEND_DEVICE__H_ +#define INFINI_OPS_ASCEND_DEVICE__H_ + +// NOTE: Cannot use `#include "device.h"` here — GCC resolves quoted includes +// relative to the current file first, and `src/ascend/` used to contain a +// `device.h`. Use `data_type.h` which transitively pulls in `src/device.h`. +#include "data_type.h" + +namespace infini::ops { + +template <> +struct DeviceEnabled : std::true_type {}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/flash_attention/kernel.h b/src/ascend/flash_attention/kernel.h new file mode 100644 index 00000000..d8545d90 --- /dev/null +++ b/src/ascend/flash_attention/kernel.h @@ -0,0 +1,362 @@ +#ifndef INFINI_OPS_ASCEND_FLASH_ATTENTION_KERNEL_H_ +#define INFINI_OPS_ASCEND_FLASH_ATTENTION_KERNEL_H_ + +#include +#include +#include + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnnop/aclnn_fused_infer_attention_score_v4.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/flash_attention.h" +#include "operator.h" + +namespace infini::ops { + +namespace detail { + +// Extract cu_seqlens differences to a host aclIntArray. +// cu_seqlens = [0, s1, s1+s2, ...] -> per_seq_lens = [s1, s2, ...]. +// Used by paged decode (actualSeqLengthsKv = per-sequence KV lengths). +// +// When cu_seqlens is a CPU tensor (device type kCpu), the data pointer is +// already on the host and can be read directly — no D2H sync needed. +inline aclIntArray* extractSeqLengths(const Tensor& cu_seqlens, + aclrtStream stream) { + auto n = cu_seqlens.numel(); + + const int64_t* cu_host_ptr = nullptr; + std::vector cu_host_buf; + + if (cu_seqlens.device().type() == Device::Type::kCpu) { + cu_host_ptr = static_cast(cu_seqlens.data()); + } else { + cu_host_buf.resize(n); + aclrtMemcpyAsync(cu_host_buf.data(), n * sizeof(int64_t), + cu_seqlens.data(), n * sizeof(int64_t), + ACL_MEMCPY_DEVICE_TO_HOST, stream); + aclrtSynchronizeStream(stream); + cu_host_ptr = cu_host_buf.data(); + } + + std::vector lengths(n - 1); + for (size_t i = 0; i < lengths.size(); ++i) { + lengths[i] = cu_host_ptr[i + 1] - cu_host_ptr[i]; + } + + return aclCreateIntArray(lengths.data(), + static_cast(lengths.size())); +} + +// Extract cumulative end positions from cu_seqlens to a host aclIntArray. +// cu_seqlens = [0, s1, s1+s2, ...] -> cum_lens = [s1, s1+s2, ...]. +// FIA V4 TND varlen uses cumulative end positions, matching the vllm-ascend +// convention for npu_fused_infer_attention_score actual_seq_lengths. +// +// When cu_seqlens is a CPU tensor, reads directly from host memory. +inline aclIntArray* cumSeqLengths(const Tensor& cu_seqlens, + aclrtStream stream) { + auto n = cu_seqlens.numel(); + + const int64_t* cu_host_ptr = nullptr; + std::vector cu_host_buf; + + if (cu_seqlens.device().type() == Device::Type::kCpu) { + cu_host_ptr = static_cast(cu_seqlens.data()); + } else { + cu_host_buf.resize(n); + aclrtMemcpyAsync(cu_host_buf.data(), n * sizeof(int64_t), + cu_seqlens.data(), n * sizeof(int64_t), + ACL_MEMCPY_DEVICE_TO_HOST, stream); + aclrtSynchronizeStream(stream); + cu_host_ptr = cu_host_buf.data(); + } + + // Skip the leading 0; return [s1, s1+s2, ...]. + return aclCreateIntArray(cu_host_ptr + 1, static_cast(n - 1)); +} + +// Allocate a 2048x2048 lower-triangular UINT8 causal mask on device. +// Required for sparseMode >= 2. +inline aclTensor* makeCausalMask(void** mask_buf, aclrtStream stream) { + constexpr int64_t kMaskDim = 2048; + const int64_t mask_elems = kMaskDim * kMaskDim; + const size_t mask_bytes = static_cast(mask_elems); // uint8_t + + aclrtMalloc(mask_buf, mask_bytes, ACL_MEM_MALLOC_NORMAL_ONLY); + + std::vector host_mask(mask_elems); + for (int64_t r = 0; r < kMaskDim; ++r) { + for (int64_t c = 0; c < kMaskDim; ++c) { + // 1 = masked out (upper triangle); 0 = attend (lower triangle). + host_mask[r * kMaskDim + c] = (c > r) ? 1 : 0; + } + } + aclrtMemcpyAsync(*mask_buf, mask_bytes, host_mask.data(), mask_bytes, + ACL_MEMCPY_HOST_TO_DEVICE, stream); + aclrtSynchronizeStream(stream); + + std::vector mask_shape = {kMaskDim, kMaskDim}; + std::vector mask_strides = {kMaskDim, 1}; + std::vector mask_storage = {mask_elems}; + return aclCreateTensor(mask_shape.data(), 2, ACL_UINT8, mask_strides.data(), + 0, ACL_FORMAT_ND, mask_storage.data(), 1, *mask_buf); +} + +} // namespace detail + +template <> +class Operator : public FlashAttention { + public: + Operator(const Tensor query, const Tensor key, const Tensor value, + std::optional cu_seqlens_q, + std::optional cu_seqlens_kv, + std::optional block_table, int64_t num_heads, + int64_t num_kv_heads, int64_t head_size, double scale, bool causal, + int64_t window_left, int64_t window_right, int64_t block_size, + Tensor output) + : FlashAttention(query, key, value, cu_seqlens_q, cu_seqlens_kv, + block_table, num_heads, num_kv_heads, head_size, scale, + causal, window_left, window_right, block_size, output) { + paged_ = block_table.has_value() && block_size > 0; + aclDataType acl_dt = ascend::toAclDtype(query.dtype()); + + if (!paged_) { + // Prefill: cache Q and output (TND layout). + prefill_q_cache_ = ascend::AclTensorCache(query); + prefill_out_cache_ = ascend::AclTensorCache(output); + + // Pre-compute causal mask once (sparse_mode >= 2). + if (causal) { + int64_t sm = (window_left >= 0) ? 4 : 3; + if (sm >= 2) { + causal_mask_ = detail::makeCausalMask(&causal_mask_buf_, nullptr); + } + } + } else { + // Decode: cache Q/output (BNSD), block_table. + const int64_t N = query.size(1); + const int64_t D = query.size(2); + const int64_t B = query.size(0); + + decode_q_cache_ = ascend::AclTensorCache( + {B, N, 1, D}, acl_dt, const_cast(query.data())); + decode_out_cache_ = ascend::AclTensorCache( + {B, N, 1, D}, acl_dt, output.data()); + block_table_cache_ = ascend::AclTensorCache(block_table.value()); + + // Pre-compute KV reshape metadata. + const int64_t nb = key.size(0); + const int64_t bsz = key.size(1); + const int64_t NkvD = key.size(2) * key.size(3); + kv_shape_ = {nb, bsz, NkvD}; + kv_strides_ = {bsz * NkvD, NkvD, 1}; + kv_storage_shape_ = {nb * bsz * NkvD}; + kv_acl_dt_ = acl_dt; + } + } + + ~Operator() { + if (causal_mask_) aclDestroyTensor(causal_mask_); + if (causal_mask_buf_) aclrtFree(causal_mask_buf_); + } + + void operator()(const Tensor query, const Tensor key, const Tensor value, + std::optional cu_seqlens_q, + std::optional cu_seqlens_kv, + std::optional block_table, int64_t num_heads, + int64_t num_kv_heads, int64_t head_size, double scale, + bool causal, int64_t window_left, int64_t window_right, + int64_t block_size, Tensor output) const override { + auto stream = static_cast(stream_); + const bool paged = paged_; + + int64_t sparse_mode; + int64_t pre_tokens = 2147483647; + int64_t next_tokens = 2147483647; + if (causal) { + if (window_left >= 0) { + sparse_mode = 4; + pre_tokens = window_left; + next_tokens = 0; + } else { + sparse_mode = 3; + next_tokens = 0; + } + } else { + sparse_mode = 0; + if (window_left >= 0) pre_tokens = window_left; + if (window_right >= 0) next_tokens = window_right; + } + + if (!paged) { + // --- Prefill --- + int64_t T = query.size(0); + + // cumSeqLengths / extractSeqLengths automatically skip D2H when + // cu_seqlens is a CPU tensor (see detail:: helpers above). + aclIntArray* seq_q = + cu_seqlens_q.has_value() + ? detail::cumSeqLengths(cu_seqlens_q.value(), stream) + : aclCreateIntArray(&T, 1); + aclIntArray* seq_kv = + cu_seqlens_kv.has_value() + ? detail::cumSeqLengths(cu_seqlens_kv.value(), stream) + : aclCreateIntArray(&T, 1); + + aclTensor* t_q = prefill_q_cache_.get(const_cast(query.data())); + // K/V descriptors go into TensorList which takes ownership — must be + // per-call (cannot cache). + aclTensor* t_k = ascend::buildAclTensor(key); + aclTensor* t_v = ascend::buildAclTensor(value); + aclTensor* t_out = prefill_out_cache_.get(output.data()); + + const aclTensor* k_arr[] = {t_k}; + const aclTensor* v_arr[] = {t_v}; + aclTensorList* key_list = aclCreateTensorList(k_arr, 1); + aclTensorList* val_list = aclCreateTensorList(v_arr, 1); + + uint64_t ws_needed = 0; + aclOpExecutor* executor = nullptr; + aclError gws = aclnnFusedInferAttentionScoreV4GetWorkspaceSize( + t_q, key_list, val_list, + nullptr, // pseShift + causal_mask_, // attenMask (pre-computed, or nullptr) + seq_q, // actualSeqLengths + seq_kv, // actualSeqLengthsKv + nullptr, nullptr, nullptr, nullptr, + nullptr, // deqScale1..quantOffset2 + nullptr, nullptr, // antiquantScale, antiquantOffset + nullptr, // blockTable + nullptr, nullptr, // queryPaddingSize, kvPaddingSize + nullptr, nullptr, nullptr, + nullptr, // key/value antiquant scale/offset + nullptr, nullptr, + nullptr, // keySharedPrefix, valueSharedPrefix, actualSharedPrefixLen + nullptr, nullptr, + nullptr, // queryRope, keyRope, keyRopeAntiquantScale + nullptr, nullptr, // dequantScaleQuery, learnableSink + num_heads, scale, pre_tokens, next_tokens, const_cast("TND"), + num_kv_heads, sparse_mode, + 0, // innerPrecise + 0, // blockSize (unused for prefill) + 0, false, // antiquantMode, softmaxLseFlag + 0, 0, 0, // keyAntiquantMode, valueAntiquantMode, queryQuantMode + t_out, nullptr, &ws_needed, &executor); + assert( + gws == ACL_SUCCESS && + "aclnnFusedInferAttentionScoreV4GetWorkspaceSize failed (prefill)"); + + auto& arena = ascend::workspacePool().ensure(stream, ws_needed); + aclError ret = aclnnFusedInferAttentionScoreV4(arena.buf, ws_needed, + executor, stream); + assert(ret == ACL_SUCCESS && + "aclnnFusedInferAttentionScoreV4 failed (prefill)"); + + // t_q and t_out are owned by caches — do NOT destroy. + // t_k and t_v are owned by TensorLists. + aclDestroyTensorList(key_list); + aclDestroyTensorList(val_list); + aclDestroyIntArray(seq_q); + aclDestroyIntArray(seq_kv); + return; + } + + // --- Paged decode --- + assert(cu_seqlens_kv.has_value() && + "`FlashAttention` paged decode requires `cu_seqlens_kv`"); + + aclTensor* t_query = decode_q_cache_.get(const_cast(query.data())); + aclTensor* t_output = decode_out_cache_.get(output.data()); + + // K/V descriptors go into TensorList which takes ownership — must be + // per-call. Use pre-computed metadata to avoid heap allocs. + aclTensor* t_key = aclCreateTensor( + kv_shape_.data(), static_cast(kv_shape_.size()), kv_acl_dt_, + kv_strides_.data(), 0, ACL_FORMAT_ND, kv_storage_shape_.data(), + static_cast(kv_storage_shape_.size()), + const_cast(key.data())); + aclTensor* t_value = aclCreateTensor( + kv_shape_.data(), static_cast(kv_shape_.size()), kv_acl_dt_, + kv_strides_.data(), 0, ACL_FORMAT_ND, kv_storage_shape_.data(), + static_cast(kv_storage_shape_.size()), + const_cast(value.data())); + + // extractSeqLengths skips D2H when cu_seqlens_kv is a CPU tensor. + aclIntArray* seq_kv = + detail::extractSeqLengths(cu_seqlens_kv.value(), stream); + aclTensor* t_block_table = + block_table_cache_.get(const_cast(block_table.value().data())); + + const aclTensor* k_arr[] = {t_key}; + const aclTensor* v_arr[] = {t_value}; + aclTensorList* key_list = aclCreateTensorList(k_arr, 1); + aclTensorList* val_list = aclCreateTensorList(v_arr, 1); + + uint64_t ws_needed = 0; + aclOpExecutor* executor = nullptr; + aclError gws = aclnnFusedInferAttentionScoreV4GetWorkspaceSize( + t_query, key_list, val_list, + nullptr, // pseShift + nullptr, // attenMask (sparseMode ignored for Q_S=1) + nullptr, // actualSeqLengths (ignored for Q_S=1) + seq_kv, // actualSeqLengthsKv (mandatory for paged) + nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, + t_block_table, // blockTable + nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, + nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, num_heads, scale, + static_cast(2147483647), static_cast(2147483647), + const_cast("BNSD"), num_kv_heads, + 0, // sparseMode=0 (ignored for Q_S=1) + 0, // innerPrecise + block_size, // blockSize + 0, false, // antiquantMode, softmaxLseFlag + 0, 0, 0, // keyAntiquantMode, valueAntiquantMode, queryQuantMode + t_output, nullptr, &ws_needed, &executor); + assert(gws == ACL_SUCCESS && + "aclnnFusedInferAttentionScoreV4GetWorkspaceSize failed (decode)"); + + auto& arena = ascend::workspacePool().ensure(stream, ws_needed); + aclError ret = + aclnnFusedInferAttentionScoreV4(arena.buf, ws_needed, executor, stream); + assert(ret == ACL_SUCCESS && + "aclnnFusedInferAttentionScoreV4 failed (decode)"); + + // t_query, t_output, t_block_table owned by caches — do NOT destroy. + // t_key, t_value owned by TensorLists. + aclDestroyTensorList(key_list); + aclDestroyTensorList(val_list); + aclDestroyIntArray(seq_kv); + } + + private: + bool paged_ = false; + + mutable ascend::AclTensorCache prefill_q_cache_; + + mutable ascend::AclTensorCache prefill_out_cache_; + + mutable ascend::AclTensorCache decode_q_cache_; + + mutable ascend::AclTensorCache decode_out_cache_; + + mutable ascend::AclTensorCache block_table_cache_; + + aclTensor* causal_mask_ = nullptr; + + void* causal_mask_buf_ = nullptr; + + std::vector kv_shape_; + + std::vector kv_strides_; + + std::vector kv_storage_shape_; + + aclDataType kv_acl_dt_ = ACL_DT_UNDEFINED; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/gemm/kernel.h b/src/ascend/gemm/kernel.h new file mode 100644 index 00000000..a59d6249 --- /dev/null +++ b/src/ascend/gemm/kernel.h @@ -0,0 +1,102 @@ +#ifndef INFINI_OPS_ASCEND_GEMM_KERNEL_H_ +#define INFINI_OPS_ASCEND_GEMM_KERNEL_H_ + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnnop/aclnn_addmm.h" +#include "aclnnop/aclnn_baddbmm.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/gemm.h" +#include "operator.h" + +namespace infini::ops { + +template <> +class Operator : public Gemm { + public: + Operator(const Tensor a, const Tensor b, std::optional alpha, + std::optional beta, std::optional trans_a, + std::optional trans_b, Tensor c) + : Gemm(a, b, alpha, beta, trans_a, trans_b, c), + batched_{batch_count_ > 1}, + alpha_val_{alpha.value_or(1.0f)}, + beta_val_{beta.value_or(1.0f)}, + self_cache_(c), + a_cache_(a, trans_a_), + b_cache_(b, trans_b_), + out_cache_(c) { + alpha_scalar_ = aclCreateScalar(&alpha_val_, ACL_FLOAT); + beta_scalar_ = aclCreateScalar(&beta_val_, ACL_FLOAT); + } + + ~Operator() { + if (executor_) aclDestroyAclOpExecutor(executor_); + aclDestroyScalar(alpha_scalar_); + aclDestroyScalar(beta_scalar_); + } + + void operator()(const Tensor a, const Tensor b, std::optional alpha, + std::optional beta, std::optional trans_a, + std::optional trans_b, Tensor c) const override { + auto stream = static_cast(stream_); + + auto t_self = self_cache_.get(c.data()); + auto t_a = a_cache_.get(const_cast(a.data())); + auto t_b = b_cache_.get(const_cast(b.data())); + auto t_out = out_cache_.get(c.data()); + + if (!executor_) { + if (batched_) { + aclnnBaddbmmGetWorkspaceSize(t_self, t_a, t_b, beta_scalar_, + alpha_scalar_, t_out, 0, &ws_size_, + &executor_); + } else { + aclnnAddmmGetWorkspaceSize(t_self, t_a, t_b, beta_scalar_, + alpha_scalar_, t_out, 0, &ws_size_, + &executor_); + } + aclSetAclOpExecutorRepeatable(executor_); + } else { + aclSetInputTensorAddr(executor_, 0, t_self, c.data()); + aclSetInputTensorAddr(executor_, 1, t_a, const_cast(a.data())); + aclSetInputTensorAddr(executor_, 2, t_b, const_cast(b.data())); + aclSetOutputTensorAddr(executor_, 0, t_out, c.data()); + } + + auto& arena = ascend::workspacePool().ensure(stream, ws_size_); + + if (batched_) { + aclnnBaddbmm(arena.buf, ws_size_, executor_, stream); + } else { + aclnnAddmm(arena.buf, ws_size_, executor_, stream); + } + } + + private: + bool batched_; + + float alpha_val_; + + float beta_val_; + + aclScalar* alpha_scalar_ = nullptr; + + aclScalar* beta_scalar_ = nullptr; + + mutable ascend::AclTensorCache self_cache_; + + mutable ascend::AclTensorCache a_cache_; + + mutable ascend::AclTensorCache b_cache_; + + mutable ascend::AclTensorCache out_cache_; + + mutable aclOpExecutor* executor_ = nullptr; + + mutable uint64_t ws_size_ = 0; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/linear/kernel.h b/src/ascend/linear/kernel.h new file mode 100644 index 00000000..ec0f4ec6 --- /dev/null +++ b/src/ascend/linear/kernel.h @@ -0,0 +1,122 @@ +#ifndef INFINI_OPS_ASCEND_LINEAR_KERNEL_H_ +#define INFINI_OPS_ASCEND_LINEAR_KERNEL_H_ + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnnop/aclnn_addmm.h" +#include "aclnnop/aclnn_baddbmm.h" +#include "aclnnop/aclnn_matmul.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/linear.h" +#include "operator.h" + +namespace infini::ops { + +template <> +class Operator : public Linear { + public: + Operator(const Tensor a, const Tensor b, std::optional bias, + bool trans_a, bool trans_b, Tensor out) + : Linear(a, b, bias, trans_a, trans_b, out), + batched_{out.ndim() > 2}, + a_cache_(a, trans_a), + b_cache_(b, trans_b), + out_cache_(out) { + if (has_bias_) { + bias_cache_ = ascend::AclTensorCache(*bias); + alpha_scalar_ = aclCreateScalar(&alpha_storage_, ACL_FLOAT); + beta_scalar_ = aclCreateScalar(&beta_storage_, ACL_FLOAT); + } + } + + ~Operator() { + if (executor_) aclDestroyAclOpExecutor(executor_); + if (alpha_scalar_) aclDestroyScalar(alpha_scalar_); + if (beta_scalar_) aclDestroyScalar(beta_scalar_); + } + + void operator()(const Tensor a, const Tensor b, std::optional bias, + bool trans_a, bool trans_b, Tensor out) const override { + auto stream = static_cast(stream_); + auto t_a = a_cache_.get(const_cast(a.data())); + auto t_b = b_cache_.get(const_cast(b.data())); + auto t_out = out_cache_.get(out.data()); + + if (has_bias_) { + auto t_bias = bias_cache_.get(const_cast(bias->data())); + + if (!executor_) { + if (batched_) { + aclnnBaddbmmGetWorkspaceSize(t_bias, t_a, t_b, beta_scalar_, + alpha_scalar_, t_out, 0, &ws_size_, + &executor_); + } else { + aclnnAddmmGetWorkspaceSize(t_bias, t_a, t_b, beta_scalar_, + alpha_scalar_, t_out, 0, &ws_size_, + &executor_); + } + aclSetAclOpExecutorRepeatable(executor_); + } else { + aclSetInputTensorAddr(executor_, 0, t_bias, + const_cast(bias->data())); + aclSetInputTensorAddr(executor_, 1, t_a, + const_cast(a.data())); + aclSetInputTensorAddr(executor_, 2, t_b, + const_cast(b.data())); + aclSetOutputTensorAddr(executor_, 0, t_out, out.data()); + } + + auto& arena = ascend::workspacePool().ensure(stream, ws_size_); + + if (batched_) { + aclnnBaddbmm(arena.buf, ws_size_, executor_, stream); + } else { + aclnnAddmm(arena.buf, ws_size_, executor_, stream); + } + } else { + if (!executor_) { + int8_t cube_math_type = 1; + aclnnMatmulGetWorkspaceSize(t_a, t_b, t_out, cube_math_type, + &ws_size_, &executor_); + aclSetAclOpExecutorRepeatable(executor_); + } else { + aclSetInputTensorAddr(executor_, 0, t_a, + const_cast(a.data())); + aclSetInputTensorAddr(executor_, 1, t_b, + const_cast(b.data())); + aclSetOutputTensorAddr(executor_, 0, t_out, out.data()); + } + + auto& arena = ascend::workspacePool().ensure(stream, ws_size_); + aclnnMatmul(arena.buf, ws_size_, executor_, stream); + } + } + + private: + bool batched_; + + mutable ascend::AclTensorCache bias_cache_; + + mutable ascend::AclTensorCache a_cache_; + + mutable ascend::AclTensorCache b_cache_; + + mutable ascend::AclTensorCache out_cache_; + + float alpha_storage_ = 1.0f; + + float beta_storage_ = 1.0f; + + aclScalar* alpha_scalar_ = nullptr; + + aclScalar* beta_scalar_ = nullptr; + + mutable aclOpExecutor* executor_ = nullptr; + + mutable uint64_t ws_size_ = 0; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/matmul/kernel.h b/src/ascend/matmul/kernel.h new file mode 100644 index 00000000..2d98c23f --- /dev/null +++ b/src/ascend/matmul/kernel.h @@ -0,0 +1,63 @@ +#ifndef INFINI_OPS_ASCEND_MATMUL_KERNEL_H_ +#define INFINI_OPS_ASCEND_MATMUL_KERNEL_H_ + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnnop/aclnn_matmul.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/matmul.h" +#include "operator.h" + +namespace infini::ops { + +template <> +class Operator : public Matmul { + public: + Operator(const Tensor a, const Tensor b, Tensor c, bool trans_a, bool trans_b) + : Matmul(a, b, c, trans_a, trans_b), + a_cache_(a, trans_a), + b_cache_(b, trans_b), + out_cache_(c) {} + + ~Operator() { + if (executor_) aclDestroyAclOpExecutor(executor_); + } + + void operator()(const Tensor a, const Tensor b, Tensor c, bool trans_a, + bool trans_b) const override { + auto stream = static_cast(stream_); + auto t_a = a_cache_.get(const_cast(a.data())); + auto t_b = b_cache_.get(const_cast(b.data())); + auto t_out = out_cache_.get(c.data()); + + if (!executor_) { + int8_t cube_math_type = 1; + aclnnMatmulGetWorkspaceSize(t_a, t_b, t_out, cube_math_type, &ws_size_, + &executor_); + aclSetAclOpExecutorRepeatable(executor_); + } else { + aclSetInputTensorAddr(executor_, 0, t_a, const_cast(a.data())); + aclSetInputTensorAddr(executor_, 1, t_b, const_cast(b.data())); + aclSetOutputTensorAddr(executor_, 0, t_out, c.data()); + } + + auto& arena = ascend::workspacePool().ensure(stream, ws_size_); + aclnnMatmul(arena.buf, ws_size_, executor_, stream); + } + + private: + mutable ascend::AclTensorCache a_cache_; + + mutable ascend::AclTensorCache b_cache_; + + mutable ascend::AclTensorCache out_cache_; + + mutable aclOpExecutor* executor_ = nullptr; + + mutable uint64_t ws_size_ = 0; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/mul/kernel.h b/src/ascend/mul/kernel.h new file mode 100644 index 00000000..38a09869 --- /dev/null +++ b/src/ascend/mul/kernel.h @@ -0,0 +1,63 @@ +#ifndef INFINI_OPS_ASCEND_MUL_KERNEL_H_ +#define INFINI_OPS_ASCEND_MUL_KERNEL_H_ + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnn_mul.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/mul.h" +#include "operator.h" + +namespace infini::ops { + +template <> +class Operator : public Mul { + public: + Operator(const Tensor input, const Tensor other, Tensor out) + : Mul(input, other, out), + in_cache_(input), + oth_cache_(other), + out_cache_(out) {} + + ~Operator() { + if (executor_) aclDestroyAclOpExecutor(executor_); + } + + void operator()(const Tensor input, const Tensor other, + Tensor out) const override { + auto stream = static_cast(stream_); + auto t_in = in_cache_.get(const_cast(input.data())); + auto t_oth = oth_cache_.get(const_cast(other.data())); + auto t_out = out_cache_.get(out.data()); + + if (!executor_) { + aclnnMulGetWorkspaceSize(t_in, t_oth, t_out, &ws_size_, &executor_); + aclSetAclOpExecutorRepeatable(executor_); + } else { + aclSetInputTensorAddr(executor_, 0, t_in, + const_cast(input.data())); + aclSetInputTensorAddr(executor_, 1, t_oth, + const_cast(other.data())); + aclSetOutputTensorAddr(executor_, 0, t_out, out.data()); + } + + auto& arena = ascend::workspacePool().ensure(stream, ws_size_); + aclnnMul(arena.buf, ws_size_, executor_, stream); + } + + private: + mutable ascend::AclTensorCache in_cache_; + + mutable ascend::AclTensorCache oth_cache_; + + mutable ascend::AclTensorCache out_cache_; + + mutable aclOpExecutor* executor_ = nullptr; + + mutable uint64_t ws_size_ = 0; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/reshape_and_cache/kernel.h b/src/ascend/reshape_and_cache/kernel.h new file mode 100644 index 00000000..3bc0360c --- /dev/null +++ b/src/ascend/reshape_and_cache/kernel.h @@ -0,0 +1,110 @@ +#ifndef INFINI_OPS_ASCEND_RESHAPE_AND_CACHE_KERNEL_H_ +#define INFINI_OPS_ASCEND_RESHAPE_AND_CACHE_KERNEL_H_ + +#include +#include + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnnop/aclnn_index_copy.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/reshape_and_cache.h" +#include "operator.h" + +namespace infini::ops { + +// Device-side scatter via aclnnInplaceIndexCopy. +// +// The previous implementation copied slot_mapping D2H (aclrtSynchronizeStream), +// then issued per-token D2D memcpy in a host loop. For batch=256, this meant +// ~100 us sync + ~500 us host loop overhead. aclnnInplaceIndexCopy performs +// the scatter entirely on the NPU with two ACLNN calls (one for K, one for V), +// eliminating all D2H synchronisation and host-side loops. +// +// Requirement: slot_mapping must contain only non-negative values. Padding +// tokens (slot < 0) must be filtered by the caller before invoking this +// operator. +template <> +class Operator + : public ReshapeAndCache { + public: + Operator(const Tensor key, const Tensor value, const Tensor kv_cache, + const Tensor slot_mapping, Tensor kv_cache_out) + : ReshapeAndCache(key, value, kv_cache, slot_mapping, kv_cache_out), + key_cache_(key), + value_cache_(value), + slot_cache_(slot_mapping) { + auto num_blocks = static_cast(kv_cache.size(1)); + auto bs = static_cast(block_size_); + int64_t total_slots = num_blocks * bs; + int64_t nkv = static_cast(num_kv_heads_); + int64_t hs = static_cast(head_size_); + + aclDataType acl_dt = ascend::toAclDtype(key.dtype()); + + // Flattened K cache view: [total_slots, num_kv_heads, head_size]. + // K cache is kv_cache_out[0], starting at offset 0. + kv_k_cache_ = ascend::AclTensorCache( + {total_slots, nkv, hs}, acl_dt, kv_cache_out.data()); + + // V cache is kv_cache_out[1], offset by stride(0) elements. + v_offset_bytes_ = static_cast(kv_cache_out.stride(0)) * + kv_cache_out.element_size(); + kv_v_cache_ = ascend::AclTensorCache( + {total_slots, nkv, hs}, acl_dt, + static_cast(kv_cache_out.data()) + v_offset_bytes_); + } + + void operator()(const Tensor key, const Tensor value, const Tensor kv_cache, + const Tensor slot_mapping, + Tensor kv_cache_out) const override { + auto stream = static_cast(stream_); + + void* kv_k_data = kv_cache_out.data(); + void* kv_v_data = + static_cast(kv_cache_out.data()) + v_offset_bytes_; + + auto t_kv_k = kv_k_cache_.get(kv_k_data); + auto t_kv_v = kv_v_cache_.get(kv_v_data); + auto t_key = key_cache_.get(const_cast(key.data())); + auto t_value = value_cache_.get(const_cast(value.data())); + auto t_slot = slot_cache_.get(const_cast(slot_mapping.data())); + + // K cache scatter: kv_k[slot_mapping[i]] = key[i] along dim 0. + // Executor caching is not used here because aclnnInplaceIndexCopy is an + // inplace operation where self is both input and output; the executor + // reuse via aclSetInputTensorAddr does not update the output reference. + uint64_t k_ws = 0; + aclOpExecutor* k_exec = nullptr; + aclnnInplaceIndexCopyGetWorkspaceSize(t_kv_k, 0, t_slot, t_key, + &k_ws, &k_exec); + auto& k_arena = ascend::workspacePool().ensure(stream, k_ws); + aclnnInplaceIndexCopy(k_arena.buf, k_ws, k_exec, stream); + + // V cache scatter: kv_v[slot_mapping[i]] = value[i] along dim 0. + uint64_t v_ws = 0; + aclOpExecutor* v_exec = nullptr; + aclnnInplaceIndexCopyGetWorkspaceSize(t_kv_v, 0, t_slot, t_value, + &v_ws, &v_exec); + auto& v_arena = ascend::workspacePool().ensure(stream, v_ws); + aclnnInplaceIndexCopy(v_arena.buf, v_ws, v_exec, stream); + } + + private: + mutable ascend::AclTensorCache kv_k_cache_; + + mutable ascend::AclTensorCache kv_v_cache_; + + mutable ascend::AclTensorCache key_cache_; + + mutable ascend::AclTensorCache value_cache_; + + mutable ascend::AclTensorCache slot_cache_; + + size_t v_offset_bytes_ = 0; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/rms_norm/kernel.h b/src/ascend/rms_norm/kernel.h new file mode 100644 index 00000000..4061936b --- /dev/null +++ b/src/ascend/rms_norm/kernel.h @@ -0,0 +1,87 @@ +#ifndef INFINI_OPS_ASCEND_RMS_NORM_KERNEL_H_ +#define INFINI_OPS_ASCEND_RMS_NORM_KERNEL_H_ + +#include + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnn_rms_norm.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/rms_norm.h" +#include "operator.h" + +namespace infini::ops { + +template <> +class Operator : public RmsNorm { + public: + Operator(const Tensor input, const Tensor weight, float eps, Tensor out) + : RmsNorm(input, weight, eps, out), + in_cache_(input), + weight_cache_(weight), + out_cache_(out) { + // aclnnRmsNorm writes rstd as a required side output. + // Allocate a persistent device buffer for it. + rstd_shape_ = {static_cast(batch_size_), + static_cast(nhead_)}; + size_t rstd_bytes = batch_size_ * nhead_ * sizeof(float); + aclrtMalloc(&rstd_data_, rstd_bytes, ACL_MEM_MALLOC_NORMAL_ONLY); + + // The rstd descriptor has a stable data pointer. + rstd_tensor_ = aclCreateTensor(rstd_shape_.data(), 2, ACL_FLOAT, + /*strides=*/nullptr, 0, ACL_FORMAT_ND, + rstd_shape_.data(), 2, rstd_data_); + } + + ~Operator() { + if (executor_) aclDestroyAclOpExecutor(executor_); + if (rstd_tensor_) aclDestroyTensor(rstd_tensor_); + if (rstd_data_) aclrtFree(rstd_data_); + } + + void operator()(const Tensor input, const Tensor weight, float eps, + Tensor out) const override { + auto t_in = in_cache_.get(const_cast(input.data())); + auto t_weight = weight_cache_.get(const_cast(weight.data())); + auto t_out = out_cache_.get(out.data()); + + if (!executor_) { + aclnnRmsNormGetWorkspaceSize(t_in, t_weight, eps, t_out, rstd_tensor_, + &ws_size_, &executor_); + aclSetAclOpExecutorRepeatable(executor_); + } else { + aclSetInputTensorAddr(executor_, 0, t_in, + const_cast(input.data())); + aclSetInputTensorAddr(executor_, 1, t_weight, + const_cast(weight.data())); + aclSetOutputTensorAddr(executor_, 0, t_out, out.data()); + // rstd at output index 1 has a stable address — no update needed. + } + + auto stream = static_cast(stream_); + auto& arena = ascend::workspacePool().ensure(stream, ws_size_); + aclnnRmsNorm(arena.buf, ws_size_, executor_, stream); + } + + private: + mutable ascend::AclTensorCache in_cache_; + + mutable ascend::AclTensorCache weight_cache_; + + mutable ascend::AclTensorCache out_cache_; + + mutable aclOpExecutor* executor_ = nullptr; + + mutable uint64_t ws_size_ = 0; + + std::vector rstd_shape_; + + void* rstd_data_ = nullptr; + + aclTensor* rstd_tensor_ = nullptr; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/rotary_embedding/kernel.h b/src/ascend/rotary_embedding/kernel.h new file mode 100644 index 00000000..659f91d2 --- /dev/null +++ b/src/ascend/rotary_embedding/kernel.h @@ -0,0 +1,273 @@ +#ifndef INFINI_OPS_ASCEND_ROTARY_EMBEDDING_KERNEL_H_ +#define INFINI_OPS_ASCEND_ROTARY_EMBEDDING_KERNEL_H_ + +#include +#include +#include +#include + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnnop/aclnn_apply_rotary_pos_emb_v2.h" +#include "aclnnop/aclnn_index_select.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/rotary_embedding.h" +#include "operator.h" + +namespace infini::ops { + +// Rotary position embedding via aclnnApplyRotaryPosEmbV2. +// +// V2 handles Q and K simultaneously in a single inplace call (layout=4, TND). +// The `rotaryMode` parameter accepts "half", "interleave", or "quarter", but +// CANN currently only supports "half" (neox style). Passing "interleave" or +// "quarter" returns ACLNN_ERR_PARAM_INVALID. +// +// fp16 note: V2 accumulates with ~4 ULP error for float16 (max diff ~0.008), +// which exceeds strict atol=0.001 tests but is acceptable for inference. +// bfloat16 passes with atol=0.005. +// +// Restrictions: +// - rotary_dim must equal head_size (partial rotation not supported). +// - is_neox_style must be true (rotaryMode="half" only). +// All mainstream models (LLaMA, Qwen, Mistral, DeepSeek) satisfy both. +template <> +class Operator + : public RotaryEmbedding { + public: + Operator(const Tensor positions, const Tensor query, const Tensor key, + const Tensor cos_sin_cache, int64_t head_size, int64_t rotary_dim, + bool is_neox_style, Tensor query_out, Tensor key_out) + : RotaryEmbedding(positions, query, key, cos_sin_cache, head_size, + rotary_dim, is_neox_style, query_out, key_out) { + assert(rotary_dim == head_size && + "Ascend `RotaryEmbedding` requires rotary_dim == head_size " + "(partial rotation not supported)"); + assert(is_neox_style && + "Ascend `RotaryEmbedding` requires neox style — " + "aclnnApplyRotaryPosEmbV2 rotaryMode only supports \"half\"; " + "\"interleave\" and \"quarter\" return ACLNN_ERR_PARAM_INVALID"); + + const int64_t max_seq_len = cos_sin_cache.size(0); + const int64_t D = head_size_; + const int64_t half_D = D / 2; + const size_t elem_sz = cos_sin_cache.element_size(); + + // One-time: D2H copy cos_sin_cache, split cos/sin, expand, upload. + // cos_sin_cache layout per row: [c0..c_{D/2-1}, s0..s_{D/2-1}]. + size_t table_bytes = static_cast(max_seq_len * D) * elem_sz; + std::vector cache_host(table_bytes); + aclrtMemcpy(cache_host.data(), table_bytes, cos_sin_cache.data(), + table_bytes, ACL_MEMCPY_DEVICE_TO_HOST); + + // Pre-expand into separate cos/sin tables [max_seq_len, D]. + // neox: cos = [c0..c_{hD-1}, c0..c_{hD-1}] (halves duplicated) + // interleave: cos = [c0,c0, c1,c1, ..., c_{hD-1},c_{hD-1}] + std::vector cos_host(table_bytes); + std::vector sin_host(table_bytes); + + for (int64_t p = 0; p < max_seq_len; ++p) { + for (int64_t j = 0; j < half_D; ++j) { + const auto* c_src = + cache_host.data() + + static_cast(p * D + j) * elem_sz; + const auto* s_src = + cache_host.data() + + static_cast(p * D + half_D + j) * elem_sz; + + // Neox expansion: [c0..c_{hD-1}, c0..c_{hD-1}] (halves duplicated). + std::memcpy( + cos_host.data() + static_cast(p * D + j) * elem_sz, + c_src, elem_sz); + std::memcpy( + cos_host.data() + + static_cast(p * D + half_D + j) * elem_sz, + c_src, elem_sz); + std::memcpy( + sin_host.data() + static_cast(p * D + j) * elem_sz, + s_src, elem_sz); + std::memcpy( + sin_host.data() + + static_cast(p * D + half_D + j) * elem_sz, + s_src, elem_sz); + } + } + + // Upload expanded tables to device (one-time). + aclrtMalloc(&cos_table_dev_, table_bytes, ACL_MEM_MALLOC_NORMAL_ONLY); + aclrtMalloc(&sin_table_dev_, table_bytes, ACL_MEM_MALLOC_NORMAL_ONLY); + aclrtMemcpy(cos_table_dev_, table_bytes, cos_host.data(), table_bytes, + ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(sin_table_dev_, table_bytes, sin_host.data(), table_bytes, + ACL_MEMCPY_HOST_TO_DEVICE); + + const int64_t T = num_tokens_; + const int64_t Nq = num_heads_; + const int64_t Nkv = num_kv_heads_; + aclDataType acl_dt = ascend::toAclDtype(query.dtype()); + + // Gathered cos/sin buffers [T, D] — filled by aclnnIndexSelect each call. + size_t gathered_bytes = static_cast(T * D) * elem_sz; + aclrtMalloc(&cos_dev_, gathered_bytes, ACL_MEM_MALLOC_NORMAL_ONLY); + aclrtMalloc(&sin_dev_, gathered_bytes, ACL_MEM_MALLOC_NORMAL_ONLY); + + // IndexSelect descriptors: table ptrs stable, positions ptr varies. + cos_table_cache_ = ascend::AclTensorCache( + {max_seq_len, D}, acl_dt, cos_table_dev_); + sin_table_cache_ = ascend::AclTensorCache( + {max_seq_len, D}, acl_dt, sin_table_dev_); + idx_cache_ = ascend::AclTensorCache( + {T}, ACL_INT64, const_cast(positions.data())); + cos_out_cache_ = ascend::AclTensorCache({T, D}, acl_dt, cos_dev_); + sin_out_cache_ = ascend::AclTensorCache({T, D}, acl_dt, sin_dev_); + + // V2 descriptors: cos/sin [T, 1, D], Q [T, Nq, D], K [T, Nkv, D]. + cos_v2_cache_ = ascend::AclTensorCache({T, 1, D}, acl_dt, cos_dev_); + sin_v2_cache_ = ascend::AclTensorCache({T, 1, D}, acl_dt, sin_dev_); + q_cache_ = ascend::AclTensorCache( + {T, Nq, D}, acl_dt, const_cast(query_out.data())); + k_cache_ = ascend::AclTensorCache( + {T, Nkv, D}, acl_dt, const_cast(key_out.data())); + } + + ~Operator() { + if (idx_cos_exec_) aclDestroyAclOpExecutor(idx_cos_exec_); + if (idx_sin_exec_) aclDestroyAclOpExecutor(idx_sin_exec_); + if (v2_exec_) aclDestroyAclOpExecutor(v2_exec_); + + if (cos_table_dev_) aclrtFree(cos_table_dev_); + if (sin_table_dev_) aclrtFree(sin_table_dev_); + if (cos_dev_) aclrtFree(cos_dev_); + if (sin_dev_) aclrtFree(sin_dev_); + } + + void operator()(const Tensor positions, const Tensor query, const Tensor key, + const Tensor cos_sin_cache, int64_t head_size, + int64_t rotary_dim, bool is_neox_style, Tensor query_out, + Tensor key_out) const override { + auto stream = static_cast(stream_); + + const int64_t T = query.size(0); + const int64_t Nq = query.size(1); + const int64_t Nkv = key.size(1); + const int64_t D = head_size; + + // Step 1: Gather cos/sin by positions via aclnnIndexSelect (async). + { + auto t_cos_table = cos_table_cache_.get(cos_table_dev_); + auto t_sin_table = sin_table_cache_.get(sin_table_dev_); + auto t_idx = idx_cache_.get(const_cast(positions.data())); + auto t_cos_out = cos_out_cache_.get(cos_dev_); + auto t_sin_out = sin_out_cache_.get(sin_dev_); + + if (!idx_cos_exec_) { + aclnnIndexSelectGetWorkspaceSize(t_cos_table, 0, t_idx, t_cos_out, + &idx_cos_ws_, &idx_cos_exec_); + aclSetAclOpExecutorRepeatable(idx_cos_exec_); + } else { + aclSetInputTensorAddr(idx_cos_exec_, 1, t_idx, + const_cast(positions.data())); + } + + if (!idx_sin_exec_) { + aclnnIndexSelectGetWorkspaceSize(t_sin_table, 0, t_idx, t_sin_out, + &idx_sin_ws_, &idx_sin_exec_); + aclSetAclOpExecutorRepeatable(idx_sin_exec_); + } else { + aclSetInputTensorAddr(idx_sin_exec_, 1, t_idx, + const_cast(positions.data())); + } + + uint64_t ws_max = idx_cos_ws_ > idx_sin_ws_ ? idx_cos_ws_ : idx_sin_ws_; + auto& arena = ascend::workspacePool().ensure(stream, ws_max); + + aclnnIndexSelect(arena.buf, idx_cos_ws_, idx_cos_exec_, stream); + aclnnIndexSelect(arena.buf, idx_sin_ws_, idx_sin_exec_, stream); + } + + // Step 2: Copy q→q_out, k→k_out if not inplace (V2 operates inplace). + size_t elem_sz = query.element_size(); + + if (query.data() != query_out.data()) { + aclrtMemcpyAsync(query_out.data(), + static_cast(T * Nq * D) * elem_sz, query.data(), + static_cast(T * Nq * D) * elem_sz, + ACL_MEMCPY_DEVICE_TO_DEVICE, stream); + } + + if (key.data() != key_out.data()) { + aclrtMemcpyAsync(key_out.data(), + static_cast(T * Nkv * D) * elem_sz, key.data(), + static_cast(T * Nkv * D) * elem_sz, + ACL_MEMCPY_DEVICE_TO_DEVICE, stream); + } + + // Step 3: Apply V2 RoPE inplace on q_out and k_out. + auto t_cos = cos_v2_cache_.get(cos_dev_); + auto t_sin = sin_v2_cache_.get(sin_dev_); + auto t_q = q_cache_.get(query_out.data()); + auto t_k = k_cache_.get(key_out.data()); + + if (!v2_exec_) { + aclnnApplyRotaryPosEmbV2GetWorkspaceSize( + t_q, t_k, t_cos, t_sin, /*layout=*/4, const_cast("half"), + &v2_ws_, &v2_exec_); + aclSetAclOpExecutorRepeatable(v2_exec_); + } else { + aclSetInputTensorAddr(v2_exec_, 0, t_q, query_out.data()); + aclSetInputTensorAddr(v2_exec_, 1, t_k, key_out.data()); + } + + auto& arena = ascend::workspacePool().ensure(stream, v2_ws_); + aclnnApplyRotaryPosEmbV2(arena.buf, v2_ws_, v2_exec_, stream); + } + + private: + // Pre-expanded cos/sin tables on device: [max_seq_len, D]. + void* cos_table_dev_ = nullptr; + + void* sin_table_dev_ = nullptr; + + // Device buffers for gathered [T, D] cos/sin. + void* cos_dev_ = nullptr; + + void* sin_dev_ = nullptr; + + // IndexSelect descriptors. + mutable ascend::AclTensorCache cos_table_cache_; + + mutable ascend::AclTensorCache sin_table_cache_; + + mutable ascend::AclTensorCache idx_cache_; + + mutable ascend::AclTensorCache cos_out_cache_; + + mutable ascend::AclTensorCache sin_out_cache_; + + // V2 descriptors. + mutable ascend::AclTensorCache cos_v2_cache_; + + mutable ascend::AclTensorCache sin_v2_cache_; + + mutable ascend::AclTensorCache q_cache_; + + mutable ascend::AclTensorCache k_cache_; + + // Cached executors. + mutable aclOpExecutor* idx_cos_exec_ = nullptr; + + mutable uint64_t idx_cos_ws_ = 0; + + mutable aclOpExecutor* idx_sin_exec_ = nullptr; + + mutable uint64_t idx_sin_ws_ = 0; + + mutable aclOpExecutor* v2_exec_ = nullptr; + + mutable uint64_t v2_ws_ = 0; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/runtime_.h b/src/ascend/runtime_.h new file mode 100644 index 00000000..dca74258 --- /dev/null +++ b/src/ascend/runtime_.h @@ -0,0 +1,44 @@ +#ifndef INFINI_OPS_ASCEND_RUNTIME__H_ +#define INFINI_OPS_ASCEND_RUNTIME__H_ + +// clang-format off +#include "acl/acl.h" +// clang-format on + +#include "ascend/device_.h" +#include "runtime.h" + +namespace infini::ops { + +template <> +struct Runtime + : DeviceRuntime> { + using Stream = aclrtStream; + + static constexpr Device::Type kDeviceType = Device::Type::kAscend; + + static constexpr auto Malloc = [](void** ptr, size_t size) { + return aclrtMalloc(ptr, size, ACL_MEM_MALLOC_HUGE_FIRST); + }; + + static constexpr auto Free = aclrtFree; + + static constexpr auto Memcpy = [](void* dst, const void* src, size_t count, + aclrtMemcpyKind kind) { + return aclrtMemcpy(dst, count, src, count, kind); + }; + + static constexpr auto MemcpyHostToDevice = ACL_MEMCPY_HOST_TO_DEVICE; + + static constexpr auto MemcpyDeviceToHost = ACL_MEMCPY_DEVICE_TO_HOST; + + static constexpr auto Memset = [](void* ptr, int value, size_t count) { + return aclrtMemset(ptr, count, value, count); + }; +}; + +static_assert(Runtime::Validate()); + +} // namespace infini::ops + +#endif diff --git a/src/ascend/swiglu/kernel.h b/src/ascend/swiglu/kernel.h new file mode 100644 index 00000000..b3159898 --- /dev/null +++ b/src/ascend/swiglu/kernel.h @@ -0,0 +1,99 @@ +#ifndef INFINI_OPS_ASCEND_SWIGLU_KERNEL_H_ +#define INFINI_OPS_ASCEND_SWIGLU_KERNEL_H_ + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnn_mul.h" +#include "aclnn_silu.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/swiglu.h" +#include "data_type.h" +#include "operator.h" + +namespace infini::ops { + +// Implements SwiGLU as two ACLNN calls: silu(gate) into a temp buffer, +// then elementwise mul(input, temp) into out. +// aclnnSiluMul was not used because it fuses silu_AND_mul on the same +// tensor (x * silu(x)), whereas SwiGLU requires input * silu(gate) — +// two distinct inputs. +template <> +class Operator : public Swiglu { + public: + Operator(const Tensor input, const Tensor gate, Tensor out) + : Swiglu(input, gate, out), + in_cache_(input), + gate_cache_(gate), + out_cache_(out) { + size_t nbytes = input.numel() * kDataTypeToSize.at(input.dtype()); + aclrtMalloc(&temp_buf_, nbytes, ACL_MEM_MALLOC_NORMAL_ONLY); + + // Build temp cache from gate geometry (contiguous, same shape/dtype). + Tensor temp_t{temp_buf_, gate.shape(), gate.dtype(), gate.device()}; + temp_cache_ = ascend::AclTensorCache(temp_t); + } + + ~Operator() { + if (silu_exec_) aclDestroyAclOpExecutor(silu_exec_); + if (mul_exec_) aclDestroyAclOpExecutor(mul_exec_); + aclrtFree(temp_buf_); + } + + void operator()(const Tensor input, const Tensor gate, + Tensor out) const override { + auto t_in = in_cache_.get(const_cast(input.data())); + auto t_gate = gate_cache_.get(const_cast(gate.data())); + auto t_out = out_cache_.get(out.data()); + auto t_temp = temp_cache_.get(temp_buf_); + auto stream = static_cast(stream_); + + // Step 1: silu(gate) -> temp. + if (!silu_exec_) { + aclnnSiluGetWorkspaceSize(t_gate, t_temp, &silu_ws_, &silu_exec_); + aclSetAclOpExecutorRepeatable(silu_exec_); + } else { + aclSetInputTensorAddr(silu_exec_, 0, t_gate, + const_cast(gate.data())); + aclSetOutputTensorAddr(silu_exec_, 0, t_temp, temp_buf_); + } + auto& silu_arena = ascend::workspacePool().ensure(stream, silu_ws_); + aclnnSilu(silu_arena.buf, silu_ws_, silu_exec_, stream); + + // Step 2: mul(input, temp) -> out. + if (!mul_exec_) { + aclnnMulGetWorkspaceSize(t_in, t_temp, t_out, &mul_ws_, &mul_exec_); + aclSetAclOpExecutorRepeatable(mul_exec_); + } else { + aclSetInputTensorAddr(mul_exec_, 0, t_in, + const_cast(input.data())); + aclSetInputTensorAddr(mul_exec_, 1, t_temp, temp_buf_); + aclSetOutputTensorAddr(mul_exec_, 0, t_out, out.data()); + } + auto& mul_arena = ascend::workspacePool().ensure(stream, mul_ws_); + aclnnMul(mul_arena.buf, mul_ws_, mul_exec_, stream); + } + + private: + mutable ascend::AclTensorCache in_cache_; + + mutable ascend::AclTensorCache gate_cache_; + + mutable ascend::AclTensorCache out_cache_; + + mutable ascend::AclTensorCache temp_cache_; + + void* temp_buf_ = nullptr; + + mutable aclOpExecutor* silu_exec_ = nullptr; + + mutable uint64_t silu_ws_ = 0; + + mutable aclOpExecutor* mul_exec_ = nullptr; + + mutable uint64_t mul_ws_ = 0; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/workspace_pool_.h b/src/ascend/workspace_pool_.h new file mode 100644 index 00000000..3960017f --- /dev/null +++ b/src/ascend/workspace_pool_.h @@ -0,0 +1,99 @@ +#ifndef INFINI_OPS_ASCEND_WORKSPACE_POOL__H_ +#define INFINI_OPS_ASCEND_WORKSPACE_POOL__H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "acl/acl.h" + +namespace infini::ops::ascend { + +struct WorkspaceArena { + void* buf = nullptr; + + uint64_t capacity = 0; +}; + +class WorkspacePool { + public: + WorkspaceArena& ensure(aclrtStream stream, uint64_t needed) { + // Thread-local fast path: skip mutex when the same stream's arena already + // has enough capacity. After warmup (first call per operator), workspace + // sizes are fixed and this path is always taken. + // + // NOTE: Only the most recent stream is cached. If a single thread + // alternates between multiple streams (e.g. TP>1 driven by one thread), + // every stream switch falls back to the slow path. Replace with a + // small thread-local map if multi-stream-per-thread becomes common. + thread_local aclrtStream last_stream = nullptr; + thread_local WorkspaceArena* last_arena = nullptr; + + if (stream == last_stream && last_arena != nullptr && + needed <= last_arena->capacity) { + return *last_arena; + } + + // Slow path: look up arena in the map under lock. + // Arenas are heap-allocated via `unique_ptr` so that pointers remain stable + // across `unordered_map` rehashes (which invalidate value references). + std::lock_guard lock(mutex_); + auto& slot = arenas_[stream]; + if (!slot) { + slot = std::make_unique(); + } + auto* arena = slot.get(); + if (needed > arena->capacity) { + if (arena->capacity > 0) { + aclrtSynchronizeStream(stream); + aclrtFree(arena->buf); + } + if (needed > 0) { + auto ret = + aclrtMalloc(&arena->buf, needed, ACL_MEM_MALLOC_NORMAL_ONLY); + assert(ret == ACL_SUCCESS && "`WorkspacePool`: `aclrtMalloc` failed"); + } + arena->capacity = needed; + } + last_stream = stream; + last_arena = arena; + return *arena; + } + + ~WorkspacePool() { + for (auto& [stream, arena] : arenas_) { + if (arena && arena->capacity > 0) { + // The CANN runtime may already be torn down when this static + // destructor runs. aclrtGetDevice fails in that case — skip the + // free to avoid glibc "double free" abort. + int32_t dev_id = -1; + if (aclrtGetDevice(&dev_id) == ACL_SUCCESS) { + aclrtFree(arena->buf); + } else { + fprintf(stderr, + "[InfiniOps] `WorkspacePool`: CANN runtime already finalized, " + "skipping `aclrtFree` (%" PRIu64 " bytes leaked).\n", + arena->capacity); + } + } + } + } + + private: + std::unordered_map> arenas_; + + std::mutex mutex_; +}; + +inline WorkspacePool& workspacePool() { + static WorkspacePool pool; + return pool; +} + +} // namespace infini::ops::ascend + +#endif diff --git a/src/base/add_rms_norm.h b/src/base/add_rms_norm.h new file mode 100644 index 00000000..9e0a3e3d --- /dev/null +++ b/src/base/add_rms_norm.h @@ -0,0 +1,55 @@ +#ifndef INFINI_OPS_BASE_ADD_RMS_NORM_H_ +#define INFINI_OPS_BASE_ADD_RMS_NORM_H_ + +#include +#include + +#include "operator.h" +#include "tensor.h" + +namespace infini::ops { + +class AddRmsNorm : public Operator { + public: + AddRmsNorm(const Tensor x1, const Tensor x2, const Tensor weight, float eps, + Tensor y_out, Tensor x_out) + : x1_strides_{x1.strides()}, + x2_strides_{x2.strides()}, + y_out_strides_{y_out.strides()}, + x_out_strides_{x_out.strides()}, + eps_{eps}, + dim_{y_out.size(-1)}, + ndim_{y_out.ndim()}, + batch_size_{ndim_ == 2 ? y_out.size(-2) : y_out.size(-3)}, + nhead_{ndim_ == 2 ? 1 : y_out.size(-2)} { + assert(x1.dtype() == x2.dtype() && x1.dtype() == weight.dtype() && + x1.dtype() == y_out.dtype() && x1.dtype() == x_out.dtype()); + } + + virtual void operator()(const Tensor x1, const Tensor x2, + const Tensor weight, float eps, Tensor y_out, + Tensor x_out) const = 0; + + protected: + Tensor::Strides x1_strides_; + + Tensor::Strides x2_strides_; + + Tensor::Strides y_out_strides_; + + Tensor::Strides x_out_strides_; + + float eps_{1e-6f}; + + Tensor::Size dim_{0}; + + Tensor::Size ndim_{0}; + + Tensor::Size batch_size_{0}; + + Tensor::Size nhead_{1}; +}; + +} // namespace infini::ops + +#endif diff --git a/src/base/cast.h b/src/base/cast.h new file mode 100644 index 00000000..29f1f40c --- /dev/null +++ b/src/base/cast.h @@ -0,0 +1,52 @@ +#ifndef INFINI_OPS_BASE_CAST_H_ +#define INFINI_OPS_BASE_CAST_H_ + +#include "operator.h" + +namespace infini::ops { + +class Cast : public Operator { + public: + Cast(const Tensor input, Tensor out) + : ndim_{out.ndim()}, + output_size_{out.numel()}, + input_dtype_{input.dtype()}, + out_dtype_{out.dtype()}, + input_shape_{input.shape()}, + out_shape_{out.shape()}, + input_strides_{input.strides()}, + out_strides_{out.strides()}, + is_input_contiguous_{input.IsContiguous()}, + is_out_contiguous_{out.IsContiguous()} { + assert(input.numel() == out.numel() && + "the input and output of `Cast` must have the same number of " + "elements"); + } + + virtual void operator()(const Tensor input, Tensor out) const = 0; + + protected: + Tensor::Size ndim_{0}; + + Tensor::Size output_size_{0}; + + const DataType input_dtype_; + + const DataType out_dtype_; + + Tensor::Shape input_shape_; + + Tensor::Shape out_shape_; + + Tensor::Strides input_strides_; + + Tensor::Strides out_strides_; + + bool is_input_contiguous_{false}; + + bool is_out_contiguous_{false}; +}; + +} // namespace infini::ops + +#endif diff --git a/src/base/cat.h b/src/base/cat.h new file mode 100644 index 00000000..69642c30 --- /dev/null +++ b/src/base/cat.h @@ -0,0 +1,91 @@ +#ifndef INFINI_OPS_BASE_CAT_H_ +#define INFINI_OPS_BASE_CAT_H_ + +#include +#include +#include +#include +#include + +#include "operator.h" + +namespace infini::ops { + +class Cat : public Operator { + public: + Cat(const Tensor first_input, std::vector rest_inputs, int64_t dim, + Tensor out) + : dim_{static_cast(dim >= 0 ? dim : dim + out.ndim())}, + input_count_{1 + rest_inputs.size()}, + dtype_{first_input.dtype()}, + ndim_{out.ndim()}, + output_size_{out.numel()} { + assert(dim_ < ndim_ && "cat dim out of range"); + assert(out.dtype() == dtype_ && + "operator `Cat` requires all tensors to have the same dtype"); + + for (const auto& t : rest_inputs) { + assert(t.dtype() == dtype_ && + "operator `Cat` requires all tensors to have the same dtype"); + assert(t.ndim() == ndim_ && + "operator `Cat` requires all tensors to have the same ndim"); + } + + // Collect all input tensors. + inputs_.reserve(input_count_); + inputs_.push_back(first_input); + + for (auto& t : rest_inputs) { + inputs_.push_back(std::move(t)); + } + + // Build cumulative sizes along the cat dimension. + cum_dim_sizes_.resize(input_count_); + cum_dim_sizes_[0] = inputs_[0].size(dim_); + + for (size_t i = 1; i < input_count_; ++i) { + cum_dim_sizes_[i] = cum_dim_sizes_[i - 1] + inputs_[i].size(dim_); + } + + // Compute outer_size (product of dims before cat dim) and inner_size + // (product of dims after cat dim). + outer_size_ = 1; + + for (size_t i = 0; i < dim_; ++i) { + outer_size_ *= out.size(i); + } + + inner_size_ = 1; + + for (size_t i = dim_ + 1; i < ndim_; ++i) { + inner_size_ *= out.size(i); + } + } + + virtual void operator()(const Tensor first_input, + std::vector rest_inputs, int64_t dim, + Tensor out) const = 0; + + protected: + size_t dim_{0}; + + size_t input_count_{0}; + + const DataType dtype_; + + size_t ndim_{0}; + + size_t output_size_{0}; + + size_t outer_size_{1}; + + size_t inner_size_{1}; + + std::vector inputs_; + + std::vector cum_dim_sizes_; +}; + +} // namespace infini::ops + +#endif diff --git a/src/base/flash_attention.h b/src/base/flash_attention.h new file mode 100644 index 00000000..734e9a22 --- /dev/null +++ b/src/base/flash_attention.h @@ -0,0 +1,104 @@ +#ifndef INFINI_OPS_BASE_FLASH_ATTENTION_H_ +#define INFINI_OPS_BASE_FLASH_ATTENTION_H_ + +#include +#include +#include + +#include "operator.h" + +namespace infini::ops { + +class FlashAttention : public Operator { + public: + FlashAttention(const Tensor query, const Tensor key, const Tensor value, + std::optional cu_seqlens_q, + std::optional cu_seqlens_kv, + std::optional block_table, int64_t num_heads, + int64_t num_kv_heads, int64_t head_size, double scale, + bool causal, int64_t window_left, int64_t window_right, + int64_t block_size, Tensor output) + : num_tokens_{query.size(0)}, + num_heads_{num_heads}, + num_kv_heads_{num_kv_heads}, + head_size_{head_size}, + scale_{scale}, + causal_{causal}, + window_left_{window_left}, + window_right_{window_right}, + block_size_{block_size}, + dtype_{query.dtype()}, + query_shape_{query.shape()}, + key_shape_{key.shape()}, + value_shape_{value.shape()}, + output_shape_{output.shape()}, + query_strides_{query.strides()}, + key_strides_{key.strides()}, + value_strides_{value.strides()}, + output_strides_{output.strides()}, + has_cu_seqlens_q_{cu_seqlens_q.has_value()}, + has_cu_seqlens_kv_{cu_seqlens_kv.has_value()}, + has_block_table_{block_table.has_value()} { + assert(num_heads % num_kv_heads == 0 && + "`FlashAttention` requires num_heads divisible by num_kv_heads"); + assert(query.ndim() == 3 && + "`FlashAttention` requires query to be 3D [T, N, D]"); + } + + virtual void operator()(const Tensor query, const Tensor key, + const Tensor value, + std::optional cu_seqlens_q, + std::optional cu_seqlens_kv, + std::optional block_table, int64_t num_heads, + int64_t num_kv_heads, int64_t head_size, double scale, + bool causal, int64_t window_left, + int64_t window_right, int64_t block_size, + Tensor output) const = 0; + + protected: + Tensor::Size num_tokens_{0}; + + int64_t num_heads_{0}; + + int64_t num_kv_heads_{0}; + + int64_t head_size_{0}; + + double scale_{0.0}; + + bool causal_{false}; + + int64_t window_left_{-1}; + + int64_t window_right_{-1}; + + int64_t block_size_{0}; + + const DataType dtype_; + + Tensor::Shape query_shape_; + + Tensor::Shape key_shape_; + + Tensor::Shape value_shape_; + + Tensor::Shape output_shape_; + + Tensor::Strides query_strides_; + + Tensor::Strides key_strides_; + + Tensor::Strides value_strides_; + + Tensor::Strides output_strides_; + + bool has_cu_seqlens_q_{false}; + + bool has_cu_seqlens_kv_{false}; + + bool has_block_table_{false}; +}; + +} // namespace infini::ops + +#endif diff --git a/src/base/linear.h b/src/base/linear.h new file mode 100644 index 00000000..a8a7523e --- /dev/null +++ b/src/base/linear.h @@ -0,0 +1,83 @@ +#ifndef INFINI_OPS_BASE_LINEAR_H_ +#define INFINI_OPS_BASE_LINEAR_H_ + +#include +#include + +#include "operator.h" + +namespace infini::ops { + +class Linear : public Operator { + public: + Linear(const Tensor a, const Tensor b, std::optional bias, + bool trans_a, bool trans_b, Tensor out) + : has_bias_{bias.has_value()}, + trans_a_{trans_a}, + trans_b_{trans_b}, + m_{out.size(-2)}, + n_{out.size(-1)}, + k_{trans_a_ ? a.size(-2) : a.size(-1)}, + a_type_{a.dtype()}, + b_type_{b.dtype()}, + out_type_{out.dtype()}, + a_strides_{a.strides()}, + b_strides_{b.strides()}, + out_strides_{out.strides()}, + lda_{std::max(a.stride(-2), a.stride(-1))}, + ldb_{std::max(b.stride(-2), b.stride(-1))}, + ldc_{std::max(out.stride(-2), out.stride(-1))}, + batch_count_{out.strides().size() > 2 ? out.size(-3) : 1}, + batch_stride_a_{a.strides().size() > 2 ? a.stride(-3) : 0}, + batch_stride_b_{b.strides().size() > 2 ? b.stride(-3) : 0}, + batch_stride_c_{out.strides().size() > 2 ? out.stride(-3) : 0} { + // TODO: Check constraints. + } + + virtual void operator()(const Tensor a, const Tensor b, + std::optional bias, bool trans_a, + bool trans_b, Tensor out) const = 0; + + protected: + bool has_bias_{false}; + + bool trans_a_{false}; + + bool trans_b_{false}; + + Tensor::Size m_{0}; + + Tensor::Size n_{0}; + + Tensor::Size k_{0}; + + const DataType a_type_; + + const DataType b_type_; + + const DataType out_type_; + + Tensor::Strides a_strides_; + + Tensor::Strides b_strides_; + + Tensor::Strides out_strides_; + + Tensor::Stride lda_{0}; + + Tensor::Stride ldb_{0}; + + Tensor::Stride ldc_{0}; + + Tensor::Size batch_count_{1}; + + Tensor::Stride batch_stride_a_{0}; + + Tensor::Stride batch_stride_b_{0}; + + Tensor::Stride batch_stride_c_{0}; +}; + +} // namespace infini::ops + +#endif diff --git a/src/base/matmul.h b/src/base/matmul.h new file mode 100644 index 00000000..cdada846 --- /dev/null +++ b/src/base/matmul.h @@ -0,0 +1,84 @@ +#ifndef INFINI_OPS_BASE_MATMUL_H_ +#define INFINI_OPS_BASE_MATMUL_H_ + +#include + +#include "operator.h" + +namespace infini::ops { + +class Matmul : public Operator { + public: + Matmul(const Tensor a, const Tensor b, Tensor c, bool trans_a, bool trans_b) + : trans_a_{trans_a}, + trans_b_{trans_b}, + m_{c.size(-2)}, + n_{c.size(-1)}, + k_{trans_a_ ? a.size(-2) : a.size(-1)}, + a_type_{a.dtype()}, + b_type_{b.dtype()}, + c_type_{c.dtype()}, + a_strides_{a.strides()}, + b_strides_{b.strides()}, + c_strides_{c.strides()}, + lda_{std::max(a.stride(-2), a.stride(-1))}, + ldb_{std::max(b.stride(-2), b.stride(-1))}, + ldc_{std::max(c.stride(-2), c.stride(-1))}, + batch_count_{c.strides().size() > 2 ? c.size(-3) : 1}, + batch_stride_a_{a.strides().size() > 2 ? a.stride(-3) : 0}, + batch_stride_b_{b.strides().size() > 2 ? b.stride(-3) : 0}, + batch_stride_c_{c.strides().size() > 2 ? c.stride(-3) : 0} { + // TODO: Check constraints. + } + + Matmul(const Tensor a, const Tensor b, Tensor c) + : Matmul{a, b, c, false, false} {} + + virtual void operator()(const Tensor a, const Tensor b, Tensor c, + bool trans_a, bool trans_b) const = 0; + + virtual void operator()(const Tensor a, const Tensor b, Tensor c) const { + return operator()(a, b, c, false, false); + } + + protected: + bool trans_a_{false}; + + bool trans_b_{false}; + + Tensor::Size m_{0}; + + Tensor::Size n_{0}; + + Tensor::Size k_{0}; + + const DataType a_type_; + + const DataType b_type_; + + const DataType c_type_; + + Tensor::Strides a_strides_; + + Tensor::Strides b_strides_; + + Tensor::Strides c_strides_; + + Tensor::Stride lda_{0}; + + Tensor::Stride ldb_{0}; + + Tensor::Stride ldc_{0}; + + Tensor::Size batch_count_{1}; + + Tensor::Stride batch_stride_a_{0}; + + Tensor::Stride batch_stride_b_{0}; + + Tensor::Stride batch_stride_c_{0}; +}; + +} // namespace infini::ops + +#endif diff --git a/src/base/mul.h b/src/base/mul.h new file mode 100644 index 00000000..9e7be223 --- /dev/null +++ b/src/base/mul.h @@ -0,0 +1,67 @@ +#ifndef INFINI_OPS_BASE_MUL_H_ +#define INFINI_OPS_BASE_MUL_H_ + +#include "operator.h" + +namespace infini::ops { + +class Mul : public Operator { + public: + Mul(const Tensor input, const Tensor other, Tensor out) + : ndim_{out.ndim()}, + output_size_{out.numel()}, + input_type_{input.dtype()}, + other_type_{other.dtype()}, + out_type_{out.dtype()}, + input_shape_{input.shape()}, + other_shape_{other.shape()}, + out_shape_{out.shape()}, + input_strides_{input.strides()}, + other_strides_{other.strides()}, + out_strides_{out.strides()}, + is_input_contiguous_{input.IsContiguous()}, + is_other_contiguous_{other.IsContiguous()}, + is_out_contiguous_{out.IsContiguous()} { + assert(!out.HasBroadcastDim() && + "the output of `Mul` should NOT have broadcasted dim!"); + assert(input_type_ == other_type_ && other_type_ == out_type_ && + "operator `Mul` requires all input and output tensors to have the " + "same dtype"); + } + + virtual void operator()(const Tensor input, const Tensor other, + Tensor out) const = 0; + + protected: + Tensor::Size ndim_{0}; + + Tensor::Size output_size_{0}; + + const DataType input_type_; + + const DataType other_type_; + + const DataType out_type_; + + Tensor::Shape input_shape_; + + Tensor::Shape other_shape_; + + Tensor::Shape out_shape_; + + Tensor::Strides input_strides_; + + Tensor::Strides other_strides_; + + Tensor::Strides out_strides_; + + bool is_input_contiguous_{false}; + + bool is_other_contiguous_{false}; + + bool is_out_contiguous_{false}; +}; + +} // namespace infini::ops + +#endif diff --git a/src/base/reshape_and_cache.h b/src/base/reshape_and_cache.h new file mode 100644 index 00000000..4aabe083 --- /dev/null +++ b/src/base/reshape_and_cache.h @@ -0,0 +1,74 @@ +#ifndef INFINI_OPS_BASE_RESHAPE_AND_CACHE_H_ +#define INFINI_OPS_BASE_RESHAPE_AND_CACHE_H_ + +#include +#include + +#include "operator.h" + +namespace infini::ops { + +class ReshapeAndCache : public Operator { + public: + ReshapeAndCache(const Tensor key, const Tensor value, const Tensor kv_cache, + const Tensor slot_mapping, Tensor kv_cache_out) + : num_tokens_{key.size(0)}, + num_kv_heads_{key.size(1)}, + head_size_{key.size(2)}, + block_size_{kv_cache.size(2)}, + key_dtype_{key.dtype()}, + key_shape_{key.shape()}, + value_shape_{value.shape()}, + kv_cache_shape_{kv_cache.shape()}, + slot_mapping_shape_{slot_mapping.shape()}, + key_strides_{key.strides()}, + value_strides_{value.strides()}, + kv_cache_strides_{kv_cache.strides()}, + slot_mapping_strides_{slot_mapping.strides()}, + kv_cache_out_strides_{kv_cache_out.strides()} { + assert(key.shape() == value.shape() && + "`ReshapeAndCache` requires key and value same shape"); + assert(kv_cache.ndim() == 5 && + "`ReshapeAndCache` requires kv_cache to be 5D [2, num_blocks, " + "block_size, num_kv_heads, head_size]"); + assert(slot_mapping.ndim() == 1 && + "`ReshapeAndCache` requires slot_mapping to be 1D"); + } + + virtual void operator()(const Tensor key, const Tensor value, + const Tensor kv_cache, const Tensor slot_mapping, + Tensor kv_cache_out) const = 0; + + protected: + Tensor::Size num_tokens_{0}; + + Tensor::Size num_kv_heads_{0}; + + Tensor::Size head_size_{0}; + + Tensor::Size block_size_{0}; + + const DataType key_dtype_; + + Tensor::Shape key_shape_; + + Tensor::Shape value_shape_; + + Tensor::Shape kv_cache_shape_; + + Tensor::Shape slot_mapping_shape_; + + Tensor::Strides key_strides_; + + Tensor::Strides value_strides_; + + Tensor::Strides kv_cache_strides_; + + Tensor::Strides slot_mapping_strides_; + + Tensor::Strides kv_cache_out_strides_; +}; + +} // namespace infini::ops + +#endif diff --git a/src/base/rotary_embedding.h b/src/base/rotary_embedding.h new file mode 100644 index 00000000..70989fa8 --- /dev/null +++ b/src/base/rotary_embedding.h @@ -0,0 +1,80 @@ +#ifndef INFINI_OPS_BASE_ROTARY_EMBEDDING_H_ +#define INFINI_OPS_BASE_ROTARY_EMBEDDING_H_ + +#include +#include + +#include "operator.h" + +namespace infini::ops { + +class RotaryEmbedding : public Operator { + public: + RotaryEmbedding(const Tensor positions, const Tensor query, const Tensor key, + const Tensor cos_sin_cache, int64_t head_size, + int64_t rotary_dim, bool is_neox_style, Tensor query_out, + Tensor key_out) + : num_tokens_{query.size(0)}, + num_heads_{static_cast(query.size(1))}, + num_kv_heads_{static_cast(key.size(1))}, + head_size_{head_size}, + rotary_dim_{rotary_dim}, + is_neox_style_{is_neox_style}, + query_shape_{query.shape()}, + key_shape_{key.shape()}, + cos_sin_cache_shape_{cos_sin_cache.shape()}, + query_out_shape_{query_out.shape()}, + key_out_shape_{key_out.shape()}, + query_strides_{query.strides()}, + key_strides_{key.strides()}, + query_out_strides_{query_out.strides()}, + key_out_strides_{key_out.strides()} { + assert(query.ndim() == 3 && + "`RotaryEmbedding` requires query to be 3D [T, N, D]"); + assert(key.ndim() == 3 && + "`RotaryEmbedding` requires key to be 3D [T, N_kv, D]"); + assert(rotary_dim <= head_size && + "`RotaryEmbedding` requires rotary_dim <= head_size"); + } + + virtual void operator()(const Tensor positions, const Tensor query, + const Tensor key, const Tensor cos_sin_cache, + int64_t head_size, int64_t rotary_dim, + bool is_neox_style, Tensor query_out, + Tensor key_out) const = 0; + + protected: + Tensor::Size num_tokens_{0}; + + int64_t num_heads_{0}; + + int64_t num_kv_heads_{0}; + + int64_t head_size_{0}; + + int64_t rotary_dim_{0}; + + bool is_neox_style_{true}; + + Tensor::Shape query_shape_; + + Tensor::Shape key_shape_; + + Tensor::Shape cos_sin_cache_shape_; + + Tensor::Shape query_out_shape_; + + Tensor::Shape key_out_shape_; + + Tensor::Strides query_strides_; + + Tensor::Strides key_strides_; + + Tensor::Strides query_out_strides_; + + Tensor::Strides key_out_strides_; +}; + +} // namespace infini::ops + +#endif diff --git a/src/base/swiglu.h b/src/base/swiglu.h index 023b14a2..bf0bd844 100644 --- a/src/base/swiglu.h +++ b/src/base/swiglu.h @@ -9,28 +9,28 @@ namespace infini::ops { class Swiglu : public Operator { public: - Swiglu(const Tensor input, const Tensor gate, Tensor out) + Swiglu(const Tensor input, const Tensor other, Tensor out) : ndim_{out.ndim()}, output_size_{out.numel()}, input_type_{input.dtype()}, - gate_type_{gate.dtype()}, + other_type_{other.dtype()}, out_type_{out.dtype()}, input_shape_{input.shape()}, - gate_shape_{gate.shape()}, + other_shape_{other.shape()}, out_shape_{out.shape()}, input_strides_{input.strides()}, - gate_strides_{gate.strides()}, + other_strides_{other.strides()}, out_strides_{out.strides()}, is_input_contiguous_{input.IsContiguous()}, - is_gate_contiguous_{gate.IsContiguous()}, + is_other_contiguous_{other.IsContiguous()}, is_out_contiguous_{out.IsContiguous()} { assert( - input_type_ == gate_type_ && gate_type_ == out_type_ && + input_type_ == other_type_ && other_type_ == out_type_ && "operator `Swiglu` requires all input and output tensors to have the " "same dtype"); } - virtual void operator()(const Tensor input, const Tensor gate, + virtual void operator()(const Tensor input, const Tensor other, Tensor out) const = 0; protected: @@ -40,25 +40,25 @@ class Swiglu : public Operator { const DataType input_type_; - const DataType gate_type_; + const DataType other_type_; const DataType out_type_; Tensor::Shape input_shape_; - Tensor::Shape gate_shape_; + Tensor::Shape other_shape_; Tensor::Shape out_shape_; Tensor::Strides input_strides_; - Tensor::Strides gate_strides_; + Tensor::Strides other_strides_; Tensor::Strides out_strides_; bool is_input_contiguous_{false}; - bool is_gate_contiguous_{false}; + bool is_other_contiguous_{false}; bool is_out_contiguous_{false}; }; diff --git a/src/common/generic_utils.h b/src/common/generic_utils.h index 795f2fb7..1213b2ff 100644 --- a/src/common/generic_utils.h +++ b/src/common/generic_utils.h @@ -21,6 +21,25 @@ constexpr auto CeilDiv(const X& x, const Y& y) { return (x + y - 1) / y; } +// Aligned vector type for vectorized memory access. +// +// Maps (T, VEC_SIZE) to a POD type with the same size as T[VEC_SIZE] and +// natural alignment. Used for 128-bit coalesced load/store in CUDA kernels. +template +struct AlignedVec { + using type = struct alignas(sizeof(T) * VEC_SIZE) { T data[VEC_SIZE]; }; +}; + +// Compute the optimal vectorization factor for type T. +// Target: 128-bit (16-byte) loads where possible. +template +constexpr int OptimalVecSize() { + constexpr int kTargetBytes = 16; + constexpr int vec = kTargetBytes / sizeof(T); + + return vec > 0 ? vec : 1; +} + } // namespace infini::ops::utils #endif diff --git a/src/cpu/cast/cast.h b/src/cpu/cast/cast.h new file mode 100644 index 00000000..67c8367c --- /dev/null +++ b/src/cpu/cast/cast.h @@ -0,0 +1,57 @@ +#ifndef INFINI_OPS_CPU_CAST_CAST_H_ +#define INFINI_OPS_CPU_CAST_CAST_H_ + +#include "base/cast.h" +#include "common/generic_utils.h" +#include "cpu/caster_.h" + +namespace infini::ops { + +template <> +class Operator : public Cast { + public: + Operator(const Tensor input, Tensor out) : Cast{input, out} {} + + void operator()(const Tensor input, Tensor out) const override { + DispatchFunc( + input_dtype_, + [&](auto in_tag) { + using InT = typename decltype(in_tag)::type; + DispatchFunc( + out_dtype_, + [&](auto out_tag) { + using OutT = typename decltype(out_tag)::type; + Compute(input, out); + }, + "`Operator::operator()` (out)"); + }, + "`Operator::operator()` (in)"); + } + + private: + template + void Compute(const Tensor input, Tensor out) const { + const auto* in_ptr = static_cast(input.data()); + auto* out_ptr = static_cast(out.data()); + + auto get_idx = [&](Tensor::Size i, bool is_contig, const auto* shape, + const auto* strides) { + return is_contig ? i : utils::IndexToOffset(i, ndim_, shape, strides); + }; + +#pragma omp parallel for + for (Tensor::Size i = 0; i < output_size_; ++i) { + auto in_idx = get_idx(i, is_input_contiguous_, input_shape_.data(), + input_strides_.data()); + auto out_idx = get_idx(i, is_out_contiguous_, out_shape_.data(), + out_strides_.data()); + + out_ptr[out_idx] = + Caster::template Cast(in_ptr[in_idx]); + } + } +}; + +} // namespace infini::ops + +#endif diff --git a/src/cpu/cat/cat.h b/src/cpu/cat/cat.h new file mode 100644 index 00000000..bf6be7b1 --- /dev/null +++ b/src/cpu/cat/cat.h @@ -0,0 +1,57 @@ +#ifndef INFINI_OPS_CPU_CAT_CAT_H_ +#define INFINI_OPS_CPU_CAT_CAT_H_ + +#include +#include + +#include "base/cat.h" +#include "cpu/caster_.h" + +namespace infini::ops { + +template <> +class Operator : public Cat, + Caster { + public: + Operator(const Tensor first_input, std::vector rest_inputs, + int64_t dim, Tensor out) + : Cat{first_input, std::move(rest_inputs), dim, out} {} + + void operator()(const Tensor first_input, std::vector rest_inputs, + int64_t dim, Tensor out) const override { + DispatchFunc( + dtype_, + [&](auto tag) { + using T = typename decltype(tag)::type; + Compute(out); + }, + "`Operator::operator()`"); + } + + private: + template + void Compute(Tensor out) const { + auto* out_ptr = static_cast(out.data()); + + for (size_t outer = 0; outer < outer_size_; ++outer) { + size_t out_offset = 0; + + for (size_t i = 0; i < input_count_; ++i) { + const auto* in_ptr = static_cast(inputs_[i].data()); + size_t dim_size = inputs_[i].size(dim_); + size_t copy_count = dim_size * inner_size_; + + std::memcpy( + out_ptr + outer * cum_dim_sizes_.back() * inner_size_ + out_offset, + in_ptr + outer * dim_size * inner_size_, + copy_count * sizeof(T)); + + out_offset += copy_count; + } + } + } +}; + +} // namespace infini::ops + +#endif diff --git a/src/cpu/linear/linear.h b/src/cpu/linear/linear.h new file mode 100644 index 00000000..ab107c61 --- /dev/null +++ b/src/cpu/linear/linear.h @@ -0,0 +1,71 @@ +#ifndef INFINI_OPS_CPU_LINEAR_LINEAR_H_ +#define INFINI_OPS_CPU_LINEAR_LINEAR_H_ + +#include + +#include "base/linear.h" +#include "common/generic_utils.h" +#include "cpu/caster_.h" + +namespace infini::ops { + +template <> +class Operator : public Linear, + Caster { + public: + Operator(const Tensor a, const Tensor b, std::optional bias, + bool trans_a, bool trans_b, Tensor out) + : Linear{a, b, bias, trans_a, trans_b, out} {} + + void operator()(const Tensor a, const Tensor b, std::optional bias, + bool trans_a, bool trans_b, Tensor out) const override { + DispatchFunc( + out.dtype(), + [&](auto tag) { + using T = typename decltype(tag)::type; + Compute(a, b, bias, out); + }, + "`Operator::operator()`"); + } + + private: + template + void Compute(const Tensor a, const Tensor b, std::optional bias, + Tensor out) const { + const auto* A = static_cast(a.data()); + const auto* B = static_cast(b.data()); + auto* Out = static_cast(out.data()); + const T* Bias = bias ? static_cast(bias->data()) : nullptr; + + for (Tensor::Size batch = 0; batch < batch_count_; ++batch) { + const auto* A_batch = A + batch * batch_stride_a_; + const auto* B_batch = B + batch * batch_stride_b_; + auto* Out_batch = Out + batch * batch_stride_c_; + + for (Tensor::Size i = 0; i < m_; ++i) { + + for (Tensor::Size j = 0; j < n_; ++j) { + float sum = 0.0f; + + for (Tensor::Size l = 0; l < k_; ++l) { + float a_val = Cast( + A_batch[trans_a_ ? (l * lda_ + i) : (i * lda_ + l)]); + float b_val = Cast( + B_batch[trans_b_ ? (j * ldb_ + l) : (l * ldb_ + j)]); + sum += a_val * b_val; + } + + if (Bias) { + sum += Cast(Bias[j]); + } + + Out_batch[i * ldc_ + j] = Cast(sum); + } + } + } + } +}; + +} // namespace infini::ops + +#endif diff --git a/src/cpu/matmul/matmul.h b/src/cpu/matmul/matmul.h new file mode 100644 index 00000000..d0468fcc --- /dev/null +++ b/src/cpu/matmul/matmul.h @@ -0,0 +1,84 @@ +#ifndef INFINI_OPS_CPU_MATMUL_H_ +#define INFINI_OPS_CPU_MATMUL_H_ + +#include + +#include "base/matmul.h" +#include "common/generic_utils.h" +#include "cpu/caster_.h" + +namespace infini::ops { + +template <> +class Operator : public Matmul, + Caster { + public: + Operator(const Tensor a, const Tensor b, Tensor c, bool trans_a, + bool trans_b) + : Matmul{a, b, c, trans_a, trans_b} { + // TODO: Check constraints. + } + + Operator(const Tensor a, const Tensor b, Tensor c) + : Operator{a, b, c, false, false} {} + + void operator()(const Tensor a, const Tensor b, Tensor c, bool trans_a, + bool trans_b) const override { + DispatchFunc( + c.dtype(), + [&](auto tag) { + using T = typename decltype(tag)::type; + Compute(a, b, c, trans_a, trans_b); + }, + "`Operator::operator()`"); + } + + private: + template + void Compute(const Tensor a, const Tensor b, Tensor c, bool trans_a, + bool trans_b) const { + const auto* A = static_cast(a.data()); + const auto* B = static_cast(b.data()); + auto* C = static_cast(c.data()); + + Tensor::Stride stride_a_m = trans_a + ? a_strides_[a_strides_.size() - 1] + : a_strides_[a_strides_.size() - 2]; + Tensor::Stride stride_a_k = trans_a + ? a_strides_[a_strides_.size() - 2] + : a_strides_[a_strides_.size() - 1]; + Tensor::Stride stride_b_k = trans_b + ? b_strides_[b_strides_.size() - 1] + : b_strides_[b_strides_.size() - 2]; + Tensor::Stride stride_b_n = trans_b + ? b_strides_[b_strides_.size() - 2] + : b_strides_[b_strides_.size() - 1]; + Tensor::Stride stride_c_m = c_strides_[c_strides_.size() - 2]; + Tensor::Stride stride_c_n = c_strides_[c_strides_.size() - 1]; + + for (Tensor::Size batch = 0; batch < batch_count_; ++batch) { + const auto* A_batch = A + batch * batch_stride_a_; + const auto* B_batch = B + batch * batch_stride_b_; + auto* C_batch = C + batch * batch_stride_c_; + + for (Tensor::Size i = 0; i < m_; ++i) { + for (Tensor::Size j = 0; j < n_; ++j) { + float sum = 0.0f; + + for (Tensor::Size l = 0; l < k_; ++l) { + float a_val = Cast(A_batch[i * stride_a_m + l * stride_a_k]); + float b_val = Cast(B_batch[l * stride_b_k + j * stride_b_n]); + sum += a_val * b_val; + } + + Tensor::Size idx = i * stride_c_m + j * stride_c_n; + C_batch[idx] = Cast(sum); + } + } + } + } +}; + +} // namespace infini::ops + +#endif diff --git a/src/cpu/mul/mul.h b/src/cpu/mul/mul.h new file mode 100644 index 00000000..0bdefb96 --- /dev/null +++ b/src/cpu/mul/mul.h @@ -0,0 +1,63 @@ +#ifndef INFINI_OPS_CPU_MUL_MUL_H_ +#define INFINI_OPS_CPU_MUL_MUL_H_ + +#include + +#include "base/mul.h" +#include "common/generic_utils.h" +#include "cpu/caster_.h" + +namespace infini::ops { + +template <> +class Operator : public Mul, + Caster { + public: + Operator(const Tensor input, const Tensor other, Tensor out) + : Mul{input, other, out} {} + + void operator()(const Tensor input, const Tensor other, + Tensor out) const override { + DispatchFunc( + out_type_, + [&](auto tag) { + using T = typename decltype(tag)::type; + Compute(input, other, out); + }, + "`Operator::operator()`"); + } + + private: + template + void Compute(const Tensor input, const Tensor other, Tensor out) const { + using ComputeType = std::conditional_t || + IsFP16, + float, T>; + + const auto* input_ptr = static_cast(input.data()); + const auto* other_ptr = static_cast(other.data()); + auto* out_ptr = static_cast(out.data()); + + auto get_idx = [&](Tensor::Size i, bool is_contig, const auto* shape, + const auto* strides) { + return is_contig ? i : utils::IndexToOffset(i, ndim_, shape, strides); + }; + +#pragma omp parallel for + for (Tensor::Size i = 0; i < output_size_; ++i) { + auto input_idx = get_idx(i, is_input_contiguous_, input_shape_.data(), + input_strides_.data()); + auto other_idx = get_idx(i, is_other_contiguous_, other_shape_.data(), + other_strides_.data()); + auto out_idx = get_idx(i, is_out_contiguous_, out_shape_.data(), + out_strides_.data()); + + out_ptr[out_idx] = Cast(Cast(input_ptr[input_idx]) * + Cast(other_ptr[other_idx])); + } + } +}; + +} // namespace infini::ops + +#endif diff --git a/src/cpu/swiglu/swiglu.h b/src/cpu/swiglu/swiglu.h index 57dccf18..5eb45c2f 100644 --- a/src/cpu/swiglu/swiglu.h +++ b/src/cpu/swiglu/swiglu.h @@ -15,26 +15,26 @@ class Operator : public Swiglu, public: using Swiglu::Swiglu; - void operator()(const Tensor input, const Tensor gate, + void operator()(const Tensor input, const Tensor other, Tensor out) const override { DispatchFunc( out_type_, [&](auto tag) { using T = typename decltype(tag)::type; - Compute(input, gate, out); + Compute(input, other, out); }, "Operator::operator()"); } private: template - void Compute(const Tensor input, const Tensor gate, Tensor out) const { + void Compute(const Tensor input, const Tensor other, Tensor out) const { using ComputeType = std::conditional_t || IsFP16, float, T>; const auto* input_ptr = static_cast(input.data()); - const auto* gate_ptr = static_cast(gate.data()); + const auto* other_ptr = static_cast(other.data()); auto* out_ptr = static_cast(out.data()); auto get_idx = [&](Tensor::Size i, bool is_contig, const auto* shape, @@ -46,16 +46,16 @@ class Operator : public Swiglu, for (Tensor::Size i = 0; i < output_size_; ++i) { auto input_idx = get_idx(i, is_input_contiguous_, input_shape_.data(), input_strides_.data()); - auto gate_idx = get_idx(i, is_gate_contiguous_, gate_shape_.data(), - gate_strides_.data()); + auto gate_idx = get_idx(i, is_other_contiguous_, other_shape_.data(), + other_strides_.data()); auto out_idx = get_idx(i, is_out_contiguous_, out_shape_.data(), out_strides_.data()); - const ComputeType gate_val = Cast(gate_ptr[gate_idx]); - const ComputeType sigmoid_gate = static_cast( - 1.0 / (1.0 + std::exp(-static_cast(gate_val)))); - const ComputeType swish_gate = gate_val * sigmoid_gate; + const ComputeType other_val = Cast(other_ptr[gate_idx]); + const ComputeType sigmoid_other = static_cast( + 1.0 / (1.0 + std::exp(-static_cast(other_val)))); + const ComputeType swish_other = other_val * sigmoid_other; out_ptr[out_idx] = - Cast(Cast(input_ptr[input_idx]) * swish_gate); + Cast(Cast(input_ptr[input_idx]) * swish_other); } } }; diff --git a/src/cpu/templates/binary_elementwise.h b/src/cpu/templates/binary_elementwise.h new file mode 100644 index 00000000..773dcf1d --- /dev/null +++ b/src/cpu/templates/binary_elementwise.h @@ -0,0 +1,67 @@ +#ifndef INFINI_OPS_CPU_TEMPLATES_BINARY_ELEMENTWISE_H_ +#define INFINI_OPS_CPU_TEMPLATES_BINARY_ELEMENTWISE_H_ + +#include + +#include "common/generic_utils.h" +#include "cpu/caster_.h" +#include "dispatcher.h" +#include "tensor.h" + +namespace infini::ops { + +// CPU binary elementwise brick. +// +// `Op` is a host-side functor: `T operator()(const T&, const T&) const`. +// Handles non-contiguous tensors via `IndexToOffset` and promotes FP16/BF16 +// to float for computation. +template +void CpuBinaryElementwise(const Tensor a, const Tensor b, Tensor out, + Tensor::Size output_size, Tensor::Size ndim, + bool a_contig, bool b_contig, bool out_contig, + const Tensor::Shape& a_shape, + const Tensor::Shape& b_shape, + const Tensor::Shape& out_shape, + const Tensor::Strides& a_strides, + const Tensor::Strides& b_strides, + const Tensor::Strides& out_strides, DataType dtype, + Op op) { + DispatchFunc( + dtype, + [&](auto tag) { + using T = typename decltype(tag)::type; + using ComputeType = + std::conditional_t || + IsFP16, + float, T>; + + const auto* a_ptr = static_cast(a.data()); + const auto* b_ptr = static_cast(b.data()); + auto* out_ptr = static_cast(out.data()); + +#pragma omp parallel for + for (Tensor::Size i = 0; i < output_size; ++i) { + auto ai = a_contig + ? i + : utils::IndexToOffset(i, ndim, a_shape.data(), + a_strides.data()); + auto bi = b_contig + ? i + : utils::IndexToOffset(i, ndim, b_shape.data(), + b_strides.data()); + auto oi = out_contig + ? i + : utils::IndexToOffset(i, ndim, out_shape.data(), + out_strides.data()); + + out_ptr[oi] = Caster::Cast( + op(Caster::Cast(a_ptr[ai]), + Caster::Cast(b_ptr[bi]))); + } + }, + "CpuBinaryElementwise"); +} + +} // namespace infini::ops + +#endif diff --git a/src/cpu/templates/reduce_transform.h b/src/cpu/templates/reduce_transform.h new file mode 100644 index 00000000..7eb9e720 --- /dev/null +++ b/src/cpu/templates/reduce_transform.h @@ -0,0 +1,103 @@ +#ifndef INFINI_OPS_CPU_TEMPLATES_REDUCE_TRANSFORM_H_ +#define INFINI_OPS_CPU_TEMPLATES_REDUCE_TRANSFORM_H_ + +#include +#include + +#include "cpu/caster_.h" +#include "dispatcher.h" +#include "tensor.h" + +namespace infini::ops { + +// CPU reduce-then-transform brick. +// +// Iterates over [batch, head] slices. For each slice, reduces over `dim` +// elements, then applies a transform using the reduction result. +// +// `ReduceOp` must define: +// `float Init()` — identity element. +// `float Accumulate(float acc, float value)` — fold one element. +// `float Finalize(float acc, size_t count)` — post-process total. +// +// `TransformOp` must define: +// `T Apply(T x, float reduced, size_t i)` — per-element transform. +template +void CpuReduceThenTransform( + const Tensor in, Tensor out, size_t batch_size, size_t nhead, + size_t dim, DataType dtype, const Tensor::Strides& in_strides, + const Tensor::Strides& out_strides, ReduceOp reduce_op, + TransformOp transform_op) { + auto stride_in_batch = in_strides.size() > 1 ? in_strides[0] : 0; + auto stride_in_head = + in_strides.size() > 1 ? in_strides[1] : in_strides[0]; + auto stride_out_batch = out_strides.size() > 1 ? out_strides[0] : 0; + auto stride_out_head = + out_strides.size() > 1 ? out_strides[1] : out_strides[0]; + + DispatchFunc( + dtype, + [&](auto tag) { + using T = typename decltype(tag)::type; + + const auto* in_ptr = static_cast(in.data()); + auto* out_ptr = static_cast(out.data()); + + for (size_t bi = 0; bi < batch_size; ++bi) { + + for (size_t hi = 0; hi < nhead; ++hi) { + auto in_row = in_ptr + bi * stride_in_batch + hi * stride_in_head; + auto out_row = + out_ptr + bi * stride_out_batch + hi * stride_out_head; + + // Reduction phase. + float acc = reduce_op.Init(); + + for (size_t k = 0; k < dim; ++k) { + float v = Caster::Cast(in_row[k]); + acc = reduce_op.Accumulate(acc, v); + } + + float reduced = reduce_op.Finalize(acc, dim); + + // Transform phase. + for (size_t k = 0; k < dim; ++k) { + out_row[k] = + transform_op.template Apply(in_row[k], reduced, k); + } + } + } + }, + "CpuReduceThenTransform"); +} + +// ---------- Built-in ops matching the CUDA counterparts --------------------- + +struct CpuMeanSquareReduce { + float Init() const { return 0.f; } + + float Accumulate(float acc, float v) const { return acc + v * v; } + + float Finalize(float acc, size_t count) const { + return 1.f / std::sqrt(acc / static_cast(count) + epsilon); + } + + float epsilon; +}; + +struct CpuRmsNormTransform { + template + T Apply(T x, float rms, size_t i) const { + const auto* w = static_cast(weight); + + return Caster::Cast( + Caster::Cast(x) * + Caster::Cast(w[i]) * rms); + } + + const void* weight; +}; + +} // namespace infini::ops + +#endif diff --git a/src/cpu/templates/unary_elementwise.h b/src/cpu/templates/unary_elementwise.h new file mode 100644 index 00000000..5f15d9b2 --- /dev/null +++ b/src/cpu/templates/unary_elementwise.h @@ -0,0 +1,61 @@ +#ifndef INFINI_OPS_CPU_TEMPLATES_UNARY_ELEMENTWISE_H_ +#define INFINI_OPS_CPU_TEMPLATES_UNARY_ELEMENTWISE_H_ + +#include + +#include "common/generic_utils.h" +#include "cpu/caster_.h" +#include "dispatcher.h" +#include "tensor.h" + +namespace infini::ops { + +// CPU unary elementwise brick with dual-dtype dispatch. +// +// `Op` is a host-side functor called as `op.template operator()(x)`, +// allowing the functor to know both input and output types. Handles +// non-contiguous tensors via `IndexToOffset`. +template +void CpuUnaryElementwise(const Tensor in, Tensor out, + Tensor::Size output_size, Tensor::Size ndim, + bool in_contig, bool out_contig, + const Tensor::Shape& in_shape, + const Tensor::Shape& out_shape, + const Tensor::Strides& in_strides, + const Tensor::Strides& out_strides, + DataType input_dtype, DataType output_dtype, Op op) { + DispatchFunc( + input_dtype, + [&](auto in_tag) { + using TIn = typename decltype(in_tag)::type; + + DispatchFunc( + output_dtype, + [&](auto out_tag) { + using TOut = typename decltype(out_tag)::type; + + const auto* in_ptr = static_cast(in.data()); + auto* out_ptr = static_cast(out.data()); + +#pragma omp parallel for + for (Tensor::Size i = 0; i < output_size; ++i) { + auto ii = in_contig + ? i + : utils::IndexToOffset(i, ndim, in_shape.data(), + in_strides.data()); + auto oi = out_contig + ? i + : utils::IndexToOffset(i, ndim, out_shape.data(), + out_strides.data()); + + out_ptr[oi] = op.template operator()(in_ptr[ii]); + } + }, + "CpuUnaryElementwise (out)"); + }, + "CpuUnaryElementwise (in)"); +} + +} // namespace infini::ops + +#endif diff --git a/src/cuda/add/dsl.h b/src/cuda/add/dsl.h new file mode 100644 index 00000000..b2ee583e --- /dev/null +++ b/src/cuda/add/dsl.h @@ -0,0 +1,42 @@ +#ifndef INFINI_OPS_CUDA_ADD_DSL_H_ +#define INFINI_OPS_CUDA_ADD_DSL_H_ + +#include "cuda/templates/binary_elementwise.cuh" +#include "base/add.h" + +namespace infini::ops { + +// Device-side binary functor for `Add` (DSL). +template +struct DslAddOp { + template + __device__ __forceinline__ T operator()(const T& a, const T& b) const { + using ComputeType = float; + auto va = Caster::template Cast(a); + auto vb = Caster::template Cast(b); + return Caster::template Cast((va + vb)); + } +}; + +template +class DslCudaAdd : public Add { + public: + DslCudaAdd(const Tensor input, const Tensor other, Tensor out) + : Add{input, other, out}, + brick_{input, other, out, ndim_} {} + + void operator()(const Tensor input, const Tensor other, + Tensor out) const override { + brick_.template Run( + stream_, input, other, out, output_size_, ndim_, + is_input_contiguous_, is_other_contiguous_, is_out_contiguous_, + out_type_); + } + + private: + BinaryElementwiseBrick brick_; +}; + +} // namespace infini::ops + +#endif diff --git a/src/cuda/add/kernel.h b/src/cuda/add/kernel.h index 95d82c91..8bf78c1b 100644 --- a/src/cuda/add/kernel.h +++ b/src/cuda/add/kernel.h @@ -1,104 +1,31 @@ #ifndef INFINI_OPS_CUDA_ADD_KERNEL_H_ #define INFINI_OPS_CUDA_ADD_KERNEL_H_ -#include -#include -#include -#include - #include "base/add.h" -#include "common/generic_utils.h" #include "cuda/add/kernel.cuh" -#include "cuda/kernel_commons.cuh" -#include "cuda/runtime_utils.h" +#include "cuda/templates/binary_elementwise.cuh" namespace infini::ops { +// CudaAdd uses BinaryElementwiseBrick for automatic vectorized dispatch +// on contiguous tensors (128-bit coalesced load/store). template class CudaAdd : public Add { public: CudaAdd(const Tensor input, const Tensor other, Tensor out) - : Add{input, other, out} { - size_t shape_size = ndim_ * sizeof(*d_input_shape_); - size_t strides_size = ndim_ * sizeof(*d_input_strides_); - const size_t metadata_size = 3 * (shape_size + strides_size); - std::vector metadata(metadata_size); - - Backend::Malloc((void**)&d_metadata_, metadata_size); - - size_t offset = 0; - d_input_shape_ = reinterpret_cast(d_metadata_ + offset); - std::memcpy(metadata.data() + offset, input_shape_.data(), shape_size); - offset += shape_size; - - d_other_shape_ = reinterpret_cast(d_metadata_ + offset); - std::memcpy(metadata.data() + offset, other_shape_.data(), shape_size); - offset += shape_size; - - d_out_shape_ = reinterpret_cast(d_metadata_ + offset); - std::memcpy(metadata.data() + offset, out_shape_.data(), shape_size); - offset += shape_size; - - d_input_strides_ = reinterpret_cast(d_metadata_ + offset); - std::memcpy(metadata.data() + offset, input_strides_.data(), strides_size); - offset += strides_size; - - d_other_strides_ = reinterpret_cast(d_metadata_ + offset); - std::memcpy(metadata.data() + offset, other_strides_.data(), strides_size); - offset += strides_size; - - d_out_strides_ = reinterpret_cast(d_metadata_ + offset); - std::memcpy(metadata.data() + offset, out_strides_.data(), strides_size); - - Backend::Memcpy(d_metadata_, metadata.data(), metadata_size, - Backend::MemcpyHostToDevice); - } - - ~CudaAdd() { Backend::Free(d_metadata_); } + : Add{input, other, out}, + brick_{input, other, out, ndim_} {} void operator()(const Tensor input, const Tensor other, Tensor out) const override { - int block_size = RuntimeUtils::GetOptimalBlockSize(); - DispatchFunc( - {static_cast(out_type_), block_size}, - [&](auto list_tag) { - using T = TypeMapType(list_tag)>; - constexpr int kBlockSize = ListGet<1>(list_tag); - - auto cuda_stream = - static_cast(stream_ ? stream_ : 0); - dim3 blockDims( - std::min(static_cast(block_size), output_size_)); - dim3 gridDims(utils::CeilDiv(output_size_, blockDims.x)); - - T* d_out = reinterpret_cast(out.data()); - const T* d_input = reinterpret_cast(input.data()); - const T* d_other = reinterpret_cast(other.data()); - - AddKernel - <<>>( - d_out, d_input, d_other, d_out_shape_, d_input_shape_, - d_other_shape_, d_out_strides_, d_input_strides_, - d_other_strides_, output_size_, ndim_, is_out_contiguous_, - is_input_contiguous_, is_other_contiguous_); - }, - "CudaAdd::operator()"); + brick_.template Run( + stream_, input, other, out, output_size_, ndim_, + is_input_contiguous_, is_other_contiguous_, is_out_contiguous_, + out_type_); } private: - std::byte* d_metadata_{nullptr}; - - Tensor::Size* d_input_shape_{nullptr}; - - Tensor::Size* d_other_shape_{nullptr}; - - Tensor::Size* d_out_shape_{nullptr}; - - Tensor::Stride* d_input_strides_{nullptr}; - - Tensor::Stride* d_other_strides_{nullptr}; - - Tensor::Stride* d_out_strides_{nullptr}; + BinaryElementwiseBrick brick_; }; } // namespace infini::ops diff --git a/src/cuda/add_rms_norm/kernel.cuh b/src/cuda/add_rms_norm/kernel.cuh new file mode 100644 index 00000000..fe97ad0f --- /dev/null +++ b/src/cuda/add_rms_norm/kernel.cuh @@ -0,0 +1,76 @@ +#ifndef INFINI_OPS_CUDA_ADD_RMS_NORM_KERNEL_CUH_ +#define INFINI_OPS_CUDA_ADD_RMS_NORM_KERNEL_CUH_ + +#include +#include +#include + +#include "cuda/caster.cuh" +#include "cuda/kernel_commons.cuh" + +namespace infini::ops { + +// Single-pass AddRmsNorm with shared memory caching. +// +// Pass 1: Compute residual = x1 + x2, write x_out, cache in shared memory, +// accumulate sum-of-squares. +// Pass 2: Read residual from shared memory (not x_out global), normalize. +template +__global__ void AddRmsNormKernel( + TData* __restrict__ y_out, int64_t stride_y_out_batch, + int64_t stride_y_out_nhead, TData* __restrict__ x_out, + int64_t stride_x_out_batch, int64_t stride_x_out_nhead, + const TData* __restrict__ x1, int64_t stride_x1_batch, + int64_t stride_x1_nhead, const TData* __restrict__ x2, + int64_t stride_x2_batch, int64_t stride_x2_nhead, + const TWeight* __restrict__ w, size_t nhead, size_t dim, float epsilon) { + // Dynamic shared memory for caching residual values. + extern __shared__ char smem_raw[]; + TCompute* res_cache = reinterpret_cast(smem_raw); + + size_t batch_idx = blockIdx.x / nhead; + size_t head_idx = blockIdx.x % nhead; + + auto y_out_ptr = + y_out + batch_idx * stride_y_out_batch + head_idx * stride_y_out_nhead; + auto x_out_ptr = + x_out + batch_idx * stride_x_out_batch + head_idx * stride_x_out_nhead; + auto x1_ptr = x1 + batch_idx * stride_x1_batch + head_idx * stride_x1_nhead; + auto x2_ptr = x2 + batch_idx * stride_x2_batch + head_idx * stride_x2_nhead; + + // Pass 1: Compute residual, cache in shared memory, write x_out, + // accumulate sum-of-squares. + TCompute ss = 0; + + for (size_t i = threadIdx.x; i < dim; i += block_size) { + TCompute val = Caster::template Cast(x1_ptr[i]) + + Caster::template Cast(x2_ptr[i]); + res_cache[i] = val; + x_out_ptr[i] = Caster::template Cast(val); + ss += val * val; + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + ss = BlockReduce(temp_storage).Sum(ss); + + __shared__ TCompute rms; + + if (threadIdx.x == 0) { + rms = rsqrtf(ss / static_cast(dim) + epsilon); + } + + __syncthreads(); + + // Pass 2: Normalize using cached residual (no second global read). + for (size_t i = threadIdx.x; i < dim; i += block_size) { + y_out_ptr[i] = Caster::template Cast( + res_cache[i] * + Caster::template Cast(w[i]) * rms); + } +} + +} // namespace infini::ops + +#endif diff --git a/src/cuda/add_rms_norm/kernel.h b/src/cuda/add_rms_norm/kernel.h new file mode 100644 index 00000000..3731c3fe --- /dev/null +++ b/src/cuda/add_rms_norm/kernel.h @@ -0,0 +1,74 @@ +#ifndef INFINI_OPS_CUDA_ADD_RMS_NORM_KERNEL_H_ +#define INFINI_OPS_CUDA_ADD_RMS_NORM_KERNEL_H_ + +#include +#include + +#include "base/add_rms_norm.h" +#include "cuda/add_rms_norm/kernel.cuh" +#include "cuda/kernel_commons.cuh" +#include "cuda/runtime_utils.h" +#include "data_type.h" +#include "dispatcher.h" + +namespace infini::ops { + +template +class CudaAddRmsNorm : public AddRmsNorm { + public: + using AddRmsNorm::AddRmsNorm; + + void operator()(const Tensor x1, const Tensor x2, const Tensor weight, + float eps, Tensor y_out, Tensor x_out) const override { + auto cuda_stream = + static_cast(stream_ ? stream_ : 0); + + auto stride_x1_batch = x1_strides_.size() > 1 ? x1_strides_[0] : 0; + auto stride_x1_nhead = + x1_strides_.size() > 1 ? x1_strides_[1] : x1_strides_[0]; + auto stride_x2_batch = x2_strides_.size() > 1 ? x2_strides_[0] : 0; + auto stride_x2_nhead = + x2_strides_.size() > 1 ? x2_strides_[1] : x2_strides_[0]; + auto stride_y_out_batch = + y_out_strides_.size() > 1 ? y_out_strides_[0] : 0; + auto stride_y_out_nhead = + y_out_strides_.size() > 1 ? y_out_strides_[1] : y_out_strides_[0]; + auto stride_x_out_batch = + x_out_strides_.size() > 1 ? x_out_strides_[0] : 0; + auto stride_x_out_nhead = + x_out_strides_.size() > 1 ? x_out_strides_[1] : x_out_strides_[0]; + + uint32_t num_blocks = static_cast(batch_size_ * nhead_); + + assert(x1.dtype() == x2.dtype() && x1.dtype() == weight.dtype() && + x1.dtype() == y_out.dtype() && x1.dtype() == x_out.dtype()); + + int block_size = RuntimeUtils::GetOptimalBlockSize(); + + DispatchFunc, ReducedFloatTypes>, + AllCudaBlockSizes>( + {static_cast(y_out.dtype()), block_size}, + [&](auto list_tag) { + using T = TypeMapType(list_tag)>; + constexpr int kBlockSize = ListGet<1>(list_tag); + + size_t smem_bytes = dim_ * sizeof(float); + + AddRmsNormKernel + <<>>( + reinterpret_cast(y_out.data()), stride_y_out_batch, + stride_y_out_nhead, reinterpret_cast(x_out.data()), + stride_x_out_batch, stride_x_out_nhead, + reinterpret_cast(x1.data()), stride_x1_batch, + stride_x1_nhead, reinterpret_cast(x2.data()), + stride_x2_batch, stride_x2_nhead, + reinterpret_cast(weight.data()), nhead_, dim_, + eps_); + }, + "CudaAddRmsNorm::operator()"); + } +}; + +} // namespace infini::ops + +#endif diff --git a/src/cuda/cast/dsl.h b/src/cuda/cast/dsl.h new file mode 100644 index 00000000..86773543 --- /dev/null +++ b/src/cuda/cast/dsl.h @@ -0,0 +1,39 @@ +#ifndef INFINI_OPS_CUDA_CAST_DSL_H_ +#define INFINI_OPS_CUDA_CAST_DSL_H_ + +#include "cuda/templates/unary_elementwise.cuh" +#include "base/cast.h" + +namespace infini::ops { + +// Device-side unary functor for `Cast` (DSL). +template +struct DslCastOp { + template + __device__ __forceinline__ TOut operator()(const TIn& x) const { + auto va = Caster::template Cast(x); + return Caster::template Cast(va); + } +}; + +template +class DslCudaCast : public Cast { + public: + DslCudaCast(const Tensor input, Tensor out) + : Cast{input, out}, + brick_{input, out, ndim_} {} + + void operator()(const Tensor input, Tensor out) const override { + brick_.template Run( + stream_, input, out, output_size_, ndim_, + is_input_contiguous_, is_out_contiguous_, + input_dtype_, out_dtype_); + } + + private: + UnaryElementwiseBrick brick_; +}; + +} // namespace infini::ops + +#endif diff --git a/src/cuda/cat/kernel.cuh b/src/cuda/cat/kernel.cuh new file mode 100644 index 00000000..187fc37a --- /dev/null +++ b/src/cuda/cat/kernel.cuh @@ -0,0 +1,52 @@ +#ifndef INFINI_OPS_CUDA_CAT_KERNEL_CUH_ +#define INFINI_OPS_CUDA_CAT_KERNEL_CUH_ + +#include "cuda/kernel_commons.cuh" + +namespace infini::ops { + +template +__global__ void CatKernel(T* __restrict__ out, + const void* const* __restrict__ inputs, + const size_t* __restrict__ cum_sizes, + size_t input_count, size_t outer_size, + size_t inner_size, size_t total_dim_size, + size_t output_size) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (idx >= output_size) { + return; + } + + // Decompose flat index into (outer, dim_and_inner). + size_t slice_size = total_dim_size * inner_size; + size_t outer = idx / slice_size; + size_t rem = idx % slice_size; + size_t dim_idx = rem / inner_size; + size_t inner = rem % inner_size; + + // Find which input tensor this element belongs to via cumulative sizes. + size_t input_idx = 0; + + for (size_t i = 0; i < input_count; ++i) { + if (dim_idx < cum_sizes[i]) { + input_idx = i; + break; + } + } + + // Compute the local dimension index within the input tensor. + size_t local_dim = dim_idx - (input_idx > 0 ? cum_sizes[input_idx - 1] : 0); + size_t input_dim_size = + cum_sizes[input_idx] - (input_idx > 0 ? cum_sizes[input_idx - 1] : 0); + + const T* in_ptr = static_cast(inputs[input_idx]); + size_t in_offset = outer * input_dim_size * inner_size + + local_dim * inner_size + inner; + + out[idx] = in_ptr[in_offset]; +} + +} // namespace infini::ops + +#endif diff --git a/src/cuda/cat/kernel.h b/src/cuda/cat/kernel.h new file mode 100644 index 00000000..e8e0e103 --- /dev/null +++ b/src/cuda/cat/kernel.h @@ -0,0 +1,89 @@ +#ifndef INFINI_OPS_CUDA_CAT_KERNEL_H_ +#define INFINI_OPS_CUDA_CAT_KERNEL_H_ + +#include +#include +#include +#include + +#include "base/cat.h" +#include "common/generic_utils.h" +#include "cuda/cat/kernel.cuh" +#include "cuda/kernel_commons.cuh" +#include "cuda/runtime_utils.h" + +namespace infini::ops { + +template +class CudaCat : public Cat { + public: + CudaCat(const Tensor first_input, std::vector rest_inputs, + int64_t dim, Tensor out) + : Cat{first_input, std::move(rest_inputs), dim, out} { + // Allocate device memory for input pointers and cumulative sizes. + size_t ptrs_size = input_count_ * sizeof(const void*); + size_t cum_size = input_count_ * sizeof(size_t); + size_t metadata_size = ptrs_size + cum_size; + + std::vector metadata(metadata_size); + + Backend::Malloc((void**)&d_metadata_, metadata_size); + + // Copy input data pointers. + std::vector input_ptrs(input_count_); + + for (size_t i = 0; i < input_count_; ++i) { + input_ptrs[i] = inputs_[i].data(); + } + + std::memcpy(metadata.data(), input_ptrs.data(), ptrs_size); + + // Copy cumulative dimension sizes. + std::memcpy(metadata.data() + ptrs_size, cum_dim_sizes_.data(), cum_size); + + Backend::Memcpy(d_metadata_, metadata.data(), metadata_size, + Backend::MemcpyHostToDevice); + + d_inputs_ = reinterpret_cast(d_metadata_); + d_cum_sizes_ = reinterpret_cast(d_metadata_ + ptrs_size); + } + + ~CudaCat() { Backend::Free(d_metadata_); } + + void operator()(const Tensor first_input, std::vector rest_inputs, + int64_t dim, Tensor out) const override { + int block_size = RuntimeUtils::GetOptimalBlockSize(); + DispatchFunc( + {static_cast(dtype_), block_size}, + [&](auto list_tag) { + using T = TypeMapType(list_tag)>; + constexpr int kBlockSize = ListGet<1>(list_tag); + + auto cuda_stream = + static_cast(stream_ ? stream_ : 0); + dim3 blockDims( + std::min(static_cast(block_size), output_size_)); + dim3 gridDims(utils::CeilDiv(output_size_, blockDims.x)); + + T* d_out = reinterpret_cast(out.data()); + size_t total_dim_size = cum_dim_sizes_.back(); + + CatKernel + <<>>( + d_out, d_inputs_, d_cum_sizes_, input_count_, outer_size_, + inner_size_, total_dim_size, output_size_); + }, + "CudaCat::operator()"); + } + + private: + std::byte* d_metadata_{nullptr}; + + const void** d_inputs_{nullptr}; + + size_t* d_cum_sizes_{nullptr}; +}; + +} // namespace infini::ops + +#endif diff --git a/src/cuda/causal_softmax/kernel.h b/src/cuda/causal_softmax/kernel.h index 7c7ac871..cffa0713 100644 --- a/src/cuda/causal_softmax/kernel.h +++ b/src/cuda/causal_softmax/kernel.h @@ -7,6 +7,7 @@ #include "base/causal_softmax.h" #include "cuda/causal_softmax/kernel.cuh" #include "cuda/kernel_commons.cuh" +#include "cuda/runtime_utils.h" #include "data_type.h" #include "dispatcher.h" diff --git a/src/cuda/flash_attention/kernel.h b/src/cuda/flash_attention/kernel.h new file mode 100644 index 00000000..4433637b --- /dev/null +++ b/src/cuda/flash_attention/kernel.h @@ -0,0 +1,901 @@ +#ifndef INFINI_OPS_CUDA_FLASH_ATTENTION_KERNEL_H_ +#define INFINI_OPS_CUDA_FLASH_ATTENTION_KERNEL_H_ + +#include +#include +#include +#include + +#include "base/flash_attention.h" +#include "flashinfer/allocator.h" +#include "flashinfer/attention/decode.cuh" +#include "flashinfer/attention/default_decode_params.cuh" +#include "flashinfer/attention/default_prefill_params.cuh" +#include "flashinfer/attention/mask.cuh" +#include "flashinfer/attention/prefill.cuh" +#include "flashinfer/attention/scheduler.cuh" +#include "flashinfer/attention/variants.cuh" +#include "flashinfer/page.cuh" +#include "flashinfer/pos_enc.cuh" + +namespace infini::ops { + +// FlashAttention via FlashInfer header-only API. +// +// Supports four modes, selected by the presence of optional tensors: +// 1. Paged decode: `block_table` present — batch decode with paged KV cache +// 2. Batch prefill: `cu_seqlens_q` present — multiple packed sequences +// 3. Single decode: `num_tokens == 1` — single token, contiguous KV +// 4. Single prefill: default — single sequence, contiguous KV +// +// Batch prefill uses `BatchPrefillWithRaggedKVCacheDispatched` with the +// `PrefillPlan` scheduler (split-KV disabled). Paged decode uses +// `BatchDecodeWithPagedKVCacheDispatched` with the `DecodePlan` scheduler. +template +class CudaFlashAttention : public FlashAttention { + // FlashInfer recommends 128 MB for each scheduler workspace buffer. + static constexpr size_t kIntWorkspaceBytes = 128 * 1024 * 1024; + static constexpr size_t kFloatWorkspaceBytes = 128 * 1024 * 1024; + + // Scratch region after the two large buffers, used for small metadata + // arrays (`d_qo_indptr`, `d_kv_indptr`, page indices, etc.). + static constexpr size_t kScratchBytes = 8 * 1024 * 1024; // 8 MB. + + // Pinned host staging buffer for FlashInfer scheduler. + static constexpr size_t kPinnedBytes = kIntWorkspaceBytes; + + public: + template + CudaFlashAttention(Args&&... args) : FlashAttention(std::forward(args)...) { + cudaMalloc(&default_workspace_, workspace_size_in_bytes()); + assert(default_workspace_ && "failed to allocate device workspace"); + cudaMallocHost(&pinned_workspace_, kPinnedBytes); + assert(pinned_workspace_ && "failed to allocate pinned host workspace"); + } + + ~CudaFlashAttention() override { + if (default_workspace_) { + cudaFree(default_workspace_); + default_workspace_ = nullptr; + } + + if (pinned_workspace_) { + cudaFreeHost(pinned_workspace_); + pinned_workspace_ = nullptr; + } + } + + // Non-copyable, non-movable (pinned memory ownership). + CudaFlashAttention(const CudaFlashAttention&) = delete; + CudaFlashAttention& operator=(const CudaFlashAttention&) = delete; + + std::size_t workspace_size_in_bytes() const override { + // int_workspace (128 MB) + float_workspace (128 MB) + scratch (8 MB). + return kIntWorkspaceBytes + kFloatWorkspaceBytes + kScratchBytes; + } + + void operator()(const Tensor query, const Tensor key, const Tensor value, + std::optional cu_seqlens_q, + std::optional cu_seqlens_kv, + std::optional block_table, int64_t num_heads, + int64_t num_kv_heads, int64_t head_size, double scale, + bool causal, int64_t window_left, int64_t window_right, + int64_t block_size, Tensor output) const override { + auto cuda_stream = + static_cast(stream_ ? stream_ : 0); + + if (block_table.has_value()) { + // Paged decode: block_table present. + DispatchHeadDimPagedDecode(query, key, value, cu_seqlens_q.value(), + cu_seqlens_kv.value(), block_table.value(), + output, num_heads, num_kv_heads, head_size, + scale, window_left, block_size, cuda_stream); + } else if (cu_seqlens_q.has_value()) { + // Batch prefill: cu_seqlens present, packed sequences. + auto mask_mode = causal ? flashinfer::MaskMode::kCausal + : flashinfer::MaskMode::kNone; + DispatchHeadDimBatchPrefill(query, key, value, cu_seqlens_q.value(), + cu_seqlens_kv.value(), output, num_heads, + num_kv_heads, head_size, scale, window_left, + mask_mode, cuda_stream); + } else if (num_tokens_ == 1) { + // Single decode: single token query, full KV cache. + DispatchHeadDimDecode(query, key, value, output, num_heads, num_kv_heads, + head_size, scale, window_left, cuda_stream); + } else if (causal) { + DispatchHeadDimPrefill(query, key, value, output, num_heads, num_kv_heads, + head_size, scale, window_left, + flashinfer::MaskMode::kCausal, cuda_stream); + } else { + DispatchHeadDimPrefill(query, key, value, output, num_heads, num_kv_heads, + head_size, scale, window_left, + flashinfer::MaskMode::kNone, cuda_stream); + } + } + + private: + // ---- Prefill path (query seq_len > 1) --------------------------------- + + void DispatchHeadDimPrefill(const Tensor& query, const Tensor& key, + const Tensor& value, Tensor& output, + int64_t num_heads, int64_t num_kv_heads, + int64_t head_size, double scale, + int64_t window_left, + flashinfer::MaskMode mask_mode, + typename Backend::Stream stream) const { + switch (head_size) { + case 64: + DispatchMaskModePrefill<64>(query, key, value, output, num_heads, + num_kv_heads, scale, window_left, + mask_mode, stream); + break; + case 128: + DispatchMaskModePrefill<128>(query, key, value, output, num_heads, + num_kv_heads, scale, window_left, + mask_mode, stream); + break; + case 256: + DispatchMaskModePrefill<256>(query, key, value, output, num_heads, + num_kv_heads, scale, window_left, + mask_mode, stream); + break; + default: + assert(false && "unsupported head dimension for FlashAttention"); + } + } + + template + void DispatchMaskModePrefill(const Tensor& query, const Tensor& key, + const Tensor& value, Tensor& output, + int64_t num_heads, int64_t num_kv_heads, + double scale, int64_t window_left, + flashinfer::MaskMode mask_mode, + typename Backend::Stream stream) const { + switch (mask_mode) { + case flashinfer::MaskMode::kCausal: + DispatchDtypePrefill( + query, key, value, output, num_heads, num_kv_heads, scale, + window_left, stream); + break; + case flashinfer::MaskMode::kNone: + DispatchDtypePrefill( + query, key, value, output, num_heads, num_kv_heads, scale, + window_left, stream); + break; + default: + assert(false && "unsupported mask mode for FlashAttention"); + } + } + + template + void DispatchDtypePrefill(const Tensor& query, const Tensor& key, + const Tensor& value, Tensor& output, + int64_t num_heads, int64_t num_kv_heads, + double scale, int64_t window_left, + typename Backend::Stream stream) const { + DispatchFunc( + dtype_, + [&](auto type_tag) { + using DType = typename decltype(type_tag)::type; + LaunchPrefill( + query, key, value, output, num_heads, num_kv_heads, scale, + window_left, stream); + }, + "CudaFlashAttention::prefill"); + } + + template + void LaunchPrefill(const Tensor& query, const Tensor& key, + const Tensor& value, Tensor& output, int64_t num_heads, + int64_t num_kv_heads, double scale, int64_t window_left, + typename Backend::Stream stream) const { + using AttentionVariant = + flashinfer::DefaultAttention; + + flashinfer::SinglePrefillParams params; + params.q = reinterpret_cast(const_cast(query.data())); + params.k = reinterpret_cast(const_cast(key.data())); + params.v = reinterpret_cast(const_cast(value.data())); + params.o = reinterpret_cast(output.data()); + params.lse = nullptr; + params.maybe_alibi_slopes = nullptr; + params.maybe_custom_mask = nullptr; + + params.qo_len = static_cast(num_tokens_); + params.kv_len = static_cast(key.size(0)); + params.num_qo_heads = static_cast(num_heads); + params.num_kv_heads = static_cast(num_kv_heads); + params.group_size = flashinfer::uint_fastdiv( + static_cast(num_heads / num_kv_heads)); + params.head_dim = HEAD_DIM; + + // Strides for NHD layout [seq_len, num_heads, head_dim]. + params.q_stride_n = static_cast(num_heads * HEAD_DIM); + params.q_stride_h = HEAD_DIM; + params.k_stride_n = static_cast(num_kv_heads * HEAD_DIM); + params.k_stride_h = HEAD_DIM; + params.v_stride_n = static_cast(num_kv_heads * HEAD_DIM); + params.v_stride_h = HEAD_DIM; + + params.sm_scale = static_cast(scale); + params.window_left = static_cast(window_left); + params.logits_soft_cap = 0.0f; + params.rope_rcp_scale = 1.0f; + params.rope_rcp_theta = 1.0f; + params.partition_kv = 0; + + cudaError_t err = + flashinfer::SinglePrefillWithKVCacheDispatched< + HEAD_DIM, HEAD_DIM, flashinfer::PosEncodingMode::kNone, + /*USE_FP16_QK_REDUCTION=*/false, MASK_MODE, AttentionVariant>( + params, /*tmp=*/nullptr, stream); + + assert(err == cudaSuccess && + "FlashInfer SinglePrefillWithKVCacheDispatched failed"); + (void)err; + } + + // ---- Decode path (query seq_len == 1) --------------------------------- + + void DispatchHeadDimDecode(const Tensor& query, const Tensor& key, + const Tensor& value, Tensor& output, + int64_t num_heads, int64_t num_kv_heads, + int64_t head_size, double scale, + int64_t window_left, + typename Backend::Stream stream) const { + switch (head_size) { + case 64: + DispatchDtypeDecode<64>(query, key, value, output, num_heads, + num_kv_heads, scale, window_left, stream); + break; + case 128: + DispatchDtypeDecode<128>(query, key, value, output, num_heads, + num_kv_heads, scale, window_left, stream); + break; + case 256: + DispatchDtypeDecode<256>(query, key, value, output, num_heads, + num_kv_heads, scale, window_left, stream); + break; + default: + assert(false && "unsupported head dimension for FlashAttention decode"); + } + } + + template + void DispatchDtypeDecode(const Tensor& query, const Tensor& key, + const Tensor& value, Tensor& output, + int64_t num_heads, int64_t num_kv_heads, + double scale, int64_t window_left, + typename Backend::Stream stream) const { + DispatchFunc( + dtype_, + [&](auto type_tag) { + using DType = typename decltype(type_tag)::type; + LaunchDecode(query, key, value, output, num_heads, + num_kv_heads, scale, window_left, + stream); + }, + "CudaFlashAttention::decode"); + } + + template + void LaunchDecode(const Tensor& query, const Tensor& key, + const Tensor& value, Tensor& output, int64_t num_heads, + int64_t num_kv_heads, double scale, int64_t window_left, + typename Backend::Stream stream) const { + using AttentionVariant = + flashinfer::DefaultAttention; + + uint32_t kv_len = static_cast(key.size(0)); + + flashinfer::SingleDecodeParams params( + reinterpret_cast(const_cast(query.data())), + reinterpret_cast(const_cast(key.data())), + reinterpret_cast(const_cast(value.data())), + reinterpret_cast(output.data()), + /*maybe_alibi_slopes=*/nullptr, kv_len, + static_cast(num_heads), static_cast(num_kv_heads), + flashinfer::QKVLayout::kNHD, HEAD_DIM, + static_cast(window_left), + /*logits_soft_cap=*/0.0f, static_cast(scale), + /*rope_scale=*/1.0f, /*rope_theta=*/1e4f); + + // Decode needs a temporary buffer for partial results. + // Size: num_qo_heads * HEAD_DIM * sizeof(DType). + // For single decode this is small enough to use nullptr (non-partitioned). + cudaError_t err = + flashinfer::SingleDecodeWithKVCacheDispatched< + HEAD_DIM, flashinfer::PosEncodingMode::kNone, AttentionVariant>( + params, /*tmp=*/nullptr, stream); + + assert(err == cudaSuccess && + "FlashInfer SingleDecodeWithKVCacheDispatched failed"); + (void)err; + } + + // ---- Batch prefill (loop over sequences) -------------------------------- + + void DispatchHeadDimBatchPrefill( + const Tensor& query, const Tensor& key, const Tensor& value, + const Tensor& cu_seqlens_q, const Tensor& cu_seqlens_kv, Tensor& output, + int64_t num_heads, int64_t num_kv_heads, int64_t head_size, double scale, + int64_t window_left, flashinfer::MaskMode mask_mode, + typename Backend::Stream stream) const { + switch (head_size) { + case 64: + DispatchMaskModeBatchPrefill<64>(query, key, value, cu_seqlens_q, + cu_seqlens_kv, output, num_heads, + num_kv_heads, scale, window_left, + mask_mode, stream); + break; + case 128: + DispatchMaskModeBatchPrefill<128>(query, key, value, cu_seqlens_q, + cu_seqlens_kv, output, num_heads, + num_kv_heads, scale, window_left, + mask_mode, stream); + break; + case 256: + DispatchMaskModeBatchPrefill<256>(query, key, value, cu_seqlens_q, + cu_seqlens_kv, output, num_heads, + num_kv_heads, scale, window_left, + mask_mode, stream); + break; + default: + assert(false && "unsupported head dimension for FlashAttention"); + } + } + + template + void DispatchMaskModeBatchPrefill( + const Tensor& query, const Tensor& key, const Tensor& value, + const Tensor& cu_seqlens_q, const Tensor& cu_seqlens_kv, Tensor& output, + int64_t num_heads, int64_t num_kv_heads, double scale, + int64_t window_left, flashinfer::MaskMode mask_mode, + typename Backend::Stream stream) const { + switch (mask_mode) { + case flashinfer::MaskMode::kCausal: + DispatchDtypeBatchPrefill( + query, key, value, cu_seqlens_q, cu_seqlens_kv, output, num_heads, + num_kv_heads, scale, window_left, stream); + break; + case flashinfer::MaskMode::kNone: + DispatchDtypeBatchPrefill( + query, key, value, cu_seqlens_q, cu_seqlens_kv, output, num_heads, + num_kv_heads, scale, window_left, stream); + break; + default: + assert(false && "unsupported mask mode for FlashAttention"); + } + } + + template + void DispatchDtypeBatchPrefill( + const Tensor& query, const Tensor& key, const Tensor& value, + const Tensor& cu_seqlens_q, const Tensor& cu_seqlens_kv, Tensor& output, + int64_t num_heads, int64_t num_kv_heads, double scale, + int64_t window_left, typename Backend::Stream stream) const { + DispatchFunc( + dtype_, + [&](auto type_tag) { + using DType = typename decltype(type_tag)::type; + LaunchBatchPrefill( + query, key, value, cu_seqlens_q, cu_seqlens_kv, output, num_heads, + num_kv_heads, scale, window_left, stream); + }, + "CudaFlashAttention::batch_prefill"); + } + + // Batch prefill using FlashInfer's native batch kernel with scheduler. + template + void LaunchBatchPrefill(const Tensor& query, const Tensor& key, + const Tensor& value, const Tensor& cu_seqlens_q, + const Tensor& cu_seqlens_kv, Tensor& output, + int64_t num_heads, int64_t num_kv_heads, double scale, + int64_t window_left, + typename Backend::Stream stream) const { + // Copy cu_seqlens (int64) from device to host, then narrow to int32. + auto batch_size_plus_one = cu_seqlens_q.size(0); + auto batch_size = static_cast(batch_size_plus_one - 1); + + std::vector h_cu_q_i64(batch_size_plus_one); + std::vector h_cu_kv_i64(batch_size_plus_one); + cudaMemcpyAsync(h_cu_q_i64.data(), cu_seqlens_q.data(), + batch_size_plus_one * sizeof(int64_t), + cudaMemcpyDeviceToHost, stream); + cudaMemcpyAsync(h_cu_kv_i64.data(), cu_seqlens_kv.data(), + batch_size_plus_one * sizeof(int64_t), + cudaMemcpyDeviceToHost, stream); + cudaStreamSynchronize(stream); + + // Convert to int32 for FlashInfer scheduler (IdType = int32_t). + std::vector h_cu_q(batch_size_plus_one); + std::vector h_cu_kv(batch_size_plus_one); + + for (size_t i = 0; i < batch_size_plus_one; ++i) { + h_cu_q[i] = static_cast(h_cu_q_i64[i]); + h_cu_kv[i] = static_cast(h_cu_kv_i64[i]); + } + + uint32_t total_num_rows = static_cast(h_cu_q[batch_size]); + + // Partition pre-allocated device workspace into sub-regions. + void* active_workspace = workspace_ ? workspace_ : default_workspace_; + size_t active_workspace_size = workspace_ ? workspace_size_in_bytes_ + : workspace_size_in_bytes(); + char* ws = static_cast(active_workspace); + size_t ws_offset = 0; + + void* int_buf = ws + ws_offset; + ws_offset += kIntWorkspaceBytes; + + // Run PrefillPlan with split-KV disabled for simplicity. + flashinfer::PrefillPlanInfo plan_info; + cudaError_t plan_err = flashinfer::PrefillPlan( + /*float_buffer=*/nullptr, + /*float_workspace_size_in_bytes=*/0, int_buf, pinned_workspace_, + kIntWorkspaceBytes, plan_info, h_cu_q.data(), h_cu_kv.data(), + total_num_rows, batch_size, + static_cast(num_heads), static_cast(num_kv_heads), + /*head_dim_qk=*/HEAD_DIM, /*head_dim_vo=*/HEAD_DIM, + /*page_size=*/1, + /*enable_cuda_graph=*/false, /*sizeof_dtype_o=*/sizeof(DType), + static_cast(window_left), + /*fixed_split_size=*/0, /*disable_split_kv=*/true, + /*num_colocated_ctas=*/0, stream); + + assert(plan_err == cudaSuccess && "FlashInfer PrefillPlan failed"); + (void)plan_err; + + // Upload cu_seqlens as int32 to device from the scratch region. + // Skip float workspace region (unused for prefill) to reach scratch. + ws_offset += kFloatWorkspaceBytes; + + int32_t* d_qo_indptr = reinterpret_cast(ws + ws_offset); + ws_offset += batch_size_plus_one * sizeof(int32_t); + + int32_t* d_kv_indptr = reinterpret_cast(ws + ws_offset); + ws_offset += batch_size_plus_one * sizeof(int32_t); + + assert(ws_offset <= active_workspace_size && + "FlashAttention batch prefill workspace overflow"); + + cudaMemcpyAsync(d_qo_indptr, h_cu_q.data(), + batch_size_plus_one * sizeof(int32_t), + cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(d_kv_indptr, h_cu_kv.data(), + batch_size_plus_one * sizeof(int32_t), + cudaMemcpyHostToDevice, stream); + + using AttentionVariant = + flashinfer::DefaultAttention; + using Params = + flashinfer::BatchPrefillRaggedParams; + + uint32_t q_stride_n = static_cast(num_heads * HEAD_DIM); + uint32_t k_stride_n = static_cast(num_kv_heads * HEAD_DIM); + + Params params; + params.q = + reinterpret_cast(const_cast(query.data())); + params.k = + reinterpret_cast(const_cast(key.data())); + params.v = + reinterpret_cast(const_cast(value.data())); + params.o = reinterpret_cast(output.data()); + params.lse = nullptr; + params.maybe_custom_mask = nullptr; + params.maybe_alibi_slopes = nullptr; + params.maybe_q_rope_offset = nullptr; + params.maybe_k_rope_offset = nullptr; + params.maybe_mask_indptr = nullptr; + params.q_indptr = d_qo_indptr; + params.kv_indptr = d_kv_indptr; + params.num_qo_heads = static_cast(num_heads); + params.num_kv_heads = static_cast(num_kv_heads); + params.group_size = flashinfer::uint_fastdiv( + static_cast(num_heads / num_kv_heads)); + params.q_stride_n = q_stride_n; + params.q_stride_h = HEAD_DIM; + params.k_stride_n = k_stride_n; + params.k_stride_h = HEAD_DIM; + params.v_stride_n = k_stride_n; + params.v_stride_h = HEAD_DIM; + params.sm_scale = static_cast(scale); + params.window_left = static_cast(window_left); + params.logits_soft_cap = 0.0f; + params.rope_rcp_scale = 1.0f; + params.rope_rcp_theta = 1.0f; + + // Fill scheduling metadata from plan_info. + params.padded_batch_size = + static_cast(plan_info.padded_batch_size); + params.partition_kv = plan_info.split_kv; + params.max_total_num_rows = total_num_rows; + params.total_num_rows = plan_info.enable_cuda_graph + ? flashinfer::GetPtrFromBaseOffset( + int_buf, plan_info.total_num_rows_offset) + : nullptr; + params.request_indices = flashinfer::GetPtrFromBaseOffset( + int_buf, plan_info.request_indices_offset); + params.qo_tile_indices = flashinfer::GetPtrFromBaseOffset( + int_buf, plan_info.qo_tile_indices_offset); + params.kv_tile_indices = flashinfer::GetPtrFromBaseOffset( + int_buf, plan_info.kv_tile_indices_offset); + params.merge_indptr = plan_info.split_kv + ? flashinfer::GetPtrFromBaseOffset( + int_buf, plan_info.merge_indptr_offset) + : nullptr; + params.o_indptr = flashinfer::GetPtrFromBaseOffset( + int_buf, plan_info.o_indptr_offset); + params.kv_chunk_size_ptr = flashinfer::GetPtrFromBaseOffset( + int_buf, plan_info.kv_chunk_size_ptr_offset); + params.block_valid_mask = plan_info.split_kv + ? flashinfer::GetPtrFromBaseOffset( + int_buf, plan_info.block_valid_mask_offset) + : nullptr; + params.maybe_prefix_len_ptr = nullptr; + params.maybe_token_pos_in_items_ptr = nullptr; + params.token_pos_in_items_len = 0; + params.maybe_max_item_len_ptr = nullptr; + + // Dispatch on CTA_TILE_Q determined by the plan. + uint32_t cta_tile_q = static_cast(plan_info.cta_tile_q); + + switch (cta_tile_q) { + case 128: + LaunchBatchPrefillKernel<128, HEAD_DIM, MASK_MODE, DType, + AttentionVariant>(params, stream); + break; + case 64: + LaunchBatchPrefillKernel<64, HEAD_DIM, MASK_MODE, DType, + AttentionVariant>(params, stream); + break; + case 16: + LaunchBatchPrefillKernel<16, HEAD_DIM, MASK_MODE, DType, + AttentionVariant>(params, stream); + break; + default: + assert(false && "unsupported CTA_TILE_Q from PrefillPlan"); + } + + } + + // Helper to dispatch batch prefill kernel with a compile-time CTA_TILE_Q. + template + static void LaunchBatchPrefillKernel( + flashinfer::BatchPrefillRaggedParams& + params, + typename Backend::Stream stream) { + cudaError_t err = + flashinfer::BatchPrefillWithRaggedKVCacheDispatched< + CTA_TILE_Q, HEAD_DIM_VAL, HEAD_DIM_VAL, + flashinfer::PosEncodingMode::kNone, + /*USE_FP16_QK_REDUCTION=*/false, MASK_MODE_VAL, + AttentionVariant>(params, /*tmp_v=*/nullptr, + /*tmp_s=*/nullptr, + /*enable_pdl=*/false, stream); + + assert(err == cudaSuccess && + "FlashInfer BatchPrefillWithRaggedKVCacheDispatched failed"); + (void)err; + } + + // ---- Paged decode (batch via scheduler) ---------------------------------- + + void DispatchHeadDimPagedDecode( + const Tensor& query, const Tensor& key, const Tensor& value, + const Tensor& cu_seqlens_q, const Tensor& cu_seqlens_kv, + const Tensor& block_table, Tensor& output, int64_t num_heads, + int64_t num_kv_heads, int64_t head_size, double scale, + int64_t window_left, int64_t block_size, + typename Backend::Stream stream) const { + switch (head_size) { + case 64: + DispatchDtypePagedDecode<64>(query, key, value, cu_seqlens_q, + cu_seqlens_kv, block_table, output, + num_heads, num_kv_heads, scale, + window_left, block_size, stream); + break; + case 128: + DispatchDtypePagedDecode<128>(query, key, value, cu_seqlens_q, + cu_seqlens_kv, block_table, output, + num_heads, num_kv_heads, scale, + window_left, block_size, stream); + break; + case 256: + DispatchDtypePagedDecode<256>(query, key, value, cu_seqlens_q, + cu_seqlens_kv, block_table, output, + num_heads, num_kv_heads, scale, + window_left, block_size, stream); + break; + default: + assert(false && + "unsupported head dimension for FlashAttention paged decode"); + } + } + + template + void DispatchDtypePagedDecode( + const Tensor& query, const Tensor& key, const Tensor& value, + const Tensor& cu_seqlens_q, const Tensor& cu_seqlens_kv, + const Tensor& block_table, Tensor& output, int64_t num_heads, + int64_t num_kv_heads, double scale, int64_t window_left, + int64_t block_size, typename Backend::Stream stream) const { + DispatchFunc( + dtype_, + [&](auto type_tag) { + using DType = typename decltype(type_tag)::type; + LaunchPagedDecode( + query, key, value, cu_seqlens_q, cu_seqlens_kv, block_table, + output, num_heads, num_kv_heads, scale, window_left, block_size, + stream); + }, + "CudaFlashAttention::paged_decode"); + } + + // Batch paged decode using FlashInfer's native batch kernel with scheduler. + template + void LaunchPagedDecode(const Tensor& query, const Tensor& key, + const Tensor& value, const Tensor& cu_seqlens_q, + const Tensor& cu_seqlens_kv, + const Tensor& block_table, Tensor& output, + int64_t num_heads, int64_t num_kv_heads, double scale, + int64_t window_left, int64_t block_size, + typename Backend::Stream stream) const { + // Copy metadata to host. + auto num_reqs = static_cast(block_table.size(0)); + auto max_blocks_per_req = block_table.size(1); + + // cu_seqlens_kv is int64 on device. + std::vector h_cu_kv_i64(num_reqs + 1); + cudaMemcpyAsync(h_cu_kv_i64.data(), cu_seqlens_kv.data(), + (num_reqs + 1) * sizeof(int64_t), cudaMemcpyDeviceToHost, + stream); + cudaStreamSynchronize(stream); + + // Build page indptr and last_page_len arrays for paged_kv_t. + // block_table has shape [num_reqs, max_blocks_per_req] on device. + std::vector h_page_indptr(num_reqs + 1); + std::vector h_last_page_len(num_reqs); + h_page_indptr[0] = 0; + + for (uint32_t i = 0; i < num_reqs; ++i) { + int64_t kv_len = h_cu_kv_i64[i + 1] - h_cu_kv_i64[i]; + uint32_t num_pages = + kv_len > 0 + ? static_cast((kv_len + block_size - 1) / block_size) + : 0; + h_page_indptr[i + 1] = h_page_indptr[i] + static_cast(num_pages); + + if (kv_len > 0) { + int32_t last_len = static_cast(kv_len % block_size); + h_last_page_len[i] = last_len == 0 + ? static_cast(block_size) + : last_len; + } else { + h_last_page_len[i] = 0; + } + } + + int32_t total_pages = h_page_indptr[num_reqs]; + + // Flatten block_table into a contiguous page indices array on device. + // block_table is [num_reqs, max_blocks_per_req] int32 on device; we need + // a flat [total_pages] array with only the valid entries. + std::vector h_block_table(num_reqs * max_blocks_per_req); + cudaMemcpyAsync(h_block_table.data(), block_table.data(), + h_block_table.size() * sizeof(int32_t), + cudaMemcpyDeviceToHost, stream); + cudaStreamSynchronize(stream); + + std::vector h_page_indices(total_pages); + int32_t idx = 0; + + for (uint32_t i = 0; i < num_reqs; ++i) { + int32_t num_pages = + h_page_indptr[i + 1] - h_page_indptr[i]; + + for (int32_t j = 0; j < num_pages; ++j) { + h_page_indices[idx++] = + h_block_table[i * max_blocks_per_req + j]; + } + } + + // Partition pre-allocated device workspace into sub-regions. + void* active_workspace = workspace_ ? workspace_ : default_workspace_; + size_t active_workspace_size = workspace_ ? workspace_size_in_bytes_ + : workspace_size_in_bytes(); + char* ws = static_cast(active_workspace); + size_t ws_offset = 0; + + void* int_buf = ws + ws_offset; + ws_offset += kIntWorkspaceBytes; + + void* float_buf = ws + ws_offset; + ws_offset += kFloatWorkspaceBytes; + + // Small metadata arrays from the scratch region. + int32_t* d_page_indices = reinterpret_cast(ws + ws_offset); + ws_offset += std::max(total_pages, 1) * sizeof(int32_t); + + int32_t* d_page_indptr = reinterpret_cast(ws + ws_offset); + ws_offset += (num_reqs + 1) * sizeof(int32_t); + + int32_t* d_last_page_len = reinterpret_cast(ws + ws_offset); + ws_offset += num_reqs * sizeof(int32_t); + + assert(ws_offset <= active_workspace_size && + "FlashAttention paged decode workspace overflow"); + + if (total_pages > 0) { + cudaMemcpyAsync(d_page_indices, h_page_indices.data(), + total_pages * sizeof(int32_t), cudaMemcpyHostToDevice, + stream); + } + + cudaMemcpyAsync(d_page_indptr, h_page_indptr.data(), + (num_reqs + 1) * sizeof(int32_t), cudaMemcpyHostToDevice, + stream); + cudaMemcpyAsync(d_last_page_len, h_last_page_len.data(), + num_reqs * sizeof(int32_t), cudaMemcpyHostToDevice, + stream); + + // KV cache layout: [num_blocks, block_size, num_kv_heads, head_dim] (NHD). + auto* kv_data = + reinterpret_cast(const_cast(key.data())); + + flashinfer::paged_kv_t paged_kv( + static_cast(num_kv_heads), + static_cast(block_size), HEAD_DIM, num_reqs, + flashinfer::QKVLayout::kNHD, kv_data, kv_data, d_page_indices, + d_page_indptr, d_last_page_len); + + // Device workspace was partitioned above; use pinned host member. + + using AttentionVariant = + flashinfer::DefaultAttention; + using Params = + flashinfer::BatchDecodeParams; + + uint32_t group_size = static_cast(num_heads / num_kv_heads); + + // Dispatch on GQA group size for DecodePlan + kernel launch. The group + // size must be a compile-time constant for the work estimation function. + switch (group_size) { + case 1: + LaunchPagedDecodeInner( + query, output, paged_kv, float_buf, kFloatWorkspaceBytes, int_buf, + pinned_workspace_, kIntWorkspaceBytes, h_page_indptr.data(), num_reqs, + num_heads, scale, window_left, block_size, stream); + break; + case 2: + LaunchPagedDecodeInner( + query, output, paged_kv, float_buf, kFloatWorkspaceBytes, int_buf, + pinned_workspace_, kIntWorkspaceBytes, h_page_indptr.data(), num_reqs, + num_heads, scale, window_left, block_size, stream); + break; + case 4: + LaunchPagedDecodeInner( + query, output, paged_kv, float_buf, kFloatWorkspaceBytes, int_buf, + pinned_workspace_, kIntWorkspaceBytes, h_page_indptr.data(), num_reqs, + num_heads, scale, window_left, block_size, stream); + break; + case 8: + LaunchPagedDecodeInner( + query, output, paged_kv, float_buf, kFloatWorkspaceBytes, int_buf, + pinned_workspace_, kIntWorkspaceBytes, h_page_indptr.data(), num_reqs, + num_heads, scale, window_left, block_size, stream); + break; + default: + assert(false && "unsupported GQA group size for paged decode"); + } + + } + + // Inner helper for paged decode, templated on compile-time GROUP_SIZE. + template + static void LaunchPagedDecodeInner( + const Tensor& query, Tensor& output, + flashinfer::paged_kv_t& paged_kv, void* float_buf, + size_t float_ws, void* int_buf, void* pinned_buf, size_t int_ws, + int32_t* page_indptr_h, uint32_t num_reqs, int64_t num_heads, + double scale, int64_t window_left, int64_t block_size, + typename Backend::Stream stream) { + // Work estimation function with compile-time GROUP_SIZE. + cudaError_t (*work_estimation_func)( + bool&, uint32_t&, uint32_t&, uint32_t&, uint32_t&, uint32_t, + int32_t*, uint32_t, uint32_t, bool, cudaStream_t) = + flashinfer::BatchDecodeWithPagedKVCacheWorkEstimationDispatched< + GROUP_SIZE, HEAD_DIM, flashinfer::PosEncodingMode::kNone, + AttentionVariant, Params>; + + flashinfer::DecodePlanInfo plan_info; + cudaError_t plan_err = flashinfer::DecodePlan< + HEAD_DIM, flashinfer::PosEncodingMode::kNone, AttentionVariant, + Params>( + float_buf, float_ws, int_buf, pinned_buf, int_ws, plan_info, + page_indptr_h, num_reqs, static_cast(num_heads), + static_cast(block_size), + /*enable_cuda_graph=*/false, stream, work_estimation_func); + + assert(plan_err == cudaSuccess && "FlashInfer DecodePlan failed"); + (void)plan_err; + + // Fill BatchDecodeParams. + uint32_t q_stride_n = static_cast(num_heads * HEAD_DIM); + + Params params( + reinterpret_cast(const_cast(query.data())), + /*q_rope_offset=*/nullptr, paged_kv, + reinterpret_cast(output.data()), + /*lse=*/nullptr, /*maybe_alibi_slopes=*/nullptr, + static_cast(num_heads), + static_cast(q_stride_n), + static_cast(HEAD_DIM), + static_cast(window_left), + /*logits_soft_cap=*/0.0f, static_cast(scale), + /*rope_scale=*/1.0f, /*rope_theta=*/1e4f); + + // Fill scheduling metadata from plan_info. + params.padded_batch_size = + static_cast(plan_info.padded_batch_size); + params.partition_kv = plan_info.split_kv; + params.request_indices = flashinfer::GetPtrFromBaseOffset( + int_buf, plan_info.request_indices_offset); + params.kv_tile_indices = flashinfer::GetPtrFromBaseOffset( + int_buf, plan_info.kv_tile_indices_offset); + params.o_indptr = flashinfer::GetPtrFromBaseOffset( + int_buf, plan_info.o_indptr_offset); + params.kv_chunk_size_ptr = flashinfer::GetPtrFromBaseOffset( + int_buf, plan_info.kv_chunk_size_ptr_offset); + params.block_valid_mask = plan_info.split_kv + ? flashinfer::GetPtrFromBaseOffset( + int_buf, plan_info.block_valid_mask_offset) + : nullptr; + + // Temporary buffers for split-KV reduction. + DType* tmp_v = plan_info.split_kv + ? flashinfer::GetPtrFromBaseOffset( + float_buf, plan_info.v_offset) + : nullptr; + float* tmp_s = plan_info.split_kv + ? flashinfer::GetPtrFromBaseOffset( + float_buf, plan_info.s_offset) + : nullptr; + + cudaError_t err = + flashinfer::BatchDecodeWithPagedKVCacheDispatched< + HEAD_DIM, flashinfer::PosEncodingMode::kNone, AttentionVariant>( + params, tmp_v, tmp_s, /*enable_pdl=*/false, stream); + + assert(err == cudaSuccess && + "FlashInfer BatchDecodeWithPagedKVCacheDispatched failed"); + (void)err; + } + + // Device workspace, allocated once in the constructor. Used as fallback + // when the handle does not provide a workspace buffer. + mutable void* default_workspace_{nullptr}; + + // Pinned host staging buffer, allocated once in the constructor. + mutable void* pinned_workspace_{nullptr}; +}; + +} // namespace infini::ops + +#endif diff --git a/src/cuda/linear/kernel.cuh b/src/cuda/linear/kernel.cuh new file mode 100644 index 00000000..242f3dbd --- /dev/null +++ b/src/cuda/linear/kernel.cuh @@ -0,0 +1,20 @@ +#ifndef INFINI_OPS_CUDA_LINEAR_KERNEL_CUH_ +#define INFINI_OPS_CUDA_LINEAR_KERNEL_CUH_ + +#include "cuda/kernel_commons.cuh" + +namespace infini::ops { + +template +__global__ void BiasAddKernel(T* out, const T* bias, size_t rows, size_t cols) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (idx < rows * cols) { + size_t col = idx % cols; + out[idx] += bias[col]; + } +} + +} // namespace infini::ops + +#endif diff --git a/src/cuda/linear/kernel.h b/src/cuda/linear/kernel.h new file mode 100644 index 00000000..48ae27b8 --- /dev/null +++ b/src/cuda/linear/kernel.h @@ -0,0 +1,220 @@ +#ifndef INFINI_OPS_CUDA_LINEAR_KERNEL_H_ +#define INFINI_OPS_CUDA_LINEAR_KERNEL_H_ + +#include +#include +#include + +// clang-format off +#include "cublasLt.h" +// clang-format on + +#include "base/linear.h" +#include "cuda/linear/kernel.cuh" +#include "cuda/runtime_utils.h" +#include "nvidia/blas_utils.h" + +namespace infini::ops { + +// Linear operator using cuBLASLt with heuristic algorithm selection. +// Computes out = a @ b (+ bias), with optional transpose. +template +class CudaLinear : public Linear { + public: + CudaLinear(const Tensor a, const Tensor b, std::optional bias, + bool trans_a, bool trans_b, Tensor out) + : Linear{a, b, bias, trans_a, trans_b, out}, + a_is_col_major_{a.stride(-1) == 1}, + b_is_col_major_{b.stride(-1) == 1}, + swap_a_and_b_{out.stride(-1) == 1} {} + + void operator()(const Tensor a, const Tensor b, std::optional bias, + bool trans_a, bool trans_b, Tensor out) const override { + auto cuda_stream = + static_cast(stream_ ? stream_ : 0); + + float alpha = 1.0f; + float beta = 0.0f; + + auto op_a = GetOpA(trans_a, trans_b); + auto op_b = GetOpB(trans_a, trans_b); + + auto matmul_m = static_cast(swap_a_and_b_ ? n_ : m_); + auto matmul_n = static_cast(swap_a_and_b_ ? m_ : n_); + auto matmul_k = static_cast(k_); + + const auto* a_ptr = swap_a_and_b_ ? b.data() : a.data(); + const auto* b_ptr = swap_a_and_b_ ? a.data() : b.data(); + auto a_dtype = + BlasUtils::GetDataType( + swap_a_and_b_ ? b.dtype() : a.dtype()); + auto b_dtype = + BlasUtils::GetDataType( + swap_a_and_b_ ? a.dtype() : b.dtype()); + auto c_dtype = + BlasUtils::GetDataType(out.dtype()); + auto a_ld = static_cast(swap_a_and_b_ ? ldb_ : lda_); + auto b_ld = static_cast(swap_a_and_b_ ? lda_ : ldb_); + auto c_ld = static_cast(ldc_); + auto a_batch_stride = static_cast( + swap_a_and_b_ ? batch_stride_b_ : batch_stride_a_); + auto b_batch_stride = static_cast( + swap_a_and_b_ ? batch_stride_a_ : batch_stride_b_); + auto c_batch_stride = static_cast(batch_stride_c_); + + // Create cuBLASLt matmul descriptor. + cublasLtMatmulDesc_t op_desc{}; + auto status = cublasLtMatmulDescCreate( + &op_desc, + BlasUtils::GetComputeType(out.dtype()), + CUDA_R_32F); + assert(status == CUBLAS_STATUS_SUCCESS && + "failed to create cuBLASLt matmul descriptor"); + + status = cublasLtMatmulDescSetAttribute( + op_desc, CUBLASLT_MATMUL_DESC_TRANSA, &op_a, sizeof(op_a)); + assert(status == CUBLAS_STATUS_SUCCESS); + + status = cublasLtMatmulDescSetAttribute( + op_desc, CUBLASLT_MATMUL_DESC_TRANSB, &op_b, sizeof(op_b)); + assert(status == CUBLAS_STATUS_SUCCESS); + + // Create matrix layouts. + cublasLtMatrixLayout_t a_layout{}; + status = cublasLtMatrixLayoutCreate( + &a_layout, a_dtype, + op_a == CUBLAS_OP_N ? matmul_m : matmul_k, + op_a == CUBLAS_OP_N ? matmul_k : matmul_m, a_ld); + assert(status == CUBLAS_STATUS_SUCCESS); + + cublasLtMatrixLayout_t b_layout{}; + status = cublasLtMatrixLayoutCreate( + &b_layout, b_dtype, + op_b == CUBLAS_OP_N ? matmul_k : matmul_n, + op_b == CUBLAS_OP_N ? matmul_n : matmul_k, b_ld); + assert(status == CUBLAS_STATUS_SUCCESS); + + cublasLtMatrixLayout_t c_layout{}; + status = cublasLtMatrixLayoutCreate( + &c_layout, c_dtype, matmul_m, matmul_n, c_ld); + assert(status == CUBLAS_STATUS_SUCCESS); + + if (batch_count_ > 1) { + SetStridedBatchAttributes(a_layout, a_batch_stride); + SetStridedBatchAttributes(b_layout, b_batch_stride); + SetStridedBatchAttributes(c_layout, c_batch_stride); + } + + // Search for optimal algorithm. + cublasLtMatmulPreference_t preference{}; + status = cublasLtMatmulPreferenceCreate(&preference); + assert(status == CUBLAS_STATUS_SUCCESS); + + size_t workspace_size = workspace_size_in_bytes_; + status = cublasLtMatmulPreferenceSetAttribute( + preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &workspace_size, sizeof(workspace_size)); + assert(status == CUBLAS_STATUS_SUCCESS); + + cublasLtMatmulHeuristicResult_t heuristic{}; + int returned_results = 0; + status = cublasLtMatmulAlgoGetHeuristic( + GetHandle(), op_desc, a_layout, b_layout, c_layout, c_layout, + preference, 1, &heuristic, &returned_results); + assert(status == CUBLAS_STATUS_SUCCESS && returned_results > 0 && + "failed to find a cuBLASLt algorithm for Linear"); + + // Execute. + status = cublasLtMatmul( + GetHandle(), op_desc, &alpha, a_ptr, a_layout, b_ptr, b_layout, + &beta, out.data(), c_layout, out.data(), c_layout, + &heuristic.algo, workspace_, workspace_size_in_bytes_, cuda_stream); + assert(status == CUBLAS_STATUS_SUCCESS && "cuBLASLt Linear matmul failed"); + + // Cleanup. + cublasLtMatmulPreferenceDestroy(preference); + cublasLtMatrixLayoutDestroy(c_layout); + cublasLtMatrixLayoutDestroy(b_layout); + cublasLtMatrixLayoutDestroy(a_layout); + cublasLtMatmulDescDestroy(op_desc); + + // Bias add. + if (has_bias_ && bias.has_value()) { + LaunchBiasAdd(out, bias.value(), cuda_stream); + } + } + + private: + void LaunchBiasAdd(Tensor out, const Tensor bias, + typename Backend::Stream stream) const { + size_t rows = batch_count_ * m_; + size_t cols = n_; + size_t total = rows * cols; + int block_size = RuntimeUtils::GetOptimalBlockSize(); + + DispatchFunc( + out.dtype(), + [&](auto tag) { + using T = typename decltype(tag)::type; + dim3 blockDims(block_size); + dim3 gridDims((total + block_size - 1) / block_size); + + BiasAddKernel<<>>( + reinterpret_cast(out.data()), + reinterpret_cast(bias.data()), rows, cols); + }, + "CudaLinear::BiasAdd"); + } + + void SetStridedBatchAttributes(cublasLtMatrixLayout_t layout, + int64_t batch_stride) const { + int batch_count = static_cast(batch_count_); + auto status = cublasLtMatrixLayoutSetAttribute( + layout, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &batch_count, sizeof(batch_count)); + assert(status == CUBLAS_STATUS_SUCCESS); + + status = cublasLtMatrixLayoutSetAttribute( + layout, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, + &batch_stride, sizeof(batch_stride)); + assert(status == CUBLAS_STATUS_SUCCESS); + } + + cublasOperation_t GetOpA(bool trans_a, bool trans_b) const { + if (swap_a_and_b_) { + return (b_is_col_major_ == trans_b) ? CUBLAS_OP_T : CUBLAS_OP_N; + } + + return (a_is_col_major_ != trans_a) ? CUBLAS_OP_T : CUBLAS_OP_N; + } + + cublasOperation_t GetOpB(bool trans_a, bool trans_b) const { + if (swap_a_and_b_) { + return (a_is_col_major_ == trans_a) ? CUBLAS_OP_T : CUBLAS_OP_N; + } + + return (b_is_col_major_ != trans_b) ? CUBLAS_OP_T : CUBLAS_OP_N; + } + + static cublasLtHandle_t& GetHandle() { + static cublasLtHandle_t handle = []() { + cublasLtHandle_t h{}; + auto status = cublasLtCreate(&h); + assert(status == CUBLAS_STATUS_SUCCESS && + "failed to create cuBLASLt handle"); + return h; + }(); + + return handle; + } + + bool a_is_col_major_{false}; + + bool b_is_col_major_{false}; + + bool swap_a_and_b_{false}; +}; + +} // namespace infini::ops + +#endif diff --git a/src/cuda/matmul/blas.h b/src/cuda/matmul/blas.h new file mode 100644 index 00000000..997791a0 --- /dev/null +++ b/src/cuda/matmul/blas.h @@ -0,0 +1,97 @@ +#ifndef INFINI_OPS_CUDA_MATMUL_BLAS_H_ +#define INFINI_OPS_CUDA_MATMUL_BLAS_H_ + +#include + +#include "base/matmul.h" +#include "cuda/blas_utils.h" + +namespace infini::ops { + +template +class BlasMatmul : public Matmul { + public: + BlasMatmul(const Tensor a, const Tensor b, Tensor c, bool trans_a, + bool trans_b) + : Matmul{a, b, c, trans_a, trans_b}, + a_is_col_major_{a.stride(-1) == 1}, + b_is_col_major_{b.stride(-1) == 1}, + swap_a_and_b_{c.stride(-1) == 1} { + // TODO: Check constraints. + } + + BlasMatmul(const Tensor a, const Tensor b, Tensor c) + : BlasMatmul{a, b, c, false, false} {} + + void operator()(const Tensor a, const Tensor b, Tensor c, bool trans_a, + bool trans_b) const override { + Backend::BlasSetStream(GetHandle(), + static_cast(stream_)); + + auto op_a{GetOpA(trans_a, trans_b)}; + auto op_b{GetOpB(trans_a, trans_b)}; + + const float alpha{1.0f}; + const float beta{0.0f}; + + Backend::BlasGemmStridedBatchedEx( + GetHandle(), op_a, op_b, swap_a_and_b_ ? n_ : m_, + swap_a_and_b_ ? m_ : n_, k_, &alpha, + swap_a_and_b_ ? b.data() : a.data(), + BlasUtils::GetDataType(swap_a_and_b_ ? b.dtype() + : a.dtype()), + swap_a_and_b_ ? ldb_ : lda_, + swap_a_and_b_ ? batch_stride_b_ : batch_stride_a_, + swap_a_and_b_ ? a.data() : b.data(), + BlasUtils::GetDataType(swap_a_and_b_ ? a.dtype() + : b.dtype()), + swap_a_and_b_ ? lda_ : ldb_, + swap_a_and_b_ ? batch_stride_a_ : batch_stride_b_, &beta, c.data(), + BlasUtils::GetDataType(c.dtype()), ldc_, + batch_stride_c_, batch_count_, + BlasUtils::GetComputeType(c.dtype()), + Backend::BLAS_GEMM_DEFAULT); + } + + private: + auto GetOpA(bool trans_a, bool trans_b) const { + if (swap_a_and_b_) { + return (b_is_col_major_ == trans_b) ? Backend::BLAS_OP_T + : Backend::BLAS_OP_N; + } + + return (a_is_col_major_ != trans_a) ? Backend::BLAS_OP_T + : Backend::BLAS_OP_N; + } + + auto GetOpB(bool trans_a, bool trans_b) const { + if (swap_a_and_b_) { + return (a_is_col_major_ == trans_a) ? Backend::BLAS_OP_T + : Backend::BLAS_OP_N; + } + + return (b_is_col_major_ != trans_b) ? Backend::BLAS_OP_T + : Backend::BLAS_OP_N; + } + + // TODO: This static singleton is not thread-safe under concurrent access + // from multiple host threads. Add proper synchronization in the future. + static typename Backend::BlasHandle& GetHandle() { + static typename Backend::BlasHandle handle = []() { + typename Backend::BlasHandle h; + Backend::BlasCreate(&h); + return h; + }(); + return handle; + } + + bool a_is_col_major_{false}; + + bool b_is_col_major_{false}; + + bool swap_a_and_b_{false}; +}; + +} // namespace infini::ops + +#endif diff --git a/src/cuda/mul/dsl.h b/src/cuda/mul/dsl.h new file mode 100644 index 00000000..4082271f --- /dev/null +++ b/src/cuda/mul/dsl.h @@ -0,0 +1,42 @@ +#ifndef INFINI_OPS_CUDA_MUL_DSL_H_ +#define INFINI_OPS_CUDA_MUL_DSL_H_ + +#include "cuda/templates/binary_elementwise.cuh" +#include "base/mul.h" + +namespace infini::ops { + +// Device-side binary functor for `Mul` (DSL). +template +struct DslMulOp { + template + __device__ __forceinline__ T operator()(const T& a, const T& b) const { + using ComputeType = float; + auto va = Caster::template Cast(a); + auto vb = Caster::template Cast(b); + return Caster::template Cast((va * vb)); + } +}; + +template +class DslCudaMul : public Mul { + public: + DslCudaMul(const Tensor input, const Tensor other, Tensor out) + : Mul{input, other, out}, + brick_{input, other, out, ndim_} {} + + void operator()(const Tensor input, const Tensor other, + Tensor out) const override { + brick_.template Run( + stream_, input, other, out, output_size_, ndim_, + is_input_contiguous_, is_other_contiguous_, is_out_contiguous_, + out_type_); + } + + private: + BinaryElementwiseBrick brick_; +}; + +} // namespace infini::ops + +#endif diff --git a/src/cuda/reshape_and_cache/kernel.cuh b/src/cuda/reshape_and_cache/kernel.cuh new file mode 100644 index 00000000..ce406f21 --- /dev/null +++ b/src/cuda/reshape_and_cache/kernel.cuh @@ -0,0 +1,71 @@ +#ifndef INFINI_OPS_CUDA_RESHAPE_AND_CACHE_KERNEL_CUH_ +#define INFINI_OPS_CUDA_RESHAPE_AND_CACHE_KERNEL_CUH_ + +#include +#include + +namespace infini::ops { + +// Writes key and value tensors into a paged KV cache using a slot mapping. +// +// Each thread block processes one token. Threads within the block cooperatively +// write all (num_kv_heads * head_size) elements for that token into both the +// key cache and value cache. +// +// KV cache layout: [2, num_blocks, block_size, num_kv_heads, head_size] +// - Index 0 along dim 0 is the key cache. +// - Index 1 along dim 0 is the value cache. +// +// Key/value layout: [num_tokens, num_kv_heads, head_size] +// +// Slot mapping: [num_tokens] — maps each token to a flat slot index in the +// cache. `block_idx = slot / block_size`, `block_offset = slot % block_size`. +template +__global__ void ReshapeAndCacheKernel( + const T* __restrict__ key, const T* __restrict__ value, + T* __restrict__ kv_cache_out, const int64_t* __restrict__ slot_mapping, + size_t num_kv_heads, size_t head_size, size_t block_size, + size_t num_blocks) { + const size_t token_idx = blockIdx.x; + const int64_t slot = slot_mapping[token_idx]; + + // Padding tokens have slot_mapping == -1; skip them. + if (slot < 0) { + return; + } + + const size_t block_idx = static_cast(slot) / block_size; + const size_t block_offset = static_cast(slot) % block_size; + + const size_t elems_per_token = num_kv_heads * head_size; + + // Compute base offsets into the contiguous KV cache. + // Cache shape: [2, num_blocks, block_size, num_kv_heads, head_size] + // Strides: [num_blocks*block_size*num_kv_heads*head_size, + // block_size*num_kv_heads*head_size, + // num_kv_heads*head_size, + // head_size, + // 1] + const size_t cache_block_stride = block_size * num_kv_heads * head_size; + const size_t cache_kv_stride = num_blocks * cache_block_stride; + + const size_t key_cache_base = + block_idx * cache_block_stride + block_offset * num_kv_heads * head_size; + const size_t value_cache_base = cache_kv_stride + key_cache_base; + + // Source offset for this token: key/value shape is [num_tokens, num_kv_heads, + // head_size], contiguous. + const size_t src_base = token_idx * elems_per_token; + + for (size_t i = threadIdx.x; i < elems_per_token; i += BLOCK_SIZE) { + const T k = key[src_base + i]; + const T v = value[src_base + i]; + + kv_cache_out[key_cache_base + i] = k; + kv_cache_out[value_cache_base + i] = v; + } +} + +} // namespace infini::ops + +#endif diff --git a/src/cuda/reshape_and_cache/kernel.h b/src/cuda/reshape_and_cache/kernel.h new file mode 100644 index 00000000..1a23f884 --- /dev/null +++ b/src/cuda/reshape_and_cache/kernel.h @@ -0,0 +1,62 @@ +#ifndef INFINI_OPS_CUDA_RESHAPE_AND_CACHE_KERNEL_H_ +#define INFINI_OPS_CUDA_RESHAPE_AND_CACHE_KERNEL_H_ + +#include +#include + +#include "base/reshape_and_cache.h" +#include "common/generic_utils.h" +#include "cuda/kernel_commons.cuh" +#include "cuda/reshape_and_cache/kernel.cuh" +#include "cuda/runtime_utils.h" + +namespace infini::ops { + +template +class CudaReshapeAndCache : public ReshapeAndCache { + public: + CudaReshapeAndCache(const Tensor key, const Tensor value, + const Tensor kv_cache, const Tensor slot_mapping, + Tensor kv_cache_out) + : ReshapeAndCache{key, value, kv_cache, slot_mapping, kv_cache_out} {} + + void operator()(const Tensor key, const Tensor value, const Tensor kv_cache, + const Tensor slot_mapping, + Tensor kv_cache_out) const override { + int block_size_cfg = + RuntimeUtils::GetOptimalBlockSize(); + + DispatchFunc( + {static_cast(key_dtype_), block_size_cfg}, + [&](auto list_tag) { + using T = TypeMapType(list_tag)>; + constexpr int kBlockSize = ListGet<1>(list_tag); + + auto cuda_stream = + static_cast(stream_ ? stream_ : 0); + + // One thread block per token. + dim3 gridDims(num_tokens_); + dim3 blockDims(std::min(static_cast(block_size_cfg), + num_kv_heads_ * head_size_)); + + const T* d_key = reinterpret_cast(key.data()); + const T* d_value = reinterpret_cast(value.data()); + T* d_kv_cache_out = reinterpret_cast(kv_cache_out.data()); + const int64_t* d_slot_mapping = + reinterpret_cast(slot_mapping.data()); + + const size_t num_blocks = kv_cache_shape_[1]; + + ReshapeAndCacheKernel + <<>>( + d_key, d_value, d_kv_cache_out, d_slot_mapping, num_kv_heads_, + head_size_, block_size_, num_blocks); + }, + "CudaReshapeAndCache::operator()"); + } +}; + +} // namespace infini::ops + +#endif diff --git a/src/cuda/rms_norm/dsl.h b/src/cuda/rms_norm/dsl.h new file mode 100644 index 00000000..d4f59988 --- /dev/null +++ b/src/cuda/rms_norm/dsl.h @@ -0,0 +1,62 @@ +#ifndef INFINI_OPS_CUDA_RMS_NORM_DSL_H_ +#define INFINI_OPS_CUDA_RMS_NORM_DSL_H_ + +#include "cuda/templates/reduce_transform.cuh" +#include "base/rms_norm.h" + +namespace infini::ops { + +// Reduce op for `RmsNorm` (DSL). +struct DslRmsNormReduce { + template + __device__ __forceinline__ float Accumulate(const TData* ptr, + size_t count) const { + float ss = 0; + + for (size_t i = threadIdx.x; i < count; i += block_size) { + float v = Caster::template Cast(ptr[i]); + ss += v * v; + } + + return ss; + } + + __device__ __forceinline__ float Finalize(float total, + size_t count) const { + return rsqrtf(total / static_cast(count) + epsilon); + } + + float epsilon; +}; + +// Transform op for `RmsNorm` (DSL). +struct DslRmsNormTransform { + template + __device__ __forceinline__ TData Apply(TData x, float reduced, + size_t i) const { + return Caster::template Cast( + Caster::template Cast(x) * + Caster::template Cast(static_cast(weight)[i]) * reduced); + } + + const void* weight; +}; + +template +class DslCudaRmsNorm : public RmsNorm { + public: + using RmsNorm::RmsNorm; + + void operator()(const Tensor input, const Tensor weight, float eps, + Tensor out) const override { + LaunchReduceThenTransform, ReducedFloatTypes>>( + stream_, input, out, batch_size_, nhead_, dim_, + out.dtype(), input_strides_, out_strides_, + DslRmsNormReduce{eps}, + DslRmsNormTransform{weight.data()}); + } +}; + +} // namespace infini::ops + +#endif diff --git a/src/cuda/rms_norm/kernel.cuh b/src/cuda/rms_norm/kernel.cuh index 980f776e..f10e7ded 100644 --- a/src/cuda/rms_norm/kernel.cuh +++ b/src/cuda/rms_norm/kernel.cuh @@ -10,24 +10,14 @@ namespace infini::ops { -namespace { - -template -__device__ __forceinline__ TCompute SumSquared(const TData* data_ptr, - size_t count) { - TCompute ss = 0; - for (size_t i = threadIdx.x; i < count; i += block_size) { - TCompute value = Caster::template Cast(data_ptr[i]); - ss += value * value; - } - using BlockReduce = cub::BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - return BlockReduce(temp_storage).Sum(ss); -} - -} // namespace - +// Single-pass RmsNorm kernel with shared memory caching. +// +// Pass 1: Load x from global memory into shared memory, accumulate +// sum-of-squares in registers, then block-reduce. +// Pass 2: Read x from shared memory (NOT global), apply rms * weight, +// write y to global memory. +// +// This halves global memory traffic compared to the two-pass approach. template __global__ void RmsNormKernel(TData* __restrict__ y, int64_t stride_y_batch, @@ -36,26 +26,41 @@ __global__ void RmsNormKernel(TData* __restrict__ y, int64_t stride_y_batch, int64_t stride_x_batch, int64_t stride_x_nhead, const TWeight* __restrict__ w, size_t nhead, size_t dim, float epsilon) { + extern __shared__ char smem_raw[]; + TCompute* x_cache = reinterpret_cast(smem_raw); + size_t batch_idx = blockIdx.x / nhead; size_t head_idx = blockIdx.x % nhead; auto y_ptr = y + batch_idx * stride_y_batch + head_idx * stride_y_nhead; auto x_ptr = x + batch_idx * stride_x_batch + head_idx * stride_x_nhead; - auto w_ptr = w; - TCompute ss = SumSquared(x_ptr, dim); + // Pass 1: Load x into shared memory and compute sum-of-squares. + TCompute ss = 0; + + for (size_t i = threadIdx.x; i < dim; i += block_size) { + TCompute val = Caster::template Cast(x_ptr[i]); + x_cache[i] = val; + ss += val * val; + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + TCompute total = BlockReduce(temp_storage).Sum(ss); __shared__ TCompute rms; + if (threadIdx.x == 0) { - rms = Caster::template Cast( - rsqrtf(ss / Caster::template Cast(dim) + epsilon)); + rms = rsqrtf(total / static_cast(dim) + epsilon); } + __syncthreads(); + // Pass 2: Transform using cached x from shared memory. for (size_t i = threadIdx.x; i < dim; i += block_size) { y_ptr[i] = Caster::template Cast( - Caster::template Cast(x_ptr[i]) * - Caster::template Cast(w_ptr[i]) * rms); + x_cache[i] * + Caster::template Cast(w[i]) * rms); } } diff --git a/src/cuda/rms_norm/kernel.h b/src/cuda/rms_norm/kernel.h index 14146edc..5cdc73d7 100644 --- a/src/cuda/rms_norm/kernel.h +++ b/src/cuda/rms_norm/kernel.h @@ -43,8 +43,11 @@ class CudaRmsNorm : public RmsNorm { using T = TypeMapType(list_tag)>; constexpr int kBlockSize = ListGet<1>(list_tag); + // Dynamic shared memory for caching x values (single-pass). + size_t smem_bytes = dim_ * sizeof(float); + RmsNormKernel - <<>>( + <<>>( reinterpret_cast(out.data()), stride_out_batch, stride_out_nhead, reinterpret_cast(input.data()), stride_input_batch, stride_input_nhead, diff --git a/src/cuda/rotary_embedding/kernel.cuh b/src/cuda/rotary_embedding/kernel.cuh new file mode 100644 index 00000000..102d234b --- /dev/null +++ b/src/cuda/rotary_embedding/kernel.cuh @@ -0,0 +1,110 @@ +#ifndef INFINI_OPS_CUDA_ROTARY_EMBEDDING_KERNEL_CUH_ +#define INFINI_OPS_CUDA_ROTARY_EMBEDDING_KERNEL_CUH_ + +#include +#include + +#include "cuda/caster.cuh" +#include "cuda/kernel_commons.cuh" + +namespace infini::ops { + +// Applies rotary position embeddings to query and key tensors. +// +// Each thread block handles one token. Threads within the block iterate over +// (head, rot_offset) pairs to apply the rotation formula: +// arr[x_idx] = x * cos - y * sin +// arr[y_idx] = y * cos + x * sin +// +// Supports two index patterns: +// - NeoX style: x_idx = rot_offset, y_idx = half_rotary_dim + rot_offset +// - GPT-J style: x_idx = 2 * rot_offset, y_idx = 2 * rot_offset + 1 +template +__global__ void RotaryEmbeddingKernel( + TData* __restrict__ query_out, TData* __restrict__ key_out, + const TData* __restrict__ query, const TData* __restrict__ key, + const TData* __restrict__ cos_sin_cache, + const int64_t* __restrict__ positions, int64_t num_heads, + int64_t num_kv_heads, int64_t head_size, int64_t rotary_dim, + int64_t query_stride_token, int64_t query_stride_head, + int64_t key_stride_token, int64_t key_stride_head, + int64_t query_out_stride_token, int64_t query_out_stride_head, + int64_t key_out_stride_token, int64_t key_out_stride_head, + bool is_neox_style) { + int64_t token_idx = blockIdx.x; + int64_t pos = positions[token_idx]; + int64_t half_rotary_dim = rotary_dim / 2; + + // Pointer to the cos/sin row for this token's position. + // Cache layout: [max_seq_len, rotary_dim] where first half is cos, second + // half is sin. + const TData* cos_ptr = cos_sin_cache + pos * rotary_dim; + const TData* sin_ptr = cos_ptr + half_rotary_dim; + + int64_t total_heads = num_heads + num_kv_heads; + int64_t total_work = total_heads * half_rotary_dim; + + for (int64_t i = threadIdx.x; i < total_work; i += kBlockSize) { + int64_t head_idx = i / half_rotary_dim; + int64_t rot_offset = i % half_rotary_dim; + + TCompute cos_val = + Caster::template Cast(cos_ptr[rot_offset]); + TCompute sin_val = + Caster::template Cast(sin_ptr[rot_offset]); + + int64_t x_idx, y_idx; + + if (is_neox_style) { + x_idx = rot_offset; + y_idx = half_rotary_dim + rot_offset; + } else { + x_idx = 2 * rot_offset; + y_idx = 2 * rot_offset + 1; + } + + if (head_idx < num_heads) { + // Apply to query. + const TData* q_in = + query + token_idx * query_stride_token + head_idx * query_stride_head; + TData* q_out = query_out + token_idx * query_out_stride_token + + head_idx * query_out_stride_head; + + TCompute x = Caster::template Cast(q_in[x_idx]); + TCompute y = Caster::template Cast(q_in[y_idx]); + q_out[x_idx] = Caster::template Cast(x * cos_val - y * sin_val); + q_out[y_idx] = Caster::template Cast(y * cos_val + x * sin_val); + + // Copy non-rotary dimensions if needed. + if (rot_offset == 0 && rotary_dim < head_size) { + for (int64_t d = rotary_dim; d < head_size; ++d) { + q_out[d] = q_in[d]; + } + } + } else { + // Apply to key. + int64_t kv_head_idx = head_idx - num_heads; + const TData* k_in = + key + token_idx * key_stride_token + kv_head_idx * key_stride_head; + TData* k_out = key_out + token_idx * key_out_stride_token + + kv_head_idx * key_out_stride_head; + + TCompute x = Caster::template Cast(k_in[x_idx]); + TCompute y = Caster::template Cast(k_in[y_idx]); + k_out[x_idx] = Caster::template Cast(x * cos_val - y * sin_val); + k_out[y_idx] = Caster::template Cast(y * cos_val + x * sin_val); + + // Copy non-rotary dimensions if needed. + if (rot_offset == 0 && rotary_dim < head_size) { + for (int64_t d = rotary_dim; d < head_size; ++d) { + k_out[d] = k_in[d]; + } + } + } + } +} + +} // namespace infini::ops + +#endif diff --git a/src/cuda/rotary_embedding/kernel.h b/src/cuda/rotary_embedding/kernel.h new file mode 100644 index 00000000..44b95eda --- /dev/null +++ b/src/cuda/rotary_embedding/kernel.h @@ -0,0 +1,60 @@ +#ifndef INFINI_OPS_CUDA_ROTARY_EMBEDDING_KERNEL_H_ +#define INFINI_OPS_CUDA_ROTARY_EMBEDDING_KERNEL_H_ + +#include +#include + +#include "base/rotary_embedding.h" +#include "cuda/kernel_commons.cuh" +#include "cuda/rotary_embedding/kernel.cuh" +#include "cuda/runtime_utils.h" +#include "dispatcher.h" + +namespace infini::ops { + +template +class CudaRotaryEmbedding : public RotaryEmbedding { + public: + using RotaryEmbedding::RotaryEmbedding; + + void operator()(const Tensor positions, const Tensor query, const Tensor key, + const Tensor cos_sin_cache, int64_t head_size, + int64_t rotary_dim, bool is_neox_style, Tensor query_out, + Tensor key_out) const override { + auto cuda_stream = + static_cast(stream_ ? stream_ : 0); + + uint32_t num_blocks = static_cast(num_tokens_); + int block_size = RuntimeUtils::GetOptimalBlockSize(); + + assert(query.dtype() == key.dtype() && + "query and key must have the same dtype"); + + DispatchFunc, ReducedFloatTypes>, + AllCudaBlockSizes>( + {static_cast(query.dtype()), block_size}, + [&](auto list_tag) { + using T = TypeMapType(list_tag)>; + constexpr int kBlockSize = ListGet<1>(list_tag); + + RotaryEmbeddingKernel + <<>>( + reinterpret_cast(query_out.data()), + reinterpret_cast(key_out.data()), + reinterpret_cast(query.data()), + reinterpret_cast(key.data()), + reinterpret_cast(cos_sin_cache.data()), + reinterpret_cast(positions.data()), + num_heads_, num_kv_heads_, head_size_, rotary_dim_, + query_strides_[0], query_strides_[1], key_strides_[0], + key_strides_[1], query_out_strides_[0], + query_out_strides_[1], key_out_strides_[0], + key_out_strides_[1], is_neox_style_); + }, + "CudaRotaryEmbedding::operator()"); + } +}; + +} // namespace infini::ops + +#endif diff --git a/src/cuda/swiglu/dsl.h b/src/cuda/swiglu/dsl.h new file mode 100644 index 00000000..54991dbf --- /dev/null +++ b/src/cuda/swiglu/dsl.h @@ -0,0 +1,43 @@ +#ifndef INFINI_OPS_CUDA_SWIGLU_DSL_H_ +#define INFINI_OPS_CUDA_SWIGLU_DSL_H_ + +#include "cuda/templates/binary_elementwise.cuh" +#include "base/swiglu.h" + +namespace infini::ops { + +// Device-side binary functor for `Swiglu` (DSL). +template +struct DslSwigluOp { + template + __device__ __forceinline__ T operator()(const T& a, const T& b) const { + using ComputeType = float; + auto va = Caster::template Cast(a); + auto vb = Caster::template Cast(b); + auto t2 = vb / (static_cast(1) + expf(-vb)); + return Caster::template Cast((va * t2)); + } +}; + +template +class DslCudaSwiglu : public Swiglu { + public: + DslCudaSwiglu(const Tensor input, const Tensor other, Tensor out) + : Swiglu{input, other, out}, + brick_{input, other, out, ndim_} {} + + void operator()(const Tensor input, const Tensor other, + Tensor out) const override { + brick_.template Run( + stream_, input, other, out, output_size_, ndim_, + is_input_contiguous_, is_other_contiguous_, is_out_contiguous_, + out_type_); + } + + private: + BinaryElementwiseBrick brick_; +}; + +} // namespace infini::ops + +#endif diff --git a/src/cuda/swiglu/kernel.cuh b/src/cuda/swiglu/kernel.cuh index 36b9f975..3150dab4 100644 --- a/src/cuda/swiglu/kernel.cuh +++ b/src/cuda/swiglu/kernel.cuh @@ -24,6 +24,27 @@ __device__ __forceinline__ T Sigmoid(const T& x) { } } +// Device-side SwiGLU functor for BinaryElementwiseBrick. +// SwiGLU(input, gate) = input * gate * sigmoid(gate). +template +struct SwigluOp { + template + __device__ __forceinline__ T operator()(const T& input, + const T& gate) const { + if constexpr (IsFP16 || IsBFloat16) { + float gf = Caster::template Cast(gate); + float uf = Caster::template Cast(input); + float sf = __frcp_rn(__fadd_rn(1.0f, __expf(-gf))); + return Caster::template Cast( + __fmul_rn(__fmul_rn(gf, sf), uf)); + } else if constexpr (std::is_same_v) { + return __fmul_rn(__fmul_rn(gate, Sigmoid(gate)), input); + } else { + return gate * Sigmoid(gate) * input; + } + } +}; + // SwiGLU(x, gate) = Swish(x) * gate = (x * sigmoid(x)) * gate. template __global__ void SwigluKernel(T* __restrict__ out, const T* __restrict__ a, diff --git a/src/cuda/swiglu/kernel.h b/src/cuda/swiglu/kernel.h index 5fcfe73b..3a0d87a8 100644 --- a/src/cuda/swiglu/kernel.h +++ b/src/cuda/swiglu/kernel.h @@ -1,104 +1,31 @@ #ifndef INFINI_OPS_CUDA_SWIGLU_KERNEL_H_ #define INFINI_OPS_CUDA_SWIGLU_KERNEL_H_ -#include -#include -#include -#include - #include "base/swiglu.h" -#include "common/generic_utils.h" -#include "cuda/runtime_utils.h" #include "cuda/swiglu/kernel.cuh" +#include "cuda/templates/binary_elementwise.cuh" namespace infini::ops { +// CudaSwiglu uses BinaryElementwiseBrick for automatic vectorized dispatch +// on contiguous tensors (128-bit coalesced load/store). template class CudaSwiglu : public Swiglu { public: - CudaSwiglu(const Tensor input, const Tensor gate, Tensor out) - : Swiglu{input, gate, out} { - size_t shape_size = ndim_ * sizeof(*d_input_shape_); - size_t strides_size = ndim_ * sizeof(*d_input_strides_); - - const size_t metadata_size = 3 * (shape_size + strides_size); - std::vector metadata(metadata_size); - - Backend::Malloc((void**)&d_metadata_, metadata_size); - - size_t offset = 0; - d_input_shape_ = reinterpret_cast(d_metadata_ + offset); - std::memcpy(metadata.data() + offset, input_shape_.data(), shape_size); - offset += shape_size; - - d_gate_shape_ = reinterpret_cast(d_metadata_ + offset); - std::memcpy(metadata.data() + offset, gate_shape_.data(), shape_size); - offset += shape_size; - - d_out_shape_ = reinterpret_cast(d_metadata_ + offset); - std::memcpy(metadata.data() + offset, out_shape_.data(), shape_size); - offset += shape_size; - - d_input_strides_ = reinterpret_cast(d_metadata_ + offset); - std::memcpy(metadata.data() + offset, input_strides_.data(), strides_size); - offset += strides_size; - - d_gate_strides_ = reinterpret_cast(d_metadata_ + offset); - std::memcpy(metadata.data() + offset, gate_strides_.data(), strides_size); - offset += strides_size; - - d_out_strides_ = reinterpret_cast(d_metadata_ + offset); - std::memcpy(metadata.data() + offset, out_strides_.data(), strides_size); - - Backend::Memcpy(d_metadata_, metadata.data(), metadata_size, - Backend::MemcpyHostToDevice); - } + CudaSwiglu(const Tensor input, const Tensor other, Tensor out) + : Swiglu{input, other, out}, + brick_{input, other, out, ndim_} {} - ~CudaSwiglu() { Backend::Free(d_metadata_); } - - void operator()(const Tensor input, const Tensor gate, + void operator()(const Tensor input, const Tensor other, Tensor out) const override { - int block_size = RuntimeUtils::GetOptimalBlockSize(); - DispatchFunc( - {static_cast(out_type_), block_size}, - [&](auto list_tag) { - using T = TypeMapType(list_tag)>; - constexpr int kBlockSize = ListGet<1>(list_tag); - - auto cuda_stream = - static_cast(stream_ ? stream_ : 0); - dim3 blockDims( - std::min(static_cast(block_size), output_size_)); - dim3 gridDims(utils::CeilDiv(output_size_, blockDims.x)); - - T* d_out = reinterpret_cast(out.data()); - const T* d_input = reinterpret_cast(input.data()); - const T* d_gate = reinterpret_cast(gate.data()); - - SwigluKernel - <<>>( - d_out, d_input, d_gate, d_out_shape_, d_input_shape_, - d_gate_shape_, d_out_strides_, d_input_strides_, - d_gate_strides_, output_size_, ndim_, is_out_contiguous_, - is_input_contiguous_, is_gate_contiguous_); - }, - "CudaSwiglu::operator()"); + brick_.template Run( + stream_, input, other, out, output_size_, ndim_, + is_input_contiguous_, is_other_contiguous_, is_out_contiguous_, + out_type_); } private: - std::byte* d_metadata_{nullptr}; - - Tensor::Size* d_input_shape_{nullptr}; - - Tensor::Size* d_gate_shape_{nullptr}; - - Tensor::Size* d_out_shape_{nullptr}; - - Tensor::Stride* d_input_strides_{nullptr}; - - Tensor::Stride* d_gate_strides_{nullptr}; - - Tensor::Stride* d_out_strides_{nullptr}; + BinaryElementwiseBrick brick_; }; } // namespace infini::ops diff --git a/src/cuda/templates/binary_elementwise.cuh b/src/cuda/templates/binary_elementwise.cuh new file mode 100644 index 00000000..e12f1190 --- /dev/null +++ b/src/cuda/templates/binary_elementwise.cuh @@ -0,0 +1,214 @@ +#ifndef INFINI_OPS_CUDA_TEMPLATES_BINARY_ELEMENTWISE_CUH_ +#define INFINI_OPS_CUDA_TEMPLATES_BINARY_ELEMENTWISE_CUH_ + +#include +#include +#include +#include + +#include "common/generic_utils.h" +#include "cuda/kernel_commons.cuh" +#include "cuda/runtime_utils.h" +#include "dispatcher.h" +#include "tensor.h" + +namespace infini::ops { + +// Vectorized binary elementwise kernel for contiguous tensors. +// +// Processes VEC_SIZE elements per thread using vectorized load/store for +// higher memory bandwidth utilization. Falls back to scalar when the +// total element count is not divisible by VEC_SIZE. +template +__global__ void BinaryElementwiseVecKernel(T* __restrict__ out, + const T* __restrict__ a, + const T* __restrict__ b, + size_t output_size) { + size_t tid = blockIdx.x * blockDim.x + threadIdx.x; + size_t stride = gridDim.x * blockDim.x; + size_t vec_count = output_size / VEC_SIZE; + + using VecT = typename utils::AlignedVec::type; + const VecT* a_vec = reinterpret_cast(a); + const VecT* b_vec = reinterpret_cast(b); + VecT* out_vec = reinterpret_cast(out); + + Op op{}; + + for (size_t i = tid; i < vec_count; i += stride) { + VecT va = a_vec[i]; + VecT vb = b_vec[i]; + const T* pa = reinterpret_cast(&va); + const T* pb = reinterpret_cast(&vb); + VecT vout; + T* po = reinterpret_cast(&vout); + + #pragma unroll + for (int j = 0; j < VEC_SIZE; ++j) { + po[j] = op(pa[j], pb[j]); + } + + out_vec[i] = vout; + } + + // Handle remaining elements. + size_t tail_start = vec_count * VEC_SIZE; + + for (size_t i = tail_start + tid; i < output_size; i += stride) { + out[i] = op(a[i], b[i]); + } +} + +// Generic binary elementwise GPU kernel (non-contiguous path). +// +// `Op` is a device-side functor with signature `T operator()(const T&, const T&)`. +template +__global__ void BinaryElementwiseKernel( + T* __restrict__ out, const T* __restrict__ a, const T* __restrict__ b, + const size_t* __restrict__ out_shape, const size_t* __restrict__ a_shape, + const size_t* __restrict__ b_shape, + const ptrdiff_t* __restrict__ out_strides, + const ptrdiff_t* __restrict__ a_strides, + const ptrdiff_t* __restrict__ b_strides, size_t output_size, size_t ndim, + bool out_contig, bool a_contig, bool b_contig) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (idx < output_size) { + size_t out_idx = + out_contig ? idx : IndexToOffset(idx, ndim, out_shape, out_strides); + size_t a_idx = + a_contig ? idx : IndexToOffset(idx, ndim, a_shape, a_strides); + size_t b_idx = + b_contig ? idx : IndexToOffset(idx, ndim, b_shape, b_strides); + + out[out_idx] = Op{}(a[a_idx], b[b_idx]); + } +} + +// Manages device metadata (shapes/strides) for a binary elementwise operator +// and provides a templated `Run` method for dtype-dispatched kernel launch. +template +class BinaryElementwiseBrick { + public: + BinaryElementwiseBrick(const Tensor a, const Tensor b, const Tensor out, + Tensor::Size ndim) { + size_t shape_bytes = ndim * sizeof(Tensor::Size); + size_t stride_bytes = ndim * sizeof(Tensor::Stride); + size_t total = 3 * (shape_bytes + stride_bytes); + std::vector staging(total); + + Backend::Malloc((void**)&d_metadata_, total); + + size_t offset = 0; + + d_a_shape_ = reinterpret_cast(d_metadata_ + offset); + std::memcpy(staging.data() + offset, a.shape().data(), shape_bytes); + offset += shape_bytes; + + d_b_shape_ = reinterpret_cast(d_metadata_ + offset); + std::memcpy(staging.data() + offset, b.shape().data(), shape_bytes); + offset += shape_bytes; + + d_out_shape_ = reinterpret_cast(d_metadata_ + offset); + std::memcpy(staging.data() + offset, out.shape().data(), shape_bytes); + offset += shape_bytes; + + d_a_strides_ = reinterpret_cast(d_metadata_ + offset); + std::memcpy(staging.data() + offset, a.strides().data(), stride_bytes); + offset += stride_bytes; + + d_b_strides_ = reinterpret_cast(d_metadata_ + offset); + std::memcpy(staging.data() + offset, b.strides().data(), stride_bytes); + offset += stride_bytes; + + d_out_strides_ = reinterpret_cast(d_metadata_ + offset); + std::memcpy(staging.data() + offset, out.strides().data(), stride_bytes); + + Backend::Memcpy(d_metadata_, staging.data(), total, + Backend::MemcpyHostToDevice); + } + + ~BinaryElementwiseBrick() { Backend::Free(d_metadata_); } + + BinaryElementwiseBrick(const BinaryElementwiseBrick&) = delete; + BinaryElementwiseBrick& operator=(const BinaryElementwiseBrick&) = delete; + + // Launch the elementwise kernel with dtype dispatch. + // + // When all three tensors are contiguous, uses a vectorized kernel with + // 128-bit coalesced loads for higher memory bandwidth. Falls back to + // the scalar kernel with per-element IndexToOffset for non-contiguous. + template class Op> + void Run(void* stream, const Tensor a, const Tensor b, Tensor out, + Tensor::Size output_size, Tensor::Size ndim, bool a_contig, + bool b_contig, bool out_contig, DataType dtype) const { + int block_size = RuntimeUtils::GetOptimalBlockSize(); + bool all_contig = a_contig && b_contig && out_contig; + + DispatchFunc( + {static_cast(dtype), block_size}, + [&](auto list_tag) { + using T = TypeMapType(list_tag)>; + constexpr int kBlockSize = ListGet<1>(list_tag); + + auto cuda_stream = + static_cast(stream ? stream : 0); + + if (all_contig) { + // Vectorized path: 128-bit loads, grid-stride loop. + constexpr int kVecSize = utils::OptimalVecSize(); + size_t vec_count = output_size / kVecSize; + size_t total_threads = vec_count > 0 ? vec_count : output_size; + dim3 blockDims(std::min(static_cast(block_size), + total_threads)); + dim3 gridDims( + std::min(utils::CeilDiv(total_threads, blockDims.x), + static_cast(65535))); + + BinaryElementwiseVecKernel, T, + kBlockSize, kVecSize> + <<>>( + reinterpret_cast(out.data()), + reinterpret_cast(a.data()), + reinterpret_cast(b.data()), output_size); + } else { + // Scalar path with IndexToOffset for non-contiguous tensors. + dim3 blockDims( + std::min(static_cast(block_size), output_size)); + dim3 gridDims(utils::CeilDiv(output_size, blockDims.x)); + + BinaryElementwiseKernel, T, kBlockSize> + <<>>( + reinterpret_cast(out.data()), + reinterpret_cast(a.data()), + reinterpret_cast(b.data()), d_out_shape_, + d_a_shape_, d_b_shape_, d_out_strides_, d_a_strides_, + d_b_strides_, output_size, ndim, out_contig, a_contig, + b_contig); + } + }, + "BinaryElementwiseBrick::Run"); + } + + private: + std::byte* d_metadata_{nullptr}; + + Tensor::Size* d_a_shape_{nullptr}; + + Tensor::Size* d_b_shape_{nullptr}; + + Tensor::Size* d_out_shape_{nullptr}; + + Tensor::Stride* d_a_strides_{nullptr}; + + Tensor::Stride* d_b_strides_{nullptr}; + + Tensor::Stride* d_out_strides_{nullptr}; +}; + +} // namespace infini::ops + +#endif diff --git a/src/cuda/templates/reduce_transform.cuh b/src/cuda/templates/reduce_transform.cuh new file mode 100644 index 00000000..84e5e1d4 --- /dev/null +++ b/src/cuda/templates/reduce_transform.cuh @@ -0,0 +1,144 @@ +#ifndef INFINI_OPS_CUDA_TEMPLATES_REDUCE_TRANSFORM_CUH_ +#define INFINI_OPS_CUDA_TEMPLATES_REDUCE_TRANSFORM_CUH_ + +#include +#include +#include + +#include "cuda/caster.cuh" +#include "cuda/kernel_commons.cuh" +#include "cuda/runtime_utils.h" +#include "dispatcher.h" +#include "tensor.h" + +namespace infini::ops { + +// Generic reduce-then-transform GPU kernel. +// +// One CUDA block processes one logical unit (e.g. one [batch, head] slice). +// The reduction runs over `reduce_dim` elements using CUB `BlockReduce`, +// then the transform writes back `reduce_dim` elements using all threads. +// +// Template parameters: +// `ReduceOp` — functor: `TCompute operator()(const TData* ptr, size_t count)` +// returns per-thread partial result for BlockReduce::Sum. +// `TransformOp` — functor: `TData operator()(TData x, TCompute reduced, size_t i)` +// applied per element after reduction. +template +__global__ void ReduceThenTransformKernel( + TData* __restrict__ out, int64_t stride_out_batch, int64_t stride_out_head, + const TData* __restrict__ in, int64_t stride_in_batch, + int64_t stride_in_head, size_t nhead, size_t reduce_dim, + ReduceOp reduce_op, TransformOp transform_op) { + size_t batch_idx = blockIdx.x / nhead; + size_t head_idx = blockIdx.x % nhead; + + auto out_ptr = out + batch_idx * stride_out_batch + head_idx * stride_out_head; + auto in_ptr = in + batch_idx * stride_in_batch + head_idx * stride_in_head; + + // Reduction phase: each thread accumulates a partial sum, then block-reduce. + TCompute partial = reduce_op.template Accumulate( + in_ptr, reduce_dim); + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + TCompute total = BlockReduce(temp_storage).Sum(partial); + + // Thread 0 post-processes the reduction result and shares via shared memory. + __shared__ TCompute reduced; + + if (threadIdx.x == 0) { + reduced = reduce_op.Finalize(total, reduce_dim); + } + + __syncthreads(); + + // Transform phase: all threads apply the transform in parallel. + for (size_t i = threadIdx.x; i < reduce_dim; i += block_size) { + out_ptr[i] = transform_op.template Apply(in_ptr[i], reduced, i); + } +} + +// Launches a reduce-then-transform kernel with dtype dispatch. +// +// `ReduceOp` and `TransformOp` are host-side structs that carry any extra +// state (weights, epsilon, etc.) and define device-side methods. +template +void LaunchReduceThenTransform( + void* stream, const Tensor in, Tensor out, size_t batch_size, + size_t nhead, size_t reduce_dim, DataType dtype, + const Tensor::Strides& in_strides, const Tensor::Strides& out_strides, + ReduceOp reduce_op, TransformOp transform_op) { + auto cuda_stream = + static_cast(stream ? stream : 0); + + auto stride_in_batch = in_strides.size() > 1 ? in_strides[0] : 0; + auto stride_in_head = + in_strides.size() > 1 ? in_strides[1] : in_strides[0]; + auto stride_out_batch = out_strides.size() > 1 ? out_strides[0] : 0; + auto stride_out_head = + out_strides.size() > 1 ? out_strides[1] : out_strides[0]; + + uint32_t num_blocks = static_cast(batch_size * nhead); + int block_size = RuntimeUtils::GetOptimalBlockSize(); + + DispatchFunc( + {static_cast(dtype), block_size}, + [&](auto list_tag) { + using T = TypeMapType(list_tag)>; + constexpr int kBlockSize = ListGet<1>(list_tag); + + ReduceThenTransformKernel + <<>>( + reinterpret_cast(out.data()), stride_out_batch, + stride_out_head, reinterpret_cast(in.data()), + stride_in_batch, stride_in_head, nhead, reduce_dim, reduce_op, + transform_op); + }, + "LaunchReduceThenTransform"); +} + +// ---------- Built-in reduce/transform ops for common patterns --------------- + +// Reduce op: mean of squares (for RmsNorm). +struct MeanSquareReduce { + template + __device__ __forceinline__ float Accumulate(const TData* ptr, + size_t count) const { + float ss = 0; + + for (size_t i = threadIdx.x; i < count; i += block_size) { + float v = Caster::template Cast(ptr[i]); + ss += v * v; + } + + return ss; + } + + __device__ __forceinline__ float Finalize(float total, + size_t count) const { + return rsqrtf(total / static_cast(count) + epsilon); + } + + float epsilon; +}; + +// Transform op: multiply by weight and reduced RMS value (for RmsNorm). +struct RmsNormTransform { + template + __device__ __forceinline__ TData Apply(TData x, float rms, + size_t i) const { + return Caster::template Cast( + Caster::template Cast(x) * + Caster::template Cast(weight[i]) * rms); + } + + const void* weight; +}; + +} // namespace infini::ops + +#endif diff --git a/src/cuda/templates/unary_elementwise.cuh b/src/cuda/templates/unary_elementwise.cuh new file mode 100644 index 00000000..c8d57898 --- /dev/null +++ b/src/cuda/templates/unary_elementwise.cuh @@ -0,0 +1,198 @@ +#ifndef INFINI_OPS_CUDA_TEMPLATES_UNARY_ELEMENTWISE_CUH_ +#define INFINI_OPS_CUDA_TEMPLATES_UNARY_ELEMENTWISE_CUH_ + +#include +#include +#include +#include + +#include "common/generic_utils.h" +#include "cuda/kernel_commons.cuh" +#include "cuda/runtime_utils.h" +#include "dispatcher.h" +#include "tensor.h" + +namespace infini::ops { + +// Vectorized unary elementwise kernel for contiguous tensors. +// +// Uses vectorized load/store with grid-stride loop. VEC_SIZE is chosen +// based on the *input* type to target 128-bit loads. +template +__global__ void UnaryElementwiseVecKernel(TOut* __restrict__ out, + const TIn* __restrict__ in, + size_t output_size) { + Op op{}; + size_t tid = blockIdx.x * blockDim.x + threadIdx.x; + size_t stride = gridDim.x * blockDim.x; + size_t vec_count = output_size / VEC_SIZE; + + using InVec = typename utils::AlignedVec::type; + const InVec* in_vec = reinterpret_cast(in); + + // Use output vectorization when sizeof matches (same type cast) or + // when VEC_SIZE output elements fit naturally. + using OutVec = typename utils::AlignedVec::type; + OutVec* out_vec = reinterpret_cast(out); + + for (size_t i = tid; i < vec_count; i += stride) { + InVec vin = in_vec[i]; + const TIn* pin = reinterpret_cast(&vin); + OutVec vout; + TOut* po = reinterpret_cast(&vout); + + #pragma unroll + for (int j = 0; j < VEC_SIZE; ++j) { + po[j] = op.template operator()(pin[j]); + } + + out_vec[i] = vout; + } + + // Handle remaining elements. + size_t tail_start = vec_count * VEC_SIZE; + + for (size_t i = tail_start + tid; i < output_size; i += stride) { + out[i] = op.template operator()(in[i]); + } +} + +// Generic unary elementwise GPU kernel (non-contiguous path). +// +// `Op` is a device-side functor with signature `TOut operator()(const TIn&)`. +template +__global__ void UnaryElementwiseKernel( + TOut* __restrict__ out, const TIn* __restrict__ in, + const size_t* __restrict__ out_shape, const size_t* __restrict__ in_shape, + const ptrdiff_t* __restrict__ out_strides, + const ptrdiff_t* __restrict__ in_strides, size_t output_size, size_t ndim, + bool out_contig, bool in_contig) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (idx < output_size) { + size_t out_idx = + out_contig ? idx : IndexToOffset(idx, ndim, out_shape, out_strides); + size_t in_idx = + in_contig ? idx : IndexToOffset(idx, ndim, in_shape, in_strides); + + out[out_idx] = Op{}.template operator()(in[in_idx]); + } +} + +// Manages device metadata (shapes/strides) for a unary elementwise operator +// and provides a templated `Run` method for dual-dtype-dispatched kernel launch. +template +class UnaryElementwiseBrick { + public: + UnaryElementwiseBrick(const Tensor input, Tensor out, Tensor::Size ndim) { + size_t shape_bytes = ndim * sizeof(Tensor::Size); + size_t stride_bytes = ndim * sizeof(Tensor::Stride); + size_t total = 2 * (shape_bytes + stride_bytes); + std::vector staging(total); + + Backend::Malloc((void**)&d_metadata_, total); + + size_t offset = 0; + + d_in_shape_ = reinterpret_cast(d_metadata_ + offset); + std::memcpy(staging.data() + offset, input.shape().data(), shape_bytes); + offset += shape_bytes; + + d_out_shape_ = reinterpret_cast(d_metadata_ + offset); + std::memcpy(staging.data() + offset, out.shape().data(), shape_bytes); + offset += shape_bytes; + + d_in_strides_ = reinterpret_cast(d_metadata_ + offset); + std::memcpy(staging.data() + offset, input.strides().data(), stride_bytes); + offset += stride_bytes; + + d_out_strides_ = reinterpret_cast(d_metadata_ + offset); + std::memcpy(staging.data() + offset, out.strides().data(), stride_bytes); + + Backend::Memcpy(d_metadata_, staging.data(), total, + Backend::MemcpyHostToDevice); + } + + ~UnaryElementwiseBrick() { Backend::Free(d_metadata_); } + + UnaryElementwiseBrick(const UnaryElementwiseBrick&) = delete; + + UnaryElementwiseBrick& operator=(const UnaryElementwiseBrick&) = delete; + + // Launch the elementwise kernel with dual-dtype dispatch. + // + // `InputTypeList` and `OutputTypeList` are the compile-time lists of + // supported `DataType` values for input and output respectively. + // `Op` is a device-side functor templated on `Device::Type kDev` with + // a member `template TOut operator()(const TIn&)`. + template class Op> + void Run(void* stream, const Tensor input, Tensor out, + Tensor::Size output_size, Tensor::Size ndim, bool in_contig, + bool out_contig, DataType input_dtype, + DataType output_dtype) const { + int block_size = RuntimeUtils::GetOptimalBlockSize(); + + DispatchFunc( + {static_cast(input_dtype), static_cast(output_dtype), + block_size}, + [&](auto list_tag) { + using TIn = TypeMapType(list_tag)>; + using TOut = TypeMapType(list_tag)>; + constexpr int kBlockSize = ListGet<2>(list_tag); + + auto cuda_stream = + static_cast(stream ? stream : 0); + + if (in_contig && out_contig) { + // Vectorized path: 128-bit loads on input type. + constexpr int kVecSize = utils::OptimalVecSize(); + size_t vec_count = output_size / kVecSize; + size_t total_threads = vec_count > 0 ? vec_count : output_size; + dim3 blockDims(std::min(static_cast(block_size), + total_threads)); + dim3 gridDims( + std::min(utils::CeilDiv(total_threads, blockDims.x), + static_cast(65535))); + + UnaryElementwiseVecKernel, TIn, TOut, + kBlockSize, kVecSize> + <<>>( + reinterpret_cast(out.data()), + reinterpret_cast(input.data()), output_size); + } else { + dim3 blockDims( + std::min(static_cast(block_size), output_size)); + dim3 gridDims(utils::CeilDiv(output_size, blockDims.x)); + + UnaryElementwiseKernel, TIn, TOut, + kBlockSize> + <<>>( + reinterpret_cast(out.data()), + reinterpret_cast(input.data()), d_out_shape_, + d_in_shape_, d_out_strides_, d_in_strides_, output_size, + ndim, out_contig, in_contig); + } + }, + "UnaryElementwiseBrick::Run"); + } + + private: + std::byte* d_metadata_{nullptr}; + + Tensor::Size* d_in_shape_{nullptr}; + + Tensor::Size* d_out_shape_{nullptr}; + + Tensor::Stride* d_in_strides_{nullptr}; + + Tensor::Stride* d_out_strides_{nullptr}; +}; + +} // namespace infini::ops + +#endif diff --git a/src/hash.h b/src/hash.h index efb34f75..4721f33f 100644 --- a/src/hash.h +++ b/src/hash.h @@ -2,6 +2,7 @@ #define INFINI_OPS_HASH_H_ #include +#include template inline void HashCombine(std::size_t& seed, const T& v) { @@ -9,4 +10,12 @@ inline void HashCombine(std::size_t& seed, const T& v) { seed ^= hasher(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2); } +template +inline void HashCombine(std::size_t& seed, const std::vector& v) { + HashCombine(seed, v.size()); + for (const auto& elem : v) { + HashCombine(seed, elem); + } +} + #endif diff --git a/src/impl.h b/src/impl.h new file mode 100644 index 00000000..9a8be014 --- /dev/null +++ b/src/impl.h @@ -0,0 +1,17 @@ +#ifndef INFINI_OPS_IMPL_H_ +#define INFINI_OPS_IMPL_H_ + +#include + +namespace infini::ops { + +// Global implementation index constants for the common case: +// a hand-written default and a DSL-generated alternative. +struct Impl { + static constexpr std::size_t kDefault = 0; + static constexpr std::size_t kDsl = 1; +}; + +} // namespace infini::ops + +#endif diff --git a/src/nvidia/add/kernel.h b/src/nvidia/add/kernel.h deleted file mode 100644 index d11c89d6..00000000 --- a/src/nvidia/add/kernel.h +++ /dev/null @@ -1,21 +0,0 @@ -#ifndef INFINI_OPS_NVIDIA_ADD_KERNEL_H_ -#define INFINI_OPS_NVIDIA_ADD_KERNEL_H_ - -#include - -#include "cuda/add/kernel.h" -#include "nvidia/caster.cuh" -#include "nvidia/runtime_.h" - -namespace infini::ops { - -template <> -class Operator - : public CudaAdd> { - public: - using CudaAdd>::CudaAdd; -}; - -} // namespace infini::ops - -#endif diff --git a/src/nvidia/causal_softmax/kernel.h b/src/nvidia/causal_softmax/kernel.h deleted file mode 100644 index c0b30770..00000000 --- a/src/nvidia/causal_softmax/kernel.h +++ /dev/null @@ -1,21 +0,0 @@ -#ifndef INFINI_OPS_NVIDIA_CAUSAL_SOFTMAX_KERNEL_H_ -#define INFINI_OPS_NVIDIA_CAUSAL_SOFTMAX_KERNEL_H_ - -#include - -#include "cuda/causal_softmax/kernel.h" -#include "nvidia/caster.cuh" -#include "nvidia/runtime_.h" - -namespace infini::ops { - -template <> -class Operator - : public CudaCausalSoftmax> { - public: - using CudaCausalSoftmax>::CudaCausalSoftmax; -}; - -} // namespace infini::ops - -#endif diff --git a/src/nvidia/gemm/cublas.h b/src/nvidia/gemm/cublas.h deleted file mode 100644 index 35bdd77a..00000000 --- a/src/nvidia/gemm/cublas.h +++ /dev/null @@ -1,19 +0,0 @@ -#ifndef INFINI_OPS_NVIDIA_GEMM_CUBLAS_H_ -#define INFINI_OPS_NVIDIA_GEMM_CUBLAS_H_ - -#include "cuda/gemm/blas.h" -#include "nvidia/blas.h" -#include "nvidia/gemm/registry.h" - -namespace infini::ops { - -template <> -class Operator - : public BlasGemm> { - public: - using BlasGemm>::BlasGemm; -}; - -} // namespace infini::ops - -#endif diff --git a/src/nvidia/gemm/cublaslt.h b/src/nvidia/gemm/cublaslt.h index 38de8507..7c0a6142 100644 --- a/src/nvidia/gemm/cublaslt.h +++ b/src/nvidia/gemm/cublaslt.h @@ -16,7 +16,7 @@ namespace infini::ops { template <> -class Operator : public Gemm { +class Operator : public Gemm { public: Operator(const Tensor a, const Tensor b, std::optional alpha, std::optional beta, std::optional trans_a, diff --git a/src/nvidia/gemm/registry.h b/src/nvidia/gemm/registry.h index a13591dc..f9dacb6c 100644 --- a/src/nvidia/gemm/registry.h +++ b/src/nvidia/gemm/registry.h @@ -5,9 +5,18 @@ namespace infini::ops { +// Gemm-specific implementation indices. +// cuBLAS is the default for stability (matches reference implementations). +// cuBLASLt uses heuristic algorithm selection and is 2-3x faster on +// typical LLM shapes — select with `implementation="cublaslt"`. +struct GemmImpl { + static constexpr std::size_t kCublas = 0; + static constexpr std::size_t kCublasLt = 1; +}; + template <> struct ActiveImplementationsImpl { - using type = List<0, 1>; + using type = List; }; } // namespace infini::ops diff --git a/src/nvidia/matmul/cublas.h b/src/nvidia/matmul/cublas.h new file mode 100644 index 00000000..0bdc5aa6 --- /dev/null +++ b/src/nvidia/matmul/cublas.h @@ -0,0 +1,19 @@ +#ifndef INFINI_OPS_NVIDIA_MATMUL_CUBLAS_H_ +#define INFINI_OPS_NVIDIA_MATMUL_CUBLAS_H_ + +#include "cuda/matmul/blas.h" +#include "nvidia/blas.h" +#include "nvidia/matmul/registry.h" + +namespace infini::ops { + +template <> +class Operator + : public BlasMatmul> { + public: + using BlasMatmul>::BlasMatmul; +}; + +} // namespace infini::ops + +#endif diff --git a/src/nvidia/matmul/cublaslt.h b/src/nvidia/matmul/cublaslt.h new file mode 100644 index 00000000..09cf0778 --- /dev/null +++ b/src/nvidia/matmul/cublaslt.h @@ -0,0 +1,188 @@ +#ifndef INFINI_OPS_NVIDIA_MATMUL_CUBLASLT_H_ +#define INFINI_OPS_NVIDIA_MATMUL_CUBLASLT_H_ + +#include +#include + +// clang-format off +#include "cublasLt.h" +// clang-format on + +#include "base/matmul.h" +#include "nvidia/blas_utils.h" +#include "nvidia/matmul/registry.h" +#include "nvidia/runtime_.h" + +namespace infini::ops { + +template <> +class Operator : public Matmul { + public: + Operator(const Tensor a, const Tensor b, Tensor c, bool trans_a, + bool trans_b) + : Matmul{a, b, c, trans_a, trans_b}, + a_is_col_major_{a.stride(-1) == 1}, + b_is_col_major_{b.stride(-1) == 1}, + swap_a_and_b_{c.stride(-1) == 1} {} + + Operator(const Tensor a, const Tensor b, Tensor c) + : Operator{a, b, c, false, false} {} + + void operator()(const Tensor a, const Tensor b, Tensor c, bool trans_a, + bool trans_b) const override { + const auto op_a{GetOpA(trans_a, trans_b)}; + const auto op_b{GetOpB(trans_a, trans_b)}; + const auto matmul_m{static_cast(swap_a_and_b_ ? n_ : m_)}; + const auto matmul_n{static_cast(swap_a_and_b_ ? m_ : n_)}; + const auto matmul_k{static_cast(k_)}; + + const auto* a_ptr{swap_a_and_b_ ? b.data() : a.data()}; + const auto* b_ptr{swap_a_and_b_ ? a.data() : b.data()}; + const auto a_dtype{BlasUtils::GetDataType( + swap_a_and_b_ ? b.dtype() : a.dtype())}; + const auto b_dtype{BlasUtils::GetDataType( + swap_a_and_b_ ? a.dtype() : b.dtype())}; + const auto c_dtype{ + BlasUtils::GetDataType(c.dtype())}; + const auto a_ld{static_cast(swap_a_and_b_ ? ldb_ : lda_)}; + const auto b_ld{static_cast(swap_a_and_b_ ? lda_ : ldb_)}; + const auto c_ld{static_cast(ldc_)}; + const auto a_batch_stride{static_cast( + swap_a_and_b_ ? batch_stride_b_ : batch_stride_a_)}; + const auto b_batch_stride{static_cast( + swap_a_and_b_ ? batch_stride_a_ : batch_stride_b_)}; + const auto c_batch_stride{static_cast(batch_stride_c_)}; + + cublasLtMatmulDesc_t op_desc{}; + auto status = cublasLtMatmulDescCreate( + &op_desc, BlasUtils::GetComputeType(c.dtype()), + CUDA_R_32F); + assert(status == CUBLAS_STATUS_SUCCESS && + "failed to create cuBLASLt matmul descriptor"); + + status = cublasLtMatmulDescSetAttribute( + op_desc, CUBLASLT_MATMUL_DESC_TRANSA, &op_a, sizeof(op_a)); + assert(status == CUBLAS_STATUS_SUCCESS && + "failed to set cuBLASLt transa attribute"); + + status = cublasLtMatmulDescSetAttribute( + op_desc, CUBLASLT_MATMUL_DESC_TRANSB, &op_b, sizeof(op_b)); + assert(status == CUBLAS_STATUS_SUCCESS && + "failed to set cuBLASLt transb attribute"); + + cublasLtMatrixLayout_t a_layout{}; + status = cublasLtMatrixLayoutCreate( + &a_layout, a_dtype, op_a == CUBLAS_OP_N ? matmul_m : matmul_k, + op_a == CUBLAS_OP_N ? matmul_k : matmul_m, a_ld); + assert(status == CUBLAS_STATUS_SUCCESS && + "failed to create cuBLASLt A layout"); + + cublasLtMatrixLayout_t b_layout{}; + status = cublasLtMatrixLayoutCreate( + &b_layout, b_dtype, op_b == CUBLAS_OP_N ? matmul_k : matmul_n, + op_b == CUBLAS_OP_N ? matmul_n : matmul_k, b_ld); + assert(status == CUBLAS_STATUS_SUCCESS && + "failed to create cuBLASLt B layout"); + + cublasLtMatrixLayout_t c_layout{}; + status = cublasLtMatrixLayoutCreate(&c_layout, c_dtype, matmul_m, matmul_n, + c_ld); + assert(status == CUBLAS_STATUS_SUCCESS && + "failed to create cuBLASLt C layout"); + + if (batch_count_ > 1) { + SetStridedBatchAttributes(a_layout, a_batch_stride); + SetStridedBatchAttributes(b_layout, b_batch_stride); + SetStridedBatchAttributes(c_layout, c_batch_stride); + } + + cublasLtMatmulPreference_t preference{}; + status = cublasLtMatmulPreferenceCreate(&preference); + assert(status == CUBLAS_STATUS_SUCCESS && + "failed to create cuBLASLt preference"); + + const auto workspace_size{workspace_size_in_bytes_}; + status = cublasLtMatmulPreferenceSetAttribute( + preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspace_size, + sizeof(workspace_size)); + assert(status == CUBLAS_STATUS_SUCCESS && + "failed to set cuBLASLt workspace preference"); + + cublasLtMatmulHeuristicResult_t heuristic{}; + int returned_results{0}; + status = cublasLtMatmulAlgoGetHeuristic( + GetHandle(), op_desc, a_layout, b_layout, c_layout, c_layout, + preference, 1, &heuristic, &returned_results); + assert(status == CUBLAS_STATUS_SUCCESS && returned_results > 0 && + "failed to find a cuBLASLt matmul algorithm"); + + const float alpha{1.0f}; + const float beta{0.0f}; + status = cublasLtMatmul( + GetHandle(), op_desc, &alpha, a_ptr, a_layout, b_ptr, b_layout, &beta, + c.data(), c_layout, c.data(), c_layout, &heuristic.algo, workspace_, + workspace_size_in_bytes_, + static_cast::Stream>(stream_)); + assert(status == CUBLAS_STATUS_SUCCESS && "cuBLASLt matmul launch failed"); + + cublasLtMatmulPreferenceDestroy(preference); + cublasLtMatrixLayoutDestroy(c_layout); + cublasLtMatrixLayoutDestroy(b_layout); + cublasLtMatrixLayoutDestroy(a_layout); + cublasLtMatmulDescDestroy(op_desc); + } + + private: + static cublasLtHandle_t& GetHandle() { + static cublasLtHandle_t handle = []() { + cublasLtHandle_t h{}; + auto status = cublasLtCreate(&h); + assert(status == CUBLAS_STATUS_SUCCESS && + "failed to create cuBLASLt handle"); + return h; + }(); + return handle; + } + + void SetStridedBatchAttributes(cublasLtMatrixLayout_t layout, + int64_t batch_stride) const { + const int batch_count{static_cast(batch_count_)}; + auto status = cublasLtMatrixLayoutSetAttribute( + layout, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, + sizeof(batch_count)); + assert(status == CUBLAS_STATUS_SUCCESS && + "failed to set cuBLASLt batch count"); + + status = cublasLtMatrixLayoutSetAttribute( + layout, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &batch_stride, + sizeof(batch_stride)); + assert(status == CUBLAS_STATUS_SUCCESS && + "failed to set cuBLASLt batch stride"); + } + + cublasOperation_t GetOpA(bool trans_a, bool trans_b) const { + if (swap_a_and_b_) { + return (b_is_col_major_ == trans_b) ? CUBLAS_OP_T : CUBLAS_OP_N; + } + + return (a_is_col_major_ != trans_a) ? CUBLAS_OP_T : CUBLAS_OP_N; + } + + cublasOperation_t GetOpB(bool trans_a, bool trans_b) const { + if (swap_a_and_b_) { + return (a_is_col_major_ == trans_a) ? CUBLAS_OP_T : CUBLAS_OP_N; + } + + return (b_is_col_major_ != trans_b) ? CUBLAS_OP_T : CUBLAS_OP_N; + } + + bool a_is_col_major_{false}; + + bool b_is_col_major_{false}; + + bool swap_a_and_b_{false}; +}; + +} // namespace infini::ops + +#endif diff --git a/src/nvidia/matmul/registry.h b/src/nvidia/matmul/registry.h new file mode 100644 index 00000000..5b13c8e4 --- /dev/null +++ b/src/nvidia/matmul/registry.h @@ -0,0 +1,15 @@ +#ifndef INFINI_OPS_NVIDIA_MATMUL_REGISTRY_H_ +#define INFINI_OPS_NVIDIA_MATMUL_REGISTRY_H_ + +#include "base/matmul.h" + +namespace infini::ops { + +template <> +struct ActiveImplementationsImpl { + using type = List<0, 1>; +}; + +} // namespace infini::ops + +#endif diff --git a/src/nvidia/rms_norm/kernel.h b/src/nvidia/rms_norm/kernel.h deleted file mode 100644 index 7499b81d..00000000 --- a/src/nvidia/rms_norm/kernel.h +++ /dev/null @@ -1,21 +0,0 @@ -#ifndef INFINI_OPS_NVIDIA_RMS_NORM_KERNEL_H_ -#define INFINI_OPS_NVIDIA_RMS_NORM_KERNEL_H_ - -#include - -#include "cuda/rms_norm/kernel.h" -#include "nvidia/caster.cuh" -#include "nvidia/runtime_.h" - -namespace infini::ops { - -template <> -class Operator - : public CudaRmsNorm> { - public: - using CudaRmsNorm>::CudaRmsNorm; -}; - -} // namespace infini::ops - -#endif diff --git a/src/nvidia/swiglu/kernel.h b/src/nvidia/swiglu/kernel.h deleted file mode 100644 index 8e393521..00000000 --- a/src/nvidia/swiglu/kernel.h +++ /dev/null @@ -1,21 +0,0 @@ -#ifndef INFINI_OPS_NVIDIA_SWIGLU_KERNEL_H_ -#define INFINI_OPS_NVIDIA_SWIGLU_KERNEL_H_ - -#include - -#include "cuda/swiglu/kernel.h" -#include "nvidia/caster.cuh" -#include "nvidia/runtime_.h" - -namespace infini::ops { - -template <> -class Operator - : public CudaSwiglu> { - public: - using CudaSwiglu>::CudaSwiglu; -}; - -} // namespace infini::ops - -#endif diff --git a/src/operator.h b/src/operator.h index 76efd7a9..65ea99bc 100644 --- a/src/operator.h +++ b/src/operator.h @@ -3,6 +3,7 @@ #include #include +#include #include #include #include @@ -36,6 +37,14 @@ struct CacheKey { tensors.push_back(t); } + void Absorb(const std::vector& ts) { + HashCombine(hash, ts.size()); + for (const auto& t : ts) { + HashCombine(hash, t); + tensors.push_back(t); + } + } + template void Absorb(const T& v) { HashCombine(hash, v); @@ -43,10 +52,30 @@ struct CacheKey { } }; +// Check whether a value is present in a compile-time List. +template +constexpr bool ListContains(std::size_t value, List) { + return ((value == static_cast(values)) || ...); +} + +// Return the first element of a compile-time List. +template +constexpr std::size_t ListFirst(List) { + return static_cast(head); +} + template auto DispatchImplementation(std::size_t implementation_index, Functor&& func, std::string_view context_str, - List, Args&&... args) { + List list, + Args&&... args) { + // Fall back to the first available implementation when the requested + // index does not exist (e.g., operator has only a DSL implementation + // but the caller uses the default index 0). + if (!ListContains(implementation_index, list)) { + implementation_index = ListFirst(list); + } + return DispatchFunc(implementation_indices)...>( implementation_index, std::forward(func), context_str, @@ -176,10 +205,10 @@ class Operator : public OperatorBase { auto it{cache.find(key)}; if (it == cache.end()) { - it = cache - .emplace(std::move(key), - make(config, std::forward(args)...)) - .first; + // Pass args as lvalue refs so they remain valid for the `operator()` call + // below. Forwarding rvalue temporaries into `make()` would leave the args + // in a moved-from (empty) state before operator() can use them. + it = cache.emplace(std::move(key), make(config, args...)).first; } auto& op{it->second}; diff --git a/src/pybind11_utils.h b/src/pybind11_utils.h index 0f5e73b9..27fca4f5 100644 --- a/src/pybind11_utils.h +++ b/src/pybind11_utils.h @@ -116,6 +116,31 @@ inline Tensor TensorFromPybind11Handle(py::handle obj) { return Tensor{data, std::move(shape), dtype, device, std::move(strides)}; } +inline std::optional OptionalTensorFromPybind11Handle( + const std::optional& obj) { + if (!obj.has_value() || obj->is_none()) return std::nullopt; + return TensorFromPybind11Handle(*obj); +} + +inline std::optional TensorFromPybind11Handle( + std::optional obj) { + if (!obj.has_value() || obj->is_none()) { + return std::nullopt; + } + + return TensorFromPybind11Handle(obj->cast()); +} + +inline std::vector VectorTensorFromPybind11Handle( + const std::vector& objs) { + std::vector result; + result.reserve(objs.size()); + for (const auto& obj : objs) { + result.push_back(TensorFromPybind11Handle(obj)); + } + return result; +} + } // namespace infini::ops #endif diff --git a/tests/benchmark_all.py b/tests/benchmark_all.py new file mode 100644 index 00000000..dbc62d07 --- /dev/null +++ b/tests/benchmark_all.py @@ -0,0 +1,349 @@ +"""Comprehensive performance benchmark for all CUDA operators. + +Run with: pytest tests/benchmark_all.py --benchmark -v -s --devices cuda +""" + +import pytest +import torch +import torch.utils.benchmark as benchmark + +import infini.ops + +pytestmark = pytest.mark.skipif( + not torch.cuda.is_available(), reason="CUDA not available" +) + + +def _bench(fn, label, sub_label, min_run_time=2): + """Benchmark a function and return the measurement.""" + timer = benchmark.Timer( + stmt="fn()", + globals={"fn": fn}, + label=label, + sub_label=sub_label, + ) + + return timer.blocked_autorange(min_run_time=min_run_time) + + +# ---- Add ---- + +@pytest.mark.benchmark +@pytest.mark.parametrize("shape", [(4, 4, 5632), (1, 32, 4096), (64, 32, 128)]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) +def test_bench_add(shape, dtype): + a = torch.randn(shape, dtype=dtype, device="cuda") + b = torch.randn(shape, dtype=dtype, device="cuda") + out = torch.empty(shape, dtype=dtype, device="cuda") + + m = _bench(lambda: infini.ops.add(a, b, out), "Add", f"{shape} {dtype}") + print(f" Add {shape} {dtype}: {m.median*1e3:.3f} ms") + + +# ---- Mul ---- + +@pytest.mark.benchmark +@pytest.mark.parametrize("shape", [(4, 4, 5632), (1, 32, 4096), (64, 32, 128)]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) +def test_bench_mul(shape, dtype): + a = torch.randn(shape, dtype=dtype, device="cuda") + b = torch.randn(shape, dtype=dtype, device="cuda") + out = torch.empty(shape, dtype=dtype, device="cuda") + + m = _bench(lambda: infini.ops.mul(a, b, out), "Mul", f"{shape} {dtype}") + print(f" Mul {shape} {dtype}: {m.median*1e3:.3f} ms") + + +# ---- Cast ---- + +@pytest.mark.benchmark +@pytest.mark.parametrize( + "shape, in_dtype, out_dtype", + [ + ((4, 4, 5632), torch.float32, torch.float16), + ((4, 4, 5632), torch.float16, torch.float32), + ((1, 32, 4096), torch.float32, torch.bfloat16), + ((1, 32, 4096), torch.bfloat16, torch.float32), + ], +) +def test_bench_cast(shape, in_dtype, out_dtype): + inp = torch.randn(shape, dtype=in_dtype, device="cuda") + out = torch.empty(shape, dtype=out_dtype, device="cuda") + + m = _bench( + lambda: infini.ops.cast(inp, out), "Cast", f"{shape} {in_dtype}->{out_dtype}" + ) + print(f" Cast {shape} {in_dtype}->{out_dtype}: {m.median*1e3:.3f} ms") + + +# ---- Swiglu ---- + +@pytest.mark.benchmark +@pytest.mark.parametrize("shape", [(4, 4, 5632), (1, 32, 4096)]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) +def test_bench_swiglu(shape, dtype): + inp = torch.rand(shape, dtype=dtype, device="cuda") + gate = torch.rand(shape, dtype=dtype, device="cuda") + out = torch.empty(shape, dtype=dtype, device="cuda") + + m = _bench( + lambda: infini.ops.swiglu(inp, gate, out), "Swiglu", f"{shape} {dtype}" + ) + print(f" Swiglu {shape} {dtype}: {m.median*1e3:.3f} ms") + + +# ---- RmsNorm ---- + +@pytest.mark.benchmark +@pytest.mark.parametrize("shape", [(2, 4, 2048), (1, 32, 4096), (4, 48, 64)]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) +def test_bench_rms_norm(shape, dtype): + inp = torch.randn(shape, dtype=dtype, device="cuda") + weight = torch.randn(shape[-1], dtype=dtype, device="cuda") + out = torch.empty(shape, dtype=dtype, device="cuda") + + m = _bench( + lambda: infini.ops.rms_norm(inp, weight, 1e-6, out), + "RmsNorm", + f"{shape} {dtype}", + ) + print(f" RmsNorm {shape} {dtype}: {m.median*1e3:.3f} ms") + + +# ---- CausalSoftmax ---- + +@pytest.mark.benchmark +@pytest.mark.parametrize("shape", [(2, 4, 64, 64), (1, 32, 128, 128)]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) +def test_bench_causal_softmax(shape, dtype): + inp = torch.randn(shape, dtype=dtype, device="cuda") + out = torch.empty(shape, dtype=dtype, device="cuda") + + m = _bench( + lambda: infini.ops.causal_softmax(inp, out), + "CausalSoftmax", + f"{shape} {dtype}", + ) + print(f" CausalSoftmax {shape} {dtype}: {m.median*1e3:.3f} ms") + + +# ---- AddRmsNorm ---- + +@pytest.mark.benchmark +@pytest.mark.parametrize("shape", [(2, 4, 2048), (1, 32, 4096)]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) +def test_bench_add_rms_norm(shape, dtype): + x1 = torch.randn(shape, dtype=dtype, device="cuda") + x2 = torch.randn(shape, dtype=dtype, device="cuda") + weight = torch.randn(shape[-1], dtype=dtype, device="cuda") + y_out = torch.empty(shape, dtype=dtype, device="cuda") + x_out = torch.empty(shape, dtype=dtype, device="cuda") + + m = _bench( + lambda: infini.ops.add_rms_norm(x1, x2, weight, 1e-6, y_out, x_out), + "AddRmsNorm", + f"{shape} {dtype}", + ) + print(f" AddRmsNorm {shape} {dtype}: {m.median*1e3:.3f} ms") + + +# ---- Cat ---- + +@pytest.mark.benchmark +@pytest.mark.parametrize( + "shapes, dim", + [ + ([(4, 128), (4, 128), (4, 128)], 0), + ([(4, 1024), (4, 2048), (4, 512)], 1), + ([(2, 32, 4096), (2, 32, 4096)], 0), + ], +) +@pytest.mark.parametrize("dtype", [torch.float16]) +def test_bench_cat(shapes, dim, dtype): + tensors = [torch.randn(s, dtype=dtype, device="cuda") for s in shapes] + + out_shape = list(shapes[0]) + out_shape[dim] = sum(s[dim] for s in shapes) + out = torch.empty(out_shape, dtype=dtype, device="cuda") + + first = tensors[0] + rest = tensors[1:] + + m = _bench( + lambda: infini.ops.cat(first, rest, dim, out), + "Cat", + f"{shapes} dim={dim} {dtype}", + ) + print(f" Cat {shapes} dim={dim}: {m.median*1e3:.3f} ms") + + +# ---- Gemm ---- + +@pytest.mark.benchmark +@pytest.mark.parametrize( + "M, N, K", + [(1024, 1024, 1024), (4096, 4096, 4096), (1, 4096, 4096)], +) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_bench_gemm(M, N, K, dtype): + a = torch.randn(M, K, dtype=dtype, device="cuda") + b = torch.randn(K, N, dtype=dtype, device="cuda") + c = torch.empty(M, N, dtype=dtype, device="cuda") + + m = _bench( + lambda: infini.ops.gemm(a, b, c), "Gemm", f"({M},{N},{K}) {dtype}" + ) + + tflops = 2 * M * N * K / m.median / 1e12 + print(f" Gemm ({M},{N},{K}) {dtype}: {m.median*1e3:.3f} ms ({tflops:.1f} TFLOPS)") + + +# ---- Matmul ---- + +@pytest.mark.benchmark +@pytest.mark.parametrize( + "M, N, K", + [(1024, 1024, 1024), (4096, 4096, 4096), (1, 4096, 4096)], +) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_bench_matmul(M, N, K, dtype): + a = torch.randn(M, K, dtype=dtype, device="cuda") + b = torch.randn(K, N, dtype=dtype, device="cuda") + c = torch.empty(M, N, dtype=dtype, device="cuda") + + m = _bench( + lambda: infini.ops.matmul(a, b, c, False, False), + "Matmul", + f"({M},{N},{K}) {dtype}", + ) + + tflops = 2 * M * N * K / m.median / 1e12 + print( + f" Matmul ({M},{N},{K}) {dtype}: {m.median*1e3:.3f} ms ({tflops:.1f} TFLOPS)" + ) + + +# ---- Linear ---- + +@pytest.mark.benchmark +@pytest.mark.parametrize( + "M, N, K, has_bias", + [(1024, 4096, 4096, False), (1024, 4096, 4096, True), (1, 4096, 4096, False)], +) +@pytest.mark.parametrize("dtype", [torch.float16]) +def test_bench_linear(M, N, K, has_bias, dtype): + a = torch.randn(M, K, dtype=dtype, device="cuda") + b = torch.randn(K, N, dtype=dtype, device="cuda") + bias = torch.randn(N, dtype=dtype, device="cuda") if has_bias else None + out = torch.empty(M, N, dtype=dtype, device="cuda") + + m = _bench( + lambda: infini.ops.linear(a, b, bias, False, False, out), + "Linear", + f"({M},{N},{K}) bias={has_bias} {dtype}", + ) + print( + f" Linear ({M},{N},{K}) bias={has_bias}: {m.median*1e3:.3f} ms" + ) + + +# ---- RotaryEmbedding ---- + +@pytest.mark.benchmark +@pytest.mark.parametrize( + "num_tokens, num_heads, head_size", + [(128, 32, 128), (1, 32, 128), (512, 32, 64)], +) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_bench_rotary_embedding(num_tokens, num_heads, head_size, dtype): + positions = torch.arange(num_tokens, device="cuda") + query = torch.randn(num_tokens, num_heads, head_size, dtype=dtype, device="cuda") + key = torch.randn(num_tokens, num_heads, head_size, dtype=dtype, device="cuda") + cos_sin = torch.randn(8192, head_size, dtype=dtype, device="cuda") + q_out = torch.empty_like(query) + k_out = torch.empty_like(key) + + m = _bench( + lambda: infini.ops.rotary_embedding( + positions, query, key, cos_sin, head_size, head_size, True, q_out, k_out + ), + "RotaryEmbed", + f"T={num_tokens} H={num_heads} D={head_size} {dtype}", + ) + print( + f" RotaryEmbed T={num_tokens} H={num_heads} D={head_size} {dtype}: " + f"{m.median*1e3:.3f} ms" + ) + + +# ---- ReshapeAndCache ---- + +@pytest.mark.benchmark +@pytest.mark.parametrize( + "num_tokens, num_kv_heads, head_size, block_size, num_blocks", + [(128, 8, 128, 16, 64), (32, 32, 128, 16, 32)], +) +@pytest.mark.parametrize("dtype", [torch.float16]) +def test_bench_reshape_and_cache( + num_tokens, num_kv_heads, head_size, block_size, num_blocks, dtype +): + key = torch.randn(num_tokens, num_kv_heads, head_size, dtype=dtype, device="cuda") + value = torch.randn_like(key) + kv_cache = torch.zeros( + 2, num_blocks, block_size, num_kv_heads, head_size, dtype=dtype, device="cuda" + ) + slot_mapping = torch.randint( + 0, num_blocks * block_size, (num_tokens,), dtype=torch.int64, device="cuda" + ) + kv_cache_out = kv_cache.clone() + + m = _bench( + lambda: infini.ops.reshape_and_cache( + key, value, kv_cache, slot_mapping, kv_cache_out + ), + "ReshapeAndCache", + f"T={num_tokens} Nkv={num_kv_heads} D={head_size} {dtype}", + ) + print( + f" ReshapeAndCache T={num_tokens} Nkv={num_kv_heads}: {m.median*1e3:.3f} ms" + ) + + +# ---- FlashAttention ---- + +@pytest.mark.benchmark +@pytest.mark.parametrize( + "seq_len, num_heads, num_kv_heads, head_size", + [ + (128, 32, 32, 128), + (512, 32, 32, 128), + (2048, 32, 32, 128), + (128, 32, 8, 128), + (512, 32, 8, 128), + ], +) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_bench_flash_attention(seq_len, num_heads, num_kv_heads, head_size, dtype): + q = torch.randn(seq_len, num_heads, head_size, dtype=dtype, device="cuda") + k = torch.randn(seq_len, num_kv_heads, head_size, dtype=dtype, device="cuda") + v = torch.randn(seq_len, num_kv_heads, head_size, dtype=dtype, device="cuda") + o = torch.empty(seq_len, num_heads, head_size, dtype=dtype, device="cuda") + scale = 1.0 / head_size**0.5 + + m = _bench( + lambda: infini.ops.flash_attention( + q, k, v, None, None, None, + num_heads, num_kv_heads, head_size, scale, + True, -1, -1, 0, o, + ), + "FlashAttn", + f"S={seq_len} H={num_heads}/{num_kv_heads} D={head_size} {dtype}", + ) + + # FLOPs: 2 * S * S * H * D (for QK^T) + 2 * S * S * H * D (for attn @ V) + flops = 4 * seq_len * seq_len * num_heads * head_size + tflops = flops / m.median / 1e12 + print( + f" FlashAttn S={seq_len} H={num_heads}/{num_kv_heads} {dtype}: " + f"{m.median*1e3:.3f} ms ({tflops:.1f} TFLOPS)" + ) diff --git a/tests/benchmark_dsl.py b/tests/benchmark_dsl.py new file mode 100644 index 00000000..9d4c1150 --- /dev/null +++ b/tests/benchmark_dsl.py @@ -0,0 +1,152 @@ +"""Performance benchmark comparing DSL-generated vs hand-written kernels. + +Measures the execution time of DSL-generated and hand-written (default) +implementations for each operator on CUDA, printing a comparison summary. +""" + +import pytest +import torch +import torch.utils.benchmark as benchmark + +import infini.ops + +pytestmark = pytest.mark.skipif( + not torch.cuda.is_available(), reason="CUDA not available" +) + + +# --------------------------------------------------------------------------- +# Setup helpers +# --------------------------------------------------------------------------- + + +def _setup_binary(shape, dtype, device): + """Create input, other, and output tensors for binary operators.""" + input = torch.randn(shape, dtype=dtype, device=device) + other = torch.randn(shape, dtype=dtype, device=device) + out = torch.empty(shape, dtype=dtype, device=device) + + return input, other, out + + +def _setup_rms_norm(shape, dtype, device): + """Create input, weight, output tensors and epsilon for RmsNorm.""" + input = torch.randn(shape, dtype=dtype, device=device) + weight = torch.randn(shape[-1], dtype=dtype, device=device) + out = torch.empty(shape, dtype=dtype, device=device) + eps = 1e-6 + + return input, weight, out, eps + + +# --------------------------------------------------------------------------- +# Benchmark runner +# --------------------------------------------------------------------------- + + +def _run_benchmark(fn, label, sub_label, num_warmup=10): + """Run warmup iterations then measure with ``torch.utils.benchmark.Timer``.""" + + for _ in range(num_warmup): + fn() + + timer = benchmark.Timer( + stmt="fn()", + globals={"fn": fn}, + label=label, + sub_label=sub_label, + ) + + return timer.blocked_autorange(min_run_time=1) + + +def _print_comparison(op_name, shape, dtype, default_result, dsl_result): + """Print a one-line comparison of default vs DSL timings.""" + default_ms = default_result.median * 1e3 + dsl_ms = dsl_result.median * 1e3 + ratio = default_ms / dsl_ms + + print( + f"{op_name}: default={default_ms:.3f}ms, dsl={dsl_ms:.3f}ms, " + f"ratio={ratio:.2f}x (shape={shape}, dtype={dtype})" + ) + + +# --------------------------------------------------------------------------- +# Benchmarks +# --------------------------------------------------------------------------- + + +@pytest.mark.benchmark +@pytest.mark.parametrize("shape", [(4, 4, 5632), (1024, 1024)]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) +def test_benchmark_add(shape, dtype): + """Benchmark Add operator: default (hand-written) vs DSL implementation.""" + device = "cuda" + input, other, out = _setup_binary(shape, dtype, device) + + label = f"Add {shape} {dtype}" + + default_result = _run_benchmark( + lambda: infini.ops.add(input, other, out, implementation="default"), + label, + "default", + ) + + dsl_result = _run_benchmark( + lambda: infini.ops.add(input, other, out, implementation="dsl"), + label, + "dsl", + ) + + _print_comparison("Add", shape, dtype, default_result, dsl_result) + + +@pytest.mark.benchmark +@pytest.mark.parametrize("shape", [(4, 4, 5632), (1024, 1024)]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) +def test_benchmark_rms_norm(shape, dtype): + """Benchmark RmsNorm operator: default (hand-written) vs DSL implementation.""" + device = "cuda" + input, weight, out, eps = _setup_rms_norm(shape, dtype, device) + + label = f"RmsNorm {shape} {dtype}" + + default_result = _run_benchmark( + lambda: infini.ops.rms_norm(input, weight, eps, out, implementation="default"), + label, + "default", + ) + + dsl_result = _run_benchmark( + lambda: infini.ops.rms_norm(input, weight, eps, out, implementation="dsl"), + label, + "dsl", + ) + + _print_comparison("RmsNorm", shape, dtype, default_result, dsl_result) + + +@pytest.mark.benchmark +@pytest.mark.parametrize("shape", [(4, 4, 5632), (1024, 1024)]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) +def test_benchmark_swiglu(shape, dtype): + """Benchmark Swiglu operator: default (hand-written) vs DSL implementation.""" + device = "cuda" + input, gate, out = _setup_binary(shape, dtype, device) + + label = f"Swiglu {shape} {dtype}" + + default_result = _run_benchmark( + lambda: infini.ops.swiglu(input, gate, out, implementation="default"), + label, + "default", + ) + + dsl_result = _run_benchmark( + lambda: infini.ops.swiglu(input, gate, out, implementation="dsl"), + label, + "dsl", + ) + + _print_comparison("Swiglu", shape, dtype, default_result, dsl_result) diff --git a/tests/conftest.py b/tests/conftest.py index 44654c3d..905e011a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,6 +12,12 @@ def pytest_addoption(parser): parser.addoption( "--benchmark", action="store_true", help="Run performance benchmarks." ) + parser.addoption( + "--devices", + nargs="+", + default=None, + help="Device(s) to test on (e.g., --devices ascend cpu). Accepts platform names (ascend, nvidia, cambricon, metax, moore, iluvatar) or PyTorch device types (npu, cuda, mlu, musa). Defaults to all available devices.", + ) def pytest_configure(config): @@ -38,11 +44,46 @@ def set_seed_per_test(request): _set_random_seed(_hash(_test_case_path_from_request(request))) +_NPU_UNSUPPORTED_DTYPES = {torch.float64} + +# `torch_npu` does not implement random number generation for `uint16`/`uint32`/`uint64`. +for _bits in (16, 32, 64): + _t = getattr(torch, f"uint{_bits}", None) + if _t is not None: + _NPU_UNSUPPORTED_DTYPES.add(_t) + + +@pytest.fixture(autouse=True) +def skip_unsupported_dtype(request): + if not hasattr(request.node, "callspec"): + return + + params = request.node.callspec.params + + if params.get("device") == "npu" and params.get("dtype") in _NPU_UNSUPPORTED_DTYPES: + pytest.skip(f"{params['dtype']} not supported on Ascend 910B") + + def _set_random_seed(seed): random.seed(seed) torch.manual_seed(seed) +_PLATFORM_TO_TORCH_DEVICE = { + "nvidia": "cuda", + "iluvatar": "cuda", + "metax": "cuda", + "cambricon": "mlu", + "moore": "musa", + "ascend": "npu", +} + + +def _resolve_device(name): + """Map a platform name (e.g., ``ascend``) to a PyTorch device type (e.g., ``npu``).""" + return _PLATFORM_TO_TORCH_DEVICE.get(name, name) + + def pytest_generate_tests(metafunc): already_parametrized = _get_parametrized_args(metafunc) @@ -57,7 +98,17 @@ def pytest_generate_tests(metafunc): ) if "device" in metafunc.fixturenames and "device" not in already_parametrized: - metafunc.parametrize("device", get_available_devices()) + cli_devices = metafunc.config.getoption("--devices") + available = get_available_devices() + + if cli_devices: + devices = tuple( + d for d in (_resolve_device(x) for x in cli_devices) if d in available + ) + else: + devices = () + + metafunc.parametrize("device", devices or available) @pytest.hookimpl(tryfirst=True) diff --git a/tests/test_add.py b/tests/test_add.py index 8b8166c3..f5604355 100644 --- a/tests/test_add.py +++ b/tests/test_add.py @@ -2,7 +2,13 @@ import pytest import torch -from tests.utils import Payload, empty_strided, randint_strided, randn_strided +from tests.utils import ( + Payload, + empty_strided, + get_npu_stream, + randint_strided, + randn_strided, +) _INT_DTYPES = (torch.int16, torch.int32, torch.int64) @@ -63,7 +69,10 @@ def test_add( def _add(input, other, out): - infini.ops.add(input, other, out) + if input.device.type == "npu": + infini.ops.add(input, other, out, stream=get_npu_stream(input)) + else: + infini.ops.add(input, other, out) return out diff --git a/tests/test_add_dsl.py b/tests/test_add_dsl.py new file mode 100644 index 00000000..681c78b2 --- /dev/null +++ b/tests/test_add_dsl.py @@ -0,0 +1,56 @@ +"""Tests for the DSL-generated Add operator (implementation_index=1). + +Validates that the DSL-generated CUDA and CPU code produces results +identical to PyTorch's `torch.add`. +""" + +import infini.ops +import pytest +import torch + +from tests.utils import Payload, empty_strided, randn_strided + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "shape, input_strides, other_strides, out_strides", + ( + ((13, 4), None, None, None), + ((13, 4), (10, 1), (10, 1), (10, 1)), + ((13, 4), (0, 1), None, None), + ((13, 4, 4), None, None, None), + ((16, 5632), None, None, None), + ((4, 4, 5632), None, None, None), + ), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float32, 1e-7, 1e-7), + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ), +) +def test_add_dsl( + shape, input_strides, other_strides, out_strides, dtype, device, rtol, atol +): + input = randn_strided(shape, input_strides, dtype=dtype, device=device) + other = randn_strided(shape, other_strides, dtype=dtype, device=device) + out = empty_strided(shape, out_strides, dtype=dtype, device=device) + + return Payload( + _add_dsl, _torch_add, (input, other, out), {}, rtol=rtol, atol=atol + ) + + +def _add_dsl(input, other, out): + infini.ops.add(input, other, out, implementation="dsl") + + return out + + +def _torch_add(input, other, out): + res = torch.add(input, other) + out.copy_(res) + + return out diff --git a/tests/test_add_rms_norm.py b/tests/test_add_rms_norm.py new file mode 100644 index 00000000..a8197b11 --- /dev/null +++ b/tests/test_add_rms_norm.py @@ -0,0 +1,89 @@ +import infini.ops +import pytest +import torch + +from tests.utils import Payload, empty_strided, randn_strided + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "shape, weight_shape, x1_strides, x2_strides, weight_strides, y_out_strides, x_out_strides", + ( + ((1, 64), (64,), None, None, None, None, None), + ((2, 128), (128,), None, None, None, None, None), + ((4, 48, 64), (64,), None, None, None, None, None), + ((2, 4, 2048), (2048,), None, None, None, None, None), + ((1, 64), (64,), (64, 1), (64, 1), (1,), (64, 1), (64, 1)), + ( + (4, 48, 64), + (64,), + (3072, 64, 1), + (3072, 64, 1), + (1,), + (3072, 64, 1), + (3072, 64, 1), + ), + ), +) +@pytest.mark.parametrize("eps", (1e-6, 1e-5)) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float32, 1e-4, 1e-4), + (torch.float16, 1e-2, 1e-2), + (torch.bfloat16, 2e-2, 1e-2), + ), +) +def test_add_rms_norm( + shape, + weight_shape, + x1_strides, + x2_strides, + weight_strides, + y_out_strides, + x_out_strides, + eps, + dtype, + device, + rtol, + atol, +): + x1 = randn_strided(shape, x1_strides, dtype=dtype, device=device) + x2 = randn_strided(shape, x2_strides, dtype=dtype, device=device) + weight = randn_strided(weight_shape, weight_strides, dtype=dtype, device=device) + y_out = empty_strided(shape, y_out_strides, dtype=dtype, device=device) + x_out = empty_strided(shape, x_out_strides, dtype=dtype, device=device) + + return Payload( + _add_rms_norm, + _torch_add_rms_norm, + (x1, x2, weight), + {"eps": eps, "y_out": y_out, "x_out": x_out}, + rtol=rtol, + atol=atol, + ) + + +def _add_rms_norm(x1, x2, weight, *, eps=1e-6, y_out=None, x_out=None): + infini.ops.add_rms_norm(x1, x2, weight, eps, y_out, x_out) + + return y_out + + +def _torch_add_rms_norm(x1, x2, weight, *, eps=1e-6, y_out=None, x_out=None): + # Compute residual = x1 + x2. + residual = x1.float() + x2.float() + + if x_out is not None: + x_out.copy_(residual.to(x1.dtype)) + + # Compute rms_norm(residual) * weight. + rms = torch.sqrt(torch.mean(residual * residual, dim=-1, keepdim=True) + eps) + result = (residual / rms).to(x1.dtype) * weight + + if y_out is not None: + y_out.copy_(result) + else: + y_out = result + + return y_out diff --git a/tests/test_cast.py b/tests/test_cast.py new file mode 100644 index 00000000..24b50ee9 --- /dev/null +++ b/tests/test_cast.py @@ -0,0 +1,65 @@ +import infini.ops +import pytest +import torch + +from tests.utils import Payload, empty_strided, get_npu_stream, randn_strided + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "shape, input_strides, out_strides", + ( + ((13, 4), None, None), + ((13, 4), (10, 1), (10, 1)), + ((13, 4, 4), None, None), + ((16, 5632), None, None), + ((4, 4, 5632), None, None), + ), +) +@pytest.mark.parametrize( + ("input_dtype", "out_dtype", "rtol", "atol"), + ( + (torch.float16, torch.float32, 1e-3, 1e-3), + (torch.float32, torch.float16, 1e-3, 1e-3), + (torch.bfloat16, torch.float32, 1e-2, 5e-3), + (torch.float32, torch.bfloat16, 1e-2, 5e-3), + (torch.float16, torch.bfloat16, 1e-2, 5e-3), + (torch.bfloat16, torch.float16, 1e-2, 5e-3), + ), +) +def test_cast( + shape, + input_strides, + out_strides, + input_dtype, + out_dtype, + device, + rtol, + atol, +): + input = randn_strided(shape, input_strides, dtype=input_dtype, device=device) + out = empty_strided(shape, out_strides, dtype=out_dtype, device=device) + + return Payload( + _cast, + _torch_cast, + (input, out), + {}, + rtol=rtol, + atol=atol, + ) + + +def _cast(input, out): + if input.device.type == "npu": + infini.ops.cast(input, out, stream=get_npu_stream(input)) + else: + infini.ops.cast(input, out) + + return out + + +def _torch_cast(input, out): + out.copy_(input.to(out.dtype)) + + return out diff --git a/tests/test_cast_dsl.py b/tests/test_cast_dsl.py new file mode 100644 index 00000000..e6e41fb4 --- /dev/null +++ b/tests/test_cast_dsl.py @@ -0,0 +1,66 @@ +"""Tests for the DSL-generated Cast operator (implementation_index=1). + +Validates that the DSL-generated CUDA and CPU code produces results +identical to PyTorch's `Tensor.to(dtype)`. +""" + +import infini.ops +import pytest +import torch + +from tests.utils import Payload, empty_strided, randn_strided + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "shape, input_strides, out_strides", + ( + ((13, 4), None, None), + ((13, 4), (10, 1), (10, 1)), + ((13, 4, 4), None, None), + ((16, 5632), None, None), + ((4, 4, 5632), None, None), + ), +) +@pytest.mark.parametrize( + ("input_dtype", "out_dtype", "rtol", "atol"), + ( + (torch.float32, torch.float16, 1e-3, 1e-3), + (torch.float16, torch.float32, 1e-3, 1e-3), + (torch.bfloat16, torch.float32, 1e-2, 5e-3), + (torch.float32, torch.bfloat16, 1e-2, 5e-3), + ), +) +def test_cast_dsl( + shape, + input_strides, + out_strides, + input_dtype, + out_dtype, + device, + rtol, + atol, +): + input = randn_strided(shape, input_strides, dtype=input_dtype, device=device) + out = empty_strided(shape, out_strides, dtype=out_dtype, device=device) + + return Payload( + _cast_dsl, + _torch_cast, + (input, out), + {}, + rtol=rtol, + atol=atol, + ) + + +def _cast_dsl(input, out): + infini.ops.cast(input, out, implementation="dsl") + + return out + + +def _torch_cast(input, out): + out.copy_(input.to(out.dtype)) + + return out diff --git a/tests/test_cat.py b/tests/test_cat.py new file mode 100644 index 00000000..68a9dfa8 --- /dev/null +++ b/tests/test_cat.py @@ -0,0 +1,51 @@ +import infini.ops +import pytest +import torch + +from tests.utils import Payload, randn_strided + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "shapes, dim", + ( + (((4, 3), (4, 5)), 1), + (((2, 3), (4, 3)), 0), + (((2, 3, 4), (2, 5, 4)), 1), + (((2, 3, 4), (2, 3, 6)), 2), + (((2, 3, 4), (2, 3, 4), (2, 3, 4)), 0), + (((1, 8), (3, 8), (2, 8)), 0), + (((3, 1), (3, 2), (3, 4)), 1), + (((2, 3, 4), (2, 3, 4)), -1), + (((2, 3, 4), (2, 3, 4)), -2), + (((16, 128), (16, 256)), 1), + ), +) +def test_cat(shapes, dim, dtype, device, rtol, atol): + inputs = [ + randn_strided(shape, None, dtype=dtype, device=device) + for shape in shapes + ] + + expected_shape = list(shapes[0]) + cat_dim = dim if dim >= 0 else dim + len(shapes[0]) + expected_shape[cat_dim] = sum(s[cat_dim] for s in shapes) + + out = torch.empty(expected_shape, dtype=dtype, device=device) + + return Payload( + _cat, _torch_cat, (inputs, dim, out), {}, rtol=rtol, atol=atol + ) + + +def _cat(inputs, dim, out): + infini.ops.cat(inputs[0], inputs[1:], dim, out) + + return out + + +def _torch_cat(inputs, dim, out): + result = torch.cat(inputs, dim=dim) + out.copy_(result) + + return out diff --git a/tests/test_causal_softmax.py b/tests/test_causal_softmax.py index 8b35457a..df4894c3 100644 --- a/tests/test_causal_softmax.py +++ b/tests/test_causal_softmax.py @@ -2,7 +2,7 @@ import pytest import torch -from tests.utils import Payload, empty_strided, randn_strided +from tests.utils import Payload, empty_strided, get_npu_stream, randn_strided @pytest.mark.auto_act_and_assert @@ -40,7 +40,10 @@ def test_causal_softmax(shape, input_strides, out_strides, dtype, device, rtol, def _causal_softmax(input, out): - infini.ops.causal_softmax(input, out) + if input.device.type == "npu": + infini.ops.causal_softmax(input, out, stream=get_npu_stream(input)) + else: + infini.ops.causal_softmax(input, out) return out @@ -48,7 +51,7 @@ def _causal_softmax(input, out): def _torch_causal_softmax(input, out): mask = torch.tril(torch.ones_like(input), diagonal=-1).flip(dims=[-2, -1]) masked = torch.where(mask == 1, -torch.inf, input.to(torch.float32)) - result = torch.nn.functional.softmax(masked, dim=-1, dtype=input.dtype) + result = torch.nn.functional.softmax(masked, dim=-1) out.copy_(result) return out diff --git a/tests/test_e2e_layer.py b/tests/test_e2e_layer.py new file mode 100644 index 00000000..92df9a2c --- /dev/null +++ b/tests/test_e2e_layer.py @@ -0,0 +1,418 @@ +import infini.ops +import pytest +import torch + +from tests.utils import get_npu_stream, randn_strided, randint_strided + + +def _stream_kw(tensor): + if tensor.device.type == "npu": + return {"stream": get_npu_stream(tensor)} + + return {} + + +def _ref_rms_norm(x, weight, eps): + rms = torch.sqrt(torch.mean(x * x, dim=-1, keepdim=True) + eps) + + return (x / rms) * weight + + +def _ref_rope( + positions, query, key, cos_sin_cache, head_size, rotary_dim, is_neox_style +): + T = query.size(0) + R = rotary_dim + half_R = R // 2 + cos_half = cos_sin_cache[:, :half_R] + sin_half = cos_sin_cache[:, half_R:] + + def apply_rope(x): + out = x.clone() + + for t in range(T): + p = positions[t].item() + c = cos_half[p] + s = sin_half[p] + + if is_neox_style: + x1 = x[t, :, :half_R] + x2 = x[t, :, half_R:R] + out[t, :, :half_R] = c * x1 - s * x2 + out[t, :, half_R:R] = c * x2 + s * x1 + else: + x1 = x[t, :, 0::2] + x2 = x[t, :, 1::2] + out[t, :, 0::2] = c * x1 - s * x2 + out[t, :, 1::2] = c * x2 + s * x1 + + return out + + return apply_rope(query), apply_rope(key) + + +def _ref_sdpa(query, key, value, num_heads, num_kv_heads, head_size, scale, causal): + q = query.transpose(0, 1).float() + k = key.transpose(0, 1).float() + v = value.transpose(0, 1).float() + + if num_kv_heads < num_heads: + ratio = num_heads // num_kv_heads + k = k.repeat_interleave(ratio, dim=0) + v = v.repeat_interleave(ratio, dim=0) + + out = torch.nn.functional.scaled_dot_product_attention( + q.unsqueeze(0), + k.unsqueeze(0), + v.unsqueeze(0), + scale=scale, + is_causal=causal, + ) + + return out.squeeze(0).transpose(0, 1) + + +def _infiniops_layer( + hidden, + positions, + cos_sin_cache, + input_norm_w, + qkv_proj_w, + o_proj_w, + gate_proj_w, + up_proj_w, + down_proj_w, + post_norm_w, + num_heads, + num_kv_heads, + head_size, + rotary_dim, + intermediate_size, + is_neox_style, + eps, + scale, + num_tokens, +): + """Run one LLaMA decoder layer using InfiniOps kernels.""" + kw = _stream_kw(hidden) + dtype = hidden.dtype + device = hidden.device + hidden_size = hidden.size(-1) + + # Save residual. + residual = hidden.clone() + + # 1. Input RMSNorm. + normed = torch.empty_like(hidden) + infini.ops.rms_norm(hidden, input_norm_w, eps, normed, **kw) + + # 2. QKV projection: [T, D] @ [D, (N+2*Nkv)*H] -> [T, (N+2*Nkv)*H]. + qkv_dim = (num_heads + 2 * num_kv_heads) * head_size + qkv = torch.empty(num_tokens, qkv_dim, dtype=dtype, device=device) + infini.ops.gemm(normed, qkv_proj_w, 1.0, 0.0, False, False, qkv, **kw) + + # Split Q, K, V. + q = ( + qkv[:, : num_heads * head_size] + .reshape( + num_tokens, + num_heads, + head_size, + ) + .contiguous() + ) + k = ( + qkv[:, num_heads * head_size : (num_heads + num_kv_heads) * head_size] + .reshape( + num_tokens, + num_kv_heads, + head_size, + ) + .contiguous() + ) + v = ( + qkv[:, (num_heads + num_kv_heads) * head_size :] + .reshape( + num_tokens, + num_kv_heads, + head_size, + ) + .contiguous() + ) + + # 3. RoPE. + q_rot = torch.empty_like(q) + k_rot = torch.empty_like(k) + infini.ops.rotary_embedding( + positions, + q, + k, + cos_sin_cache, + head_size, + rotary_dim, + is_neox_style, + q_rot, + k_rot, + **kw, + ) + + # 4. Flash attention (single-sequence prefill, causal). + attn_out = torch.empty( + num_tokens, + num_heads, + head_size, + dtype=dtype, + device=device, + ) + infini.ops.flash_attention( + q_rot, + k_rot, + v, + None, + None, + None, + num_heads, + num_kv_heads, + head_size, + scale, + True, + -1, + 0, + 0, + attn_out, + **kw, + ) + + # 5. O projection: [T, N*H] @ [N*H, D] -> [T, D]. + attn_2d = attn_out.reshape(num_tokens, num_heads * head_size) + o_out = torch.empty(num_tokens, hidden_size, dtype=dtype, device=device) + infini.ops.gemm(attn_2d, o_proj_w, 1.0, 0.0, False, False, o_out, **kw) + + # 6. Residual add. + after_attn = torch.empty_like(residual) + infini.ops.add(residual, o_out, after_attn, **kw) + + # 7. Post-attention RMSNorm. + residual2 = after_attn.clone() + normed2 = torch.empty_like(after_attn) + infini.ops.rms_norm(after_attn, post_norm_w, eps, normed2, **kw) + + # 8. Gate + up projections. + gate = torch.empty(num_tokens, intermediate_size, dtype=dtype, device=device) + up = torch.empty(num_tokens, intermediate_size, dtype=dtype, device=device) + infini.ops.gemm(normed2, gate_proj_w, 1.0, 0.0, False, False, gate, **kw) + infini.ops.gemm(normed2, up_proj_w, 1.0, 0.0, False, False, up, **kw) + + # 9. SwiGLU: ``up * silu(gate)``. + ffn = torch.empty(num_tokens, intermediate_size, dtype=dtype, device=device) + infini.ops.swiglu(up, gate, ffn, **kw) + + # 10. Down projection: [T, FFN] @ [FFN, D] -> [T, D]. + down = torch.empty(num_tokens, hidden_size, dtype=dtype, device=device) + infini.ops.gemm(ffn, down_proj_w, 1.0, 0.0, False, False, down, **kw) + + # 11. Second residual add. + output = torch.empty_like(residual2) + infini.ops.add(residual2, down, output, **kw) + + return output + + +def _reference_layer( + hidden, + positions, + cos_sin_cache, + input_norm_w, + qkv_proj_w, + o_proj_w, + gate_proj_w, + up_proj_w, + down_proj_w, + post_norm_w, + num_heads, + num_kv_heads, + head_size, + rotary_dim, + intermediate_size, + is_neox_style, + eps, + scale, + num_tokens, +): + """PyTorch float32 reference for one LLaMA decoder layer.""" + # Compute in float32 on CPU for accuracy. + h = hidden.float().cpu() + pos = positions.cpu() + csc = cos_sin_cache.float().cpu() + inw = input_norm_w.float().cpu() + qkvw = qkv_proj_w.float().cpu() + ow = o_proj_w.float().cpu() + gw = gate_proj_w.float().cpu() + uw = up_proj_w.float().cpu() + dw = down_proj_w.float().cpu() + pnw = post_norm_w.float().cpu() + + # 1. Input RMSNorm. + residual = h.clone() + normed = _ref_rms_norm(h, inw, eps) + + # 2. QKV projection. + qkv = normed @ qkvw + + q = qkv[:, : num_heads * head_size].reshape(num_tokens, num_heads, head_size) + k = qkv[:, num_heads * head_size : (num_heads + num_kv_heads) * head_size].reshape( + num_tokens, + num_kv_heads, + head_size, + ) + v = qkv[:, (num_heads + num_kv_heads) * head_size :].reshape( + num_tokens, + num_kv_heads, + head_size, + ) + + # 3. RoPE. + q_rot, k_rot = _ref_rope( + pos, + q, + k, + csc, + head_size, + rotary_dim, + is_neox_style, + ) + + # 4. SDPA. + attn_out = _ref_sdpa( + q_rot, k_rot, v, num_heads, num_kv_heads, head_size, scale, causal=True + ) + + # 5. O projection. + attn_2d = attn_out.reshape(num_tokens, num_heads * head_size) + o_out = attn_2d @ ow + + # 6. Residual add. + after_attn = residual + o_out + + # 7. Post-attention RMSNorm. + residual2 = after_attn.clone() + normed2 = _ref_rms_norm(after_attn, pnw, eps) + + # 8. Gate + up projections. + gate = normed2 @ gw + up = normed2 @ uw + + # 9. SwiGLU: ``up * silu(gate)``. + ffn = up * (gate * torch.sigmoid(gate)) + + # 10. Down projection. + down = ffn @ dw + + # 11. Second residual add. + output = residual2 + down + + return output.to(hidden.dtype).to(hidden.device) + + +def _make_rope_cache(max_seq_len, rotary_dim, dtype, device): + """Build a proper RoPE cos/sin cache (bounded to [-1, 1]).""" + freq = 1.0 / (10000.0 ** (torch.arange(0, rotary_dim, 2).float() / rotary_dim)) + t = torch.arange(max_seq_len, dtype=torch.float32) + angles = torch.outer(t, freq) # [max_seq_len, half_dim] + cos_half = torch.cos(angles).to(dtype=dtype, device=device) + sin_half = torch.sin(angles).to(dtype=dtype, device=device) + + return torch.cat([cos_half, sin_half], dim=-1) + + +@pytest.mark.parametrize("device", ("npu",)) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float16, 5e-3, 5e-3), + (torch.bfloat16, 1e-2, 2e-2), + ), +) +def test_llama_layer(device, dtype, rtol, atol): + """End-to-end test of a LLaMA decoder layer using InfiniOps kernels.""" + if device == "npu" and not (hasattr(torch, "npu") and torch.npu.is_available()): + pytest.skip("NPU not available") + + # Small LLaMA-like model config. + hidden_size = 512 + num_heads = 8 + num_kv_heads = 2 + head_size = hidden_size // num_heads + intermediate_size = 1024 + num_tokens = 1 + max_seq_len = 16 + rotary_dim = head_size + is_neox_style = True + eps = 1e-6 + scale = 1.0 / head_size**0.5 + + def _scaled_weight(*shape): + return randn_strided(shape, None, dtype=dtype, device=device) / shape[0] ** 0.5 + + # Random weights (stored as [in_features, out_features], Xavier-scaled). + qkv_proj_w = _scaled_weight( + hidden_size, + (num_heads + 2 * num_kv_heads) * head_size, + ) + o_proj_w = _scaled_weight(num_heads * head_size, hidden_size) + gate_proj_w = _scaled_weight(hidden_size, intermediate_size) + up_proj_w = _scaled_weight(hidden_size, intermediate_size) + down_proj_w = _scaled_weight(intermediate_size, hidden_size) + input_norm_w = torch.ones(hidden_size, dtype=dtype, device=device) + post_norm_w = torch.ones(hidden_size, dtype=dtype, device=device) + + # Proper cos/sin cache from frequency decomposition (bounded [-1, 1]). + cos_sin_cache = _make_rope_cache(max_seq_len, rotary_dim, dtype, device) + positions = randint_strided( + 0, + max_seq_len, + (num_tokens,), + None, + dtype=torch.int64, + device=device, + ) + + # Input hidden states scaled to prevent value explosion through layers. + hidden = ( + randn_strided( + (num_tokens, hidden_size), + None, + dtype=dtype, + device=device, + ) + / hidden_size**0.5 + ) + + common = dict( + positions=positions, + cos_sin_cache=cos_sin_cache, + input_norm_w=input_norm_w, + qkv_proj_w=qkv_proj_w, + o_proj_w=o_proj_w, + gate_proj_w=gate_proj_w, + up_proj_w=up_proj_w, + down_proj_w=down_proj_w, + post_norm_w=post_norm_w, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_size=head_size, + rotary_dim=rotary_dim, + intermediate_size=intermediate_size, + is_neox_style=is_neox_style, + eps=eps, + scale=scale, + num_tokens=num_tokens, + ) + + infini_out = _infiniops_layer(hidden, **common) + ref_out = _reference_layer(hidden, **common) + + max_diff = (infini_out.float() - ref_out.float()).abs().max().item() + assert torch.allclose(infini_out, ref_out, rtol=rtol, atol=atol), ( + f"Max diff: {max_diff}" + ) diff --git a/tests/test_flash_attention.py b/tests/test_flash_attention.py new file mode 100644 index 00000000..6f439f71 --- /dev/null +++ b/tests/test_flash_attention.py @@ -0,0 +1,715 @@ +import infini.ops +import pytest +import torch + +from tests.utils import Payload, get_npu_stream, randn_strided + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "num_heads, num_kv_heads, head_size", + ( + (32, 32, 128), # MHA + (32, 8, 128), # GQA (4x) + (16, 4, 64), # GQA (4x), smaller + ), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ), +) +@pytest.mark.parametrize("device", ("cuda", "npu")) +def test_flash_attention_prefill_single( + num_heads, + num_kv_heads, + head_size, + dtype, + rtol, + atol, + device, +): + """Single sequence prefill (no block table).""" + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + if device == "npu" and not (hasattr(torch, "npu") and torch.npu.is_available()): + pytest.skip("NPU not available") + + num_tokens = 16 + scale = 1.0 / head_size**0.5 + + query = randn_strided( + (num_tokens, num_heads, head_size), None, dtype=dtype, device=device + ) + key = randn_strided( + (num_tokens, num_kv_heads, head_size), None, dtype=dtype, device=device + ) + value = randn_strided( + (num_tokens, num_kv_heads, head_size), None, dtype=dtype, device=device + ) + output = torch.empty((num_tokens, num_heads, head_size), dtype=dtype, device=device) + + return Payload( + lambda q, k, v, o: _flash_attention( + q, + k, + v, + None, + None, + None, + num_heads, + num_kv_heads, + head_size, + scale, + True, + -1, + 0, + 0, + o, + ), + lambda q, k, v, o: _ref_flash_attention( + q, + k, + v, + num_heads, + num_kv_heads, + head_size, + scale, + causal=True, + ), + (query, key, value, output), + {}, + rtol=rtol, + atol=atol, + ) + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "num_heads, num_kv_heads, head_size", + ( + (32, 32, 128), # MHA + (32, 8, 128), # GQA (4x) + (16, 4, 64), # GQA (4x), smaller + ), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ), +) +@pytest.mark.parametrize("device", ("cuda",)) +def test_flash_attention_prefill_single_noncausal( + num_heads, + num_kv_heads, + head_size, + dtype, + rtol, + atol, + device, +): + """Single sequence prefill, non-causal.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + num_tokens = 16 + scale = 1.0 / head_size**0.5 + + query = randn_strided( + (num_tokens, num_heads, head_size), None, dtype=dtype, device=device + ) + key = randn_strided( + (num_tokens, num_kv_heads, head_size), None, dtype=dtype, device=device + ) + value = randn_strided( + (num_tokens, num_kv_heads, head_size), None, dtype=dtype, device=device + ) + output = torch.empty((num_tokens, num_heads, head_size), dtype=dtype, device=device) + + return Payload( + lambda q, k, v, o: _flash_attention( + q, + k, + v, + None, + None, + None, + num_heads, + num_kv_heads, + head_size, + scale, + False, + -1, + 0, + 0, + o, + ), + lambda q, k, v, o: _ref_flash_attention( + q, + k, + v, + num_heads, + num_kv_heads, + head_size, + scale, + causal=False, + ), + (query, key, value, output), + {}, + rtol=rtol, + atol=atol, + ) + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "num_heads, num_kv_heads, head_size", + ((32, 8, 128),), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ), +) +@pytest.mark.parametrize("device", ("npu",)) +def test_flash_attention_prefill_multi( + num_heads, + num_kv_heads, + head_size, + dtype, + rtol, + atol, + device, +): + """Multi-sequence prefill with cu_seqlens.""" + if device == "npu" and not (hasattr(torch, "npu") and torch.npu.is_available()): + pytest.skip("NPU not available") + + seq_lens = [8, 12, 4] + num_tokens = sum(seq_lens) + scale = 1.0 / head_size**0.5 + + query = randn_strided( + (num_tokens, num_heads, head_size), None, dtype=dtype, device=device + ) + key = randn_strided( + (num_tokens, num_kv_heads, head_size), None, dtype=dtype, device=device + ) + value = randn_strided( + (num_tokens, num_kv_heads, head_size), None, dtype=dtype, device=device + ) + output = torch.empty((num_tokens, num_heads, head_size), dtype=dtype, device=device) + + cu_seqlens_q = torch.tensor( + [0] + [sum(seq_lens[: i + 1]) for i in range(len(seq_lens))], + dtype=torch.int64, + device=device, + ) + cu_seqlens_kv = cu_seqlens_q.clone() + + return Payload( + lambda q, k, v, o: _flash_attention( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_kv, + None, + num_heads, + num_kv_heads, + head_size, + scale, + True, + -1, + 0, + 0, + o, + ), + lambda q, k, v, o: _ref_flash_attention_multi( + q, + k, + v, + seq_lens, + seq_lens, + num_heads, + num_kv_heads, + head_size, + scale, + causal=True, + ), + (query, key, value, output), + {}, + rtol=rtol, + atol=atol, + ) + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "num_heads, num_kv_heads, head_size, block_size", + ( + (32, 8, 128, 128), + (16, 4, 64, 128), + ), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ), +) +@pytest.mark.parametrize("device", ("npu",)) +def test_flash_attention_decode( + num_heads, + num_kv_heads, + head_size, + block_size, + dtype, + rtol, + atol, + device, +): + """Decode phase: single token per request with paged KV cache.""" + if device == "npu" and not (hasattr(torch, "npu") and torch.npu.is_available()): + pytest.skip("NPU not available") + + num_reqs = 3 + kv_len = 16 # Total KV length per request. + num_blocks_per_req = (kv_len + block_size - 1) // block_size + num_blocks = num_reqs * num_blocks_per_req + scale = 1.0 / head_size**0.5 + + query = randn_strided( + (num_reqs, num_heads, head_size), None, dtype=dtype, device=device + ) + # Paged KV cache: vLLM standard layout [num_blocks, block_size, KV_N, D]. + kv_cache = randn_strided( + (num_blocks, block_size, num_kv_heads, head_size), + None, + dtype=dtype, + device=device, + ) + output = torch.empty((num_reqs, num_heads, head_size), dtype=dtype, device=device) + + # Block table: request i uses blocks [i*num_blocks_per_req, ...]. + block_table = torch.zeros( + (num_reqs, num_blocks_per_req), dtype=torch.int32, device=device + ) + for i in range(num_reqs): + for j in range(num_blocks_per_req): + block_table[i, j] = i * num_blocks_per_req + j + + cu_seqlens_q = torch.arange(0, num_reqs + 1, dtype=torch.int64, device=device) + cu_seqlens_kv = torch.tensor( + [i * kv_len for i in range(num_reqs + 1)], dtype=torch.int64, device=device + ) + + return Payload( + lambda q, k, v, o: _flash_attention( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_kv, + block_table, + num_heads, + num_kv_heads, + head_size, + scale, + True, + -1, + 0, + block_size, + o, + ), + lambda q, k, v, o: _ref_flash_attention_paged( + q, + k, + block_table, + cu_seqlens_q, + cu_seqlens_kv, + num_heads, + num_kv_heads, + head_size, + block_size, + scale, + causal=True, + ), + (query, kv_cache, kv_cache, output), + {}, + rtol=rtol, + atol=atol, + ) + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "num_heads, num_kv_heads, head_size", + ( + (32, 32, 128), # MHA + (32, 8, 128), # GQA (4x) + (16, 4, 64), # GQA (4x), smaller + ), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ), +) +@pytest.mark.parametrize("device", ("cuda",)) +def test_flash_attention_prefill_multi_cuda( + num_heads, + num_kv_heads, + head_size, + dtype, + rtol, + atol, + device, +): + """Multi-sequence prefill with cu_seqlens on CUDA.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + seq_lens = [8, 12, 4] + num_tokens = sum(seq_lens) + scale = 1.0 / head_size**0.5 + + query = randn_strided( + (num_tokens, num_heads, head_size), None, dtype=dtype, device=device + ) + key = randn_strided( + (num_tokens, num_kv_heads, head_size), None, dtype=dtype, device=device + ) + value = randn_strided( + (num_tokens, num_kv_heads, head_size), None, dtype=dtype, device=device + ) + output = torch.empty((num_tokens, num_heads, head_size), dtype=dtype, device=device) + + cu_seqlens_q = torch.tensor( + [0] + [sum(seq_lens[: i + 1]) for i in range(len(seq_lens))], + dtype=torch.int64, + device=device, + ) + cu_seqlens_kv = cu_seqlens_q.clone() + + return Payload( + lambda q, k, v, o: _flash_attention( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_kv, + None, + num_heads, + num_kv_heads, + head_size, + scale, + True, + -1, + 0, + 0, + o, + ), + lambda q, k, v, o: _ref_flash_attention_multi( + q, + k, + v, + seq_lens, + seq_lens, + num_heads, + num_kv_heads, + head_size, + scale, + causal=True, + ), + (query, key, value, output), + {}, + rtol=rtol, + atol=atol, + ) + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "num_heads, num_kv_heads, head_size, block_size", + ( + (32, 8, 128, 128), + (16, 4, 64, 128), + ), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ), +) +@pytest.mark.parametrize("device", ("cuda",)) +def test_flash_attention_paged_decode_cuda( + num_heads, + num_kv_heads, + head_size, + block_size, + dtype, + rtol, + atol, + device, +): + """Decode phase: single token per request with paged KV cache on CUDA.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + num_reqs = 3 + kv_len = 16 # Total KV length per request. + num_blocks_per_req = (kv_len + block_size - 1) // block_size + num_blocks = num_reqs * num_blocks_per_req + scale = 1.0 / head_size**0.5 + + query = randn_strided( + (num_reqs, num_heads, head_size), None, dtype=dtype, device=device + ) + # Paged KV cache: vLLM standard layout [num_blocks, block_size, KV_N, D]. + kv_cache = randn_strided( + (num_blocks, block_size, num_kv_heads, head_size), + None, + dtype=dtype, + device=device, + ) + output = torch.empty((num_reqs, num_heads, head_size), dtype=dtype, device=device) + + # Block table: request i uses blocks [i*num_blocks_per_req, ...]. + block_table = torch.zeros( + (num_reqs, num_blocks_per_req), dtype=torch.int32, device=device + ) + + for i in range(num_reqs): + for j in range(num_blocks_per_req): + block_table[i, j] = i * num_blocks_per_req + j + + cu_seqlens_q = torch.arange(0, num_reqs + 1, dtype=torch.int64, device=device) + cu_seqlens_kv = torch.tensor( + [i * kv_len for i in range(num_reqs + 1)], + dtype=torch.int64, + device=device, + ) + + return Payload( + lambda q, k, v, o: _flash_attention( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_kv, + block_table, + num_heads, + num_kv_heads, + head_size, + scale, + True, + -1, + 0, + block_size, + o, + ), + lambda q, k, v, o: _ref_flash_attention_paged( + q, + k, + block_table, + cu_seqlens_q, + cu_seqlens_kv, + num_heads, + num_kv_heads, + head_size, + block_size, + scale, + causal=True, + ), + (query, kv_cache, kv_cache, output), + {}, + rtol=rtol, + atol=atol, + ) + + +def _flash_attention( + query, + key, + value, + cu_seqlens_q, + cu_seqlens_kv, + block_table, + num_heads, + num_kv_heads, + head_size, + scale, + causal, + window_left, + window_right, + block_size, + output, +): + if query.device.type == "npu": + infini.ops.flash_attention( + query, + key, + value, + cu_seqlens_q, + cu_seqlens_kv, + block_table, + num_heads, + num_kv_heads, + head_size, + scale, + causal, + window_left, + window_right, + block_size, + output, + stream=get_npu_stream(query), + ) + else: + infini.ops.flash_attention( + query, + key, + value, + cu_seqlens_q, + cu_seqlens_kv, + block_table, + num_heads, + num_kv_heads, + head_size, + scale, + causal, + window_left, + window_right, + block_size, + output, + ) + + return output + + +def _ref_flash_attention( + query, key, value, num_heads, num_kv_heads, head_size, scale, causal=True +): + """PyTorch SDPA reference for single-sequence prefill.""" + # [T, N, D] -> [N, T, D] + q = query.transpose(0, 1).float() + k = key.transpose(0, 1).float() + v = value.transpose(0, 1).float() + + # GQA: expand K/V to match num_heads. + if num_kv_heads < num_heads: + ratio = num_heads // num_kv_heads + k = k.repeat_interleave(ratio, dim=0) + v = v.repeat_interleave(ratio, dim=0) + + # [N, T, D] -> [1, N, T, D] for scaled_dot_product_attention. + q = q.unsqueeze(0) + k = k.unsqueeze(0) + v = v.unsqueeze(0) + + out = torch.nn.functional.scaled_dot_product_attention( + q, k, v, scale=scale, is_causal=causal + ) + + # [1, N, T, D] -> [T, N, D] -> original dtype. + return out.squeeze(0).transpose(0, 1).to(query.dtype) + + +def _ref_flash_attention_multi( + query, + key, + value, + seq_lens_q, + seq_lens_kv, + num_heads, + num_kv_heads, + head_size, + scale, + causal=True, +): + """PyTorch SDPA reference for multi-sequence prefill.""" + outputs = [] + offset = 0 + for sq, sk in zip(seq_lens_q, seq_lens_kv): + q = query[offset : offset + sq] + k = key[offset : offset + sq] + v = value[offset : offset + sq] + out = _ref_flash_attention( + q, k, v, num_heads, num_kv_heads, head_size, scale, causal + ) + outputs.append(out) + offset += sq + + return torch.cat(outputs, dim=0) + + +def _ref_flash_attention_paged( + query, + kv_cache_arg, + block_table, + cu_seqlens_q, + cu_seqlens_kv, + num_heads, + num_kv_heads, + head_size, + block_size, + scale, + causal=True, +): + """PyTorch SDPA reference for decode with paged KV cache.""" + cu_kv = cu_seqlens_kv.cpu() + bt = block_table.cpu() + cache = kv_cache_arg.cpu() + q_cpu = query.cpu() + num_reqs = bt.size(0) + outputs = [] + + for i in range(num_reqs): + q = q_cpu[i : i + 1] # [1, N, D] + kv_len = int(cu_kv[i + 1] - cu_kv[i]) + + # Gather KV from paged cache. + # cache: [num_blocks, KV_N, block_size, D] + blocks = bt[i] + k_pages = [] + v_pages = [] + remaining = kv_len + for b in blocks: + if remaining <= 0: + break + take = min(remaining, block_size) + # cache layout: [num_blocks, block_size, KV_N, D] + # Slice [take, KV_N, D], transpose to [KV_N, take, D] for cat. + k_pages.append(cache[int(b.item()), :take, :, :].transpose(0, 1)) + v_pages.append(cache[int(b.item()), :take, :, :].transpose(0, 1)) + remaining -= take + k = torch.cat(k_pages, dim=1) # [KV_N, kv_len, D] + v = torch.cat(v_pages, dim=1) + + # Decode: Q_S=1 attends to all past KV positions; causal masking is + # not applicable here (it would mask everything beyond position 0). + out = _ref_flash_attention( + q, # [1, N, D] - already TND format + k.transpose(0, 1), # [KV_N, kv_len, D] -> [kv_len, KV_N, D] + v.transpose(0, 1), + num_heads, + num_kv_heads, + head_size, + scale, + causal=False, + ) + outputs.append(out) + + return torch.cat(outputs, dim=0).to(query.device) diff --git a/tests/test_gemm.py b/tests/test_gemm.py index 40ed35df..2c3adec4 100644 --- a/tests/test_gemm.py +++ b/tests/test_gemm.py @@ -2,7 +2,7 @@ import pytest import torch -from tests.utils import Payload, randn_strided +from tests.utils import Payload, get_npu_stream, randn_strided @pytest.mark.auto_act_and_assert @@ -59,8 +59,13 @@ def test_gemm( if implementation_index not in active_indices: pytest.skip(f"implementation `{implementation_index}` not active on `{device}`") + # cuBLASLt (implementation_index=1, implementation="cublaslt") is 2-3x + # faster than cuBLAS on typical LLM shapes, but TF32 compute mode + # produces slightly different results for fp16/bf16 that exceed the + # current test tolerances (rtol=1e-2). Use `implementation="cublaslt"` + # in production for better performance. if implementation_index == 1 and dtype in (torch.float16, torch.bfloat16): - pytest.skip("cuBLASLt half-precision exceeds current tolerances") + pytest.skip("cuBLASLt TF32 results exceed current tolerances (use for perf, not precision)") a = randn_strided(a_shape, a_strides, dtype=dtype, device=device) b = randn_strided(b_shape, b_strides, dtype=dtype, device=device) @@ -84,16 +89,28 @@ def test_gemm( def _gemm(a, b, alpha, beta, trans_a, trans_b, c, implementation_index=0): - infini.ops.gemm( - a, - b, - alpha, - beta, - trans_a, - trans_b, - c, - implementation_index=implementation_index, - ) + if a.device.type == "npu": + infini.ops.gemm( + a, + b, + alpha, + beta, + trans_a, + trans_b, + c, + stream=get_npu_stream(a), + ) + else: + infini.ops.gemm( + a, + b, + alpha, + beta, + trans_a, + trans_b, + c, + implementation_index=implementation_index, + ) return c diff --git a/tests/test_linear.py b/tests/test_linear.py new file mode 100644 index 00000000..db9608f4 --- /dev/null +++ b/tests/test_linear.py @@ -0,0 +1,87 @@ +import infini.ops +import pytest +import torch + +from tests.utils import Payload, randn_strided + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "a_shape, b_shape, out_shape", + ( + ((1, 128), (128, 64), (1, 64)), + ((4, 256), (256, 128), (4, 128)), + ((2, 4, 128), (2, 128, 64), (2, 4, 64)), + ), +) +@pytest.mark.parametrize("has_bias", (False, True)) +@pytest.mark.parametrize("trans_a", (False, True)) +@pytest.mark.parametrize("trans_b", (False, True)) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float32, 1e-3, 1e-3), + (torch.float16, 5e-2, 5e-2), + (torch.bfloat16, 5e-2, 5e-2), + ), +) +def test_linear( + a_shape, + b_shape, + out_shape, + has_bias, + trans_a, + trans_b, + dtype, + device, + rtol, + atol, +): + if device == "cpu": + pytest.skip("CPU Linear is not implemented") + + a = randn_strided(a_shape, None, dtype=dtype, device=device) + b = randn_strided(b_shape, None, dtype=dtype, device=device) + + if trans_a: + a = a.transpose(-2, -1) + + if trans_b: + b = b.transpose(-2, -1) + + out = randn_strided(out_shape, None, dtype=dtype, device=device) + + bias = None + + if has_bias: + n = out_shape[-1] + bias = randn_strided((n,), None, dtype=dtype, device=device) + + return Payload( + lambda *args: _linear(*args), + _torch_linear, + (a, b, bias, trans_a, trans_b, out), + {}, + rtol=rtol, + atol=atol, + ) + + +def _linear(a, b, bias, trans_a, trans_b, out): + infini.ops.linear(a, b, bias, trans_a, trans_b, out) + + return out + + +def _torch_linear(a, b, bias, trans_a, trans_b, out): + a_mat = a.transpose(-2, -1) if trans_a else a + b_mat = b.transpose(-2, -1) if trans_b else b + + result = torch.matmul(a_mat.float(), b_mat.float()).to(out.dtype) + + if bias is not None: + result = result + bias + + out.copy_(result) + + return out diff --git a/tests/test_matmul.py b/tests/test_matmul.py new file mode 100644 index 00000000..950e5f02 --- /dev/null +++ b/tests/test_matmul.py @@ -0,0 +1,102 @@ +import infini.ops +import pytest +import torch + +from tests.utils import Payload, randn_strided + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "a_shape, b_shape, c_shape, a_strides, b_strides, c_strides", + ( + ((1, 2048), (2048, 2048), (1, 2048), None, None, None), + ((2, 4, 2048), (2, 2048, 2048), (2, 4, 2048), None, None, None), + ((1, 2048), (2048, 2048), (1, 2048), (4096, 1), (4096, 1), (4096, 1)), + ((6, 2048), (2048, 2560), (6, 2560), (2048, 1), (1, 2048), (2560, 1)), + ((4, 48, 64), (4, 64, 6), (4, 48, 6), None, None, None), + ), +) +@pytest.mark.parametrize("trans_a", (False, True)) +@pytest.mark.parametrize("trans_b", (False, True)) +@pytest.mark.parametrize("implementation_index", (0, 1)) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float32, 1e-3, 1e-3), + (torch.float16, 5e-2, 5e-2), + (torch.bfloat16, 5e-2, 5e-2), + ), +) +def test_matmul( + a_shape, + b_shape, + c_shape, + a_strides, + b_strides, + c_strides, + trans_a, + trans_b, + implementation_index, + dtype, + device, + rtol, + atol, +): + active_indices = infini.ops.Matmul.active_implementation_indices(device) + + if implementation_index not in active_indices: + pytest.skip(f"implementation `{implementation_index}` not active on `{device}`") + + if implementation_index == 0 and dtype in (torch.float16, torch.bfloat16): + pytest.skip("cuBLASLt half-precision exceeds current tolerances") + + a = randn_strided(a_shape, a_strides, dtype=dtype, device=device) + b = randn_strided(b_shape, b_strides, dtype=dtype, device=device) + + if trans_a: + a = a.transpose(-2, -1) + + if trans_b: + b = b.transpose(-2, -1) + + c = randn_strided(c_shape, c_strides, dtype=dtype, device=device) + + return Payload( + lambda *args: _matmul(*args, implementation_index=implementation_index), + _torch_matmul, + (a, b, c, trans_a, trans_b), + {}, + rtol=rtol, + atol=atol, + ) + + +def _matmul(a, b, c, trans_a, trans_b, implementation_index=0): + infini.ops.matmul( + a, + b, + c, + trans_a, + trans_b, + implementation_index=implementation_index, + ) + + return c + + +def _torch_matmul(a, b, c, trans_a=False, trans_b=False): + if trans_a: + a = a.transpose(-2, -1) + + if trans_b: + b = b.transpose(-2, -1) + + try: + return torch.matmul(a, b, out=c) + except RuntimeError: + # Fallback for backends that don't support `matmul(out=...)` for + # certain strided outputs or half-precision types. + result = torch.matmul(a.float(), b.float()) + c.copy_(result.to(c.dtype)) + + return c diff --git a/tests/test_mul.py b/tests/test_mul.py new file mode 100644 index 00000000..ea7f9180 --- /dev/null +++ b/tests/test_mul.py @@ -0,0 +1,90 @@ +import infini.ops +import pytest +import torch + +from tests.utils import ( + Payload, + empty_strided, + get_npu_stream, + randint_strided, + randn_strided, +) + +_INT_DTYPES = (torch.int16, torch.int32, torch.int64) + +_UINT_DTYPES = tuple( + filter(None, (getattr(torch, f"uint{bits}", None) for bits in (16, 32, 64))) +) + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "shape, input_strides, other_strides, out_strides", + ( + ((13, 4), None, None, None), + ((13, 4), (10, 1), (10, 1), (10, 1)), + ((13, 4), (0, 1), None, None), + ((13, 4, 4), None, None, None), + ((13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1)), + ((13, 4, 4), (4, 0, 1), (0, 4, 1), None), + ((16, 5632), None, None, None), + ((16, 5632), (13312, 1), (13312, 1), (13312, 1)), + ((13, 16, 2), (128, 4, 1), (0, 2, 1), (64, 4, 1)), + ((13, 16, 2), (128, 4, 1), (2, 0, 1), (64, 4, 1)), + ((4, 4, 5632), None, None, None), + ((4, 4, 5632), (45056, 5632, 1), (45056, 5632, 1), (45056, 5632, 1)), + ), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float32, 1e-7, 1e-7), + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ) + + tuple((dtype, 0, 0) for dtype in _INT_DTYPES + _UINT_DTYPES), +) +def test_mul( + shape, input_strides, other_strides, out_strides, dtype, device, rtol, atol +): + if device == "musa" and dtype in _UINT_DTYPES: + pytest.skip( + "The `torch.musa` test cloning path does not support `uint16`, `uint32`, or `uint64`." + ) + + if dtype in _INT_DTYPES or dtype in _UINT_DTYPES: + input = randint_strided( + 0, 100, shape, input_strides, dtype=dtype, device=device + ) + other = randint_strided( + 0, 100, shape, other_strides, dtype=dtype, device=device + ) + else: + input = randn_strided(shape, input_strides, dtype=dtype, device=device) + other = randn_strided(shape, other_strides, dtype=dtype, device=device) + + out = empty_strided(shape, out_strides, dtype=dtype, device=device) + + return Payload(_mul, _torch_mul, (input, other, out), {}, rtol=rtol, atol=atol) + + +def _mul(input, other, out): + if input.device.type == "npu": + infini.ops.mul(input, other, out, stream=get_npu_stream(input)) + else: + infini.ops.mul(input, other, out) + + return out + + +def _torch_mul(input, other, out): + if input.dtype in _UINT_DTYPES: + input = input.to(torch.int64) + + if other.dtype in _UINT_DTYPES: + other = other.to(torch.int64) + + res = torch.mul(input, other) + out.copy_(res.to(out.dtype)) + + return out diff --git a/tests/test_mul_dsl.py b/tests/test_mul_dsl.py new file mode 100644 index 00000000..afd55bd1 --- /dev/null +++ b/tests/test_mul_dsl.py @@ -0,0 +1,56 @@ +"""Tests for the DSL-generated Mul operator (implementation_index=1). + +Validates that the DSL-generated CUDA and CPU code produces results +identical to PyTorch's `torch.mul`. +""" + +import infini.ops +import pytest +import torch + +from tests.utils import Payload, empty_strided, randn_strided + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "shape, input_strides, other_strides, out_strides", + ( + ((13, 4), None, None, None), + ((13, 4), (10, 1), (10, 1), (10, 1)), + ((13, 4), (0, 1), None, None), + ((13, 4, 4), None, None, None), + ((16, 5632), None, None, None), + ((4, 4, 5632), None, None, None), + ), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float32, 1e-7, 1e-7), + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ), +) +def test_mul_dsl( + shape, input_strides, other_strides, out_strides, dtype, device, rtol, atol +): + input = randn_strided(shape, input_strides, dtype=dtype, device=device) + other = randn_strided(shape, other_strides, dtype=dtype, device=device) + out = empty_strided(shape, out_strides, dtype=dtype, device=device) + + return Payload( + _mul_dsl, _torch_mul, (input, other, out), {}, rtol=rtol, atol=atol + ) + + +def _mul_dsl(input, other, out): + infini.ops.mul(input, other, out, implementation="dsl") + + return out + + +def _torch_mul(input, other, out): + res = torch.mul(input, other) + out.copy_(res) + + return out diff --git a/tests/test_reshape_and_cache.py b/tests/test_reshape_and_cache.py new file mode 100644 index 00000000..f409e85c --- /dev/null +++ b/tests/test_reshape_and_cache.py @@ -0,0 +1,90 @@ +import infini.ops +import pytest +import torch + +from tests.utils import Payload + + +def _reshape_and_cache_ref(key, value, kv_cache, slot_mapping, kv_cache_out): + """Reference implementation: scatter key/value into paged KV cache.""" + kv_cache_out.copy_(kv_cache) + num_tokens = key.size(0) + + for i in range(num_tokens): + slot = slot_mapping[i].item() + + if slot < 0: + continue + + block_size = kv_cache_out.size(2) + block_idx = slot // block_size + block_offset = slot % block_size + + # kv_cache_out shape: [2, num_blocks, block_size, num_kv_heads, head_size] + kv_cache_out[0, block_idx, block_offset, :, :] = key[i] + kv_cache_out[1, block_idx, block_offset, :, :] = value[i] + + return kv_cache_out + + +def _reshape_and_cache(key, value, kv_cache, slot_mapping, kv_cache_out): + infini.ops.reshape_and_cache(key, value, kv_cache, slot_mapping, kv_cache_out) + + return kv_cache_out + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "num_tokens, num_kv_heads, head_size, num_blocks, block_size", + ( + (1, 1, 64, 1, 1), + (4, 8, 64, 4, 16), + (7, 4, 128, 8, 32), + (16, 32, 128, 16, 16), + (3, 2, 64, 2, 8), + ), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float32, 0, 0), + (torch.float16, 0, 0), + (torch.bfloat16, 0, 0), + ), +) +def test_reshape_and_cache( + num_tokens, num_kv_heads, head_size, num_blocks, block_size, dtype, device, + rtol, atol +): + total_slots = num_blocks * block_size + + if num_tokens > total_slots: + pytest.skip("more tokens than available slots") + + key = torch.randn( + num_tokens, num_kv_heads, head_size, dtype=dtype, device=device + ) + value = torch.randn( + num_tokens, num_kv_heads, head_size, dtype=dtype, device=device + ) + + kv_cache = torch.zeros( + 2, num_blocks, block_size, num_kv_heads, head_size, + dtype=dtype, device=device, + ) + + # Build a slot mapping: assign each token a unique random slot. + slots = torch.randperm(total_slots)[:num_tokens].to( + dtype=torch.int64, device=device + ) + + kv_cache_out = kv_cache.clone() + + return Payload( + _reshape_and_cache, + _reshape_and_cache_ref, + (key, value, kv_cache, slots, kv_cache_out), + {}, + rtol=rtol, + atol=atol, + ) diff --git a/tests/test_rms_norm.py b/tests/test_rms_norm.py index d6d4dff1..ba540a95 100644 --- a/tests/test_rms_norm.py +++ b/tests/test_rms_norm.py @@ -2,7 +2,7 @@ import pytest import torch -from tests.utils import Payload, empty_strided, randn_strided +from tests.utils import Payload, empty_strided, get_npu_stream, randn_strided @pytest.mark.auto_act_and_assert @@ -53,7 +53,10 @@ def test_rms_norm( def _rms_norm(input, weight, *, eps=1e-6, out=None): - infini.ops.rms_norm(input, weight, eps, out) + if input.device.type == "npu": + infini.ops.rms_norm(input, weight, eps, out, stream=get_npu_stream(input)) + else: + infini.ops.rms_norm(input, weight, eps, out) return out diff --git a/tests/test_rms_norm_dsl.py b/tests/test_rms_norm_dsl.py new file mode 100644 index 00000000..4fb1c611 --- /dev/null +++ b/tests/test_rms_norm_dsl.py @@ -0,0 +1,82 @@ +"""Tests for the DSL-generated RmsNorm operator (implementation_index=1). + +Validates that the DSL-generated CUDA and CPU code produces results +identical to PyTorch's RMS norm reference implementation. +""" + +import infini.ops +import pytest +import torch + +from tests.utils import Payload, empty_strided, randn_strided + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "input_shape, weight_shape, input_strides, weight_strides, out_strides", + ( + ((1, 64), (64,), None, None, None), + ((2, 128), (128,), None, None, None), + ((4, 48, 64), (64,), None, None, None), + ((2, 4, 2048), (2048,), None, None, None), + ((1, 64), (64,), (64, 1), (1,), (64, 1)), + ((4, 48, 64), (64,), (3072, 64, 1), (1,), (3072, 64, 1)), + ), +) +@pytest.mark.parametrize("eps", (1e-6, 1e-5)) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float32, 1e-4, 1e-4), + (torch.float16, 1e-2, 1e-2), + (torch.bfloat16, 2e-2, 1e-2), + ), +) +def test_rms_norm_dsl( + input_shape, + weight_shape, + input_strides, + weight_strides, + out_strides, + eps, + dtype, + device, + rtol, + atol, +): + input = randn_strided(input_shape, input_strides, dtype=dtype, device=device) + weight = randn_strided(weight_shape, weight_strides, dtype=dtype, device=device) + out = empty_strided(input_shape, out_strides, dtype=dtype, device=device) + + return Payload( + _rms_norm_dsl, + _torch_rms_norm, + (input, weight), + {"eps": eps, "out": out}, + rtol=rtol, + atol=atol, + ) + + +def _rms_norm_dsl(input, weight, *, eps=1e-6, out=None): + infini.ops.rms_norm(input, weight, eps, out, implementation="dsl") + + return out + + +def _torch_rms_norm(input, weight, *, eps=1e-6, out=None): + def _fallback(input, _normalized_shape, weight, *, eps=1e-6): + rms = torch.sqrt(torch.mean(input * input, dim=-1, keepdim=True) + eps) + + return (input / rms) * weight + + rms_norm_fn = getattr(torch.nn.functional, "rms_norm", _fallback) + + result = rms_norm_fn(input, input.shape[-1:], weight=weight, eps=eps) + + if out is not None: + out.copy_(result) + else: + out = result + + return out diff --git a/tests/test_rotary_embedding.py b/tests/test_rotary_embedding.py new file mode 100644 index 00000000..2e93f5e0 --- /dev/null +++ b/tests/test_rotary_embedding.py @@ -0,0 +1,287 @@ +import infini.ops +import pytest +import torch + +from tests.utils import get_npu_stream, randn_strided, randint_strided + + +def _rotary_embedding( + positions, + query, + key, + cos_sin_cache, + head_size, + rotary_dim, + is_neox_style, + query_out, + key_out, + device, +): + if device == "npu": + infini.ops.rotary_embedding( + positions, + query, + key, + cos_sin_cache, + head_size, + rotary_dim, + is_neox_style, + query_out, + key_out, + stream=get_npu_stream(query), + ) + else: + infini.ops.rotary_embedding( + positions, + query, + key, + cos_sin_cache, + head_size, + rotary_dim, + is_neox_style, + query_out, + key_out, + ) + + return query_out, key_out + + +def _ref_rotary_embedding( + positions, query, key, cos_sin_cache, head_size, rotary_dim, is_neox_style +): + """PyTorch reference for RoPE. + + ``cos_sin_cache`` layout: ``[max_seq_len, rotary_dim]`` where the first + ``rotary_dim // 2`` columns are cos and the rest are sin. + """ + T = query.size(0) + R = rotary_dim + half_R = R // 2 + + cos_sin = cos_sin_cache.float() + cos_half = cos_sin[:, :half_R] + sin_half = cos_sin[:, half_R:] + + def apply_rope(x): + out = x.float().clone() + + for t in range(T): + p = positions[t].item() + c = cos_half[p] + s = sin_half[p] + + if is_neox_style: + x1 = x[t, :, :half_R].float() + x2 = x[t, :, half_R:R].float() + out[t, :, :half_R] = c * x1 - s * x2 + out[t, :, half_R:R] = c * x2 + s * x1 + else: + x1 = x[t, :, 0::2].float() + x2 = x[t, :, 1::2].float() + out[t, :, 0::2] = c * x1 - s * x2 + out[t, :, 1::2] = c * x2 + s * x1 + + return out.to(x.dtype) + + return apply_rope(query), apply_rope(key) + + +def _assert_close(actual, expected, rtol, atol): + assert torch.allclose(actual, expected, rtol=rtol, atol=atol), ( + f"Max diff: {(actual.float() - expected.float()).abs().max().item()}" + ) + + +@pytest.mark.parametrize( + "num_heads, head_size", + ( + (32, 128), + (8, 64), + ), +) +@pytest.mark.parametrize("is_neox_style", (True, False)) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ), +) +@pytest.mark.parametrize("device", ("cuda", "npu")) +def test_rotary_embedding_full( + num_heads, head_size, is_neox_style, dtype, rtol, atol, device +): + """Full rotary: ``rotary_dim == head_size``.""" + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + if device == "npu" and not (hasattr(torch, "npu") and torch.npu.is_available()): + pytest.skip("NPU not available") + + if device == "npu" and not is_neox_style: + pytest.skip( + "Ascend aclnnApplyRotaryPosEmbV2 only supports neox style " + "(rotaryMode='half')" + ) + + # aclnnApplyRotaryPosEmbV2 accumulates with ~4 ULP error for float16. + if device == "npu" and dtype == torch.float16: + atol = 0.01 + + num_kv_heads = num_heads + rotary_dim = head_size + num_tokens = 16 + max_seq_len = 64 + + positions = randint_strided( + 0, + max_seq_len, + (num_tokens,), + None, + dtype=torch.int64, + device=device, + ) + query = randn_strided( + (num_tokens, num_heads, head_size), + None, + dtype=dtype, + device=device, + ) + key = randn_strided( + (num_tokens, num_kv_heads, head_size), + None, + dtype=dtype, + device=device, + ) + cos_sin_cache = randn_strided( + (max_seq_len, rotary_dim), + None, + dtype=dtype, + device=device, + ) + query_out = torch.empty_like(query) + key_out = torch.empty_like(key) + + q_out, k_out = _rotary_embedding( + positions, + query, + key, + cos_sin_cache, + head_size, + rotary_dim, + is_neox_style, + query_out, + key_out, + device, + ) + + ref_q, ref_k = _ref_rotary_embedding( + positions, + query, + key, + cos_sin_cache, + head_size, + rotary_dim, + is_neox_style, + ) + + _assert_close(q_out, ref_q, rtol, atol) + _assert_close(k_out, ref_k, rtol, atol) + + +@pytest.mark.parametrize( + "num_heads, num_kv_heads, head_size, rotary_dim", + ( + (32, 8, 128, 64), + (16, 4, 64, 32), + ), +) +@pytest.mark.parametrize("is_neox_style", (True,)) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ), +) +@pytest.mark.parametrize("device", ("cuda", "npu")) +def test_rotary_embedding_partial( + num_heads, + num_kv_heads, + head_size, + rotary_dim, + is_neox_style, + dtype, + rtol, + atol, + device, +): + """Partial rotary: ``rotary_dim < head_size``.""" + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + if device == "npu" and not (hasattr(torch, "npu") and torch.npu.is_available()): + pytest.skip("NPU not available") + + if device == "npu": + pytest.skip( + "Ascend aclnnApplyRotaryPosEmbV2 requires rotary_dim == head_size" + ) + + num_tokens = 16 + max_seq_len = 64 + + positions = randint_strided( + 0, + max_seq_len, + (num_tokens,), + None, + dtype=torch.int64, + device=device, + ) + query = randn_strided( + (num_tokens, num_heads, head_size), + None, + dtype=dtype, + device=device, + ) + key = randn_strided( + (num_tokens, num_kv_heads, head_size), + None, + dtype=dtype, + device=device, + ) + cos_sin_cache = randn_strided( + (max_seq_len, rotary_dim), + None, + dtype=dtype, + device=device, + ) + query_out = torch.empty_like(query) + key_out = torch.empty_like(key) + + q_out, k_out = _rotary_embedding( + positions, + query, + key, + cos_sin_cache, + head_size, + rotary_dim, + is_neox_style, + query_out, + key_out, + device, + ) + + ref_q, ref_k = _ref_rotary_embedding( + positions, + query, + key, + cos_sin_cache, + head_size, + rotary_dim, + is_neox_style, + ) + + _assert_close(q_out, ref_q, rtol, atol) + _assert_close(k_out, ref_k, rtol, atol) diff --git a/tests/test_swiglu.py b/tests/test_swiglu.py index 89c95f77..71eaceb1 100644 --- a/tests/test_swiglu.py +++ b/tests/test_swiglu.py @@ -2,7 +2,7 @@ import pytest import torch -from tests.utils import Payload, empty_strided, rand_strided +from tests.utils import Payload, empty_strided, get_npu_stream, rand_strided @pytest.mark.auto_act_and_assert @@ -38,7 +38,10 @@ def test_swiglu( def _swiglu(input, gate, out): - infini.ops.swiglu(input, gate, out) + if input.device.type == "npu": + infini.ops.swiglu(input, gate, out, stream=get_npu_stream(input)) + else: + infini.ops.swiglu(input, gate, out) return out diff --git a/tests/test_swiglu_dsl.py b/tests/test_swiglu_dsl.py new file mode 100644 index 00000000..5627e96a --- /dev/null +++ b/tests/test_swiglu_dsl.py @@ -0,0 +1,54 @@ +"""Tests for the DSL-generated Swiglu operator (implementation_index=1). + +Validates that the DSL-generated code produces results identical to +the reference: SwiGLU(input, gate) = input * silu(gate). +""" + +import infini.ops +import pytest +import torch + +from tests.utils import Payload, empty_strided, rand_strided + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "shape, input_strides, gate_strides, out_strides", + ( + ((13, 4), None, None, None), + ((13, 4), (10, 1), (10, 1), (10, 1)), + ((13, 4, 4), None, None, None), + ((16, 5632), None, None, None), + ((4, 4, 5632), None, None, None), + ), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float32, 1e-7, 1e-7), + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ), +) +def test_swiglu_dsl( + shape, input_strides, gate_strides, out_strides, dtype, device, rtol, atol +): + input = rand_strided(shape, input_strides, dtype=dtype, device=device) + gate = rand_strided(shape, gate_strides, dtype=dtype, device=device) + out = empty_strided(shape, out_strides, dtype=dtype, device=device) + + return Payload( + _swiglu_dsl, _torch_swiglu, (input, gate, out), {}, rtol=rtol, atol=atol + ) + + +def _swiglu_dsl(input, gate, out): + infini.ops.swiglu(input, gate, out, implementation="dsl") + + return out + + +def _torch_swiglu(input, gate, out): + swish_x = gate * torch.sigmoid(gate) + + return torch.mul(input, swish_x, out=out) diff --git a/tests/utils.py b/tests/utils.py index aa4ee429..8412cd61 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -32,12 +32,18 @@ def get_available_devices(): if hasattr(torch, "musa") and torch.musa.is_available(): devices.append("musa") + if hasattr(torch, "npu") and torch.npu.is_available(): + devices.append("npu") + return tuple(devices) with contextlib.suppress(ImportError, ModuleNotFoundError): import torch_mlu # noqa: F401 +with contextlib.suppress(ImportError, ModuleNotFoundError): + import torch_npu # noqa: F401 + def empty_strided(shape, strides, *, dtype=None, device=None): if strides is None: @@ -76,6 +82,14 @@ def randint_strided(low, high, shape, strides, *, dtype=None, device=None): return output +def get_npu_stream(tensor): + """Return the current NPU stream handle for `tensor`, or 0 on other devices.""" + if tensor.device.type != "npu": + return 0 + + return torch.npu.current_stream().npu_stream + + def clone_strided(input): output = empty_strided( input.size(), input.stride(), dtype=input.dtype, device=input.device diff --git a/third_party/flashinfer b/third_party/flashinfer new file mode 160000 index 00000000..a1166dc0 --- /dev/null +++ b/third_party/flashinfer @@ -0,0 +1 @@ +Subproject commit a1166dc0169b479aa3220b61759547d04c64e473