Skip to content

Commit 45c4086

Browse files
authored
fix: add support for API base path overriding (#908)
1 parent 48c2bf1 commit 45c4086

File tree

3 files changed

+29
-16
lines changed

3 files changed

+29
-16
lines changed

google/cloud/aiplatform/initializer.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -194,18 +194,20 @@ def encryption_spec_key_name(self) -> Optional[str]:
194194
return self._encryption_spec_key_name
195195

196196
def get_client_options(
197-
self, location_override: Optional[str] = None, prediction_client: bool = False
197+
self,
198+
location_override: Optional[str] = None,
199+
prediction_client: bool = False,
200+
api_base_path_override: Optional[str] = None,
198201
) -> client_options.ClientOptions:
199202
"""Creates GAPIC client_options using location and type.
200203
201204
Args:
202205
location_override (str):
203-
Set this parameter to get client options for a location different from
204-
location set by initializer. Must be a GCP region supported by AI
205-
Platform (Unified).
206-
prediction_client (str): Optional flag to use a prediction endpoint.
207-
208-
206+
Optional. Set this parameter to get client options for a location different
207+
from location set by initializer. Must be a GCP region supported by
208+
Vertex AI.
209+
prediction_client (str): Optional. flag to use a prediction endpoint.
210+
api_base_path_override (str): Optional. Override default API base path.
209211
Returns:
210212
clients_options (google.api_core.client_options.ClientOptions):
211213
A ClientOptions object set with regionalized API endpoint, i.e.
@@ -222,7 +224,7 @@ def get_client_options(
222224

223225
utils.validate_region(region)
224226

225-
service_base_path = (
227+
service_base_path = api_base_path_override or (
226228
constants.PREDICTION_API_BASE_PATH
227229
if prediction_client
228230
else constants.API_BASE_PATH
@@ -261,17 +263,19 @@ def create_client(
261263
credentials: Optional[auth_credentials.Credentials] = None,
262264
location_override: Optional[str] = None,
263265
prediction_client: bool = False,
266+
api_base_path_override: Optional[str] = None,
264267
) -> utils.VertexAiServiceClientWithOverride:
265268
"""Instantiates a given VertexAiServiceClient with optional
266269
overrides.
267270
268271
Args:
269272
client_class (utils.VertexAiServiceClientWithOverride):
270-
(Required) A Vertex AI Service Client with optional overrides.
273+
Required. A Vertex AI Service Client with optional overrides.
271274
credentials (auth_credentials.Credentials):
272-
Custom auth credentials. If not provided will use the current config.
273-
location_override (str): Optional location override.
274-
prediction_client (str): Optional flag to use a prediction endpoint.
275+
Optional. Custom auth credentials. If not provided will use the current config.
276+
location_override (str): Optional. location override.
277+
prediction_client (str): Optional. flag to use a prediction endpoint.
278+
api_base_path_override (str): Optional. Override default api base path.
275279
Returns:
276280
client: Instantiated Vertex AI Service client with optional overrides
277281
"""
@@ -288,6 +292,7 @@ def create_client(
288292
"client_options": self.get_client_options(
289293
location_override=location_override,
290294
prediction_client=prediction_client,
295+
api_base_path_override=api_base_path_override,
291296
),
292297
"client_info": client_info,
293298
}

google/cloud/aiplatform/training_utils/cloud_profiler/plugins/tensorflow/tensorboard_api.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
from google.api_core import exceptions
3030
from google.cloud import aiplatform
3131
from google.cloud import storage
32-
from google.cloud.aiplatform.constants import base as constants
3332
from google.cloud.aiplatform.utils import TensorboardClientWithOverride
3433
from google.cloud.aiplatform.tensorboard import uploader_utils
3534
from google.cloud.aiplatform.compat.types import tensorboard_experiment
@@ -41,16 +40,16 @@
4140

4241
def _get_api_client() -> TensorboardClientWithOverride:
4342
"""Creates an Tensorboard API client."""
44-
constants.API_BASE_PATH = training_utils.environment_variables.tensorboard_api_uri
45-
4643
m = re.match(
4744
"projects/.*/locations/(.*)/tensorboards/.*",
4845
training_utils.environment_variables.tensorboard_resource_name,
4946
)
5047
region = m[1]
5148

5249
api_client = aiplatform.initializer.global_config.create_client(
53-
client_class=TensorboardClientWithOverride, location_override=region,
50+
client_class=TensorboardClientWithOverride,
51+
location_override=region,
52+
api_base_path_override=training_utils.environment_variables.tensorboard_api_uri,
5453
)
5554

5655
return api_client

tests/unit/aiplatform/test_initializer.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,15 @@ def test_get_client_options(
181181
== expected_endpoint
182182
)
183183

184+
def test_get_client_options_with_api_override(self):
185+
initializer.global_config.init(location="asia-east1")
186+
187+
client_options = initializer.global_config.get_client_options(
188+
api_base_path_override="override.googleapis.com"
189+
)
190+
191+
assert client_options.api_endpoint == "asia-east1-override.googleapis.com"
192+
184193

185194
class TestThreadPool:
186195
def teardown_method(self):

0 commit comments

Comments
 (0)