@@ -1941,7 +1941,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
1941
1941
if (device->fp16 ) {
1942
1942
device_extensions.push_back (" VK_KHR_shader_float16_int8" );
1943
1943
}
1944
- device->name = device-> properties . deviceName . data ( );
1944
+ device->name = GGML_VK_NAME + std::to_string (idx );
1945
1945
1946
1946
device_create_info = {
1947
1947
vk::DeviceCreateFlags (),
@@ -1968,7 +1968,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
1968
1968
1969
1969
device->buffer_type = {
1970
1970
/* .iface = */ ggml_backend_vk_buffer_type_interface,
1971
- /* .device = */ nullptr ,
1971
+ /* .device = */ ggml_backend_reg_dev_get ( ggml_backend_vk_reg (), idx) ,
1972
1972
/* .context = */ new ggml_backend_vk_buffer_type_context{ device->name , device },
1973
1973
};
1974
1974
@@ -6378,7 +6378,7 @@ ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type() {
6378
6378
/* .get_alloc_size = */ ggml_backend_cpu_buffer_type ()->iface .get_alloc_size ,
6379
6379
/* .is_host = */ ggml_backend_cpu_buffer_type ()->iface .is_host ,
6380
6380
},
6381
- /* .device = */ nullptr ,
6381
+ /* .device = */ ggml_backend_reg_dev_get ( ggml_backend_vk_reg (), 0 ) ,
6382
6382
/* .context = */ nullptr ,
6383
6383
};
6384
6384
@@ -6581,9 +6581,135 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
6581
6581
UNUSED (backend);
6582
6582
}
6583
6583
6584
- static bool ggml_backend_vk_supports_op (ggml_backend_t backend, const ggml_tensor * op) {
6585
- // ggml_backend_vk_context * ctx = (ggml_backend_vk_context *) backend->context;
6584
+ // TODO: enable async and synchronize
6585
+ static ggml_backend_i ggml_backend_vk_interface = {
6586
+ /* .get_name = */ ggml_backend_vk_name,
6587
+ /* .free = */ ggml_backend_vk_free,
6588
+ /* .get_default_buffer_type = */ ggml_backend_vk_get_default_buffer_type,
6589
+ /* .set_tensor_async = */ NULL , // ggml_backend_vk_set_tensor_async,
6590
+ /* .get_tensor_async = */ NULL , // ggml_backend_vk_get_tensor_async,
6591
+ /* .cpy_tensor_async = */ NULL , // ggml_backend_vk_cpy_tensor_async,
6592
+ /* .synchronize = */ NULL , // ggml_backend_vk_synchronize,
6593
+ /* .graph_plan_create = */ NULL ,
6594
+ /* .graph_plan_free = */ NULL ,
6595
+ /* .graph_plan_update = */ NULL ,
6596
+ /* .graph_plan_compute = */ NULL ,
6597
+ /* .graph_compute = */ ggml_backend_vk_graph_compute,
6598
+ /* .supports_op = */ NULL ,
6599
+ /* .supports_buft = */ NULL ,
6600
+ /* .offload_op = */ NULL ,
6601
+ /* .event_record = */ NULL ,
6602
+ /* .event_wait = */ NULL ,
6603
+ };
6604
+
6605
+ static ggml_guid_t ggml_backend_vk_guid () {
6606
+ static ggml_guid guid = { 0xb8 , 0xf7 , 0x4f , 0x86 , 0x40 , 0x3c , 0xe1 , 0x02 , 0x91 , 0xc8 , 0xdd , 0xe9 , 0x02 , 0x3f , 0xc0 , 0x2b };
6607
+ return &guid;
6608
+ }
6609
+
6610
+ ggml_backend_t ggml_backend_vk_init (size_t dev_num) {
6611
+ VK_LOG_DEBUG (" ggml_backend_vk_init(" << dev_num << " )" );
6612
+
6613
+ ggml_backend_vk_context * ctx = new ggml_backend_vk_context;
6614
+ ggml_vk_init (ctx, dev_num);
6615
+
6616
+ ggml_backend_t vk_backend = new ggml_backend {
6617
+ /* .guid = */ ggml_backend_vk_guid (),
6618
+ /* .interface = */ ggml_backend_vk_interface,
6619
+ /* .device = */ ggml_backend_reg_dev_get (ggml_backend_vk_reg (), dev_num),
6620
+ /* .context = */ ctx,
6621
+ };
6622
+
6623
+ return vk_backend;
6624
+ }
6625
+
6626
+ bool ggml_backend_is_vk (ggml_backend_t backend) {
6627
+ return backend != NULL && ggml_guid_matches (backend->guid , ggml_backend_vk_guid ());
6628
+ }
6629
+
6630
+ int ggml_backend_vk_get_device_count () {
6631
+ return ggml_vk_get_device_count ();
6632
+ }
6633
+
6634
+ void ggml_backend_vk_get_device_description (int device, char * description, size_t description_size) {
6635
+ GGML_ASSERT (device < (int ) vk_instance.device_indices .size ());
6636
+ int dev_idx = vk_instance.device_indices [device];
6637
+ ggml_vk_get_device_description (dev_idx, description, description_size);
6638
+ }
6639
+
6640
+ void ggml_backend_vk_get_device_memory (int device, size_t * free, size_t * total) {
6641
+ GGML_ASSERT (device < (int ) vk_instance.device_indices .size ());
6642
+
6643
+ vk::PhysicalDevice vkdev = vk_instance.instance .enumeratePhysicalDevices ()[vk_instance.device_indices [device]];
6644
+
6645
+ vk::PhysicalDeviceMemoryProperties memprops = vkdev.getMemoryProperties ();
6646
+
6647
+ for (const vk::MemoryHeap& heap : memprops.memoryHeaps ) {
6648
+ if (heap.flags & vk::MemoryHeapFlagBits::eDeviceLocal) {
6649
+ *total = heap.size ;
6650
+ *free = heap.size ;
6651
+ break ;
6652
+ }
6653
+ }
6654
+ }
6655
+
6656
+ // ////////////////////////
6657
+
6658
+ struct ggml_backend_vk_device_context {
6659
+ int device;
6660
+ std::string name;
6661
+ std::string description;
6662
+ };
6663
+
6664
+ static const char * ggml_backend_vk_device_get_name (ggml_backend_dev_t dev) {
6665
+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context ;
6666
+ return ctx->name .c_str ();
6667
+ }
6668
+
6669
+ static const char * ggml_backend_vk_device_get_description (ggml_backend_dev_t dev) {
6670
+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context ;
6671
+ return ctx->description .c_str ();
6672
+ }
6673
+
6674
+ static void ggml_backend_vk_device_get_memory (ggml_backend_dev_t device, size_t * free, size_t * total) {
6675
+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)device->context ;
6676
+ ggml_backend_vk_get_device_memory (ctx->device , free, total);
6677
+ }
6678
+
6679
+ static ggml_backend_buffer_type_t ggml_backend_vk_device_get_buffer_type (ggml_backend_dev_t dev) {
6680
+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context ;
6681
+ return ggml_backend_vk_buffer_type (ctx->device );
6682
+ }
6683
+
6684
+ static ggml_backend_buffer_type_t ggml_backend_vk_device_get_host_buffer_type (ggml_backend_dev_t dev) {
6685
+ UNUSED (dev);
6686
+ return ggml_backend_vk_host_buffer_type ();
6687
+ }
6586
6688
6689
+ static enum ggml_backend_dev_type ggml_backend_vk_device_get_type (ggml_backend_dev_t dev) {
6690
+ UNUSED (dev);
6691
+ return GGML_BACKEND_DEVICE_TYPE_GPU_FULL;
6692
+ }
6693
+
6694
+ static void ggml_backend_vk_device_get_props (ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
6695
+ props->name = ggml_backend_vk_device_get_name (dev);
6696
+ props->description = ggml_backend_vk_device_get_description (dev);
6697
+ props->type = ggml_backend_vk_device_get_type (dev);
6698
+ ggml_backend_vk_device_get_memory (dev, &props->memory_free , &props->memory_total );
6699
+ props->caps = {
6700
+ /* async */ false ,
6701
+ /* host_buffer */ true ,
6702
+ /* events */ false ,
6703
+ };
6704
+ }
6705
+
6706
+ static ggml_backend_t ggml_backend_vk_device_init (ggml_backend_dev_t dev, const char * params) {
6707
+ UNUSED (params);
6708
+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context ;
6709
+ return ggml_backend_vk_init (ctx->device );
6710
+ }
6711
+
6712
+ static bool ggml_backend_vk_device_supports_op (ggml_backend_dev_t dev, const ggml_tensor * op) {
6587
6713
switch (op->op ) {
6588
6714
case GGML_OP_UNARY:
6589
6715
switch (ggml_get_unary_op (op)) {
@@ -6701,97 +6827,101 @@ static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const ggml_tenso
6701
6827
return false ;
6702
6828
}
6703
6829
6704
- UNUSED (backend);
6705
- }
6706
-
6707
- static bool ggml_backend_vk_offload_op (ggml_backend_t backend, const ggml_tensor * op) {
6708
- const int min_batch_size = 32 ;
6709
-
6710
- return (op->ne [1 ] >= min_batch_size && op->op != GGML_OP_GET_ROWS) ||
6711
- (op->ne [2 ] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID);
6712
-
6713
- UNUSED (backend);
6830
+ UNUSED (dev);
6714
6831
}
6715
6832
6716
- static bool ggml_backend_vk_supports_buft ( ggml_backend_t backend , ggml_backend_buffer_type_t buft) {
6833
+ static bool ggml_backend_vk_device_supports_buft ( ggml_backend_dev_t dev , ggml_backend_buffer_type_t buft) {
6717
6834
if (buft->iface .get_name != ggml_backend_vk_buffer_type_name) {
6718
6835
return false ;
6719
6836
}
6720
6837
6838
+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context ;
6721
6839
ggml_backend_vk_buffer_type_context * buft_ctx = (ggml_backend_vk_buffer_type_context *)buft->context ;
6722
- ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context ;
6723
-
6724
- return buft_ctx->device == ctx->device ;
6725
- }
6726
-
6727
- // TODO: enable async and synchronize
6728
- static ggml_backend_i ggml_backend_vk_interface = {
6729
- /* .get_name = */ ggml_backend_vk_name,
6730
- /* .free = */ ggml_backend_vk_free,
6731
- /* .get_default_buffer_type = */ ggml_backend_vk_get_default_buffer_type,
6732
- /* .set_tensor_async = */ NULL , // ggml_backend_vk_set_tensor_async,
6733
- /* .get_tensor_async = */ NULL , // ggml_backend_vk_get_tensor_async,
6734
- /* .cpy_tensor_async = */ NULL , // ggml_backend_vk_cpy_tensor_async,
6735
- /* .synchronize = */ NULL , // ggml_backend_vk_synchronize,
6736
- /* .graph_plan_create = */ NULL ,
6737
- /* .graph_plan_free = */ NULL ,
6738
- /* .graph_plan_update = */ NULL ,
6739
- /* .graph_plan_compute = */ NULL ,
6740
- /* .graph_compute = */ ggml_backend_vk_graph_compute,
6741
- /* .supports_op = */ ggml_backend_vk_supports_op,
6742
- /* .supports_buft = */ ggml_backend_vk_supports_buft,
6743
- /* .offload_op = */ ggml_backend_vk_offload_op,
6744
- /* .event_record = */ NULL ,
6745
- /* .event_wait = */ NULL ,
6746
- };
6747
6840
6748
- static ggml_guid_t ggml_backend_vk_guid () {
6749
- static ggml_guid guid = { 0xb8 , 0xf7 , 0x4f , 0x86 , 0x40 , 0x3c , 0xe1 , 0x02 , 0x91 , 0xc8 , 0xdd , 0xe9 , 0x02 , 0x3f , 0xc0 , 0x2b };
6750
- return &guid;
6841
+ return buft_ctx->device ->idx == ctx->device ;
6751
6842
}
6752
6843
6753
- ggml_backend_t ggml_backend_vk_init ( size_t dev_num ) {
6754
- VK_LOG_DEBUG ( " ggml_backend_vk_init( " << dev_num << " ) " ) ;
6844
+ static bool ggml_backend_vk_device_offload_op ( ggml_backend_dev_t dev, const ggml_tensor * op ) {
6845
+ const int min_batch_size = 32 ;
6755
6846
6756
- ggml_backend_vk_context * ctx = new ggml_backend_vk_context;
6757
- ggml_vk_init (ctx, dev_num );
6847
+ return (op-> ne [ 1 ] >= min_batch_size && op-> op != GGML_OP_GET_ROWS) ||
6848
+ (op-> ne [ 2 ] >= min_batch_size && op-> op == GGML_OP_MUL_MAT_ID );
6758
6849
6759
- ggml_backend_t vk_backend = new ggml_backend {
6760
- /* .guid = */ ggml_backend_vk_guid (),
6761
- /* .interface = */ ggml_backend_vk_interface,
6762
- /* .device = */ nullptr ,
6763
- /* .context = */ ctx,
6764
- };
6850
+ UNUSED (dev);
6851
+ }
6852
+
6853
+ static const struct ggml_backend_device_i ggml_backend_vk_device_i = {
6854
+ /* .get_name = */ ggml_backend_vk_device_get_name,
6855
+ /* .get_description = */ ggml_backend_vk_device_get_description,
6856
+ /* .get_memory = */ ggml_backend_vk_device_get_memory,
6857
+ /* .get_type = */ ggml_backend_vk_device_get_type,
6858
+ /* .get_props = */ ggml_backend_vk_device_get_props,
6859
+ /* .init_backend = */ ggml_backend_vk_device_init,
6860
+ /* .get_buffer_type = */ ggml_backend_vk_device_get_buffer_type,
6861
+ /* .get_host_buffer_type = */ ggml_backend_vk_device_get_host_buffer_type,
6862
+ /* .buffer_from_host_ptr = */ NULL ,
6863
+ /* .supports_op = */ ggml_backend_vk_device_supports_op,
6864
+ /* .supports_buft = */ ggml_backend_vk_device_supports_buft,
6865
+ /* .offload_op = */ ggml_backend_vk_device_offload_op,
6866
+ /* .event_new = */ NULL ,
6867
+ /* .event_free = */ NULL ,
6868
+ /* .event_synchronize = */ NULL ,
6869
+ };
6765
6870
6766
- return vk_backend;
6871
+ static const char * ggml_backend_vk_reg_get_name (ggml_backend_reg_t reg) {
6872
+ UNUSED (reg);
6873
+ return GGML_VK_NAME;
6767
6874
}
6768
6875
6769
- bool ggml_backend_is_vk (ggml_backend_t backend) {
6770
- return backend != NULL && ggml_guid_matches (backend->guid , ggml_backend_vk_guid ());
6876
+ static size_t ggml_backend_vk_reg_get_device_count (ggml_backend_reg_t reg) {
6877
+ UNUSED (reg);
6878
+ return ggml_backend_vk_get_device_count ();
6771
6879
}
6772
6880
6773
- int ggml_backend_vk_get_device_count () {
6774
- return ggml_vk_get_device_count ();
6775
- }
6881
+ static ggml_backend_dev_t ggml_backend_vk_reg_get_device (ggml_backend_reg_t reg, size_t device) {
6882
+ static std::vector<ggml_backend_dev_t > devices;
6776
6883
6777
- void ggml_backend_vk_get_device_description (int device, char * description, size_t description_size) {
6778
- ggml_vk_get_device_description (device, description, description_size);
6779
- }
6884
+ static bool initialized = false ;
6780
6885
6781
- void ggml_backend_vk_get_device_memory (int device, size_t * free, size_t * total) {
6782
- GGML_ASSERT (device < (int ) vk_instance.device_indices .size ());
6886
+ {
6887
+ static std::mutex mutex;
6888
+ std::lock_guard<std::mutex> lock (mutex);
6889
+ if (!initialized) {
6890
+ for (size_t i = 0 ; i < ggml_backend_vk_get_device_count (); i++) {
6891
+ ggml_backend_vk_device_context * ctx = new ggml_backend_vk_device_context;
6892
+ char desc[256 ];
6893
+ ggml_backend_vk_get_device_description (i, desc, sizeof (desc));
6894
+ ctx->device = i;
6895
+ ctx->name = GGML_VK_NAME + std::to_string (i);
6896
+ ctx->description = desc;
6897
+ devices.push_back (new ggml_backend_device {
6898
+ /* .iface = */ ggml_backend_vk_device_i,
6899
+ /* .reg = */ reg,
6900
+ /* .context = */ ctx,
6901
+ });
6902
+ }
6903
+ initialized = true ;
6904
+ }
6905
+ }
6783
6906
6784
- vk::PhysicalDevice vkdev = vk_instance.instance .enumeratePhysicalDevices ()[vk_instance.device_indices [device]];
6907
+ GGML_ASSERT (device < devices.size ());
6908
+ return devices[device];
6909
+ }
6785
6910
6786
- vk::PhysicalDeviceMemoryProperties memprops = vkdev.getMemoryProperties ();
6911
+ static const struct ggml_backend_reg_i ggml_backend_vk_reg_i = {
6912
+ /* .get_name = */ ggml_backend_vk_reg_get_name,
6913
+ /* .get_device_count = */ ggml_backend_vk_reg_get_device_count,
6914
+ /* .get_device = */ ggml_backend_vk_reg_get_device,
6915
+ /* .get_proc_address = */ NULL ,
6916
+ };
6787
6917
6788
- for ( const vk::MemoryHeap& heap : memprops. memoryHeaps ) {
6789
- if (heap. flags & vk::MemoryHeapFlagBits::eDeviceLocal) {
6790
- *total = heap. size ;
6791
- *free = heap. size ;
6792
- break ;
6793
- }
6794
- }
6918
+ ggml_backend_reg_t ggml_backend_vk_reg ( ) {
6919
+ static ggml_backend_reg reg = {
6920
+ /* .iface = */ ggml_backend_vk_reg_i,
6921
+ /* .context = */ nullptr ,
6922
+ } ;
6923
+
6924
+ return ®
6795
6925
}
6796
6926
6797
6927
// Extension availability
0 commit comments