Skip to content

Commit e220312

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
fix: RAG Fix v1 rag_store compatibility with generative_models Tool by changing back to v1beta1
PiperOrigin-RevId: 702520155
1 parent 0537fec commit e220312

File tree

3 files changed

+69
-18
lines changed

3 files changed

+69
-18
lines changed

tests/unit/vertex_rag/test_rag_store.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,39 @@
2222

2323
@pytest.mark.usefixtures("google_auth_mock")
2424
class TestRagStoreValidations:
25+
def test_retrieval_tool_success(self):
26+
tool = Tool.from_retrieval(
27+
retrieval=rag.Retrieval(
28+
source=rag.VertexRagStore(
29+
rag_resources=[tc.TEST_RAG_RESOURCE],
30+
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_CONFIG,
31+
),
32+
)
33+
)
34+
assert tool is not None
35+
36+
def test_retrieval_tool_vector_similarity_success(self):
37+
tool = Tool.from_retrieval(
38+
retrieval=rag.Retrieval(
39+
source=rag.VertexRagStore(
40+
rag_resources=[tc.TEST_RAG_RESOURCE],
41+
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_SIMILARITY_CONFIG,
42+
),
43+
)
44+
)
45+
assert tool is not None
46+
47+
def test_retrieval_tool_no_rag_resources(self):
48+
with pytest.raises(ValueError) as e:
49+
Tool.from_retrieval(
50+
retrieval=rag.Retrieval(
51+
source=rag.VertexRagStore(
52+
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_SIMILARITY_CONFIG,
53+
),
54+
)
55+
)
56+
e.match("rag_resources must be specified.")
57+
2558
def test_retrieval_tool_invalid_name(self):
2659
with pytest.raises(ValueError) as e:
2760
Tool.from_retrieval(

tests/unit/vertex_rag/test_rag_store_preview.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,30 @@ def test_retrieval_tool_ranking_config_success(self):
7373
)
7474
)
7575

76+
def test_empty_retrieval_tool_success(self):
77+
tool = Tool.from_retrieval(
78+
retrieval=rag.Retrieval(
79+
source=rag.VertexRagStore(
80+
rag_resources=[test_rag_constants_preview.TEST_RAG_RESOURCE],
81+
rag_retrieval_config=rag.RagRetrievalConfig(),
82+
similarity_top_k=3,
83+
vector_distance_threshold=0.4,
84+
),
85+
)
86+
)
87+
assert tool is not None
88+
89+
def test_retrieval_tool_no_rag_resources(self):
90+
with pytest.raises(ValueError) as e:
91+
Tool.from_retrieval(
92+
retrieval=rag.Retrieval(
93+
source=rag.VertexRagStore(
94+
rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_SIMILARITY_CONFIG,
95+
),
96+
)
97+
)
98+
e.match("rag_resources or rag_corpora must be specified.")
99+
76100
def test_retrieval_tool_invalid_name(self):
77101
with pytest.raises(ValueError) as e:
78102
Tool.from_retrieval(

vertexai/rag/rag_store.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919
import re
2020
from typing import List, Optional, Union
2121

22-
from google.cloud import aiplatform_v1
22+
from google.cloud import aiplatform_v1beta1
2323
from google.cloud.aiplatform import initializer
24-
from google.cloud.aiplatform_v1.types import tool as gapic_tool_types
25-
from vertexai.preview import generative_models
24+
from google.cloud.aiplatform_v1beta1.types import tool as gapic_tool_types
25+
from vertexai import generative_models
2626
from vertexai.rag.utils import _gapic_utils
2727
from vertexai.rag.utils import resources
2828

@@ -103,7 +103,7 @@ def __init__(
103103
)
104104

105105
# If rag_retrieval_config is not specified, set it to default values.
106-
api_retrieval_config = aiplatform_v1.RagRetrievalConfig()
106+
api_retrieval_config = aiplatform_v1beta1.RagRetrievalConfig()
107107
# If rag_retrieval_config is specified, populate the default config.
108108
if rag_retrieval_config:
109109
api_retrieval_config.top_k = rag_retrieval_config.top_k
@@ -128,17 +128,11 @@ def __init__(
128128
rag_retrieval_config.filter.vector_similarity_threshold
129129
)
130130

131-
if rag_resources:
132-
gapic_rag_resource = gapic_tool_types.VertexRagStore.RagResource(
133-
rag_corpus=rag_corpus_name,
134-
rag_file_ids=rag_resources[0].rag_file_ids,
135-
)
136-
self._raw_vertex_rag_store = gapic_tool_types.VertexRagStore(
137-
rag_resources=[gapic_rag_resource],
138-
rag_retrieval_config=api_retrieval_config,
139-
)
140-
else:
141-
self._raw_vertex_rag_store = gapic_tool_types.VertexRagStore(
142-
rag_corpora=[rag_corpus_name],
143-
rag_retrieval_config=api_retrieval_config,
144-
)
131+
gapic_rag_resource = gapic_tool_types.VertexRagStore.RagResource(
132+
rag_corpus=rag_corpus_name,
133+
rag_file_ids=rag_resources[0].rag_file_ids,
134+
)
135+
self._raw_vertex_rag_store = gapic_tool_types.VertexRagStore(
136+
rag_resources=[gapic_rag_resource],
137+
rag_retrieval_config=api_retrieval_config,
138+
)

0 commit comments

Comments
 (0)