Skip to content

Commit f727124

Browse files
committed
simplify batch ptr
1 parent f8c5394 commit f727124

File tree

2 files changed

+52
-98
lines changed

2 files changed

+52
-98
lines changed

examples/llava/clip-impl.h

Lines changed: 6 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -195,70 +195,24 @@ struct clip_image_size_deleter {
195195
};
196196
typedef std::unique_ptr<clip_image_size, clip_image_size_deleter> clip_image_size_ptr;
197197

198-
// use composition to avoid problems with inheritance from STL classes
199-
template <typename T, typename Initializer, typename Deleter>
200-
struct clip_image_buffer_base {
201-
std::unique_ptr<T, Deleter> ptr;
202-
clip_image_buffer_base() : ptr(Initializer()()) {}
203-
explicit clip_image_buffer_base(T* p) : ptr(p) {}
204-
clip_image_buffer_base(const clip_image_buffer_base& other) = delete;
205-
clip_image_buffer_base& operator=(const clip_image_buffer_base& other) = delete;
206-
clip_image_buffer_base(clip_image_buffer_base&& other) noexcept = default;
207-
clip_image_buffer_base& operator=(clip_image_buffer_base&& other) noexcept = default;
208-
~clip_image_buffer_base() = default;
209-
void reset(T* p = nullptr) { ptr.reset(p); }
210-
T* get() const noexcept { return ptr.get(); }
211-
T& operator*() const { return *ptr; }
212-
T* operator->() const noexcept { return ptr.get(); }
213-
explicit operator bool() const noexcept { return static_cast<bool>(ptr); }
214-
};
215-
216198
// wrapper for clip_image_u8
217-
struct clip_image_u8_initializer {
218-
clip_image_u8 * operator()() { return clip_image_u8_init(); }
219-
};
220199
struct clip_image_u8_deleter {
221200
void operator()(clip_image_u8 * val) { clip_image_u8_free(val); }
222201
};
223-
using clip_image_u8_ptr = clip_image_buffer_base<clip_image_u8, clip_image_u8_initializer, clip_image_u8_deleter>;
202+
typedef std::unique_ptr<clip_image_u8, clip_image_u8_deleter> clip_image_u8_ptr;
224203

225204
// wrapper for clip_image_f32
226-
struct clip_image_f32_initializer {
227-
clip_image_f32 * operator()() { return clip_image_f32_init(); }
228-
};
229205
struct clip_image_f32_deleter {
230206
void operator()(clip_image_f32 * val) { clip_image_f32_free(val); }
231207
};
232-
using clip_image_f32_ptr = clip_image_buffer_base<clip_image_f32, clip_image_f32_initializer, clip_image_f32_deleter>;
233-
234-
// use composition to avoid problems with inheritance from STL classes
235-
template <typename ImagePtrType>
236-
struct clip_image_batch_base {
237-
std::vector<ImagePtrType> images;
238-
clip_image_batch_base() = default;
239-
void push_back(ImagePtrType&& value) { images.push_back(std::move(value)); }
240-
void clear() noexcept { images.clear(); }
241-
void reserve(size_t n) { images.reserve(n); }
242-
243-
// Capacity
244-
size_t size() const noexcept { return images.size(); }
245-
bool empty() const noexcept { return images.empty(); }
246-
247-
// Element access
248-
ImagePtrType& operator[](size_t pos) { return images[pos]; }
249-
const ImagePtrType& operator[](size_t pos) const { return images[pos]; }
250-
ImagePtrType& at(size_t pos) { return images.at(pos); }
251-
const ImagePtrType& at(size_t pos) const { return images.at(pos); }
252-
};
208+
typedef std::unique_ptr<clip_image_f32, clip_image_f32_deleter> clip_image_f32_ptr;
253209

