18
18
import csv
19
19
import logging
20
20
21
- from typing import List , Optional , Sequence , Tuple , Union
21
+ from typing import List , Optional , Sequence , Set , Tuple , Union
22
22
23
23
from google .auth import credentials as auth_credentials
24
24
@@ -73,18 +73,24 @@ def column_names(self) -> List[str]:
73
73
gcs_source_uris .sort ()
74
74
75
75
# Get the first file in sorted list
76
- return self ._retrieve_gcs_source_columns (
77
- project = self .project ,
78
- gcs_csv_file_path = gcs_source_uris [0 ],
79
- credentials = self .credentials ,
76
+ # TODO(b/193044977): Return as Set instead of List
77
+ return list (
78
+ self ._retrieve_gcs_source_columns (
79
+ project = self .project ,
80
+ gcs_csv_file_path = gcs_source_uris [0 ],
81
+ credentials = self .credentials ,
82
+ )
80
83
)
81
84
elif bq_source :
82
85
bq_table_uri = bq_source .get ("uri" )
83
86
if bq_table_uri :
84
- return self ._retrieve_bq_source_columns (
85
- project = self .project ,
86
- bq_table_uri = bq_table_uri ,
87
- credentials = self .credentials ,
87
+ # TODO(b/193044977): Return as Set instead of List
88
+ return list (
89
+ self ._retrieve_bq_source_columns (
90
+ project = self .project ,
91
+ bq_table_uri = bq_table_uri ,
92
+ credentials = self .credentials ,
93
+ )
88
94
)
89
95
90
96
raise RuntimeError ("No valid CSV or BigQuery datasource found." )
@@ -94,7 +100,7 @@ def _retrieve_gcs_source_columns(
94
100
project : str ,
95
101
gcs_csv_file_path : str ,
96
102
credentials : Optional [auth_credentials .Credentials ] = None ,
97
- ) -> List [str ]:
103
+ ) -> Set [str ]:
98
104
"""Retrieve the columns from a comma-delimited CSV file stored on Google Cloud Storage
99
105
100
106
Example Usage:
@@ -104,7 +110,7 @@ def _retrieve_gcs_source_columns(
104
110
"gs://example-bucket/path/to/csv_file"
105
111
)
106
112
107
- # column_names = [ "column_1", "column_2"]
113
+ # column_names = { "column_1", "column_2"}
108
114
109
115
Args:
110
116
project (str):
@@ -115,8 +121,8 @@ def _retrieve_gcs_source_columns(
115
121
credentials (auth_credentials.Credentials):
116
122
Credentials to use to with GCS Client.
117
123
Returns:
118
- List [str]
119
- A list of columns names in the CSV file.
124
+ Set [str]
125
+ A set of columns names in the CSV file.
120
126
121
127
Raises:
122
128
RuntimeError: When the retrieved CSV file is invalid.
@@ -163,15 +169,53 @@ def _retrieve_gcs_source_columns(
163
169
finally :
164
170
logger .removeFilter (logging_warning_filter )
165
171
166
- return next (csv_reader )
172
+ return set (next (csv_reader ))
173
+
174
+ @staticmethod
175
+ def _get_bq_schema_field_names_recursively (
176
+ schema_field : bigquery .SchemaField ,
177
+ ) -> Set [str ]:
178
+ """Retrieve the name for a schema field along with ancestor fields.
179
+ Nested schema fields are flattened and concatenated with a ".".
180
+ Schema fields with child fields are not included, but the children are.
181
+
182
+ Args:
183
+ project (str):
184
+ Required. Project to initiate the BigQuery client with.
185
+ bq_table_uri (str):
186
+ Required. A URI to a BigQuery table.
187
+ Can include "bq://" prefix but not required.
188
+ credentials (auth_credentials.Credentials):
189
+ Credentials to use with BQ Client.
190
+
191
+ Returns:
192
+ Set[str]
193
+ A set of columns names in the BigQuery table.
194
+ """
195
+
196
+ ancestor_names = {
197
+ nested_field_name
198
+ for field in schema_field .fields
199
+ for nested_field_name in TabularDataset ._get_bq_schema_field_names_recursively (
200
+ field
201
+ )
202
+ }
203
+
204
+ # Only return "leaf nodes", basically any field that doesn't have children
205
+ if len (ancestor_names ) == 0 :
206
+ return {schema_field .name }
207
+ else :
208
+ return {f"{ schema_field .name } .{ name } " for name in ancestor_names }
167
209
168
210
@staticmethod
169
211
def _retrieve_bq_source_columns (
170
212
project : str ,
171
213
bq_table_uri : str ,
172
214
credentials : Optional [auth_credentials .Credentials ] = None ,
173
- ) -> List [str ]:
174
- """Retrieve the columns from a table on Google BigQuery
215
+ ) -> Set [str ]:
216
+ """Retrieve the column names from a table on Google BigQuery
217
+ Nested schema fields are flattened and concatenated with a ".".
218
+ Schema fields with child fields are not included, but the children are.
175
219
176
220
Example Usage:
177
221
@@ -180,7 +224,7 @@ def _retrieve_bq_source_columns(
180
224
"bq://project_id.dataset.table"
181
225
)
182
226
183
- # column_names = [ "column_1", "column_2"]
227
+ # column_names = { "column_1", "column_2", "column_3.nested_field"}
184
228
185
229
Args:
186
230
project (str):
@@ -192,8 +236,8 @@ def _retrieve_bq_source_columns(
192
236
Credentials to use with BQ Client.
193
237
194
238
Returns:
195
- List [str]
196
- A list of columns names in the BigQuery table.
239
+ Set [str]
240
+ A set of column names in the BigQuery table.
197
241
"""
198
242
199
243
# Remove bq:// prefix
@@ -204,7 +248,14 @@ def _retrieve_bq_source_columns(
204
248
client = bigquery .Client (project = project , credentials = credentials )
205
249
table = client .get_table (bq_table_uri )
206
250
schema = table .schema
207
- return [schema .name for schema in schema ]
251
+
252
+ return {
253
+ field_name
254
+ for field in schema
255
+ for field_name in TabularDataset ._get_bq_schema_field_names_recursively (
256
+ field
257
+ )
258
+ }
208
259
209
260
@classmethod
210
261
def create (
0 commit comments