Skip to content

Commit 8c95d3f

Browse files
yeesiancopybara-github
authored andcommitted
feat: Add support for system paths in ModuleAgent
PiperOrigin-RevId: 758248656
1 parent 02236be commit 8c95d3f

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

vertexai/agent_engines/_agent_engines.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ def __init__(
144144
module_name: str,
145145
agent_name: str,
146146
register_operations: Dict[str, Sequence[str]],
147+
sys_paths: Optional[Sequence[str]] = None,
147148
):
148149
"""Initializes a module-based agent.
149150
@@ -154,11 +155,19 @@ def __init__(
154155
Required. The name of the agent in the module to instantiate.
155156
register_operations (Dict[str, Sequence[str]]):
156157
Required. A dictionary of API modes to a list of method names.
158+
sys_paths (Sequence[str]):
159+
Optional. The system paths to search for the module. It should
160+
be relative to the directory where the code will be running.
161+
I.e. it should correspond to the directory being passed to
162+
`extra_packages=...` in the create method. It will be appended
163+
to the system path in the sequence being specified here, and
164+
only be appended if it is not already in the system path.
157165
"""
158166
self._tmpl_attrs = {
159167
"module_name": module_name,
160168
"agent_name": agent_name,
161169
"register_operations": register_operations,
170+
"sys_paths": sys_paths,
162171
}
163172

164173
def clone(self):
@@ -167,6 +176,7 @@ def clone(self):
167176
module_name=self._tmpl_attrs.get("module_name"),
168177
agent_name=self._tmpl_attrs.get("agent_name"),
169178
register_operations=self._tmpl_attrs.get("register_operations"),
179+
sys_paths=self._tmpl_attrs.get("sys_paths"),
170180
)
171181

172182
def register_operations(self) -> Dict[str, Sequence[str]]:
@@ -178,6 +188,14 @@ def set_up(self) -> None:
178188
It runs the code to import the agent from the module, and registers the
179189
operations of the agent.
180190
"""
191+
if self._tmpl_attrs.get("sys_paths"):
192+
import sys
193+
194+
for sys_path in self._tmpl_attrs.get("sys_paths"):
195+
abs_path = os.path.abspath(sys_path)
196+
if abs_path not in sys.path:
197+
sys.path.append(abs_path)
198+
181199
import importlib
182200

183201
module = importlib.import_module(self._tmpl_attrs.get("module_name"))

0 commit comments

Comments
 (0)