Skip to content

Commit f010b77

Browse files
authored
vulkan : add backend registry / device interfaces (#9721)
* vulkan : add backend registry / device interfaces * llama : print devices used on model load
1 parent 2194200 commit f010b77

File tree

4 files changed

+223
-120
lines changed

4 files changed

+223
-120
lines changed

ggml/include/ggml-vulkan.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ GGML_API ggml_backend_buffer_type_t ggml_backend_vk_buffer_type(size_t dev_num);
2424
// pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
2525
GGML_API ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type(void);
2626

27+
GGML_API ggml_backend_reg_t ggml_backend_vk_reg(void);
28+
2729
#ifdef __cplusplus
2830
}
2931
#endif

ggml/src/ggml-backend.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,10 @@ void * ggml_backend_reg_get_proc_address(ggml_backend_reg_t reg, const char * na
538538
#include "ggml-metal.h"
539539
#endif
540540

541+
#ifdef GGML_USE_VULKAN
542+
#include "ggml-vulkan.h"
543+
#endif
544+
541545
#ifdef GGML_USE_BLAS
542546
#include "ggml-blas.h"
543547
#endif
@@ -557,14 +561,17 @@ struct ggml_backend_registry {
557561
#ifdef GGML_USE_METAL
558562
register_backend(ggml_backend_metal_reg());
559563
#endif
564+
#ifdef GGML_USE_VULKAN
565+
register_backend(ggml_backend_vk_reg());
566+
#endif
560567
#ifdef GGML_USE_BLAS
561568
register_backend(ggml_backend_blas_reg());
562569
#endif
563570
#ifdef GGML_USE_RPC
564571
register_backend(ggml_backend_rpc_reg());
565572
#endif
566573

567-
// TODO: sycl, vulkan, kompute, cann
574+
// TODO: sycl, kompute, cann
568575

569576
register_backend(ggml_backend_cpu_reg());
570577
}

ggml/src/ggml-vulkan.cpp

Lines changed: 204 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1941,7 +1941,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
19411941
if (device->fp16) {
19421942
device_extensions.push_back("VK_KHR_shader_float16_int8");
19431943
}
1944-
device->name = device->properties.deviceName.data();
1944+
device->name = GGML_VK_NAME + std::to_string(idx);
19451945

19461946
device_create_info = {
19471947
vk::DeviceCreateFlags(),
@@ -1968,7 +1968,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
19681968

19691969
device->buffer_type = {
19701970
/* .iface = */ ggml_backend_vk_buffer_type_interface,
1971-
/* .device = */ nullptr,
1971+
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), idx),
19721972
/* .context = */ new ggml_backend_vk_buffer_type_context{ device->name, device },
19731973
};
19741974

@@ -6378,7 +6378,7 @@ ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type() {
63786378
/* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
63796379
/* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host,
63806380
},
6381-
/* .device = */ nullptr,
6381+
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), 0),
63826382
/* .context = */ nullptr,
63836383
};
63846384

@@ -6581,9 +6581,135 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
65816581
UNUSED(backend);
65826582
}
65836583

