Skip to content

Commit 3fc1d44

Browse files
authored
fix: Handle nested fields from BigQuery source when getting default column_names (#522)
* Handle nested fields from BigQuery source * Added unit test for nested BigQuery fields and refactored column_names to return a Set instead of a List * Added comment * Fixed minor issues with tabular_dataset * Switched TabularDataset.column_names back to returning a List as to not introduce a breaking change at this time
1 parent 2508fe9 commit 3fc1d44

File tree

2 files changed

+135
-22
lines changed

2 files changed

+135
-22
lines changed

google/cloud/aiplatform/datasets/tabular_dataset.py

Lines changed: 71 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import csv
1919
import logging
2020

21-
from typing import List, Optional, Sequence, Tuple, Union
21+
from typing import List, Optional, Sequence, Set, Tuple, Union
2222

2323
from google.auth import credentials as auth_credentials
2424

@@ -73,18 +73,24 @@ def column_names(self) -> List[str]:
7373
gcs_source_uris.sort()
7474

7575
# Get the first file in sorted list
76-
return self._retrieve_gcs_source_columns(
77-
project=self.project,
78-
gcs_csv_file_path=gcs_source_uris[0],
79-
credentials=self.credentials,
76+
# TODO(b/193044977): Return as Set instead of List
77+
return list(
78+
self._retrieve_gcs_source_columns(
79+
project=self.project,
80+
gcs_csv_file_path=gcs_source_uris[0],
81+
credentials=self.credentials,
82+
)
8083
)
8184
elif bq_source:
8285
bq_table_uri = bq_source.get("uri")
8386
if bq_table_uri:
84-
return self._retrieve_bq_source_columns(
85-
project=self.project,
86-
bq_table_uri=bq_table_uri,
87-
credentials=self.credentials,
87+
# TODO(b/193044977): Return as Set instead of List
88+
return list(
89+
self._retrieve_bq_source_columns(
90+
project=self.project,
91+
bq_table_uri=bq_table_uri,
92+
credentials=self.credentials,
93+
)
8894
)
8995

9096
raise RuntimeError("No valid CSV or BigQuery datasource found.")
@@ -94,7 +100,7 @@ def _retrieve_gcs_source_columns(
94100
project: str,
95101
gcs_csv_file_path: str,
96102
credentials: Optional[auth_credentials.Credentials] = None,
97-
) -> List[str]:
103+
) -> Set[str]:
98104
"""Retrieve the columns from a comma-delimited CSV file stored on Google Cloud Storage
99105
100106
Example Usage:
@@ -104,7 +110,7 @@ def _retrieve_gcs_source_columns(
104110
"gs://example-bucket/path/to/csv_file"
105111
)
106112
107-
# column_names = ["column_1", "column_2"]
113+
# column_names = {"column_1", "column_2"}
108114
109115
Args:
110116
project (str):
@@ -115,8 +121,8 @@ def _retrieve_gcs_source_columns(
115121
credentials (auth_credentials.Credentials):
116122
Credentials to use to with GCS Client.
117123
Returns:
118-
List[str]
119-
A list of columns names in the CSV file.
124+
Set[str]
125+
A set of columns names in the CSV file.
120126
121127
Raises:
122128
RuntimeError: When the retrieved CSV file is invalid.
@@ -163,15 +169,53 @@ def _retrieve_gcs_source_columns(
163169
finally:
164170
logger.removeFilter(logging_warning_filter)
165171

166-
return next(csv_reader)
172+
return set(next(csv_reader))
173+
174+
@staticmethod
175+
def _get_bq_schema_field_names_recursively(
176+
schema_field: bigquery.SchemaField,
177+
) -> Set[str]:
178+
"""Retrieve the name for a schema field along with ancestor fields.
179+
Nested schema fields are flattened and concatenated with a ".".
180+
Schema fields with child fields are not included, but the children are.
181+
182+
Args:
183+
project (str):
184+
Required. Project to initiate the BigQuery client with.
185+
bq_table_uri (str):
186+
Required. A URI to a BigQuery table.
187+
Can include "bq://" prefix but not required.
188+
credentials (auth_credentials.Credentials):
189+
Credentials to use with BQ Client.
190+
191+
Returns:
192+
Set[str]
193+
A set of columns names in the BigQuery table.
194+
"""
195+
196+
ancestor_names = {
197+
nested_field_name
198+
for field in schema_field.fields
199+
for nested_field_name in TabularDataset._get_bq_schema_field_names_recursively(
200+
field
201+
)
202+
}
203+
204+
# Only return "leaf nodes", basically any field that doesn't have children
205+
if len(ancestor_names) == 0:
206+
return {schema_field.name}
207+
else:
208+
return {f"{schema_field.name}.{name}" for name in ancestor_names}
167209

168210
@staticmethod
169211
def _retrieve_bq_source_columns(
170212
project: str,
171213
bq_table_uri: str,
172214
credentials: Optional[auth_credentials.Credentials] = None,
173-
) -> List[str]:
174-
"""Retrieve the columns from a table on Google BigQuery
215+
) -> Set[str]:
216+
"""Retrieve the column names from a table on Google BigQuery
217+
Nested schema fields are flattened and concatenated with a ".".
218+
Schema fields with child fields are not included, but the children are.
175219
176220
Example Usage:
177221
@@ -180,7 +224,7 @@ def _retrieve_bq_source_columns(
180224
"bq://project_id.dataset.table"
181225
)
182226
183-
# column_names = ["column_1", "column_2"]
227+
# column_names = {"column_1", "column_2", "column_3.nested_field"}
184228
185229
Args:
186230
project (str):
@@ -192,8 +236,8 @@ def _retrieve_bq_source_columns(
192236
Credentials to use with BQ Client.
193237
194238
Returns:
195-
List[str]
196-
A list of columns names in the BigQuery table.
239+
Set[str]
240+
A set of column names in the BigQuery table.
197241
"""
198242

199243
# Remove bq:// prefix
@@ -204,7 +248,14 @@ def _retrieve_bq_source_columns(
204248
client = bigquery.Client(project=project, credentials=credentials)
205249
table = client.get_table(bq_table_uri)
206250
schema = table.schema
207-
return [schema.name for schema in schema]
251+
252+
return {
253+
field_name
254+
for field in schema
255+
for field_name in TabularDataset._get_bq_schema_field_names_recursively(
256+
field
257+
)
258+
}
208259

209260
@classmethod
210261
def create(

tests/unit/aiplatform/test_datasets.py

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,59 @@ def bigquery_table_schema_mock():
375375
bigquery_table_schema_mock.return_value = [
376376
bigquery.SchemaField("column_1", "FLOAT", "NULLABLE", "", (), None),
377377
bigquery.SchemaField("column_2", "FLOAT", "NULLABLE", "", (), None),
378+
bigquery.SchemaField(
379+
"column_3",
380+
"RECORD",
381+
"NULLABLE",
382+
"",
383+
(
384+
bigquery.SchemaField(
385+
"nested_3_1",
386+
"RECORD",
387+
"NULLABLE",
388+
"",
389+
(
390+
bigquery.SchemaField(
391+
"nested_3_1_1", "FLOAT", "NULLABLE", "", (), None
392+
),
393+
bigquery.SchemaField(
394+
"nested_3_1_2", "FLOAT", "NULLABLE", "", (), None
395+
),
396+
),
397+
None,
398+
),
399+
bigquery.SchemaField(
400+
"nested_3_2", "FLOAT", "NULLABLE", "", (), None
401+
),
402+
bigquery.SchemaField(
403+
"nested_3_3",
404+
"RECORD",
405+
"NULLABLE",
406+
"",
407+
(
408+
bigquery.SchemaField(
409+
"nested_3_3_1",
410+
"RECORD",
411+
"NULLABLE",
412+
"",
413+
(
414+
bigquery.SchemaField(
415+
"nested_3_3_1_1",
416+
"FLOAT",
417+
"NULLABLE",
418+
"",
419+
(),
420+
None,
421+
),
422+
),
423+
None,
424+
),
425+
),
426+
None,
427+
),
428+
),
429+
None,
430+
),
378431
]
379432
yield bigquery_table_schema_mock
380433

@@ -1007,7 +1060,7 @@ def test_tabular_dataset_column_name_missing_datasource(self):
10071060
def test_tabular_dataset_column_name_gcs(self):
10081061
my_dataset = datasets.TabularDataset(dataset_name=_TEST_NAME)
10091062

1010-
assert my_dataset.column_names == ["column_1", "column_2"]
1063+
assert set(my_dataset.column_names) == {"column_1", "column_2"}
10111064

10121065
@pytest.mark.usefixtures("get_dataset_tabular_gcs_mock")
10131066
def test_tabular_dataset_column_name_gcs_with_creds(self, gcs_client_mock):
@@ -1045,7 +1098,16 @@ def test_tabular_dataset_column_name_bq_with_creds(self, bq_client_mock):
10451098
def test_tabular_dataset_column_name_bigquery(self):
10461099
my_dataset = datasets.TabularDataset(dataset_name=_TEST_NAME)
10471100

1048-
assert my_dataset.column_names == ["column_1", "column_2"]
1101+
assert set(my_dataset.column_names) == set(
1102+
[
1103+
"column_1",
1104+
"column_2",
1105+
"column_3.nested_3_1.nested_3_1_1",
1106+
"column_3.nested_3_1.nested_3_1_2",
1107+
"column_3.nested_3_2",
1108+
"column_3.nested_3_3.nested_3_3_1.nested_3_3_1_1",
1109+
]
1110+
)
10491111

10501112

10511113
class TestTextDataset:

0 commit comments

Comments
 (0)