提交 cda80a78 编写于 作者: E Eric Liu 提交者: TensorFlower Gardener

[tpu profiler] Dump HLO graphs in profile responses to the log directory.

PiperOrigin-RevId: 163318992
上级 dd1f0cdd
......@@ -21,8 +21,10 @@ cc_binary(
visibility = ["//tensorflow/contrib/tpu/profiler:__subpackages__"],
deps = [
":tpu_profiler_proto_cc",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
"@grpc//:grpc++_unsecure",
],
......
......@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/contrib/tpu/profiler/tpu_profiler.grpc.pb.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
......@@ -33,6 +34,7 @@ limitations under the License.
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/util/command_line_flags.h"
#include "tensorflow/core/util/events_writer.h"
namespace tensorflow {
namespace tpu {
......@@ -47,6 +49,7 @@ using ::tensorflow::WriteStringToFile;
constexpr char kProfilePluginDirectory[] = "plugins/profile/";
constexpr char kTraceFileName[] = "trace";
constexpr char kGraphRunPrefix[] = "tpu_profiler.hlo_graph.";
tensorflow::string GetCurrentTimeStampAsString() {
char s[128];
......@@ -55,10 +58,10 @@ tensorflow::string GetCurrentTimeStampAsString() {
return s;
}
// The trace will be stored in <logdir>/plugins/profile/<timestamp>/trace.
void DumpTraceToLogDirectory(const tensorflow::string& logdir,
// The trace will be stored in <logdir>/plugins/profile/<run>/trace.
void DumpTraceToLogDirectory(tensorflow::StringPiece logdir,
tensorflow::StringPiece run,
tensorflow::StringPiece trace) {
tensorflow::string run = GetCurrentTimeStampAsString();
tensorflow::string run_dir = JoinPath(logdir, kProfilePluginDirectory, run);
TF_CHECK_OK(Env::Default()->RecursivelyCreateDir(run_dir));
tensorflow::string path = JoinPath(run_dir, kTraceFileName);
......@@ -83,6 +86,18 @@ ProfileResponse Profile(const tensorflow::string& service_addr,
return response;
}
void DumpGraph(tensorflow::StringPiece logdir, tensorflow::StringPiece run,
const tensorflow::string& graph_def) {
// The graph plugin expects the graph in <logdir>/<run>/<event.file>.
tensorflow::string run_dir =
JoinPath(logdir, tensorflow::strings::StrCat(kGraphRunPrefix, run));
TF_CHECK_OK(Env::Default()->RecursivelyCreateDir(run_dir));
tensorflow::EventsWriter event_writer(JoinPath(run_dir, "events"));
tensorflow::Event event;
event.set_graph_def(graph_def);
event_writer.WriteEvent(event);
}
} // namespace
} // namespace tpu
} // namespace tensorflow
......@@ -111,14 +126,28 @@ int main(int argc, char** argv) {
int duration_ms = FLAGS_duration_ms;
tensorflow::ProfileResponse response =
tensorflow::tpu::Profile(FLAGS_service_addr, duration_ms);
// Use the current timestamp as the run name.
tensorflow::string run = tensorflow::tpu::GetCurrentTimeStampAsString();
// Ignore computation_graph for now.
if (response.encoded_trace().empty()) {
LOG(WARNING) << "No trace event is collected during the " << duration_ms
<< "ms interval.";
} else {
tensorflow::tpu::DumpTraceToLogDirectory(FLAGS_logdir,
tensorflow::tpu::DumpTraceToLogDirectory(FLAGS_logdir, run,
response.encoded_trace());
}
int num_graphs = response.computation_graph_size();
if (num_graphs > 0) {
// The server might generates multiple graphs for one program; we simply
// pick the first one.
if (num_graphs > 1) {
LOG(INFO) << num_graphs
<< " TPU program variants observed over the profiling period. "
<< "One computation graph will be chosen arbitrarily.";
}
tensorflow::tpu::DumpGraph(
FLAGS_logdir, run, response.computation_graph(0).SerializeAsString());
}
// Print this at the end so that it's not buried in irrelevant LOG messages.
std::cout
<< "NOTE: using the trace duration " << duration_ms << "ms." << std::endl
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册