6584-
static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
6585-
// ggml_backend_vk_context * ctx = (ggml_backend_vk_context *) backend->context;
6584+
// TODO: enable async and synchronize
6585+
static ggml_backend_i ggml_backend_vk_interface = {
6586+
/* .get_name = */ ggml_backend_vk_name,
6587+
/* .free = */ ggml_backend_vk_free,
6588+
/* .get_default_buffer_type = */ ggml_backend_vk_get_default_buffer_type,
6589+
/* .set_tensor_async = */ NULL, // ggml_backend_vk_set_tensor_async,
6590+
/* .get_tensor_async = */ NULL, // ggml_backend_vk_get_tensor_async,
6591+
/* .cpy_tensor_async = */ NULL, // ggml_backend_vk_cpy_tensor_async,
6592+
/* .synchronize = */ NULL, // ggml_backend_vk_synchronize,
6593+
/* .graph_plan_create = */ NULL,
6594+
/* .graph_plan_free = */ NULL,
6595+
/* .graph_plan_update = */ NULL,
6596+
/* .graph_plan_compute = */ NULL,
6597+
/* .graph_compute = */ ggml_backend_vk_graph_compute,
6598+
/* .supports_op = */ NULL,
6599+
/* .supports_buft = */ NULL,
6600+
/* .offload_op = */ NULL,
6601+
/* .event_record = */ NULL,
6602+
/* .event_wait = */ NULL,
6603+
};
6604+
6605+
static ggml_guid_t ggml_backend_vk_guid() {
6606+
static ggml_guid guid = { 0xb8, 0xf7, 0x4f, 0x86, 0x40, 0x3c, 0xe1, 0x02, 0x91, 0xc8, 0xdd, 0xe9, 0x02, 0x3f, 0xc0, 0x2b };
6607+
return &guid;
6608+
}
6609+
6610+
ggml_backend_t ggml_backend_vk_init(size_t dev_num) {
6611+
VK_LOG_DEBUG("ggml_backend_vk_init(" << dev_num << ")");
6612+
6613+
ggml_backend_vk_context * ctx = new ggml_backend_vk_context;
6614+
ggml_vk_init(ctx, dev_num);
6615+
6616+
ggml_backend_t vk_backend = new ggml_backend {
6617+
/* .guid = */ ggml_backend_vk_guid(),
6618+
/* .interface = */ ggml_backend_vk_interface,
6619+
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), dev_num),
6620+
/* .context = */ ctx,
6621+
};
6622+
6623+
return vk_backend;
6624+
}
6625+
6626+
bool ggml_backend_is_vk(ggml_backend_t backend) {
6627+
return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_vk_guid());
6628+
}
6629+
6630+
int ggml_backend_vk_get_device_count() {
6631+
return ggml_vk_get_device_count();
6632+
}
6633+
6634+
void ggml_backend_vk_get_device_description(int device, char * description, size_t description_size) {
6635+
GGML_ASSERT(device < (int) vk_instance.device_indices.size());
6636+
int dev_idx = vk_instance.device_indices[device];
6637+
ggml_vk_get_device_description(dev_idx, description, description_size);
6638+
}
6639+
6640+
void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total) {
6641+
GGML_ASSERT(device < (int) vk_instance.device_indices.size());
6642+
6643+
vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device]];
6644+
6645+
vk::PhysicalDeviceMemoryProperties memprops = vkdev.getMemoryProperties();
6646+
6647+
for (const vk::MemoryHeap& heap : memprops.memoryHeaps) {
6648+
if (heap.flags & vk::MemoryHeapFlagBits::eDeviceLocal) {
6649+
*total = heap.size;
6650+
*free = heap.size;
6651+
break;
6652+
}
6653+
}
6654+
}
6655+
6656+
//////////////////////////
6657+
6658+
struct ggml_backend_vk_device_context {
6659+
int device;
6660+
std::string name;
6661+
std::string description;
6662+
};
6663+
6664+
static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) {
6665+
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
6666+
return ctx->name.c_str();
6667+
}
6668+
6669+
static const char * ggml_backend_vk_device_get_description(ggml_backend_dev_t dev) {
6670+
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
6671+
return ctx->description.c_str();
6672+
}
6673+
6674+
static void ggml_backend_vk_device_get_memory(ggml_backend_dev_t device, size_t * free, size_t * total) {
6675+
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)device->context;
6676+
ggml_backend_vk_get_device_memory(ctx->device, free, total);
6677+
}
6678+
6679+
static ggml_backend_buffer_type_t ggml_backend_vk_device_get_buffer_type(ggml_backend_dev_t dev) {
6680+
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
6681+
return ggml_backend_vk_buffer_type(ctx->device);
6682+
}
6683+
6684+
static ggml_backend_buffer_type_t ggml_backend_vk_device_get_host_buffer_type(ggml_backend_dev_t dev) {
6685+
UNUSED(dev);
6686+
return ggml_backend_vk_host_buffer_type();
6687+
}
65866688

