提交 ba545bc5 编写于 作者: K Kathy Wu

Merge branch 'r2.2' of https://github.com/tensorflow/tensorflow into cherrypicks_SXE8X

......@@ -46,7 +46,6 @@
# sycl_asan:
# sycl_trisycl:
# mkl: Enable full mkl support.
# mkl_open_source_only: Enable MKL support only using open source MKL libraries.
# tensorrt: Enable Tensorrt support.
# ngraph: Enable ngraph support.
# numa: Enable numa using hwloc.
......@@ -140,13 +139,6 @@ build:mkl --define=tensorflow_mkldnn_contraction_kernel=0
build:mkl --define=build_with_mkl_dnn_v1_only=true
build:mkl -c opt
# This config option is used to enable MKL-DNN open source library only,
# without depending on MKL binary version.
build:mkl_open_source_only --define=build_with_mkl_dnn_only=true
build:mkl_open_source_only --define=build_with_mkl_dnn_v1_only=true
build:mkl_open_source_only --define=build_with_mkl=true --define=enable_mkl=true
build:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=0
# This config refers to building with CUDA available. It does not necessarily
# mean that we build CUDA op kernels.
build:using_cuda --define=using_cuda=true
......
......@@ -8,6 +8,9 @@ glob_lit_tests(
":test_utilities",
],
driver = "@llvm-project//mlir:run_lit.sh",
tags_override = {
"error-message-with-source-info.pbtxt": ["no_oss"], # TODO(b/150946057): to be fixed on oss.
},
test_file_exts = ["pbtxt"],
)
......
......@@ -229,6 +229,7 @@ cc_library(
":op_stats_to_input_pipeline_analysis",
":op_stats_to_overview_page",
":op_stats_to_tf_stats",
":trace_events_to_json",
":xplane_to_op_stats",
":xplane_to_trace_events",
"//tensorflow/core:lib",
......@@ -240,6 +241,7 @@ cc_library(
"//tensorflow/core/profiler/protobuf:overview_page_proto_cc",
"//tensorflow/core/profiler/protobuf:tf_stats_proto_cc",
"//tensorflow/core/profiler/protobuf:xplane_proto_cc",
"//tensorflow/core/profiler/rpc/client:save_profile",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
],
......
......@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.h"
#include "tensorflow/core/profiler/convert/op_stats_to_overview_page.h"
#include "tensorflow/core/profiler/convert/op_stats_to_tf_stats.h"
#include "tensorflow/core/profiler/convert/trace_events_to_json.h"
#include "tensorflow/core/profiler/convert/xplane_to_op_stats.h"
#include "tensorflow/core/profiler/convert/xplane_to_trace_events.h"
#include "tensorflow/core/profiler/profiler_service.pb.h"
......@@ -30,6 +31,7 @@ limitations under the License.
#include "tensorflow/core/profiler/protobuf/overview_page.pb.h"
#include "tensorflow/core/profiler/protobuf/tf_stats.pb.h"
#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
#include "tensorflow/core/profiler/rpc/client/save_profile.h"
namespace tensorflow {
namespace profiler {
......@@ -57,24 +59,30 @@ void AddToolData(absl::string_view tool_name, const Proto& tool_output,
// Returns the tool name with extension.
string ToolName(absl::string_view tool) {
if (tool == kTraceViewer) return "trace";
if (tool == kTraceViewer) return "trace.json.gz";
return absl::StrCat(tool, ".pb");
}
} // namespace
void ConvertXSpaceToProfileResponse(const XSpace& xspace,
const ProfileRequest& req,
ProfileResponse* response) {
Status ConvertXSpaceToProfileResponse(const XSpace& xspace,
const ProfileRequest& req,
ProfileResponse* response) {
absl::flat_hash_set<absl::string_view> tools(req.tools().begin(),
req.tools().end());
if (tools.empty()) return;
if (tools.empty()) return Status::OK();
if (tools.contains(kTraceViewer)) {
Trace trace;
ConvertXSpaceToTraceEvents(xspace, &trace);
AddToolData(ToolName(kTraceViewer), trace, response);
if (trace.trace_events().empty()) {
response->set_empty_trace(true);
return Status::OK();
}
TF_RETURN_IF_ERROR(SaveGzippedToolDataToTensorboardProfile(
req.repository_root(), req.session_id(), req.host_name(),
ToolName(kTraceViewer), TraceEventsToJson(trace)));
// Trace viewer is the only tool, skip OpStats conversion.
if (tools.size() == 1) return;
if (tools.size() == 1) return Status::OK();
}
OpStats op_stats = ConvertXSpaceToOpStats(xspace);
HardwareType hw_type =
......@@ -99,6 +107,7 @@ void ConvertXSpaceToProfileResponse(const XSpace& xspace,
if (tools.contains(kKernelStats)) {
AddToolData(ToolName(kKernelStats), op_stats.kernel_stats_db(), response);
}
return Status::OK();
}
} // namespace profiler
......
......@@ -27,10 +27,11 @@ namespace profiler {
// Convert collected trace in XSpace format to tools data based on the
// specified list of tools, and save to ProfileResponse.
// The accepted tools are:
// "overview_page", "input_pipeline" and "tensorflow_stats".
void ConvertXSpaceToProfileResponse(const XSpace& xspace,
const ProfileRequest& req,
ProfileResponse* response);
// "overview_page", "input_pipeline", "tensorflow_stats", "kernel_stats"
// and "trace_viewer".
Status ConvertXSpaceToProfileResponse(const XSpace& xspace,
const ProfileRequest& req,
ProfileResponse* response);
} // namespace profiler
} // namespace tensorflow
......
......@@ -66,7 +66,7 @@ TEST(ConvertXPlaneToProfileResponse, TraceViewer) {
CreateXSpace(&xspace);
ProfileRequest request;
ProfileResponse response;
ConvertXSpaceToProfileResponse(xspace, request, &response);
TF_CHECK_OK(ConvertXSpaceToProfileResponse(xspace, request, &response));
}
TEST(ConvertXPlaneToProfileResponse, OverviewPage) {
......@@ -75,7 +75,7 @@ TEST(ConvertXPlaneToProfileResponse, OverviewPage) {
ProfileRequest request;
request.add_tools("overview_page");
ProfileResponse response;
ConvertXSpaceToProfileResponse(xspace, request, &response);
TF_CHECK_OK(ConvertXSpaceToProfileResponse(xspace, request, &response));
EXPECT_EQ(1, response.tool_data_size());
EXPECT_EQ("overview_page.pb", response.tool_data(/*index=*/0).name());
OverviewPage overview_page;
......@@ -89,7 +89,7 @@ TEST(ConvertXPlaneToProfileResponse, InputPipeline) {
ProfileRequest request;
request.add_tools("input_pipeline");
ProfileResponse response;
ConvertXSpaceToProfileResponse(xspace, request, &response);
TF_CHECK_OK(ConvertXSpaceToProfileResponse(xspace, request, &response));
EXPECT_EQ(1, response.tool_data_size());
EXPECT_EQ("input_pipeline.pb", response.tool_data(/*index=*/0).name());
InputPipelineAnalysisResult input_pipeline;
......@@ -103,7 +103,7 @@ TEST(ConvertXPlaneToProfileResponse, TensorflowStats) {
ProfileRequest request;
request.add_tools("tensorflow_stats");
ProfileResponse response;
ConvertXSpaceToProfileResponse(xspace, request, &response);
TF_CHECK_OK(ConvertXSpaceToProfileResponse(xspace, request, &response));
EXPECT_EQ(1, response.tool_data_size());
EXPECT_EQ("tensorflow_stats.pb", response.tool_data(/*index=*/0).name());
TfStatsDatabase tf_stats_db;
......
......@@ -29,5 +29,6 @@ cc_library(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/profiler:profiler_service_proto_cc",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/time",
],
)
......@@ -14,8 +14,6 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/profiler/rpc/client/capture_profile.h"
#include <cstdio>
#include <ctime>
#include <vector>
#include "grpcpp/grpcpp.h"
......@@ -37,14 +35,6 @@ namespace {
constexpr uint64 kMaxEvents = 1000000;
string GetCurrentTimeStampAsString() {
char s[128];
std::time_t t = std::time(nullptr);
auto result = std::strftime(s, sizeof(s), "%F_%T", std::localtime(&t));
DCHECK_NE(result, 0);
return s;
}
ProfileRequest PopulateProfileRequest(int duration_ms,
const string& repository_root,
const string& session_id,
......@@ -52,12 +42,8 @@ ProfileRequest PopulateProfileRequest(int duration_ms,
ProfileRequest request;
request.set_duration_ms(duration_ms);
request.set_max_events(kMaxEvents);
if (absl::StartsWith(repository_root, "gs://")) {
// For backward compatibilities, only generate tracetable etc when the
// user provide a GCS path for model directory.
request.set_repository_root(repository_root);
request.set_session_id(session_id);
}
request.set_repository_root(repository_root);
request.set_session_id(session_id);
request.add_tools("trace_viewer");
request.add_tools("op_profile");
request.add_tools("input_pipeline");
......@@ -94,11 +80,12 @@ Status Profile(const string& service_addr, const string& logdir,
const ProfileOptions& opts) {
ProfileRequest request =
PopulateProfileRequest(duration_ms, logdir, session_id, opts);
std::vector<string> parts = absl::StrSplit(service_addr, ':');
request.set_host_name(parts[0]);
::grpc::ClientContext context;
::grpc::ChannelArguments channel_args;
// TODO(qiuminxu): use `NewHostPortGrpcChannel` instead once their
// `ValidateHostPortPair` checks for empty host string case.
channel_args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH,
std::numeric_limits<int32>::max());
std::unique_ptr<grpc::ProfilerService::Stub> stub =
......@@ -110,8 +97,8 @@ Status Profile(const string& service_addr, const string& logdir,
FromGrpcStatus(stub->Profile(&context, request, &response)));
if (!response.empty_trace()) {
TF_CHECK_OK(
SaveTensorboardProfile(logdir, session_id, "", response, &std::cout));
TF_RETURN_IF_ERROR(SaveTensorboardProfile(
logdir, session_id, request.host_name(), response, &std::cout));
// Print this at the end so that it's not buried in irrelevant LOG messages.
std::cout
<< "NOTE: using the trace duration " << duration_ms << "ms.\n"
......@@ -145,7 +132,6 @@ Status NewSession(const string& service_addr, const string& repository_root,
::grpc::ClientContext context;
::grpc::ChannelArguments channel_args;
// TODO(qiuminxu): use `NewHostPortGrpcChannel` instead once their
// `ValidateHostPortPair` checks for empty host string case.
channel_args.SetMaxReceiveMessageSize(std::numeric_limits<int32>::max());
// TODO(jiesun): GRPC support following relevant naming scheme:
// 1. dns:///host:port
......
......@@ -21,8 +21,11 @@ limitations under the License.
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/strip.h"
#include "absl/time/clock.h"
#include "absl/time/time.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/io/compression.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/protobuf.h"
// Windows.h #defines ERROR, but it is also used in
......@@ -35,18 +38,51 @@ namespace tensorflow {
namespace profiler {
namespace {
using ::tensorflow::io::JoinPath;
#ifdef PLATFORM_WINDOWS
const absl::string_view kPathSep = "\\";
#else
const absl::string_view kPathSep = "/";
#endif
string ProfilerJoinPathImpl(std::initializer_list<absl::string_view> paths) {
string result;
for (absl::string_view path : paths) {
if (path.empty()) continue;
if (result.empty()) {
result = string(path);
continue;
}
path = absl::StripPrefix(path, kPathSep);
if (absl::EndsWith(result, kPathSep)) {
strings::StrAppend(&result, path);
} else {
strings::StrAppend(&result, kPathSep, path);
}
}
return result;
}
// A local duplication of ::tensorflow::io::JoinPath that supports windows.
// TODO(b/150699701): revert to use ::tensorflow::io::JoinPath when fixed.
template <typename... T>
string ProfilerJoinPath(const T&... args) {
return ProfilerJoinPathImpl({args...});
}
constexpr char kProtoTraceFileName[] = "trace";
constexpr char kTfStatsHelperSuffix[] = "tf_stats_helper_result";
Status DumpToolDataToLogDirectory(StringPiece run_dir,
const string& host_prefix,
Status DumpToolDataToLogDirectory(StringPiece run_dir, const string& host,
const ProfileToolData& tool,
std::ostream* os) {
// Don't save the intermediate results for combining the per host tool data.
if (absl::EndsWith(tool.name(), kTfStatsHelperSuffix)) return Status::OK();
string path = JoinPath(run_dir, absl::StrCat(host_prefix, tool.name()));
string host_prefix = host.empty() ? "" : absl::StrCat(host, ".");
string path =
ProfilerJoinPath(run_dir, absl::StrCat(host_prefix, tool.name()));
TF_RETURN_IF_ERROR(WriteStringToFile(Env::Default(), path, tool.data()));
if (os) {
*os << "Dumped tool data for " << tool.name() << " to " << path
......@@ -68,38 +104,81 @@ Status MaybeCreateEmptyEventFile(const string& logdir) {
return Status::OK();
}
}
EventsWriter event_writer(JoinPath(logdir, "events"));
EventsWriter event_writer(ProfilerJoinPath(logdir, "events"));
return event_writer.InitWithSuffix(kProfileEmptySuffix);
}
Status WriteGzippedDataToFile(const string& filepath, const string& data) {
std::unique_ptr<WritableFile> file;
TF_RETURN_IF_ERROR(Env::Default()->NewWritableFile(filepath, &file));
io::ZlibCompressionOptions options = io::ZlibCompressionOptions::GZIP();
io::ZlibOutputBuffer buffer(file.get(), options.input_buffer_size,
options.output_buffer_size, options);
TF_RETURN_IF_ERROR(buffer.Init());
TF_RETURN_IF_ERROR(buffer.Append(data));
TF_RETURN_IF_ERROR(buffer.Close());
TF_RETURN_IF_ERROR(file->Close());
return Status::OK();
}
Status GetOrCreateProfileRunDir(const string& logdir, const string& run,
string* profile_run_dir, std::ostream* os) {
// Dumps profile data to <logdir>/plugins/profile/<run>/.
*profile_run_dir =
ProfilerJoinPath(GetTensorBoardProfilePluginDir(logdir), run);
*os << "Creating directory: " << *profile_run_dir;
TF_RETURN_IF_ERROR(Env::Default()->RecursivelyCreateDir(*profile_run_dir));
// Creates an empty event file so that TensorBoard plugin logic can find
// the logdir.
TF_RETURN_IF_ERROR(MaybeCreateEmptyEventFile(logdir));
return Status::OK();
}
} // namespace
string GetTensorBoardProfilePluginDir(const string& logdir) {
constexpr char kPluginName[] = "plugins";
constexpr char kProfileName[] = "profile";
return JoinPath(logdir, kPluginName, kProfileName);
return ProfilerJoinPath(logdir, kPluginName, kProfileName);
}
Status SaveTensorboardProfile(const string& logdir, const string& run,
const string& host,
const ProfileResponse& response,
std::ostream* os) {
// Dumps profile data to <logdir>/plugins/profile/<run>/.
string host_prefix = host.empty() ? "" : absl::StrCat(host, ".");
string profile_run_dir =
JoinPath(GetTensorBoardProfilePluginDir(logdir), run);
*os << "Creating directory: " << profile_run_dir;
TF_RETURN_IF_ERROR(Env::Default()->RecursivelyCreateDir(profile_run_dir));
// Creates an empty event file so that TensorBoard plugin logic can find
// the logdir.
TF_RETURN_IF_ERROR(MaybeCreateEmptyEventFile(logdir));
string profile_run_dir;
TF_RETURN_IF_ERROR(
GetOrCreateProfileRunDir(logdir, run, &profile_run_dir, os));
for (const auto& tool_data : response.tool_data()) {
TF_RETURN_IF_ERROR(DumpToolDataToLogDirectory(profile_run_dir, host_prefix,
tool_data, os));
TF_RETURN_IF_ERROR(
DumpToolDataToLogDirectory(profile_run_dir, host, tool_data, os));
}
return Status::OK();
}
Status SaveGzippedToolDataToTensorboardProfile(const string& logdir,
const string& run,
const string& host,
const string& tool_name,
const string& data) {
string profile_run_dir;
std::stringstream ss;
Status status = GetOrCreateProfileRunDir(logdir, run, &profile_run_dir, &ss);
LOG(INFO) << ss.str();
TF_RETURN_IF_ERROR(status);
string host_prefix = host.empty() ? "" : absl::StrCat(host, ".");
string path =
ProfilerJoinPath(profile_run_dir, absl::StrCat(host_prefix, tool_name));
TF_RETURN_IF_ERROR(WriteGzippedDataToFile(path, data));
LOG(INFO) << "Dumped gzipped tool data for " << tool_name << " to " << path;
return Status::OK();
}
string GetCurrentTimeStampAsString() {
return absl::FormatTime("%E4Y_%m_%d_%H_%M_%S", absl::Now(),
absl::LocalTimeZone());
}
} // namespace profiler
} // namespace tensorflow
......@@ -22,6 +22,8 @@ limitations under the License.
namespace tensorflow {
namespace profiler {
string GetCurrentTimeStampAsString();
// Returns the profile plugin directory given a logdir to TensorBoard.
string GetTensorBoardProfilePluginDir(const string& logdir);
......@@ -34,6 +36,13 @@ Status SaveTensorboardProfile(const string& logdir, const string& run,
const ProfileResponse& response,
std::ostream* os);
// Gzip the data and save to the specified filepath.
Status SaveGzippedToolDataToTensorboardProfile(const string& logdir,
const string& run,
const string& host,
const string& tool_name,
const string& data);
} // namespace profiler
} // namespace tensorflow
......
......@@ -36,7 +36,8 @@ Status CollectDataToResponse(const ProfileRequest& req,
ProfileResponse* response) {
profiler::XSpace xspace;
TF_RETURN_IF_ERROR(profiler->CollectData(&xspace));
profiler::ConvertXSpaceToProfileResponse(xspace, req, response);
TF_RETURN_IF_ERROR(
profiler::ConvertXSpaceToProfileResponse(xspace, req, response));
return Status::OK();
}
......
......@@ -1829,7 +1829,7 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase):
for (dirpath, dirnames, filenames) in os.walk(profile_dir):
del dirnames # unused
for filename in filenames:
if filename.endswith('.trace'):
if filename.endswith('.trace.json.gz'):
return os.path.join(dirpath, filename)
return None
......
......@@ -94,6 +94,8 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
x=train_ds,
epochs=num_epoch,
steps_per_epoch=steps,
validation_data=train_ds,
validation_steps=steps,
callbacks=[
callbacks.ModelCheckpoint(
filepath=saving_filepath, save_weights_only=save_weights_only)
......
......@@ -859,10 +859,13 @@ class Network(base_layer.Layer):
argspec = self._layer_call_argspecs[layer].args
if 'training' in argspec:
kwargs.setdefault('training', training)
if (type(kwargs['training']) is ops.Tensor and # pylint: disable=unidiomatic-typecheck
any([kwargs['training'] is x
for x in backend._GRAPH_LEARNING_PHASES.values()])):
if 'training' not in kwargs or kwargs['training'] is None:
kwargs['training'] = training
elif (type(kwargs['training']) is ops.Tensor and # pylint: disable=unidiomatic-typecheck
any([
kwargs['training'] is x
for x in backend._GRAPH_LEARNING_PHASES.values()
])):
kwargs['training'] = training # Materialize placeholder.
# Map Keras tensors in kwargs to their computed value.
......
......@@ -32,7 +32,7 @@ from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.engine import base_layer
from tensorflow.python.keras.engine import input_layer as input_layer_lib
from tensorflow.python.keras.engine import network as network_lib
from tensorflow.python.keras.engine import training
from tensorflow.python.keras.engine import training as training_lib
from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
......@@ -1109,20 +1109,20 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
outputs = keras.layers.Dense(4)(inputs)
with self.assertRaisesRegexp(TypeError, 'unexpected argument'):
model = training.Model(inputs, outputs, name='m', trainable=False,
dtype='int64')
model = training_lib.Model(
inputs, outputs, name='m', trainable=False, dtype='int64')
with self.assertRaisesRegexp(TypeError, 'unexpected argument'):
model = training.Model(inputs, outputs, name='m', trainable=False,
dynamic=False)
model = training_lib.Model(
inputs, outputs, name='m', trainable=False, dynamic=False)
model = training.Model(inputs, outputs, name='m', trainable=False)
model = training_lib.Model(inputs, outputs, name='m', trainable=False)
self.assertEqual('m', model.name)
self.assertFalse(model.trainable)
self.assertFalse(model.dynamic)
# Subclassed model
model = training.Model(name='subclassed', trainable=True, dtype='int64',
dynamic=True)
model = training_lib.Model(
name='subclassed', trainable=True, dtype='int64', dynamic=True)
self.assertEqual('subclassed', model.name)
self.assertTrue(model.dynamic)
self.assertTrue(model.trainable)
......@@ -1875,6 +1875,44 @@ class CacheCorrectnessTest(keras_parameterized.TestCase):
y = net(x)
self.assertEqual(y.shape.rank, 2)
def test_training_passed_during_construction(self):
class MyLayer(base_layer.Layer):
def call(self, x, training=None):
self.training = training
return x
my_layer = MyLayer()
x = np.ones((1, 10))
inputs = input_layer_lib.Input(10)
outputs = my_layer(inputs, training=True)
network = network_lib.Network(inputs, outputs)
network(x, training=False)
# Hard-coded value passed during construction is respected.
self.assertTrue(my_layer.training)
inputs = input_layer_lib.Input(10)
outputs = my_layer(inputs, training=False)
network = network_lib.Network(inputs, outputs)
network(x, training=True)
# Hard-coded value passed during construction is respected.
self.assertFalse(my_layer.training)
inputs = input_layer_lib.Input(10)
outputs = my_layer(inputs, training=None)
network = network_lib.Network(inputs, outputs)
network(x, training=True)
# `None` value passed during construction is overridden.
self.assertTrue(my_layer.training)
network(x, training=False)
# `None` value passed during construction is overridden.
self.assertFalse(my_layer.training)
if __name__ == '__main__':
test.main()
......@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.distribute import distribute_coordinator as dc
from tensorflow.python.distribute import distribute_coordinator_context as dc_context
from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.distribute import values as ds_values
from tensorflow.python.eager import backprop
......@@ -61,6 +62,10 @@ def enable_multi_worker(method):
if not self._in_multi_worker_mode(): # pylint: disable=protected-access
return method(self, *args, **kwargs)
# Running inside `run_distribute_coordinator` already.
if dc_context.get_current_worker_context():
return method(self, *args, **kwargs)
return dc.run_distribute_coordinator(
lambda _: method(self, *args, **kwargs),
self.distribute_strategy,
......
......@@ -514,11 +514,9 @@ class BatchNormalizationBase(Layer):
K.zeros_like(update_delta))
return state_ops.assign_sub(variable, update_delta, name=scope)
def _assign_new_value(self, variable, value, inputs_size=None):
def _assign_new_value(self, variable, value):
with K.name_scope('AssignNewValue') as scope:
with ops.colocate_with(variable):
if inputs_size is not None:
value = array_ops.where(inputs_size > 0, value, variable)
return state_ops.assign(variable, value, name=scope)
def _fused_batch_norm(self, inputs, training):
......@@ -569,6 +567,9 @@ class BatchNormalizationBase(Layer):
data_format=self._data_format,
exponential_avg_factor=exponential_avg_factor)
def _fused_batch_norm_training_empty():
return inputs, self.moving_mean, self.moving_variance
def _fused_batch_norm_inference():
return nn.fused_batch_norm(
inputs,
......@@ -580,8 +581,14 @@ class BatchNormalizationBase(Layer):
is_training=False,
data_format=self._data_format)
output, mean, variance = tf_utils.smart_cond(
training, _fused_batch_norm_training, _fused_batch_norm_inference)
train_op = _fused_batch_norm_training
if compat.forward_compatible(2020, 3, 6) and inputs_size is not None:
train_op = lambda: tf_utils.smart_cond(inputs_size > 0,
_fused_batch_norm_training,
_fused_batch_norm_training_empty)
output, mean, variance = tf_utils.smart_cond(training, train_op,
_fused_batch_norm_inference)
variance = _maybe_add_or_remove_bessels_correction(variance, remove=True)
training_value = tf_utils.constant_value(training)
......@@ -596,7 +603,7 @@ class BatchNormalizationBase(Layer):
def mean_update():
"""Update self.moving_mean with the most recent data point."""
if compat.forward_compatible(2020, 3, 6):
return self._assign_new_value(self.moving_mean, mean, inputs_size)
return self._assign_new_value(self.moving_mean, mean)
else:
return self._assign_moving_average(self.moving_mean, mean, momentum,
inputs_size)
......@@ -604,8 +611,7 @@ class BatchNormalizationBase(Layer):
def variance_update():
"""Update self.moving_variance with the most recent data point."""
if compat.forward_compatible(2020, 3, 6):
return self._assign_new_value(self.moving_variance, variance,
inputs_size)
return self._assign_new_value(self.moving_variance, variance)
else:
return self._assign_moving_average(self.moving_variance, variance,
momentum, inputs_size)
......
......@@ -59,7 +59,6 @@ cuda_py_test(
python_version = "PY3",
tags = [
"no_pip",
"no_windows",
],
deps = [
":profiler_v2",
......
......@@ -125,7 +125,6 @@ tf_python_pybind_extension(
"//tensorflow/core/profiler/rpc/client:save_profile",
"//tensorflow/python:pybind11_status",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/time",
"@pybind11",
],
)
......@@ -16,7 +16,6 @@ limitations under the License.
#include <memory>
#include "absl/memory/memory.h"
#include "absl/time/time.h"
#include "include/pybind11/pybind11.h"
#include "tensorflow/core/platform/host_info.h"
#include "tensorflow/core/platform/types.h"
......@@ -31,18 +30,18 @@ namespace py = ::pybind11;
namespace {
tensorflow::string GetCurrentTimeStampAsString() {
return absl::FormatTime("%E4Y-%m-%d_%H:%M:%S", absl::Now(),
absl::LocalTimeZone());
}
tensorflow::ProfileRequest MakeProfileRequest() {
tensorflow::ProfileRequest MakeProfileRequest(
const tensorflow::string& logdir, const tensorflow::string& session_id,
const tensorflow::string& host) {
tensorflow::ProfileRequest request;
request.add_tools("trace_viewer");
request.add_tools("overview_page");
request.add_tools("input_pipeline");
request.add_tools("kernel_stats");
request.add_tools("tensorflow_stats");
request.set_host_name(host);
request.set_repository_root(logdir);
request.set_session_id(session_id);
return request;
}
......@@ -71,20 +70,22 @@ class ProfilerSessionWrapper {
tensorflow::Status status;
status = session_->CollectData(&xspace);
session_.reset();
if (!status.ok()) {
tensorflow::MaybeRaiseRegisteredFromStatus(status);
return;
}
tensorflow::MaybeRaiseRegisteredFromStatus(status);
tensorflow::ProfileResponse response;
tensorflow::profiler::ConvertXSpaceToProfileResponse(
xspace, MakeProfileRequest(), &response);
tensorflow::ProfileRequest request = MakeProfileRequest(
logdir_, tensorflow::profiler::GetCurrentTimeStampAsString(),
tensorflow::port::Hostname());
status = tensorflow::profiler::ConvertXSpaceToProfileResponse(
xspace, request, &response);
tensorflow::MaybeRaiseRegisteredFromStatus(status);
std::stringstream ss; // Record LOG messages.
status = tensorflow::profiler::SaveTensorboardProfile(
logdir_, GetCurrentTimeStampAsString(), tensorflow::port::Hostname(),
request.repository_root(), request.session_id(), request.host_name(),
response, &ss);
LOG(INFO) << ss.str();
tensorflow::MaybeRaiseRegisteredFromStatus(tensorflow::Status::OK());
tensorflow::MaybeRaiseRegisteredFromStatus(status);
}
private:
......
......@@ -81,6 +81,9 @@ def start(logdir):
'server and profiler APIs at the same time.')
raise errors.AlreadyExistsError(None, None,
'Another profiler is running.')
except Exception:
_profiler = None
raise
@tf_export('profiler.experimental.stop', v1=[])
......@@ -102,7 +105,11 @@ def stop(save=True):
None, None,
'Cannot export profiling results. No profiler is running.')
if save:
_profiler.export_to_tb()
try:
_profiler.export_to_tb()
except Exception:
_profiler = None
raise
_profiler = None
......@@ -127,12 +134,8 @@ def start_server(port):
Args:
port: port profiler server listens to.
Example usage:
```python
tf.profiler.experimental.server.start('6009')
# do your training here.
Example usage: ```python tf.profiler.experimental.server.start('6009') # do
your training here.
"""
_pywrap_profiler.start_server(port)
......
......@@ -21,10 +21,7 @@ from __future__ import print_function
import os
import socket
from tensorflow.core.protobuf import trace_events_pb2
from tensorflow.python.eager import profiler
from tensorflow.python.eager import test
from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
......@@ -45,6 +42,16 @@ class ProfilerTest(test_util.TensorFlowTestCase):
with self.assertRaises(errors.UnavailableError):
profiler.stop()
# Test with a bad logdir, and it correctly raises exception and deletes
# profiler.
# pylint: disable=anomalous-backslash-in-string
profiler.start('/\/\/:123')
# pylint: enable=anomalous-backslash-in-string
with self.assertRaises(Exception):
profiler.stop()
profiler.start(logdir)
profiler.stop()
def test_save_profile(self):
logdir = self.get_temp_dir()
profiler.start(logdir)
......@@ -74,19 +81,10 @@ class ProfilerTest(test_util.TensorFlowTestCase):
tensorflow_stats = os.path.join(profile_dir, run,
hostname + '.tensorflow_stats.pb')
self.assertTrue(gfile.Exists(tensorflow_stats))
trace_file = os.path.join(profile_dir, run, hostname + '.trace')
kernel_stats = os.path.join(profile_dir, run, hostname + '.kernel_stats.pb')
self.assertTrue(gfile.Exists(kernel_stats))
trace_file = os.path.join(profile_dir, run, hostname + '.trace.json.gz')
self.assertTrue(gfile.Exists(trace_file))
with gfile.Open(trace_file, 'rb') as f:
profile_pb = trace_events_pb2.Trace()
profile_pb.ParseFromString(f.read())
devices = frozenset(device.name for device in profile_pb.devices.values())
self.assertIn('/host:CPU', devices)
if config.list_physical_devices('GPU'):
self.assertIn('/device:GPU:0', devices)
events = frozenset(event.name for event in profile_pb.trace_events)
self.assertIn('three_times_five', events)
self.assertIn('Mul:Mul', events)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册