Skip to content

Commit 39b5149

Browse files
yinghsienwucopybara-github
authored andcommitted
feat: Vertex RAG for enhanced generative AI
PiperOrigin-RevId: 627454806
1 parent 754c89d commit 39b5149

File tree

14 files changed

+1724
-22
lines changed

14 files changed

+1724
-22
lines changed

google/cloud/aiplatform/compat/services/__init__.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,12 @@
6363
from google.cloud.aiplatform_v1beta1.services.model_service import (
6464
client as model_service_client_v1beta1,
6565
)
66-
from google.cloud.aiplatform_v1beta1.services.pipeline_service import (
67-
client as pipeline_service_client_v1beta1,
68-
)
6966
from google.cloud.aiplatform_v1beta1.services.persistent_resource_service import (
7067
client as persistent_resource_service_client_v1beta1,
7168
)
69+
from google.cloud.aiplatform_v1beta1.services.pipeline_service import (
70+
client as pipeline_service_client_v1beta1,
71+
)
7272
from google.cloud.aiplatform_v1beta1.services.prediction_service import (
7373
client as prediction_service_client_v1beta1,
7474
)
@@ -90,10 +90,20 @@
9090
from google.cloud.aiplatform_v1beta1.services.tensorboard_service import (
9191
client as tensorboard_service_client_v1beta1,
9292
)
93+
from google.cloud.aiplatform_v1beta1.services.vertex_rag_data_service import (
94+
client as vertex_rag_data_service_client_v1beta1,
95+
)
96+
from google.cloud.aiplatform_v1beta1.services.vertex_rag_data_service import (
97+
async_client as vertex_rag_data_service_async_client_v1beta1,
98+
)
99+
from google.cloud.aiplatform_v1beta1.services.vertex_rag_service import (
100+
client as vertex_rag_service_client_v1beta1,
101+
)
93102
from google.cloud.aiplatform_v1beta1.services.vizier_service import (
94103
client as vizier_service_client_v1beta1,
95104
)
96105

106+
97107
from google.cloud.aiplatform_v1.services.dataset_service import (
98108
client as dataset_service_client_v1,
99109
)
@@ -195,9 +205,14 @@
195205
pipeline_service_client_v1beta1,
196206
prediction_service_client_v1beta1,
197207
prediction_service_async_client_v1beta1,
208+
reasoning_engine_execution_service_client_v1beta1,
209+
reasoning_engine_service_client_v1beta1,
198210
schedule_service_client_v1beta1,
199211
specialist_pool_service_client_v1beta1,
200212
metadata_service_client_v1beta1,
201213
tensorboard_service_client_v1beta1,
214+
vertex_rag_service_client_v1beta1,
215+
vertex_rag_data_service_client_v1beta1,
216+
vertex_rag_data_service_async_client_v1beta1,
202217
vizier_service_client_v1beta1,
203218
)

google/cloud/aiplatform/utils/__init__.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@
6161
persistent_resource_service_client_v1beta1,
6262
reasoning_engine_service_client_v1beta1,
6363
reasoning_engine_execution_service_client_v1beta1,
64+
vertex_rag_data_service_async_client_v1beta1,
65+
vertex_rag_data_service_client_v1beta1,
66+
vertex_rag_service_client_v1beta1,
6467
)
6568
from google.cloud.aiplatform.compat.services import (
6669
dataset_service_client_v1,
@@ -799,6 +802,39 @@ class ReasoningEngineExecutionClientWithOverride(ClientWithOverride):
799802
)
800803

801804

