Skip to content

Commit f3addc9

Browse files
yeesiancopybara-github
authored andcommitted
feat: make it optional to pass in an instance of an agent when creating a new ReasoningEngine instance
PiperOrigin-RevId: 741007911
1 parent ff8f142 commit f3addc9

File tree

3 files changed

+71
-81
lines changed

3 files changed

+71
-81
lines changed

tests/unit/vertex_langchain/test_agent_engines.py

Lines changed: 22 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -304,29 +304,21 @@ def register_operations(self) -> Dict[str, List[str]]:
304304
)
305305
)
306306
_TEST_AGENT_ENGINE_QUERY_SCHEMA[_TEST_MODE_KEY_IN_SCHEMA] = _TEST_STANDARD_API_MODE
307+
_TEST_AGENT_ENGINE_PACKAGE_SPEC = types.ReasoningEngineSpec.PackageSpec(
308+
python_version=f"{sys.version_info.major}.{sys.version_info.minor}",
309+
pickle_object_gcs_uri=_TEST_AGENT_ENGINE_GCS_URI,
310+
dependency_files_gcs_uri=_TEST_AGENT_ENGINE_DEPENDENCY_FILES_GCS_URI,
311+
requirements_gcs_uri=_TEST_AGENT_ENGINE_REQUIREMENTS_GCS_URI,
312+
)
307313
_TEST_INPUT_AGENT_ENGINE_OBJ = types.ReasoningEngine(
308314
display_name=_TEST_AGENT_ENGINE_DISPLAY_NAME,
309-
spec=types.ReasoningEngineSpec(
310-
package_spec=types.ReasoningEngineSpec.PackageSpec(
311-
python_version=f"{sys.version_info.major}.{sys.version_info.minor}",
312-
pickle_object_gcs_uri=_TEST_AGENT_ENGINE_GCS_URI,
313-
dependency_files_gcs_uri=_TEST_AGENT_ENGINE_DEPENDENCY_FILES_GCS_URI,
314-
requirements_gcs_uri=_TEST_AGENT_ENGINE_REQUIREMENTS_GCS_URI,
315-
),
316-
),
315+
spec=types.ReasoningEngineSpec(package_spec=_TEST_AGENT_ENGINE_PACKAGE_SPEC),
317316
)
318317
_TEST_INPUT_AGENT_ENGINE_OBJ.spec.class_methods.append(_TEST_AGENT_ENGINE_QUERY_SCHEMA)
319318
_TEST_AGENT_ENGINE_OBJ = types.ReasoningEngine(
320319
name=_TEST_AGENT_ENGINE_RESOURCE_NAME,
321320
display_name=_TEST_AGENT_ENGINE_DISPLAY_NAME,
322-
spec=types.ReasoningEngineSpec(
323-
package_spec=types.ReasoningEngineSpec.PackageSpec(
324-
python_version=f"{sys.version_info.major}.{sys.version_info.minor}",
325-
pickle_object_gcs_uri=_TEST_AGENT_ENGINE_GCS_URI,
326-
dependency_files_gcs_uri=_TEST_AGENT_ENGINE_DEPENDENCY_FILES_GCS_URI,
327-
requirements_gcs_uri=_TEST_AGENT_ENGINE_REQUIREMENTS_GCS_URI,
328-
),
329-
),
321+
spec=types.ReasoningEngineSpec(package_spec=_TEST_AGENT_ENGINE_PACKAGE_SPEC),
330322
)
331323
_TEST_AGENT_ENGINE_OBJ.spec.class_methods.append(_TEST_AGENT_ENGINE_QUERY_SCHEMA)
332324
_TEST_UPDATE_AGENT_ENGINE_OBJ = types.ReasoningEngine(
@@ -603,17 +595,6 @@ def mock_streamer():
603595
yield stream_query_agent_engine_mock
604596

605597

606-
# Function scope is required for the pytest parameterized tests.
607-
@pytest.fixture(scope="function")
608-
def types_agent_engine_mock():
609-
with mock.patch.object(
610-
types,
611-
"ReasoningEngine",
612-
return_value=types.ReasoningEngine(name=_TEST_AGENT_ENGINE_RESOURCE_NAME),
613-
) as types_agent_engine_mock:
614-
yield types_agent_engine_mock
615-
616-
617598
@pytest.fixture(scope="function")
618599
def get_gca_resource_mock():
619600
with mock.patch.object(
@@ -1234,24 +1215,22 @@ def test_create_class_methods_spec_with_registered_operations(
12341215
test_case_name,
12351216
test_engine,
12361217
want_class_methods,
1237-
types_agent_engine_mock,
1218+
create_agent_engine_mock,
12381219
):
1239-
agent_engines.create(test_engine)
1240-
want_spec = types.ReasoningEngineSpec(
1241-
package_spec=types.ReasoningEngineSpec.PackageSpec(
1242-
python_version=(f"{sys.version_info.major}.{sys.version_info.minor}"),
1243-
requirements_gcs_uri=_TEST_AGENT_ENGINE_REQUIREMENTS_GCS_URI,
1244-
pickle_object_gcs_uri=_TEST_AGENT_ENGINE_GCS_URI,
1245-
)
1220+
agent_engines.create(
1221+
test_engine,
1222+
display_name=_TEST_AGENT_ENGINE_DISPLAY_NAME,
1223+
requirements=_TEST_AGENT_ENGINE_REQUIREMENTS,
1224+
extra_packages=[_TEST_AGENT_ENGINE_EXTRA_PACKAGE_PATH],
12461225
)
1247-
want_spec.class_methods.extend(want_class_methods)
1248-
assert_called_with_diff(
1249-
types_agent_engine_mock,
1250-
{
1251-
"display_name": None,
1252-
"description": None,
1253-
"spec": want_spec,
1254-
},
1226+
spec = types.ReasoningEngineSpec(package_spec=_TEST_AGENT_ENGINE_PACKAGE_SPEC)
1227+
spec.class_methods.extend(want_class_methods)
1228+
create_agent_engine_mock.assert_called_with(
1229+
parent=_TEST_PARENT,
1230+
reasoning_engine=types.ReasoningEngine(
1231+
display_name=_TEST_AGENT_ENGINE_DISPLAY_NAME,
1232+
spec=spec,
1233+
),
12551234
)
12561235

12571236
# pytest does not allow absl.testing.parameterized.named_parameters.

vertexai/agent_engines/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def get(resource_name: str) -> AgentEngine:
5959

6060

6161
def create(
62-
agent_engine: Union[Queryable, OperationRegistrable],
62+
agent_engine: Optional[Union[Queryable, OperationRegistrable]] = None,
6363
*,
6464
requirements: Optional[Union[str, Sequence[str]]] = None,
6565
display_name: Optional[str] = None,

vertexai/agent_engines/_agent_engines.py

Lines changed: 48 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def resource_name(self) -> str:
151151
@classmethod
152152
def create(
153153
cls,
154-
agent_engine: Union[Queryable, OperationRegistrable],
154+
agent_engine: Optional[Union[Queryable, OperationRegistrable]] = None,
155155
*,
156156
requirements: Optional[Union[str, Sequence[str]]] = None,
157157
display_name: Optional[str] = None,
@@ -197,7 +197,7 @@ def create(
197197
198198
Args:
199199
agent_engine (AgentEngineInterface):
200-
Required. The Agent Engine to be created.
200+
Optional. The Agent Engine to be created.
201201
requirements (Union[str, Sequence[str]]):
202202
Optional. The set of PyPI dependencies needed. It can either be
203203
the path to a single file (requirements.txt), or an ordered list
@@ -222,14 +222,22 @@ def create(
222222
ValueError: If the `location` was not set using `vertexai.init`.
223223
ValueError: If the `staging_bucket` was not set using vertexai.init.
224224
ValueError: If the `staging_bucket` does not start with "gs://".
225+
ValueError: If `extra_packages` is specified but `agent_engine` is None.
226+
ValueError: If `requirements` is specified but `agent_engine` is None.
225227
FileNotFoundError: If `extra_packages` includes a file or directory
226228
that does not exist.
227229
IOError: If requirements is a string that corresponds to a
228230
nonexistent file.
229231
"""
230232
sys_version = f"{sys.version_info.major}.{sys.version_info.minor}"
231233
_validate_sys_version_or_raise(sys_version)
232-
agent_engine = _validate_agent_engine_or_raise(agent_engine)
234+
if agent_engine is not None:
235+
agent_engine = _validate_agent_engine_or_raise(agent_engine)
236+
if agent_engine is None:
237+
if requirements is not None:
238+
raise ValueError("requirements must be None if agent_engine is None.")
239+
if extra_packages is not None:
240+
raise ValueError("extra_packages must be None if agent_engine is None.")
233241
requirements = _validate_requirements_or_raise(agent_engine, requirements)
234242
extra_packages = _validate_extra_packages_or_raise(extra_packages)
235243
gcs_dir_name = gcs_dir_name or _DEFAULT_GCS_DIR_NAME
@@ -251,43 +259,45 @@ def create(
251259
gcs_dir_name=gcs_dir_name,
252260
extra_packages=extra_packages,
253261
)
254-
# Update the package spec.
255-
package_spec = aip_types.ReasoningEngineSpec.PackageSpec(
256-
python_version=sys_version,
257-
pickle_object_gcs_uri="{}/{}/{}".format(
258-
staging_bucket,
259-
gcs_dir_name,
260-
_BLOB_FILENAME,
261-
),
262+
reasoning_engine = aip_types.ReasoningEngine(
263+
display_name=display_name,
264+
description=description,
262265
)
263-
if extra_packages:
264-
package_spec.dependency_files_gcs_uri = "{}/{}/{}".format(
265-
staging_bucket,
266-
gcs_dir_name,
267-
_EXTRA_PACKAGES_FILE,
266+
if agent_engine is not None:
267+
# Update the package spec.
268+
package_spec = aip_types.ReasoningEngineSpec.PackageSpec(
269+
python_version=sys_version,
270+
pickle_object_gcs_uri="{}/{}/{}".format(
271+
staging_bucket,
272+
gcs_dir_name,
273+
_BLOB_FILENAME,
274+
),
268275
)
269-
if requirements:
270-
package_spec.requirements_gcs_uri = "{}/{}/{}".format(
271-
staging_bucket,
272-
gcs_dir_name,
273-
_REQUIREMENTS_FILE,
276+
if extra_packages:
277+
package_spec.dependency_files_gcs_uri = "{}/{}/{}".format(
278+
staging_bucket,
279+
gcs_dir_name,
280+
_EXTRA_PACKAGES_FILE,
281+
)
282+
if requirements:
283+
package_spec.requirements_gcs_uri = "{}/{}/{}".format(
284+
staging_bucket,
285+
gcs_dir_name,
286+
_REQUIREMENTS_FILE,
287+
)
288+
agent_engine_spec = aip_types.ReasoningEngineSpec(
289+
package_spec=package_spec,
274290
)
275-
agent_engine_spec = aip_types.ReasoningEngineSpec(
276-
package_spec=package_spec,
277-
)
278-
class_methods_spec = _generate_class_methods_spec_or_raise(
279-
agent_engine, _get_registered_operations(agent_engine)
280-
)
281-
agent_engine_spec.class_methods.extend(class_methods_spec)
291+
class_methods_spec = _generate_class_methods_spec_or_raise(
292+
agent_engine, _get_registered_operations(agent_engine)
293+
)
294+
agent_engine_spec.class_methods.extend(class_methods_spec)
295+
reasoning_engine.spec = agent_engine_spec
282296
operation_future = sdk_resource.api_client.create_reasoning_engine(
283297
parent=initializer.global_config.common_location_path(
284298
project=sdk_resource.project, location=sdk_resource.location
285299
),
286-
reasoning_engine=aip_types.ReasoningEngine(
287-
display_name=display_name,
288-
description=description,
289-
spec=agent_engine_spec,
290-
),
300+
reasoning_engine=reasoning_engine,
291301
)
292302
_LOGGER.log_create_with_lro(cls, operation_future)
293303
_LOGGER.info(
@@ -309,10 +319,11 @@ def create(
309319
credentials=sdk_resource.credentials,
310320
location_override=sdk_resource.location,
311321
)
312-
try:
313-
_register_api_methods_or_raise(sdk_resource)
314-
except Exception as e:
315-
_LOGGER.warning("Failed to register API methods: {%s}", e)
322+
if agent_engine is not None:
323+
try:
324+
_register_api_methods_or_raise(sdk_resource)
325+
except Exception as e:
326+
_LOGGER.warning("Failed to register API methods: {%s}", e)
316327
sdk_resource._operation_schemas = None
317328
return sdk_resource
318329

0 commit comments

Comments
 (0)