Skip to content

Commit c9f68c6

Browse files
yfang1Yicheng Fang
andauthored
feat: pre batch creating TensorboardRuns and TensorboardTimeSeries in one_shot mode to speed up uploading (#772)
* feat: pre batch creating TensorboardRuns and TensorboardTimeSeries in one_shot mode to speed up uploading Co-authored-by: Yicheng Fang <[email protected]>
1 parent 8ef0ded commit c9f68c6

File tree

3 files changed

+270
-27
lines changed

3 files changed

+270
-27
lines changed

google/cloud/aiplatform/tensorboard/uploader.py

Lines changed: 115 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
Iterable,
3030
Optional,
3131
ContextManager,
32+
Tuple,
3233
)
3334
import uuid
3435

@@ -195,6 +196,7 @@ def __init__(
195196
self._logdir = logdir
196197
self._allowed_plugins = frozenset(allowed_plugins)
197198
self._run_name_prefix = run_name_prefix
199+
self._is_brand_new_experiment = False
198200

199201
self._upload_limits = upload_limits
200202
if not self._upload_limits:
@@ -265,6 +267,9 @@ def active_filter(secs):
265267
self._logdir_loader = logdir_loader.LogdirLoader(
266268
self._logdir, directory_loader_factory
267269
)
270+
self._logdir_loader_pre_create = logdir_loader.LogdirLoader(
271+
self._logdir, directory_loader_factory
272+
)
268273
self._tracker = upload_tracker.UploadTracker(verbosity=self._verbosity)
269274

270275
self._create_additional_senders()
@@ -290,6 +295,7 @@ def _create_or_get_experiment(self) -> tensorboard_experiment.TensorboardExperim
290295
tensorboard_experiment=tb_experiment,
291296
tensorboard_experiment_id=self._experiment_name,
292297
)
298+
self._is_brand_new_experiment = True
293299
except exceptions.AlreadyExists:
294300
logger.info("Creating experiment failed. Retrieving experiment.")
295301
experiment_name = os.path.join(
@@ -303,7 +309,11 @@ def create_experiment(self):
303309

304310
experiment = self._create_or_get_experiment()
305311
self._experiment = experiment
306-
request_sender = _BatchedRequestSender(
312+
self._one_platform_resource_manager = uploader_utils.OnePlatformResourceManager(
313+
self._experiment.name, self._api
314+
)
315+
316+
self._request_sender = _BatchedRequestSender(
307317
self._experiment.name,
308318
self._api,
309319
allowed_plugins=self._allowed_plugins,
@@ -313,6 +323,7 @@ def create_experiment(self):
313323
blob_rpc_rate_limiter=self._blob_rpc_rate_limiter,
314324
blob_storage_bucket=self._blob_storage_bucket,
315325
blob_storage_folder=self._blob_storage_folder,
326+
one_platform_resource_manager=self._one_platform_resource_manager,
316327
tracker=self._tracker,
317328
)
318329

@@ -323,7 +334,8 @@ def create_experiment(self):
323334
)
324335

325336
self._dispatcher = _Dispatcher(
326-
request_sender=request_sender, additional_senders=self._additional_senders,
337+
request_sender=self._request_sender,
338+
additional_senders=self._additional_senders,
327339
)
328340

329341
def _create_additional_senders(self) -> Dict[str, uploader_utils.RequestSender]:
@@ -366,6 +378,17 @@ def start_uploading(self):
366378
"""
367379
if self._dispatcher is None:
368380
raise RuntimeError("Must call create_experiment() before start_uploading()")
381+
382+
if self._one_shot:
383+
if self._is_brand_new_experiment:
384+
self._pre_create_runs_and_time_series()
385+
else:
386+
logger.warning(
387+
"Please consider uploading to a new experiment instead of "
388+
"an existing one, as the former allows for better upload "
389+
"performance."
390+
)
391+
369392
while True:
370393
self._logdir_poll_rate_limiter.tick()
371394
self._upload_once()
@@ -377,6 +400,58 @@ def start_uploading(self):
377400
"without any uploadable data" % self._logdir
378401
)
379402

403+
def _pre_create_runs_and_time_series(self):
404+
"""
405+
Iterates though the log dir to collect TensorboardRuns and
406+
TensorboardTimeSeries that need to be created, and creates them in batch
407+
to speed up uploading later on.
408+
"""
409+
self._logdir_loader_pre_create.synchronize_runs()
410+
run_to_events = self._logdir_loader_pre_create.get_run_events()
411+
if self._run_name_prefix:
412+
run_to_events = {
413+
self._run_name_prefix + k: v for k, v in run_to_events.items()
414+
}
415+
416+
run_names = []
417+
run_tag_name_to_time_series_proto = {}
418+
for (run_name, events) in run_to_events.items():
419+
run_names.append(run_name)
420+
for event in events:
421+
_filter_graph_defs(event)
422+
for value in event.summary.value:
423+
metadata, is_valid = self._request_sender.get_metadata_and_validate(
424+
run_name, value
425+
)
426+
if not is_valid:
427+
continue
428+
if metadata.data_class == summary_pb2.DATA_CLASS_SCALAR:
429+
value_type = (
430+
tensorboard_time_series.TensorboardTimeSeries.ValueType.SCALAR
431+
)
432+
elif metadata.data_class == summary_pb2.DATA_CLASS_TENSOR:
433+
value_type = (
434+
tensorboard_time_series.TensorboardTimeSeries.ValueType.TENSOR
435+
)
436+
elif metadata.data_class == summary_pb2.DATA_CLASS_BLOB_SEQUENCE:
437+
value_type = (
438+
tensorboard_time_series.TensorboardTimeSeries.ValueType.BLOB_SEQUENCE
439+
)
440+
441+
run_tag_name_to_time_series_proto[
442+
(run_name, value.tag)
443+
] = tensorboard_time_series.TensorboardTimeSeries(
444+
display_name=value.tag,
445+
value_type=value_type,
446+
plugin_name=metadata.plugin_data.plugin_name,
447+
plugin_data=metadata.plugin_data.content,
448+
)
449+
450+
self._one_platform_resource_manager.batch_create_runs(run_names)
451+
self._one_platform_resource_manager.batch_create_time_series(
452+
run_tag_name_to_time_series_proto
453+
)
454+
380455
def _upload_once(self):
381456
"""Runs one upload cycle, sending zero or more RPCs."""
382457
logger.info("Starting an upload cycle")
@@ -439,6 +514,7 @@ def __init__(
439514
blob_rpc_rate_limiter: util.RateLimiter,
440515
blob_storage_bucket: storage.Bucket,
441516
blob_storage_folder: str,
517+
one_platform_resource_manager: uploader_utils.OnePlatformResourceManager,
442518
tracker: upload_tracker.UploadTracker,
443519
):
444520
"""Constructs _BatchedRequestSender for the given experiment resource.
@@ -456,16 +532,16 @@ def __init__(
456532
Note the chunk stream is internally rate-limited by backpressure from
457533
the server, so it is not a concern that we do not explicitly rate-limit
458534
within the stream here.
535+
one_platform_resource_manager: An instance of the One Platform
536+
resource management class.
459537
tracker: Upload tracker to track information about uploads.
460538
"""
461539
self._experiment_resource_name = experiment_resource_name
462540
self._api = api
463541
self._tag_metadata = {}
464542
self._allowed_plugins = frozenset(allowed_plugins)
465543
self._tracker = tracker
466-
self._one_platform_resource_manager = uploader_utils.OnePlatformResourceManager(
467-
self._experiment_resource_name, self._api
468-
)
544+
self._one_platform_resource_manager = one_platform_resource_manager
469545
self._scalar_request_sender = _ScalarBatchedRequestSender(
470546
experiment_resource_id=experiment_resource_name,
471547
api=api,
@@ -516,6 +592,37 @@ def send_request(
516592
RuntimeError: If no progress can be made because even a single
517593
point is too large (say, due to a gigabyte-long tag name).
518594
"""
595+
metadata, is_valid = self.get_metadata_and_validate(run_name, value)
596+
if not is_valid:
597+
return
598+
plugin_name = metadata.plugin_data.plugin_name
599+
self._tracker.add_plugin_name(plugin_name)
600+
601+
if metadata.data_class == summary_pb2.DATA_CLASS_SCALAR:
602+
self._scalar_request_sender.add_event(run_name, event, value, metadata)
603+
elif metadata.data_class == summary_pb2.DATA_CLASS_TENSOR:
604+
self._tensor_request_sender.add_event(run_name, event, value, metadata)
605+
elif metadata.data_class == summary_pb2.DATA_CLASS_BLOB_SEQUENCE:
606+
self._blob_request_sender.add_event(run_name, event, value, metadata)
607+
608+
def flush(self):
609+
"""Flushes any events that have been stored."""
610+
self._scalar_request_sender.flush()
611+
self._tensor_request_sender.flush()
612+
self._blob_request_sender.flush()
613+
614+
def get_metadata_and_validate(
615+
self, run_name: str, value: tf.compat.v1.Summary.Value
616+
) -> Tuple[tf.compat.v1.SummaryMetadata, bool]:
617+
"""
618+
619+
:param run_name: Name of the run retrieved by
620+
`LogdirLoader.get_run_events`
621+
:param value: A single `tf.compat.v1.Summary.Value` from the event,
622+
where there can be multiple values per event.
623+
:return: (metadata, is_valid): a metadata derived from the value, and
624+
whether the value itself is valid.
625+
"""
519626

520627
time_series_key = (run_name, value.tag)
521628

@@ -539,29 +646,16 @@ def send_request(
539646
metadata.plugin_data.plugin_name,
540647
value.metadata.plugin_data.plugin_name,
541648
)
542-
return
649+
return metadata, False
543650
if plugin_name not in self._allowed_plugins:
544651
if first_in_time_series:
545652
logger.info(
546653
"Skipping time series %r with unsupported plugin name %r",
547654
time_series_key,
548655
plugin_name,
549656
)
550-
return
551-
self._tracker.add_plugin_name(plugin_name)
552-
553-
if metadata.data_class == summary_pb2.DATA_CLASS_SCALAR:
554-
self._scalar_request_sender.add_event(run_name, event, value, metadata)
555-
elif metadata.data_class == summary_pb2.DATA_CLASS_TENSOR:
556-
self._tensor_request_sender.add_event(run_name, event, value, metadata)
557-
elif metadata.data_class == summary_pb2.DATA_CLASS_BLOB_SEQUENCE:
558-
self._blob_request_sender.add_event(run_name, event, value, metadata)
559-
560-
def flush(self):
561-
"""Flushes any events that have been stored."""
562-
self._scalar_request_sender.flush()
563-
self._tensor_request_sender.flush()
564-
self._blob_request_sender.flush()
657+
return metadata, False
658+
return metadata, True
565659

566660

567661
class _Dispatcher(object):

google/cloud/aiplatform/tensorboard/uploader_utils.py

Lines changed: 97 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import logging
2323
import re
2424
import time
25-
from typing import Callable, Dict, Generator, Optional
25+
from typing import Callable, Dict, Generator, Optional, List, Tuple
2626
import uuid
2727

2828
from tensorboard.util import tb_logging
@@ -39,7 +39,6 @@
3939
tensorboard_time_series_v1beta1 as tensorboard_time_series,
4040
)
4141
from google.cloud.aiplatform.compat.services import tensorboard_service_client_v1beta1
42-
from google.cloud.aiplatform_v1beta1.types import TensorboardRun
4342

4443
TensorboardServiceClient = tensorboard_service_client_v1beta1.TensorboardServiceClient
4544

@@ -66,6 +65,9 @@ def send_requests(run_name: str):
6665
class OnePlatformResourceManager(object):
6766
"""Helper class managing One Platform resources."""
6867

68+
CREATE_RUN_BATCH_SIZE = 1000
69+
CREATE_TIME_SERIES_BATCH_SIZE = 1000
70+
6971
def __init__(self, experiment_resource_name: str, api: TensorboardServiceClient):
7072
"""Constructor for OnePlatformResourceManager.
7173
@@ -81,6 +83,96 @@ def __init__(self, experiment_resource_name: str, api: TensorboardServiceClient)
8183
self._run_name_to_run_resource_name: Dict[str, str] = {}
8284
self._run_tag_name_to_time_series_name: Dict[(str, str), str] = {}
8385

86+
def batch_create_runs(
87+
self, run_names: List[str]
88+
) -> List[tensorboard_run.TensorboardRun]:
89+
"""Batch creates TensorboardRuns.
90+
91+
Args:
92+
run_names: a list of run_names for creating the TensorboardRuns.
93+
Returns:
94+
the created TensorboardRuns
95+
"""
96+
batch_size = OnePlatformResourceManager.CREATE_RUN_BATCH_SIZE
97+
created_runs = []
98+
for i in range(0, len(run_names), batch_size):
99+
one_batch_run_names = run_names[i : i + batch_size]
100+
tb_run_requests = [
101+
tensorboard_service.CreateTensorboardRunRequest(
102+
parent=self._experiment_resource_name,
103+
tensorboard_run=tensorboard_run.TensorboardRun(
104+
display_name=run_name
105+
),
106+
tensorboard_run_id=str(uuid.uuid4()),
107+
)
108+
for run_name in one_batch_run_names
109+
]
110+
111+
tb_runs = self._api.batch_create_tensorboard_runs(
112+
parent=self._experiment_resource_name, requests=tb_run_requests,
113+
).tensorboard_runs
114+
115+
self._run_name_to_run_resource_name.update(
116+
{run.display_name: run.name for run in tb_runs}
117+
)
118+
119+
created_runs.extend(tb_runs)
120+
121+
return created_runs
122+
123+
def batch_create_time_series(
124+
self,
125+
run_tag_name_to_time_series: Dict[
126+
Tuple[str, str], tensorboard_time_series.TensorboardTimeSeries
127+
],
128+
) -> List[tensorboard_time_series.TensorboardTimeSeries]:
129+
"""Batch creates TensorboardTimeSeries.
130+
131+
Args:
132+
run_tag_name_to_time_series: a dictionary of
133+
(run_name, tag_name) to TensorboardTimeSeries proto, containing
134+
the TensorboardTimeSeries to create.
135+
Returns:
136+
the created TensorboardTimeSeries
137+
"""
138+
batch_size = OnePlatformResourceManager.CREATE_TIME_SERIES_BATCH_SIZE
139+
run_tag_name_to_time_series_entries = list(run_tag_name_to_time_series.items())
140+
run_resource_name_to_run_name = {
141+
v: k for k, v in self._run_name_to_run_resource_name.items()
142+
}
143+
created_time_series = []
144+
for i in range(0, len(run_tag_name_to_time_series_entries), batch_size):
145+
requests = [
146+
tensorboard_service.CreateTensorboardTimeSeriesRequest(
147+
parent=self._run_name_to_run_resource_name[run_name],
148+
tensorboard_time_series=time_series,
149+
)
150+
for (
151+
(run_name, tag_name),
152+
time_series,
153+
) in run_tag_name_to_time_series_entries[i : i + batch_size]
154+
]
155+
156+
time_series = self._api.batch_create_tensorboard_time_series(
157+
parent=self._experiment_resource_name, requests=requests,
158+
).tensorboard_time_series
159+
160+
self._run_tag_name_to_time_series_name.update(
161+
{
162+
(
163+
run_resource_name_to_run_name[
164+
ts.name[: ts.name.index("/timeSeries")]
165+
],
166+
ts.display_name,
167+
): ts.name
168+
for ts in time_series
169+
}
170+
)
171+
172+
created_time_series.extend(time_series)
173+
174+
return created_time_series
175+
84176
def get_run_resource_name(self, run_name: str) -> str:
85177
"""
86178
Get the resource name of the run if it exists, otherwise creates the run
@@ -99,7 +191,9 @@ def get_run_resource_name(self, run_name: str) -> str:
99191
self._run_name_to_run_resource_name[run_name] = tb_run.name
100192
return self._run_name_to_run_resource_name[run_name]
101193

102-
def _create_or_get_run_resource(self, run_name: str) -> TensorboardRun:
194+
def _create_or_get_run_resource(
195+
self, run_name: str
196+
) -> tensorboard_run.TensorboardRun:
103197
"""Creates a new run resource in current tensorboard experiment resource.
104198
105199
Args:

0 commit comments

Comments
 (0)