6689+
static enum ggml_backend_dev_type ggml_backend_vk_device_get_type(ggml_backend_dev_t dev) {
6690+
UNUSED(dev);
6691+
return GGML_BACKEND_DEVICE_TYPE_GPU_FULL;
6692+
}
6693+
6694+
static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
6695+
props->name = ggml_backend_vk_device_get_name(dev);
6696+
props->description = ggml_backend_vk_device_get_description(dev);
6697+
props->type = ggml_backend_vk_device_get_type(dev);
6698+
ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total);
6699+
props->caps = {
6700+
/* async */ false,
6701+
/* host_buffer */ true,
6702+
/* events */ false,
6703+
};
6704+
}
6705+
6706+
static ggml_backend_t ggml_backend_vk_device_init(ggml_backend_dev_t dev, const char * params) {
6707+
UNUSED(params);
6708+
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
6709+
return ggml_backend_vk_init(ctx->device);
6710+
}
6711+
6712+
static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
65876713
switch (op->op) {
65886714
case GGML_OP_UNARY:
65896715
switch (ggml_get_unary_op(op)) {
@@ -6701,97 +6827,101 @@ static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const ggml_tenso
67016827
return false;
67026828
}
67036829

6704-
UNUSED(backend);
6705-
}
6706-
6707-
static bool ggml_backend_vk_offload_op(ggml_backend_t backend, const ggml_tensor * op) {
6708-
const int min_batch_size = 32;
6709-
6710-
return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) ||
6711-
(op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID);
6712-
6713-
UNUSED(backend);
6830+
UNUSED(dev);
67146831
}
67156832

6716-
static bool ggml_backend_vk_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
6833+
static bool ggml_backend_vk_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
67176834
if (buft->iface.get_name != ggml_backend_vk_buffer_type_name) {
67186835
return false;
67196836
}
67206837

6838+
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
67216839
ggml_backend_vk_buffer_type_context * buft_ctx = (ggml_backend_vk_buffer_type_context *)buft->context;
6722-
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
6723-
6724-
return buft_ctx->device == ctx->device;
6725-
}
6726-
6727-
// TODO: enable async and synchronize
6728-
static ggml_backend_i ggml_backend_vk_interface = {
6729-
/* .get_name = */ ggml_backend_vk_name,
6730-
/* .free = */ ggml_backend_vk_free,
6731-
/* .get_default_buffer_type = */ ggml_backend_vk_get_default_buffer_type,
6732-
/* .set_tensor_async = */ NULL, // ggml_backend_vk_set_tensor_async,
6733-
/* .get_tensor_async = */ NULL, // ggml_backend_vk_get_tensor_async,
6734-
/* .cpy_tensor_async = */ NULL, // ggml_backend_vk_cpy_tensor_async,
6735-
/* .synchronize = */ NULL, // ggml_backend_vk_synchronize,
6736-
/* .graph_plan_create = */ NULL,
6737-
/* .graph_plan_free = */ NULL,
6738-
/* .graph_plan_update = */ NULL,
6739-
/* .graph_plan_compute = */ NULL,
6740-
/* .graph_compute = */ ggml_backend_vk_graph_compute,
6741-
/* .supports_op = */ ggml_backend_vk_supports_op,
6742-
/* .supports_buft = */ ggml_backend_vk_supports_buft,
6743-
/* .offload_op = */ ggml_backend_vk_offload_op,
6744-
/* .event_record = */ NULL,
6745-
/* .event_wait = */ NULL,
6746-
};
67476840

6748-
static ggml_guid_t ggml_backend_vk_guid() {
6749-
static ggml_guid guid = { 0xb8, 0xf7, 0x4f, 0x86, 0x40, 0x3c, 0xe1, 0x02, 0x91, 0xc8, 0xdd, 0xe9, 0x02, 0x3f, 0xc0, 0x2b };
6750-
return &guid;
6841+
return buft_ctx->device->idx == ctx->device;
67516842
}
67526843

6753-
ggml_backend_t ggml_backend_vk_init(size_t dev_num) {
6754-
VK_LOG_DEBUG("ggml_backend_vk_init(" << dev_num << ")");
6844+
static bool ggml_backend_vk_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
6845+
const int min_batch_size = 32;
67556846

6756-
ggml_backend_vk_context * ctx = new ggml_backend_vk_context;
6757-
ggml_vk_init(ctx, dev_num);
6847+
return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) ||
6848+
(op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID);
67586849