805+
class VertexRagDataClientWithOverride(ClientWithOverride):
806+
_is_temporary = True
807+
_default_version = compat.V1BETA1
808+
_version_map = (
809+
(
810+
compat.V1BETA1,
811+
vertex_rag_data_service_client_v1beta1.VertexRagDataServiceClient,
812+
),
813+
)
814+
815+
816+
class VertexRagDataAsyncClientWithOverride(ClientWithOverride):
817+
_is_temporary = True
818+
_default_version = compat.V1BETA1
819+
_version_map = (
820+
(
821+
compat.V1BETA1,
822+
vertex_rag_data_service_async_client_v1beta1.VertexRagDataServiceAsyncClient,
823+
),
824+
)
825+
826+
827+
class VertexRagClientWithOverride(ClientWithOverride):
828+
_is_temporary = True
829+
_default_version = compat.V1BETA1
830+
_version_map = (
831+
(
832+
compat.V1BETA1,
833+
vertex_rag_service_client_v1beta1.VertexRagServiceClient,
834+
),
835+
)
836+
837+
802838
VertexAiServiceClientWithOverride = TypeVar(
803839
"VertexAiServiceClientWithOverride",
804840
DatasetClientWithOverride,

tests/unit/vertex_rag/conftest.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# Copyright 2024 Google LLC
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
from unittest.mock import patch
18+
from google import auth
19+
from google.api_core import operation as ga_operation
20+
from google.auth import credentials as auth_credentials
21+
from vertexai.preview import rag
22+
from google.cloud.aiplatform_v1beta1 import (
23+
DeleteRagCorpusRequest,
24+
VertexRagDataServiceAsyncClient,
25+
VertexRagDataServiceClient,
26+
)
27+
import test_rag_constants as tc
28+
import mock
29+
import pytest
30+
31+
32+
_TEST_CREDENTIALS = mock.Mock(spec=auth_credentials.AnonymousCredentials())
33+
34+
35+
@pytest.fixture(scope="module")
36+
def google_auth_mock():
37+
with mock.patch.object(auth, "default") as auth_mock:
38+
auth_mock.return_value = (
39+
auth_credentials.AnonymousCredentials(),
40+
tc.TEST_PROJECT,
41+
)
42+
yield auth_mock
43+
44+
45+
@pytest.fixture
46+
def authorized_session_mock():
47+
with patch(
48+
"google.auth.transport.requests.AuthorizedSession"
49+
) as MockAuthorizedSession:
50+
mock_auth_session = MockAuthorizedSession(_TEST_CREDENTIALS)
51+
yield mock_auth_session
52+
53+
54+
@pytest.fixture
55+
def rag_data_client_mock():
56+
with mock.patch.object(
57+
rag.utils._gapic_utils, "create_rag_data_service_client"
58+
) as rag_data_client_mock:
59+
api_client_mock = mock.Mock(spec=VertexRagDataServiceClient)
60+
61+
# get_rag_corpus
62+
api_client_mock.get_rag_corpus.return_value = tc.TEST_GAPIC_RAG_CORPUS
63+
# delete_rag_corpus
64+
delete_rag_corpus_lro_mock = mock.Mock(ga_operation.Operation)
65+
delete_rag_corpus_lro_mock.result.return_value = DeleteRagCorpusRequest()
66+
api_client_mock.delete_rag_corpus.return_value = delete_rag_corpus_lro_mock
67+
# get_rag_file
68+
api_client_mock.get_rag_file.return_value = tc.TEST_GAPIC_RAG_FILE
69+
70+
rag_data_client_mock.return_value = api_client_mock
71+
yield rag_data_client_mock
72+
73+
74+
@pytest.fixture
75+
def rag_data_client_mock_exception():
76+
with mock.patch.object(
77+
rag.utils._gapic_utils, "create_rag_data_service_client"
78+
) as rag_data_client_mock_exception:
79+
api_client_mock = mock.Mock(spec=VertexRagDataServiceClient)
80+
# create_rag_corpus
81+
api_client_mock.create_rag_corpus.side_effect = Exception
82+
# get_rag_corpus
83+
api_client_mock.get_rag_corpus.side_effect = Exception
84+
# list_rag_corpora
85+
api_client_mock.list_rag_corpora.side_effect = Exception
86+
# delete_rag_corpus
87+
api_client_mock.delete_rag_corpus.side_effect = Exception
88+
# upload_rag_file
89+
api_client_mock.upload_rag_file.side_effect = Exception
90+
# import_rag_files
91+
api_client_mock.import_rag_files.side_effect = Exception
92+
# get_rag_file
93+
api_client_mock.get_rag_file.side_effect = Exception
94+
# list_rag_files
95+
api_client_mock.list_rag_files.side_effect = Exception
96+
# delete_rag_file
97+
api_client_mock.delete_rag_file.side_effect = Exception
98+
rag_data_client_mock_exception.return_value = api_client_mock
99+
yield rag_data_client_mock_exception
100+
101+
102+
@pytest.fixture
103+
def rag_data_async_client_mock_exception():
104+
with mock.patch.object(
105+
rag.utils._gapic_utils, "create_rag_data_service_async_client"
106+
) as rag_data_async_client_mock_exception:
107+
api_client_mock = mock.Mock(spec=VertexRagDataServiceAsyncClient)
108+
# import_rag_files
109+
api_client_mock.import_rag_files.side_effect = Exception
110+
rag_data_client_mock_exception.return_value = api_client_mock
111+
yield rag_data_async_client_mock_exception
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# Copyright 2024 Google LLC
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
from vertexai.preview.rag.utils.resources import (
19+
RagCorpus,
20+
RagFile,
21+
)
22+
from google.cloud import aiplatform
23+
from google.cloud.aiplatform_v1beta1 import (
24+
GoogleDriveSource,
25+
RagFileChunkingConfig,
26+
ImportRagFilesConfig,
27+
ImportRagFilesRequest,
28+
ImportRagFilesResponse,
29+
RagCorpus as GapicRagCorpus,
30+
RagFile as GapicRagFile,
31+
RagContexts,
32+
RetrieveContextsResponse,
33+
)
34+
35+
36+
TEST_PROJECT = "test-project"
37+
TEST_PROJECT_NUMBER = "12345678"
38+
TEST_REGION = "us-central1"
39+
TEST_CORPUS_DISPLAY_NAME = "my-corpus-1"
40+
TEST_CORPUS_DISCRIPTION = "My first corpus."
41+
TEST_RAG_CORPUS_ID = "generate-123"
42+
TEST_API_ENDPOINT = "us-central1-" + aiplatform.constants.base.API_BASE_PATH
43+
TEST_RAG_CORPUS_RESOURCE_NAME = f"projects/{TEST_PROJECT_NUMBER}/locations/{TEST_REGION}/ragCorpora/{TEST_RAG_CORPUS_ID}"
44+
45+
# RagCorpus
46+
TEST_GAPIC_RAG_CORPUS = GapicRagCorpus(
47+
name=TEST_RAG_CORPUS_RESOURCE_NAME,
48+
display_name=TEST_CORPUS_DISPLAY_NAME,
49+
description=TEST_CORPUS_DISCRIPTION,
50+
)
51+
TEST_RAG_CORPUS = RagCorpus(
52+
name=TEST_RAG_CORPUS_RESOURCE_NAME,
53+
display_name=TEST_CORPUS_DISPLAY_NAME,
54+
description=TEST_CORPUS_DISCRIPTION,
55+
)
56+
TEST_PAGE_TOKEN = "test-page-token"
57+
58+
# RagFiles
59+
TEST_PATH = "usr/home/my_file.txt"
60+
TEST_GCS_PATH = "gs://usr/home/data_dir/"
61+
TEST_FILE_DISPLAY_NAME = "my-file.txt"
62+
TEST_FILE_DESCRIPTION = "my file."
63+
TEST_HEADERS = {"X-Goog-Upload-Protocol": "multipart"}
64+
TEST_UPLOAD_REQUEST_URI = "https://{}/upload/v1beta1/projects/{}/locations/{}/ragCorpora/{}/ragFiles:upload".format(
65+
TEST_API_ENDPOINT, TEST_PROJECT_NUMBER, TEST_REGION, TEST_RAG_CORPUS_ID
66+
)
67+
TEST_RAG_FILE_ID = "generate-456"
68+
TEST_RAG_FILE_RESOURCE_NAME = (
69+
TEST_RAG_CORPUS_RESOURCE_NAME + f"/ragFiles/{TEST_RAG_FILE_ID}"
70+
)
71+
TEST_UPLOAD_RAG_FILE_RESPONSE_CONTENT = ""
72+
TEST_RAG_FILE_JSON = {
73+
"ragFile": {
74+
"name": TEST_RAG_FILE_RESOURCE_NAME,
75+
"displayName": TEST_FILE_DISPLAY_NAME,
76+
}
77+
}
78+
TEST_RAG_FILE_JSON_ERROR = {"error": {"code": 13}}
79+
TEST_CHUNK_SIZE = 512
80+
TEST_CHUNK_OVERLAP = 100
81+
# GCS
82+
TEST_IMPORT_FILES_CONFIG_GCS = ImportRagFilesConfig()
83+
TEST_IMPORT_FILES_CONFIG_GCS.gcs_source.uris = [TEST_GCS_PATH]
84+
TEST_IMPORT_REQUEST_GCS = ImportRagFilesRequest(
85+
parent=TEST_RAG_CORPUS_RESOURCE_NAME,
86+
import_rag_files_config=TEST_IMPORT_FILES_CONFIG_GCS,
87+
)
88+
# Google Drive folders
89+
TEST_DRIVE_FOLDER_ID = "123"
90+
TEST_DRIVE_FOLDER = (
91+
f"https://drive.google.com/corp/drive/folders/{TEST_DRIVE_FOLDER_ID}"
92+
)
93+
TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER = ImportRagFilesConfig()
94+
TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER.google_drive_source.resource_ids = [
95+
GoogleDriveSource.ResourceId(
96+
resource_id=TEST_DRIVE_FOLDER_ID,
97+
resource_type=GoogleDriveSource.ResourceId.ResourceType.RESOURCE_TYPE_FOLDER,
98+
)
99+
]
100+
TEST_IMPORT_REQUEST_DRIVE_FOLDER = ImportRagFilesRequest(
101+
parent=TEST_RAG_CORPUS_RESOURCE_NAME,
102+
import_rag_files_config=TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER,
103+
)
104+
# Google Drive files
105+
TEST_DRIVE_FILE_ID = "456"
106+
TEST_DRIVE_FILE = f"https://drive.google.com/file/d/{TEST_DRIVE_FILE_ID}"
107+
TEST_IMPORT_FILES_CONFIG_DRIVE_FILE = ImportRagFilesConfig(
108+
rag_file_chunking_config=RagFileChunkingConfig(
109+
chunk_size=TEST_CHUNK_SIZE,
110+
chunk_overlap=TEST_CHUNK_OVERLAP,
111+
)
112+
)
113+
TEST_IMPORT_FILES_CONFIG_DRIVE_FILE.google_drive_source.resource_ids = [
114+
GoogleDriveSource.ResourceId(
115+
resource_id=TEST_DRIVE_FILE_ID,
116+
resource_type=GoogleDriveSource.ResourceId.ResourceType.RESOURCE_TYPE_FILE,
117+
)
118+
]
119+
TEST_IMPORT_REQUEST_DRIVE_FILE = ImportRagFilesRequest(
120+
parent=TEST_RAG_CORPUS_RESOURCE_NAME,
121+
import_rag_files_config=TEST_IMPORT_FILES_CONFIG_DRIVE_FILE,
122+
)
123+
124+
TEST_IMPORT_RESPONSE = ImportRagFilesResponse(imported_rag_files_count=2)
125+
126+
TEST_GAPIC_RAG_FILE = GapicRagFile(
127+
name=TEST_RAG_FILE_RESOURCE_NAME,
128+
display_name=TEST_FILE_DISPLAY_NAME,
129+
description=TEST_FILE_DESCRIPTION,
130+
)
131+
TEST_RAG_FILE = RagFile(
132+
name=TEST_RAG_FILE_RESOURCE_NAME,
133+
display_name=TEST_FILE_DISPLAY_NAME,
134+
description=TEST_FILE_DESCRIPTION,
135+
)
136+
137+
# Retrieval
138+
TEST_QUERY_TEXT = "What happen to the fox and the dog?"
139+
TEST_CONTEXTS = RagContexts(
140+
contexts=[
141+
RagContexts.Context(
142+
source_uri="https://drive.google.com/file/d/123/view?usp=drivesdk",
143+
text="The quick brown fox jumps over the lazy dog.",
144+
),
145+
RagContexts.Context(text="The slow red fox jumps over the lazy dog."),
146+
]
147+
)
148+
TEST_RETRIEVAL_RESPONSE = RetrieveContextsResponse(contexts=TEST_CONTEXTS)

0 commit comments

Comments
 (0)