Skip to content

Commit 71b053c

Browse files
GPU weights not in RAM, direct loading with cuFile
1 parent af005ce commit 71b053c

File tree

5 files changed

+159
-64
lines changed

5 files changed

+159
-64
lines changed

.gitignore

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,12 @@ qnt-*.txt
4848
perf-*.txt
4949

5050
examples/jeopardy/results.txt
51+
52+
/prompts
53+
*.sh
54+
*.log
55+
*.py
56+
*.txt
57+
/wikitext-2-raw/
58+
*.org
59+
/libllama.so

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ endif
125125
ifdef LLAMA_CUBLAS
126126
CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include
127127
CXXFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include
128-
LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib
128+
LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lcufile -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib
129129
OBJS += ggml-cuda.o
130130
NVCC = nvcc
131131
NVCCFLAGS = --forward-unknown-to-host-compiler -arch=native

ggml-cuda.cu

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22
#include <cstdint>
33
#include <stdint.h>
44
#include <stdio.h>
5+
#include <fcntl.h>
56
#include <atomic>
67

78
#include <cuda_runtime.h>
89
#include <cublas_v2.h>
910
#include <cuda_fp16.h>
11+
#include <cufile.h>
1012

1113
#include "ggml-cuda.h"
1214
#include "ggml.h"
@@ -32,6 +34,15 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
3234
} \
3335
} while (0)
3436

37+
#define CUFILE_CHECK(status) \
38+
do { \
39+
CUfileError_t status_ = (status); \
40+
if (status_.err != CU_FILE_SUCCESS) { \
41+
fprintf(stderr, "cuFile error %d at %s:%d\n", status_.err, __FILE__, __LINE__); \
42+
exit(1); \
43+
} \
44+
} while (0)
45+
3546
typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, float & v0, float & v1);
3647
typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
3748
typedef void (*dequantize_mul_mat_vec_cuda_t)(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream);
@@ -372,7 +383,7 @@ struct cuda_buffer {
372383
static cuda_buffer g_cuda_buffer_pool[MAX_CUDA_BUFFERS];
373384
static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT;
374385

375-
static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
386+
void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
376387
scoped_spin_lock lock(g_cuda_pool_lock);
377388

378389
for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
@@ -391,7 +402,7 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
391402
return ptr;
392403
}
393404

394-
static void ggml_cuda_pool_free(void * ptr, size_t size) {
405+
void ggml_cuda_pool_free(void * ptr, size_t size) {
395406
scoped_spin_lock lock(g_cuda_pool_lock);
396407

397408
for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
@@ -431,6 +442,9 @@ void ggml_init_cublas() {
431442

432443
// configure logging to stdout
433444
// CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr));
445+
446+
// initialize cuFile for loading model parameters directly to VRAM
447+
CUFILE_CHECK(cuFileDriverOpen());
434448
}
435449
}
436450

@@ -893,3 +907,34 @@ void ggml_cuda_transform_tensor(ggml_tensor * tensor) {
893907
tensor->data = dst;
894908
tensor->backend = GGML_BACKEND_CUDA;
895909
}
910+
911+
bool ggml_cuda_load_data_cufile(const char * fname, struct ggml_tensor ** tensors, const int num_tensors, const size_t * offsets) {
912+
CUfileDescr_t cf_descr;
913+
memset((void *)&cf_descr, 0, sizeof(CUfileDescr_t));
914+
const int fd_cf = open(fname, O_RDONLY|O_DIRECT, 0644);
915+
cf_descr.handle.fd = fd_cf;
916+
cf_descr.type = CU_FILE_HANDLE_TYPE_OPAQUE_FD;
917+
918+
CUfileHandle_t cf_handle;
919+
CUfileError_t status = cuFileHandleRegister(&cf_handle, &cf_descr);
920+
921+
if (status.err == CU_FILE_INTERNAL_ERROR) {
922+
fprintf(stderr, "WARNING: cuFile experienced an internal error while loading weights from \"%s\". Using a workaround (slower). "
923+
"This happens with weight files on Btrfs partitions. ext4 and NTFS are confirmed to work.\n", fname);
924+
}
925+
if (status.err != CU_FILE_SUCCESS) {
926+
return false;
927+
}
928+
929+
for (int i = 0; i < num_tensors; ++i) {
930+
ggml_tensor * tensor = tensors[i];
931+
const size_t size = ggml_nbytes(tensor);
932+
const size_t offset = offsets[i];
933+
934+
size_t actual_size;
935+
void * buf = ggml_cuda_pool_malloc(size, &actual_size);
936+
cuFileRead(cf_handle, buf, size, offset, 0);
937+
tensor->data = buf;
938+
}
939+
return true;
940+
}

ggml-cuda.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,11 @@ void ggml_cuda_mul_mat(const struct ggml_tensor * src0, const struct ggml_tens
1414
// TODO: export these with GGML_API
1515
void * ggml_cuda_host_malloc(size_t size);
1616
void ggml_cuda_host_free(void * ptr);
17+
void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size);
18+
void ggml_cuda_pool_free(void * ptr, size_t size);
1719

1820
void ggml_cuda_transform_tensor(struct ggml_tensor * tensor);
21+
bool ggml_cuda_load_data_cufile(const char * fname, struct ggml_tensor ** tensors, int num_tensors, const size_t * offsets);
1922

2023
#ifdef __cplusplus
2124
}

