Skip to content

Commit 96a850f

Browse files
authored
feat: add create_batch_prediction_job samples (#67)
* chore: sample tests lint * lint * lnt * lint * feat: add create_batch_prediction_job samples * lint
1 parent 77956b2 commit 96a850f

4 files changed

+296
-0
lines changed
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# [START aiplatform_create_batch_prediction_job_bigquery_sample]
16+
from google.cloud import aiplatform
17+
from google.protobuf import json_format
18+
from google.protobuf.struct_pb2 import Value
19+
20+
21+
def create_batch_prediction_job_bigquery_sample(
22+
project: str,
23+
display_name: str,
24+
model_name: str,
25+
instances_format: str,
26+
bigquery_source_input_uri: str,
27+
predictions_format: str,
28+
bigquery_destination_output_uri: str,
29+
location: str = "us-central1",
30+
api_endpoint: str = "us-central1-aiplatform.googleapis.com",
31+
):
32+
client_options = {"api_endpoint": api_endpoint}
33+
# Initialize client that will be used to create and send requests.
34+
# This client only needs to be created once, and can be reused for multiple requests.
35+
client = aiplatform.gapic.JobServiceClient(client_options=client_options)
36+
model_parameters_dict = {}
37+
model_parameters = json_format.ParseDict(model_parameters_dict, Value())
38+
39+
batch_prediction_job = {
40+
"display_name": display_name,
41+
# Format: 'projects/{project}/locations/{location}/models/{model_id}'
42+
"model": model_name,
43+
"model_parameters": model_parameters,
44+
"input_config": {
45+
"instances_format": instances_format,
46+
"bigquery_source": {"input_uri": bigquery_source_input_uri},
47+
},
48+
"output_config": {
49+
"predictions_format": predictions_format,
50+
"bigquery_destination": {"output_uri": bigquery_destination_output_uri},
51+
},
52+
# optional
53+
"generate_explanation": True,
54+
}
55+
parent = f"projects/{project}/locations/{location}"
56+
response = client.create_batch_prediction_job(
57+
parent=parent, batch_prediction_job=batch_prediction_job
58+
)
59+
print("response:", response)
60+
61+
62+
# [END aiplatform_create_batch_prediction_job_bigquery_sample]
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
from uuid import uuid4
17+
18+
from google.cloud import aiplatform
19+
import pytest
20+
21+
import create_batch_prediction_job_bigquery_sample
22+
import helpers
23+
24+
PROJECT_ID = os.getenv("BUILD_SPECIFIC_GCLOUD_PROJECT")
25+
LOCATION = "us-central1"
26+
MODEL_ID = "3125638878883479552" # bq all
27+
DISPLAY_NAME = f"temp_create_batch_prediction_job_test_{uuid4()}"
28+
BIGQUERY_SOURCE_INPUT_URI = "bq://ucaip-sample-tests.table_test.all_bq_types"
29+
BIGQUERY_DESTINATION_OUTPUT_URI = "bq://ucaip-sample-tests"
30+
INSTANCES_FORMAT = "bigquery"
31+
PREDICTIONS_FORMAT = "bigquery"
32+
33+
34+
@pytest.fixture
35+
def shared_state():
36+
state = {}
37+
yield state
38+
39+
40+
@pytest.fixture
41+
def job_client():
42+
job_client = aiplatform.gapic.JobServiceClient(
43+
client_options={"api_endpoint": "us-central1-aiplatform.googleapis.com"}
44+
)
45+
return job_client
46+
47+
48+
@pytest.fixture(scope="function", autouse=True)
49+
def teardown(shared_state, job_client):
50+
yield
51+
52+
job_client.cancel_batch_prediction_job(name=shared_state["batch_prediction_job_name"])
53+
54+
# Waiting until the job is in CANCELLED state.
55+
helpers.wait_for_job_state(
56+
get_job_method=job_client.get_batch_prediction_job,
57+
name=shared_state["batch_prediction_job_name"],
58+
)
59+
60+
job_client.delete_batch_prediction_job(name=shared_state["batch_prediction_job_name"])
61+
62+
63+
def test_ucaip_generated_create_batch_prediction_job_bigquery_sample(
64+
capsys, shared_state
65+
):
66+
67+
model_name = f"projects/{PROJECT_ID}/locations/{LOCATION}/models/{MODEL_ID}"
68+
69+
create_batch_prediction_job_bigquery_sample.create_batch_prediction_job_bigquery_sample(
70+
project=PROJECT_ID,
71+
display_name=DISPLAY_NAME,
72+
model_name=model_name,
73+
bigquery_source_input_uri=BIGQUERY_SOURCE_INPUT_URI,
74+
bigquery_destination_output_uri=BIGQUERY_DESTINATION_OUTPUT_URI,
75+
instances_format=INSTANCES_FORMAT,
76+
predictions_format=PREDICTIONS_FORMAT,
77+
)
78+
79+
out, _ = capsys.readouterr()
80+
81+
# Save resource name of the newly created batch prediction job
82+
shared_state["batch_prediction_job_name"] = helpers.get_name(out)
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# [START aiplatform_create_batch_prediction_job_sample]
16+
from google.cloud import aiplatform
17+
from google.protobuf import json_format
18+
from google.protobuf.struct_pb2 import Value
19+
20+
21+
def create_batch_prediction_job_sample(
22+
project: str,
23+
display_name: str,
24+
model_name: str,
25+
instances_format: str,
26+
gcs_source_uri: str,
27+
predictions_format: str,
28+
gcs_destination_output_uri_prefix: str,
29+
location: str = "us-central1",
30+
api_endpoint: str = "us-central1-aiplatform.googleapis.com",
31+
):
32+
client_options = {"api_endpoint": api_endpoint}
33+
# Initialize client that will be used to create and send requests.
34+
# This client only needs to be created once, and can be reused for multiple requests.
35+
client = aiplatform.gapic.JobServiceClient(client_options=client_options)
36+
model_parameters_dict = {}
37+
model_parameters = json_format.ParseDict(model_parameters_dict, Value())
38+
39+
batch_prediction_job = {
40+
"display_name": display_name,
41+
# Format: 'projects/{project}/locations/{location}/models/{model_id}'
42+
"model": model_name,
43+
"model_parameters": model_parameters,
44+
"input_config": {
45+
"instances_format": instances_format,
46+
"gcs_source": {"uris": [gcs_source_uri]},
47+
},
48+
"output_config": {
49+
"predictions_format": predictions_format,
50+
"gcs_destination": {"output_uri_prefix": gcs_destination_output_uri_prefix},
51+
},
52+
"dedicated_resources": {
53+
"machine_spec": {
54+
"machine_type": "n1-standard-2",
55+
"accelerator_type": aiplatform.gapic.AcceleratorType.NVIDIA_TESLA_K80,
56+
"accelerator_count": 1,
57+
},
58+
"starting_replica_count": 1,
59+
"max_replica_count": 1,
60+
},
61+
}
62+
parent = f"projects/{project}/locations/{location}"
63+
response = client.create_batch_prediction_job(
64+
parent=parent, batch_prediction_job=batch_prediction_job
65+
)
66+
print("response:", response)
67+
68+
69+
# [END aiplatform_create_batch_prediction_job_sample]
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
from uuid import uuid4
17+
18+
from google.cloud import aiplatform
19+
import pytest
20+
21+
import create_batch_prediction_job_sample
22+
import helpers
23+
24+
PROJECT_ID = os.getenv("BUILD_SPECIFIC_GCLOUD_PROJECT")
25+
LOCATION = "us-central1"
26+
MODEL_ID = "1478306577684365312" # Permanent 50 flowers model
27+
DISPLAY_NAME = f"temp_create_batch_prediction_job_test_{uuid4()}"
28+
GCS_SOURCE_URI = (
29+
"gs://ucaip-samples-test-output/inputs/icn_batch_prediction_input.jsonl"
30+
)
31+
GCS_OUTPUT_URI = "gs://ucaip-samples-test-output/"
32+
INSTANCES_FORMAT = "jsonl"
33+
PREDICTIONS_FORMAT = "jsonl"
34+
35+
36+
@pytest.fixture
37+
def shared_state():
38+
state = {}
39+
yield state
40+
41+
42+
@pytest.fixture
43+
def job_client():
44+
job_client = aiplatform.gapic.JobServiceClient(
45+
client_options={"api_endpoint": "us-central1-aiplatform.googleapis.com"}
46+
)
47+
return job_client
48+
49+
50+
@pytest.fixture(scope="function", autouse=True)
51+
def teardown(shared_state, job_client):
52+
yield
53+
54+
job_client.cancel_batch_prediction_job(name=shared_state["batch_prediction_job_name"])
55+
56+
# Waiting until the job is in CANCELLED state.
57+
helpers.wait_for_job_state(
58+
get_job_method=job_client.get_batch_prediction_job,
59+
name=shared_state["batch_prediction_job_name"],
60+
)
61+
62+
job_client.delete_batch_prediction_job(name=shared_state["batch_prediction_job_name"])
63+
64+
65+
# Creating AutoML Vision Classification batch prediction job
66+
def test_ucaip_generated_create_batch_prediction_sample(capsys, shared_state):
67+
68+
model_name = f"projects/{PROJECT_ID}/locations/{LOCATION}/models/{MODEL_ID}"
69+
70+
create_batch_prediction_job_sample.create_batch_prediction_job_sample(
71+
project=PROJECT_ID,
72+
display_name=DISPLAY_NAME,
73+
model_name=model_name,
74+
gcs_source_uri=GCS_SOURCE_URI,
75+
gcs_destination_output_uri_prefix=GCS_OUTPUT_URI,
76+
instances_format=INSTANCES_FORMAT,
77+
predictions_format=PREDICTIONS_FORMAT,
78+
)
79+
80+
out, _ = capsys.readouterr()
81+
82+
# Save resource name of the newly created batch prediction job
83+
shared_state["batch_prediction_job_name"] = helpers.get_name(out)

0 commit comments

Comments
 (0)