Skip to content

Commit 2fc05ca

Browse files
authored
feat: support new protobuf value param types for Pipeline Job client (#797)
* feat: update PipelineJob to accept protobuf value * fix tests * address comments
1 parent 7ab05d5 commit 2fc05ca

File tree

3 files changed

+176
-133
lines changed

3 files changed

+176
-133
lines changed

google/cloud/aiplatform/utils/pipeline_utils.py

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,13 @@ def from_job_spec_json(
6464
.get("inputDefinitions", {})
6565
.get("parameters", {})
6666
)
67-
parameter_types = {k: v["type"] for k, v in parameter_input_definitions.items()}
67+
# 'type' is deprecated in IR and change to 'parameterType'.
68+
parameter_types = {
69+
k: v.get("parameterType") or v.get("type")
70+
for k, v in parameter_input_definitions.items()
71+
}
6872

69-
pipeline_root = runtime_config_spec.get("gcs_output_directory")
73+
pipeline_root = runtime_config_spec.get("gcsOutputDirectory")
7074
parameter_values = _parse_runtime_parameters(runtime_config_spec)
7175
return cls(pipeline_root, parameter_types, parameter_values)
7276

@@ -108,7 +112,7 @@ def build(self) -> Dict[str, Any]:
108112
"compile time, or when calling the service."
109113
)
110114
return {
111-
"gcs_output_directory": self._pipeline_root,
115+
"gcsOutputDirectory": self._pipeline_root,
112116
"parameters": {
113117
k: self._get_vertex_value(k, v)
114118
for k, v in self._parameter_values.items()
@@ -117,14 +121,14 @@ def build(self) -> Dict[str, Any]:
117121
}
118122

119123
def _get_vertex_value(
120-
self, name: str, value: Union[int, float, str]
124+
self, name: str, value: Union[int, float, str, bool, list, dict]
121125
) -> Dict[str, Any]:
122126
"""Converts primitive values into Vertex pipeline Value proto message.
123127
124128
Args:
125129
name (str):
126130
Required. The name of the pipeline parameter.
127-
value (Union[int, float, str]):
131+
value (Union[int, float, str, bool, list, dict]):
128132
Required. The value of the pipeline parameter.
129133
130134
Returns:
@@ -150,6 +154,16 @@ def _get_vertex_value(
150154
result["doubleValue"] = value
151155
elif self._parameter_types[name] == "STRING":
152156
result["stringValue"] = value
157+
elif self._parameter_types[name] == "BOOLEAN":
158+
result["boolValue"] = value
159+
elif self._parameter_types[name] == "NUMBER_DOUBLE":
160+
result["numberValue"] = value
161+
elif self._parameter_types[name] == "NUMBER_INTEGER":
162+
result["numberValue"] = value
163+
elif self._parameter_types[name] == "LIST":
164+
result["listValue"] = value
165+
elif self._parameter_types[name] == "STRUCT":
166+
result["structValue"] = value
153167
else:
154168
raise TypeError("Got unknown type of value: {}".format(value))
155169

@@ -164,19 +178,19 @@ def _parse_runtime_parameters(
164178
Raises:
165179
TypeError: if the parameter type is not one of 'INT', 'DOUBLE', 'STRING'.
166180
"""
167-
runtime_parameters = runtime_config_spec.get("parameters")
168-
if not runtime_parameters:
169-
return None
170-
171-
result = {}
172-
for name, value in runtime_parameters.items():
173-
if "intValue" in value:
174-
result[name] = int(value["intValue"])
175-
elif "doubleValue" in value:
176-
result[name] = float(value["doubleValue"])
177-
elif "stringValue" in value:
178-
result[name] = value["stringValue"]
179-
else:
180-
raise TypeError("Got unknown type of value: {}".format(value))
181+
# 'parameters' are deprecated in IR and changed to 'parameterValues'.
182+
if runtime_config_spec.get("parameterValues") is not None:
183+
return runtime_config_spec.get("parameterValues")
181184

182-
return result
185+
if runtime_config_spec.get("parameters") is not None:
186+
result = {}
187+
for name, value in runtime_config_spec.get("parameters").items():
188+
if "intValue" in value:
189+
result[name] = int(value["intValue"])
190+
elif "doubleValue" in value:
191+
result[name] = float(value["doubleValue"])
192+
elif "stringValue" in value:
193+
result[name] = value["stringValue"]
194+
else:
195+
raise TypeError("Got unknown type of value: {}".format(value))
196+
return result

0 commit comments

Comments
 (0)