[MLIR][RFC][Vector][SPIR-V] Support single element vector in vector.load/store lowering to SPIR-V

Hi All,

I was trying to use vector-to-spirv and it seems like the vector.load/store lowering is not supported for single element vector.

Here is the simple code snippet:

// RUN: mlir-opt --convert-vector-to-spirv test.mlir

module attributes {
  gpu.container_module, spirv.target_env =  #spirv.target_env<#spirv.vce<v1.0, [Shader], []>, api=OpenCL, #spirv.resource_limits<>>} {
  gpu.module @kernels {
      func.func @vector_store(%arg0 : memref<4xf32, #spirv.storage_class<CrossWorkgroup>>, %arg1 : vector<4xf32>, %arg2 : vector<1xf32>, %arg3 : vector<f32>, %arg4 : f32) attributes {} {
      %idx = arith.constant 0 : index
      %c0 = arith.constant 0 : i32
      %c0_vector = arith.constant dense<0.0> : vector<1xf32>

      // vector.load (vector::LoadOp)
      %load_1d = vector.load %arg0[%idx] : memref<4xf32, #spirv.storage_class<CrossWorkgroup>>, vector<1xf32>
      %load_0d = vector.load %arg0[%idx] : memref<4xf32, #spirv.storage_class<CrossWorkgroup>>, vector<f32>

      // vector.store (vector::StoreOp)
      vector.store %arg3, %arg0[%idx] : memref<4xf32, #spirv.storage_class<CrossWorkgroup>>, vector<f32>
      vector.store %arg2, %arg0[%idx] : memref<4xf32, #spirv.storage_class<CrossWorkgroup>>, vector<1xf32>

      return
      }
  }
}

It seems like vector-to-spirv pass have issues handling single element vector.

It entirely possible, I missing something or maybe not using the pass correctly. If not, I can try to add the single element vector support to vector-to-spirv pass.

One possible solution could be:
For single element vector we could follow the lowering LLVM path, and lower it to <vector.extractelement + memref.load/store> instead of spirv.load/store.

But I am open to other ideas.

Thanks in advance :slight_smile:

1 Like

@kuhar

In general, SPIR-V lowering resolves one-element vectors to scalars, so anything more complicated than a simple ‘in-place’ use of one-element vectors needs special handling. You can find a bunch of code that check for this with isa<spirv::ScalarType>, or similar, in the SPIR-V conversion patterns.

I don’t remember why one-element vector load/store is not handled; we do rely on expansion/emulation for some vector ops like vector.maskedload/store (see VectorEmulateMaskedLoadStore.cpp), but maybe we didn’t need it because memref.load can express the same access patterns.

We could do that if direct lowering is not straightforward to implement.

In any case, patches welcome :slight_smile:

1 Like

Thank you so much @kuhar :slight_smile: