Skip to content

Commit 481d172

Browse files
authored
fix: pass credentials to BQ and GCS clients (#469)
1 parent c2cf612 commit 481d172

File tree

3 files changed

+76
-18
lines changed

3 files changed

+76
-18
lines changed

google/cloud/aiplatform/datasets/dataset.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def __init__(
6868
Optional location to retrieve dataset from. If not set, location
6969
set in aiplatform.init will be used.
7070
credentials (auth_credentials.Credentials):
71-
Custom credentials to use to upload this model. Overrides
71+
Custom credentials to use to retreive this Dataset. Overrides
7272
credentials set in aiplatform.init.
7373
"""
7474

google/cloud/aiplatform/datasets/tabular_dataset.py

+25-9
Original file line numberDiff line numberDiff line change
@@ -73,20 +73,28 @@ def column_names(self) -> List[str]:
7373
gcs_source_uris.sort()
7474

7575
# Get the first file in sorted list
76-
return TabularDataset._retrieve_gcs_source_columns(
77-
self.project, gcs_source_uris[0]
76+
return self._retrieve_gcs_source_columns(
77+
project=self.project,
78+
gcs_csv_file_path=gcs_source_uris[0],
79+
credentials=self.credentials,
7880
)
7981
elif bq_source:
8082
bq_table_uri = bq_source.get("uri")
8183
if bq_table_uri:
82-
return TabularDataset._retrieve_bq_source_columns(
83-
self.project, bq_table_uri
84+
return self._retrieve_bq_source_columns(
85+
project=self.project,
86+
bq_table_uri=bq_table_uri,
87+
credentials=self.credentials,
8488
)
8589

8690
raise RuntimeError("No valid CSV or BigQuery datasource found.")
8791

8892
@staticmethod
89-
def _retrieve_gcs_source_columns(project: str, gcs_csv_file_path: str) -> List[str]:
93+
def _retrieve_gcs_source_columns(
94+
project: str,
95+
gcs_csv_file_path: str,
96+
credentials: Optional[auth_credentials.Credentials] = None,
97+
) -> List[str]:
9098
"""Retrieve the columns from a comma-delimited CSV file stored on Google Cloud Storage
9199
92100
Example Usage:
@@ -104,7 +112,8 @@ def _retrieve_gcs_source_columns(project: str, gcs_csv_file_path: str) -> List[s
104112
gcs_csv_file_path (str):
105113
Required. A full path to a CSV files stored on Google Cloud Storage.
106114
Must include "gs://" prefix.
107-
115+
credentials (auth_credentials.Credentials):
116+
Credentials to use to with GCS Client.
108117
Returns:
109118
List[str]
110119
A list of columns names in the CSV file.
@@ -116,7 +125,7 @@ def _retrieve_gcs_source_columns(project: str, gcs_csv_file_path: str) -> List[s
116125
gcs_bucket, gcs_blob = utils.extract_bucket_and_prefix_from_gcs_path(
117126
gcs_csv_file_path
118127
)
119-
client = storage.Client(project=project)
128+
client = storage.Client(project=project, credentials=credentials)
120129
bucket = client.bucket(gcs_bucket)
121130
blob = bucket.blob(gcs_blob)
122131

@@ -135,6 +144,7 @@ def _retrieve_gcs_source_columns(project: str, gcs_csv_file_path: str) -> List[s
135144
line += blob.download_as_bytes(
136145
start=start_index, end=start_index + increment
137146
).decode("utf-8")
147+
138148
first_new_line_index = line.find("\n")
139149
start_index += increment
140150

@@ -156,7 +166,11 @@ def _retrieve_gcs_source_columns(project: str, gcs_csv_file_path: str) -> List[s
156166
return next(csv_reader)
157167

158168
@staticmethod
159-
def _retrieve_bq_source_columns(project: str, bq_table_uri: str) -> List[str]:
169+
def _retrieve_bq_source_columns(
170+
project: str,
171+
bq_table_uri: str,
172+
credentials: Optional[auth_credentials.Credentials] = None,
173+
) -> List[str]:
160174
"""Retrieve the columns from a table on Google BigQuery
161175
162176
Example Usage:
@@ -174,6 +188,8 @@ def _retrieve_bq_source_columns(project: str, bq_table_uri: str) -> List[str]:
174188
bq_table_uri (str):
175189
Required. A URI to a BigQuery table.
176190
Can include "bq://" prefix but not required.
191+
credentials (auth_credentials.Credentials):
192+
Credentials to use with BQ Client.
177193
178194
Returns:
179195
List[str]
@@ -185,7 +201,7 @@ def _retrieve_bq_source_columns(project: str, bq_table_uri: str) -> List[str]:
185201
if bq_table_uri.startswith(prefix):
186202
bq_table_uri = bq_table_uri[len(prefix) :]
187203

188-
client = bigquery.Client(project=project)
204+
client = bigquery.Client(project=project, credentials=credentials)
189205
table = client.get_table(bq_table_uri)
190206
schema = table.schema
191207
return [schema.name for schema in schema]

tests/unit/aiplatform/test_datasets.py

+50-8
Original file line numberDiff line numberDiff line change
@@ -341,16 +341,30 @@ def list_datasets_mock():
341341

342342
@pytest.fixture
343343
def gcs_client_download_as_bytes_mock():
344-
with patch.object(storage.Blob, "download_as_bytes") as bigquery_blob_mock:
345-
bigquery_blob_mock.return_value = b'"column_1","column_2"\n0, 1'
346-
yield bigquery_blob_mock
344+
with patch.object(storage.Blob, "download_as_bytes") as gcs_blob_mock:
345+
gcs_blob_mock.return_value = b'"column_1","column_2"\n0, 1'
346+
yield gcs_blob_mock
347347

348348

349349
@pytest.fixture
350-
def bigquery_client_mock():
351-
with patch.object(bigquery.Client, "get_table") as bigquery_client_mock:
352-
bigquery_client_mock.return_value = bigquery.Table("project.dataset.table")
353-
yield bigquery_client_mock
350+
def gcs_client_mock():
351+
with patch.object(storage, "Client") as client_mock:
352+
yield client_mock
353+
354+
355+
@pytest.fixture
356+
def bq_client_mock():
357+
with patch.object(bigquery, "Client") as client_mock:
358+
yield client_mock
359+
360+
361+
@pytest.fixture
362+
def bigquery_client_table_mock():
363+
with patch.object(bigquery.Client, "get_table") as bigquery_client_table_mock:
364+
bigquery_client_table_mock.return_value = bigquery.Table(
365+
"project.dataset.table"
366+
)
367+
yield bigquery_client_table_mock
354368

355369

356370
@pytest.fixture
@@ -995,9 +1009,37 @@ def test_tabular_dataset_column_name_gcs(self):
9951009

9961010
assert my_dataset.column_names == ["column_1", "column_2"]
9971011

1012+
@pytest.mark.usefixtures("get_dataset_tabular_gcs_mock")
1013+
def test_tabular_dataset_column_name_gcs_with_creds(self, gcs_client_mock):
1014+
creds = auth_credentials.AnonymousCredentials()
1015+
my_dataset = datasets.TabularDataset(dataset_name=_TEST_NAME, credentials=creds)
1016+
1017+
# we are just testing creds passing
1018+
# this exception if from the mock not returning
1019+
# the csv data which is tested above
1020+
try:
1021+
my_dataset.column_names
1022+
except StopIteration:
1023+
pass
1024+
1025+
gcs_client_mock.assert_called_once_with(
1026+
project=_TEST_PROJECT, credentials=creds
1027+
)
1028+
1029+
@pytest.mark.usefixtures("get_dataset_tabular_bq_mock",)
1030+
def test_tabular_dataset_column_name_bq_with_creds(self, bq_client_mock):
1031+
creds = auth_credentials.AnonymousCredentials()
1032+
my_dataset = datasets.TabularDataset(dataset_name=_TEST_NAME, credentials=creds)
1033+
1034+
my_dataset.column_names
1035+
1036+
assert bq_client_mock.call_args_list[0] == mock.call(
1037+
project=_TEST_PROJECT, credentials=creds
1038+
)
1039+
9981040
@pytest.mark.usefixtures(
9991041
"get_dataset_tabular_bq_mock",
1000-
"bigquery_client_mock",
1042+
"bigquery_client_table_mock",
10011043
"bigquery_table_schema_mock",
10021044
)
10031045
def test_tabular_dataset_column_name_bigquery(self):

0 commit comments

Comments
 (0)