Skip to content

Commit 0d6f6a7

Browse files
committed
add --reranking argument
1 parent 84b0af8 commit 0d6f6a7

File tree

7 files changed

+43
-18
lines changed

7 files changed

+43
-18
lines changed

common/arg.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,10 @@ static bool gpt_params_parse_ex(int argc, char ** argv, gpt_params_context & ctx
284284
params.kv_overrides.back().key[0] = 0;
285285
}
286286

287+
if (params.reranking && params.embedding) {
288+
throw std::invalid_argument("error: either --embedding or --reranking can be specified, but not both");
289+
}
290+
287291
return true;
288292
}
289293

@@ -1750,6 +1754,13 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
17501754
params.embedding = true;
17511755
}
17521756
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_EMBEDDINGS"));
1757+
add_opt(llama_arg(
1758+
{"--reranking", "--rerank"},
1759+
format("enable reranking endpoint on server (default: %s)", params.reranking ? "enabled" : "disabled"),
1760+
[](gpt_params & params) {
1761+
params.reranking = true;
1762+
}
1763+
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_RERANKING"));
17531764
add_opt(llama_arg(
17541765
{"--api-key"}, "KEY",
17551766
"API key to use for authentication (default: none)",

common/common.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1023,6 +1023,11 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
10231023
cparams.flash_attn = params.flash_attn;
10241024
cparams.no_perf = params.no_perf;
10251025

1026+
if (params.reranking) {
1027+
cparams.embeddings = true;
1028+
cparams.pooling_type = LLAMA_POOLING_TYPE_RANK;
1029+
}
1030+
10261031
cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
10271032
cparams.type_v = kv_cache_type_from_str(params.cache_type_v);
10281033

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,7 @@ struct gpt_params {
271271
int32_t embd_normalize = 2; // normalisation for embendings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
272272
std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix
273273
std::string embd_sep = "\n"; // separator of embendings
274+
bool reranking = false; // enable reranking support on server
274275

275276
// server params
276277
int32_t port = 8080; // server listens on this network port

examples/server/server.cpp

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2888,8 +2888,8 @@ int main(int argc, char ** argv) {
28882888
};
28892889

28902890
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_cmpl_type cmpl_type, json & data, httplib::Response & res) {
2891-
if (ctx_server.params.embedding) {
2892-
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
2891+
if (ctx_server.params.embedding || ctx_server.params.reranking) {
2892+
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
28932893
return;
28942894
}
28952895

@@ -2949,8 +2949,8 @@ int main(int argc, char ** argv) {
29492949

29502950
// TODO: maybe merge this function with "handle_completions_generic"
29512951
const auto handle_chat_completions = [&ctx_server, &params, &res_error, &res_ok, verbose](const httplib::Request & req, httplib::Response & res) {
2952-
if (ctx_server.params.embedding) {
2953-
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
2952+
if (ctx_server.params.embedding || ctx_server.params.reranking) {
2953+
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
29542954
return;
29552955
}
29562956

@@ -3074,6 +3074,11 @@ int main(int argc, char ** argv) {
30743074
};
30753075

30763076
const auto handle_embeddings = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
3077+
// TODO: somehow clean up this checks in the future
3078+
if (!ctx_server.params.embedding || ctx_server.params.reranking) {
3079+
res_error(res, format_error_response("This server does not support embeddings. Start it with `--embeddings` and without `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
3080+
return;
3081+
}
30773082
const json body = json::parse(req.body);
30783083
bool is_openai = false;
30793084

@@ -3125,6 +3130,10 @@ int main(int argc, char ** argv) {
31253130
};
31263131

31273132
const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
3133+
if (!ctx_server.params.reranking) {
3134+
res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
3135+
return;
3136+
}
31283137
const json body = json::parse(req.body);
31293138

31303139
// TODO: implement
@@ -3148,15 +3157,9 @@ int main(int argc, char ** argv) {
31483157
return;
31493158
}
31503159

3151-
json documents;
3152-
if (body.count("documents") != 0) {
3153-
documents = body.at("documents");
3154-
if (!documents.is_array() || documents.size() == 0) {
3155-
res_error(res, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST));
3156-
return;
3157-
}
3158-
} else {
3159-
res_error(res, format_error_response("\"documents\" must be provided", ERROR_TYPE_INVALID_REQUEST));
3160+
std::vector<std::string> documents = json_value(body, "documents", std::vector<std::string>());
3161+
if (documents.empty()) {
3162+
res_error(res, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST));
31603163
return;
31613164
}
31623165

examples/server/tests/features/embeddings.feature

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ Feature: llama.cpp server
1515
And 128 as batch size
1616
And 128 as ubatch size
1717
And 512 KV cache size
18-
And embeddings extraction
18+
And enable embeddings endpoint
1919
Then the server is starting
2020
Then the server is healthy
2121

examples/server/tests/features/rerank.feature

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ Feature: llama.cpp server
1212
And 512 as batch size
1313
And 512 as ubatch size
1414
And 512 KV cache size
15-
And embeddings extraction
15+
And enable reranking endpoint
1616
Then the server is starting
1717
Then the server is healthy
1818

@@ -39,5 +39,4 @@ Feature: llama.cpp server
3939
"""
4040
When reranking request
4141
Then reranking results are returned
42-
# TODO: this result make no sense, probably need a better model?
43-
Then reranking highest score is index 3 and lowest score is index 0
42+
Then reranking highest score is index 2 and lowest score is index 3

examples/server/tests/features/steps/steps.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def step_server_config(context, server_fqdn: str, server_port: str):
6868
context.server_api_key = None
6969
context.server_continuous_batching = False
7070
context.server_embeddings = False
71+
context.server_reranking = False
7172
context.server_metrics = False
7273
context.server_process = None
7374
context.seed = None
@@ -176,10 +177,13 @@ def step_server_continuous_batching(context):
176177
context.server_continuous_batching = True
177178

178179

179-
@step('embeddings extraction')
180+
@step('enable embeddings endpoint')
180181
def step_server_embeddings(context):
181182
context.server_embeddings = True
182183

184+
@step('enable reranking endpoint')
185+
def step_server_reranking(context):
186+
context.server_reranking = True
183187

184188
@step('prometheus compatible metrics exposed')
185189
def step_server_metrics(context):
@@ -1408,6 +1412,8 @@ def start_server_background(context):
14081412
server_args.append('--cont-batching')
14091413
if context.server_embeddings:
14101414
server_args.append('--embedding')
1415+
if context.server_reranking:
1416+
server_args.append('--reranking')
14111417
if context.server_metrics:
14121418
server_args.append('--metrics')
14131419
if context.model_alias:

0 commit comments

Comments
 (0)