29
29
Iterable ,
30
30
Optional ,
31
31
ContextManager ,
32
+ Tuple ,
32
33
)
33
34
import uuid
34
35
@@ -195,6 +196,7 @@ def __init__(
195
196
self ._logdir = logdir
196
197
self ._allowed_plugins = frozenset (allowed_plugins )
197
198
self ._run_name_prefix = run_name_prefix
199
+ self ._is_brand_new_experiment = False
198
200
199
201
self ._upload_limits = upload_limits
200
202
if not self ._upload_limits :
@@ -265,6 +267,9 @@ def active_filter(secs):
265
267
self ._logdir_loader = logdir_loader .LogdirLoader (
266
268
self ._logdir , directory_loader_factory
267
269
)
270
+ self ._logdir_loader_pre_create = logdir_loader .LogdirLoader (
271
+ self ._logdir , directory_loader_factory
272
+ )
268
273
self ._tracker = upload_tracker .UploadTracker (verbosity = self ._verbosity )
269
274
270
275
self ._create_additional_senders ()
@@ -290,6 +295,7 @@ def _create_or_get_experiment(self) -> tensorboard_experiment.TensorboardExperim
290
295
tensorboard_experiment = tb_experiment ,
291
296
tensorboard_experiment_id = self ._experiment_name ,
292
297
)
298
+ self ._is_brand_new_experiment = True
293
299
except exceptions .AlreadyExists :
294
300
logger .info ("Creating experiment failed. Retrieving experiment." )
295
301
experiment_name = os .path .join (
@@ -303,7 +309,11 @@ def create_experiment(self):
303
309
304
310
experiment = self ._create_or_get_experiment ()
305
311
self ._experiment = experiment
306
- request_sender = _BatchedRequestSender (
312
+ self ._one_platform_resource_manager = uploader_utils .OnePlatformResourceManager (
313
+ self ._experiment .name , self ._api
314
+ )
315
+
316
+ self ._request_sender = _BatchedRequestSender (
307
317
self ._experiment .name ,
308
318
self ._api ,
309
319
allowed_plugins = self ._allowed_plugins ,
@@ -313,6 +323,7 @@ def create_experiment(self):
313
323
blob_rpc_rate_limiter = self ._blob_rpc_rate_limiter ,
314
324
blob_storage_bucket = self ._blob_storage_bucket ,
315
325
blob_storage_folder = self ._blob_storage_folder ,
326
+ one_platform_resource_manager = self ._one_platform_resource_manager ,
316
327
tracker = self ._tracker ,
317
328
)
318
329
@@ -323,7 +334,8 @@ def create_experiment(self):
323
334
)
324
335
325
336
self ._dispatcher = _Dispatcher (
326
- request_sender = request_sender , additional_senders = self ._additional_senders ,
337
+ request_sender = self ._request_sender ,
338
+ additional_senders = self ._additional_senders ,
327
339
)
328
340
329
341
def _create_additional_senders (self ) -> Dict [str , uploader_utils .RequestSender ]:
@@ -366,6 +378,17 @@ def start_uploading(self):
366
378
"""
367
379
if self ._dispatcher is None :
368
380
raise RuntimeError ("Must call create_experiment() before start_uploading()" )
381
+
382
+ if self ._one_shot :
383
+ if self ._is_brand_new_experiment :
384
+ self ._pre_create_runs_and_time_series ()
385
+ else :
386
+ logger .warning (
387
+ "Please consider uploading to a new experiment instead of "
388
+ "an existing one, as the former allows for better upload "
389
+ "performance."
390
+ )
391
+
369
392
while True :
370
393
self ._logdir_poll_rate_limiter .tick ()
371
394
self ._upload_once ()
@@ -377,6 +400,58 @@ def start_uploading(self):
377
400
"without any uploadable data" % self ._logdir
378
401
)
379
402
403
+ def _pre_create_runs_and_time_series (self ):
404
+ """
405
+ Iterates though the log dir to collect TensorboardRuns and
406
+ TensorboardTimeSeries that need to be created, and creates them in batch
407
+ to speed up uploading later on.
408
+ """
409
+ self ._logdir_loader_pre_create .synchronize_runs ()
410
+ run_to_events = self ._logdir_loader_pre_create .get_run_events ()
411
+ if self ._run_name_prefix :
412
+ run_to_events = {
413
+ self ._run_name_prefix + k : v for k , v in run_to_events .items ()
414
+ }
415
+
416
+ run_names = []
417
+ run_tag_name_to_time_series_proto = {}
418
+ for (run_name , events ) in run_to_events .items ():
419
+ run_names .append (run_name )
420
+ for event in events :
421
+ _filter_graph_defs (event )
422
+ for value in event .summary .value :
423
+ metadata , is_valid = self ._request_sender .get_metadata_and_validate (
424
+ run_name , value
425
+ )
426
+ if not is_valid :
427
+ continue
428
+ if metadata .data_class == summary_pb2 .DATA_CLASS_SCALAR :
429
+ value_type = (
430
+ tensorboard_time_series .TensorboardTimeSeries .ValueType .SCALAR
431
+ )
432
+ elif metadata .data_class == summary_pb2 .DATA_CLASS_TENSOR :
433
+ value_type = (
434
+ tensorboard_time_series .TensorboardTimeSeries .ValueType .TENSOR
435
+ )
436
+ elif metadata .data_class == summary_pb2 .DATA_CLASS_BLOB_SEQUENCE :
437
+ value_type = (
438
+ tensorboard_time_series .TensorboardTimeSeries .ValueType .BLOB_SEQUENCE
439
+ )
440
+
441
+ run_tag_name_to_time_series_proto [
442
+ (run_name , value .tag )
443
+ ] = tensorboard_time_series .TensorboardTimeSeries (
444
+ display_name = value .tag ,
445
+ value_type = value_type ,
446
+ plugin_name = metadata .plugin_data .plugin_name ,
447
+ plugin_data = metadata .plugin_data .content ,
448
+ )
449
+
450
+ self ._one_platform_resource_manager .batch_create_runs (run_names )
451
+ self ._one_platform_resource_manager .batch_create_time_series (
452
+ run_tag_name_to_time_series_proto
453
+ )
454
+
380
455
def _upload_once (self ):
381
456
"""Runs one upload cycle, sending zero or more RPCs."""
382
457
logger .info ("Starting an upload cycle" )
@@ -439,6 +514,7 @@ def __init__(
439
514
blob_rpc_rate_limiter : util .RateLimiter ,
440
515
blob_storage_bucket : storage .Bucket ,
441
516
blob_storage_folder : str ,
517
+ one_platform_resource_manager : uploader_utils .OnePlatformResourceManager ,
442
518
tracker : upload_tracker .UploadTracker ,
443
519
):
444
520
"""Constructs _BatchedRequestSender for the given experiment resource.
@@ -456,16 +532,16 @@ def __init__(
456
532
Note the chunk stream is internally rate-limited by backpressure from
457
533
the server, so it is not a concern that we do not explicitly rate-limit
458
534
within the stream here.
535
+ one_platform_resource_manager: An instance of the One Platform
536
+ resource management class.
459
537
tracker: Upload tracker to track information about uploads.
460
538
"""
461
539
self ._experiment_resource_name = experiment_resource_name
462
540
self ._api = api
463
541
self ._tag_metadata = {}
464
542
self ._allowed_plugins = frozenset (allowed_plugins )
465
543
self ._tracker = tracker
466
- self ._one_platform_resource_manager = uploader_utils .OnePlatformResourceManager (
467
- self ._experiment_resource_name , self ._api
468
- )
544
+ self ._one_platform_resource_manager = one_platform_resource_manager
469
545
self ._scalar_request_sender = _ScalarBatchedRequestSender (
470
546
experiment_resource_id = experiment_resource_name ,
471
547
api = api ,
@@ -516,6 +592,37 @@ def send_request(
516
592
RuntimeError: If no progress can be made because even a single
517
593
point is too large (say, due to a gigabyte-long tag name).
518
594
"""
595
+ metadata , is_valid = self .get_metadata_and_validate (run_name , value )
596
+ if not is_valid :
597
+ return
598
+ plugin_name = metadata .plugin_data .plugin_name
599
+ self ._tracker .add_plugin_name (plugin_name )
600
+
601
+ if metadata .data_class == summary_pb2 .DATA_CLASS_SCALAR :
602
+ self ._scalar_request_sender .add_event (run_name , event , value , metadata )
603
+ elif metadata .data_class == summary_pb2 .DATA_CLASS_TENSOR :
604
+ self ._tensor_request_sender .add_event (run_name , event , value , metadata )
605
+ elif metadata .data_class == summary_pb2 .DATA_CLASS_BLOB_SEQUENCE :
606
+ self ._blob_request_sender .add_event (run_name , event , value , metadata )
607
+
608
+ def flush (self ):
609
+ """Flushes any events that have been stored."""
610
+ self ._scalar_request_sender .flush ()
611
+ self ._tensor_request_sender .flush ()
612
+ self ._blob_request_sender .flush ()
613
+
614
+ def get_metadata_and_validate (
615
+ self , run_name : str , value : tf .compat .v1 .Summary .Value
616
+ ) -> Tuple [tf .compat .v1 .SummaryMetadata , bool ]:
617
+ """
618
+
619
+ :param run_name: Name of the run retrieved by
620
+ `LogdirLoader.get_run_events`
621
+ :param value: A single `tf.compat.v1.Summary.Value` from the event,
622
+ where there can be multiple values per event.
623
+ :return: (metadata, is_valid): a metadata derived from the value, and
624
+ whether the value itself is valid.
625
+ """
519
626
520
627
time_series_key = (run_name , value .tag )
521
628
@@ -539,29 +646,16 @@ def send_request(
539
646
metadata .plugin_data .plugin_name ,
540
647
value .metadata .plugin_data .plugin_name ,
541
648
)
542
- return
649
+ return metadata , False
543
650
if plugin_name not in self ._allowed_plugins :
544
651
if first_in_time_series :
545
652
logger .info (
546
653
"Skipping time series %r with unsupported plugin name %r" ,
547
654
time_series_key ,
548
655
plugin_name ,
549
656
)
550
- return
551
- self ._tracker .add_plugin_name (plugin_name )
552
-
553
- if metadata .data_class == summary_pb2 .DATA_CLASS_SCALAR :
554
- self ._scalar_request_sender .add_event (run_name , event , value , metadata )
555
- elif metadata .data_class == summary_pb2 .DATA_CLASS_TENSOR :
556
- self ._tensor_request_sender .add_event (run_name , event , value , metadata )
557
- elif metadata .data_class == summary_pb2 .DATA_CLASS_BLOB_SEQUENCE :
558
- self ._blob_request_sender .add_event (run_name , event , value , metadata )
559
-
560
- def flush (self ):
561
- """Flushes any events that have been stored."""
562
- self ._scalar_request_sender .flush ()
563
- self ._tensor_request_sender .flush ()
564
- self ._blob_request_sender .flush ()
657
+ return metadata , False
658
+ return metadata , True
565
659
566
660
567
661
class _Dispatcher (object ):
0 commit comments