Skip to content

Commit 6d5c7c4

Browse files
authored
feat: Add cloud profiler to training_utils
1 parent 91dd3c0 commit 6d5c7c4

File tree

16 files changed

+1414
-3
lines changed

16 files changed

+1414
-3
lines changed

README.rst

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,24 @@ To use Explanation Metadata in endpoint deployment and model upload:
464464
aiplatform.Model.upload(..., explanation_metadata=explanation_metadata)
465465
466466
467+
Cloud Profiler
468+
----------------------------
469+
470+
Cloud Profiler allows you to profile your remote Vertex AI Training jobs on demand and visualize the results in Vertex Tensorboard.
471+
472+
To start using the profiler with TensorFlow, update your training script to include the following:
473+
474+
.. code-block:: Python
475+
476+
from google.cloud.aiplatform.training_utils import cloud_profiler
477+
...
478+
cloud_profiler.init()
479+
480+
Next, run the job with with a Vertex TensorBoard instance. For full details on how to do this, visit https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-overview
481+
482+
Finally, visit your TensorBoard in your Google Cloud Console, navigate to the "Profile" tab, and click the `Capture Profile` button. This will allow users to capture profiling statistics for the running jobs.
483+
484+
467485
Next Steps
468486
~~~~~~~~~~
469487

google/cloud/aiplatform/tensorboard/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# limitations under the License.
1616
#
1717

18-
from google.cloud.aiplatform.tensorboard.tensorboard import Tensorboard
18+
from google.cloud.aiplatform.tensorboard.tensorboard_resource import Tensorboard
1919

2020

2121
__all__ = ("Tensorboard",)