254-
struct clip_image_u8_batch : clip_image_batch_base<clip_image_u8_ptr> {
255-
clip_image_u8_batch() = default;
256-
~clip_image_u8_batch() = default;
210+
struct clip_image_u8_batch {
211+
std::vector<clip_image_u8_ptr> entries;
257212
};
258213

259-
struct clip_image_f32_batch : clip_image_batch_base<clip_image_f32_ptr> {
260-
clip_image_f32_batch() = default;
261-
~clip_image_f32_batch() = default;
214+
struct clip_image_f32_batch {
215+
std::vector<clip_image_f32_ptr> entries;
262216
};
263217

264218

examples/llava/clip.cpp

Lines changed: 46 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_im
388388
const int n_layer = hparams.n_layer;
389389
const float eps = hparams.eps;
390390

391-
GGML_ASSERT(imgs.size() == 1); // batch_size == 1
391+
GGML_ASSERT(imgs.entries.size() == 1); // batch_size == 1
392392

393393
struct ggml_init_params params = {
394394
/*.mem_size =*/ ctx->buf_compute_meta.size(),
@@ -540,16 +540,16 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
540540
image_size_width = load_image_size.width;
541541
image_size_height = load_image_size.height;
542542
if (is_inf) {
543-
image_size_width = imgs[0]->nx;
544-
image_size_height = imgs[0]->ny;
543+
image_size_width = imgs.entries[0]->nx;
544+
image_size_height = imgs.entries[0]->ny;
545545
}
546546
}
547547
else if (ctx->has_qwen2vl_merger) {
548548
// use the image's native resolution when image is avaible
549549
if (is_inf) {
550550
// if (imgs->data->nx && imgs->data->ny) {
551-
image_size_width = imgs[0]->nx;
552-
image_size_height = imgs[0]->ny;
551+
image_size_width = imgs.entries[0]->nx;
552+
image_size_height = imgs.entries[0]->ny;
553553
}
554554
}
555555
const int patch_size = hparams.patch_size;
@@ -564,7 +564,7 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
564564
const float eps = hparams.eps;
565565
int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
566566

567-
const int batch_size = imgs.size();
567+
const int batch_size = imgs.entries.size();
568568

569569
if (ctx->has_llava_projector || ctx->has_minicpmv_projector || ctx->has_glm_projector) {
570570
GGML_ASSERT(batch_size == 1);
@@ -1477,15 +1477,15 @@ struct clip_model_loader {
14771477

14781478
// create a fake batch
14791479
clip_image_f32_batch batch;
1480-
clip_image_f32_ptr img;
1480+
clip_image_f32_ptr img(clip_image_f32_init());
14811481
clip_image_size image_size;
14821482
image_size.width = clip_get_image_size(&ctx_clip);
14831483
image_size.height = clip_get_image_size(&ctx_clip);
14841484
int n_patches = clip_get_image_size(&ctx_clip) / image_size.width;
14851485
img->nx = n_patches;
14861486
img->ny = n_patches;
14871487
img->buf.resize(n_patches * image_size.width * image_size.height * 3);
1488-
batch.push_back(std::move(img));
1488+
batch.entries.push_back(std::move(img));
14891489

14901490
ggml_cgraph * gf = clip_image_build_graph(&ctx_clip, batch, image_size, false);
14911491
ggml_backend_sched_reserve(ctx_clip.sched.get(), gf);
@@ -1626,31 +1626,31 @@ void clip_image_u8_batch_free(struct clip_image_u8_batch * batch) { if (batch) d
16261626
void clip_image_f32_batch_free(struct clip_image_f32_batch * batch) { if (batch) delete batch; }
16271627

16281628
size_t clip_image_f32_batch_n_images(const struct clip_image_f32_batch * batch) {
1629-
return batch->size();
1629+
return batch->entries.size();
16301630
}
16311631

16321632
size_t clip_image_f32_batch_nx(const struct clip_image_f32_batch * batch, int idx) {
1633-
if (idx < 0 || idx >= (int)batch->size()) {
1633+
if (idx < 0 || idx >= (int)batch->entries.size()) {
16341634
LOG_ERR("%s: invalid index %d\n", __func__, idx);
16351635
return 0;
16361636
}
1637-
return batch->at(idx)->nx;
1637+
return batch->entries[idx]->nx;
16381638
}
16391639

16401640
size_t clip_image_f32_batch_ny(const struct clip_image_f32_batch * batch, int idx) {
1641-
if (idx < 0 || idx >= (int)batch->size()) {
1641+
if (idx < 0 || idx >= (int)batch->entries.size()) {
16421642
LOG_ERR("%s: invalid index %d\n", __func__, idx);
16431643
return 0;
16441644
}
1645-
return batch->at(idx)->ny;
1645+
return batch->entries[idx]->ny;
16461646
}
16471647

16481648
clip_image_f32 * clip_image_f32_get_img(const struct clip_image_f32_batch * batch, int idx) {
1649-
if (idx < 0 || idx >= (int)batch->size()) {
1649+
if (idx < 0 || idx >= (int)batch->entries.size()) {
16501650
LOG_ERR("%s: invalid index %d\n", __func__, idx);
16511651
return nullptr;
16521652
}
1653-
return batch->at(idx).get();
1653+
return batch->entries[idx].get();
16541654
}
16551655

16561656
void clip_build_img_from_pixels(const unsigned char * rgb_pixels, int nx, int ny, clip_image_u8 * img) {
@@ -1884,7 +1884,7 @@ static std::vector<clip_image_u8_ptr> divide_to_patches_u8(const clip_image_u8 &
18841884
int height = image.ny;
18851885
for (int i = 0; i < height; i += patch_size) {
18861886
for (int j = 0; j < width; j += patch_size) {
1887-
clip_image_u8_ptr patch;
1887+
clip_image_u8_ptr patch(clip_image_u8_init());
18881888
patch->nx = std::min(patch_size, width - j);
18891889
patch->ny = std::min(patch_size, height - i);
18901890
patch->buf.resize(3 * patch->nx * patch->ny);
@@ -1990,14 +1990,14 @@ static std::vector<std::vector<clip_image_u8_ptr>> uhd_slice_image(const clip_im
19901990

19911991
if (multiple <= 1) {
19921992
auto best_size = uhd_find_best_resize(original_size, scale_resolution, patch_size, true);
1993-
clip_image_u8_ptr source_image;
1993+
clip_image_u8_ptr source_image(clip_image_u8_init());
19941994
bicubic_resize(*img, *source_image, best_size.first, best_size.second);
19951995
// source_image = image.resize(best_size, Image.Resampling.BICUBIC)
19961996
images.back().push_back(std::move(source_image));
19971997
}
19981998
else if (multiple > 1) {
19991999
auto best_size = uhd_find_best_resize(original_size, scale_resolution, patch_size);
2000-
clip_image_u8_ptr source_image;
2000+
clip_image_u8_ptr source_image(clip_image_u8_init());
20012001
bicubic_resize(*img, *source_image, best_size.first, best_size.second);
20022002
// source_image = image.copy().resize(best_resize, Image.Resampling.BICUBIC)
20032003
LOG_DBG("%s: image_size: %d %d; source_image size: %d %d\n", __func__, img->nx, img->ny, best_size.first, best_size.second);
@@ -2007,7 +2007,7 @@ static std::vector<std::vector<clip_image_u8_ptr>> uhd_slice_image(const clip_im
20072007
LOG_DBG("%s: image_size: %d %d; best_grid: %d %d\n", __func__, img->nx, img->ny, best_grid.first, best_grid.second);
20082008

20092009
auto refine_size = uhd_get_refine_size(original_size, best_grid, scale_resolution, patch_size, true);
2010-
clip_image_u8_ptr refine_image;
2010+
clip_image_u8_ptr refine_image(clip_image_u8_init());
20112011
bicubic_resize(*img, *refine_image, refine_size.first, refine_size.second);
20122012

20132013
LOG_DBG("%s: refine_image_size: %d %d; refine_size: %d %d\n", __func__, refine_image->nx, refine_image->ny, refine_size.first, refine_size.second);
@@ -2020,7 +2020,7 @@ static std::vector<std::vector<clip_image_u8_ptr>> uhd_slice_image(const clip_im
20202020
for (int patches_i = 0, ic = 0; patches_i < height && ic < best_grid.second; patches_i += grid_y, ic += 1){
20212021
images.push_back(std::vector<clip_image_u8_ptr>());
20222022
for(int patches_j = 0, jc = 0; patches_j < width && jc < best_grid.first; patches_j += grid_x, jc += 1){
2023-
clip_image_u8_ptr patch;
2023+
clip_image_u8_ptr patch(clip_image_u8_init());
20242024
patch->nx = grid_x;
20252025
patch->ny = grid_y;
20262026
patch->buf.resize(3 * patch->nx * patch->ny);
@@ -2062,36 +2062,36 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
20622062
for (size_t i = 0; i < imgs.size(); ++i) {
20632063
for (size_t j = 0; j < imgs[i].size(); ++j) {
20642064
LOG_DBG("%s: %d %d\n", __func__,imgs[i][j]->nx,imgs[i][j]->ny);
2065-
clip_image_f32_ptr res;
2065+
clip_image_f32_ptr res(clip_image_f32_init());
20662066
normalize_image_u8_to_f32(*imgs[i][j], *res, ctx->image_mean, ctx->image_std);
2067-
res_imgs->push_back(std::move(res));
2067+
res_imgs->entries.push_back(std::move(res));
20682068
}
20692069
}
20702070
return true;
20712071
}
20722072
else if (ctx->has_qwen2vl_merger) {
2073-
clip_image_u8_ptr resized;
2073+
clip_image_u8 resized;
20742074
auto patch_size = clip_get_patch_size(ctx) * 2;
20752075
int nx = ceil((float)img->nx / patch_size) * patch_size;
20762076
int ny = ceil((float)img->ny / patch_size) * patch_size;
2077-
bicubic_resize(*img, *resized, nx, ny);
2077+
bicubic_resize(*img, resized, nx, ny);
20782078

2079-
clip_image_f32_ptr img_f32;
2080-
// clip_image_f32_ptr res;
2081-
normalize_image_u8_to_f32(*resized, *img_f32, ctx->image_mean, ctx->image_std);
2079+
clip_image_f32_ptr img_f32(clip_image_f32_init());
2080+
// clip_image_f32_ptr res(clip_image_f32_init());
2081+
normalize_image_u8_to_f32(resized, *img_f32, ctx->image_mean, ctx->image_std);
20822082
// res_imgs->data[0] = *res;
2083-
res_imgs->push_back(std::move(img_f32));
2083+
res_imgs->entries.push_back(std::move(img_f32));
20842084
return true;
20852085
}
20862086

20872087
if (ctx->has_glm_projector || ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
20882088
clip_image_u8 resized_image;
20892089
int32_t sz=ctx->vision_model.hparams.image_size;
20902090
bicubic_resize(*img, resized_image,sz,sz);
2091-
clip_image_f32_ptr img_f32;
2091+
clip_image_f32_ptr img_f32(clip_image_f32_init());
20922092
//clip_image_save_to_bmp(resized_image, "resized.bmp");
20932093
normalize_image_u8_to_f32(resized_image, *img_f32, ctx->image_mean, ctx->image_std);
2094-
res_imgs->push_back(std::move(img_f32));
2094+
res_imgs->entries.push_back(std::move(img_f32));
20952095
return true;
20962096
}
20972097

@@ -2106,12 +2106,12 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
21062106
pad_to_square = false;
21072107
}
21082108
// free the previous res_imgs if any set
2109-
res_imgs->clear();
2109+
res_imgs->entries.clear();
21102110

21112111
// the logic below is to pad the shorter side to the longer side with a background color: rgb(122, 116, 104)
21122112
// see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156
21132113

2114-
clip_image_u8_ptr temp; // we will keep the input image data here temporarily
2114+
clip_image_u8_ptr temp(clip_image_u8_init()); // we will keep the input image data here temporarily
21152115
if (pad_to_square && img->nx != img->ny) {
21162116
int longer_side = std::max(img->nx, img->ny);
21172117
temp->nx = longer_side;
@@ -2156,14 +2156,14 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
21562156

21572157
std::vector<clip_image_u8_ptr> patches = divide_to_patches_u8(*temp, params.image_size); // prepare spatial sorted main patches of image_size each (336 in llava-1.6)
21582158

2159-
clip_image_u8_ptr image_original_resize;
2159+
clip_image_u8_ptr image_original_resize(clip_image_u8_init());
21602160
// bilinear_resize(*img, *image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square
21612161
bicubic_resize(*img, *image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square
21622162
patches.insert(patches.begin(), std::move(image_original_resize));
21632163
for (auto & patch : patches) {
2164-
clip_image_f32_ptr res;
2164+
clip_image_f32_ptr res(clip_image_f32_init());
21652165
normalize_image_u8_to_f32(*patch, *res, ctx->image_mean, ctx->image_std);
2166-
res_imgs->push_back(std::move(res));
2166+
res_imgs->entries.push_back(std::move(res));
21672167
}
21682168

21692169
return true;
@@ -2181,7 +2181,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
21812181

21822182
const int nx2 = ctx->vision_model.hparams.image_size;
21832183
const int ny2 = ctx->vision_model.hparams.image_size;
2184-
clip_image_f32_ptr res;
2184+
clip_image_f32_ptr res(clip_image_f32_init());
21852185
res->nx = nx2;
21862186
res->ny = ny2;
21872187
res->buf.resize(3 * nx2 * ny2);
@@ -2242,7 +2242,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
22422242
// }
22432243
// res_imgs.push_back(res);
22442244

2245-
res_imgs->push_back(std::move(res));
2245+
res_imgs->entries.push_back(std::move(res));
22462246

22472247
return true;
22482248
}
@@ -2424,9 +2424,9 @@ bool clip_image_encode(struct clip_ctx * ctx, const int n_threads, clip_image_f3
24242424
}
24252425

24262426
clip_image_f32_batch imgs;
2427-
clip_image_f32_ptr img_copy;
2427+
clip_image_f32_ptr img_copy(clip_image_f32_init());
24282428
*img_copy = *img;
2429-
imgs.push_back(std::move(img_copy));
2429+
imgs.entries.push_back(std::move(img_copy));
24302430

24312431
return clip_image_batch_encode(ctx, n_threads, &imgs, vec);
24322432
}
@@ -2439,7 +2439,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
24392439
return false;
24402440
}
24412441

2442-
int batch_size = imgs.size();
2442+
int batch_size = imgs.entries.size();
24432443
if (ctx->has_llava_projector) {
24442444
GGML_ASSERT(batch_size == 1); // TODO: support multiple images
24452445
}
@@ -2466,8 +2466,8 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
24662466
int image_size_width = image_size;
24672467
int image_size_height = image_size;
24682468
if (ctx->has_minicpmv_projector | ctx->has_qwen2vl_merger) {
2469-
image_size_width = imgs[0]->nx;
2470-
image_size_height = imgs[0]->ny;
2469+
image_size_width = imgs.entries[0]->nx;
2470+
image_size_height = imgs.entries[0]->ny;
24712471
}
24722472
const int patch_size = hparams.patch_size;
24732473
const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size));
@@ -2479,9 +2479,9 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
24792479
struct ggml_tensor * inp_raw = ggml_graph_get_tensor(gf, "inp_raw");
24802480
float * data = (float *)malloc(ggml_nbytes(inp_raw));
24812481

2482-
for (size_t i = 0; i < imgs.size(); i++) {
2483-
const int nx = imgs[i]->nx;
2484-
const int ny = imgs[i]->ny;
2482+
for (size_t i = 0; i < imgs.entries.size(); i++) {
2483+
const int nx = imgs.entries[i]->nx;
2484+
const int ny = imgs.entries[i]->ny;
24852485
if (!(ctx->has_minicpmv_projector | ctx->has_qwen2vl_merger)) {
24862486
GGML_ASSERT(nx == image_size && ny == image_size);
24872487
}
@@ -2492,7 +2492,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
24922492
for (int k = 0; k < 3; k++) {
24932493
for (int y = 0; y < ny; y++) {
24942494
for (int x = 0; x < nx; x++) {
2495-
data[(b * 3 * n) + k * n + y * nx + x] = imgs[b]->buf[3 * (y * nx + x) + k];
2495+
data[(b * 3 * n) + k * n + y * nx + x] = imgs.entries[b]->buf[3 * (y * nx + x) + k];
24962496
}
24972497
}
24982498
}

0 commit comments

Comments
 (0)