224
224
bool support_simdgroup_mm;
225
225
226
226
bool should_capture_next_compute;
227
+
228
+ // abort ggml_metal_graph_compute if callback returns true
229
+ ggml_abort_callback abort_callback;
230
+ void * abort_callback_data;
227
231
};
228
232
229
233
// MSL code
@@ -878,8 +882,11 @@ static enum ggml_status ggml_metal_graph_compute(
878
882
id <MTLCommandBuffer > command_buffer = [ctx->queue commandBufferWithUnretainedReferences ];
879
883
command_buffer_builder[cb_idx] = command_buffer;
880
884
881
- // enqueue the command buffers in order to specify their execution order
882
- [command_buffer enqueue ];
885
+ // always enqueue the first two command buffers
886
+ // enqueue all of the command buffers if we don't need to abort
887
+ if (cb_idx < 2 || ctx->abort_callback == NULL ) {
888
+ [command_buffer enqueue ];
889
+ }
883
890
}
884
891
885
892
const id <MTLCommandBuffer > *command_buffers = command_buffer_builder;
@@ -2827,7 +2834,9 @@ static enum ggml_status ggml_metal_graph_compute(
2827
2834
2828
2835
[encoder endEncoding ];
2829
2836
2830
- [command_buffer commit ];
2837
+ if (cb_idx < 2 || ctx->abort_callback == NULL ) {
2838
+ [command_buffer commit ];
2839
+ }
2831
2840
});
2832
2841
2833
2842
// Wait for completion and check status of each command buffer
@@ -2847,6 +2856,23 @@ static enum ggml_status ggml_metal_graph_compute(
2847
2856
2848
2857
return GGML_STATUS_FAILED;
2849
2858
}
2859
+
2860
+ id <MTLCommandBuffer > next_buffer = (i + 1 < n_cb ? command_buffers[i + 1 ] : nil );
2861
+ if (!next_buffer) {
2862
+ continue ;
2863
+ }
2864
+
2865
+ bool next_queued = ([next_buffer status ] != MTLCommandBufferStatusNotEnqueued );
2866
+ if (next_queued) {
2867
+ continue ;
2868
+ }
2869
+
2870
+ if (ctx->abort_callback && ctx->abort_callback (ctx->abort_callback_data )) {
2871
+ GGML_METAL_LOG_INFO (" %s : command buffer %d aborted" , __func__, i);
2872
+ return GGML_STATUS_ABORTED;
2873
+ }
2874
+
2875
+ [next_buffer commit ];
2850
2876
}
2851
2877
2852
2878
if (should_capture) {
@@ -3242,6 +3268,15 @@ void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
3242
3268
ctx->n_cb = MIN (n_cb, GGML_METAL_MAX_BUFFERS);
3243
3269
}
3244
3270
3271
+ void ggml_backend_metal_set_abort_callback (ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data) {
3272
+ GGML_ASSERT (ggml_backend_is_metal (backend));
3273
+
3274
+ struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context ;
3275
+
3276
+ ctx->abort_callback = abort_callback;
3277
+ ctx->abort_callback_data = user_data;
3278
+ }
3279
+
3245
3280
bool ggml_backend_metal_supports_family (ggml_backend_t backend, int family) {
3246
3281
GGML_ASSERT (ggml_backend_is_metal (backend));
3247
3282
0 commit comments