Optimizing CUDA kernel with atomics

Please help,

I have been working on optimizing a CUDA kernel that utilizes atomic adds and comparing performance. I have been able to achieve approximately 2x faster but once the size of the input goes beyond 1025, e.g., 1026, I get the following output:

i (0): original (0.000000,0.000000), modified (16416.000000,8208.000000)

Atomic and Optimized kernels do not match.

Which is incorrect as the modified/optimized CUDA kernel should match the output from the original atomic add CUDA kernel - just faster execution. I am hoping that it is just something stupid simple I am doing wrong. Can anyone help?

The code is posted below, and I am using CUDA 12.2 (driver 535.54.03) with A100-SXM4-80GB device:

#include <cuda.h>
#include <chrono>
#include <iostream>
#include <stdio.h>
#include <string>
#include <sstream>
#include <vector>
#include <stdexcept>
#include <cstdlib>
#include <cmath>

// for easy gpu error checking
#define GPU_ERROR_CHECK(ans) do{gpuAssert((ans),__FILE__,__LINE__);}while(0)
inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort=true)
{
   if (code != cudaSuccess)
   {
      fprintf(stderr,"GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line);
      printf("\nCUDA KERNEL ERROR: CUDA Kernel reports error: %s\n",cudaGetErrorString(code));
      if (abort) exit(code);
   }
}

__forceinline__ __host__ __device__ float dot(float3 a, float3 b)
{
    return a.x * b.x + a.y * b.y + a.z * b.z;
}

__forceinline__ __host__ __device__ float length(float3 v)
{
    return sqrtf(dot(v, v));
}

__forceinline__ __device__ float2 myKernel(float val) {
    return make_float2(val, val / 2);
}
/**
 * @brief Works for values less than 1026 samples, that is up to and including 1025 samples
 */
__global__ void optimized_org_kernel(const float3 * __restrict__ pos_a, const float3 * __restrict__ pos_b, 
                                     const uint32_t input_size, const uint32_t output_size, 
                                     float2 * __restrict__ result, const int region, 
                                     const uint32_t first_idx_x, const uint32_t last_idx_x)
{
    // Calculate thread indices
    uint32_t thridx_x = threadIdx.x + blockDim.x * blockIdx.x + first_idx_x;
    uint32_t stride_x = blockDim.x * gridDim.x;
    uint32_t thridx_y = threadIdx.y + blockDim.y * blockIdx.y;
    uint32_t stride_y = blockDim.y * gridDim.y;

    float3 distance3;
    float distance;
    uint32_t output_start, output_end;

    // Local accumulation variables to reduce the number of atomic operations
    float2 local_accum = make_float2(0.0f, 0.0f);

    for (uint32_t x = thridx_x; x < last_idx_x; x += stride_x) {
        // Calculate the distance between points
        distance3.x = pos_a[x].x - pos_b[x].x;
        distance3.y = pos_a[x].y - pos_b[x].y;
        distance3.z = pos_a[x].z - pos_b[x].z;
        distance = length(distance3);

        // Determine the output range for this thread
        output_start = __fdiv_rz(output_size, 2) + __fdiv_rz(distance, output_size) - region;
        output_end = output_start + region;

        // Clamp the values to ensure they stay within bounds
        output_start = max(0u, output_start);
        output_end = min(output_end, output_size);

        for (uint32_t y = thridx_y; y < output_size; y += stride_y) {
            // Only accumulate within the valid range
            if (y >= output_start && y < output_end) {
                float2 lval = myKernel(1.0f);
                local_accum.x += lval.x;
                local_accum.y += lval.y;
            }
        }
    }

    // Write back the accumulated values using atomic operations
    if (local_accum.x != 0.0f || local_accum.y != 0.0f) {
        atomicAdd(&result[thridx_y].x, local_accum.x);
        atomicAdd(&result[thridx_y].y, local_accum.y);
    }
}

__global__ void org_kernel(const float3 * pos_a, const float3 * pos_b, 
                           const uint32_t input_size, const uint32_t output_size, 
                           float2 * result, const int region, 
                           const uint32_t first_idx_x, const uint32_t last_idx_x)
{
    uint32_t thridx_x = threadIdx.x + blockDim.x * blockIdx.x + first_idx_x;
    uint32_t thridx_y = threadIdx.y + blockDim.y * blockIdx.y;
    uint32_t stride_x = blockDim.x * gridDim.x;
    uint32_t stride_y = blockDim.y * gridDim.y;

    float3 distance3 = make_float3(0.0f, 0.0f, 0.0f);
    float distance = 0;

    uint32_t output_start, output_end;

    for(uint32_t x = thridx_x; x < last_idx_x; x += stride_x){
        // distance calcs
        distance3.x = pos_a[x].x - pos_b[x].x;
        distance3.y = pos_a[x].y - pos_b[x].y;
        distance3.z = pos_a[x].z - pos_b[x].z;

        distance = length(distance3);
        output_start = __fdiv_rz(output_size, 2) + __fdiv_rz(distance, output_size) - region;
        output_end = output_start + region;
        
        for(uint32_t y = thridx_y; y < output_size; y += stride_y){
            if((y < output_end) && (y >= output_start)){
                float2 lval = myKernel(1.0f);
                atomicAdd(&result[y].x, lval.x);
                atomicAdd(&result[y].y, lval.y);
            }                        
        }
    }
}

bool eval_arrays_equal(float2 * d_org, float2 * d_mod, uint32_t n) {
    if (d_org == nullptr || d_mod == nullptr) {
        throw std::invalid_argument("Arrays are NULL.");
    }
    if (n < 1) {
        throw std::invalid_argument("Invalid array length, less than 1.");
    }

    float2 * h_org;
    float2 * h_mod;
    size_t sz = n * sizeof(float2);
    h_org = (float2*)malloc(sz);
    h_mod  = (float2*)malloc(sz);

    GPU_ERROR_CHECK(cudaMemcpy(h_org, d_org, sz, cudaMemcpyDeviceToHost));
    GPU_ERROR_CHECK(cudaMemcpy(h_mod, d_mod, sz, cudaMemcpyDeviceToHost));

    GPU_ERROR_CHECK(cudaDeviceSynchronize());
    for (uint32_t i = 0; i < n; ++i) {
        if (h_org[i].x != h_mod[i].x || h_org[i].y != h_mod[i].y) {
            printf("\ti (%i): original (%f,%f), modified (%f,%f)\n", i, h_org[i].x, h_org[i].y, h_mod[i].x, h_mod[i].y);
            return false;
        }
    }

    free(h_org);
    free(h_mod);

    // Every element is equal
    return true;
}

void printDeviceProperties() {
    int32_t device;
    cudaError_t error = cudaGetDevice(&device);
    if (error != cudaSuccess) {
        std::cerr << "Failed to get current device: " << cudaGetErrorString(error) << std::endl;
        return;
    }

    cudaDeviceProp deviceProp;
    error = cudaGetDeviceProperties(&deviceProp, device);
    if (error != cudaSuccess) {
        std::cerr << "Failed to get device properties: " << cudaGetErrorString(error) << std::endl;
        return;
    }

    std::cout << "Device " << device << ": \"" << deviceProp.name << "\"" << std::endl;
    std::cout << "  CUDA Capability: " << deviceProp.major << "." << deviceProp.minor << std::endl;
    std::cout << "  Total Global Memory: " << deviceProp.totalGlobalMem / (1024 * 1024) << " MB" << std::endl;
    std::cout << "  Shared Memory per Block: " << deviceProp.sharedMemPerBlock / 1024 << " KB" << std::endl;
    std::cout << "  Registers per Block: " << deviceProp.regsPerBlock << std::endl;
    std::cout << "  Warp Size: " << deviceProp.warpSize << std::endl;
    std::cout << "  Max Threads per Block: " << deviceProp.maxThreadsPerBlock << std::endl;
    std::cout << "  Max Threads Dim: [" << deviceProp.maxThreadsDim[0] << ", "
              << deviceProp.maxThreadsDim[1] << ", " << deviceProp.maxThreadsDim[2] << "]" << std::endl;
    std::cout << "  Max Grid Size: [" << deviceProp.maxGridSize[0] << ", "
              << deviceProp.maxGridSize[1] << ", " << deviceProp.maxGridSize[2] << "]" << std::endl;
    std::cout << "  Clock Rate: " << deviceProp.clockRate / 1000 << " MHz" << std::endl;
    std::cout << "  Total Constant Memory: " << deviceProp.totalConstMem / 1024 << " KB" << std::endl;
    std::cout << "  Multiprocessor Count: " << deviceProp.multiProcessorCount << std::endl;
    std::cout << "  Compute Mode: " << deviceProp.computeMode << std::endl;
}

int main(int argc, char * argv[]) {
    if (argc != 2) {
        fprintf(stderr, "\nPass the number of array elements via command line as follows:\n");
        fprintf(stderr, "./xTest <num_elems>\n\n");
        return EXIT_FAILURE;
    }

    // Dimensions
    const uint32_t BLOCK_WIDTH = 512;
    dim3 nblks(BLOCK_WIDTH,1,1);
    dim3 nthreads(1,BLOCK_WIDTH,1);

    // Retrieve command-line argument
    uint32_t n_values = static_cast<uint32_t>(std::stoi(argv[1]));
   
    uint32_t region = 3;

    uint32_t n_float3s = n_values;
    uint32_t float3_sz = n_float3s * sizeof(float3);
    uint32_t output_sz = n_values * sizeof(float2);

    // Allocate host & device side
    float2 *d_out_org;
    float2 *d_out_mod;
    GPU_ERROR_CHECK(cudaMalloc(&d_out_org, output_sz));
    GPU_ERROR_CHECK(cudaMalloc(&d_out_mod, output_sz));
    GPU_ERROR_CHECK(cudaMemset(d_out_org, 0, output_sz));
    GPU_ERROR_CHECK(cudaMemset(d_out_mod, 0, output_sz));

    // Float3s
    float3 *pos_a, *pos_b;
    float3 *d_pos_a, *d_pos_b;
    pos_a = (float3*)malloc(float3_sz);
    pos_b = (float3*)malloc(float3_sz);
    for(size_t p = 0; p < n_float3s; ++p){
        pos_a[p] = make_float3(1,1,1);
        pos_b[p] = make_float3(0.1,0.1,0.1);
    }
    GPU_ERROR_CHECK(cudaMalloc(&d_pos_a, float3_sz));
    GPU_ERROR_CHECK(cudaMalloc(&d_pos_b, float3_sz));
    GPU_ERROR_CHECK(cudaMemcpy(d_pos_a, pos_a, float3_sz, cudaMemcpyHostToDevice));
    GPU_ERROR_CHECK(cudaMemcpy(d_pos_b, pos_b, float3_sz, cudaMemcpyHostToDevice));
    GPU_ERROR_CHECK(cudaDeviceSynchronize());

    float total_time_org = 0.0f;
    float total_time_mod = 0.0f;

    uint32_t first_idx_x = 0;
    uint32_t last_idx_x  = n_values;

    const uint32_t n_passes = 16;

    for (uint32_t pass = 0; pass < n_passes; ++pass) {
        auto start = std::chrono::high_resolution_clock::now();

        // Original atomic add kernel
        org_kernel<<<nblks,nthreads,0,0>>>(d_pos_a, d_pos_b, n_values, n_values, d_out_org, region, first_idx_x, last_idx_x);
        GPU_ERROR_CHECK(cudaDeviceSynchronize());

        auto stop = std::chrono::high_resolution_clock::now();
        total_time_org += static_cast<float>(std::chrono::duration_cast<std::chrono::nanoseconds>(stop - start).count());

        start = std::chrono::high_resolution_clock::now();
        
        // Optimized atomic add kernel
        optimized_org_kernel<<<nblks,nthreads,0,0>>>(d_pos_a, d_pos_b, n_values, n_values, d_out_mod, region, first_idx_x, last_idx_x);
        GPU_ERROR_CHECK(cudaDeviceSynchronize());

        stop = std::chrono::high_resolution_clock::now();
        total_time_mod += static_cast<float>(std::chrono::duration_cast<std::chrono::nanoseconds>(stop - start).count());
    }

    // Check for fidelity
    if (eval_arrays_equal(d_out_org, d_out_mod, n_values)) {
        printf("\nFidelity achieved.\n");
        printf("\tTotal number of passes: %d\n", n_passes);
        float org_time = (total_time_org / n_passes);
        float mod_time = (total_time_mod / n_passes);
        printf("\t[ORIGINAL] Time: %8.9f (us.)\n", org_time);
        printf("\t[MODIFIED] Time: %8.9f (us.)\n", mod_time);
        printf("\tSpeedup Factor: %8.9f\n", (org_time / mod_time));
    } else {
        printf("\nAtomic and Optimized kernels do not match.\n");
        return EXIT_FAILURE;
    }

    GPU_ERROR_CHECK(cudaPeekAtLastError());
    GPU_ERROR_CHECK(cudaDeviceSynchronize());
    GPU_ERROR_CHECK(cudaDeviceReset());

    return EXIT_SUCCESS;
}

Thanks to anyone who can point out any error or point a direction that might fix this issue.

The original kernel adds a number to multiple columns. The modified kernel sums up multiple numbers and adds it to a single column. How can this ever be equivalent?

Did you investigate all suggestions already given in the original thread here ?

Hi @striker159

Thank you for the reply.

Yes, I have been looking into the suggestions. I don’t think shared memory is useful in my case as I was able to get it working but during execution the performance was not actually much better, maybe 3%. So, I wanted to look in another direction hopefully without shared memory - just lowering the number of necessary atomicAdds if possible.

I think I just solved the problem, by using warp level primitives and only calling atomicAdd once per block.

The performance is somewhere between 1.5x to 2.8x faster with the mod_kernel as compared to the org_kernel. I have posted the code below for anyone who may like to use it for their own project(s).

Thanks to all for assistance pointing me in some good direction(s) for optimizing this CUDA kernel.

#include <cuda.h>
#include <cuda_runtime.h>
#include <iostream>
#include <iomanip>
#include <stdexcept>
#include <chrono>
#include <cstdlib>

// for easy gpu error checking
#define GPU_ERROR_CHECK(ans) do{gpuAssert((ans),__FILE__,__LINE__);}while(0)
inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort=true)
{
   if (code != cudaSuccess)
   {
      fprintf(stderr,"GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line);
      printf("\nCUDA KERNEL ERROR: CUDA Kernel reports error: %s\n",cudaGetErrorString(code));
      if (abort) exit(code);
   }
}
/**
 * @brief CUDA DEVICE kernels executes scalar dot product.
 * 
 * @param a The first float3.
 * @param b The second float3.
 * @return floating-point value that is the scalar dot product.
 */
__forceinline__ __host__ __device__ float dot(float3 a, float3 b) {
	return a.x * b.x + a.y * b.y + a.z * b.z;
}

/**
 * @brief CUDA DEVICE kernel executes Euclidean length of input float3.
 * 
 * @param v The float3 whose x, y, z components length is being computed from.
 * @return floating-point value that is the Euclidean length of input float3.
 */
__forceinline__ __host__ __device__ float length(float3 v) {
	return sqrtf(dot(v, v));
}

/**
 * @brief CUDA DEVICE kernel is a toy operation for demonstration purposes, whereby the
 * input value is modified and returned as a float2 datatype.
 * 
 * @param val The input value being modified.
 * @return float2 version of input float with modifications applied.
 */
__forceinline__ __host__ __device__ float2 myKernel(float val) {
	return make_float2(val, val / 2);
}

/**
 * @brief CUDA DEVICE kernel that executes a warp-level summation of input float2 value.
 * @details Allows data to be summed without the use of extra memory space, that is, shared
 * directly across threads in a single warp (32-threads).
 * 
 * @param val The float2 value being summed across warp.
 * @return float2 value summed across threads in a single warp.
 */
__inline__ __device__ float2 warpReduceSum(float2 val) {
    for (int offset = warpSize / 2; offset > 0; offset /= 2) {
        val.x += __shfl_down_sync(0xffffffff, val.x, offset);
        val.y += __shfl_down_sync(0xffffffff, val.y, offset);
    }
    return val;
}

/**
 * @brief CUDA DEVICE kernel that calls CUDA intrinsic @ref atomicAdd only on the
 * first thread and warp in the block.
 * 
 * @param address The resulting global address where the value is added and stored.
 * @param val The value being added to the global address.
 */
__inline__ __device__ void atomicAddWarp(float2 *address, float2 val) {
    if (threadIdx.x % warpSize == 0) {
        atomicAdd(&address->x, val.x);
        atomicAdd(&address->y, val.y);
    }
}
__global__ void org_kernel(const float3 * pos_a, const float3 * pos_b, 
                           const uint32_t input_size, const uint32_t output_size, 
                           float2 * result, const int32_t region,
                           const uint32_t first_idx_x, const uint32_t last_idx_x) {
    // Compute indices
    uint32_t thridx_x = threadIdx.x + blockDim.x * blockIdx.x + first_idx_x;
    uint32_t thridx_y = threadIdx.y + blockDim.y * blockIdx.y;
    uint32_t stride_x = blockDim.x * gridDim.x;
    uint32_t stride_y = blockDim.y * gridDim.y;

    float3 distance3 = make_float3(0.0f, 0.0f, 0.0f);
    float distance = 0;

    uint32_t output_start, output_end;

    for(uint32_t x = thridx_x; x < last_idx_x; x += stride_x){
        // distance calcs
        distance3.x = pos_a[x].x - pos_b[x].x;
        distance3.y = pos_a[x].y - pos_b[x].y;
        distance3.z = pos_a[x].z - pos_b[x].z;

        distance = length(distance3);
        output_start = __fdiv_rz(output_size, 2) + __fdiv_rz(distance, output_size) - region;
        output_end = output_start + region;
        
        for(uint32_t y = thridx_y; y < output_size; y += stride_y){
            if((y < output_end) && (y >= output_start)) {
                float2 lval = myKernel(1.0f);
                atomicAdd(&result[y].x, lval.x);
                atomicAdd(&result[y].y, lval.y);
            }                        
        }
    }
}
__global__ void mod_kernel(const float3 * __restrict__ pos_a, const float3 * __restrict__ pos_b, 
                           const uint32_t input_size, const uint32_t output_size, 
                           float2 * __restrict__ result, const int32_t region,
                           const uint32_t first_idx_x, const uint32_t last_idx_x) {
    // Compute indices
    uint32_t thridx_x = threadIdx.x + blockDim.x * blockIdx.x + first_idx_x;
    uint32_t thridx_y = threadIdx.y + blockDim.y * blockIdx.y;
    uint32_t stride_x = blockDim.x * gridDim.x;
    uint32_t stride_y = blockDim.y * gridDim.y;

    if (thridx_x >= last_idx_x) return;

    float3 distance3;
    float distance;
    uint32_t output_start, output_end;

    for (uint32_t x = thridx_x; x < last_idx_x; x += stride_x) {
        // Pre-calculate distance components
        distance3.x = pos_a[x].x - pos_b[x].x;
        distance3.y = pos_a[x].y - pos_b[x].y;
        distance3.z = pos_a[x].z - pos_b[x].z;

        // Compute the distance and the output indices range
        distance = sqrtf(distance3.x * distance3.x + 
                         distance3.y * distance3.y + 
                         distance3.z * distance3.z);

        output_start = __fdividef(output_size, 2) + __fdividef(distance, output_size) - region;
        output_end = output_start + region;

        // Restrict output range to valid indices
        output_start = max(output_start, 0U);
        output_end = min(output_end, output_size);

        for (uint32_t y = thridx_y; y < output_size; y += stride_y) {
            if (y >= output_start && y < output_end) {
                float2 lval = myKernel(1.0f);
                // Execute warp-level primitives then only call atomic add once per block
                float2 warp_sum = warpReduceSum(lval);
                atomicAddWarp(&result[y], warp_sum);
            }
        }
    }
}
bool eval_arrays(const float2 * d_arr1, const float2 * d_arr2, const uint32_t n) {
    if (d_arr1 == nullptr || d_arr2 == nullptr) {
	    throw std::invalid_argument("Null array(s).");
    }
    if (n < 1) {
	    throw std::invalid_argument("Invalid array length.");
    }

    float2 * h_arr1 = nullptr;
    float2 * h_arr2 = nullptr;
    h_arr1 = new float2[n];
    h_arr2 = new float2[n];
    GPU_ERROR_CHECK(cudaMemcpy(h_arr1, d_arr1, n * sizeof(float2), cudaMemcpyDeviceToHost));
    GPU_ERROR_CHECK(cudaMemcpy(h_arr2, d_arr2, n * sizeof(float2), cudaMemcpyDeviceToHost));

    for (uint32_t i = 0; i < n; ++i) {
        if (h_arr1[i].x != h_arr2[i].x || h_arr1[i].y != h_arr2[i].y) {
            std::cout << "Index: " << i << " Array 1 (" << h_arr1[i].x << "," << h_arr1[i].y << "), Array 2 ("
                      << h_arr2[i].x << "," << h_arr2[i].y << ")\n";
            delete [] h_arr1;
            delete [] h_arr2;
            return false;
	    }
    }
    
    delete [] h_arr1;
    delete [] h_arr2;

    // Every element in both arrays was the same
    return true;
}

int main(int argc, char * argv[]) {
    if (argc != 2) {
        std::cerr << "\nPass the number of array elements via command line as follows:\n";
        std::cerr << "./xOptimize <num_elems>\n\n";
        return EXIT_FAILURE;
    }

    // Get number of array elements from command line
    int n_values = std::stoi(argv[1]);
    if (n_values < 1) {
        std::cerr << "Invalid number of array elements: " << n_values << std::endl;
        return EXIT_FAILURE;
    }

    // Defined sizes
    const uint32_t BLOCK_WIDTH = 512;
    size_t float3_sz = n_values * sizeof(float3);
    size_t output_sz = n_values * sizeof(float2);

    // HOST-side positions
    float3 *pos_a = nullptr;
    float3 *pos_b = nullptr;
    pos_a = new float3[n_values];
    pos_b = new float3[n_values];
    for (int i = 0; i < n_values; ++i) {
        pos_a[i] = make_float3(i, i + 1, i + 2);
        pos_b[i] = make_float3(i + 0.5f, i + 1.5f, i + 2.5f);
    }
    // DEVICE-side positions
    float3 *d_pos_a = nullptr;
    float3 *d_pos_b = nullptr;
    GPU_ERROR_CHECK(cudaMalloc(&d_pos_a, float3_sz));
    GPU_ERROR_CHECK(cudaMalloc(&d_pos_b, float3_sz));
    GPU_ERROR_CHECK(cudaMemcpy(d_pos_a, pos_a, float3_sz, cudaMemcpyHostToDevice));
    GPU_ERROR_CHECK(cudaMemcpy(d_pos_b, pos_b, float3_sz, cudaMemcpyHostToDevice));

    // DEVICE-side outputs
    float2 *d_out_org = nullptr;
    float2 *d_out_mod = nullptr;
    GPU_ERROR_CHECK(cudaMalloc(&d_out_org, output_sz));
    GPU_ERROR_CHECK(cudaMalloc(&d_out_mod, output_sz));
    GPU_ERROR_CHECK(cudaMemset(d_out_org, 0, output_sz));
    GPU_ERROR_CHECK(cudaMemset(d_out_mod, 0, output_sz));

    float total_time_org = 0.0f;
    float total_time_mod = 0.0f;
    uint32_t first_idx_x = 0;
    uint32_t last_idx_x = n_values;

    int region = 3;
    dim3 nthreads(BLOCK_WIDTH, 1, 1);
    dim3 nblocks(1, BLOCK_WIDTH, 1);

    const uint32_t n_passes = 16;
    for (uint32_t pass = 0; pass < n_passes; ++pass) {

        auto start = std::chrono::high_resolution_clock::now();

        // Original atomic kernel
        org_kernel<<<nblocks, nthreads>>>(d_pos_a, d_pos_b, n_values, n_values, d_out_org, region, first_idx_x, last_idx_x);
        GPU_ERROR_CHECK(cudaDeviceSynchronize());

        auto stop = std::chrono::high_resolution_clock::now();
        total_time_org += static_cast<float>(std::chrono::duration_cast<std::chrono::nanoseconds>(stop - start).count());

        start = std::chrono::high_resolution_clock::now();

        // Modified atomic kernel
        mod_kernel<<<nblocks, nthreads>>>(d_pos_a, d_pos_b, n_values, n_values, d_out_mod, region, first_idx_x, last_idx_x);
        GPU_ERROR_CHECK(cudaDeviceSynchronize());

        stop = std::chrono::high_resolution_clock::now();
        total_time_mod += static_cast<float>(std::chrono::duration_cast<std::chrono::nanoseconds>(stop - start).count());
    }

    std::cout << std::fixed << std::setprecision(4);
    total_time_org /= n_passes;
    total_time_mod /= n_passes;

    std::cout << "\nTotal number of passes: " << n_passes << std::endl;
    std::cout << "Original CUDA Kernel Time: " << total_time_org << " (us.)\n";
    std::cout << "Modified CUDA Kernel Time: " << total_time_mod << " (us.)\n";
    std::cout << "Speedup factor: " << (total_time_org / total_time_mod) << std::endl;

    // Check fidelity
    if (eval_arrays(d_out_org, d_out_mod, n_values)) {
        std::cout << "\nFidelity achieved.\n\n";
    }else {
        std::cout << "\nFidelity not achieved.\n\n";
    }

    return EXIT_SUCCESS;
}

When trying to improve a kernel, your main focus should be correctness, not speed. If you don’t want correct results, you could remove the kernel which will give you the greatest speedup.

As already pointed out by others in the first thread, your modified kernels are not equivalent to the original kernel.

This is your original code

        for(uint32_t y = thridx_y; y < output_size; y += stride_y){
            if((y < output_end) && (y >= output_start)) {
                float2 lval = myKernel(1.0f);
                atomicAdd(&result[y].x, lval.x);
            }                        
        }

Let’s plug in some numbers for simplicity.

        for(uint32_t y = 0; y < 5; y += 1){
            if((y < 5) && (y >= 0)) {
                float2 lval = myKernel(1.0f);
                atomicAdd(&result[y].x, lval.x);
            }                        
        }

This means if result is initialized with [0,0,0,0,0], it will be [1,1,1,1,1] after the loop.

But if you accumulate the values for multiple y in any way, and then write to output, you will get different results.

        local_accum = 0
        for(uint32_t y = 0; y < 5; y += 1){
            if((y < 5) && (y >= 0)) {
                float2 lval = myKernel(1.0f);
                local_accum.x += lval.x;
            }                        
        }
       atomicAdd(&result[0].x, local_accum.x);

This will output [5,0,0,0,0], not [1,1,1,1,1]

The simplest solution the reduce the number of atomics in the original kernel is to only perform atomicAdd(result.x, 1) . Then afterwards use a second kernel which computes result.y = result.x / 2

Hi @striker159,

Thank you for the reply.

Yes, my first goal has always been to ensure the correctness of the modified code with the original code then to speed it up, if possible. This is why it was such a problem, speeding it up is not too difficult but speeding it up and maintaining consistency with the original results, that’s the issue.

I understand that the following code would give incorrect results when comparing with the original:

That is why I have since abandoned that idea.

However, is there any reason that the following snippet of code from above wouldn’t work?

...
__inline__ __device__ float2 warpReduceSum(float2 val) {
    for (int offset = warpSize / 2; offset > 0; offset /= 2) {
        val.x += __shfl_down_sync(0xffffffff, val.x, offset);
        val.y += __shfl_down_sync(0xffffffff, val.y, offset);
    }
    return val;
}
__inline__ __device__ void atomicAddWarp(float2 *address, float2 val) {
    if (threadIdx.x % warpSize == 0) {
        atomicAdd(&address->x, val.x);
        atomicAdd(&address->y, val.y);
    }
}
__global__ void mod_kernel(const float3 * __restrict__ pos_a, const float3 * __restrict__ pos_b, 
                           const uint32_t input_size, const uint32_t output_size, 
                           float2 * __restrict__ result, const int32_t region,
                           const uint32_t first_idx_x, const uint32_t last_idx_x) {
    // Compute indices
    uint32_t thridx_x = threadIdx.x + blockDim.x * blockIdx.x + first_idx_x;
    uint32_t thridx_y = threadIdx.y + blockDim.y * blockIdx.y;
    uint32_t stride_x = blockDim.x * gridDim.x;
    uint32_t stride_y = blockDim.y * gridDim.y;

    if (thridx_x >= last_idx_x) return;

    float3 distance3;
    float distance;
    uint32_t output_start, output_end;

    for (uint32_t x = thridx_x; x < last_idx_x; x += stride_x) {
        // Pre-calculate distance components
        distance3.x = pos_a[x].x - pos_b[x].x;
        distance3.y = pos_a[x].y - pos_b[x].y;
        distance3.z = pos_a[x].z - pos_b[x].z;

        // Compute the distance and the output indices range
        distance = sqrtf(distance3.x * distance3.x + 
                         distance3.y * distance3.y + 
                         distance3.z * distance3.z);

        output_start = __fdividef(output_size, 2) + __fdividef(distance, output_size) - region;
        output_end = output_start + region;

        // Restrict output range to valid indices
        output_start = max(output_start, 0U);
        output_end = min(output_end, output_size);

        for (uint32_t y = thridx_y; y < output_size; y += stride_y) {
            if (y >= output_start && y < output_end) {
                float2 lval = myKernel(1.0f);
                // Execute warp-level primitives then only call atomic add once per block
                float2 warp_sum = warpReduceSum(lval);
                atomicAddWarp(&result[y], warp_sum);
            }
        }
    }
}
...

Thanks again for the help.

First of all, there is no guarantee that all threads in the warp reach the reduction code.

That aside, the reduction approach has the same problem. Values for multiple y are combined.
For example, in thread0 y=0, in thread1 y=1. Then the warpsum will be 2, and both result[0] and result[1] will be set to 2. But the correct result would be result[0] = 1 and result[1] = 1.

Okay, but wouldn’t calling one kernel to operate on the x component and then another kernel to operate on the y component end up costing more time than just calling the original kernel once? Unless the input is very large to maybe mitigate.

I guess I am trying to understand if maybe the original kernel is as fast as it can be given the parameters of the problem itself.

After the first x calculations, store the intermediate results in shared memory, use syncthreads() to synchronize block-wise, then use one or some of the warps for y processing within the block.

Hi @Curefab,

Thank you for the reply. Using shared memory is a good idea, and I have tried various versions of this methodology. While the results are valid in that they match the original values, the performance is really no better or if it is it is negligible. This is why I am beginning to think that maybe the original kernel is as fast as it can get.

That is quite likely.

You can make some further tests with changed parameters or theoretical calculations of the maximum speed (e.g. how many bytes you read/write vs. the bandwidth of device memory) combined with Compute Nsight to support this assumption.