From a32b22cf2ce307cfa774c4dd2a74e9003652e7cd Mon Sep 17 00:00:00 2001 From: Shehtab Date: Fri, 13 Mar 2026 03:34:07 -0400 Subject: [PATCH] Patch local kernel to add explicit specialization - Update CMake to only require when NVSHMEM is enabled - Fix newline warning on macro --- CMakeLists.txt | 2 +- .../distributed/csrc/local_data_kernels.cuh | 22 +++++++++++++++---- DGraph/distributed/include/macros.hpp | 2 +- 3 files changed, 20 insertions(+), 6 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 1125042..1b8f04c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -60,7 +60,6 @@ list(APPEND "${CMAKE_CURRENT_SOURCE_DIR}/cmake" ) find_package(CUDAToolkit REQUIRED) -find_package(MPI 3.0 REQUIRED COMPONENTS CXX) find_package(Torch 2.6 REQUIRED CONFIG) # Also, torch_python! @@ -79,6 +78,7 @@ find_library(TORCH_PYTHON_LIBRARY find_library(TORCH_PYTHON_LIBRARY torch_python REQUIRED) if (DGRAPH_ENABLE_NVSHMEM) + find_package(MPI 3.0 REQUIRED COMPONENTS CXX) find_package(NVSHMEM 2.5 REQUIRED MODULE) endif () diff --git a/DGraph/distributed/csrc/local_data_kernels.cuh b/DGraph/distributed/csrc/local_data_kernels.cuh index e4f58bc..7b95a72 100644 --- a/DGraph/distributed/csrc/local_data_kernels.cuh +++ b/DGraph/distributed/csrc/local_data_kernels.cuh @@ -252,8 +252,6 @@ namespace Local } } - - template struct FloatAtomicAddOp { @@ -263,6 +261,23 @@ namespace Local } }; + // Add specialization + template <> + struct FloatAtomicAddOp + { + __device__ __forceinline__ void operator()(float4 *cur_addr, const float4 new_val) + { + // Cast the float4 pointer to a standard float pointer + float *addr_as_float = reinterpret_cast(cur_addr); + + // Atomically add each component individually + atomicAdd(&addr_as_float[0], new_val.x); + atomicAdd(&addr_as_float[1], new_val.y); + atomicAdd(&addr_as_float[2], new_val.z); + atomicAdd(&addr_as_float[3], new_val.w); + } + }; + template struct FloatSetOp { @@ -272,7 +287,6 @@ namespace Local } }; - /** * * Masked Gather Kernel operation that performs the operation: @@ -381,7 +395,7 @@ namespace Local for (; col < num_cols / 4; col += nthreadsx) { const float4 values_vec = reinterpret_cast(values)[values_offset + input_row * num_cols / 4 + col]; - float4* output_addr = &reinterpret_cast(output)[output_offset + output_row * num_cols / 4 + col]; + float4 *output_addr = &reinterpret_cast(output)[output_offset + output_row * num_cols / 4 + col]; binary_operator(output_addr, values_vec); } } diff --git a/DGraph/distributed/include/macros.hpp b/DGraph/distributed/include/macros.hpp index a1444b6..0b614e8 100644 --- a/DGraph/distributed/include/macros.hpp +++ b/DGraph/distributed/include/macros.hpp @@ -32,4 +32,4 @@ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_INPUT(x) \ CHECK_CUDA(x); \ - CHECK_CONTIGUOUS(x) \ No newline at end of file + CHECK_CONTIGUOUS(x)