google/cloud/aiplatform/tensorboard/uploader_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,6 @@ def get_or_create(
406406
filter="display_name = {}".format(json.dumps(str(tag_name))),
407407
)
408408
)
409-
410409
num = 0
411410
time_series = None
412411

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
Cloud Profiler
2+
=================================
3+
4+
Cloud Profiler allows you to profile your remote Vertex AI Training jobs on demand and visualize the results in Vertex Tensorboard.
5+
6+
Quick Start
7+
------------
8+
9+
To start using the profiler with TensorFlow, update your training script to include the following:
10+
11+
.. code-block:: Python
12+
13+
from google.cloud.aiplatform.training_utils import cloud_profiler
14+
...
15+
cloud_profiler.init()
16+
17+
18+
Next, run the job with with a Vertex TensorBoard instance. For full details on how to do this, visit https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-overview
19+
20+
Finally, visit your TensorBoard in your Google Cloud Console, navigate to the "Profile" tab, and click the `Capture Profile` button. This will allow users to capture profiling statistics for the running jobs.
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# Copyright 2021 Google LLC
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# https://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
try:
19+
import google.cloud.aiplatform.training_utils.cloud_profiler.initializer as initializer
20+
except ImportError as err:
21+
raise ImportError(
22+
"Could not load the cloud profiler. To use the profiler, "
23+
'install the SDK using "pip install google-cloud-aiplatform[cloud-profiler]"'
24+
) from err
25+
26+
"""
27+
Initialize the cloud profiler for tensorflow.
28+
29+
Usage:
30+
from google.cloud.aiplatform.training_utils import cloud_profiler
31+
32+
cloud_profiler.init(profiler='tensorflow')
33+
"""
34+
35+
init = initializer.initialize
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# Copyright 2021 Google LLC
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# https://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import logging
19+
import threading
20+
from typing import Optional, Type
21+
from werkzeug import serving
22+
23+
from google.cloud.aiplatform.training_utils import environment_variables
24+
from google.cloud.aiplatform.training_utils.cloud_profiler import webserver
25+
from google.cloud.aiplatform.training_utils.cloud_profiler.plugins import base_plugin
26+
from google.cloud.aiplatform.training_utils.cloud_profiler.plugins.tensorflow import (
27+
tf_profiler,
28+
)
29+
30+
# Mapping of available plugins to use
31+
_AVAILABLE_PLUGINS = {"tensorflow": tf_profiler.TFProfiler}
32+
33+
34+
class MissingEnvironmentVariableException(Exception):
35+
pass
36+
37+
38+
def _build_plugin(
39+
plugin: Type[base_plugin.BasePlugin],
40+
) -> Optional[base_plugin.BasePlugin]:
41+
"""Builds the plugin given the object.
42+
43+
Args:
44+
plugin (Type[base_plugin]):
45+
Required. An uninitialized plugin class.
46+
47+
Returns:
48+
An initialized plugin, or None if plugin cannot be
49+
initialized.
50+
"""
51+
if not plugin.can_initialize():
52+
logging.warning("Cannot initialize the plugin")
53+
return
54+
55+
plugin.setup()
56+
57+
if not plugin.post_setup_check():
58+
return
59+
60+
return plugin()
61+
62+
63+
def _run_app_thread(server: webserver.WebServer, port: int):
64+
"""Run the webserver in a separate thread.
65+
66+
Args:
67+
server (webserver.WebServer):
68+
Required. A webserver to accept requests.
69+
port (int):
70+
Required. The port to run the webserver on.
71+
"""
72+
daemon = threading.Thread(
73+
name="profile_server",
74+
target=serving.run_simple,
75+
args=("0.0.0.0", port, server,),
76+
)
77+
daemon.setDaemon(True)
78+
daemon.start()
79+
80+
81+
def initialize(plugin: str = "tensorflow"):
82+
"""Initializes the profiling SDK.
83+
84+
Args:
85+
plugin (str):
86+
Required. Name of the plugin to initialize.
87+
Current options are ["tensorflow"]
88+
89+
Raises:
90+
ValueError:
91+
The plugin does not exist.
92+
MissingEnvironmentVariableException:
93+
An environment variable that is needed is not set.
94+
"""
95+
plugin_obj = _AVAILABLE_PLUGINS.get(plugin)
96+
97+
if not plugin_obj:
98+
raise ValueError(
99+
"Plugin {} not available, must choose from {}".format(
100+
plugin, _AVAILABLE_PLUGINS.keys()
101+
)
102+
)
103+
104+
prof_plugin = _build_plugin(plugin_obj)
105+
106+
if prof_plugin is None:
107+
return
108+
109+
server = webserver.WebServer([prof_plugin])
110+
111+
if not environment_variables.http_handler_port:
112+
raise MissingEnvironmentVariableException(
113+
"'AIP_HTTP_HANDLER_PORT' must be set."
114+
)
115+
116+
port = int(environment_variables.http_handler_port)
117+
118+
_run_app_thread(server, port)
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# Copyright 2021 Google LLC
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# https://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import abc
19+
from typing import Callable, Dict
20+
from werkzeug import Response
21+
22+
23+
class BasePlugin(abc.ABC):
24+
"""Base plugin for cloud training tools endpoints.
25+
26+
The plugins support registering http handlers to be used for
27+
AI Platform training jobs.
28+
"""
29+
30+
@staticmethod
31+
@abc.abstractmethod
32+
def setup() -> None:
33+
"""Run any setup code for the plugin before webserver is launched."""
34+
raise NotImplementedError
35+
36+
@staticmethod
37+
@abc.abstractmethod
38+
def can_initialize() -> bool:
39+
"""Check whether a plugin is able to be initialized.
40+
41+
Used for checking if correct dependencies are installed, system requirements, etc.
42+
43+
Returns:
44+
Bool indicating whether the plugin can be initialized.
45+
"""
46+
raise NotImplementedError
47+
48+
@staticmethod
49+
@abc.abstractmethod
50+
def post_setup_check() -> bool:
51+
"""Check if after initialization, we need to use the plugin.
52+
53+
Example: Web server only needs to run for main node for training, others
54+
just need to have 'setup()' run to start the rpc server.
55+
56+
Returns:
57+
A boolean indicating whether post setup checks pass.
58+
"""
59+
raise NotImplementedError
60+
61+
@abc.abstractmethod
62+
def get_routes(self) -> Dict[str, Callable[..., Response]]:
63+
"""Get the mapping from path to handler.
64+
65+
This is the method in which plugins can assign different routes to
66+
different handlers.
67+
68+
Returns:
69+
A mapping from a route to a handler.
70+
"""
71+
raise NotImplementedError

0 commit comments

Comments
 (0)