Skip to content

Commit 85fca8d

Browse files
conradevggerganov
authored andcommitted
metal : add abort callback (ggml/905)
1 parent ebd541a commit 85fca8d

File tree

2 files changed

+40
-3
lines changed

2 files changed

+40
-3
lines changed

ggml/include/ggml-metal.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void
5050

5151
GGML_API void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb);
5252

53+
GGML_API void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data);
54+
5355
GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void);
5456

5557
// helper to check if the device supports a specific family

ggml/src/ggml-metal.m

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,10 @@
224224
bool support_simdgroup_mm;
225225

226226
bool should_capture_next_compute;
227+
228+
// abort ggml_metal_graph_compute if callback returns true
229+
ggml_abort_callback abort_callback;
230+
void * abort_callback_data;
227231
};
228232

229233
// MSL code
@@ -878,8 +882,11 @@ static enum ggml_status ggml_metal_graph_compute(
878882
id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
879883
command_buffer_builder[cb_idx] = command_buffer;
880884

881-
// enqueue the command buffers in order to specify their execution order
882-
[command_buffer enqueue];
885+
// always enqueue the first two command buffers
886+
// enqueue all of the command buffers if we don't need to abort
887+
if (cb_idx < 2 || ctx->abort_callback == NULL) {
888+
[command_buffer enqueue];
889+
}
883890
}
884891

885892
const id<MTLCommandBuffer> *command_buffers = command_buffer_builder;
@@ -2827,7 +2834,9 @@ static enum ggml_status ggml_metal_graph_compute(
28272834

28282835
[encoder endEncoding];
28292836

2830-
[command_buffer commit];
2837+
if (cb_idx < 2 || ctx->abort_callback == NULL) {
2838+
[command_buffer commit];
2839+
}
28312840
});
28322841

28332842
// Wait for completion and check status of each command buffer
@@ -2847,6 +2856,23 @@ static enum ggml_status ggml_metal_graph_compute(
28472856

28482857
return GGML_STATUS_FAILED;
28492858
}
2859+
2860+
id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? command_buffers[i + 1] : nil);
2861+
if (!next_buffer) {
2862+
continue;
2863+
}
2864+
2865+
bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued);
2866+
if (next_queued) {
2867+
continue;
2868+
}
2869+
2870+
if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) {
2871+
GGML_METAL_LOG_INFO("%s: command buffer %d aborted", __func__, i);
2872+
return GGML_STATUS_ABORTED;
2873+
}
2874+
2875+
[next_buffer commit];
28502876
}
28512877

28522878
if (should_capture) {
@@ -3242,6 +3268,15 @@ void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
32423268
ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
32433269
}
32443270

3271+
void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data) {
3272+
GGML_ASSERT(ggml_backend_is_metal(backend));
3273+
3274+
struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
3275+
3276+
ctx->abort_callback = abort_callback;
3277+
ctx->abort_callback_data = user_data;
3278+
}
3279+
32453280
bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
32463281
GGML_ASSERT(ggml_backend_is_metal(backend));
32473282

0 commit comments

Comments
 (0)