6759-
ggml_backend_t vk_backend = new ggml_backend {
6760-
/* .guid = */ ggml_backend_vk_guid(),
6761-
/* .interface = */ ggml_backend_vk_interface,
6762-
/* .device = */ nullptr,
6763-
/* .context = */ ctx,
6764-
};
6850+
UNUSED(dev);
6851+
}
6852+
6853+
static const struct ggml_backend_device_i ggml_backend_vk_device_i = {
6854+
/* .get_name = */ ggml_backend_vk_device_get_name,
6855+
/* .get_description = */ ggml_backend_vk_device_get_description,
6856+
/* .get_memory = */ ggml_backend_vk_device_get_memory,
6857+
/* .get_type = */ ggml_backend_vk_device_get_type,
6858+
/* .get_props = */ ggml_backend_vk_device_get_props,
6859+
/* .init_backend = */ ggml_backend_vk_device_init,
6860+
/* .get_buffer_type = */ ggml_backend_vk_device_get_buffer_type,
6861+
/* .get_host_buffer_type = */ ggml_backend_vk_device_get_host_buffer_type,
6862+
/* .buffer_from_host_ptr = */ NULL,
6863+
/* .supports_op = */ ggml_backend_vk_device_supports_op,
6864+
/* .supports_buft = */ ggml_backend_vk_device_supports_buft,
6865+
/* .offload_op = */ ggml_backend_vk_device_offload_op,
6866+
/* .event_new = */ NULL,
6867+
/* .event_free = */ NULL,
6868+
/* .event_synchronize = */ NULL,
6869+
};
67656870

6766-
return vk_backend;
6871+
static const char * ggml_backend_vk_reg_get_name(ggml_backend_reg_t reg) {
6872+
UNUSED(reg);
6873+
return GGML_VK_NAME;
67676874
}
67686875

6769-
bool ggml_backend_is_vk(ggml_backend_t backend) {
6770-
return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_vk_guid());
6876+
static size_t ggml_backend_vk_reg_get_device_count(ggml_backend_reg_t reg) {
6877+
UNUSED(reg);
6878+
return ggml_backend_vk_get_device_count();
67716879
}
67726880

6773-
int ggml_backend_vk_get_device_count() {
6774-
return ggml_vk_get_device_count();
6775-
}
6881+
static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, size_t device) {
6882+
static std::vector<ggml_backend_dev_t> devices;
67766883

6777-
void ggml_backend_vk_get_device_description(int device, char * description, size_t description_size) {
6778-
ggml_vk_get_device_description(device, description, description_size);
6779-
}
6884+
static bool initialized = false;
67806885

6781-
void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total) {
6782-
GGML_ASSERT(device < (int) vk_instance.device_indices.size());
6886+
{
6887+
static std::mutex mutex;
6888+
std::lock_guard<std::mutex> lock(mutex);
6889+
if (!initialized) {
6890+
for (size_t i = 0; i < ggml_backend_vk_get_device_count(); i++) {
6891+
ggml_backend_vk_device_context * ctx = new ggml_backend_vk_device_context;
6892+
char desc[256];
6893+
ggml_backend_vk_get_device_description(i, desc, sizeof(desc));
6894+
ctx->device = i;
6895+
ctx->name = GGML_VK_NAME + std::to_string(i);
6896+
ctx->description = desc;
6897+
devices.push_back(new ggml_backend_device {
6898+
/* .iface = */ ggml_backend_vk_device_i,
6899+
/* .reg = */ reg,
6900+
/* .context = */ ctx,
6901+
});
6902+
}
6903+
initialized = true;
6904+
}
6905+
}
67836906

6784-
vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device]];
6907+
GGML_ASSERT(device < devices.size());
6908+
return devices[device];
6909+
}
67856910

6786-
vk::PhysicalDeviceMemoryProperties memprops = vkdev.getMemoryProperties();
6911+
static const struct ggml_backend_reg_i ggml_backend_vk_reg_i = {
6912+
/* .get_name = */ ggml_backend_vk_reg_get_name,
6913+
/* .get_device_count = */ ggml_backend_vk_reg_get_device_count,
6914+
/* .get_device = */ ggml_backend_vk_reg_get_device,
6915+
/* .get_proc_address = */ NULL,
6916+
};
67876917

6788-
for (const vk::MemoryHeap& heap : memprops.memoryHeaps) {
6789-
if (heap.flags & vk::MemoryHeapFlagBits::eDeviceLocal) {
6790-
*total = heap.size;
6791-
*free = heap.size;
6792-
break;
6793-
}
6794-
}
6918+
ggml_backend_reg_t ggml_backend_vk_reg() {
6919+
static ggml_backend_reg reg = {
6920+
/* .iface = */ ggml_backend_vk_reg_i,
6921+
/* .context = */ nullptr,
6922+
};
6923+
6924+
return &reg;
67956925
}
67966926

67976927
// Extension availability

0 commit comments

Comments
 (0)