|
31 | 31 | from google.cloud import aiplatform
|
32 | 32 | from google.cloud.aiplatform import base
|
33 | 33 | from google.cloud.aiplatform.compat.types import custom_job as gca_custom_job_compat
|
34 |
| -from google.cloud.aiplatform.compat.types import ( |
35 |
| - custom_job_v1beta1 as gca_custom_job_v1beta1, |
36 |
| -) |
37 | 34 | from google.cloud.aiplatform.compat.types import io as gca_io_compat
|
38 | 35 | from google.cloud.aiplatform.compat.types import job_state as gca_job_state_compat
|
39 | 36 | from google.cloud.aiplatform.compat.types import (
|
40 | 37 | encryption_spec as gca_encryption_spec_compat,
|
41 | 38 | )
|
42 | 39 | from google.cloud.aiplatform_v1.services.job_service import client as job_service_client
|
43 |
| -from google.cloud.aiplatform_v1beta1.services.job_service import ( |
44 |
| - client as job_service_client_v1beta1, |
45 |
| -) |
46 | 40 |
|
47 | 41 | _TEST_PROJECT = "test-project"
|
48 | 42 | _TEST_LOCATION = "us-central1"
|
|
114 | 108 | )
|
115 | 109 |
|
116 | 110 |
|
117 |
| -def _get_custom_job_proto(state=None, name=None, error=None, version="v1"): |
| 111 | +def _get_custom_job_proto(state=None, name=None, error=None): |
118 | 112 | custom_job_proto = copy.deepcopy(_TEST_BASE_CUSTOM_JOB_PROTO)
|
119 | 113 | custom_job_proto.name = name
|
120 | 114 | custom_job_proto.state = state
|
121 | 115 | custom_job_proto.error = error
|
122 |
| - |
123 |
| - if version == "v1beta1": |
124 |
| - v1beta1_custom_job_proto = gca_custom_job_v1beta1.CustomJob() |
125 |
| - v1beta1_custom_job_proto._pb.MergeFromString( |
126 |
| - custom_job_proto._pb.SerializeToString() |
127 |
| - ) |
128 |
| - custom_job_proto = v1beta1_custom_job_proto |
129 |
| - custom_job_proto.job_spec.tensorboard = _TEST_TENSORBOARD_NAME |
130 |
| - |
131 | 116 | return custom_job_proto
|
132 | 117 |
|
133 | 118 |
|
134 |
| -def _get_custom_job_proto_with_enable_web_access( |
135 |
| - state=None, name=None, error=None, version="v1" |
136 |
| -): |
137 |
| - custom_job_proto = _get_custom_job_proto( |
138 |
| - state=state, name=name, error=error, version=version |
139 |
| - ) |
| 119 | +def _get_custom_job_proto_with_enable_web_access(state=None, name=None, error=None): |
| 120 | + custom_job_proto = _get_custom_job_proto(state=state, name=name, error=error) |
140 | 121 | custom_job_proto.job_spec.enable_web_access = _TEST_ENABLE_WEB_ACCESS
|
141 | 122 | if state == gca_job_state_compat.JobState.JOB_STATE_RUNNING:
|
142 | 123 | custom_job_proto.web_access_uris = _TEST_WEB_ACCESS_URIS
|
@@ -260,24 +241,25 @@ def create_custom_job_mock_with_enable_web_access():
|
260 | 241 |
|
261 | 242 |
|
262 | 243 | @pytest.fixture
|
263 |
| -def create_custom_job_mock_fail(): |
| 244 | +def create_custom_job_mock_with_tensorboard(): |
264 | 245 | with mock.patch.object(
|
265 | 246 | job_service_client.JobServiceClient, "create_custom_job"
|
266 | 247 | ) as create_custom_job_mock:
|
267 |
| - create_custom_job_mock.side_effect = RuntimeError("Mock fail") |
| 248 | + custom_job_proto = _get_custom_job_proto( |
| 249 | + name=_TEST_CUSTOM_JOB_NAME, |
| 250 | + state=gca_job_state_compat.JobState.JOB_STATE_PENDING, |
| 251 | + ) |
| 252 | + custom_job_proto.job_spec.tensorboard = _TEST_TENSORBOARD_NAME |
| 253 | + create_custom_job_mock.return_value = custom_job_proto |
268 | 254 | yield create_custom_job_mock
|
269 | 255 |
|
270 | 256 |
|
271 | 257 | @pytest.fixture
|
272 |
| -def create_custom_job_v1beta1_mock(): |
| 258 | +def create_custom_job_mock_fail(): |
273 | 259 | with mock.patch.object(
|
274 |
| - job_service_client_v1beta1.JobServiceClient, "create_custom_job" |
| 260 | + job_service_client.JobServiceClient, "create_custom_job" |
275 | 261 | ) as create_custom_job_mock:
|
276 |
| - create_custom_job_mock.return_value = _get_custom_job_proto( |
277 |
| - name=_TEST_CUSTOM_JOB_NAME, |
278 |
| - state=gca_job_state_compat.JobState.JOB_STATE_PENDING, |
279 |
| - version="v1beta1", |
280 |
| - ) |
| 262 | + create_custom_job_mock.side_effect = RuntimeError("Mock fail") |
281 | 263 | yield create_custom_job_mock
|
282 | 264 |
|
283 | 265 |
|
@@ -573,7 +555,7 @@ def test_get_web_access_uris_job_succeeded(
|
573 | 555 |
|
574 | 556 | @pytest.mark.parametrize("sync", [True, False])
|
575 | 557 | def test_create_custom_job_with_tensorboard(
|
576 |
| - self, create_custom_job_v1beta1_mock, get_custom_job_mock, sync |
| 558 | + self, create_custom_job_mock_with_tensorboard, get_custom_job_mock, sync |
577 | 559 | ):
|
578 | 560 |
|
579 | 561 | aiplatform.init(
|
@@ -601,9 +583,10 @@ def test_create_custom_job_with_tensorboard(
|
601 | 583 |
|
602 | 584 | job.wait()
|
603 | 585 |
|
604 |
| - expected_custom_job = _get_custom_job_proto(version="v1beta1") |
| 586 | + expected_custom_job = _get_custom_job_proto() |
| 587 | + expected_custom_job.job_spec.tensorboard = _TEST_TENSORBOARD_NAME |
605 | 588 |
|
606 |
| - create_custom_job_v1beta1_mock.assert_called_once_with( |
| 589 | + create_custom_job_mock_with_tensorboard.assert_called_once_with( |
607 | 590 | parent=_TEST_PARENT, custom_job=expected_custom_job
|
608 | 591 | )
|
609 | 592 |
|
|
0 commit comments