blob: 301e3965f0811ebe73aebedaf972ed7c29aa77b2 [file] [log] [blame]
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "content/browser/ai/echo_ai_language_model.h"
#include <optional>
#include "base/containers/to_vector.h"
#include "base/functional/bind.h"
#include "base/location.h"
#include "base/notimplemented.h"
#include "base/time/time.h"
#include "components/optimization_guide/core/optimization_guide_features.h"
#include "content/browser/ai/echo_ai_manager_impl.h"
#include "content/public/browser/browser_thread.h"
#include "mojo/public/cpp/bindings/remote.h"
#include "mojo/public/cpp/bindings/self_owned_receiver.h"
#include "third_party/blink/public/mojom/ai/ai_common.mojom.h"
#include "third_party/blink/public/mojom/ai/ai_language_model.mojom.h"
#include "third_party/blink/public/mojom/ai/model_streaming_responder.mojom.h"
namespace content {
namespace {
constexpr char kResponsePrefix[] =
"On-device model is not available in Chromium, this API is just echoing "
"back the input:\n";
}
EchoAILanguageModel::EchoAILanguageModel(
blink::mojom::AILanguageModelSamplingParamsPtr sampling_params,
base::flat_set<blink::mojom::AILanguageModelPromptType> input_types,
uint32_t initial_tokens_size)
: current_tokens_(initial_tokens_size),
sampling_params_(std::move(sampling_params)),
input_types_(input_types) {}
EchoAILanguageModel::~EchoAILanguageModel() = default;
void EchoAILanguageModel::DoMockExecution(
const std::string& input,
mojo::RemoteSetElementId responder_id) {
blink::mojom::ModelStreamingResponder* responder =
responder_set_.Get(responder_id);
if (!responder) {
return;
}
uint32_t quota = EchoAIManagerImpl::kMaxContextSizeInTokens;
if (input.size() > quota) {
responder->OnError(
blink::mojom::ModelStreamingResponseStatus::kErrorInputTooLarge,
blink::mojom::QuotaErrorInfo::New(input.size(), quota));
return;
}
if (current_tokens_ > quota - input.size()) {
current_tokens_ = input.size();
responder->OnQuotaOverflow();
}
current_tokens_ += input.size();
responder->OnStreaming(kResponsePrefix);
responder->OnStreaming(input);
responder->OnCompletion(
blink::mojom::ModelExecutionContextInfo::New(current_tokens_));
}
void EchoAILanguageModel::Prompt(
std::vector<blink::mojom::AILanguageModelPromptPtr> prompts,
on_device_model::mojom::ResponseConstraintPtr constraint,
mojo::PendingRemote<blink::mojom::ModelStreamingResponder>
pending_responder) {
if (is_destroyed_) {
mojo::Remote<blink::mojom::ModelStreamingResponder> responder(
std::move(pending_responder));
responder->OnError(
blink::mojom::ModelStreamingResponseStatus::kErrorSessionDestroyed,
/*quota_error_info=*/nullptr);
return;
}
std::string response = "";
for (const auto& prompt : prompts) {
for (auto& content : prompt->content) {
if (content->is_text()) {
response += content->get_text();
} else if (content->is_bitmap()) {
if (!input_types_.contains(
blink::mojom::AILanguageModelPromptType::kImage)) {
mojo::ReportBadMessage("Image input is not supported.");
return;
}
response += "<image>";
} else if (content->is_audio()) {
if (!input_types_.contains(
blink::mojom::AILanguageModelPromptType::kAudio)) {
mojo::ReportBadMessage("Audio input is not supported.");
return;
}
response += "<audio>";
} else {
NOTIMPLEMENTED_LOG_ONCE();
}
}
}
mojo::RemoteSetElementId responder_id =
responder_set_.Add(std::move(pending_responder));
// Simulate the time taken by model execution.
content::GetUIThreadTaskRunner()->PostDelayedTask(
FROM_HERE,
base::BindOnce(&EchoAILanguageModel::DoMockExecution,
weak_ptr_factory_.GetWeakPtr(), response, responder_id),
base::Seconds(1));
}
void EchoAILanguageModel::Append(
std::vector<blink::mojom::AILanguageModelPromptPtr> prompts,
mojo::PendingRemote<blink::mojom::ModelStreamingResponder>
pending_responder) {
mojo::Remote<blink::mojom::ModelStreamingResponder> responder(
std::move(pending_responder));
responder->OnCompletion(
blink::mojom::ModelExecutionContextInfo::New(current_tokens_));
}
void EchoAILanguageModel::Fork(
mojo::PendingRemote<blink::mojom::AIManagerCreateLanguageModelClient>
client) {
mojo::Remote<blink::mojom::AIManagerCreateLanguageModelClient> client_remote(
std::move(client));
mojo::PendingRemote<blink::mojom::AILanguageModel> language_model;
mojo::MakeSelfOwnedReceiver(
std::make_unique<EchoAILanguageModel>(sampling_params_.Clone(),
input_types_, current_tokens_),
language_model.InitWithNewPipeAndPassReceiver());
client_remote->OnResult(
std::move(language_model),
blink::mojom::AILanguageModelInstanceInfo::New(
EchoAIManagerImpl::kMaxContextSizeInTokens, current_tokens_,
sampling_params_->Clone(), base::ToVector(input_types_)));
}
void EchoAILanguageModel::Destroy() {
is_destroyed_ = true;
for (auto& responder : responder_set_) {
responder->OnError(
blink::mojom::ModelStreamingResponseStatus::kErrorSessionDestroyed,
/*quota_error_info=*/nullptr);
}
responder_set_.Clear();
}
void EchoAILanguageModel::MeasureInputUsage(
std::vector<blink::mojom::AILanguageModelPromptPtr> input,
MeasureInputUsageCallback callback) {
size_t total = 0;
for (const auto& prompt : input) {
for (const auto& content : prompt->content) {
if (content->is_text()) {
total += content->get_text().size();
} else {
total += 100; // TODO(crbug.com/415304330): Improve estimate.
}
}
}
std::move(callback).Run(total);
}
} // namespace content