llama.cpp

Lines changed: 99 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include "ggml.h"
1212
#ifdef GGML_USE_CUBLAS
13+
#include <cuda_runtime.h>
1314
#include "ggml-cuda.h"
1415
#endif
1516

@@ -641,7 +642,7 @@ struct llama_model_loader {
641642
}
642643
}
643644

644-
struct ggml_tensor * get_tensor(const std::string & name, const std::vector<uint32_t> & ne) {
645+
struct ggml_tensor * get_tensor(const std::string & name, const std::vector<uint32_t> & ne, ggml_backend backend) {
645646
auto it = tensors_map.name_to_idx.find(name);
646647
if (it == tensors_map.name_to_idx.end()) {
647648
throw format("llama.cpp: tensor '%s' is missing from model", name.c_str());
@@ -652,10 +653,10 @@ struct llama_model_loader {
652653
name.c_str(), llama_format_tensor_shape(ne).c_str(), llama_format_tensor_shape(lt.ne).c_str());
653654
}
654655

655-
return get_tensor_for(lt);
656+
return get_tensor_for(lt, backend);
656657
}
657658

658-
struct ggml_tensor * get_tensor_for(llama_load_tensor & lt) {
659+
struct ggml_tensor * get_tensor_for(llama_load_tensor & lt, ggml_backend backend) {
659660
struct ggml_tensor * tensor;
660661
if (lt.ne.size() == 2) {
661662
tensor = ggml_new_tensor_2d(ggml_ctx, lt.type, lt.ne.at(0), lt.ne.at(1));
@@ -665,6 +666,7 @@ struct llama_model_loader {
665666
}
666667
ggml_set_name(tensor, lt.name.c_str());
667668
LLAMA_ASSERT(lt.ggml_tensor == NULL); // if this fails, we called get_tensor twice on the same tensor
669+
tensor->backend = backend;
668670
lt.ggml_tensor = tensor;
669671
num_ggml_tensors_created++;
670672
return tensor;
@@ -683,7 +685,7 @@ struct llama_model_loader {
683685
}
684686

685687
if (use_mmap) {
686-
mapping.reset(new llama_mmap(&file_loaders.at(0)->file));
688+
mapping.reset(new llama_mmap(&file_loaders.at(0)->file, false));
687689
if (!lmlock) {
688690
// Don't call the callback since the actual loading will be lazy
689691
// and we can't measure it.
@@ -696,6 +698,9 @@ struct llama_model_loader {
696698

697699
size_t done_size = 0;
698700
for (llama_load_tensor & lt : tensors_map.tensors) {
701+
if (lt.ggml_tensor->backend != GGML_BACKEND_CPU) {
702+
continue;
703+
}
699704
if (progress_callback) {
700705
progress_callback((float) done_size / data_size, progress_callback_user_data);
701706
}
@@ -944,26 +949,6 @@ static void llama_model_load_internal(
944949
ml->calc_sizes(&ctx_size, &mmapped_size);
945950
fprintf(stderr, "%s: ggml ctx size = %6.2f KB\n", __func__, ctx_size/1024.0);
946951

947-
// print memory requirements
948-
{
949-
const size_t scale = memory_type == GGML_TYPE_F32 ? 2 : 1;
950-
951-
// this is the total memory required to run the inference
952-
const size_t mem_required =
953-
ctx_size +
954-
mmapped_size +
955-
MEM_REQ_SCRATCH0().at(model.type) +
956-
MEM_REQ_SCRATCH1().at(model.type) +
957-
MEM_REQ_EVAL().at(model.type);
958-
959-
// this is the memory required by one llama_state
960-
const size_t mem_required_state =
961-
scale*MEM_REQ_KV_SELF().at(model.type);
962-
963-
fprintf(stderr, "%s: mem required = %7.2f MB (+ %7.2f MB per state)\n", __func__,
964-
mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0);
965-
}
966-
967952
// create the ggml context
968953
{
969954
lctx.model.buf.resize(ctx_size);
@@ -985,79 +970,131 @@ static void llama_model_load_internal(
985970
}
986971

987972
// prepare memory for the weights
973+
size_t vram_total = 0;
988974
{
989975
const uint32_t n_embd = hparams.n_embd;
990976
const uint32_t n_layer = hparams.n_layer;
991977
const uint32_t n_vocab = hparams.n_vocab;
992978

993979
ml->ggml_ctx = ctx;
994980

995-
model.tok_embeddings = ml->get_tensor("tok_embeddings.weight", {n_embd, n_vocab});
996-
model.norm = ml->get_tensor("norm.weight", {n_embd});
997-
model.output = ml->get_tensor("output.weight", {n_embd, n_vocab});
981+
model.tok_embeddings = ml->get_tensor("tok_embeddings.weight", {n_embd, n_vocab}, GGML_BACKEND_CPU);
982+
model.norm = ml->get_tensor("norm.weight", {n_embd}, GGML_BACKEND_CPU);
983+
ggml_backend backend_output;
984+
if (n_gpu_layers > int(n_layer)) {
985+
backend_output = GGML_BACKEND_CUDA;
986+
} else {
987+
backend_output = GGML_BACKEND_CPU;
988+
}
989+
model.output = ml->get_tensor("output.weight", {n_embd, n_vocab}, backend_output);
998990

999991
model.layers.resize(n_layer);
992+
const int i_gpu_start = n_layer - n_gpu_layers;
1000993
for (uint32_t i = 0; i < n_layer; ++i) {
1001994
auto & layer = model.layers[i];
995+
const ggml_backend backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : GGML_BACKEND_CUDA;
1002996

1003997
std::string layers_i = "layers." + std::to_string(i);
1004998

1005-
layer.attention_norm = ml->get_tensor(layers_i + ".attention_norm.weight", {n_embd});
999+
layer.attention_norm = ml->get_tensor(layers_i + ".attention_norm.weight", {n_embd}, backend);
10061000

1007-
layer.wq = ml->get_tensor(layers_i + ".attention.wq.weight", {n_embd, n_embd});
1008-
layer.wk = ml->get_tensor(layers_i + ".attention.wk.weight", {n_embd, n_embd});
1009-
layer.wv = ml->get_tensor(layers_i + ".attention.wv.weight", {n_embd, n_embd});
1010-
layer.wo = ml->get_tensor(layers_i + ".attention.wo.weight", {n_embd, n_embd});
1001+
layer.wq = ml->get_tensor(layers_i + ".attention.wq.weight", {n_embd, n_embd}, backend);
1002+
layer.wk = ml->get_tensor(layers_i + ".attention.wk.weight", {n_embd, n_embd}, backend);
1003+
layer.wv = ml->get_tensor(layers_i + ".attention.wv.weight", {n_embd, n_embd}, backend);
1004+
layer.wo = ml->get_tensor(layers_i + ".attention.wo.weight", {n_embd, n_embd}, backend);
10111005

1012-
layer.ffn_norm = ml->get_tensor(layers_i + ".ffn_norm.weight", {n_embd});
1006+
layer.ffn_norm = ml->get_tensor(layers_i + ".ffn_norm.weight", {n_embd}, backend);
10131007

1014-
layer.w1 = ml->get_tensor(layers_i + ".feed_forward.w1.weight", {n_embd, n_ff});
1015-
layer.w2 = ml->get_tensor(layers_i + ".feed_forward.w2.weight", { n_ff, n_embd});
1016-
layer.w3 = ml->get_tensor(layers_i + ".feed_forward.w3.weight", {n_embd, n_ff});
1008+
layer.w1 = ml->get_tensor(layers_i + ".feed_forward.w1.weight", {n_embd, n_ff}, backend);
1009+
layer.w2 = ml->get_tensor(layers_i + ".feed_forward.w2.weight", { n_ff, n_embd}, backend);
1010+
layer.w3 = ml->get_tensor(layers_i + ".feed_forward.w3.weight", {n_embd, n_ff}, backend);
1011+
if (backend == GGML_BACKEND_CUDA) {
1012+
vram_total += ggml_nbytes(layer.attention_norm) + ggml_nbytes(layer.wq) + ggml_nbytes(layer.wk)
1013+
+ ggml_nbytes(layer.wv) + ggml_nbytes(layer.wo) + ggml_nbytes(layer.attention_norm)
1014+
+ ggml_nbytes(layer.w1) + ggml_nbytes(layer.w2) + ggml_nbytes(layer.w3);
1015+
}
10171016
}
10181017
}
10191018

10201019
ml->done_getting_tensors();
10211020

1021+
// print memory requirements
1022+
{
1023+
const size_t scale = memory_type == GGML_TYPE_F32 ? 2 : 1;
1024+
1025+
// this is the total memory required to run the inference
1026+
const size_t mem_required =
1027+
ctx_size +
1028+
mmapped_size - vram_total + // weights in VRAM not in memory
1029+
MEM_REQ_SCRATCH0().at(model.type) +
1030+
MEM_REQ_SCRATCH1().at(model.type) +
1031+
MEM_REQ_EVAL().at(model.type);
1032+
1033+
// this is the memory required by one llama_state
1034+
const size_t mem_required_state =
1035+
scale*MEM_REQ_KV_SELF().at(model.type);
1036+
1037+
fprintf(stderr, "%s: mem required = %7.2f MB (+ %7.2f MB per state)\n", __func__,
1038+
mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0);
1039+
1040+
#ifdef GGML_USE_CUBLAS
1041+
const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer));
1042+
1043+
fprintf(stderr, "%s: [cublas] offloading %d layers to GPU\n", __func__, n_gpu);
1044+
if (n_gpu_layers > (int) hparams.n_layer) {
1045+
fprintf(stderr, "%s: [cublas] offloading output layer to GPU\n", __func__);
1046+
}
1047+
fprintf(stderr, "%s: [cublas] total VRAM used: %zu MB\n", __func__, vram_total / 1024 / 1024);
1048+
#else
1049+
(void) n_gpu_layers;
1050+
#endif
1051+
}
1052+
10221053
// populate `tensors_by_name`
10231054
for (llama_load_tensor & lt : ml->tensors_map.tensors) {
10241055
model.tensors_by_name.emplace_back(lt.name, lt.ggml_tensor);
10251056
}
10261057

10271058
ml->load_all_data(progress_callback, progress_callback_user_data, use_mlock ? &lctx.model.mlock_mmap : NULL);
10281059

1029-
model.mapping = std::move(ml->mapping);
10301060
#ifdef GGML_USE_CUBLAS
10311061
{
1032-
const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer));
1062+
std::vector<struct ggml_tensor *> tensors;
1063+
std::vector<size_t> offsets;
1064+
for (llama_load_tensor & lt : ml->tensors_map.tensors) {
1065+
if (lt.ggml_tensor->backend != GGML_BACKEND_CUDA) {
1066+
continue;
1067+
}
1068+
tensors.emplace_back(lt.ggml_tensor);
1069+
LLAMA_ASSERT(lt.shards.size() == 1);
1070+
offsets.emplace_back(lt.shards.at(0).file_off);
1071+
}
1072+
bool cufile_success = ggml_cuda_load_data_cufile(fname.c_str(), tensors.data(), tensors.size(), offsets.data());
10331073

1034-
fprintf(stderr, "%s: [cublas] offloading %d layers to GPU\n", __func__, n_gpu);
1074+
if (!cufile_success) {
1075+
for (llama_load_tensor & lt : ml->tensors_map.tensors) {
1076+
if (lt.ggml_tensor->backend != GGML_BACKEND_CUDA) {
1077+
continue;
1078+
}
1079+
size_t actual_size;
1080+
void * buf = ggml_cuda_pool_malloc(lt.size, &actual_size);
1081+
void * buf_host = ggml_cuda_host_malloc(lt.size);
10351082

1036-
size_t vram_total = 0;
1083+
llama_file & file = ml->file_loaders.at(lt.shards.at(0).file_idx)->file;
1084+
file.seek(lt.shards.at(0).file_off, SEEK_SET);
1085+
file.read_raw(buf_host, lt.size);
10371086

1038-
for (int i = 0; i < n_gpu; ++i) {
1039-
const auto & layer = model.layers[i];
1087+
cudaMemcpy(buf, buf_host, lt.size, cudaMemcpyHostToDevice);
1088+
cudaDeviceSynchronize();
10401089

1041-
ggml_cuda_transform_tensor(layer.attention_norm); vram_total += ggml_nbytes(layer.attention_norm);
1042-
ggml_cuda_transform_tensor(layer.wq); vram_total += ggml_nbytes(layer.wq);
1043-
ggml_cuda_transform_tensor(layer.wk); vram_total += ggml_nbytes(layer.wk);
1044-
ggml_cuda_transform_tensor(layer.wv); vram_total += ggml_nbytes(layer.wv);
1045-
ggml_cuda_transform_tensor(layer.wo); vram_total += ggml_nbytes(layer.wo);
1046-
ggml_cuda_transform_tensor(layer.ffn_norm); vram_total += ggml_nbytes(layer.ffn_norm);
1047-
ggml_cuda_transform_tensor(layer.w1); vram_total += ggml_nbytes(layer.w1);
1048-
ggml_cuda_transform_tensor(layer.w2); vram_total += ggml_nbytes(layer.w2);
1049-
ggml_cuda_transform_tensor(layer.w3); vram_total += ggml_nbytes(layer.w3);
1050-
}
1051-
if (n_gpu_layers > (int) hparams.n_layer) {
1052-
fprintf(stderr, "%s: [cublas] offloading output layer to GPU\n", __func__);
1053-
ggml_cuda_transform_tensor(model.output); vram_total += ggml_nbytes(model.output);
1090+
lt.ggml_tensor->data = buf;
1091+
ggml_cuda_host_free(buf_host);
1092+
}
10541093
}
1055-
1056-
fprintf(stderr, "%s: [cublas] total VRAM used: %zu MB\n", __func__, vram_total / 1024 / 1024);
10571094
}
1058-
#else
1059-
(void) n_gpu_layers;
1060-
#endif
1095+
#endif // GGML_USE_CUBLAS
1096+
1097+
model.mapping = std::move(ml->mapping);
10611098

10621099
// loading time will be recalculate after the first eval, so
10631100
// we take page faults deferred by mmap() into consideration
@@ -2395,7 +2432,8 @@ int llama_apply_lora_from_file_internal(struct llama_context * ctx, const char *
23952432
}
23962433
size_t idx = model_loader->tensors_map.name_to_idx[base_name];
23972434
llama_load_tensor & lt = model_loader->tensors_map.tensors[idx];
2398-
base_t = model_loader->get_tensor(base_name, { (uint32_t)dest_t->ne[0], (uint32_t)dest_t->ne[1] });
2435+
base_t = model_loader->get_tensor(
2436+
base_name, { (uint32_t)dest_t->ne[0], (uint32_t)dest_t->ne[1] }, GGML_BACKEND_CPU);
23992437
lt.data = (uint8_t *) lt.ggml_tensor->data;
24002438
model_loader->load_data_for(lt);
24012439
lt.ggml_tensor->data = lt.data;

0 commit comments

Comments
 (0)