diff --git a/WORKSPACE b/WORKSPACE index 1bf1069f8801c9d135d77c871520ff733b7713e9..b40913801ba8e3c8ee73f7ba69540b520ad698a6 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -5,7 +5,7 @@ http_archive( sha256 = "110fe68753413777944b473c25eed6368c4a0487cee23a7bac1b13cc49d3e257", strip_prefix = "rules_closure-4af89ef1db659eb41f110df189b67d4cf14073e1", urls = [ - "http://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/4af89ef1db659eb41f110df189b67d4cf14073e1.tar.gz", + "https://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/4af89ef1db659eb41f110df189b67d4cf14073e1.tar.gz", "https://github.com/bazelbuild/rules_closure/archive/4af89ef1db659eb41f110df189b67d4cf14073e1.tar.gz", # 2017-08-28 ], ) diff --git a/tensorflow/BUILD b/tensorflow/BUILD index e99d9e29ef2f82be9efa6eb458aea72132fe5feb..e3bed8cc6213828ac03404ec59ada7a5ffd7ad8e 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -348,6 +348,7 @@ filegroup( "//tensorflow/compiler/xla/service/llvm_ir:all_files", "//tensorflow/compiler/xla/tests:all_files", "//tensorflow/compiler/xla/tools:all_files", + "//tensorflow/compiler/xla/tools/parser:all_files", "//tensorflow/contrib:all_files", "//tensorflow/contrib/all_reduce:all_files", "//tensorflow/contrib/android:all_files", @@ -421,7 +422,6 @@ filegroup( "//tensorflow/contrib/remote_fused_graph/pylib:all_files", "//tensorflow/contrib/resampler:all_files", "//tensorflow/contrib/rnn:all_files", - "//tensorflow/contrib/s3:all_files", "//tensorflow/contrib/saved_model:all_files", "//tensorflow/contrib/saved_model/cc/saved_model:all_files", "//tensorflow/contrib/seq2seq:all_files", @@ -475,6 +475,7 @@ filegroup( "//tensorflow/core/platform/cloud:all_files", "//tensorflow/core/platform/default/build_config:all_files", "//tensorflow/core/platform/hadoop:all_files", + "//tensorflow/core/platform/s3:all_files", "//tensorflow/core/profiler:all_files", "//tensorflow/core/profiler/internal:all_files", "//tensorflow/core/profiler/internal/advisor:all_files", diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 96f3c3e195e7025252c1e3cda5436237ad89257b..c77896b80b478cd34d3502e1061a7e76204ba021 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -3,6 +3,7 @@ licenses(["notice"]) # Apache 2.0 load( "//tensorflow:tensorflow.bzl", + "tf_cuda_cc_test", "tf_cc_test", "tf_copts", "tf_cuda_library", @@ -50,7 +51,7 @@ tf_cuda_library( ], ) -tf_cc_test( +tf_cuda_cc_test( name = "c_api_test", srcs = ["c_api_test.cc"], deps = [ diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 514a4010bc81bb280c3a1208b57a5db752f52f8a..28ea2edee4face6e66424c35ae6336bd87bdb3d8 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -54,9 +54,23 @@ string DeviceName(tensorflow::Device* d) { extern "C" { -TFE_Context* TFE_NewContext(const TF_SessionOptions* opts, TF_Status* status) { +TFE_ContextOptions* TFE_NewContextOptions() { return new TFE_ContextOptions; } + +void TFE_ContextOptionsSetConfig(TFE_ContextOptions* options, const void* proto, + size_t proto_len, TF_Status* status) { + TF_SetConfig(&options->session_options, proto, proto_len, status); +} + +void TFE_ContextOptionsSetDevicePlacementPolicy( + TFE_ContextOptions* options, TFE_ContextDevicePlacementPolicy policy) { + options->policy = policy; +} + +void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; } + +TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { TF_Graph* graph = TF_NewGraph(); - TF_Session* session = TF_NewSession(graph, opts, status); + TF_Session* session = TF_NewSession(graph, &opts->session_options, status); if (status->status.ok()) { if (session->device_mgr == nullptr || session->devices.empty()) { status->status = tensorflow::errors::InvalidArgument( @@ -71,9 +85,10 @@ TFE_Context* TFE_NewContext(const TF_SessionOptions* opts, TF_Status* status) { } TFE_Context* ret = new TFE_Context(session); + ret->policy = opts->policy; ret->pflr.reset(new tensorflow::ProcessFunctionLibraryRuntime( - ret->session->device_mgr, opts->options.env, TF_GRAPH_DEF_VERSION, - &ret->func_lib_def, {})); + ret->session->device_mgr, opts->session_options.options.env, + TF_GRAPH_DEF_VERSION, &ret->func_lib_def, {})); ret->rendezvous = new tensorflow::IntraProcessRendezvous(ret->session->device_mgr); @@ -408,8 +423,10 @@ void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name, namespace { tensorflow::Status ValidateInputTypeAndPlacement( - tensorflow::Device* host_device, tensorflow::Device* op_device, TFE_Op* op, - const tensorflow::OpKernel* kernel) { + TFE_Context* ctx, tensorflow::Device* host_device, + tensorflow::Device* op_device, TFE_Op* op, + const tensorflow::OpKernel* kernel, + std::vector* copied_tensors) { const tensorflow::MemoryTypeVector& memtypes = kernel->input_memory_types(); if (memtypes.size() != op->inputs.size()) { return tensorflow::errors::InvalidArgument( @@ -421,11 +438,42 @@ tensorflow::Status ValidateInputTypeAndPlacement( const tensorflow::Device* actual_device = op->input_devices[i] == nullptr ? host_device : op->input_devices[i]; if (expected_device != actual_device) { - return tensorflow::errors::InvalidArgument( - "cannot compute ", op->name, " as input #", i, - " was expected to be on ", expected_device->name(), - " but is actually on ", actual_device->name(), - " (operation running on ", op_device->name(), ")"); + switch (ctx->policy) { + case TFE_DEVICE_PLACEMENT_EXPLICIT: + return tensorflow::errors::InvalidArgument( + "cannot compute ", op->name, " as input #", i, + " was expected to be on ", expected_device->name(), + " but is actually on ", actual_device->name(), + " (operation running on ", op_device->name(), ")"); + case TFE_DEVICE_PLACEMENT_WARN: + LOG(WARNING) << "before computing " << op->name << " input #" << i + << " was expected to be on " << expected_device->name() + << " but is actually on " << actual_device->name() + << " (operation running on " << op_device->name() + << "). This triggers a copy which can be a performance " + "bottleneck."; + break; + case TFE_DEVICE_PLACEMENT_SILENT: // Do nothing. + break; + } + // We are only here if the policy is warn or silent copies, so we should + // trigger a copy. + TFE_TensorHandle original{op->inputs[i], op->input_devices[i]}; + TF_Status* s = TF_NewStatus(); + TFE_TensorHandle* copied_tensor = TFE_TensorHandleCopyToDevice( + &original, ctx, expected_device->name().c_str(), s); + if (!s->status.ok()) { + tensorflow::Status status = s->status; + delete s; + return tensorflow::errors::Internal( + "Failed copying input tensor from ", actual_device->name(), " to ", + expected_device->name(), " in order to run ", op->name, ": ", + status.error_message()); + } + op->inputs[i] = copied_tensor->t; + copied_tensors->push_back(copied_tensor); + op->input_devices[i] = copied_tensor->d; + delete s; } if (op->inputs[i].dtype() != kernel->input_type(i)) { return tensorflow::errors::InvalidArgument( @@ -468,10 +516,14 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, } tensorflow::gtl::InsertOrUpdate(&(ctx->kernel_cache), cache_key, kernel); } - status->status = ValidateInputTypeAndPlacement(ctx->devices()[0], device, op, - kernel->kernel()); + std::vector copied_tensors; + status->status = ValidateInputTypeAndPlacement( + ctx, ctx->devices()[0], device, op, kernel->kernel(), &copied_tensors); output_memory_types = &kernel->kernel()->output_memory_types(); if (!status->status.ok()) { + for (auto* t : copied_tensors) { + TFE_DeleteTensorHandle(t); + } return; } // WARNING: kernel->Run utilizes the FunctionLibraryRuntime @@ -483,6 +535,9 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, // sense for FunctionLibraryRuntime to ensure thread-safe access to // FunctionLibraryDefinition?). status->status = kernel->Run(&op->inputs, &outputs); + for (auto* t : copied_tensors) { + TFE_DeleteTensorHandle(t); + } if (!status->status.ok()) return; *num_retvals = std::min(*num_retvals, outputs.size()); for (int i = 0; i < *num_retvals; ++i) { diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index 9bfa63711b5360b33819434f9a551030e0f988c8..865580c5f3a823d9cf49fe460bd007e3b3b88767 100644 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -43,14 +43,46 @@ limitations under the License. extern "C" { #endif +typedef struct TFE_ContextOptions TFE_ContextOptions; + +// Return a new options object. +TF_CAPI_EXPORT extern TFE_ContextOptions* TFE_NewContextOptions(); + +// Set the config in TF_ContextOptions.options. +// config should be a serialized tensorflow.ConfigProto proto. +// If config was not parsed successfully as a ConfigProto, record the +// error information in *status. +TF_CAPI_EXPORT extern void TFE_ContextOptionsSetConfig( + TFE_ContextOptions* options, const void* proto, size_t proto_len, + TF_Status* status); + +// Controls how to act when we try to run an operation on a given device but +// some input tensors are not on that device. +typedef enum TFE_ContextDevicePlacementPolicy { + // The default: running operations with input tensors on the wrong device will + // fail. + TFE_DEVICE_PLACEMENT_EXPLICIT = 0, + // Copy the tensor to the right device but log a warning. + TFE_DEVICE_PLACEMENT_WARN = 1, + // Silently copy the tensor, which has a performance cost since the + // operation will be blocked till the copy completes. + TFE_DEVICE_PLACEMENT_SILENT = 2, +} TFE_ContextDevicePlacementPolicy; + +TF_CAPI_EXPORT extern void TFE_ContextOptionsSetDevicePlacementPolicy( + TFE_ContextOptions*, TFE_ContextDevicePlacementPolicy); + +// Destroy an options object. +TF_CAPI_EXPORT extern void TFE_DeleteContextOptions(TFE_ContextOptions*); + // "Context" under which operations/functions are executed. It encapsulates // things like the available devices, resource manager etc. // // TODO(ashankar): Merge with TF_Session? typedef struct TFE_Context TFE_Context; -TF_CAPI_EXPORT extern TFE_Context* TFE_NewContext(const TF_SessionOptions* opts, - TF_Status* status); +TF_CAPI_EXPORT extern TFE_Context* TFE_NewContext( + const TFE_ContextOptions* opts, TF_Status* status); TF_CAPI_EXPORT extern void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status); TF_CAPI_EXPORT extern TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status); diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index 712526f17002a612a145f80538977fedfde00038..0971e2ab2fe98cc8bf6f631f41d5adce90ee7051 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -35,9 +35,16 @@ limitations under the License. #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/thread_annotations.h" +struct TFE_ContextOptions { + TF_SessionOptions session_options; + TFE_ContextDevicePlacementPolicy policy{TFE_DEVICE_PLACEMENT_EXPLICIT}; +}; + struct TFE_Context { explicit TFE_Context(TF_Session* s) : session(s) {} + TFE_ContextDevicePlacementPolicy policy; + // TFE_Context is an extension of TF_Session. And TF_Session needs a TF_Graph. TF_Session* session; tensorflow::Rendezvous* rendezvous; diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 72e0fe8a1565a9a717c01aed83044cab2dd2dfbc..4af91b8853d0e85570bad136752a9d0a04b87da5 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -62,10 +62,10 @@ TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) { void BM_InitOp(int iters) { tensorflow::testing::StopTiming(); TF_Status* status = TF_NewStatus(); - TF_SessionOptions* opts = TF_NewSessionOptions(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_Context* ctx = TFE_NewContext(opts, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TF_DeleteSessionOptions(opts); + TFE_DeleteContextOptions(opts); TFE_TensorHandle* m = TestMatrixTensorHandle(); tensorflow::testing::StartTiming(); @@ -84,10 +84,10 @@ BENCHMARK(BM_InitOp); void BM_Execute(int iters) { tensorflow::testing::StopTiming(); TF_Status* status = TF_NewStatus(); - TF_SessionOptions* opts = TF_NewSessionOptions(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_Context* ctx = TFE_NewContext(opts, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TF_DeleteSessionOptions(opts); + TFE_DeleteContextOptions(opts); TFE_TensorHandle* m = TestMatrixTensorHandle(); TFE_Op* matmul = MatMulOp(ctx, m, m); @@ -109,9 +109,9 @@ BENCHMARK(BM_Execute); TEST(CAPI, Context) { TF_Status* status = TF_NewStatus(); - TF_SessionOptions* opts = TF_NewSessionOptions(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_Context* ctx = TFE_NewContext(opts, status); - TF_DeleteSessionOptions(opts); + TFE_DeleteContextOptions(opts); TF_DeviceList* devices = TFE_ContextListDevices(ctx, status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); @@ -150,9 +150,9 @@ TEST(CAPI, TensorHandle) { TEST(CAPI, TensorHandleCopyBetweenDevices) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); - TF_SessionOptions* opts = TF_NewSessionOptions(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_Context* ctx = TFE_NewContext(opts, status.get()); - TF_DeleteSessionOptions(opts); + TFE_DeleteContextOptions(opts); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); TFE_TensorHandle* hcpu = TestMatrixTensorHandle(); @@ -216,12 +216,58 @@ TEST(CAPI, TensorHandleCopyBetweenDevices) { EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); } +TEST(CAPI, TensorHandleSilentCopy) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); + TFE_Context* ctx = TFE_NewContext(opts, status.get()); + TFE_DeleteContextOptions(opts); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + TFE_TensorHandle* hcpu = TestMatrixTensorHandle(); + TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + const int num_devices = TF_DeviceListCount(devices); + + // Disable the test if no GPU is present. + if (num_devices > 1) { + const int device_to_use = 1; + const string name(TF_DeviceListName(devices, device_to_use, status.get())); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + TFE_TensorHandle* hgpu = + TFE_TensorHandleCopyToDevice(hcpu, ctx, name.c_str(), status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu); + TFE_OpSetDevice(matmul, name.c_str(), status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + TFE_TensorHandle* retvals[1]; + int num_retvals = 1; + TFE_Execute(matmul, &retvals[0], &num_retvals, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + TFE_DeleteOp(matmul); + TFE_DeleteTensorHandle(retvals[0]); + TFE_DeleteTensorHandle(hgpu); + } + + TF_DeleteDeviceList(devices); + TF_DeleteTensor(t); + TFE_DeleteTensorHandle(hcpu); + TFE_DeleteContext(ctx, status.get()); + EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); +} + TEST(CAPI, Execute) { TF_Status* status = TF_NewStatus(); - TF_SessionOptions* opts = TF_NewSessionOptions(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_Context* ctx = TFE_NewContext(opts, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TF_DeleteSessionOptions(opts); + TFE_DeleteContextOptions(opts); TFE_TensorHandle* m = TestMatrixTensorHandle(); TFE_Op* matmul = MatMulOp(ctx, m, m); @@ -285,10 +331,10 @@ string MatMulFunction() { TEST(CAPI, FunctionDefAndExecute) { TF_Status* status = TF_NewStatus(); - TF_SessionOptions* opts = TF_NewSessionOptions(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_Context* ctx = TFE_NewContext(opts, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TF_DeleteSessionOptions(opts); + TFE_DeleteContextOptions(opts); string function_def = MatMulFunction(); TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(), @@ -326,10 +372,10 @@ TEST(CAPI, FunctionDefAndExecute) { void BM_ExecuteFunction(int iters) { tensorflow::testing::StopTiming(); TF_Status* status = TF_NewStatus(); - TF_SessionOptions* opts = TF_NewSessionOptions(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_Context* ctx = TFE_NewContext(opts, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TF_DeleteSessionOptions(opts); + TFE_DeleteContextOptions(opts); string function_def = MatMulFunction(); TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(), @@ -406,10 +452,10 @@ TEST(CAPI, Variables) { // Variables use resource handles, so this is really a test for resource // tensor handling. TF_Status* status = TF_NewStatus(); - TF_SessionOptions* opts = TF_NewSessionOptions(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_Context* ctx = TFE_NewContext(opts, status); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TF_DeleteSessionOptions(opts); + TFE_DeleteContextOptions(opts); TFE_TensorHandle* var_handle = CreateVariable(ctx, 12.0, status); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); @@ -446,10 +492,10 @@ TEST(CAPI, Variables) { void BM_ReadVariable(int iters) { tensorflow::testing::StopTiming(); TF_Status* status = TF_NewStatus(); - TF_SessionOptions* opts = TF_NewSessionOptions(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_Context* ctx = TFE_NewContext(opts, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TF_DeleteSessionOptions(opts); + TFE_DeleteContextOptions(opts); TFE_TensorHandle* var_handle = CreateVariable(ctx, 5.0, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h index cdd9c8847f56e25bcb807a9cf0631e72bf4355ee..93c2a806780140ac3d60e1eacf66c60c024cd9b7 100644 --- a/tensorflow/compiler/xla/client/computation_builder.h +++ b/tensorflow/compiler/xla/client/computation_builder.h @@ -138,6 +138,11 @@ class ComputationBuilder { ComputationDataHandle ConstantR2( std::initializer_list> values); template + ComputationDataHandle ConstantFromArrayWithLayout( + const Array& values, const Layout& layout); + template + ComputationDataHandle ConstantFromArray(const Array& values); + template ComputationDataHandle ConstantR2FromArray2DWithLayout( const Array2D& values, const Layout& layout); template @@ -910,48 +915,54 @@ ComputationDataHandle ComputationBuilder::ConstantR2( } template -ComputationDataHandle ComputationBuilder::ConstantR2FromArray2DWithLayout( - const Array2D& values, const Layout& layout) { +ComputationDataHandle ComputationBuilder::ConstantFromArrayWithLayout( + const Array& values, const Layout& layout) { return ConstantOp([&values, &layout](Literal* literal) { - literal->PopulateR2FromArray2DWithLayout(values, layout); + literal->PopulateFromArrayWithLayout(values, layout); }); } +template +ComputationDataHandle ComputationBuilder::ConstantFromArray( + const Array& values) { + return ConstantOp( + [&values](Literal* literal) { literal->PopulateFromArray(values); }); +} + +template +ComputationDataHandle ComputationBuilder::ConstantR2FromArray2DWithLayout( + const Array2D& values, const Layout& layout) { + return ConstantFromArrayWithLayout(values, layout); +} + template ComputationDataHandle ComputationBuilder::ConstantR2FromArray2D( const Array2D& values) { - return ConstantOp( - [&values](Literal* literal) { literal->PopulateR2FromArray2D(values); }); + return ConstantFromArray(values); } template ComputationDataHandle ComputationBuilder::ConstantR3FromArray3DWithLayout( const Array3D& values, const Layout& layout) { - return ConstantOp([&values, &layout](Literal* literal) { - literal->PopulateR3FromArray3DWithLayout(values, layout); - }); + return ConstantFromArrayWithLayout(values, layout); } template ComputationDataHandle ComputationBuilder::ConstantR3FromArray3D( const Array3D& values) { - return ConstantOp( - [&values](Literal* literal) { literal->PopulateR3FromArray3D(values); }); + return ConstantFromArray(values); } template ComputationDataHandle ComputationBuilder::ConstantR4FromArray4DWithLayout( const Array4D& values, const Layout& layout) { - return ConstantOp([&values, &layout](Literal* literal) { - literal->PopulateR4FromArray4DWithLayout(values, layout); - }); + return ConstantFromArrayWithLayout(values, layout); } template ComputationDataHandle ComputationBuilder::ConstantR4FromArray4D( const Array4D& values) { - return ConstantOp( - [&values](Literal* literal) { literal->PopulateR4FromArray4D(values); }); + return ConstantFromArray(values); } } // namespace xla diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index 011fc3c194e0eb9ebd6b9e42571deddaf25c09ff..5c2cc2a7a99cc51ded3d98c9dd5903e4b3078548 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -83,6 +83,10 @@ Layout CreateDefaultLayoutForRank(int64 rank) { return CreateDefaultLayoutForRank(shape.dimensions_size()); } +/* static */ Layout LayoutUtil::GetDefaultLayoutForRank(int64 rank) { + return CreateDefaultLayoutForRank(rank); +} + /* static */ Layout LayoutUtil::GetDefaultLayoutForR2() { return CreateDefaultLayoutForRank(2); } diff --git a/tensorflow/compiler/xla/layout_util.h b/tensorflow/compiler/xla/layout_util.h index 5de0a653f66688ac75fc377c18ff93012314abdd..bc42e222292933be35e82d1fe50802e8830d16b3 100644 --- a/tensorflow/compiler/xla/layout_util.h +++ b/tensorflow/compiler/xla/layout_util.h @@ -40,6 +40,7 @@ class LayoutUtil { static Layout GetDefaultLayoutForShape(const Shape& shape); // Helper functions that create default layouts for various ranks. + static Layout GetDefaultLayoutForRank(int64 rank); static Layout GetDefaultLayoutForR2(); static Layout GetDefaultLayoutForR3(); static Layout GetDefaultLayoutForR4(); diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc index 8892bfbe929d168c602af24cfbb507256dc05328..f2cdd9669c727bb778fce495ede0faaf2d9a923d 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc +++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc @@ -206,9 +206,9 @@ void AllocateFlags() { flag_values->xla_gpu_disable_multi_streaming(), "If true, multi-streaming in the GPU backend is disabled."), tensorflow::Flag( - "xla_dump_debug_json_to", - flag_values->mutable_xla_dump_debug_json_to(), - "Dump compilation artifacts as JSON into this directory."), + "xla_dump_hlo_proto_to", + flag_values->mutable_xla_dump_hlo_proto_to(), + "Dump compilation artifacts as proto binary into this directory."), tensorflow::Flag( "xla_test_all_output_layouts", bool_setter_for(&DebugOptions::set_xla_test_all_output_layouts), diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index e8cee732d4cf5629c1e2b9c98d1f1bbe1e29a122..4063cb05a91bff21afa318440104af16694f8ed1 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -334,6 +334,11 @@ class Literal { // WithLayout use the default XLA layout for the literal's linear // representation in memory. template + static std::unique_ptr CreateFromArray(const Array& values); + template + static std::unique_ptr CreateFromArrayWithLayout( + const Array& values, const Layout& layout); + template static std::unique_ptr CreateR2FromArray2D( const Array2D& values); template @@ -481,6 +486,11 @@ class Literal { std::initializer_list> values, const Layout& layout); template + void PopulateFromArray(const Array& values); + template + void PopulateFromArrayWithLayout(const Array& values, + const Layout& layout); + template void PopulateR2FromArray2D(const Array2D& values); template void PopulateR2FromArray2DWithLayout(const Array2D& values, @@ -816,33 +826,42 @@ template } template -/* static */ std::unique_ptr Literal::CreateR2FromArray2DWithLayout( - const Array2D& values, const Layout& layout) { +/* static */ std::unique_ptr Literal::CreateFromArrayWithLayout( + const Array& values, const Layout& layout) { auto literal = MakeUnique(); - literal->PopulateR2FromArray2DWithLayout(values, layout); + literal->PopulateFromArrayWithLayout(values, layout); return literal; } +template +/* static */ std::unique_ptr Literal::CreateFromArray( + const Array& values) { + return CreateFromArrayWithLayout( + values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions())); +} + +template +/* static */ std::unique_ptr Literal::CreateR2FromArray2DWithLayout( + const Array2D& values, const Layout& layout) { + return CreateFromArrayWithLayout(values, layout); +} + template /* static */ std::unique_ptr Literal::CreateR2FromArray2D( const Array2D& values) { - return CreateR2FromArray2DWithLayout(values, - LayoutUtil::GetDefaultLayoutForR2()); + return CreateFromArray(values); } template /* static */ std::unique_ptr Literal::CreateR3FromArray3DWithLayout( const Array3D& values, const Layout& layout) { - auto literal = MakeUnique(); - literal->PopulateR3FromArray3DWithLayout(values, layout); - return literal; + return CreateFromArrayWithLayout(values, layout); } template /* static */ std::unique_ptr Literal::CreateR3FromArray3D( const Array3D& values) { - return CreateR3FromArray3DWithLayout(values, - LayoutUtil::GetDefaultLayoutForR3()); + return CreateFromArray(values); } template @@ -901,16 +920,13 @@ template template /* static */ std::unique_ptr Literal::CreateR4FromArray4D( const Array4D& values) { - return CreateR4FromArray4DWithLayout(values, - LayoutUtil::GetDefaultLayoutForR4()); + return CreateFromArray(values); } template /* static */ std::unique_ptr Literal::CreateR4FromArray4DWithLayout( const Array4D& values, const Layout& layout) { - auto literal = MakeUnique(); - literal->PopulateR4FromArray4DWithLayout(values, layout); - return literal; + return CreateFromArrayWithLayout(values, layout); } template @@ -1070,82 +1086,53 @@ void Literal::PopulateR2( } template -void Literal::PopulateR2FromArray2DWithLayout(const Array2D& values, - const Layout& layout) { +void Literal::PopulateFromArrayWithLayout(const Array& values, + const Layout& layout) { *mutable_shape() = ShapeUtil::MakeShapeWithLayout( - primitive_util::NativeToPrimitiveType(), - {values.height(), values.width()}, AsInt64Slice(layout.minor_to_major())); + primitive_util::NativeToPrimitiveType(), values.dimensions(), + AsInt64Slice(layout.minor_to_major())); + Reserve(values.num_elements()); + values.Each([this](tensorflow::gtl::ArraySlice indices, + NativeT value) { this->Set(indices, value); }); +} - const int64 dim1_size = values.width(); - const int64 dim0_size = values.height(); - CHECK_EQ(dim0_size, shape().dimensions(0)); - CHECK_EQ(dim1_size, shape().dimensions(1)); - Reserve(dim1_size * dim0_size); - for (int64 dim0 = 0; dim0 < dim0_size; ++dim0) { - for (int64 dim1 = 0; dim1 < dim1_size; ++dim1) { - Set({dim0, dim1}, values(dim0, dim1)); - } - } +template +void Literal::PopulateFromArray(const Array& values) { + PopulateFromArrayWithLayout( + values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions())); +} + +template +void Literal::PopulateR2FromArray2DWithLayout(const Array2D& values, + const Layout& layout) { + PopulateFromArrayWithLayout(values, layout); } template void Literal::PopulateR2FromArray2D(const Array2D& values) { - PopulateR2FromArray2DWithLayout(values, LayoutUtil::GetDefaultLayoutForR2()); + PopulateFromArray(values); } template void Literal::PopulateR3FromArray3DWithLayout(const Array3D& values, const Layout& layout) { - *mutable_shape() = ShapeUtil::MakeShapeWithLayout( - primitive_util::NativeToPrimitiveType(), - {values.n1(), values.n2(), values.n3()}, - AsInt64Slice(layout.minor_to_major())); - - CHECK_EQ(values.n1(), shape().dimensions(0)); - CHECK_EQ(values.n2(), shape().dimensions(1)); - CHECK_EQ(values.n3(), shape().dimensions(2)); - Reserve(values.n1() * values.n2() * values.n3()); - for (int64 dim0 = 0; dim0 < values.n1(); ++dim0) { - for (int64 dim1 = 0; dim1 < values.n2(); ++dim1) { - for (int64 dim2 = 0; dim2 < values.n3(); ++dim2) { - Set({dim0, dim1, dim2}, values(dim0, dim1, dim2)); - } - } - } + PopulateFromArrayWithLayout(values, layout); } template void Literal::PopulateR3FromArray3D(const Array3D& values) { - PopulateR3FromArray3DWithLayout(values, LayoutUtil::GetDefaultLayoutForR3()); + PopulateFromArray(values); } template void Literal::PopulateR4FromArray4DWithLayout(const Array4D& values, const Layout& layout) { - *mutable_shape() = ShapeUtil::MakeShapeWithLayout( - primitive_util::NativeToPrimitiveType(), - {values.planes(), values.depth(), values.height(), values.width()}, - AsInt64Slice(layout.minor_to_major())); - - CHECK_EQ(values.n1(), shape().dimensions(0)); - CHECK_EQ(values.n2(), shape().dimensions(1)); - CHECK_EQ(values.n3(), shape().dimensions(2)); - CHECK_EQ(values.n4(), shape().dimensions(3)); - Reserve(values.n1() * values.n2() * values.n3() * values.n4()); - for (int64 dim0 = 0; dim0 < values.n1(); ++dim0) { - for (int64 dim1 = 0; dim1 < values.n2(); ++dim1) { - for (int64 dim2 = 0; dim2 < values.n3(); ++dim2) { - for (int64 dim3 = 0; dim3 < values.n4(); ++dim3) { - Set({dim0, dim1, dim2, dim3}, values(dim0, dim1, dim2, dim3)); - } - } - } - } + PopulateFromArrayWithLayout(values, layout); } template void Literal::PopulateR4FromArray4D(const Array4D& values) { - PopulateR4FromArray4DWithLayout(values, LayoutUtil::GetDefaultLayoutForR4()); + PopulateFromArray(values); } template diff --git a/tensorflow/compiler/xla/protobuf_util.cc b/tensorflow/compiler/xla/protobuf_util.cc index c032cb8dc5adcbef9ffa64aa1e05bb5ccb49fc6a..787725e884c810fd724ab88ad7d4beaf3e0a6cc7 100644 --- a/tensorflow/compiler/xla/protobuf_util.cc +++ b/tensorflow/compiler/xla/protobuf_util.cc @@ -37,20 +37,6 @@ bool ProtobufEquals(const tensorflow::protobuf::Message& m1, return (serialized1 == serialized2); } -StatusOr ToJson(const tensorflow::protobuf::Message& message) { - string json_output; - tensorflow::protobuf::util::JsonPrintOptions json_options; - json_options.add_whitespace = true; - json_options.always_print_primitive_fields = true; - auto status = tensorflow::protobuf::util::MessageToJsonString( - message, &json_output, json_options); - if (!status.ok()) { - return InternalError("MessageToJsonString failed: %s", - status.error_message().data()); - } - return json_output; -} - namespace { string SanitizeFilename(const string& file_name) { @@ -65,17 +51,6 @@ string SanitizeFilename(const string& file_name) { } // namespace -Status DumpJsonToDirectory(const tensorflow::protobuf::Message& message, - const string& directory, const string& file_name) { - TF_ASSIGN_OR_RETURN(const string json_output, ToJson(message)); - - tensorflow::Env* env = tensorflow::Env::Default(); - TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(directory)); - string safe_file_name = SanitizeFileName(file_name) + ".json"; - const string path = tensorflow::io::JoinPath(directory, safe_file_name); - return tensorflow::WriteStringToFile(env, path, json_output); -} - Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message, const string& directory, const string& file_name) { tensorflow::Env* env = tensorflow::Env::Default(); diff --git a/tensorflow/compiler/xla/protobuf_util.h b/tensorflow/compiler/xla/protobuf_util.h index 7accb22e0c7720d5af896f8ca833ee26175fb89f..3667621367c7639c40ff17aee7b77305d4d34e33 100644 --- a/tensorflow/compiler/xla/protobuf_util.h +++ b/tensorflow/compiler/xla/protobuf_util.h @@ -32,17 +32,12 @@ namespace protobuf_util { extern bool ProtobufEquals(const tensorflow::protobuf::Message& m1, const tensorflow::protobuf::Message& m2); -// Returns 'message' as a JSON string. -StatusOr ToJson(const tensorflow::protobuf::Message& message); - -// Writes the given message in binary proto or JSON format to the path formed by -// joining 'directory/file_name.pb' (or file_name.json). The 'directory' is -// recursively created if it doesn't already exist, and the 'file_name' is -// sanitized by replacing illegal characters with underscore '_'. +// Writes the given message in binary proto to the path formed by joining +// 'directory/file_name.pb'. The 'directory' is recursively created if it +// doesn't already exist, and the 'file_name' is sanitized by replacing +// illegal characters with underscore '_'. Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message, const string& directory, const string& file_name); -Status DumpJsonToDirectory(const tensorflow::protobuf::Message& message, - const string& directory, const string& file_name); } // namespace protobuf_util } // namespace xla diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index aaa860ff951bbae45be5d014a53eadf64191f293..c931fe02e936a15f99391e3fa9003ff01e58be1b 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -2064,6 +2064,29 @@ tf_cc_test( ], ) +cc_library( + name = "hlo_runner", + srcs = ["hlo_runner.cc"], + hdrs = ["hlo_runner.h"], + deps = [ + ":executable", + ":hlo", + ":transfer_manager", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:backend", + "//tensorflow/compiler/xla/service:compiler", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + "//third_party/eigen3", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index ce4d109214b9ad236fbf125179276bf53f4cbf57..06e7ec0c7cbe8cd2fb73c7206de99bb7dac29688 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -475,8 +475,8 @@ StatusOr> CpuCompiler::Compile( // ownership is std::moved. const bool embed_ir_in_executable = module->config().debug_options().xla_embed_ir_in_executable(); - const string dump_debug_json_to = - module->config().debug_options().xla_dump_debug_json_to(); + const string xla_dump_hlo_proto_to = + module->config().debug_options().xla_dump_hlo_proto_to(); if (options::CpuParallelBackendRequested(module->config())) { VLOG(1) << "Using parallel cpu backend"; @@ -496,10 +496,10 @@ StatusOr> CpuCompiler::Compile( // print one ourselves. XLA_VLOG_LINES(2, assignment->ToString()); - if (!dump_debug_json_to.empty()) { + if (!xla_dump_hlo_proto_to.empty()) { HloProto proto = MakeHloProto(*module, *assignment); - TF_RETURN_IF_ERROR(protobuf_util::DumpJsonToDirectory( - proto, dump_debug_json_to, module->name())); + TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory( + proto, xla_dump_hlo_proto_to, module->name())); } // If we are using the parallel CPU backend, we need to create map from @@ -603,12 +603,11 @@ StatusOr> CpuCompiler::Compile( // print one ourselves. XLA_VLOG_LINES(2, assignment->ToString()); - if (!dump_debug_json_to.empty()) { + if (!xla_dump_hlo_proto_to.empty()) { HloProto proto = MakeHloProto(*module, *assignment); - TF_RETURN_IF_ERROR(protobuf_util::DumpJsonToDirectory( - proto, dump_debug_json_to, module->name())); + TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory( + proto, xla_dump_hlo_proto_to, module->name())); } - // Each computation is a single function. Emit all embedded computations // before the entry computation. The order of computations returned from // GetEmbeddedComputations guarantees that a called computation occurs @@ -775,12 +774,12 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, // print one ourselves. XLA_VLOG_LINES(2, assignment->ToString()); - const string dump_debug_json_to = - module->config().debug_options().xla_dump_debug_json_to(); - if (!dump_debug_json_to.empty()) { + const string xla_dump_hlo_proto_to = + module->config().debug_options().xla_dump_hlo_proto_to(); + if (!xla_dump_hlo_proto_to.empty()) { HloProto proto = MakeHloProto(*module, *assignment); - TF_RETURN_IF_ERROR(protobuf_util::DumpJsonToDirectory( - proto, dump_debug_json_to, module->name())); + TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory( + proto, xla_dump_hlo_proto_to, module->name())); } IrEmitter ir_emitter(*module, *assignment, &llvm_module, diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc index 5afb2e67fff639e9cabb3740c5240e1ca90b5644..c2213c8f2ef592c537daf9abe2ffa10b83a8fa4c 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc @@ -136,6 +136,8 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount( instruction->opcode() == HloOpcode::kCall || instruction->opcode() == HloOpcode::kCustomCall || instruction->opcode() == HloOpcode::kSelectAndScatter || + instruction->opcode() == HloOpcode::kGetTupleElement || + instruction->opcode() == HloOpcode::kBitcast || (instruction->opcode() == HloOpcode::kConvolution && PotentiallyImplementedAsEigenConvolution(*instruction)) || PotentiallyImplementedAsEigenDot(*instruction) || diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 3e16e4e3c42cebce75b4e4e95fd7c6477fb230ae..9c7ca9ea38ebf7db38700d0c42e8e11ec8a65c73 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -318,12 +318,12 @@ StatusOr> GpuCompiler::Compile( // print one ourselves. XLA_VLOG_LINES(2, buffer_assignment->ToString()); - const string dump_debug_json_to = - module->config().debug_options().xla_dump_debug_json_to(); - if (!dump_debug_json_to.empty()) { + const string xla_dump_hlo_proto_to = + module->config().debug_options().xla_dump_hlo_proto_to(); + if (!xla_dump_hlo_proto_to.empty()) { HloProto proto = MakeHloProto(*module, *buffer_assignment); - TF_RETURN_IF_ERROR(protobuf_util::DumpJsonToDirectory( - proto, dump_debug_json_to, module->name())); + TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory( + proto, xla_dump_hlo_proto_to, module->name())); } IrEmitterContext ir_emitter_context(module.get(), buffer_assignment.get(), diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 9b3104eaacdbb083db2a55c75fae3e94c8ff282f..51ead753f043ea97a2908eaaf85eb8727b42938c 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -373,8 +373,8 @@ string HloComputation::ToString(int nested_level) const { for (int i = 0; i < nested_level; i++) { s << " "; } - s << name() << " " << ShapeUtil::HumanString(ComputeProgramShape()) - << " { \n"; + s << "%" << name() << " " << ShapeUtil::HumanString(ComputeProgramShape()) + << " {\n"; for (const HloInstruction* instruction : MakeInstructionPostOrder()) { for (int i = 0; i < nested_level; i++) { s << " "; diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc new file mode 100644 index 0000000000000000000000000000000000000000..d5d7042a02b474f9b4793e7f5ed67c91420686b4 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -0,0 +1,199 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_runner.h" + +#include +#include +#include + +#define EIGEN_USE_THREADS + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/backend.h" +#include "tensorflow/compiler/xla/service/executable.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/transfer_manager.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/common_runtime/eigen_thread_pool.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace se = ::perftools::gputools; + +namespace xla { + +/*static*/ StatusOr> +HloRunner::ReadModuleFromHloProtoFile(const char* filename, + const DebugOptions& debug_options) { + HloProto proto; + TF_RETURN_IF_ERROR(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), + filename, &proto)); + HloModuleConfig config; + config.set_debug_options(debug_options); + TF_ASSIGN_OR_RETURN(auto module, HloModule::CreateFromProto( + proto.hlo_module(), + VersionedComputationHandle(), config)); + return std::move(module); +} + +// Define this in .cc file to avoid having to include eigen or forward declare +// these types in the header. +struct HloRunner::EigenThreadPoolWrapper { + std::unique_ptr pool; + std::unique_ptr device; +}; + +HloRunner::HloRunner() {} + +HloRunner::HloRunner(se::Platform* platform) { + BackendOptions backend_options; + backend_options.set_platform(platform); + backend_ = Backend::CreateBackend(backend_options).ConsumeValueOrDie(); + VLOG(1) << "Created HloRunner for platform: " << platform->Name(); +} + +HloRunner::~HloRunner() { + // Deallocate all the memory allocated during the tests. + for (auto& allocation : allocations_) { + backend().default_stream_executor()->Deallocate(&allocation); + } +} + +StatusOr HloRunner::Execute( + std::unique_ptr module, + tensorflow::gtl::ArraySlice arguments, + Shape* result_shape) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr executable, + backend().compiler()->Compile(std::move(module), + backend().default_stream_executor())); + + se::Stream stream(backend().default_stream_executor()); + stream.Init(); + + ExecutableRunOptions run_options; + run_options.set_stream(&stream); + run_options.set_allocator(backend().memory_allocator()); + run_options.set_inter_op_thread_pool(backend().inter_op_thread_pool()); + run_options.set_intra_op_thread_pool( + backend().eigen_intra_op_thread_pool_device()); + + HloExecutionProfile hlo_execution_profile; + ServiceExecutableRunOptions service_run_options( + run_options, backend().StreamBorrower(), + backend().inter_op_thread_pool()); + TF_ASSIGN_OR_RETURN( + se::DeviceMemoryBase result, + executable->ExecuteOnStream(&service_run_options, arguments, + &hlo_execution_profile)); + TF_RET_CHECK(stream.BlockHostUntilDone()); + + allocations_.push_back(result); + + *result_shape = executable->result_shape(); + + if (ShapeUtil::IsTuple(*result_shape)) { + // We must record element buffers of tuples as well to avoid leaks. + DCHECK(!ShapeUtil::IsNestedTuple(*result_shape)); + TF_ASSIGN_OR_RETURN( + std::vector element_buffers, + backend().transfer_manager()->ShallowCopyTupleFromDevice( + backend().default_stream_executor(), result, *result_shape)); + + // A tuple may contain the same buffer in more than one element. Keep track + // of the buffers already added to avoid duplicates in allocations_. + std::set added_opaques; + for (auto element_buffer : element_buffers) { + if (added_opaques.count(element_buffer.opaque()) == 0) { + CHECK(element_buffer.opaque() != nullptr); + added_opaques.insert(element_buffer.opaque()); + allocations_.push_back(element_buffer); + } + } + } + + return result; +} + +se::DeviceMemoryBase HloRunner::TransferToDevice(const Literal& literal) { + // Allocate memory on the device using the stream executor. + int64 allocation_size = + backend().transfer_manager()->GetByteSizeRequirement(literal.shape()); + se::DeviceMemoryBase allocation = + backend().default_stream_executor()->AllocateArray( + allocation_size); + allocations_.push_back(allocation); + + TF_CHECK_OK(backend().transfer_manager()->TransferLiteralToDevice( + backend().default_stream_executor(), literal, &allocation)); + + return allocation; +} + +std::unique_ptr HloRunner::TransferFromDevice( + const Shape& shape, se::DeviceMemoryBase device_base) { + auto literal = MakeUnique(); + TF_CHECK_OK(backend().transfer_manager()->TransferLiteralFromDevice( + backend().default_stream_executor(), device_base, shape, shape, + literal.get())); + return literal; +} + +std::unique_ptr HloRunner::ExecuteAndTransfer( + std::unique_ptr module, + tensorflow::gtl::ArraySlice arguments) { + Shape result_shape; + se::DeviceMemoryBase device_base = + Execute(std::move(module), arguments, &result_shape).ValueOrDie(); + return TransferFromDevice(result_shape, device_base); +} + +template <> +std::unique_ptr HloRunner::Execute( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice>& literals) { + std::vector arguments; + for (const auto& literal : literals) { + arguments.push_back(TransferToDevice(*literal)); + } + return ExecuteAndTransfer(std::move(module), arguments); +} + +template <> +std::unique_ptr HloRunner::Execute( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice& literals) { + std::vector arguments; + for (const auto& literal : literals) { + arguments.push_back(TransferToDevice(*literal)); + } + return ExecuteAndTransfer(std::move(module), arguments); +} + +Backend& HloRunner::backend() { + if (!backend_) { + backend_ = Backend::CreateDefaultBackend().ConsumeValueOrDie(); + VLOG(1) << "executing on platform " << backend().platform()->Name(); + } + return *backend_; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h new file mode 100644 index 0000000000000000000000000000000000000000..d74a1b59a8c98fb3bb115f2086b12ee4e3eab965 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_runner.h @@ -0,0 +1,100 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/backend.h" +#include "tensorflow/compiler/xla/service/compiler.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { + +// A base class for running an HloModule. This executes the given HloModule on a +// certain backend directly without using the client interface. HloModule can be +// explicitly built, or loaded from a serialization file (e.g., hlo proto file). +class HloRunner { + public: + HloRunner(); + + HloRunner(::perftools::gputools::Platform* platform); + + ~HloRunner(); + + // Reads the binary proto file in xla.HloProto format, creates and returns the + // HloModule. + static StatusOr> ReadModuleFromHloProtoFile( + const char* filename, const DebugOptions& debug_options); + + // Executes the given module with given literals as input and returns the + // result as a Literal. The LiteralPtr type accepts Literal* or + // std::unique_ptr. + template + std::unique_ptr Execute( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice& literals); + + // Executes the given module and returns a global data handle. + StatusOr Execute( + std::unique_ptr module, + tensorflow::gtl::ArraySlice + arguments, + Shape* result_shape); + + // Transfers the given literal to the device and returns the data handle. + perftools::gputools::DeviceMemoryBase TransferToDevice( + const Literal& literal); + + // Transfers the array referred to by the given handle from the device and + // returns as a Literal. + std::unique_ptr TransferFromDevice( + const Shape& shape, perftools::gputools::DeviceMemoryBase device_base); + + // Executes the given module and return the result as a Literal. + std::unique_ptr ExecuteAndTransfer( + std::unique_ptr module, + tensorflow::gtl::ArraySlice + arguments); + + // If backend is not created in the constructor, creates and returns the + // default backend. If creation fails, crashes the program. + // + // This creates the backend lazily so it's possible to instantiate an + // HloRunner in a program without any backends linked in. + Backend& backend(); + + private: + struct EigenThreadPoolWrapper; + + std::vector allocations_; + + std::unique_ptr thread_pool_wrapper_; + + std::unique_ptr backend_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_ diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc index 816c8a7485bb9c5c12d3dc9e17404c74460113f5..8c2640adf52f10c387e7a9c09c0d73a09c054919 100644 --- a/tensorflow/compiler/xla/service/transpose_folding.cc +++ b/tensorflow/compiler/xla/service/transpose_folding.cc @@ -58,14 +58,32 @@ TransposeFolding::OperandIndices CanFoldOperandsIntoConvolution( return {}; } - // We only support folding the RHS. - const int64 kRhsOperandIndex = 1; - auto& operand = *convolution.operand(kRhsOperandIndex); - if (operand.opcode() == HloOpcode::kTranspose && operand.user_count() == 1) { - return transposable_conv_operands(convolution, {kRhsOperandIndex}); + const ConvolutionDimensionNumbers& dnums = + convolution.convolution_dimension_numbers(); + + TransposeFolding::OperandIndices operand_set; + for (int64 i = 0; i < convolution.operand_count(); ++i) { + auto& operand = *convolution.operand(i); + if (operand.opcode() == HloOpcode::kTranspose && + operand.user_count() == 1) { + const auto& transpose_dimensions = operand.dimensions(); + // We can transpose the LHS so long as it doesn't move around spatial + // dimensions because ConvolutionDimensionNumbers doesn't have different + // fields for input and output spatial dimensions. + if (i == 0 && + std::any_of(dnums.spatial_dimensions().begin(), + dnums.spatial_dimensions().end(), + [&](const int64 spatial_dimension) { + return transpose_dimensions[spatial_dimension] != + spatial_dimension; + })) { + continue; + } + operand_set.push_back(i); + } } - return {}; + return transposable_conv_operands(convolution, operand_set); } using InstructionOperandsPair = @@ -98,40 +116,61 @@ bool FoldTransposeIntoDot(InstructionOperandsPair pair) { // Returns whether the module is changed. bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) { auto& convolution = *pair.first; - - // We only support fusing the RHS transpose into convolution. - // - // ConvolutionDimensionNumbers doesn't make enough of a distinction between - // the output and the activations. - // - // TODO(b/37125184): Support transposing the LHS too. - if (pair.second.size() != 1 || pair.second.front() != 1) { - return false; - } + auto& operand_indices = pair.second; const ConvolutionDimensionNumbers& dnums = convolution.convolution_dimension_numbers(); - HloInstruction& transpose = *convolution.mutable_operand(1); - CHECK_EQ(transpose.opcode(), HloOpcode::kTranspose); - const auto& transpose_dimensions = transpose.dimensions(); - HloInstruction& transpose_operand = *transpose.mutable_operand(0); - - // Everything remains the same except for the kernel dimension numbers. We - // need to apply the transpose permutation to the original shape to figure out - // what the new logical dimensions are. ConvolutionDimensionNumbers new_dnums = dnums; - new_dnums.set_kernel_input_feature_dimension( - transpose_dimensions[dnums.kernel_input_feature_dimension()]); - new_dnums.set_kernel_output_feature_dimension( - transpose_dimensions[dnums.kernel_output_feature_dimension()]); - for (auto& kernel_spatial_dimension : - *new_dnums.mutable_kernel_spatial_dimensions()) { - kernel_spatial_dimension = transpose_dimensions[kernel_spatial_dimension]; + + HloInstruction* new_lhs; + const int64 kLhsIdx = 0; + if (std::find(operand_indices.begin(), operand_indices.end(), kLhsIdx) != + operand_indices.end()) { + HloInstruction& transpose = *convolution.mutable_operand(kLhsIdx); + const auto& transpose_dimensions = transpose.dimensions(); + HloInstruction& transpose_operand = *transpose.mutable_operand(0); + + // Everything remains the same except for the input/output dimension + // numbers. We need to apply the transpose permutation to the original shape + // to figure out what the new logical dimensions are. + new_dnums.set_input_batch_dimension( + transpose_dimensions[dnums.input_batch_dimension()]); + new_dnums.set_input_feature_dimension( + transpose_dimensions[dnums.input_feature_dimension()]); + for (const auto& spatial_dimension : dnums.spatial_dimensions()) { + CHECK_EQ(spatial_dimension, transpose_dimensions[spatial_dimension]); + } + new_lhs = &transpose_operand; + } else { + new_lhs = convolution.mutable_operand(kLhsIdx); + } + + HloInstruction* new_rhs; + const int64 kRhsIdx = 1; + if (std::find(operand_indices.begin(), operand_indices.end(), kRhsIdx) != + operand_indices.end()) { + HloInstruction& transpose = *convolution.mutable_operand(kRhsIdx); + const auto& transpose_dimensions = transpose.dimensions(); + HloInstruction& transpose_operand = *transpose.mutable_operand(0); + + // Everything remains the same except for the kernel dimension numbers. We + // need to apply the transpose permutation to the original shape to figure + // out what the new logical dimensions are. + new_dnums.set_kernel_input_feature_dimension( + transpose_dimensions[dnums.kernel_input_feature_dimension()]); + new_dnums.set_kernel_output_feature_dimension( + transpose_dimensions[dnums.kernel_output_feature_dimension()]); + for (auto& kernel_spatial_dimension : + *new_dnums.mutable_kernel_spatial_dimensions()) { + kernel_spatial_dimension = transpose_dimensions[kernel_spatial_dimension]; + } + new_rhs = &transpose_operand; + } else { + new_rhs = convolution.mutable_operand(kRhsIdx); } auto new_conv = HloInstruction::CreateConvolve( - convolution.shape(), convolution.mutable_operand(0), &transpose_operand, - convolution.window(), new_dnums); + convolution.shape(), new_lhs, new_rhs, convolution.window(), new_dnums); TF_CHECK_OK(convolution.parent()->ReplaceWithNewInstruction( &convolution, std::move(new_conv))); diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc index a6161b46460068b83fa3f0762e49a10a83b1471c..00462f9be1e9beb2f2694060ebfaa70b0b9dd4a0 100644 --- a/tensorflow/compiler/xla/service/transpose_folding_test.cc +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -313,8 +313,7 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeRhs) { new_conv->convolution_dimension_numbers().kernel_spatial_dimensions(1)); } -// Test that a transpose of the activations does not get folded into -// convolution. +// Test that a transpose of the activations gets folded into convolution. TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) { auto builder = HloComputation::Builder("entry_computation"); HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( @@ -348,18 +347,25 @@ TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) { module.AddEntryComputation(builder.Build(conv)); FoldTranspose(&module); - // Instructions after folding: transpose_x, y, and the convolution. + // Instructions after folding: x, y, and the convolution. std::unordered_set instruction_set( entry_computation->instructions().begin(), entry_computation->instructions().end()); - CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation."; - CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation."; - CHECK_EQ(1, instruction_set.erase(transpose_x)) - << "transpose_x is not in entry_computation."; - CHECK_EQ(1, instruction_set.erase(conv)) - << "transpose_x is not in entry_computation."; - CHECK_EQ(0, instruction_set.size()) - << "entry_computation should contain exactly 4 instructions."; + EXPECT_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation."; + EXPECT_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation."; + EXPECT_EQ(1, instruction_set.size()) + << "entry_computation should contain exactly 3 instructions."; + HloInstruction* new_conv = *instruction_set.begin(); + EXPECT_EQ(HloOpcode::kConvolution, new_conv->opcode()); + EXPECT_EQ(dnums.input_feature_dimension(), + new_conv->convolution_dimension_numbers().input_batch_dimension()); + EXPECT_EQ( + dnums.input_batch_dimension(), + new_conv->convolution_dimension_numbers().input_feature_dimension()); + EXPECT_EQ(dnums.spatial_dimensions(0), + new_conv->convolution_dimension_numbers().spatial_dimensions(0)); + EXPECT_EQ(dnums.spatial_dimensions(1), + new_conv->convolution_dimension_numbers().spatial_dimensions(1)); } } // namespace diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc index b3506b72bf5ab1aa27704c18c8a1dc69881caf71..065d2580c6807f64919b8edef7f28446fe6c2c2a 100644 --- a/tensorflow/compiler/xla/service/user_computation.cc +++ b/tensorflow/compiler/xla/service/user_computation.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" @@ -1843,10 +1844,17 @@ UserComputation::GetEmbeddedComputations( XLA_VLOG_LINES(3, session_computation_.DebugString()); std::vector computations; + std::vector sorted_handles; for (const auto& handle_request : session_computation_.requests()) { - int64 handle_value = handle_request.first; + sorted_handles.push_back(handle_request.first); + } + std::sort(sorted_handles.begin(), sorted_handles.end()); + for (int64 handle : sorted_handles) { + const auto& handle_request = session_computation_.requests().find(handle); + CHECK(handle_request != session_computation_.requests().end()); + int64 handle_value = handle_request->first; if (handle_value <= version) { - const OperationRequest& request = handle_request.second; + const OperationRequest& request = handle_request->second; switch (request.request().op_case()) { case OpRequest::kCallRequest: { CHECK_EQ(1, request.embedded_computation_versions_size()); diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 8e16056b239a9e1d1776bfe91f6e36862e0feeec..af583bed625b7c28e1d73fc014197d249efd90bc 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -102,6 +102,32 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { return true; } +// Constructs and returns the new shape with the given minor_to_major order in +// its Layout. +StatusOr MakeShapeWithLayoutInternal( + PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice minor_to_major) { + if (dimensions.size() != minor_to_major.size()) { + return InvalidArgument("Dimensions size is %ld, but layout size is %ld.", + dimensions.size(), minor_to_major.size()); + } + if (element_type == OPAQUE || element_type == TUPLE) { + return InvalidArgument("Unsupported element type: %s", + PrimitiveType_Name(element_type).c_str()); + } + Shape shape = ShapeUtil::MakeShape(element_type, dimensions); + auto min2maj = shape.mutable_layout()->mutable_minor_to_major(); + min2maj->Clear(); + for (int64 value : minor_to_major) { + min2maj->Add(value); + } + if (!shape.has_layout()) { + return InvalidArgument("Shape has no layout."); + } + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(shape)); + return shape; +} + } // namespace /* static */ bool ShapeUtil::Equal(const Shape& lhs, const Shape& rhs) { @@ -152,16 +178,8 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { /* static */ Shape ShapeUtil::MakeShapeWithLayout( PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, tensorflow::gtl::ArraySlice minor_to_major) { - CHECK_EQ(dimensions.size(), minor_to_major.size()); - Shape shape = MakeShape(element_type, dimensions); - auto min2maj = shape.mutable_layout()->mutable_minor_to_major(); - min2maj->Clear(); - for (int64 value : minor_to_major) { - min2maj->Add(value); - } - DCHECK(shape.has_layout()); - TF_DCHECK_OK(ValidateShape(shape)); - return shape; + return MakeShapeWithLayoutInternal(element_type, dimensions, minor_to_major) + .ValueOrDie(); } /* static */ Shape ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( @@ -499,11 +517,10 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { // Extract the layout minor-to-major and set it. TF_ASSIGN_OR_RETURN(std::vector min2maj, comma_list_to_int64s(layout_string)); - TF_RET_CHECK(dimensions.size() == min2maj.size()); - result = - ShapeUtil::MakeShapeWithLayout(primitive_type, dimensions, min2maj); + TF_ASSIGN_OR_RETURN(result, MakeShapeWithLayoutInternal( + primitive_type, dimensions, min2maj)); } - TF_DCHECK_OK(ShapeUtil::ValidateShape(result)); + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(result)); return std::move(result); } diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index b02d906d93e8854fc33fc49514f97e6a1b81b110..43127925e65a8e988cb56b22af75910843479576 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -102,28 +102,18 @@ cc_library( deps = [ ":literal_test_util", "//tensorflow/compiler/xla:shape_layout", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", - "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:backend", - "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:computation_layout", - "//tensorflow/compiler/xla/service:computation_placer", - "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/service:hlo_execution_profile", - "//tensorflow/compiler/xla/service:hlo_graph_dumper", - "//tensorflow/compiler/xla/service:transfer_manager", - "//tensorflow/core:core_cpu_internal", + "//tensorflow/compiler/xla/service:hlo_runner", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", - "//third_party/eigen3", ], ) diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 26513d6ce8e0b8896e9f9838ecf28f1ed5bbb383..3e244fbfd9d63cc835ecb940d4accce7aeaf6a45 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -19,24 +19,9 @@ limitations under the License. #include #include -#define EIGEN_USE_THREADS - -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/compiler/xla/service/backend.h" -#include "tensorflow/compiler/xla/service/computation_layout.h" -#include "tensorflow/compiler/xla/service/executable.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_execution_profile.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/transfer_manager.h" -#include "tensorflow/compiler/xla/shape_layout.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/common_runtime/eigen_thread_pool.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -45,22 +30,6 @@ namespace se = ::perftools::gputools; namespace xla { -// Define this in .cc file to avoid having to include eigen or forward declare -// these types in the header. -struct HloTestBase::EigenThreadPoolWrapper { - std::unique_ptr pool; - std::unique_ptr device; -}; - -HloTestBase::HloTestBase() {} - -HloTestBase::~HloTestBase() { - // Deallocate all the memory allocated during the tests. - for (auto& allocation : allocations_) { - backend().default_stream_executor()->Deallocate(&allocation); - } -} - /* static */ std::unique_ptr HloTestBase::CreateNewModule() { HloModuleConfig config; @@ -80,98 +49,25 @@ StatusOr HloTestBase::Execute( tensorflow::gtl::ArraySlice arguments, Shape* result_shape) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr executable, - backend().compiler()->Compile(std::move(module), - backend().default_stream_executor())); - - se::Stream stream(backend().default_stream_executor()); - stream.Init(); - - ExecutableRunOptions run_options; - run_options.set_stream(&stream); - run_options.set_allocator(backend().memory_allocator()); - run_options.set_inter_op_thread_pool(backend().inter_op_thread_pool()); - run_options.set_intra_op_thread_pool( - backend().eigen_intra_op_thread_pool_device()); - - HloExecutionProfile hlo_execution_profile; - ServiceExecutableRunOptions service_run_options( - run_options, backend().StreamBorrower(), - backend().inter_op_thread_pool()); - TF_ASSIGN_OR_RETURN( - se::DeviceMemoryBase result, - executable->ExecuteOnStream(&service_run_options, arguments, - &hlo_execution_profile)); - TF_RET_CHECK(stream.BlockHostUntilDone()); - - allocations_.push_back(result); - - *result_shape = executable->result_shape(); - - if (ShapeUtil::IsTuple(*result_shape)) { - // We must record element buffers of tuples as well to avoid leaks. - DCHECK(!ShapeUtil::IsNestedTuple(*result_shape)); - TF_ASSIGN_OR_RETURN( - std::vector element_buffers, - backend().transfer_manager()->ShallowCopyTupleFromDevice( - backend().default_stream_executor(), result, *result_shape)); - - // A tuple may contain the same buffer in more than one element. Keep track - // of the buffers already added to avoid duplicates in allocations_. - std::set added_opaques; - for (auto element_buffer : element_buffers) { - if (added_opaques.count(element_buffer.opaque()) == 0) { - CHECK(element_buffer.opaque() != nullptr); - added_opaques.insert(element_buffer.opaque()); - allocations_.push_back(element_buffer); - } - } - } - - return result; + return runner_.Execute(std::move(module), arguments, result_shape); } se::DeviceMemoryBase HloTestBase::TransferToDevice(const Literal& literal) { - // Allocate memory on the device using the stream executor. - int64 allocation_size = - backend().transfer_manager()->GetByteSizeRequirement(literal.shape()); - se::DeviceMemoryBase allocation = - backend().default_stream_executor()->AllocateArray( - allocation_size); - allocations_.push_back(allocation); - - TF_CHECK_OK(backend().transfer_manager()->TransferLiteralToDevice( - backend().default_stream_executor(), literal, &allocation)); - - return allocation; + return runner_.TransferToDevice(literal); } std::unique_ptr HloTestBase::TransferFromDevice( const Shape& shape, se::DeviceMemoryBase device_base) { - auto literal = MakeUnique(); - TF_CHECK_OK(backend().transfer_manager()->TransferLiteralFromDevice( - backend().default_stream_executor(), device_base, shape, shape, - literal.get())); - return literal; + return runner_.TransferFromDevice(shape, device_base); } std::unique_ptr HloTestBase::ExecuteAndTransfer( std::unique_ptr module, tensorflow::gtl::ArraySlice arguments) { - Shape result_shape; - se::DeviceMemoryBase device_base = - Execute(std::move(module), arguments, &result_shape).ValueOrDie(); - return TransferFromDevice(result_shape, device_base); + return runner_.ExecuteAndTransfer(std::move(module), arguments); } -Backend& HloTestBase::backend() { - if (!backend_) { - backend_ = Backend::CreateDefaultBackend().ConsumeValueOrDie(); - VLOG(1) << "executing on platform " << backend().platform()->Name(); - } - return *backend_; -} +Backend& HloTestBase::backend() { return runner_.backend(); } /* static */ string HloTestBase::TestName() { diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 275f1f5c7baa11245186d119f5b38b4d02b84566..7f068dce36be3546298de2f06bf6d33446d07ca2 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -21,12 +21,12 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/backend.h" -#include "tensorflow/compiler/xla/service/compiler.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_runner.h" +#include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" -#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -39,10 +39,9 @@ namespace xla { // building a graph of HLO instructions to run. class HloTestBase : public ::testing::Test { protected: - struct EigenThreadPoolWrapper; - HloTestBase(); + HloTestBase() {} - ~HloTestBase() override; + ~HloTestBase() override {} // Creates a new HLO module for a test. The module created will have // TestName() for its name; it will also automatically populate its debug @@ -102,23 +101,12 @@ class HloTestBase : public ::testing::Test { static string TestName(); - // Creates (if necessary) and returns the default backend. If creation fails, - // crashes the program. - // - // This creates the backend lazily so it's possible to instantiate an - // HloTestBase in a program without any backends linked in. + // Returns the backend owned by the HloRunner. Backend& backend(); - // This vector contains handles of all the device memory allocations performed - // by the test. These are deallocated on destruction of the test object. - std::vector allocations_; + HloRunner runner_; ErrorSpec error_spec_{0.0001}; - - std::unique_ptr thread_pool_wrapper_; - - private: - std::unique_ptr backend_; // Lazily populated. Access via backend(). }; } // namespace xla diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index 0451537af777e127df333da8a941a89e6fe315c2..759921dce5acf3cd23a121776f3ab0731c9bb623 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -210,6 +210,18 @@ tf_cc_binary( ], ) +tf_cc_binary( + name = "hlo_proto_to_json", + srcs = ["hlo_proto_to_json.cc"], + deps = [ + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:hlo_proto", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/tools/hlo_proto_to_json.cc b/tensorflow/compiler/xla/tools/hlo_proto_to_json.cc new file mode 100644 index 0000000000000000000000000000000000000000..4e02e17db65c0a4220672733be8319e1a0cc4f0f --- /dev/null +++ b/tensorflow/compiler/xla/tools/hlo_proto_to_json.cc @@ -0,0 +1,91 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Usage: +// hlo_proto_to_json --input_file=some_binary_proto +// --output_file=path_to_dump_output +// +// Reads one serilized Hlo module, convert it into JSON format and dump into +// some output directory. some_binaray_proto is obtained by serializing Hlo +// module to disk using --xla_dump_hlo_proto_to debug optoin. + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/command_line_flags.h" + +using tensorflow::Env; +using xla::string; + +namespace xla { +namespace tools { + +StatusOr ToJson(const tensorflow::protobuf::Message& message) { + string json_output; + tensorflow::protobuf::util::JsonPrintOptions json_options; + json_options.add_whitespace = true; + json_options.always_print_primitive_fields = true; + auto status = tensorflow::protobuf::util::MessageToJsonString( + message, &json_output, json_options); + if (!status.ok()) { + return InternalError("MessageToJsonString failed: %s", + status.error_message().data()); + } + return json_output; +} + +void RealMain(const string& input, const string& output) { + HloProto hlo_proto; + TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), input, + &hlo_proto)) + << "Can't open, read, or parse input file " << input; + + auto statusor = ToJson(hlo_proto); + QCHECK(statusor.ok()) << "Error converting " << input << " to JSON." + << statusor.status(); + + TF_CHECK_OK(tensorflow::WriteStringToFile(tensorflow::Env::Default(), output, + statusor.ValueOrDie())); +} + +} // namespace tools +} // namespace xla + +int main(int argc, char** argv) { + string input_file, output_file; + const std::vector flag_list = { + tensorflow::Flag("input_file", &input_file, "file to convert."), + tensorflow::Flag("output_file", &output_file, "converted file"), + }; + const string usage = tensorflow::Flags::Usage(argv[0], flag_list); + bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); + tensorflow::port::InitMain(usage.c_str(), &argc, &argv); + QCHECK(parse_ok && argc == 1) << "\n" << usage; + + QCHECK(!input_file.empty()) << "--input_file is required"; + QCHECK(!output_file.empty()) << "--output_file is required"; + + xla::tools::RealMain(input_file, output_file); + + return 0; +} diff --git a/tensorflow/compiler/xla/tools/parser/BUILD b/tensorflow/compiler/xla/tools/parser/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..c84ca9fc833881ce49bcaad5dd85394145151912 --- /dev/null +++ b/tensorflow/compiler/xla/tools/parser/BUILD @@ -0,0 +1,84 @@ +# Build file for the Hlo parser. + +licenses(["notice"]) # Apache 2.0 + +package( + default_visibility = [":friends"], +) + +package_group( + name = "friends", + includes = [ + "//tensorflow/compiler/xla:friends", + ], +) + +# Filegroup used to collect source files for dependency checking. +filegroup( + name = "c_srcs", + data = glob([ + "**/*.cc", + "**/*.h", + ]), +) + +load("//tensorflow:tensorflow.bzl", "tf_cc_test") + +cc_library( + name = "hlo_lexer", + srcs = ["hlo_lexer.cc"], + hdrs = [ + "hlo_lexer.h", + "hlo_token.h", + ], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:lib", + "//tensorflow/core:regexp_internal", + ], +) + +cc_library( + name = "hlo_parser", + srcs = ["hlo_parser.cc"], + hdrs = ["hlo_parser.h"], + deps = [ + ":hlo_lexer", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], +) + +tf_cc_test( + name = "hlo_parser_test", + size = "small", + srcs = ["hlo_parser_test.cc"], + deps = [ + ":hlo_parser", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +# ----------------------------------------------------------------------------- + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/compiler/xla/tools/parser/README.md b/tensorflow/compiler/xla/tools/parser/README.md new file mode 100644 index 0000000000000000000000000000000000000000..a334bc2b297490a6e3a5c976656f806d07d634cb --- /dev/null +++ b/tensorflow/compiler/xla/tools/parser/README.md @@ -0,0 +1,69 @@ +# HloModule string syntax + +TODO: Support subcomputations (for fusion, reduce, while, ...). + +TODO: Support ops that require extra attributes, e.g. dimensions, strides. + +```yacc +hlo_module + : 'HloModule' name computation + ; + +computation + : 'ENTRY' name param_list '->' shape instruction_list + ; + +instruction_list + : '{' instruction_list1 '}' + ; +instruction_list1 + : instruction + | instruction_list1 instruction + ; +instruction + : name '=' shape opcode operands + ; + +operands + : '(' operands1 ')' + ; +operands1 + : /*empty*/ + | operand + | operands1 ',' operand + ; +operand + : shape name + ; + +param_list + : '(' param_list1 ')' + ; +param_list1 + : /*empty*/ + | param + | param_list1 ',' param + ; +param + : name shape + ; + +shape + : shape_val_ + | '(' tuple_elements ')' + ; +tuple_elements + : /*empty*/ + | shape (',' shape)* + ; + +name + : identifier ':' + | '%' identifier + ; + +identifier + : [a-zA-Z_][a-zA-Z0-9_.-]* + ; + +``` diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc b/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc new file mode 100644 index 0000000000000000000000000000000000000000..3e84ffcbd2c36884d1f56d08a73d8a8a3c1e785d --- /dev/null +++ b/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc @@ -0,0 +1,270 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/tools/parser/hlo_lexer.h" + +#include + +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/gtl/optional.h" +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/platform/regexp.h" + +namespace xla { +namespace tools { + +using tensorflow::StringPiece; + +namespace { + +constexpr int kEOF = -1; +constexpr int kError = -2; + +// [a-zA-Z0-9_.-] +bool IsIdentifierChar(char c) { + return isalnum(static_cast(c)) || c == '-' || c == '.' || + c == '_'; +} + +} // namespace + +int HloLexer::GetNextChar() { + int current_char = PeekCurrentChar(); + if (current_char != kEOF && current_char != kError) { + current_ptr_++; + } + return current_char; +} + +int HloLexer::PeekCurrentChar() const { + if (current_ptr_ == buf_.end()) { + return kEOF; + } + char current_char = *current_ptr_; + if (current_char == 0) { + // '\0' should not appear in the middle of the string. + return kError; + } + return static_cast(current_char); +} + +bool HloLexer::CanDereference(const char* ptr) const { + return ptr < buf_.end() && ptr >= buf_.begin(); +} + +StringPiece HloLexer::StringPieceFromPointers(const char* begin, + const char* end) const { + CHECK(begin <= end); + CHECK(begin == buf_.end() || CanDereference(begin)); + CHECK(end == buf_.end() || CanDereference(end)); + return StringPiece(begin, end - begin); +} + +tensorflow::RegexpStringPiece HloLexer::RegexpStringPieceFromPointers( + const char* begin, const char* end) const { + CHECK(begin <= end); + CHECK(begin == buf_.end() || CanDereference(begin)); + CHECK(end == buf_.end() || CanDereference(end)); + return tensorflow::RegexpStringPiece(begin, end - begin); +} + +TokKind HloLexer::LexToken() { + while (true) { + token_start_ = current_ptr_; + + int current_char = GetNextChar(); + switch (current_char) { + default: + // [a-zA-Z_] + if (isalpha(static_cast(current_char)) || + current_char == '_') { + return LexIdentifier(); + } + return TokKind::kError; + case kEOF: + // Hit the end of the input buffer. + return TokKind::kEof; + case kError: + // Hit an invalid character in the input buffer. + return TokKind::kError; + case ' ': + case '\t': + case '\n': + case '\r': + // Ignore whitespace. + continue; + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + case '-': + if (current_char == '-' && PeekCurrentChar() == '>') { + current_ptr_++; + return TokKind::kArrow; + } + return LexDigitOrNegative(); + case '=': + return TokKind::kEqual; + case ',': + return TokKind::kComma; + case '%': + return LexPercent(); + case ':': + return TokKind::kColon; + case '[': + return TokKind::kLsquare; + case ']': + return TokKind::kRsquare; + case '{': + return TokKind::kLbrace; + case '}': + return TokKind::kRbrace; + case '(': + return TokKind::kLparen; + case ')': + return TokKind::kRparen; + } + } +} + +// Lex a shape, name, keyword, or opcode. +// shape ::= ([a-zA-Z0-9_]*[0-9]*)\[([0-9,]*)\](?:\s*{([0-9,]*)})? +// name ::= [a-zA-Z_][a-zA-Z0-9_.-]*: +// keyword ::= HloModule, ENTRY, ... +// opcode ::= add, greater-than, ... +TokKind HloLexer::LexIdentifier() { + { + auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); + // 'consumable' will be advanced iff its prefix matches the pattern. + static LazyRE2 shape_pattern = { + R"(^(\w*\d*)\[([\d,]*)\](?:\s*{([\d,]*)})?)"}; + if (RE2::Consume(&consumable, *shape_pattern)) { + auto status_or_shape = ShapeUtil::ParseShapeString( + StringPieceFromPointers(token_start_, consumable.begin())); + if (status_or_shape.ok()) { + // This is a shape string. + shape_val_ = status_or_shape.ValueOrDie(); + current_ptr_ = consumable.begin(); + return TokKind::kShape; + } + } + } + + while (IsIdentifierChar(PeekCurrentChar())) { + current_ptr_++; + } + + // If followed by ':', it's a name. + if (PeekCurrentChar() == ':') { + str_val_.assign(token_start_, current_ptr_); + current_ptr_++; // skip ':' + return TokKind::kName; + } + + StringPiece identifier = StringPieceFromPointers(token_start_, current_ptr_); + + // See if this is a keyword. +#define KEYWORD(STR) \ + do { \ + if (identifier == #STR) { \ + return TokKind::kw_##STR; \ + } \ + } while (false) + + KEYWORD(true); + KEYWORD(false); + KEYWORD(HloModule); + KEYWORD(ENTRY); + +#undef KEYWORD + + // See if this is an opcode. + auto opcode = StringToHloOpcode(identifier.ToString()); + if (opcode.ok()) { + opcode_val_ = opcode.ValueOrDie(); + return TokKind::kOpcode; + } + + current_ptr_ = token_start_ + 1; + return TokKind::kError; +} + +// Lex names after a % character. +// name ::= [a-zA-Z_][a-zA-Z0-9_.-]* +TokKind HloLexer::LexPercent() { + const char* name_start = current_ptr_; + if (isalpha(static_cast(PeekCurrentChar())) || + PeekCurrentChar() == '_') { + current_ptr_++; + while (IsIdentifierChar(PeekCurrentChar())) { + current_ptr_++; + } + str_val_.assign(name_start, current_ptr_); + return TokKind::kName; + } + return TokKind::kError; +} + +// Lex integer and floating-point values. +// int [-]?[0-9]+ +// fp with exp [-]?([0-9]+|[0-9]+[.][0-9]*|[0-9]*[.][0-9]+)([eE][+-]?[0-9]+) +// fp without exp [-]?([0-9]+[.][0-9]*|[0-9]*[.][0-9]+) +TokKind HloLexer::LexDigitOrNegative() { + auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); + static LazyRE2 float_pattern = { + R"([-]?((\d+|\d+[.]\d*|\d*[.]\d+)([eE][+-]?\d+))|(\d+[.]\d*|\d*[.]\d+))"}; + if (RE2::Consume(&consumable, *float_pattern)) { + current_ptr_ = consumable.begin(); + tensorflow::strings::safe_strtod(string(token_start_, current_ptr_).c_str(), + &decimal_val_); + return TokKind::kDecimal; + } + + static LazyRE2 int_pattern = {R"([-]?\d+)"}; + if (RE2::Consume(&consumable, *int_pattern)) { + current_ptr_ = consumable.begin(); + tensorflow::strings::safe_strto64( + StringPieceFromPointers(token_start_, current_ptr_), &int64_val_); + return TokKind::kInt; + } + + return TokKind::kError; +} + +StringPiece HloLexer::GetCurrentLine() const { + const char* start = token_start_; + const char* end = current_ptr_; + if (!CanDereference(start) || !CanDereference(end)) { + return "LINE OUT OF RANGE"; + } + while (start > buf_.begin() && *start != '\n') { + start--; + } + while (end < buf_.end() && *end != '\n') { + end++; + } + return StringPieceFromPointers(start, end); +} + +} // namespace tools +} // namespace xla diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.h b/tensorflow/compiler/xla/tools/parser/hlo_lexer.h new file mode 100644 index 0000000000000000000000000000000000000000..20278fd6cde4374629c33e1cab1e7c2839ae5731 --- /dev/null +++ b/tensorflow/compiler/xla/tools/parser/hlo_lexer.h @@ -0,0 +1,108 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_LEXER_H_ +#define TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_LEXER_H_ + +#include + +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_token.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/regexp.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace tools { + +// Lexer for the HloModule::ToString() format text. +class HloLexer { + public: + explicit HloLexer(tensorflow::StringPiece buf) : buf_(buf) { + current_ptr_ = buf_.begin(); + } + + TokKind Lex() { return current_kind_ = LexToken(); } + TokKind GetKind() const { return current_kind_; } + string GetStrVal() const { + CHECK(GetKind() == TokKind::kName); + return str_val_; + } + Shape GetShapeVal() const { + CHECK(GetKind() == TokKind::kShape); + return shape_val_; + } + HloOpcode GetOpcodeVal() const { + CHECK(GetKind() == TokKind::kOpcode); + return opcode_val_; + } + int64 GetInt64Val() const { + CHECK(GetKind() == TokKind::kInt); + return int64_val_; + } + double GetDecimalVal() const { + CHECK(GetKind() == TokKind::kDecimal); + return decimal_val_; + } + + // Returns the line of text that is currently being lexed. + tensorflow::StringPiece GetCurrentLine() const; + + private: + // Returns the current character. If it's neither the end of input buffer nor + // an invalid character, moves the pointer forward. + int GetNextChar(); + + // Returns the current character. + int PeekCurrentChar() const; + + // Creates StringPiece with the given begin and end. Exits if the begin > end, + // or it's out of the range of the current buffer. + tensorflow::StringPiece StringPieceFromPointers(const char* begin, + const char* end) const; + tensorflow::RegexpStringPiece RegexpStringPieceFromPointers( + const char* begin, const char* end) const; + + // Returns true if the given ptr is dereferenceable within the range of the + // current buffer. + bool CanDereference(const char* ptr) const; + + TokKind LexToken(); + + TokKind LexIdentifier(); + TokKind LexPercent(); + TokKind LexShape(); + TokKind LexConstant(); + TokKind LexDigitOrNegative(); + + const tensorflow::StringPiece buf_; + const char* current_ptr_; + + // Information about the current token. + const char* token_start_; + TokKind current_kind_; + string str_val_; + Shape shape_val_; + HloOpcode opcode_val_; + int64 int64_val_; + double decimal_val_; +}; + +} // namespace tools +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_LEXER_H_ diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc new file mode 100644 index 0000000000000000000000000000000000000000..57700493e6cc699a303092cc6b379de88d2c3cbc --- /dev/null +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc @@ -0,0 +1,502 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" + +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace xla { +namespace tools { + +namespace { + +using tensorflow::StringPiece; +using tensorflow::strings::StrCat; + +// Parser for the HloModule::ToString() format text. +class HloParser { + public: + explicit HloParser(StringPiece str) : lexer_(str) {} + + // Runs the parser. Returns false if an error occurred. + bool Run(); + + // Returns the parsed HloModule. + std::unique_ptr ConsumeHloModule() { return std::move(module_); } + + // Returns the error information. + string GetError() const { return tensorflow::str_util::Join(error_, "\n"); } + + private: + // ParseXXX returns false if an error occurred. + bool ParseHloModule(); + bool ParseComputation(); + bool ParseInstructionList(HloComputation::Builder* builder); + bool ParseInstruction(HloComputation::Builder* builder); + bool ParseLiteral(std::unique_ptr* literal, const Shape& shape); + bool ParseOperands(std::vector* operands, + const int expected_size); + bool ParseParamList(); + bool ParseName(string* result); + bool ParseShape(Shape* result); + bool ParseOpcode(HloOpcode* result); + bool ParseInt64(int64* result); + bool ParseDecimal(double* result); + bool ParseBool(bool* result); + bool ParseToken(TokKind kind, const string& msg); + + // Logs the current parsing line and the given message. Always returns false. + bool TokenError(StringPiece msg); + + // If the current token is 'kind', eats it (i.e. lexes the next token) and + // returns true. + bool EatIfPresent(TokKind kind); + + // Adds the instruction to the pool. Returns false and emits an error if the + // instruction already exists. + bool AddInstruction(const string& name, HloInstruction* instruction); + + // The map from the instruction name to the instruction. This does not own the + // instructions. + std::unordered_map instruction_pool_; + + HloLexer lexer_; + std::unique_ptr module_; + std::vector error_; +}; + +bool HloParser::TokenError(StringPiece msg) { + error_.push_back( + StrCat("was parsing \"", lexer_.GetCurrentLine(), "\"; ", msg)); + return false; +} + +bool HloParser::Run() { + lexer_.Lex(); + return ParseHloModule(); +} + +// ::= 'HloModule' name computation +bool HloParser::ParseHloModule() { + if (lexer_.GetKind() != TokKind::kw_HloModule) { + return TokenError("expects HloModule"); + } + // Eat 'HloModule' + lexer_.Lex(); + + string name; + if (!ParseName(&name)) { + return false; + } + + module_ = MakeUnique(name); + + return ParseComputation(); +} + +// computation ::= 'ENTRY' name param_list '->' shape instruction_list +bool HloParser::ParseComputation() { + string name; + if (!ParseToken(TokKind::kw_ENTRY, "expects 'ENTRY'") || !ParseName(&name)) { + return false; + } + auto builder = MakeUnique(name); + + Shape shape; + if (!ParseParamList() || !ParseToken(TokKind::kArrow, "expects '->'") || + !ParseShape(&shape) || !ParseInstructionList(builder.get())) { + return false; + } + module_->AddEntryComputation(builder->Build()); + return true; +} + +// instruction_list ::= '{' instruction_list1 '}' +// instruction_list1 ::= (instruction)+ +bool HloParser::ParseInstructionList(HloComputation::Builder* builder) { + if (!ParseToken(TokKind::kLbrace, + "expects '{' at the beginning of instruction list.")) { + return false; + } + do { + if (!ParseInstruction(builder)) { + return false; + } + } while (lexer_.GetKind() != TokKind::kRbrace); + return ParseToken(TokKind::kRbrace, + "expects '}' at the end of instruction list."); +} + +// instruction ::= name '=' shape opcode operands +bool HloParser::ParseInstruction(HloComputation::Builder* builder) { + string name; + Shape shape; + HloOpcode opcode; + std::vector operands; + if (!ParseName(&name) || + !ParseToken(TokKind::kEqual, "expects '=' in instruction") || + !ParseShape(&shape) || !ParseOpcode(&opcode)) { + return false; + } + switch (opcode) { + case HloOpcode::kParameter: { + int64 parameter_number; + return ParseToken(TokKind::kLparen, + "expects '(' before parameter number") && + ParseInt64(¶meter_number) && + ParseToken(TokKind::kRparen, + "expects ')' after parameter number") && + AddInstruction( + name, builder->AddInstruction(HloInstruction::CreateParameter( + parameter_number, shape, name))); + } + case HloOpcode::kConstant: { + std::unique_ptr literal; + return ParseToken(TokKind::kLparen, + "expects '(' before parameter number") && + ParseLiteral(&literal, shape) && + ParseToken(TokKind::kRparen, + "expects ')' after parameter number") && + AddInstruction( + name, builder->AddInstruction( + HloInstruction::CreateConstant(std::move(literal)))); + } + // Unary ops. + case HloOpcode::kAbs: + case HloOpcode::kRoundNearestAfz: + case HloOpcode::kBitcast: + case HloOpcode::kCeil: + case HloOpcode::kCopy: + case HloOpcode::kCos: + case HloOpcode::kExp: + case HloOpcode::kIsFinite: + case HloOpcode::kFloor: + case HloOpcode::kLog: + case HloOpcode::kNot: + case HloOpcode::kNegate: + case HloOpcode::kSign: + case HloOpcode::kSin: + case HloOpcode::kSort: + case HloOpcode::kTanh: { + return ParseOperands(&operands, /*expected_size=*/1) && + AddInstruction(name, + builder->AddInstruction(HloInstruction::CreateUnary( + shape, opcode, operands[0]))); + } + // Binary ops. + case HloOpcode::kAdd: + case HloOpcode::kDivide: + case HloOpcode::kMultiply: + case HloOpcode::kSubtract: + case HloOpcode::kEq: + case HloOpcode::kGe: + case HloOpcode::kGt: + case HloOpcode::kLe: + case HloOpcode::kLt: + case HloOpcode::kNe: + case HloOpcode::kDot: + case HloOpcode::kMaximum: + case HloOpcode::kMinimum: + case HloOpcode::kPower: + case HloOpcode::kRemainder: + case HloOpcode::kAnd: + case HloOpcode::kOr: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: { + return ParseOperands(&operands, /*expected_size=*/2) && + AddInstruction( + name, builder->AddInstruction(HloInstruction::CreateBinary( + shape, opcode, operands[0], operands[1]))); + } + // Ternary ops. + case HloOpcode::kClamp: + case HloOpcode::kSelect: { + return ParseOperands(&operands, /*expected_size=*/3) && + AddInstruction( + name, + builder->AddInstruction(HloInstruction::CreateTernary( + shape, opcode, operands[0], operands[1], operands[2]))); + } + // Other supported ops. + case HloOpcode::kConvert: { + return ParseOperands(&operands, /*expected_size=*/1) && + AddInstruction( + name, builder->AddInstruction( + HloInstruction::CreateConvert(shape, operands[0]))); + } + case HloOpcode::kCrossReplicaSum: { + return ParseOperands(&operands, /*expected_size=*/1) && + AddInstruction(name, builder->AddInstruction( + HloInstruction::CreateCrossReplicaSum( + shape, operands[0]))); + } + case HloOpcode::kReshape: { + return ParseOperands(&operands, /*expected_size=*/1) && + AddInstruction( + name, builder->AddInstruction( + HloInstruction::CreateReshape(shape, operands[0]))); + } + case HloOpcode::kBroadcast: + case HloOpcode::kCall: + case HloOpcode::kCustomCall: + case HloOpcode::kConcatenate: + case HloOpcode::kReducePrecision: + case HloOpcode::kConvolution: + case HloOpcode::kGetTupleElement: + case HloOpcode::kMap: + case HloOpcode::kPad: + case HloOpcode::kReduce: + case HloOpcode::kReduceWindow: + case HloOpcode::kSelectAndScatter: + case HloOpcode::kReverse: + case HloOpcode::kRng: + case HloOpcode::kSlice: + case HloOpcode::kDynamicSlice: + case HloOpcode::kDynamicUpdateSlice: + case HloOpcode::kTranspose: + case HloOpcode::kTuple: + case HloOpcode::kWhile: + case HloOpcode::kFusion: + case HloOpcode::kBatchNormTraining: + case HloOpcode::kBatchNormInference: + case HloOpcode::kInfeed: + case HloOpcode::kOutfeed: + case HloOpcode::kBatchNormGrad: + case HloOpcode::kRecv: + case HloOpcode::kSend: + case HloOpcode::kUpdate: + case HloOpcode::kIndex: + case HloOpcode::kTrace: + return TokenError(StrCat("parsing not yet implemented for op: ", + HloOpcodeString(opcode))); + } +} + +bool HloParser::ParseLiteral(std::unique_ptr* literal, + const Shape& shape) { + switch (shape.element_type()) { + case PRED: + bool b; + if (!ParseBool(&b)) { + return false; + } + *literal = Literal::CreateR0(b); + return true; + case S32: + int64 i; + if (!ParseInt64(&i)) { + return false; + } + *literal = Literal::CreateR0(i); + return true; + case F32: + double d; + if (!ParseDecimal(&d)) { + return false; + } + *literal = Literal::CreateR0(d); + return true; + default: + return TokenError(StrCat("unsupported constant in shape: ", + ShapeUtil::HumanString(shape))); + } +} + +// operands ::= '(' operands1 ')' +// operands1 +// ::= /*empty*/ +// ::= operand (, operand)* +// operand ::= shape name +bool HloParser::ParseOperands(std::vector* operands, + const int expected_size) { + if (!ParseToken(TokKind::kLparen, + "expects '(' at the beginning of operands")) { + return false; + } + if (lexer_.GetKind() == TokKind::kRparen) { + // empty + } else { + do { + Shape shape; + string name; + if (!ParseShape(&shape) || !ParseName(&name)) { + return false; + } + HloInstruction* instruction = + tensorflow::gtl::FindPtrOrNull(instruction_pool_, name); + if (!instruction) { + return TokenError(StrCat("instruction does not exist: ", name)); + } + operands->push_back(instruction); + } while (EatIfPresent(TokKind::kComma)); + } + if (expected_size != operands->size()) { + return TokenError(StrCat("expects ", expected_size, " operands, but has ", + operands->size(), " operands")); + } + return ParseToken(TokKind::kRparen, "expects ')' at the end of operands"); +} + +// param_list ::= '(' param_list1 ')' +// param_list1 +// ::= /*empty*/ +// ::= param (',' param)* +// param ::= name shape +bool HloParser::ParseParamList() { + if (!ParseToken(TokKind::kLparen, + "expects '(' at the beginning of param list")) { + return false; + } + + if (lexer_.GetKind() == TokKind::kRparen) { + // empty + } else { + do { + Shape shape; + if (!ParseToken(TokKind::kName, "expects name in parameter") || + !ParseShape(&shape)) { + return false; + } + } while (EatIfPresent(TokKind::kComma)); + } + return ParseToken(TokKind::kRparen, "expects ')' at the end of param list"); +} + +// shape ::= shape_val_ +// shape ::= '(' tuple_elements ')' +// tuple_elements +// ::= /*empty*/ +// ::= shape (',' shape)* +bool HloParser::ParseShape(Shape* result) { + if (EatIfPresent(TokKind::kLparen)) { // Tuple + std::vector shapes; + if (lexer_.GetKind() == TokKind::kRparen) { + /*empty*/ + } else { + // shape (',' shape)* + do { + shapes.emplace_back(); + if (!ParseShape(&shapes.back())) { + return false; + } + } while (EatIfPresent(TokKind::kComma)); + } + *result = ShapeUtil::MakeTupleShape(shapes); + return ParseToken(TokKind::kRparen, "expects ')' at the end of tuple."); + } + + if (lexer_.GetKind() != TokKind::kShape) { + return TokenError("expects shape"); + } + *result = lexer_.GetShapeVal(); + lexer_.Lex(); + return true; +} + +bool HloParser::ParseName(string* result) { + VLOG(1) << "ParseName"; + if (lexer_.GetKind() != TokKind::kName) { + return TokenError("expects name"); + } + *result = lexer_.GetStrVal(); + lexer_.Lex(); + return true; +} + +bool HloParser::ParseOpcode(HloOpcode* result) { + VLOG(1) << "ParseOpcode"; + if (lexer_.GetKind() != TokKind::kOpcode) { + return TokenError("expects opcode"); + } + *result = lexer_.GetOpcodeVal(); + lexer_.Lex(); + return true; +} + +bool HloParser::ParseInt64(int64* result) { + VLOG(1) << "ParseInt64"; + if (lexer_.GetKind() != TokKind::kInt) { + return TokenError("expects integer"); + } + *result = lexer_.GetInt64Val(); + lexer_.Lex(); + return true; +} + +bool HloParser::ParseDecimal(double* result) { + switch (lexer_.GetKind()) { + case TokKind::kDecimal: + *result = lexer_.GetDecimalVal(); + break; + case TokKind::kInt: + *result = static_cast(lexer_.GetInt64Val()); + break; + default: + return TokenError("expects decimal or integer"); + } + lexer_.Lex(); + return true; +} + +bool HloParser::ParseBool(bool* result) { + if (lexer_.GetKind() != TokKind::kw_true && + lexer_.GetKind() != TokKind::kw_false) { + return TokenError("expects true or false"); + } + *result = lexer_.GetKind() == TokKind::kw_true; + lexer_.Lex(); + return true; +} + +bool HloParser::ParseToken(TokKind kind, const string& msg) { + if (lexer_.GetKind() != kind) { + return TokenError(msg); + } + lexer_.Lex(); + return true; +} + +bool HloParser::EatIfPresent(TokKind kind) { + if (lexer_.GetKind() != kind) { + return false; + } + lexer_.Lex(); + return true; +} + +bool HloParser::AddInstruction(const string& name, + HloInstruction* instruction) { + auto result = instruction_pool_.insert({name, instruction}); + if (!result.second) { + return TokenError(StrCat("instruction already exists: ", name)); + } + return true; +} + +} // namespace + +StatusOr> Parse(StringPiece str) { + HloParser parser(str); + if (!parser.Run()) { + return InvalidArgument("Syntax error: %s", parser.GetError().c_str()); + } + return parser.ConsumeHloModule(); +} + +} // namespace tools +} // namespace xla diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.h b/tensorflow/compiler/xla/tools/parser/hlo_parser.h new file mode 100644 index 0000000000000000000000000000000000000000..9aaf18ef20d769cd9ac6f0e48bc92f62292ba31a --- /dev/null +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.h @@ -0,0 +1,37 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_PARSER_H_ +#define TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_PARSER_H_ + +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_lexer.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace tools { + +// The api of the hlo parser. Given a string in the HloModule::ToString() +// format, returns the parsed HloModule. +StatusOr> Parse(tensorflow::StringPiece str); + +} // namespace tools +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_PARSER_H_ diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..4ecece3eac1677433051a71d8ba981add5d9cf20 --- /dev/null +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc @@ -0,0 +1,240 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" + +#include +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace tools { +namespace { + +struct TestData { + string test_name; + string module_string; +}; + +string TestDataToString(const ::testing::TestParamInfo& data) { + return data.param.test_name; +} + +std::vector CreateTestCases() { + // clang-format off + return std::vector({ +// ax + y +{ +"AxpyParam", +R"(HloModule axpy_module: + +ENTRY %axpy.v5 (alpha: f32[2,4], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { + %alpha = f32[2,4]{1,0} parameter(0) + %x = f32[2,4]{1,0} parameter(1) + %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %alpha, f32[2,4]{1,0} %x) + %y = f32[2,4]{1,0} parameter(2) + %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) +} + +)" +}, +// pred constant +{ +"ConstantPred", +R"(HloModule constant_pred_module: + +ENTRY %constant_pred () -> pred[] { + %constant = pred[] constant(true) +} + +)" +}, +// s32 constant +{ +"ConstantS32", +R"(HloModule constant_s32_module: + +ENTRY %constant_s32 () -> s32[] { + %constant = s32[] constant(-42) +} + +)" +}, +// f32 constant, but the value is not a decimal +{ +"ConstantF32", R"(HloModule ConstantF32_module: + +ENTRY %ConstantF32.v4 () -> f32[] { + %constant = f32[] constant(42) +} + +)" +}, +// constant + constant +{ +"AddConstants", +R"(HloModule add_constants_module: + +ENTRY %add_constants () -> f32[] { + %constant = f32[] constant(3.14) + %add = f32[] add(f32[] %constant, f32[] %constant) +} + +)" +}, +// v1 > v2 ? v1 : v2 +{ +"SelectR1F32", +R"(HloModule SelectR1F32WithCmpR1F32sFromParamsSmall_module: + +ENTRY %SelectR1F32WithCmpR1F32sFromParamsSmall.v4 (v1: f32[4], v2: f32[4]) -> f32[4] { + %v1 = f32[4]{0} parameter(0) + %v2 = f32[4]{0} parameter(1) + %greater-than = pred[4]{0} greater-than(f32[4]{0} %v1, f32[4]{0} %v2) + %select = f32[4]{0} select(pred[4]{0} %greater-than, f32[4]{0} %v1, f32[4]{0} %v2) +} + +)" +} + }); + // clang-format on +} + +class HloParserTest : public ::testing::Test, + public ::testing::WithParamInterface { + protected: + void ExpectSuccess() { + const string& original = GetParam().module_string; + auto result = Parse(original); + TF_EXPECT_OK(result.status()); + EXPECT_EQ(original, result.ValueOrDie()->ToString()); + } +}; + +TEST_P(HloParserTest, Run) { ExpectSuccess(); } + +INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserTest, + ::testing::ValuesIn(CreateTestCases()), + TestDataToString); + +TEST_F(HloParserTest, Empty) { + const string original = ""; + auto result = Parse(original); + EXPECT_NE(tensorflow::Status::OK(), result.status()); +} + +TEST_F(HloParserTest, Garbage) { + const string original = "HloModule thi$ str1ng makes# N0 sen$e @all!*&^%$"; + auto result = Parse(original); + EXPECT_NE(tensorflow::Status::OK(), result.status()); +} + +TEST_F(HloParserTest, WrongOpcode) { + const string original = R"(HloModule wrong_opcode: + +ENTRY %blabla (x: f32[], y: f32[]) -> f32[] { + %x = f32[]{} parameter(0) + %y = f32[]{} parameter(1) + %le = pred[]{} le(f32[]{} %x, f32[]{} %y) +} + +)"; + auto result = Parse(original); + EXPECT_NE(tensorflow::Status::OK(), result.status()); +} + +TEST_F(HloParserTest, WrongShape) { + const string original = R"(HloModule wrong_opcode: + +ENTRY %blabla (x: g32[]) -> g32[] { + %x = g32[]{} parameter(0) +} + +)"; + auto result = Parse(original); + EXPECT_NE(tensorflow::Status::OK(), result.status()); +} + +TEST_F(HloParserTest, WrongOperandsSize) { + const string original = R"(HloModule wrong_opcode: + +ENTRY %blabla (x: f32[]) -> pred[] { + %x = f32[]{} parameter(0) + %eq = pred[]{} equal-to(f32[]{} %x) +} + +)"; + auto result = Parse(original); + EXPECT_NE(tensorflow::Status::OK(), result.status()); +} + +TEST_F(HloParserTest, OperandNotFound) { + const string original = R"(HloModule operand_not_found: +ENTRY %blabla (x: f32[]) -> pred[] { + %x = f32[]{} parameter(0) + %eq = pred[]{} equal-to(f32[]{} %x, f32[]{} %y) +} +)"; + auto result = Parse(original); + EXPECT_NE(tensorflow::Status::OK(), result.status()); +} + +TEST_F(HloParserTest, MoreConstants) { + const string original = R"(HloModule SelectScalarS32True_module: + +ENTRY %SelectScalarS32True.v4 () -> s32[] { + %constant.2 = pred[] constant(true) + %constant.1 = s32[] constant(-42) + %constant = s32[] constant(42) + %select = s32[] select(pred[] %constant.2, s32[] %constant.1, s32[] %constant) +} + +)"; + auto result = Parse(original); + TF_EXPECT_OK(result.status()); + // Constant instructions have no name. The string will be parsed successfully + // but the constant names will not be exactly the same. +} + +TEST_F(HloParserTest, ConstantWithExp) { + const string original = R"(HloModule ConstantWithExp_module: + +ENTRY %ConstantWithExp.v4 () -> f32[] { + %constant.1 = f32[] constant(3e+2) +} + +)"; + auto result = Parse(original); + TF_EXPECT_OK(result.status()); + // The string will be parsed successfully but the output strings are not + // exactly the same, because "3e2" is parsed into value 300 and will be + // printed as "300". +} + +TEST_F(HloParserTest, Tuple) { + const string original = R"(HloModule EmptyTupleCreate_module: + +ENTRY %EmptyTupleCreate.v1 () -> () { + %tuple = () tuple() +} + +)"; + auto result = Parse(original); + EXPECT_NE(tensorflow::Status::OK(), result.status()); +} + +} // namespace +} // namespace tools +} // namespace xla diff --git a/tensorflow/compiler/xla/tools/parser/hlo_token.h b/tensorflow/compiler/xla/tools/parser/hlo_token.h new file mode 100644 index 0000000000000000000000000000000000000000..1f75e17c7f0ff66f91fd71113b27252e5487eecd --- /dev/null +++ b/tensorflow/compiler/xla/tools/parser/hlo_token.h @@ -0,0 +1,58 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_TOKEN_H_ +#define TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_TOKEN_H_ + +namespace xla { +namespace tools { + +// Defines different kinds of tokens in a hlo module string. +enum class TokKind { + // Markers + kEof, + kError, + + // Tokens with no info. + kEqual, // = + kComma, // , + kColon, // : + kLsquare, + kRsquare, // [ ] + kLbrace, + kRbrace, // { } + kLparen, + kRparen, // ( ) + + kArrow, // -> + + // Keywords + kw_HloModule, + kw_ENTRY, + kw_true, + kw_false, + + // Typed tokens. + kName, // %foo + kShape, // f32[2,3]{1,0} + kOpcode, // add + kInt, // 42 + kDecimal, // 4.2 +}; + +} // namespace tools +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_TOKEN_H_ diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index 7f4bd26d1bcc3ff9cc002adb28d2adfcf96f59ab..ce3c3eee68ad7f7ebb42836e3cae14803f8650d7 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -82,8 +82,8 @@ message DebugOptions { // Dump all HLO modules as text into the provided directory path. string xla_generate_hlo_text_to = 7; - // Dump compilation artifacts as JSON into this directory. - string xla_dump_debug_json_to = 8; + // Dump compilation artifacts in binary proto into this directory. + string xla_dump_hlo_proto_to = 8; // Instrument the computation to collect per-HLO cycle counts. bool xla_hlo_profile = 9; diff --git a/tensorflow/contrib/batching/BUILD b/tensorflow/contrib/batching/BUILD index 1555a3427fd5e40ca54c134a2c80f9d2c5feca36..ae3f48f1b276b1f13078e8845c4c87cf3473513f 100644 --- a/tensorflow/contrib/batching/BUILD +++ b/tensorflow/contrib/batching/BUILD @@ -69,6 +69,28 @@ tf_cc_test( ], ) +cc_library( + name = "adaptive_shared_batch_scheduler", + hdrs = ["adaptive_shared_batch_scheduler.h"], + deps = [ + ":batch_scheduler", + "//tensorflow/contrib/batching/util:periodic_function_dynamic", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "adaptive_shared_batch_scheduler_test", + srcs = ["adaptive_shared_batch_scheduler_test.cc"], + deps = [ + ":adaptive_shared_batch_scheduler", + "//tensorflow/contrib/batching/test_util:fake_clock_env", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + cc_library( name = "basic_batch_scheduler", hdrs = ["basic_batch_scheduler.h"], diff --git a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h new file mode 100644 index 0000000000000000000000000000000000000000..ac32f096395ddb86d4230b56b7c047643340afc2 --- /dev/null +++ b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h @@ -0,0 +1,463 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_ + +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/batching/batch_scheduler.h" +#include "tensorflow/contrib/batching/util/periodic_function.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/cpu_info.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace serving { +namespace internal { +template +class ASBSBatch; + +template +class ASBSQueue; +} // namespace internal + +// Shared batch scheduler designed to minimize latency. The scheduler keeps +// track of a number of queues (one per model or model version) which are +// continuously enqueuing requests. The scheduler groups the requests into +// batches which it periodically sends off for processing (see +// shared_batch_scheduler.h for more details). The AdaptiveSharedBatchScheduler +// prioritizes batches by age (i.e. the batch's oldest request) irrespective of +// queue. The scheduler will process the oldest batch at an adjustable rate, +// regardless of batch size. The user can provide feedback to help set this rate +// to achieve some goal (i.e. minimize overall latency, limit cpu usage, etc). +// +// The rate (or rather, the corresponding period) is adjusted each time a batch +// is processed, using an exponentially weighted moving average to smooth +// potentially noisy feedback: +// ewma_feedback = ((N - 1) * ewma_feedback + feedback()) / N +// period *= (1 + K * emwa_feedback) +// +// Some potential use cases: +// Hardware Accelerators (GPUs & TPUs) - If some phase of batch processing +// involves serial processing by a device, from a latency perspective it is +// desirable to keep the device evenly loaded, avoiding the need to wait for +// the device to process prior batches. +// feedback = num_pending_on_device() - desired_pending. +// CPU utilization - If the batch processing is cpu dominated, you can reap +// latency gains when underutilized by increasing the processing rate, but +// back the rate off when the load increases to avoid overload. +// feedback = cpu_rate() - desired_cpu_rate. + +template +class AdaptiveSharedBatchScheduler + : public std::enable_shared_from_this< + AdaptiveSharedBatchScheduler> { + public: + struct Options { + // The name to use for the pool of batch threads. + string thread_pool_name = {"batch_threads"}; + // Number of batch processing threads; equivalently the maximum number of + // concurrently running batches. + int64 num_batch_threads = port::NumSchedulableCPUs(); + // The environment to use (typically only overridden by test code). + Env* env = Env::Default(); + // Initial batch scheduling period in microseconds. Will be altered for + // non-zero rate_feedback. + double initial_scheduling_period_micros = 500; + // Minimum batch scheduling period in microseconds. Recommend setting this + // value greater than 0, otherwise it may take a while to recover from a + // sustained time of negative scheduling_period_feedback (which may occur + // under low load). + double min_scheduling_period_micros = 100; + // Maximum batch scheduling period in microseconds. + double max_scheduling_period_micros = 10000; + // Feedback function used to modify the scheduling period each time a batch + // is scheduled. Should return values roughly O(1), with positive values + // resulting in an increased period. + std::function scheduling_period_feedback = [] { return 0.; }; + // To handle potentially noisy scheduling_period_feedback, the period is + // adjusted using an exponentially weighted moving average over the previous + // feedback_smoothing_batches batches. Must be greater than 0. + int64 feedback_smoothing_batches = 10; + }; + + // Ownership is shared between the caller of Create() and any queues created + // via AddQueue(). + static Status Create( + const Options& options, + std::shared_ptr>* scheduler); + + struct QueueOptions { + // Maximum size of each batch. + int max_batch_size = 1000; + // Maximum number of enqueued (i.e. non-scheduled) batches. + int max_enqueued_batches = 10; + }; + + using BatchProcessor = std::function>)>; + + // Adds queue (and its callback) to be managed by this scheduler. + Status AddQueue(const QueueOptions& options, + BatchProcessor process_batch_callback, + std::unique_ptr>* queue); + + private: + // access to AddBatch, RemoveQueue, GetEnv. + friend class internal::ASBSQueue; + + explicit AdaptiveSharedBatchScheduler(const Options& options); + + // Batch scheduling function which runs every scheduling_period_ microseconds. + void ProcessOneBatch(); + + // Notifies scheduler of non-empty batch which is eligible for processing. + void AddBatch(internal::ASBSBatch*); + + // Removes queue from scheduler. + void RemoveQueue(const internal::ASBSQueue* queue); + + Env* GetEnv() const { return options_.env; } + + const Options options_; + + struct BatchCompare { + bool operator()(const internal::ASBSBatch* a, + const internal::ASBSBatch* b); + }; + + // Collection of batches added by AddBatch, ordered by age. Owned by scheduler + // until they are released for processing. + std::priority_queue*, + std::vector*>, BatchCompare> + batches_ GUARDED_BY(mu_); + + // Unowned queues and callbacks added by AddQueue. + std::unordered_map*, BatchProcessor> + queues_and_callbacks_ GUARDED_BY(mu_); + + mutex mu_; + + // Responsible for running ProcessOneBatch. PeriodicFunction was used in order + // to check for deletion so that the thread can be shut down. + std::unique_ptr scheduling_thread_; + + // Responsible for running the batch processing callbacks. + std::unique_ptr batch_thread_pool_; + + // Time interval in microseconds between successive ProcessOneBatch calls. + double scheduling_period_; + + // Exponentially weighted moving average of + // options_.scheduling_period_feedback() evaluated in each ProcessOneBatch + // call. + double ewma_feedback_ = 0; + + TF_DISALLOW_COPY_AND_ASSIGN(AdaptiveSharedBatchScheduler); +}; + +////////////////////////////////////////////////////////// +// Implementation details follow. API users need not read. + +namespace internal { +// Consolidates tasks into batches, passing them off to the +// AdaptiveSharedBatchScheduler for processing. +template +class ASBSQueue : public BatchScheduler { + public: + using QueueOptions = + typename AdaptiveSharedBatchScheduler::QueueOptions; + + ASBSQueue(std::shared_ptr> scheduler, + const QueueOptions& options); + + ~ASBSQueue() override; + + // Adds task to current batch. Fails if the task size is larger than the batch + // size or if the current batch is full and this queue's number of outstanding + // batches is at its maximum. + Status Schedule(std::unique_ptr* task) override; + + // Number of tasks waiting to be scheduled. + size_t NumEnqueuedTasks() const override; + + // Number of size 1 tasks which could currently be scheduled without failing. + size_t SchedulingCapacity() const override; + + // Notifies queue that a batch is about to be scheduled; the queue should not + // place any more tasks in this batch. + void ReleaseBatch(const ASBSBatch* batch); + + private: + std::shared_ptr> scheduler_; + const QueueOptions options_; + // Owned by scheduler_. + ASBSBatch* current_batch_ GUARDED_BY(mu_) = nullptr; + int64 num_enqueued_batches_ GUARDED_BY(mu_) = 0; + int64 num_enqueued_tasks_ GUARDED_BY(mu_) = 0; + mutable mutex mu_; + TF_DISALLOW_COPY_AND_ASSIGN(ASBSQueue); +}; + +// Batch which remembers when and by whom it was created. +template +class ASBSBatch : public Batch { + public: + ASBSBatch(ASBSQueue* queue, int64 creation_time_micros) + : queue_(queue), creation_time_micros_(creation_time_micros) {} + + ~ASBSBatch() override {} + + ASBSQueue* queue() const { return queue_; } + + int64 creation_time_micros() const { return creation_time_micros_; } + + private: + ASBSQueue* queue_; + const int64 creation_time_micros_; + TF_DISALLOW_COPY_AND_ASSIGN(ASBSBatch); +}; +} // namespace internal + +// ---------------- AdaptiveSharedBatchScheduler ---------------- + +template +Status AdaptiveSharedBatchScheduler::Create( + const Options& options, + std::shared_ptr>* scheduler) { + if (options.num_batch_threads < 1) { + return errors::InvalidArgument("num_batch_threads must be positive; was ", + options.num_batch_threads); + } + if (options.min_scheduling_period_micros < 0) { + return errors::InvalidArgument( + "min_scheduling_period_micros must be >= 0; was ", + options.min_scheduling_period_micros); + } + if (options.min_scheduling_period_micros > + options.initial_scheduling_period_micros) { + return errors::InvalidArgument( + "initial_scheduling_period_micros (", + options.initial_scheduling_period_micros, + ") must be >= min_scheduling_period_micros (", + options.min_scheduling_period_micros, ")"); + } + if (options.initial_scheduling_period_micros > + options.max_scheduling_period_micros) { + return errors::InvalidArgument( + "initial_scheduling_period_micros (", + options.initial_scheduling_period_micros, + ") must be <= max_scheduling_period_micros (", + options.max_scheduling_period_micros, ")"); + } + if (options.feedback_smoothing_batches < 1) { + return errors::InvalidArgument( + "feedback_smoothing_batches must be positive; was ", + options.feedback_smoothing_batches); + } + scheduler->reset(new AdaptiveSharedBatchScheduler(options)); + return Status::OK(); +} + +template +AdaptiveSharedBatchScheduler::AdaptiveSharedBatchScheduler( + const Options& options) + : options_(options), + scheduling_period_(options.initial_scheduling_period_micros) { + PeriodicFunction::Options opts; + opts.thread_name_prefix = "scheduling_thread"; + opts.env = GetEnv(); + scheduling_thread_.reset( + new PeriodicFunction([this] { ProcessOneBatch(); }, 0, opts)); + batch_thread_pool_.reset(new thread::ThreadPool( + GetEnv(), options.thread_pool_name, options.num_batch_threads)); +} + +template +Status AdaptiveSharedBatchScheduler::AddQueue( + const QueueOptions& options, BatchProcessor process_batch_callback, + std::unique_ptr>* queue) { + if (options.max_batch_size <= 0) { + return errors::InvalidArgument("max_batch_size must be positive; was ", + options.max_batch_size); + } + if (options.max_enqueued_batches <= 0) { + return errors::InvalidArgument( + "max_enqueued_batches must be positive; was ", + options.max_enqueued_batches); + } + internal::ASBSQueue* asbs_queue_raw; + queue->reset(asbs_queue_raw = new internal::ASBSQueue( + this->shared_from_this(), options)); + mutex_lock l(mu_); + queues_and_callbacks_[asbs_queue_raw] = process_batch_callback; + return Status::OK(); +} + +template +void AdaptiveSharedBatchScheduler::AddBatch( + internal::ASBSBatch* batch) { + mutex_lock l(mu_); + batches_.push(batch); +} + +template +void AdaptiveSharedBatchScheduler::RemoveQueue( + const internal::ASBSQueue* queue) { + mutex_lock l(mu_); + queues_and_callbacks_.erase(queue); +} + +template +void AdaptiveSharedBatchScheduler::ProcessOneBatch() { + static const double kFeedbackMultiplier = .001; + internal::ASBSBatch* batch = nullptr; + BatchProcessor callback; + const int64 start_time_micros = GetEnv()->NowMicros(); + { + mutex_lock l(mu_); + if (!batches_.empty()) { + batch = batches_.top(); + batches_.pop(); + callback = queues_and_callbacks_[batch->queue()]; + } + } + if (batch != nullptr) { + double feedback = options_.scheduling_period_feedback(); + const int64 N = options_.feedback_smoothing_batches; + ewma_feedback_ = ((N - 1) * ewma_feedback_ + feedback) / N; + scheduling_period_ *= (1 + kFeedbackMultiplier * ewma_feedback_); + if (scheduling_period_ < options_.min_scheduling_period_micros) { + scheduling_period_ = options_.min_scheduling_period_micros; + } else if (scheduling_period_ > options_.max_scheduling_period_micros) { + scheduling_period_ = options_.max_scheduling_period_micros; + } + // Queue may destroy itself after ReleaseBatch is called. + batch->queue()->ReleaseBatch(batch); + batch_thread_pool_->Schedule([callback, batch] { + callback(std::unique_ptr>(batch)); + }); + } + const int64 sleep_time = + scheduling_period_ - (GetEnv()->NowMicros() - start_time_micros); + if (sleep_time > 0) { + GetEnv()->SleepForMicroseconds(sleep_time); + } +} + +template +bool AdaptiveSharedBatchScheduler::BatchCompare::operator()( + const internal::ASBSBatch* a, + const internal::ASBSBatch* b) { + return a->creation_time_micros() > b->creation_time_micros(); +} + +// ---------------- ASBSQueue ---------------- + +namespace internal { +template +ASBSQueue::ASBSQueue( + std::shared_ptr> scheduler, + const QueueOptions& options) + : scheduler_(scheduler), options_(options) {} + +template +ASBSQueue::~ASBSQueue() { + // Wait until last batch has been scheduled. + const int kSleepMicros = 1000; + for (;;) { + { + mutex_lock l(mu_); + if (num_enqueued_batches_ == 0) { + break; + } + } + scheduler_->GetEnv()->SleepForMicroseconds(kSleepMicros); + } + scheduler_->RemoveQueue(this); +} + +template +Status ASBSQueue::Schedule(std::unique_ptr* task) { + bool added_new_batch = false; + size_t size = (*task)->size(); + if (size > options_.max_batch_size) { + return errors::InvalidArgument("Task size ", size, + " is larger than maximum batch size ", + options_.max_batch_size); + } + { + mutex_lock l(mu_); + // Current batch is full, create another if allowed. + if (current_batch_ && + current_batch_->size() + size > options_.max_batch_size) { + if (num_enqueued_batches_ >= options_.max_enqueued_batches) { + return errors::Unavailable("The batch scheduling queue is full"); + } + current_batch_->Close(); + current_batch_ = nullptr; + } + if (!current_batch_) { + added_new_batch = true; + num_enqueued_batches_++; + current_batch_ = + new ASBSBatch(this, scheduler_->GetEnv()->NowMicros()); + } + current_batch_->AddTask(std::move(*task)); + num_enqueued_tasks_++; + } + if (added_new_batch) scheduler_->AddBatch(current_batch_); + return Status::OK(); +} + +template +void ASBSQueue::ReleaseBatch(const ASBSBatch* batch) { + mutex_lock l(mu_); + num_enqueued_batches_--; + num_enqueued_tasks_ -= batch->num_tasks(); + if (batch == current_batch_) { + current_batch_->Close(); + current_batch_ = nullptr; + } +} + +template +size_t ASBSQueue::NumEnqueuedTasks() const { + mutex_lock l(mu_); + return num_enqueued_tasks_; +} + +template +size_t ASBSQueue::SchedulingCapacity() const { + mutex_lock l(mu_); + const int current_batch_capacity = + current_batch_ ? options_.max_batch_size - current_batch_->size() : 0; + const int spare_batches = + options_.max_enqueued_batches - num_enqueued_batches_; + return spare_batches * options_.max_batch_size + current_batch_capacity; +} +} // namespace internal +} // namespace serving +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_ diff --git a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler_test.cc b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a07cd6d834fa28904bf7748b16972cca217503c1 --- /dev/null +++ b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler_test.cc @@ -0,0 +1,438 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h" + +#include "tensorflow/contrib/batching/test_util/fake_clock_env.h" +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace serving { +namespace anonymous { + +class FakeTask : public BatchTask { + public: + explicit FakeTask(size_t size) : size_(size) {} + + ~FakeTask() override = default; + + size_t size() const override { return size_; } + + private: + const size_t size_; + + TF_DISALLOW_COPY_AND_ASSIGN(FakeTask); +}; + +// Creates a FakeTask of size 'task_size', and calls 'scheduler->Schedule()' on +// that task. Returns the resulting status. +Status ScheduleTask(size_t task_size, BatchScheduler* scheduler) { + std::unique_ptr task(new FakeTask(task_size)); + Status status = scheduler->Schedule(&task); + // Schedule() should have consumed 'task' iff it returned Status::OK. + CHECK_EQ(status.ok(), task == nullptr); + return status; +} + +// Creates a thread that waits on 'start' and then advances the fake clock in +// 'env' in a loop until 'stop' is notified. Useful for allowing objects that +// use the clock to be destroyed. +std::unique_ptr CreateFakeClockAdvancerThread( + test_util::FakeClockEnv* env, Notification* start, Notification* stop) { + return std::unique_ptr(Env::Default()->StartThread( + {}, "FakeClockAdvancerThread", [env, start, stop] { + start->WaitForNotification(); + while (!stop->HasBeenNotified()) { + env->AdvanceByMicroseconds(10); + Env::Default()->SleepForMicroseconds(10); + } + })); +} + +TEST(AdaptiveSharedBatchSchedulerTest, Basic) { + for (const bool delete_scheduler_early : {false, true}) { + for (const bool delete_queue_1_early : {false, true}) { + int queue_0_tasks = 0; + auto queue_0_callback = + [&queue_0_tasks](std::unique_ptr> batch) { + ASSERT_TRUE(batch->IsClosed()); + EXPECT_GT(batch->num_tasks(), 0); + for (int i = 0; i < batch->num_tasks(); i++) { + queue_0_tasks += batch->task(i).size(); + } + }; + int queue_1_tasks = 0; + auto queue_1_callback = + [&queue_1_tasks](std::unique_ptr> batch) { + ASSERT_TRUE(batch->IsClosed()); + EXPECT_GT(batch->num_tasks(), 0); + for (int i = 0; i < batch->num_tasks(); i++) { + queue_1_tasks += batch->task(i).size(); + } + }; + { + std::shared_ptr> scheduler; + TF_ASSERT_OK( + AdaptiveSharedBatchScheduler::Create({}, &scheduler)); + + // Create two queues. + std::unique_ptr> queue_0; + TF_ASSERT_OK(scheduler->AddQueue({}, queue_0_callback, &queue_0)); + std::unique_ptr> queue_1; + TF_ASSERT_OK(scheduler->AddQueue({}, queue_1_callback, &queue_1)); + + if (delete_scheduler_early) { + // Delete our copy of the scheduler. The queues should keep it alive + // under the covers. + scheduler = nullptr; + } + // Submit tasks to the two queues, and (optionally) remove the queues. + TF_ASSERT_OK(ScheduleTask(1, queue_0.get())); + TF_ASSERT_OK(ScheduleTask(2, queue_1.get())); + TF_ASSERT_OK(ScheduleTask(3, queue_0.get())); + TF_ASSERT_OK(ScheduleTask(4, queue_1.get())); + if (delete_queue_1_early) { + queue_1 = nullptr; + } + TF_ASSERT_OK(ScheduleTask(5, queue_0.get())); + } + EXPECT_EQ(queue_0_tasks, 9); + EXPECT_EQ(queue_1_tasks, 6); + } + } +} + +TEST(AdaptiveSharedBatchSchedulerTest, BadOptions) { + using Scheduler = AdaptiveSharedBatchScheduler; + std::shared_ptr scheduler; + Scheduler::Options options; + options.num_batch_threads = 0; + EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok()); + options = Scheduler::Options(); + options.min_scheduling_period_micros = 50; + options.max_scheduling_period_micros = 100; + options.initial_scheduling_period_micros = 1; + EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok()); + options = Scheduler::Options(); + options.min_scheduling_period_micros = 50; + options.max_scheduling_period_micros = 100; + options.initial_scheduling_period_micros = 1000; + EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok()); + options = Scheduler::Options(); + options.min_scheduling_period_micros = 100; + options.max_scheduling_period_micros = 50; + options.initial_scheduling_period_micros = 75; + EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok()); + options = Scheduler::Options(); + options.feedback_smoothing_batches = 0; + EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok()); +} + +TEST(AdaptiveSharedBatchSchedulerTest, ObeysQueueOptions) { + test_util::FakeClockEnv env(Env::Default()); + Notification start_teardown, stop_teardown; + std::unique_ptr teardown_thread = + CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown); + { + AdaptiveSharedBatchScheduler::Options options; + options.initial_scheduling_period_micros = 1000; + options.env = &env; + std::shared_ptr> scheduler; + TF_ASSERT_OK( + AdaptiveSharedBatchScheduler::Create(options, &scheduler)); + std::unique_ptr> queue_0; + std::unique_ptr> queue_1; + int queue_0_tasks = 0; + int queue_1_tasks = 0; + auto queue_0_callback = [&queue_0_tasks, + &env](std::unique_ptr> batch) { + ASSERT_TRUE(batch->IsClosed()); + EXPECT_GT(batch->num_tasks(), 0); + for (int i = 0; i < batch->num_tasks(); i++) { + queue_0_tasks += batch->task(i).size(); + } + env.SleepForMicroseconds(1); + }; + auto queue_1_callback = [&queue_1_tasks, + &env](std::unique_ptr> batch) { + ASSERT_TRUE(batch->IsClosed()); + EXPECT_GT(batch->num_tasks(), 0); + for (int i = 0; i < batch->num_tasks(); i++) { + queue_1_tasks += batch->task(i).size(); + } + env.SleepForMicroseconds(1); + }; + AdaptiveSharedBatchScheduler::QueueOptions queue_options; + queue_options.max_batch_size = 10; + queue_options.max_enqueued_batches = 0; + // Queue must have max_enqueued_batchs > 1. + EXPECT_FALSE( + scheduler->AddQueue(queue_options, queue_0_callback, &queue_0).ok()); + queue_options.max_enqueued_batches = 2; + TF_ASSERT_OK( + scheduler->AddQueue(queue_options, queue_0_callback, &queue_0)); + queue_options.max_batch_size = 0; + // Queue must have max_batch_size > 0. + EXPECT_FALSE( + scheduler->AddQueue(queue_options, queue_1_callback, &queue_1).ok()); + queue_options.max_batch_size = 2; + queue_options.max_enqueued_batches = 1; + TF_ASSERT_OK( + scheduler->AddQueue(queue_options, queue_1_callback, &queue_1)); + + // Wait for scheduling_thread to sleep. + env.BlockUntilThreadsAsleep(1); + // Task larger than max_batch_size shouldn't schedule. + EXPECT_FALSE(ScheduleTask(15, queue_0.get()).ok()); + TF_ASSERT_OK(ScheduleTask(5, queue_0.get())); + TF_ASSERT_OK(ScheduleTask(5, queue_0.get())); + env.AdvanceByMicroseconds(1); + + // Task larger than max_batch_size shouldn't schedule. + EXPECT_FALSE(ScheduleTask(3, queue_1.get()).ok()); + TF_ASSERT_OK(ScheduleTask(1, queue_1.get())); + TF_ASSERT_OK(ScheduleTask(1, queue_1.get())); + env.AdvanceByMicroseconds(1); + // Exceeds max_enqueued_batches, shouldn't schedule. + EXPECT_FALSE(ScheduleTask(1, queue_1.get()).ok()); + + TF_ASSERT_OK(ScheduleTask(5, queue_0.get())); + // Exceeds max_enqueued_batches, shouldn't schedule. + EXPECT_FALSE(ScheduleTask(6, queue_0.get()).ok()); + TF_ASSERT_OK(ScheduleTask(4, queue_0.get())); + + // Batches should be processed in order from oldest to newest. + env.AdvanceByMicroseconds(1000); + env.BlockUntilThreadsAsleep(2); + EXPECT_EQ(queue_0_tasks, 10); + EXPECT_EQ(queue_1_tasks, 0); + + env.AdvanceByMicroseconds(1000); + env.BlockUntilThreadsAsleep(2); + EXPECT_EQ(queue_0_tasks, 10); + EXPECT_EQ(queue_1_tasks, 2); + + env.AdvanceByMicroseconds(1000); + env.BlockUntilThreadsAsleep(2); + EXPECT_EQ(queue_0_tasks, 19); + EXPECT_EQ(queue_1_tasks, 2); + start_teardown.Notify(); + } + stop_teardown.Notify(); +} + +TEST(AdaptiveSharedBatchSchedulerTest, RateFeedback) { + test_util::FakeClockEnv env(Env::Default()); + Notification start_teardown, stop_teardown; + std::unique_ptr teardown_thread = + CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown); + { + double feedback = 0; + AdaptiveSharedBatchScheduler::Options options; + options.initial_scheduling_period_micros = 1000; + options.min_scheduling_period_micros = 200; + options.max_scheduling_period_micros = 2000; + options.env = &env; + options.scheduling_period_feedback = [&feedback] { return feedback; }; + options.feedback_smoothing_batches = 1; + std::shared_ptr> scheduler; + TF_ASSERT_OK( + AdaptiveSharedBatchScheduler::Create(options, &scheduler)); + std::unique_ptr> queue; + int scheduled_items = 0; + auto queue_callback = [&scheduled_items, + &env](std::unique_ptr> batch) { + ASSERT_TRUE(batch->IsClosed()); + EXPECT_GT(batch->num_tasks(), 0); + scheduled_items = 0; + for (int i = 0; i < batch->num_tasks(); i++) { + scheduled_items += batch->task(i).size(); + } + env.SleepForMicroseconds(1); + }; + + TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue)); + + // Wait for scheduling_thread to sleep. + env.BlockUntilThreadsAsleep(1); + // Enqueue 6 batches. + for (int i = 0; i < 6; i++) { + TF_ASSERT_OK(ScheduleTask(900 + i, queue.get())); + env.AdvanceByMicroseconds(1); + } + feedback = -500; + env.AdvanceByMicroseconds(994); + env.BlockUntilThreadsAsleep(2); // scheduling period = 500 usec. + EXPECT_EQ(scheduled_items, 900); + env.AdvanceByMicroseconds(500); + env.BlockUntilThreadsAsleep(2); // scheduling period = 250 usec. + EXPECT_EQ(scheduled_items, 901); + feedback = 0; + env.AdvanceByMicroseconds(250); + env.BlockUntilThreadsAsleep(2); // scheduling period = 250 usec. + EXPECT_EQ(scheduled_items, 902); + feedback = 10000; // large feedback should hit max_scheduling_period. + env.AdvanceByMicroseconds(250); + env.BlockUntilThreadsAsleep(2); // scheduling period = 2000 usec. + EXPECT_EQ(scheduled_items, 903); + feedback = -10000; // large feedback should hit min_scheduling_period. + env.AdvanceByMicroseconds(1999); + // No callback scheduled, only scheduling thread sleeping. + env.BlockUntilThreadsAsleep(1); + EXPECT_EQ(scheduled_items, 903); + env.AdvanceByMicroseconds(1); + env.BlockUntilThreadsAsleep(2); // scheduling period = 200 usec. + EXPECT_EQ(scheduled_items, 904); + env.AdvanceByMicroseconds(200); + env.BlockUntilThreadsAsleep(2); + EXPECT_EQ(scheduled_items, 905); + start_teardown.Notify(); + } + stop_teardown.Notify(); +} + +TEST(AdaptiveSharedBatchSchedulerTest, FeedbackSmoothing) { + test_util::FakeClockEnv env(Env::Default()); + Notification start_teardown, stop_teardown; + std::unique_ptr teardown_thread = + CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown); + { + double feedback = 0; + AdaptiveSharedBatchScheduler::Options options; + options.initial_scheduling_period_micros = 1000; + options.env = &env; + options.scheduling_period_feedback = [&feedback] { return feedback; }; + options.feedback_smoothing_batches = 3; + std::shared_ptr> scheduler; + TF_ASSERT_OK( + AdaptiveSharedBatchScheduler::Create(options, &scheduler)); + std::unique_ptr> queue; + int scheduled_items = 0; + auto queue_callback = [&scheduled_items, + &env](std::unique_ptr> batch) { + ASSERT_TRUE(batch->IsClosed()); + EXPECT_GT(batch->num_tasks(), 0); + scheduled_items = 0; + for (int i = 0; i < batch->num_tasks(); i++) { + scheduled_items += batch->task(i).size(); + } + env.SleepForMicroseconds(1); + }; + + TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue)); + + // Wait for scheduling_thread to sleep. + env.BlockUntilThreadsAsleep(1); + // Enqueue 4 batches. + for (int i = 0; i < 4; i++) { + TF_ASSERT_OK(ScheduleTask(900 + i, queue.get())); + env.AdvanceByMicroseconds(1); + } + feedback = -300; + env.AdvanceByMicroseconds(996); + env.BlockUntilThreadsAsleep(2); + // ewma_feedback = 100, scheduling_period = 900. + EXPECT_EQ(scheduled_items, 900); + env.AdvanceByMicroseconds(899); + // No callback scheduled, only scheduling thread sleeping. + env.BlockUntilThreadsAsleep(1); + EXPECT_EQ(scheduled_items, 900); + env.AdvanceByMicroseconds(1); + env.BlockUntilThreadsAsleep(2); + // ewma_feedback = 167, scheduling_period = 750. + EXPECT_EQ(scheduled_items, 901); + env.AdvanceByMicroseconds(749); + // No callback scheduled, only scheduling thread sleeping. + env.BlockUntilThreadsAsleep(1); + EXPECT_EQ(scheduled_items, 901); + feedback = 1000 / 3.; + env.AdvanceByMicroseconds(1); + env.BlockUntilThreadsAsleep(2); + // emwa_feedback = 0, scheduling_period = 750. + EXPECT_EQ(scheduled_items, 902); + env.AdvanceByMicroseconds(749); + // No callback scheduled, only scheduling thread sleeping. + env.BlockUntilThreadsAsleep(1); + EXPECT_EQ(scheduled_items, 902); + env.AdvanceByMicroseconds(1); + env.BlockUntilThreadsAsleep(2); + EXPECT_EQ(scheduled_items, 903); + start_teardown.Notify(); + } + stop_teardown.Notify(); +} + +TEST(AdaptiveSharedBatchSchedulerTest, QueueCapacityInfo) { + test_util::FakeClockEnv env(Env::Default()); + Notification start_teardown, stop_teardown; + std::unique_ptr teardown_thread = + CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown); + { + AdaptiveSharedBatchScheduler::Options options; + options.initial_scheduling_period_micros = 1000; + options.env = &env; + std::shared_ptr> scheduler; + TF_ASSERT_OK( + AdaptiveSharedBatchScheduler::Create(options, &scheduler)); + std::unique_ptr> queue; + int scheduled_items = 0; + auto queue_callback = [&scheduled_items, + &env](std::unique_ptr> batch) { + ASSERT_TRUE(batch->IsClosed()); + EXPECT_GT(batch->num_tasks(), 0); + scheduled_items = 0; + for (int i = 0; i < batch->num_tasks(); i++) { + scheduled_items += batch->task(i).size(); + } + env.SleepForMicroseconds(1); + }; + AdaptiveSharedBatchScheduler::QueueOptions queue_options; + queue_options.max_batch_size = 10; + queue_options.max_enqueued_batches = 10; + TF_ASSERT_OK(scheduler->AddQueue(queue_options, queue_callback, &queue)); + + // Wait for scheduling_thread to sleep. + env.BlockUntilThreadsAsleep(1); + // Enqueue 3 tasks. + EXPECT_EQ(queue->NumEnqueuedTasks(), 0); + EXPECT_EQ(queue->SchedulingCapacity(), 100); + TF_ASSERT_OK(ScheduleTask(5, queue.get())); + EXPECT_EQ(queue->NumEnqueuedTasks(), 1); + EXPECT_EQ(queue->SchedulingCapacity(), 95); + env.AdvanceByMicroseconds(1); + TF_ASSERT_OK(ScheduleTask(6, queue.get())); + EXPECT_EQ(queue->NumEnqueuedTasks(), 2); + EXPECT_EQ(queue->SchedulingCapacity(), 84); + env.AdvanceByMicroseconds(1); + TF_ASSERT_OK(ScheduleTask(1, queue.get())); + EXPECT_EQ(queue->NumEnqueuedTasks(), 3); + EXPECT_EQ(queue->SchedulingCapacity(), 83); + + env.AdvanceByMicroseconds(998); + env.BlockUntilThreadsAsleep(2); + EXPECT_EQ(scheduled_items, 5); + env.AdvanceByMicroseconds(1000); + env.BlockUntilThreadsAsleep(2); + EXPECT_EQ(scheduled_items, 7); + start_teardown.Notify(); + } + stop_teardown.Notify(); +} +} // namespace anonymous +} // namespace serving +} // namespace tensorflow diff --git a/tensorflow/contrib/batching/batch_scheduler.h b/tensorflow/contrib/batching/batch_scheduler.h index 7c41ad88180badd37398f5bae057dcd0006922c3..a5072f439abad3c5db79a514a7f2baff0b021b39 100644 --- a/tensorflow/contrib/batching/batch_scheduler.h +++ b/tensorflow/contrib/batching/batch_scheduler.h @@ -78,7 +78,7 @@ template class Batch { public: Batch() = default; - ~Batch(); // Blocks until the batch is closed. + virtual ~Batch(); // Blocks until the batch is closed. // Appends 'task' to the batch. After calling AddTask(), the newly-added task // can be accessed via task(num_tasks()-1) or mutable_task(num_tasks()-1). diff --git a/tensorflow/contrib/cmake/external/cub.cmake b/tensorflow/contrib/cmake/external/cub.cmake index 7b263806d733f0e1deafe3e8fdd9baf2bb6fd81f..836889895567f679d9960e29ece1600d1a7a58eb 100644 --- a/tensorflow/contrib/cmake/external/cub.cmake +++ b/tensorflow/contrib/cmake/external/cub.cmake @@ -14,7 +14,7 @@ # ============================================================================== include (ExternalProject) -set(cub_URL https://github.com/NVlabs/cub/archive/1.7.4.zip) +set(cub_URL https://mirror.bazel.build/github.com/NVlabs/cub/archive/1.7.4.zip) set(cub_HASH SHA256=20a1a39fd97e5da7f40f5f2e7fd73fd2ea59f9dc4bb8a6c5f228aa543e727e31) set(cub_BUILD ${CMAKE_CURRENT_BINARY_DIR}/cub/src/cub) set(cub_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/cub/src/cub) diff --git a/tensorflow/contrib/cmake/external/gif.cmake b/tensorflow/contrib/cmake/external/gif.cmake index 5cb719b8787781084335779960887613df90217d..3d53c51fffcec1602a3b5553cdf3b225e3b0ae46 100644 --- a/tensorflow/contrib/cmake/external/gif.cmake +++ b/tensorflow/contrib/cmake/external/gif.cmake @@ -15,7 +15,7 @@ include (ExternalProject) set(gif_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/gif_archive/giflib-5.1.4/) -set(gif_URL http://mirror.bazel.build/ufpr.dl.sourceforge.net/project/giflib/giflib-5.1.4.tar.gz) +set(gif_URL https://mirror.bazel.build/ufpr.dl.sourceforge.net/project/giflib/giflib-5.1.4.tar.gz) set(gif_HASH SHA256=34a7377ba834397db019e8eb122e551a49c98f49df75ec3fcc92b9a794a4f6d1) set(gif_INSTALL ${CMAKE_BINARY_DIR}/gif/install) set(gif_BUILD ${CMAKE_BINARY_DIR}/gif/src/gif) diff --git a/tensorflow/contrib/cmake/external/jpeg.cmake b/tensorflow/contrib/cmake/external/jpeg.cmake index 058f554b8f2ffc4f925012e8772c684965304833..d9a165e856c588880ebdf996666d70c9e7f53da8 100644 --- a/tensorflow/contrib/cmake/external/jpeg.cmake +++ b/tensorflow/contrib/cmake/external/jpeg.cmake @@ -15,7 +15,7 @@ include (ExternalProject) set(jpeg_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/jpeg_archive) -set(jpeg_URL http://mirror.bazel.build/www.ijg.org/files/jpegsrc.v9a.tar.gz) +set(jpeg_URL https://mirror.bazel.build/www.ijg.org/files/jpegsrc.v9a.tar.gz) set(jpeg_HASH SHA256=3a753ea48d917945dd54a2d97de388aa06ca2eb1066cbfdc6652036349fe05a7) set(jpeg_BUILD ${CMAKE_CURRENT_BINARY_DIR}/jpeg/src/jpeg) set(jpeg_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/jpeg/install) diff --git a/tensorflow/contrib/cmake/external/lmdb.cmake b/tensorflow/contrib/cmake/external/lmdb.cmake index 28ec833babe8f8e600c7c0179dff511ce4d26105..79971b7cfc3c72e4b6290ccb71d40a20d1180c01 100644 --- a/tensorflow/contrib/cmake/external/lmdb.cmake +++ b/tensorflow/contrib/cmake/external/lmdb.cmake @@ -15,7 +15,7 @@ include (ExternalProject) set(lmdb_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/lmdb) -set(lmdb_URL http://mirror.bazel.build/github.com/LMDB/lmdb/archive/LMDB_0.9.19.tar.gz) +set(lmdb_URL https://mirror.bazel.build/github.com/LMDB/lmdb/archive/LMDB_0.9.19.tar.gz) set(lmdb_HASH SHA256=108532fb94c6f227558d45be3f3347b52539f0f58290a7bb31ec06c462d05326) set(lmdb_BUILD ${CMAKE_BINARY_DIR}/lmdb/src/lmdb) set(lmdb_INSTALL ${CMAKE_BINARY_DIR}/lmdb/install) diff --git a/tensorflow/contrib/cmake/external/snappy.cmake b/tensorflow/contrib/cmake/external/snappy.cmake index a35d8654fb6fa5f5b5d230ffbc061d050e5aeb5e..2d2451521c0f9127e2c76e6270694ac21fe8db93 100644 --- a/tensorflow/contrib/cmake/external/snappy.cmake +++ b/tensorflow/contrib/cmake/external/snappy.cmake @@ -47,4 +47,4 @@ ExternalProject_Add(snappy ) # actually enables snappy in the source code -add_definitions(-DSNAPPY) \ No newline at end of file +add_definitions(-DTF_USE_SNAPPY) diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD index 0c61630aa8f79e3efd25584478547abd99f30285..ace17424fefd4235c6b3ac49c75c81929c1c47a0 100644 --- a/tensorflow/contrib/eager/python/BUILD +++ b/tensorflow/contrib/eager/python/BUILD @@ -86,7 +86,7 @@ cuda_py_test( "//tensorflow/python:client", "//tensorflow/python:client_testlib", "//tensorflow/python/eager:graph_callable", - "//tensorflow/python:platform_test", + "//tensorflow/python/eager:test", "//tensorflow/python:variables", ], ) @@ -132,11 +132,12 @@ py_library( "//tensorflow/python:array_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", - "//tensorflow/python:init_ops", "//tensorflow/python:layers_base", "//tensorflow/python:math_ops", "//tensorflow/python:util", "//tensorflow/python:variable_scope", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:function", ], ) @@ -146,6 +147,10 @@ py_test( srcs_version = "PY2AND3", deps = [ ":metrics", + "//tensorflow/python:array_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:variables", + "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", ], ) @@ -160,6 +165,8 @@ py_library( deps = [ ":datasets", ":metrics", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:function", ], ) diff --git a/tensorflow/contrib/eager/python/evaluator_test.py b/tensorflow/contrib/eager/python/evaluator_test.py index 099e10e2307b2e3c406ccf847fc8ee2bce9ce407..b18463c31a71be37d232a362544deb13c27a6c73 100644 --- a/tensorflow/contrib/eager/python/evaluator_test.py +++ b/tensorflow/contrib/eager/python/evaluator_test.py @@ -86,7 +86,7 @@ class EvaluatorTest(test.TestCase): for v in e.metric_variables: p = v.name.split("/")[0] prefix_count[p] = prefix_count.get(p, 0) + 1 - self.assertEqual({"outer-mean": 2, "mean": 2}, prefix_count) + self.assertEqual({"outer_mean": 2, "mean": 2}, prefix_count) def testDataset(self): e = SimpleEvaluator(IdentityModel()) diff --git a/tensorflow/contrib/eager/python/metrics_impl.py b/tensorflow/contrib/eager/python/metrics_impl.py index 63a0f8d9a45dfb12fd1d61a1156b9acf20cf4c81..2a624b218cc65a3d1e2d48bbb8bc1dacbbf461c0 100644 --- a/tensorflow/contrib/eager/python/metrics_impl.py +++ b/tensorflow/contrib/eager/python/metrics_impl.py @@ -18,6 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import re + +from tensorflow.python.eager import context +from tensorflow.python.eager import function from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops @@ -25,55 +29,69 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope +_to_replace = re.compile("[^A-Za-z0-9.]") + + class Metric(object): """A metric holds state for aggregating statistics over an evaluation run. Users will use Evaluator.add_metric() to add Metric objects to their - evaluation, call them in each step, and then use - Evaluator.all_metric_results() at the end. + evaluation, call them in each step (treating the object as a callable), + and then use Evaluator.all_metric_results() at the end. Descendants will implement: - * call(): Should follow this pattern: - if not self.built: - self.var = self.add_variable(...) - self.add_update(self.var.assign_add(...)) - * aggregate(): Adds in the state from a list of metrics of the same type - as `self`. (Default of summing all the variables will be fine for most - descendants.) - * result(): Computes and returns a final value for the metric + * `build()`: All variables should be created in this method, by calling + `self.add_variable()` as in: `self.var = self.add_variable(...)` + build() will be called in the first invocation of `__call__()`, with + the same arguments passed `call()`. + * `call()`: Has all updates to variables, as in: + self.var.assign_add(...) + * `result()`: Computes and returns a final value for the metric from the variables in `self`. + + Decendants may override, but usually won't need to: + * `aggregate()`: Adds in the state from a list of metrics of the same type + as `self`. (Default is to sum all the variables.) + * `reset()`: Reset all variables to their initial state. (Default is to + zero all the variables.) + Note that users should not call `aggregate()` or `reset()`, they are for + use by TensorFlow infrastructure. """ def __init__(self, name=None): - self.built = False + self._built = False self._vars = [] self._updates = [] - self._name = name or self.__class__.__name__ - # TODO(josh11b): Need some way to make sure two Metrics in the same - # Network have distinct names. Maybe we can get a unique name from - # a name/variable scope? - # TODO(josh11b): self._in_graph_mode = context.in_graph_mode() + name = name or self.__class__.__name__ + # Replace things like spaces in name to create a valid scope name. + scope_name = _to_replace.sub("_", name) + # We create the variable scope now to get the unique name that will + # be used as a variable prefix when build() calls add_variable(). + with variable_scope.variable_scope( + None, default_name=scope_name, use_resource=True, reuse=False) as scope: + pos = scope.name.rfind(scope_name) + self._name = name + scope.name[pos + len(scope_name):] + self._scope = scope + if context.in_graph_mode(): + # We make self.call() into a graph callable here, so that we can + # return a single op that performs all of the variable updates. + self.call = function.defun(self.call) # ---- API for users ---- def __call__(self, *args, **kwargs): - # TODO(josh11b): If self._in_graph_mode is true, make self.call() into a - # graph callable here, so that variable updates happen without requiring - # a separate fetch. - # TODO(josh11b): Do we need a separate build() method to separate - # initialization from each update? If so, how do we get the arguments - # to it? We *could* just pass in *args and **kwargs... - if not self.built: - # TODO(ashankar): Set up container isolation so there is no chance - # distinct metrics objects accidentally share variables. - # TODO(josh11b): Replace things like spaces in self._name to create - # a valid scope name. - with variable_scope.variable_scope( - self._name, use_resource=True, reuse=False): - ret = self.call(*args, **kwargs) - self.built = True - else: - ret = self.call(*args, **kwargs) - return ret + """Returns op to execute to update this metric for these inputs. + + Returns None if eager execution is enabled. + + Args: + *args: + **kwargs: A mini-batch of inputs to the Metric, passed on to `call()`. + """ + if not self._built: + with variable_scope.variable_scope(self._scope): + self.build(*args, **kwargs) + self._built = True + return self.call(*args, **kwargs) @property def name(self): @@ -84,10 +102,43 @@ class Metric(object): return self._vars # ---- To be implemented by descendants --- + def build(self, *args, **kwargs): + """Method to create variables. + + Called by `__call__()` before `call()` for the first time. + + Args: + *args: + **kwargs: The arguments to the first invocation of `__call__()`. + `build()` may use the shape and/or dtype of these arguments + when deciding how to create variables. + """ + raise NotImplementedError("Metrics must define a build() member function") + def call(self, *args, **kwargs): - """Accumulates statistics for the metric.""" + """Accumulates statistics for the metric. Users should use __call__ instead. + + Note: This function is executed as a graph function in graph mode. + This means: + a) Operations on the same resource are executed in textual order. + This should make it easier to do things like add the updated + value of a variable to another, for example. + b) You don't need to worry about collecting the update ops to execute. + All update ops added to the graph by this function will be executed. + As a result, code should generally work the same way with graph or + eager execution. + + Args: + *args: + **kwargs: A mini-batch of inputs to the Metric, as passed to + `__call__()`. + """ raise NotImplementedError("Metrics must define a call() member function") + def result(self): # TODO(josh11b): Add an optional summary_writer parameter. + """Computes and returns a final value for the metric.""" + raise NotImplementedError("Metrics must define a result() member function") + # We can support two different strategies of for doing data-parallel # distributed metric computations: # * Put metric variables on the first device and rely on small @@ -123,16 +174,19 @@ class Metric(object): self._vars[i].assign_add(math_ops.add_n([m._vars[i] for m in metrics])) # pylint: enable=protected-access - def result(self): # TODO(josh11b): Add an optional summary_writer parameter. - """Computes and returns a final value for the metric.""" - raise NotImplementedError("Metrics must define a result() member function") + def reset(self): + """Reset this metric to a freshly initialized state. + + Default implementation zeros all the metric variables. + """ + for v in self._vars: + v.assign(math_ops.zeros_like(v)) # ---- For use by descendants --- def add_variable(self, name, shape=None, dtype=None, initializer=None): """***Only for use by descendants of Metric***.""" - if self.built: - raise RuntimeError("Can't call add_variable() after a Metric has been " - "built in the first call().") + if self._built: + raise RuntimeError("Can't call add_variable() except in build().") v = variable_scope.get_variable(name, shape, dtype, initializer, trainable=False, use_resource=True) self._vars.append(v) @@ -144,6 +198,15 @@ class Mean(Metric): # TODO(josh11b): Maybe have a dtype argument that defaults to tf.float64? # Or defaults to type of the input if it is tf.float32, else tf.float64? + def build(self, values, weights=None): + del values, weights # build() does not use call's arguments + self.numer = self.add_variable(name="numer", shape=(), + dtype=dtypes.float64, + initializer=init_ops.zeros_initializer) + self.denom = self.add_variable(name="denom", shape=(), + dtype=dtypes.float64, + initializer=init_ops.zeros_initializer) + def call(self, values, weights=None): """Accumulate statistics for computing the mean. @@ -154,13 +217,6 @@ class Mean(Metric): values: Tensor with the per-example value. weights: Optional weighting of each example. Defaults to 1. """ - if not self.built: # False only in the first call(). - self.numer = self.add_variable(name="numer", shape=(), - dtype=dtypes.float64, - initializer=init_ops.zeros_initializer) - self.denom = self.add_variable(name="denom", shape=(), - dtype=dtypes.float64, - initializer=init_ops.zeros_initializer) if weights is None: self.denom.assign_add( math_ops.cast(array_ops.size(values), dtypes.float64)) @@ -179,6 +235,10 @@ class Mean(Metric): class Accuracy(Mean): """Calculates how often `predictions` matches `labels`.""" + def build(self, labels, predictions, weights=None): + del labels, predictions, weights + super(Accuracy, self).build(None) # Arguments are unused + def call(self, labels, predictions, weights=None): """Accumulate accuracy statistics. diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py index 089bad5a0e3049543bdc09b571319262a734809f..bfb79cd72e0551d7ee5172d871dbe24117fd9d2f 100644 --- a/tensorflow/contrib/eager/python/metrics_test.py +++ b/tensorflow/contrib/eager/python/metrics_test.py @@ -19,7 +19,11 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.eager.python import metrics +from tensorflow.python.eager import context from tensorflow.python.eager import test +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variables class MetricsTest(test.TestCase): @@ -56,6 +60,53 @@ class MetricsTest(test.TestCase): m([7], [2]) # 0 correct, weight 1 self.assertEqual(2.5/5, m.result().numpy()) + def testTwoMeans(self): + # Verify two metrics with the same class and name don't + # accidentally share state. + m1 = metrics.Mean() + m2 = metrics.Mean() + m1(0) + m2(2) + self.assertEqual(0, m1.result().numpy()) + self.assertEqual(2, m2.result().numpy()) + self.assertNotEqual(m1.name, m2.name) + + def testNamesWithSpaces(self): + # Verify two metrics with the same class and name don't + # accidentally share state. + m1 = metrics.Mean("has space") + m2 = metrics.Mean("has space") + m2(2) + m1(0) + self.assertEqual(m1.name, "has space") + self.assertEqual(m1.numer.name, "has_space/numer:0") + self.assertEqual(m2.name, "has space_1") + self.assertEqual(m2.numer.name, "has_space_1/numer:0") + + def testGraph(self): + with context.graph_mode(), self.test_session() as sess: + m = metrics.Mean() + p = array_ops.placeholder(dtypes.float32) + accumulate = m(p) + variables.global_variables_initializer().run() + sess.run(accumulate, feed_dict={p: [1, 10, 100]}) + sess.run(accumulate, feed_dict={p: 1000}) + sess.run(accumulate, feed_dict={p: [10000, 100000]}) + self.assertAllEqual(m.result().eval(), 111111.0/6) + + def testTwoMeansGraph(self): + # Verify two metrics with the same class and name don't + # accidentally share state. + with context.graph_mode(), self.test_session() as sess: + m1 = metrics.Mean() + m2 = metrics.Mean() + accumulate1 = m1(0) + accumulate2 = m2(2) + variables.global_variables_initializer().run() + sess.run([accumulate1, accumulate2]) + self.assertEqual(0, m1.result().eval()) + self.assertEqual(2, m2.result().eval()) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/eager/python/saver_test.py b/tensorflow/contrib/eager/python/saver_test.py index 29af2b531f4dee7f46c1538ff23409ece5785ceb..1605435d8d78b604bd3e0eac3184b23a4ac13ab0 100644 --- a/tensorflow/contrib/eager/python/saver_test.py +++ b/tensorflow/contrib/eager/python/saver_test.py @@ -22,6 +22,7 @@ import os from tensorflow.contrib.eager.python import saver as _saver from tensorflow.python.eager import context from tensorflow.python.eager import graph_callable +from tensorflow.python.eager import test from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops @@ -29,7 +30,6 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope -from tensorflow.python.platform import test class SaverTest(test.TestCase): @@ -38,7 +38,7 @@ class SaverTest(test.TestCase): return '/device:GPU:0' if context.num_gpus() else '/device:CPU:0' def testBasics(self): - with context.eager_mode(), ops.device(self._dev()): + with ops.device(self._dev()): v1 = resource_variable_ops.ResourceVariable(1.0, name='v1') def model(): return array_ops.constant(2.0) * v1 @@ -54,8 +54,42 @@ class SaverTest(test.TestCase): saver.restore(ckpt_prefix) self.assertEqual(v1.read_value().numpy(), 1.0) - def testRestoreOnCreate(self): + def testSameNameNoClobbering(self): + with context.eager_mode(), ops.device(self._dev()): + # Note that this test purposefully uses Graphs rather than + # IsolateTest. Users are more likely to accidentally create the same + # variable name this way. + first_graph = ops.Graph() + with first_graph.as_default(): + v1_first_graph = resource_variable_ops.ResourceVariable(1.0, name='v1') + with ops.Graph().as_default(): + v1_second_graph = resource_variable_ops.ResourceVariable(2.0, name='v1') + saver = _saver.Saver([v1_first_graph, v1_second_graph]) + ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt') + with self.assertRaisesRegexp(ValueError, 'v1'): + saver.save(ckpt_prefix) + + def testDifferentGraphError(self): + with context.eager_mode(), ops.device(self._dev()): + with ops.Graph().as_default(): + v1 = resource_variable_ops.ResourceVariable(1.0, name='v1') + with ops.Graph().as_default(): + saver = _saver.Saver([v1]) + ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt') + with self.assertRaisesRegexp(ValueError, 'Graph'): + saver.save(ckpt_prefix) + + def testSameObjectOK(self): with context.eager_mode(), ops.device(self._dev()): + v1 = resource_variable_ops.ResourceVariable(1.0, name='v1') + # While different objects with the same shared_name are not good, passing + # in the same object multiple times is fine. + saver = _saver.Saver([v1, v1]) + ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt') + saver.save(ckpt_prefix) + + def testRestoreOnCreate(self): + with ops.device(self._dev()): def model(init_val): v1 = resource_variable_ops.ResourceVariable(init_val, name='v1') return array_ops.constant(1.0) * v1, v1 @@ -71,12 +105,9 @@ class SaverTest(test.TestCase): # Value is from checkpoint, but not from argument. ret, _ = model(2.0) self.assertEqual(ret.numpy(), 1.0) - # Create it a second time won't re-assign the checkpoint value. - v1_2 = resource_variable_ops.ResourceVariable(3.0, name='v1') - self.assertEqual(v1_2.read_value().numpy(), 3.0) def testRestoreNotFound(self): - with context.eager_mode(), ops.device(self._dev()): + with ops.device(self._dev()): def model(v): return array_ops.constant(1.0) * v @@ -92,7 +123,7 @@ class SaverTest(test.TestCase): _ = model(resource_variable_ops.ResourceVariable(1.0, name='v2')) def testSaveRestoreGraphCallable(self): - with context.eager_mode(), ops.device(self._dev()): + with ops.device(self._dev()): @graph_callable.graph_callable( [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)]) def model(x): diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py index 25942aadfbbe2805aa5baf4115d7fdaa3a8a687a..4ed258f6ffbded62db8576f0af6b84cc69443685 100644 --- a/tensorflow/contrib/eager/python/tfe.py +++ b/tensorflow/contrib/eager/python/tfe.py @@ -53,6 +53,7 @@ To use, at program startup, call `tfe.enable_eager_execution()`. @@in_eager_mode @@in_graph_mode +@@IsolateTest @@run_test_in_graph_and_eager_modes """ @@ -84,6 +85,7 @@ from tensorflow.python.eager.execution_callbacks import nan_callback from tensorflow.python.eager.execution_callbacks import seterr from tensorflow.python.framework.ops import enable_eager_execution from tensorflow.python.framework.ops import eager_run as run +from tensorflow.python.framework.test_util import IsolateTest from tensorflow.python.framework.test_util import run_in_graph_and_eager_modes as run_test_in_graph_and_eager_modes from tensorflow.python.ops.resource_variable_ops import ResourceVariable as Variable from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/contrib/factorization/g3doc/kmeans.md b/tensorflow/contrib/factorization/g3doc/kmeans.md index b55c9d09ad386b84623d3648c5be83cbba8bbff9..c1843f0bf0704503d43c28d186dc826f0677711f 100644 --- a/tensorflow/contrib/factorization/g3doc/kmeans.md +++ b/tensorflow/contrib/factorization/g3doc/kmeans.md @@ -24,7 +24,11 @@ the full-batch version. approach for computing the initial cluster assignments that is expensive but is typically less prone to getting stuck in bad local minima. -We provide distributed implementations of both full-batch and mini-batch -K-Means algorithm. Both K-Means++ and random initialization are supported. -The user can also choose between **Cosine** and **Squared Euclidean** distance -metrics. +**[k-MC2](https://www.aaai.org/ocs/index.php/AAAI/AAAI16/paper/view/12147/11759)** +provides a very fast seeding method that provides high quality centers +comparable to K-Means++ seeding. k-MC2 works particularly well if it is combined +with Mini-batch K-Means. + +We provide distributed implementations of both full-batch and mini-batch K-Means +algorithm. K-Means++, k-MC2 and random initialization are supported. The user +can also choose between **Cosine** and **Squared Euclidean** distance metrics. diff --git a/tensorflow/contrib/factorization/kernels/clustering_ops.cc b/tensorflow/contrib/factorization/kernels/clustering_ops.cc index a2136c08bbc2e91f4587b1cdacbfe3b1d1073949..dd61f59585aee2e0245cfd6797b313b972c19bc5 100644 --- a/tensorflow/contrib/factorization/kernels/clustering_ops.cc +++ b/tensorflow/contrib/factorization/kernels/clustering_ops.cc @@ -224,6 +224,58 @@ class KmeansPlusPlusInitializationOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("KmeansPlusPlusInitialization").Device(DEVICE_CPU), KmeansPlusPlusInitializationOp); +// Implementation of one single Markov Chain for the k-MC^2 algorithm +class KMC2ChainInitializationOp : public OpKernel { + public: + explicit KMC2ChainInitializationOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, + context->MatchSignature({DT_FLOAT, DT_INT64}, {DT_INT64})); + } + + void Compute(OpKernelContext* context) override { + const Tensor& distances_tensor = context->input(0); + const Tensor& seed_tensor = context->input(1); + OP_REQUIRES(context, TensorShapeUtils::IsVector(distances_tensor.shape()), + InvalidArgument("Input distances should be a vector.")); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(seed_tensor.shape()), + InvalidArgument("Input seed should be a scalar.")); + const int64 num_points = distances_tensor.dim_size(0); + const int64 seed = seed_tensor.scalar()(); + OP_REQUIRES(context, num_points > 0, + InvalidArgument("Expected distances_tensor.size() > 0.")); + + random::PhiloxRandom random(seed); + random::SimplePhilox rng(&random); + + auto distances = distances_tensor.flat(); + // Set the initial state of the Markov chain to be the first candidate. + int64 selected_index = 0; + float selected_distance = distances(selected_index); + // Build a Markov chain of length num_points. + for (int64 i = 1; i < num_points; ++i) { + const float candidate_distance = distances(i); + // Set the next state of the Markov chain to be the candidate with + // probability min(1, candidate_distance/selected_distance). + if (candidate_distance > rng.RandFloat() * selected_distance) { + selected_index = i; + selected_distance = candidate_distance; + } + } + + Tensor* output_sampled_index_tensor; + OP_REQUIRES_OK(context, + context->allocate_output(0, TensorShape({}), + &output_sampled_index_tensor)); + auto output = output_sampled_index_tensor->scalar(); + // Return the last state of the Markov chain as the new center. + output() = selected_index; + } +}; + +REGISTER_KERNEL_BUILDER(Name("KMC2ChainInitialization").Device(DEVICE_CPU), + KMC2ChainInitializationOp); + // Operator for computing the nearest neighbors for a set of points. class NearestNeighborsOp : public OpKernel { public: diff --git a/tensorflow/contrib/factorization/kernels/clustering_ops_test.cc b/tensorflow/contrib/factorization/kernels/clustering_ops_test.cc index c4a96b048db878169acc69b4d8caed5d4e04c18f..8172a7cebb81de70c530dbdd9ce0ca3eda4bc2ce 100644 --- a/tensorflow/contrib/factorization/kernels/clustering_ops_test.cc +++ b/tensorflow/contrib/factorization/kernels/clustering_ops_test.cc @@ -116,6 +116,62 @@ RUN_BM_KmeansPlusPlusInitialization(k3RetriesPerSample); #undef RUN_BM_KmeansPlusPlusInitialization #undef BENCHMARK_KMEANS_PLUS_PLUS +Graph* SetUpKMC2Initialization(int num_points) { + Graph* g = new Graph(OpRegistry::Global()); + Tensor distances(DT_FLOAT, TensorShape({num_points})); + Tensor seed(DT_INT64, TensorShape({})); + distances.flat().setRandom(); + seed.flat().setConstant(12345); + + TF_CHECK_OK( + NodeBuilder("KMC2ChainInitializationOp", "KMC2ChainInitialization") + .Input(test::graph::Constant(g, distances)) + .Input(test::graph::Constant(g, seed)) + .Finalize(g, nullptr /* node */)); + return g; +} + +template +void BM_KMC2Initialization(int iters) { + testing::StopTiming(); + testing::ItemsProcessed(static_cast(iters) * num_points * num_dims * + num_to_sample); + testing::UseRealTime(); + Graph* g = SetUpKMC2Initialization(num_points); + testing::StartTiming(); + test::Benchmark("cpu", g).Run(iters); +} +#define BENCHMARK_KMC2(p, c, d) \ + void BM_KMC2Initialization_##p##_##c##_##d(int iters) { \ + BM_KMC2Initialization(iters); \ + } \ + BENCHMARK(BM_KMC2Initialization_##p##_##c##_##d); + +#define RUN_BM_KMC2Initialization \ + BENCHMARK_KMC2(k10Points, k2Centers, k100Dim); \ + BENCHMARK_KMC2(k10Points, k5Centers, k100Dim); \ + BENCHMARK_KMC2(k10Points, k10Centers, k100Dim); \ + BENCHMARK_KMC2(k100Points, k10Centers, k100Dim); \ + BENCHMARK_KMC2(k100Points, k20Centers, k100Dim); \ + BENCHMARK_KMC2(k100Points, k50Centers, k100Dim); \ + BENCHMARK_KMC2(k100Points, k100Centers, k100Dim); \ + BENCHMARK_KMC2(k1kPoints, k100Centers, k100Dim); \ + BENCHMARK_KMC2(k1kPoints, k200Centers, k100Dim); \ + BENCHMARK_KMC2(k1kPoints, k500Centers, k100Dim); \ + BENCHMARK_KMC2(k1kPoints, k1kCenters, k100Dim); \ + BENCHMARK_KMC2(k10kPoints, k100Centers, k100Dim); \ + BENCHMARK_KMC2(k10kPoints, k200Centers, k100Dim); \ + BENCHMARK_KMC2(k10kPoints, k500Centers, k100Dim); \ + BENCHMARK_KMC2(k10kPoints, k1kCenters, k100Dim); \ + BENCHMARK_KMC2(k1MPoints, k100Centers, k100Dim); \ + BENCHMARK_KMC2(k1MPoints, k200Centers, k100Dim); \ + BENCHMARK_KMC2(k1MPoints, k500Centers, k100Dim); \ + BENCHMARK_KMC2(k1MPoints, k1kCenters, k100Dim) + +RUN_BM_KMC2Initialization; +#undef RUN_BM_KMC2Initialization +#undef BENCHMARK_KMC2 + Graph* SetUpNearestNeighbors(int num_dims, int num_points, int num_centers, int k) { Graph* g = new Graph(OpRegistry::Global()); diff --git a/tensorflow/contrib/factorization/ops/clustering_ops.cc b/tensorflow/contrib/factorization/ops/clustering_ops.cc index f2dfcf7ed0fb05264b10dee9980a246a5f2e49fa..2686702c1d5768f661dac610c96089eb02e360d7 100644 --- a/tensorflow/contrib/factorization/ops/clustering_ops.cc +++ b/tensorflow/contrib/factorization/ops/clustering_ops.cc @@ -44,6 +44,25 @@ num_retries_per_sample: Scalar. For each row that is sampled, this parameter samples: Matrix of shape (num_to_sample, d). The sampled rows. )"); +REGISTER_OP("KMC2ChainInitialization") + .Input("distances: float32") + .Input("seed: int64") + .Output("index: int64") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"( +Returns the index of a data point that should be added to the seed set. + +Entries in distances are assumed to be squared distances of candidate points to +the already sampled centers in the seed set. The op constructs one Markov chain +of the k-MC^2 algorithm and returns the index of one candidate point to be added +as an additional cluster center. + +distances: Vector with squared distances to the closest previously sampled + cluster center for each candidate point. +seed: Scalar. Seed for initializing the random number generator. +index: Scalar with the index of the sampled point. +)"); + REGISTER_OP("NearestNeighbors") .Input("points: float32") .Input("centers: float32") diff --git a/tensorflow/contrib/factorization/python/kernel_tests/clustering_ops_test.py b/tensorflow/contrib/factorization/python/kernel_tests/clustering_ops_test.py index 450f64063a2a357e422cd14761864d511c0e6cce..1322f7ce5f83d82c76040a30699137cd2bf491b5 100644 --- a/tensorflow/contrib/factorization/python/kernel_tests/clustering_ops_test.py +++ b/tensorflow/contrib/factorization/python/kernel_tests/clustering_ops_test.py @@ -55,6 +55,63 @@ class KmeansPlusPlusInitializationTest(test.TestCase): self.runTestWithSeed(seed) +class KMC2InitializationTest(test.TestCase): + + def runTestWithSeed(self, seed): + with self.test_session(): + distances = np.zeros(1000).astype(np.float32) + distances[6] = 10e7 + distances[4] = 10e3 + + sampled_point = clustering_ops.kmc2_chain_initialization(distances, seed) + self.assertEquals(sampled_point.eval(), 6) + distances[6] = 0.0 + sampled_point = clustering_ops.kmc2_chain_initialization(distances, seed) + self.assertEquals(sampled_point.eval(), 4) + + def testBasic(self): + for seed in range(100): + self.runTestWithSeed(seed) + + +class KMC2InitializationLargeTest(test.TestCase): + + def setUp(self): + self._distances = np.zeros(1001) + self._distances[500] = 100.0 + self._distances[1000] = 50.0 + + def testBasic(self): + with self.test_session(): + counts = {} + seed = 0 + for i in range(50): + sample = clustering_ops.kmc2_chain_initialization( + self._distances, seed + i).eval() + counts[sample] = counts.get(sample, 0) + 1 + self.assertEquals(len(counts), 2) + self.assertTrue(500 in counts) + self.assertTrue(1000 in counts) + self.assertGreaterEqual(counts[500], 5) + self.assertGreaterEqual(counts[1000], 5) + + +class KMC2InitializationCornercaseTest(test.TestCase): + + def setUp(self): + self._distances = np.zeros(10) + + def runTestWithSeed(self, seed): + with self.test_session(): + sampled_point = clustering_ops.kmc2_chain_initialization( + self._distances, seed) + self.assertEquals(sampled_point.eval(), 0) + + def testBasic(self): + for seed in range(100): + self.runTestWithSeed(seed) + + # A simple test that can be verified by hand. class NearestCentersTest(test.TestCase): diff --git a/tensorflow/contrib/factorization/python/ops/clustering_ops.py b/tensorflow/contrib/factorization/python/ops/clustering_ops.py index d7320aeb3def08d23a256dcfee242bb4ecd9b6bd..96cc80ce241347ebca5b68140f1b1c8b9898ae72 100644 --- a/tensorflow/contrib/factorization/python/ops/clustering_ops.py +++ b/tensorflow/contrib/factorization/python/ops/clustering_ops.py @@ -50,6 +50,7 @@ COSINE_DISTANCE = 'cosine' RANDOM_INIT = 'random' KMEANS_PLUS_PLUS_INIT = 'kmeans_plus_plus' +KMC2_INIT = 'kmc2' # The name of the variable holding the cluster centers. Used by the Estimator. CLUSTERS_VAR_NAME = 'clusters' @@ -66,7 +67,8 @@ class KMeans(object): use_mini_batch=False, mini_batch_steps_per_iteration=1, random_seed=0, - kmeans_plus_plus_num_retries=2): + kmeans_plus_plus_num_retries=2, + kmc2_chain_length=200): """Creates an object for generating KMeans clustering graph. This class implements the following variants of K-means algorithm: @@ -95,7 +97,8 @@ class KMeans(object): exactly like a full-batch version. Args: - inputs: An input tensor or list of input tensors + inputs: An input tensor or list of input tensors. It is assumed that the + data points have been previously randomly permuted. num_clusters: An integer tensor specifying the number of clusters. This argument is ignored if initial_clusters is a tensor or numpy array. initial_clusters: Specifies the clusters used during initialization. One @@ -104,6 +107,7 @@ class KMeans(object): - a function f(inputs, k) that returns up to k centers from `inputs`. - "random": Choose centers randomly from `inputs`. - "kmeans_plus_plus": Use kmeans++ to choose centers from `inputs`. + - "kmc2": Use the fast k-MC2 algorithm to choose centers from `inputs`. In the last three cases, one batch of `inputs` may not yield `num_clusters` centers, in which case initialization will require multiple batches until enough centers are chosen. In the case of @@ -121,13 +125,17 @@ class KMeans(object): additional points to draw from the current distribution before selecting the best. If a negative value is specified, a heuristic is used to sample O(log(num_to_sample)) additional points. + kmc2_chain_length: Determines how many candidate points are used by the + k-MC2 algorithm to produce one new cluster centers. If a (mini-)batch + contains less points, one new cluster center is generated from the + (mini-)batch. Raises: ValueError: An invalid argument was passed to initial_clusters or distance_metric. """ if isinstance(initial_clusters, str) and initial_clusters not in [ - RANDOM_INIT, KMEANS_PLUS_PLUS_INIT + RANDOM_INIT, KMEANS_PLUS_PLUS_INIT, KMC2_INIT ]: raise ValueError( "Unsupported initialization algorithm '%s'" % initial_clusters) @@ -141,6 +149,7 @@ class KMeans(object): self._mini_batch_steps_per_iteration = int(mini_batch_steps_per_iteration) self._random_seed = random_seed self._kmeans_plus_plus_num_retries = kmeans_plus_plus_num_retries + self._kmc2_chain_length = kmc2_chain_length @classmethod def _distance_graph(cls, inputs, clusters, distance_metric): @@ -302,9 +311,10 @@ class KMeans(object): else: cluster_centers_updated = cluster_centers update_in_steps = None - cluster_counts = (variable_scope.variable( - array_ops.ones([num_clusters], dtype=dtypes.int64)) - if self._use_mini_batch else None) + cluster_counts = ( + variable_scope.variable( + array_ops.ones([num_clusters], dtype=dtypes.int64)) + if self._use_mini_batch else None) return (cluster_centers, cluster_centers_initialized, cluster_counts, cluster_centers_updated, update_in_steps) @@ -359,7 +369,7 @@ class KMeans(object): init_op = _InitializeClustersOpFactory( self._inputs, num_clusters, initial_clusters, self._distance_metric, self._random_seed, self._kmeans_plus_plus_num_retries, - cluster_centers_var, cluster_centers_updated, + self._kmc2_chain_length, cluster_centers_var, cluster_centers_updated, cluster_centers_initialized).op() cluster_centers = cluster_centers_var @@ -520,8 +530,9 @@ class KMeans(object): array_ops.reshape(array_ops.shape(inp)[0], [-1])), [-1, 1]), cluster_idx, num_clusters)) with ops.colocate_with(cluster_centers, ignore_existing=True): - new_clusters_centers = math_ops.add_n(cluster_sums) / (math_ops.cast( - math_ops.add_n(cluster_counts), cluster_sums[0].dtype) + epsilon) + new_clusters_centers = math_ops.add_n(cluster_sums) / ( + math_ops.cast(math_ops.add_n(cluster_counts), cluster_sums[0].dtype) + + epsilon) if self._clusters_l2_normalized(): new_clusters_centers = nn_impl.l2_normalize(new_clusters_centers, dim=1) return state_ops.assign(cluster_centers, new_clusters_centers) @@ -548,9 +559,12 @@ class _InitializeClustersOpFactory(object): cluster_centers_initialized := true """ + # TODO(ccolby): Refactor this class so that kmc2 isn't so much a special case. + def __init__(self, inputs, num_clusters, initial_clusters, distance_metric, - random_seed, kmeans_plus_plus_num_retries, cluster_centers, - cluster_centers_updated, cluster_centers_initialized): + random_seed, kmeans_plus_plus_num_retries, kmc2_chain_length, + cluster_centers, cluster_centers_updated, + cluster_centers_initialized): """Creates an op factory. Args: @@ -560,6 +574,7 @@ class _InitializeClustersOpFactory(object): distance_metric: See KMeans constructor. random_seed: See KMeans constructor. kmeans_plus_plus_num_retries: See KMeans constructor. + kmc2_chain_length: See KMeans constructor. cluster_centers: The TF variable holding the initial centers. It may already contain some centers when the op is executed. cluster_centers_updated: A second TF variable to hold a copy of the @@ -575,6 +590,7 @@ class _InitializeClustersOpFactory(object): self._distance_metric = distance_metric self._random_seed = random_seed self._kmeans_plus_plus_num_retries = kmeans_plus_plus_num_retries + self._kmc2_chain_length = kmc2_chain_length self._cluster_centers = cluster_centers self._cluster_centers_updated = cluster_centers_updated self._cluster_centers_initialized = cluster_centers_initialized @@ -604,6 +620,90 @@ class _InitializeClustersOpFactory(object): math_ops.to_int64(self._num_remaining), self._random_seed, self._kmeans_plus_plus_num_retries) + def _kmc2_multiple_centers(self): + """Adds new initial cluster centers using the k-MC2 algorithm. + + In each call to the op, the provided batch is split into subsets based on + the specified `kmc2_chain_length`. On each subset, a single Markov chain of + the k-MC2 algorithm is used to add *one* new center cluster center. If there + are less than `kmc2_chain_length` points in the subset, a single center is + added using one Markov chain on the full input. It is assumed that the + provided batch has previously been randomly permuted. Otherwise, k-MC2 may + return suboptimal centers. + + Returns: + An op that adds new cluster centers. + """ + # The op only operates on the first shard of data. + first_shard = self._inputs[0] + # Number of points in the input that can be used. + batch_size = array_ops.shape(first_shard)[0] + # Maximum number of subsets such that the size of each subset is at least + # `kmc2_chain_length`. Final subsets may be larger. + max_to_sample = math_ops.cast( + batch_size / self._kmc2_chain_length, dtype=dtypes.int32) + # We sample at least one new center and at most all remaining centers. + num_to_sample = math_ops.maximum( + math_ops.minimum(self._num_remaining, max_to_sample), 1) + + def _cond(i, _): + """Stopping condition for the while loop.""" + return math_ops.less(i, num_to_sample) + + def _body(i, _): + """Body that adds a single new center based on a subset.""" + + def _sample_random(): + """Returns a random point as a cluster center.""" + # By assumption the batch is reshuffled and _sample_random is always + # called for i=0. Hence, we simply return the first point. + new_center = array_ops.reshape(first_shard[0], [1, -1]) + if self._distance_metric == COSINE_DISTANCE: + new_center = nn_impl.l2_normalize(new_center, dim=1) + return new_center + + def _sample_kmc2_chain(): + """Returns previous centers as well as a new center sampled using k-MC2. + """ + # Extract the subset from the underlying batch. + start = i * self._kmc2_chain_length + end = start + self._kmc2_chain_length + subset = first_shard[start:end] + # Compute the distances from points in the subset to previous centers. + _, distances = gen_clustering_ops.nearest_neighbors( + subset, self._cluster_centers, 1) + # Sample index of new center using k-MC2 Markov chain. + new_center_index = gen_clustering_ops.kmc2_chain_initialization( + array_ops.squeeze(distances), self._random_seed) + # Extract actual new center. + newly_sampled_center = array_ops.reshape(subset[new_center_index], + [1, -1]) + # Return concatenation with previously sampled centers. + if self._distance_metric == COSINE_DISTANCE: + newly_sampled_center = nn_impl.l2_normalize( + newly_sampled_center, dim=1) + return array_ops.concat([self._cluster_centers, newly_sampled_center], + 0) + + # Obtain a random point if there are no previously sampled centers. + # Otherwise, construct a k-MC2 Markov chain. + new_centers = control_flow_ops.cond( + math_ops.equal(self._num_selected, 0), _sample_random, + _sample_kmc2_chain) + # Assign new cluster centers to underlying variable. + assigned_centers = state_ops.assign( + self._cluster_centers, new_centers, validate_shape=False) + if self._cluster_centers_updated is not self._cluster_centers: + assigned_centers = state_ops.assign( + self._cluster_centers_updated, + assigned_centers, + validate_shape=False) + return i + 1, self._num_clusters - array_ops.shape(assigned_centers)[0] + + # Add num_to_sample new data points. + _, num_remaining = control_flow_ops.while_loop(_cond, _body, [0, 0]) + return num_remaining + def _greedy_batch_sampler(self, sampler): # If the input dataset size is smaller than the number of centers # remaining, choose the entire input dataset as centers. This can happen @@ -657,7 +757,10 @@ class _InitializeClustersOpFactory(object): with ops.control_dependencies([ check_ops.assert_positive(self._num_remaining), ]): - num_now_remaining = self._add_new_centers() + if self._initial_clusters == KMC2_INIT: + num_now_remaining = self._kmc2_multiple_centers() + else: + num_now_remaining = self._add_new_centers() return control_flow_ops.cond( math_ops.equal(num_now_remaining, 0), lambda: state_ops.assign(self._cluster_centers_initialized, True), diff --git a/tensorflow/contrib/framework/__init__.py b/tensorflow/contrib/framework/__init__.py index 2081a11f47d71106f8e57227f46639717a791855..8421ba7c0423c6ed274f92ba74930822d0171e05 100644 --- a/tensorflow/contrib/framework/__init__.py +++ b/tensorflow/contrib/framework/__init__.py @@ -37,6 +37,7 @@ See the @{$python/contrib.framework} guide. @@arg_scope @@add_arg_scope +@@current_arg_scope @@has_arg_scope @@arg_scoped_arguments diff --git a/tensorflow/contrib/framework/python/ops/arg_scope.py b/tensorflow/contrib/framework/python/ops/arg_scope.py index 9c194ec202ab6150278b26e844b9d3e97a7d6761..2bce00fde2459878a12027bb4d98bd3818bc92a2 100644 --- a/tensorflow/contrib/framework/python/ops/arg_scope.py +++ b/tensorflow/contrib/framework/python/ops/arg_scope.py @@ -67,6 +67,7 @@ from tensorflow.python.util import tf_decorator __all__ = ['arg_scope', 'add_arg_scope', + 'current_arg_scope', 'has_arg_scope', 'arg_scoped_arguments'] @@ -83,7 +84,7 @@ def _get_arg_stack(): return _ARGSTACK -def _current_arg_scope(): +def current_arg_scope(): stack = _get_arg_stack() return stack[-1] @@ -144,7 +145,7 @@ def arg_scope(list_ops_or_scope, **kwargs): raise TypeError('list_ops_or_scope must either be a list/tuple or reused' 'scope (i.e. dict)') try: - current_scope = _current_arg_scope().copy() + current_scope = current_arg_scope().copy() for op in list_ops_or_scope: key_op = _key_op(op) if not has_arg_scope(op): @@ -172,7 +173,7 @@ def add_arg_scope(func): A tuple with the decorated function func_with_args(). """ def func_with_args(*args, **kwargs): - current_scope = _current_arg_scope() + current_scope = current_arg_scope() current_args = kwargs key_func = _key_op(func) if key_func in current_scope: diff --git a/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py b/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py index bdb88b89bb3dba95a229724994874b0a26b1fc3f..4b34fc62849766370979bb2002d42ee03ea7161a 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py @@ -442,7 +442,8 @@ def read_keyed_batch_features(file_pattern, feature_queue_capacity=100, num_enqueue_threads=2, parse_fn=None, - name=None): + name=None, + read_batch_size=None): """Adds operations to read, queue, batch and parse `Example` protos. Given file pattern (or list of files), will setup a queue for file names, @@ -482,6 +483,8 @@ def read_keyed_batch_features(file_pattern, parse_fn: Parsing function, takes `Example` Tensor returns parsed representation. If `None`, no parsing is done. name: Name of resulting op. + read_batch_size: An int or scalar `Tensor` specifying the number of + records to read at once. If `None`, defaults to `batch_size`. Returns: Returns tuple of: @@ -493,6 +496,7 @@ def read_keyed_batch_features(file_pattern, """ with ops.name_scope(name, 'read_batch_features', [file_pattern]) as scope: + if read_batch_size is None: read_batch_size = batch_size keys, examples = read_keyed_batch_examples( file_pattern, batch_size, @@ -501,7 +505,7 @@ def read_keyed_batch_features(file_pattern, num_epochs=num_epochs, queue_capacity=queue_capacity, num_threads=reader_num_threads, - read_batch_size=batch_size, + read_batch_size=read_batch_size, parse_fn=parse_fn, name=scope) # Parse the example. @@ -727,7 +731,8 @@ def read_batch_features(file_pattern, reader_num_threads=1, num_enqueue_threads=2, parse_fn=None, - name=None): + name=None, + read_batch_size=None): """Adds operations to read, queue, batch and parse `Example` protos. Given file pattern (or list of files), will setup a queue for file names, @@ -768,6 +773,8 @@ def read_batch_features(file_pattern, parse_fn: Parsing function, takes `Example` Tensor returns parsed representation. If `None`, no parsing is done. name: Name of resulting op. + read_batch_size: An int or scalar `Tensor` specifying the number of + records to read at once. If `None`, defaults to `batch_size`. Returns: A dict of `Tensor` or `SparseTensor` objects for each in `features`. @@ -786,6 +793,7 @@ def read_batch_features(file_pattern, reader_num_threads=reader_num_threads, feature_queue_capacity=feature_queue_capacity, num_enqueue_threads=num_enqueue_threads, + read_batch_size=read_batch_size, parse_fn=parse_fn, name=name) return features diff --git a/tensorflow/contrib/makefile/Makefile b/tensorflow/contrib/makefile/Makefile index be7c790ee9e11ca90c0756011003a919f7d930f8..3dcff3d4a3df63b482905221b91623a7c5e81b9b 100644 --- a/tensorflow/contrib/makefile/Makefile +++ b/tensorflow/contrib/makefile/Makefile @@ -502,6 +502,7 @@ $(wildcard tensorflow/core/platform/google/*) \ $(wildcard tensorflow/core/platform/google/*/*) \ $(wildcard tensorflow/core/platform/jpeg.*) \ $(wildcard tensorflow/core/platform/png.*) \ +$(wildcard tensorflow/core/platform/s3/*) \ $(wildcard tensorflow/core/platform/stream_executor.*) \ $(wildcard tensorflow/core/platform/windows/*) \ $(wildcard tensorflow/core/user_ops/*.cu.cc) \ diff --git a/tensorflow/contrib/makefile/download_dependencies.sh b/tensorflow/contrib/makefile/download_dependencies.sh index a63cd89e89eff674b71496336d3667de4fb41c7c..12e3f589306d54b10b38a48d8aed356de4ddc91b 100755 --- a/tensorflow/contrib/makefile/download_dependencies.sh +++ b/tensorflow/contrib/makefile/download_dependencies.sh @@ -20,11 +20,11 @@ DOWNLOADS_DIR=tensorflow/contrib/makefile/downloads BZL_FILE_PATH=tensorflow/workspace.bzl EIGEN_URL="$(grep -o 'http.*bitbucket.org/eigen/eigen/get/.*tar\.gz' "${BZL_FILE_PATH}" | grep -v bazel-mirror | head -n1)" -GEMMLOWP_URL="$(grep -o 'http://mirror.bazel.build/github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)" +GEMMLOWP_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)" GOOGLETEST_URL="https://github.com/google/googletest/archive/release-1.8.0.tar.gz" -NSYNC_URL="$(grep -o 'http://mirror.bazel.build/github.com/google/nsync/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" -PROTOBUF_URL="$(grep -o 'http://mirror.bazel.build/github.com/google/protobuf/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" -RE2_URL="$(grep -o 'http://mirror.bazel.build/github.com/google/re2/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" +NSYNC_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/nsync/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" +PROTOBUF_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/protobuf/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" +RE2_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/re2/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" FFT2D_URL="$(grep -o 'http.*fft\.tgz' "${BZL_FILE_PATH}" | grep -v bazel-mirror | head -n1)" # TODO(petewarden): Some new code in Eigen triggers a clang bug with iOS arm64, diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py index 85c8e9038ac5642d0dbb20aea968474e0d7aa5f4..09485c4fa2a894240d9ad83e9babc9c12b2db278 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py @@ -56,7 +56,10 @@ def _safe_div(numerator, denominator, name): name=name) -def _create_local(name, shape, collections=None, validate_shape=True, +def _create_local(name, + shape, + collections=None, + validate_shape=True, dtype=dtypes.float32): """Creates a new local variable. @@ -87,7 +90,9 @@ def _assert_weights_rank(weights, values): return check_ops.assert_rank_in(weights, (0, array_ops.rank(values))) -def _count_condition(values, weights=None, metrics_collections=None, +def _count_condition(values, + weights=None, + metrics_collections=None, updates_collections=None): """Sums the weights of cases where the given values are True. @@ -134,7 +139,9 @@ def _count_condition(values, weights=None, metrics_collections=None, return value_tensor, update_op -def streaming_true_positives(predictions, labels, weights=None, +def streaming_true_positives(predictions, + labels, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -168,12 +175,17 @@ def streaming_true_positives(predictions, labels, weights=None, tuple. """ return metrics.true_positives( - predictions=predictions, labels=labels, weights=weights, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) -def streaming_true_negatives(predictions, labels, weights=None, +def streaming_true_negatives(predictions, + labels, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -206,20 +218,22 @@ def streaming_true_negatives(predictions, labels, weights=None, either `metrics_collections` or `updates_collections` are not a list or tuple. """ - with variable_scope.variable_scope( - name, 'true_negatives', (predictions, labels, weights)): + with variable_scope.variable_scope(name, 'true_negatives', + (predictions, labels, weights)): predictions, labels, weights = _remove_squeezable_dimensions( predictions=math_ops.cast(predictions, dtype=dtypes.bool), labels=math_ops.cast(labels, dtype=dtypes.bool), weights=weights) - is_true_negative = math_ops.logical_and(math_ops.equal(labels, False), - math_ops.equal(predictions, False)) + is_true_negative = math_ops.logical_and( + math_ops.equal(labels, False), math_ops.equal(predictions, False)) return _count_condition(is_true_negative, weights, metrics_collections, updates_collections) -def streaming_false_positives(predictions, labels, weights=None, +def streaming_false_positives(predictions, + labels, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -253,12 +267,17 @@ def streaming_false_positives(predictions, labels, weights=None, tuple. """ return metrics.false_positives( - predictions=predictions, labels=labels, weights=weights, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) -def streaming_false_negatives(predictions, labels, weights=None, +def streaming_false_negatives(predictions, + labels, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -291,9 +310,12 @@ def streaming_false_negatives(predictions, labels, weights=None, or tuple. """ return metrics.false_negatives( - predictions=predictions, labels=labels, weights=weights, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) # TODO(ptucker): Move this somewhere common, to share with ops/losses/losses.py. @@ -317,17 +339,18 @@ def _broadcast_weights(weights, values): with ops.name_scope(None, 'broadcast_weights', (values, weights)) as scope: weights_shape = weights.get_shape() values_shape = values.get_shape() - if (weights_shape.is_fully_defined() and - values_shape.is_fully_defined() and + if (weights_shape.is_fully_defined() and values_shape.is_fully_defined() and weights_shape.is_compatible_with(values_shape)): return weights with ops.control_dependencies((_assert_weights_rank(weights, values),)): - return math_ops.multiply( - weights, array_ops.ones_like(values), name=scope) + return math_ops.multiply(weights, array_ops.ones_like(values), name=scope) -def streaming_mean(values, weights=None, metrics_collections=None, - updates_collections=None, name=None): +def streaming_mean(values, + weights=None, + metrics_collections=None, + updates_collections=None, + name=None): """Computes the (weighted) mean of the given values. The `streaming_mean` function creates two local variables, `total` and `count` @@ -365,12 +388,18 @@ def streaming_mean(values, weights=None, metrics_collections=None, or tuple. """ return metrics.mean( - values=values, weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + values=values, + weights=weights, + metrics_collections=metrics_collections, + updates_collections=updates_collections, + name=name) -def streaming_mean_tensor(values, weights=None, metrics_collections=None, - updates_collections=None, name=None): +def streaming_mean_tensor(values, + weights=None, + metrics_collections=None, + updates_collections=None, + name=None): """Computes the element-wise (weighted) mean of the given tensors. In contrast to the `streaming_mean` function which returns a scalar with the @@ -412,12 +441,18 @@ def streaming_mean_tensor(values, weights=None, metrics_collections=None, or tuple. """ return metrics.mean_tensor( - values=values, weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + values=values, + weights=weights, + metrics_collections=metrics_collections, + updates_collections=updates_collections, + name=name) -def streaming_accuracy(predictions, labels, weights=None, - metrics_collections=None, updates_collections=None, +def streaming_accuracy(predictions, + labels, + weights=None, + metrics_collections=None, + updates_collections=None, name=None): """Calculates how often `predictions` matches `labels`. @@ -462,13 +497,19 @@ def streaming_accuracy(predictions, labels, weights=None, tuple. """ return metrics.accuracy( - predictions=predictions, labels=labels, weights=weights, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) -def streaming_precision(predictions, labels, weights=None, - metrics_collections=None, updates_collections=None, +def streaming_precision(predictions, + labels, + weights=None, + metrics_collections=None, + updates_collections=None, name=None): """Computes the precision of the predictions with respect to the labels. @@ -512,13 +553,19 @@ def streaming_precision(predictions, labels, weights=None, tuple. """ return metrics.precision( - predictions=predictions, labels=labels, weights=weights, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) -def streaming_recall(predictions, labels, weights=None, - metrics_collections=None, updates_collections=None, +def streaming_recall(predictions, + labels, + weights=None, + metrics_collections=None, + updates_collections=None, name=None): """Computes the recall of the predictions with respect to the labels. @@ -560,12 +607,17 @@ def streaming_recall(predictions, labels, weights=None, tuple. """ return metrics.recall( - predictions=predictions, labels=labels, weights=weights, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) -def _true_negatives(labels, predictions, weights=None, +def _true_negatives(labels, + predictions, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -597,20 +649,22 @@ def _true_negatives(labels, predictions, weights=None, either `metrics_collections` or `updates_collections` are not a list or tuple. """ - with variable_scope.variable_scope( - name, 'true_negatives', (predictions, labels, weights)): + with variable_scope.variable_scope(name, 'true_negatives', + (predictions, labels, weights)): predictions, labels, weights = _remove_squeezable_dimensions( predictions=math_ops.cast(predictions, dtype=dtypes.bool), labels=math_ops.cast(labels, dtype=dtypes.bool), weights=weights) - is_true_negative = math_ops.logical_and(math_ops.equal(labels, False), - math_ops.equal(predictions, False)) + is_true_negative = math_ops.logical_and( + math_ops.equal(labels, False), math_ops.equal(predictions, False)) return _count_condition(is_true_negative, weights, metrics_collections, updates_collections) -def streaming_false_positive_rate(predictions, labels, weights=None, +def streaming_false_positive_rate(predictions, + labels, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -657,30 +711,35 @@ def streaming_false_positive_rate(predictions, labels, weights=None, either `metrics_collections` or `updates_collections` are not a list or tuple. """ - with variable_scope.variable_scope( - name, 'false_positive_rate', (predictions, labels, weights)): + with variable_scope.variable_scope(name, 'false_positive_rate', + (predictions, labels, weights)): predictions, labels, weights = _remove_squeezable_dimensions( predictions=math_ops.cast(predictions, dtype=dtypes.bool), labels=math_ops.cast(labels, dtype=dtypes.bool), weights=weights) false_p, false_positives_update_op = metrics.false_positives( - labels, predictions, weights, metrics_collections=None, - updates_collections=None, name=None) + labels, + predictions, + weights, + metrics_collections=None, + updates_collections=None, + name=None) true_n, true_negatives_update_op = _true_negatives( - labels, predictions, weights, metrics_collections=None, - updates_collections=None, name=None) + labels, + predictions, + weights, + metrics_collections=None, + updates_collections=None, + name=None) def compute_fpr(fp, tn, name): return array_ops.where( - math_ops.greater(fp + tn, 0), - math_ops.div(fp, fp + tn), - 0, - name) + math_ops.greater(fp + tn, 0), math_ops.div(fp, fp + tn), 0, name) fpr = compute_fpr(false_p, true_n, 'value') - update_op = compute_fpr( - false_positives_update_op, true_negatives_update_op, 'update_op') + update_op = compute_fpr(false_positives_update_op, true_negatives_update_op, + 'update_op') if metrics_collections: ops.add_to_collections(metrics_collections, fpr) @@ -691,7 +750,9 @@ def streaming_false_positive_rate(predictions, labels, weights=None, return fpr, update_op -def streaming_false_negative_rate(predictions, labels, weights=None, +def streaming_false_negative_rate(predictions, + labels, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -738,30 +799,35 @@ def streaming_false_negative_rate(predictions, labels, weights=None, either `metrics_collections` or `updates_collections` are not a list or tuple. """ - with variable_scope.variable_scope( - name, 'false_negative_rate', (predictions, labels, weights)): + with variable_scope.variable_scope(name, 'false_negative_rate', + (predictions, labels, weights)): predictions, labels, weights = _remove_squeezable_dimensions( predictions=math_ops.cast(predictions, dtype=dtypes.bool), labels=math_ops.cast(labels, dtype=dtypes.bool), weights=weights) false_n, false_negatives_update_op = metrics.false_negatives( - labels, predictions, weights, metrics_collections=None, - updates_collections=None, name=None) + labels, + predictions, + weights, + metrics_collections=None, + updates_collections=None, + name=None) true_p, true_positives_update_op = metrics.true_positives( - labels, predictions, weights, metrics_collections=None, - updates_collections=None, name=None) + labels, + predictions, + weights, + metrics_collections=None, + updates_collections=None, + name=None) def compute_fnr(fn, tp, name): return array_ops.where( - math_ops.greater(fn + tp, 0), - math_ops.div(fn, fn + tp), - 0, - name) + math_ops.greater(fn + tp, 0), math_ops.div(fn, fn + tp), 0, name) fnr = compute_fnr(false_n, true_p, 'value') - update_op = compute_fnr( - false_negatives_update_op, true_positives_update_op, 'update_op') + update_op = compute_fnr(false_negatives_update_op, true_positives_update_op, + 'update_op') if metrics_collections: ops.add_to_collections(metrics_collections, fnr) @@ -772,8 +838,11 @@ def streaming_false_negative_rate(predictions, labels, weights=None, return fnr, update_op -def _streaming_confusion_matrix_at_thresholds( - predictions, labels, thresholds, weights=None, includes=None): +def _streaming_confusion_matrix_at_thresholds(predictions, + labels, + thresholds, + weights=None, + includes=None): """Computes true_positives, false_negatives, true_negatives, false_positives. This function creates up to four local variables, `true_positives`, @@ -861,8 +930,8 @@ def _streaming_confusion_matrix_at_thresholds( if weights is not None: broadcast_weights = weights_broadcast_ops.broadcast_weights( math_ops.to_float(weights), predictions) - weights_tiled = array_ops.tile(array_ops.reshape( - broadcast_weights, [1, -1]), [num_thresholds, 1]) + weights_tiled = array_ops.tile( + array_ops.reshape(broadcast_weights, [1, -1]), [num_thresholds, 1]) thresh_tiled.get_shape().assert_is_compatible_with( weights_tiled.get_shape()) else: @@ -877,8 +946,9 @@ def _streaming_confusion_matrix_at_thresholds( math_ops.logical_and(label_is_pos, pred_is_pos)) if weights_tiled is not None: is_true_positive *= weights_tiled - update_ops['tp'] = state_ops.assign_add( - true_positives, math_ops.reduce_sum(is_true_positive, 1)) + update_ops['tp'] = state_ops.assign_add(true_positives, + math_ops.reduce_sum( + is_true_positive, 1)) values['tp'] = true_positives if 'fn' in includes: @@ -887,8 +957,9 @@ def _streaming_confusion_matrix_at_thresholds( math_ops.logical_and(label_is_pos, pred_is_neg)) if weights_tiled is not None: is_false_negative *= weights_tiled - update_ops['fn'] = state_ops.assign_add( - false_negatives, math_ops.reduce_sum(is_false_negative, 1)) + update_ops['fn'] = state_ops.assign_add(false_negatives, + math_ops.reduce_sum( + is_false_negative, 1)) values['fn'] = false_negatives if 'tn' in includes: @@ -897,8 +968,9 @@ def _streaming_confusion_matrix_at_thresholds( math_ops.logical_and(label_is_neg, pred_is_neg)) if weights_tiled is not None: is_true_negative *= weights_tiled - update_ops['tn'] = state_ops.assign_add( - true_negatives, math_ops.reduce_sum(is_true_negative, 1)) + update_ops['tn'] = state_ops.assign_add(true_negatives, + math_ops.reduce_sum( + is_true_negative, 1)) values['tn'] = true_negatives if 'fp' in includes: @@ -907,36 +979,45 @@ def _streaming_confusion_matrix_at_thresholds( math_ops.logical_and(label_is_neg, pred_is_pos)) if weights_tiled is not None: is_false_positive *= weights_tiled - update_ops['fp'] = state_ops.assign_add( - false_positives, math_ops.reduce_sum(is_false_positive, 1)) + update_ops['fp'] = state_ops.assign_add(false_positives, + math_ops.reduce_sum( + is_false_positive, 1)) values['fp'] = false_positives return values, update_ops -def streaming_true_positives_at_thresholds( - predictions, labels, thresholds, weights=None): +def streaming_true_positives_at_thresholds(predictions, + labels, + thresholds, + weights=None): values, update_ops = _streaming_confusion_matrix_at_thresholds( predictions, labels, thresholds, weights=weights, includes=('tp',)) return values['tp'], update_ops['tp'] -def streaming_false_negatives_at_thresholds( - predictions, labels, thresholds, weights=None): +def streaming_false_negatives_at_thresholds(predictions, + labels, + thresholds, + weights=None): values, update_ops = _streaming_confusion_matrix_at_thresholds( predictions, labels, thresholds, weights=weights, includes=('fn',)) return values['fn'], update_ops['fn'] -def streaming_false_positives_at_thresholds( - predictions, labels, thresholds, weights=None): +def streaming_false_positives_at_thresholds(predictions, + labels, + thresholds, + weights=None): values, update_ops = _streaming_confusion_matrix_at_thresholds( predictions, labels, thresholds, weights=weights, includes=('fp',)) return values['fp'], update_ops['fp'] -def streaming_true_negatives_at_thresholds( - predictions, labels, thresholds, weights=None): +def streaming_true_negatives_at_thresholds(predictions, + labels, + thresholds, + weights=None): values, update_ops = _streaming_confusion_matrix_at_thresholds( predictions, labels, thresholds, weights=weights, includes=('tn',)) return values['tn'], update_ops['tn'] @@ -996,8 +1077,8 @@ def streaming_curve_points(labels=None, either `metrics_collections` or `updates_collections` are not a list or tuple. """ - with variable_scope.variable_scope(name, 'curve_points', (labels, predictions, - weights)): + with variable_scope.variable_scope(name, 'curve_points', + (labels, predictions, weights)): if curve != 'ROC' and curve != 'PR': raise ValueError('curve must be either ROC or PR, %s unknown' % (curve)) kepsilon = 1e-7 # to account for floating point imprecisions @@ -1038,9 +1119,14 @@ def streaming_curve_points(labels=None, return points, update_op -def streaming_auc(predictions, labels, weights=None, num_thresholds=200, - metrics_collections=None, updates_collections=None, - curve='ROC', name=None): +def streaming_auc(predictions, + labels, + weights=None, + num_thresholds=200, + metrics_collections=None, + updates_collections=None, + curve='ROC', + name=None): """Computes the approximate AUC via a Riemann sum. The `streaming_auc` function creates four local variables, `true_positives`, @@ -1097,14 +1183,24 @@ def streaming_auc(predictions, labels, weights=None, num_thresholds=200, tuple. """ return metrics.auc( - predictions=predictions, labels=labels, weights=weights, - metrics_collections=metrics_collections, num_thresholds=num_thresholds, - curve=curve, updates_collections=updates_collections, name=name) + predictions=predictions, + labels=labels, + weights=weights, + metrics_collections=metrics_collections, + num_thresholds=num_thresholds, + curve=curve, + updates_collections=updates_collections, + name=name) -def streaming_specificity_at_sensitivity( - predictions, labels, sensitivity, weights=None, num_thresholds=200, - metrics_collections=None, updates_collections=None, name=None): +def streaming_specificity_at_sensitivity(predictions, + labels, + sensitivity, + weights=None, + num_thresholds=200, + metrics_collections=None, + updates_collections=None, + name=None): """Computes the specificity at a given sensitivity. The `streaming_specificity_at_sensitivity` function creates four local @@ -1154,15 +1250,24 @@ def streaming_specificity_at_sensitivity( or `updates_collections` are not a list or tuple. """ return metrics.specificity_at_sensitivity( - sensitivity=sensitivity, num_thresholds=num_thresholds, - predictions=predictions, labels=labels, weights=weights, + sensitivity=sensitivity, + num_thresholds=num_thresholds, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) -def streaming_sensitivity_at_specificity( - predictions, labels, specificity, weights=None, num_thresholds=200, - metrics_collections=None, updates_collections=None, name=None): +def streaming_sensitivity_at_specificity(predictions, + labels, + specificity, + weights=None, + num_thresholds=200, + metrics_collections=None, + updates_collections=None, + name=None): """Computes the sensitivity at a given specificity. The `streaming_sensitivity_at_specificity` function creates four local @@ -1212,16 +1317,23 @@ def streaming_sensitivity_at_specificity( or `updates_collections` are not a list or tuple. """ return metrics.sensitivity_at_specificity( - specificity=specificity, num_thresholds=num_thresholds, - predictions=predictions, labels=labels, weights=weights, + specificity=specificity, + num_thresholds=num_thresholds, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) -def streaming_precision_at_thresholds(predictions, labels, thresholds, +def streaming_precision_at_thresholds(predictions, + labels, + thresholds, weights=None, metrics_collections=None, - updates_collections=None, name=None): + updates_collections=None, + name=None): """Computes precision values for different `thresholds` on `predictions`. The `streaming_precision_at_thresholds` function creates four local variables, @@ -1266,14 +1378,21 @@ def streaming_precision_at_thresholds(predictions, labels, thresholds, """ return metrics.precision_at_thresholds( thresholds=thresholds, - predictions=predictions, labels=labels, weights=weights, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) -def streaming_recall_at_thresholds(predictions, labels, thresholds, - weights=None, metrics_collections=None, - updates_collections=None, name=None): +def streaming_recall_at_thresholds(predictions, + labels, + thresholds, + weights=None, + metrics_collections=None, + updates_collections=None, + name=None): """Computes various recall values for different `thresholds` on `predictions`. The `streaming_recall_at_thresholds` function creates four local variables, @@ -1316,14 +1435,21 @@ def streaming_recall_at_thresholds(predictions, labels, thresholds, """ return metrics.recall_at_thresholds( thresholds=thresholds, - predictions=predictions, labels=labels, weights=weights, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) -def streaming_false_positive_rate_at_thresholds( - predictions, labels, thresholds, weights=None, metrics_collections=None, - updates_collections=None, name=None): +def streaming_false_positive_rate_at_thresholds(predictions, + labels, + thresholds, + weights=None, + metrics_collections=None, + updates_collections=None, + name=None): """Computes various fpr values for different `thresholds` on `predictions`. The `streaming_false_positive_rate_at_thresholds` function creates two @@ -1365,20 +1491,19 @@ def streaming_false_positive_rate_at_thresholds( either `metrics_collections` or `updates_collections` are not a list or tuple. """ - with variable_scope.variable_scope( - name, 'false_positive_rate_at_thresholds', - (predictions, labels, weights)): + with variable_scope.variable_scope(name, 'false_positive_rate_at_thresholds', + (predictions, labels, weights)): values, update_ops = _streaming_confusion_matrix_at_thresholds( predictions, labels, thresholds, weights, includes=('fp', 'tn')) # Avoid division by zero. epsilon = 1e-7 + def compute_fpr(fp, tn, name): return math_ops.div(fp, epsilon + fp + tn, name='fpr_' + name) fpr = compute_fpr(values['fp'], values['tn'], 'value') - update_op = compute_fpr( - update_ops['fp'], update_ops['tn'], 'update_op') + update_op = compute_fpr(update_ops['fp'], update_ops['tn'], 'update_op') if metrics_collections: ops.add_to_collections(metrics_collections, fpr) @@ -1389,9 +1514,13 @@ def streaming_false_positive_rate_at_thresholds( return fpr, update_op -def streaming_false_negative_rate_at_thresholds( - predictions, labels, thresholds, weights=None, metrics_collections=None, - updates_collections=None, name=None): +def streaming_false_negative_rate_at_thresholds(predictions, + labels, + thresholds, + weights=None, + metrics_collections=None, + updates_collections=None, + name=None): """Computes various fnr values for different `thresholds` on `predictions`. The `streaming_false_negative_rate_at_thresholds` function creates two @@ -1433,20 +1562,19 @@ def streaming_false_negative_rate_at_thresholds( either `metrics_collections` or `updates_collections` are not a list or tuple. """ - with variable_scope.variable_scope( - name, 'false_negative_rate_at_thresholds', - (predictions, labels, weights)): + with variable_scope.variable_scope(name, 'false_negative_rate_at_thresholds', + (predictions, labels, weights)): values, update_ops = _streaming_confusion_matrix_at_thresholds( predictions, labels, thresholds, weights, includes=('fn', 'tp')) # Avoid division by zero. epsilon = 1e-7 + def compute_fnr(fn, tp, name): return math_ops.div(fn, epsilon + fn + tp, name='fnr_' + name) fnr = compute_fnr(values['fn'], values['tp'], 'value') - update_op = compute_fnr( - update_ops['fn'], update_ops['tp'], 'update_op') + update_op = compute_fnr(update_ops['fn'], update_ops['tp'], 'update_op') if metrics_collections: ops.add_to_collections(metrics_collections, fnr) @@ -1469,8 +1597,12 @@ def _at_k_name(name, k=None, class_id=None): @deprecated('2016-11-08', 'Please use `streaming_sparse_recall_at_k`, ' 'and reshape labels from [batch_size] to [batch_size, 1].') -def streaming_recall_at_k(predictions, labels, k, weights=None, - metrics_collections=None, updates_collections=None, +def streaming_recall_at_k(predictions, + labels, + k, + weights=None, + metrics_collections=None, + updates_collections=None, name=None): """Computes the recall@k of the predictions with respect to dense labels. @@ -1516,11 +1648,8 @@ def streaming_recall_at_k(predictions, labels, k, weights=None, tuple. """ in_top_k = math_ops.to_float(nn.in_top_k(predictions, labels, k)) - return streaming_mean(in_top_k, - weights, - metrics_collections, - updates_collections, - name or _at_k_name('recall', k)) + return streaming_mean(in_top_k, weights, metrics_collections, + updates_collections, name or _at_k_name('recall', k)) # TODO(ptucker): Validate range of values in labels? @@ -1599,10 +1728,14 @@ def streaming_sparse_recall_at_k(predictions, are not a list or tuple. """ return metrics.recall_at_k( - k=k, class_id=class_id, - predictions=predictions, labels=labels, weights=weights, + k=k, + class_id=class_id, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) # TODO(ptucker): Validate range of values in labels? @@ -1684,10 +1817,14 @@ def streaming_sparse_precision_at_k(predictions, are not a list or tuple. """ return metrics.sparse_precision_at_k( - k=k, class_id=class_id, - predictions=predictions, labels=labels, weights=weights, + k=k, + class_id=class_id, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) # TODO(ptucker): Validate range of values in labels? @@ -1766,9 +1903,8 @@ def streaming_sparse_precision_at_top_k(top_k_predictions, ValueError: If `top_k_predictions` has rank < 2. """ default_name = _at_k_name('precision', class_id=class_id) - with ops.name_scope( - name, default_name, - (top_k_predictions, labels, weights)) as name_scope: + with ops.name_scope(name, default_name, + (top_k_predictions, labels, weights)) as name_scope: return metrics_impl._sparse_precision_at_top_k( # pylint: disable=protected-access labels=labels, predictions_idx=top_k_predictions, @@ -1848,8 +1984,8 @@ def sparse_recall_at_top_k(labels, are not a list or tuple. """ default_name = _at_k_name('recall', class_id=class_id) - with ops.name_scope(name, default_name, (top_k_predictions, labels, - weights)) as name_scope: + with ops.name_scope(name, default_name, + (top_k_predictions, labels, weights)) as name_scope: return metrics_impl._sparse_recall_at_top_k( # pylint: disable=protected-access labels=labels, predictions_idx=top_k_predictions, @@ -1919,9 +2055,13 @@ def streaming_sparse_average_precision_at_k(predictions, value matches `metric`. """ return metrics.sparse_average_precision_at_k( - k=k, predictions=predictions, labels=labels, weights=weights, + k=k, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) def streaming_sparse_average_precision_at_top_k(top_k_predictions, @@ -1987,7 +2127,9 @@ def streaming_sparse_average_precision_at_top_k(top_k_predictions, name=name) -def streaming_mean_absolute_error(predictions, labels, weights=None, +def streaming_mean_absolute_error(predictions, + labels, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -2035,12 +2177,18 @@ def streaming_mean_absolute_error(predictions, labels, weights=None, tuple. """ return metrics.mean_absolute_error( - predictions=predictions, labels=labels, weights=weights, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) -def streaming_mean_relative_error(predictions, labels, normalizer, weights=None, +def streaming_mean_relative_error(predictions, + labels, + normalizer, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -2089,12 +2237,18 @@ def streaming_mean_relative_error(predictions, labels, normalizer, weights=None, tuple. """ return metrics.mean_relative_error( - normalizer=normalizer, predictions=predictions, labels=labels, - weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + normalizer=normalizer, + predictions=predictions, + labels=labels, + weights=weights, + metrics_collections=metrics_collections, + updates_collections=updates_collections, + name=name) -def streaming_mean_squared_error(predictions, labels, weights=None, +def streaming_mean_squared_error(predictions, + labels, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -2142,12 +2296,17 @@ def streaming_mean_squared_error(predictions, labels, weights=None, tuple. """ return metrics.mean_squared_error( - predictions=predictions, labels=labels, weights=weights, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) -def streaming_root_mean_squared_error(predictions, labels, weights=None, +def streaming_root_mean_squared_error(predictions, + labels, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -2195,9 +2354,12 @@ def streaming_root_mean_squared_error(predictions, labels, weights=None, tuple. """ return metrics.root_mean_squared_error( - predictions=predictions, labels=labels, weights=weights, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) def streaming_covariance(predictions, @@ -2253,8 +2415,8 @@ def streaming_covariance(predictions, ValueError: If labels and predictions are of different sizes or if either `metrics_collections` or `updates_collections` are not a list or tuple. """ - with variable_scope.variable_scope( - name, 'covariance', (predictions, labels, weights)): + with variable_scope.variable_scope(name, 'covariance', + (predictions, labels, weights)): predictions, labels, weights = _remove_squeezable_dimensions( predictions, labels, weights) predictions.get_shape().assert_is_compatible_with(labels.get_shape()) @@ -2298,22 +2460,22 @@ def streaming_covariance(predictions, # prev_mean_label is E[y_A] in the update equation prev_mean_label = update_mean_label - delta_mean_label - unweighted_batch_coresiduals = ( - (predictions - batch_mean_prediction) * (labels - batch_mean_label)) + unweighted_batch_coresiduals = ((predictions - batch_mean_prediction) * + (labels - batch_mean_label)) # batch_comoment is C_B in the update equation if weights is None: batch_comoment = math_ops.reduce_sum(unweighted_batch_coresiduals) else: - batch_comoment = math_ops.reduce_sum(unweighted_batch_coresiduals * - weights) + batch_comoment = math_ops.reduce_sum( + unweighted_batch_coresiduals * weights) # View delta_comoment as = C_AB - C_A in the update equation above. # Since C_A is stored in a var, by how much do we need to increment that var # to make the var = C_AB? - delta_comoment = (batch_comoment + - (prev_mean_prediction - batch_mean_prediction) * - (prev_mean_label - batch_mean_label) * - (prev_count * batch_count / update_count)) + delta_comoment = ( + batch_comoment + (prev_mean_prediction - batch_mean_prediction) * + (prev_mean_label - batch_mean_label) * + (prev_count * batch_count / update_count)) update_comoment = state_ops.assign_add(comoment, delta_comoment) covariance = array_ops.where( @@ -2387,8 +2549,8 @@ def streaming_pearson_correlation(predictions, `weights` is the wrong size, or if either `metrics_collections` or `updates_collections` are not a `list` or `tuple`. """ - with variable_scope.variable_scope( - name, 'pearson_r', (predictions, labels, weights)): + with variable_scope.variable_scope(name, 'pearson_r', + (predictions, labels, weights)): predictions, labels, weights = _remove_squeezable_dimensions( predictions, labels, weights) predictions.get_shape().assert_is_compatible_with(labels.get_shape()) @@ -2405,13 +2567,14 @@ def streaming_pearson_correlation(predictions, pearson_r = math_ops.truediv( cov, - math_ops.multiply(math_ops.sqrt(var_predictions), - math_ops.sqrt(var_labels)), + math_ops.multiply( + math_ops.sqrt(var_predictions), math_ops.sqrt(var_labels)), name='pearson_r') update_op = math_ops.truediv( update_cov, - math_ops.multiply(math_ops.sqrt(update_var_predictions), - math_ops.sqrt(update_var_labels)), + math_ops.multiply( + math_ops.sqrt(update_var_predictions), + math_ops.sqrt(update_var_labels)), name='update_op') if metrics_collections: @@ -2425,7 +2588,10 @@ def streaming_pearson_correlation(predictions, # TODO(nsilberman): add a 'normalized' flag so that the user can request # normalization if the inputs are not normalized. -def streaming_mean_cosine_distance(predictions, labels, dim, weights=None, +def streaming_mean_cosine_distance(predictions, + labels, + dim, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -2471,12 +2637,11 @@ def streaming_mean_cosine_distance(predictions, labels, dim, weights=None, predictions, labels, weights) predictions.get_shape().assert_is_compatible_with(labels.get_shape()) radial_diffs = math_ops.multiply(predictions, labels) - radial_diffs = math_ops.reduce_sum(radial_diffs, - reduction_indices=[dim,], - keep_dims=True) - mean_distance, update_op = streaming_mean(radial_diffs, weights, - None, - None, + radial_diffs = math_ops.reduce_sum( + radial_diffs, reduction_indices=[ + dim, + ], keep_dims=True) + mean_distance, update_op = streaming_mean(radial_diffs, weights, None, None, name or 'mean_cosine_distance') mean_distance = math_ops.subtract(1.0, mean_distance) update_op = math_ops.subtract(1.0, update_op) @@ -2490,7 +2655,9 @@ def streaming_mean_cosine_distance(predictions, labels, dim, weights=None, return mean_distance, update_op -def streaming_percentage_less(values, threshold, weights=None, +def streaming_percentage_less(values, + threshold, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -2530,9 +2697,12 @@ def streaming_percentage_less(values, threshold, weights=None, or tuple. """ return metrics.percentage_below( - values=values, threshold=threshold, weights=weights, + values=values, + threshold=threshold, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) def streaming_mean_iou(predictions, @@ -2584,9 +2754,13 @@ def streaming_mean_iou(predictions, tuple. """ return metrics.mean_iou( - num_classes=num_classes, predictions=predictions, labels=labels, - weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + num_classes=num_classes, + predictions=predictions, + labels=labels, + weights=weights, + metrics_collections=metrics_collections, + updates_collections=updates_collections, + name=name) def _next_array_size(required_size, growth_factor=1.5): @@ -2601,9 +2775,9 @@ def _next_array_size(required_size, growth_factor=1.5): tf.Tensor with dtype=int32 giving the next array size. """ exponent = math_ops.ceil( - math_ops.log(math_ops.cast(required_size, dtypes.float32)) - / math_ops.log(math_ops.cast(growth_factor, dtypes.float32))) - return math_ops.cast(math_ops.ceil(growth_factor ** exponent), dtypes.int32) + math_ops.log(math_ops.cast(required_size, dtypes.float32)) / math_ops.log( + math_ops.cast(growth_factor, dtypes.float32))) + return math_ops.cast(math_ops.ceil(growth_factor**exponent), dtypes.int32) def streaming_concat(values, @@ -2660,8 +2834,7 @@ def streaming_concat(values, if not 0 <= axis < ndim: raise ValueError('axis = %r not in [0, %r)' % (axis, ndim)) - fixed_shape = [dim.value for n, dim in enumerate(values_shape) - if n != axis] + fixed_shape = [dim.value for n, dim in enumerate(values_shape) if n != axis] if any(value is None for value in fixed_shape): raise ValueError('all dimensions of `values` other than the dimension to ' 'concatenate along must have statically known size') @@ -2804,14 +2977,14 @@ def _remove_squeezable_dimensions(predictions, labels, weights): # Use static rank. if weights_rank - predictions_rank == 1: weights = array_ops.squeeze(weights, [-1]) - elif (weights_rank is None) or ( - weights_shape.dims[-1].is_compatible_with(1)): + elif (weights_rank is + None) or (weights_shape.dims[-1].is_compatible_with(1)): # Use dynamic rank weights = control_flow_ops.cond( - math_ops.equal(array_ops.rank(weights), - math_ops.add(array_ops.rank(predictions), 1)), - lambda: array_ops.squeeze(weights, [-1]), - lambda: weights) + math_ops.equal( + array_ops.rank(weights), + math_ops.add(array_ops.rank(predictions), 1)), + lambda: array_ops.squeeze(weights, [-1]), lambda: weights) return predictions, labels, weights diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py index c5fcc20abd4927c5408071bae8fa8620cd4d7eb2..ebb19db9339190c28d1b04f193f0e47a16cb31f2 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py @@ -1101,7 +1101,7 @@ class StreamingPrecisionTest(test.TestCase): predictions = random_ops.random_uniform( (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1) labels = random_ops.random_uniform( - (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2) + (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) precision, update_op = metrics.streaming_precision(predictions, labels) with self.test_session() as sess: @@ -1265,7 +1265,7 @@ class StreamingRecallTest(test.TestCase): predictions = random_ops.random_uniform( (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1) labels = random_ops.random_uniform( - (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2) + (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) recall, update_op = metrics.streaming_recall(predictions, labels) with self.test_session() as sess: @@ -1388,7 +1388,7 @@ class StreamingFPRTest(test.TestCase): predictions = random_ops.random_uniform( (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1) labels = random_ops.random_uniform( - (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2) + (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) fpr, update_op = metrics.streaming_false_positive_rate( predictions, labels) @@ -1516,7 +1516,7 @@ class StreamingFNRTest(test.TestCase): predictions = random_ops.random_uniform( (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1) labels = random_ops.random_uniform( - (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2) + (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) fnr, update_op = metrics.streaming_false_negative_rate( predictions, labels) @@ -1737,7 +1737,7 @@ class StreamingAUCTest(test.TestCase): predictions = random_ops.random_uniform( (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1) labels = random_ops.random_uniform( - (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2) + (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) auc, update_op = metrics.streaming_auc(predictions, labels) with self.test_session() as sess: @@ -2009,7 +2009,7 @@ class StreamingSpecificityAtSensitivityTest(test.TestCase): predictions = random_ops.random_uniform( (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1) labels = random_ops.random_uniform( - (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2) + (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) specificity, update_op = metrics.streaming_specificity_at_sensitivity( predictions, labels, sensitivity=0.7) @@ -2271,7 +2271,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): predictions = random_ops.random_uniform( (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1) labels = random_ops.random_uniform( - (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2) + (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) thresholds = [0, 0.5, 1.0] prec, prec_op = metrics.streaming_precision_at_thresholds(predictions, labels, @@ -2282,12 +2282,14 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): with self.test_session() as sess: sess.run(variables.local_variables_initializer()) - # Run several updates, then verify idempotency. - sess.run([prec_op, rec_op]) + # Run several updates. + for _ in range(10): + sess.run([prec_op, rec_op]) + + # Then verify idempotency. initial_prec = prec.eval() initial_rec = rec.eval() for _ in range(10): - sess.run([prec_op, rec_op]) self.assertAllClose(initial_prec, prec.eval()) self.assertAllClose(initial_rec, rec.eval()) @@ -2361,14 +2363,10 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): rec, rec_op = metrics.streaming_recall_at_thresholds( predictions, labels, thresholds, weights=weights) - [prec_low, prec_high] = array_ops.split( - value=prec, num_or_size_splits=2, axis=0) - prec_low = array_ops.reshape(prec_low, shape=()) - prec_high = array_ops.reshape(prec_high, shape=()) - [rec_low, rec_high] = array_ops.split( - value=rec, num_or_size_splits=2, axis=0) - rec_low = array_ops.reshape(rec_low, shape=()) - rec_high = array_ops.reshape(rec_high, shape=()) + prec_low = prec[0] + prec_high = prec[1] + rec_low = rec[0] + rec_high = rec[1] sess.run(variables.local_variables_initializer()) sess.run([prec_op, rec_op]) @@ -2391,14 +2389,10 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): rec, rec_op = metrics.streaming_recall_at_thresholds( predictions, labels, thresholds, weights=weights) - [prec_low, prec_high] = array_ops.split( - value=prec, num_or_size_splits=2, axis=0) - prec_low = array_ops.reshape(prec_low, shape=()) - prec_high = array_ops.reshape(prec_high, shape=()) - [rec_low, rec_high] = array_ops.split( - value=rec, num_or_size_splits=2, axis=0) - rec_low = array_ops.reshape(rec_low, shape=()) - rec_high = array_ops.reshape(rec_high, shape=()) + prec_low = prec[0] + prec_high = prec[1] + rec_low = rec[0] + rec_high = rec[1] sess.run(variables.local_variables_initializer()) sess.run([prec_op, rec_op]) @@ -2420,10 +2414,10 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels, thresholds) - [prec_low, prec_high] = array_ops.split( - value=prec, num_or_size_splits=2, axis=0) - [rec_low, rec_high] = array_ops.split( - value=rec, num_or_size_splits=2, axis=0) + prec_low = prec[0] + prec_high = prec[1] + rec_low = rec[0] + rec_high = rec[1] sess.run(variables.local_variables_initializer()) sess.run([prec_op, rec_op]) @@ -2562,7 +2556,7 @@ class StreamingFPRThresholdsTest(test.TestCase): predictions = random_ops.random_uniform( (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1) labels = random_ops.random_uniform( - (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2) + (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) thresholds = [0, 0.5, 1.0] fpr, fpr_op = metrics.streaming_false_positive_rate_at_thresholds( predictions, labels, thresholds) @@ -2794,7 +2788,7 @@ class StreamingFNRThresholdsTest(test.TestCase): predictions = random_ops.random_uniform( (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1) labels = random_ops.random_uniform( - (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2) + (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) thresholds = [0, 0.5, 1.0] fnr, fnr_op = metrics.streaming_false_negative_rate_at_thresholds( predictions, labels, thresholds) diff --git a/tensorflow/contrib/quantize/BUILD b/tensorflow/contrib/quantize/BUILD index 7ff186bc2ad7204d934c322a04ad1c3f2aa383ab..0d6c71965cb23c6a2418f2862cab22c8236a7017 100644 --- a/tensorflow/contrib/quantize/BUILD +++ b/tensorflow/contrib/quantize/BUILD @@ -13,6 +13,34 @@ py_library( deps = [], ) +py_library( + name = "graph_matcher", + srcs = [ + "python/graph_matcher.py", + ], + srcs_version = "PY2AND3", + deps = [], +) + +py_test( + name = "graph_matcher_test", + size = "small", + srcs = ["python/graph_matcher_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":graph_matcher", + "//tensorflow/contrib/layers:layers_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:init_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn_ops", + "//tensorflow/python:platform_test", + ], +) + py_library( name = "input_to_ops", srcs = ["python/input_to_ops.py"], @@ -43,6 +71,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":common", + ":graph_matcher", ":input_to_ops", "//tensorflow/contrib/graph_editor:graph_editor_py", "//tensorflow/python:array_ops", @@ -58,6 +87,7 @@ py_test( srcs_version = "PY2AND3", deps = [ ":fold_batch_norms", + ":graph_matcher", "//tensorflow/contrib/layers:layers_py", "//tensorflow/python:array_ops", "//tensorflow/python:dtypes", @@ -147,10 +177,11 @@ py_test( py_test( name = "quantize_parameterized_test", - size = "medium", + size = "large", srcs = ["python/quantize_parameterized_test.py"], srcs_version = "PY2AND3", deps = [ + ":fold_batch_norms", ":quantize", "//tensorflow/contrib/layers:layers_py", "//tensorflow/python:array_ops", diff --git a/tensorflow/contrib/quantize/python/copy_graph_test.py b/tensorflow/contrib/quantize/python/copy_graph_test.py index 0889f12de6aac53f70ecfa7b70fc19ac7b95a5fe..7ff9ad9f8412d7076bf12d6cf10772244444013f 100644 --- a/tensorflow/contrib/quantize/python/copy_graph_test.py +++ b/tensorflow/contrib/quantize/python/copy_graph_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for tensorflow.quantized.mangle.copy_graph.""" +"""Tests for copy_graph.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms.py b/tensorflow/contrib/quantize/python/fold_batch_norms.py index c4166895108294148fd09ed95e6227fda17ef36f..647d4044001f7be701037d07dc46db86c0aa3a0e 100644 --- a/tensorflow/contrib/quantize/python/fold_batch_norms.py +++ b/tensorflow/contrib/quantize/python/fold_batch_norms.py @@ -21,7 +21,9 @@ from __future__ import print_function import re from tensorflow.contrib import graph_editor from tensorflow.contrib.quantize.python import common +from tensorflow.contrib.quantize.python import graph_matcher from tensorflow.contrib.quantize.python import input_to_ops +from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn @@ -29,7 +31,7 @@ from tensorflow.python.ops import nn_ops def FoldBatchNorms(graph): - """Finds batch norm layers in the graph, folds them into preceding layers. + """Finds batch norm layers and folds them into preceding layers. Folding only affects the following layers: Conv2D, fully connected, depthwise convolution. @@ -40,10 +42,269 @@ def FoldBatchNorms(graph): Raises: ValueError: When batch norm folding fails. """ - # Fail immediately when the graph contains unsupported fused batch norm ops. - if any(op for op in graph.get_operations() if op.type == 'FusedBatchNorm'): - raise ValueError('Fused batch norm is not supported') + _FoldFusedBatchNorms(graph) + _FoldUnfusedBatchNorms(graph) + +def _FoldFusedBatchNorms(graph): + """Finds fused batch norm layers and folds them into preceding layers. + + Folding only affects the following layers: Conv2D, fully connected, depthwise + convolution. + + Args: + graph: Graph to walk and modify. + + Raises: + ValueError: When batch norm folding fails. + """ + for match in _FindFusedBatchNorms(graph): + scope, sep, _ = match.layer_op.name.rpartition('/') + # Make sure new ops are added to `graph` and put on the same device as + # `bn_op`. The '/' (i.e. `sep`) ensures that we reuse the existing scope + # named `scope`. Otherwise, TF creates a unique scope whose name starts with + # `scope`. + with graph.as_default(), graph.name_scope(scope + sep), ops.device( + match.bn_op.device): + # new weights = old weights * gamma / sqrt(variance + epsilon) + # new biases = -mean * gamma / sqrt(variance + epsilon) + beta + multiplier_tensor = match.gamma_tensor * math_ops.rsqrt( + match.variance_tensor + match.bn_op.get_attr('epsilon')) + bias_tensor = math_ops.subtract( + match.beta_tensor, match.mean_tensor * multiplier_tensor, name='bias') + + # The shape of depthwise weights is different, so we need to reshape the + # multiplier_tensor to ensure that the scaled_weight_tensor has the + # expected shape. + if match.layer_op.type == 'DepthwiseConv2dNative': + new_shape = [ + match.weight_tensor.get_shape().as_list()[2], + match.weight_tensor.get_shape().as_list()[3] + ] + multiplier_tensor = array_ops.reshape( + multiplier_tensor, new_shape, name='scale_reshape') + + # TODO(suharshs): This naming of the following ops needs to carefully + # follow the naming expected by quantize.py. Generalize the quantize code + # to not require these delicate naming conventions. + scaled_weight_tensor = math_ops.multiply( + match.weight_tensor, multiplier_tensor, name='mul_fold') + + new_layer_tensor = _CloneWithNewOperands( + match.layer_op, match.input_tensor, scaled_weight_tensor) + + bias_add_tensor = math_ops.add( + new_layer_tensor, bias_tensor, name='add_fold') + + nodes_modified_count = graph_editor.reroute_ts(bias_add_tensor, + match.output_tensor) + if nodes_modified_count != 1: + raise ValueError( + 'Unexpected inputs to op: %s' % match.output_tensor.name) + + +def _CloneWithNewOperands(layer_op, input_tensor, weight_tensor): + """Clones layer_op with input_tensor and weight_tensor as new inputs.""" + new_layer_name = layer_op.name.split('/')[-1] + '_Fold' + if layer_op.type == 'Conv2D': + return nn_ops.conv2d( + input_tensor, + weight_tensor, + strides=layer_op.get_attr('strides'), + padding=layer_op.get_attr('padding'), + use_cudnn_on_gpu=layer_op.get_attr('use_cudnn_on_gpu'), + data_format=layer_op.get_attr('data_format'), + name=new_layer_name) + elif layer_op.type == 'MatMul': + return math_ops.matmul( + input_tensor, + weight_tensor, + transpose_a=layer_op.get_attr('transpose_a'), + transpose_b=layer_op.get_attr('transpose_b'), + name=new_layer_name) + elif layer_op.type == 'DepthwiseConv2dNative': + return nn.depthwise_conv2d( + input_tensor, + weight_tensor, + strides=layer_op.get_attr('strides'), + padding=layer_op.get_attr('padding'), + name=new_layer_name) + else: + raise ValueError('Cannot handle operation of type: %s' % layer_op.type) + + +def _FindFusedBatchNorms(graph): + """Finds all ops and tensors related to found FusedBatchNorms. + + Args: + graph: Graph to inspect. + + Yields: + _FusedBatchNormMatches. + """ + input_pattern = graph_matcher.OpTypePattern('*') + weight_pattern = graph_matcher.OpTypePattern('*') + gamma_pattern = graph_matcher.OpTypePattern('*') + beta_pattern = graph_matcher.OpTypePattern('*') + mean_pattern = graph_matcher.OpTypePattern('*') + variance_pattern = graph_matcher.OpTypePattern('*') + + conv_pattern = graph_matcher.OpTypePattern( + 'Conv2D|DepthwiseConv2dNative', inputs=[input_pattern, weight_pattern]) + # MatMul has a Reshape between it and FusedBatchNorm. + matmul_pattern = graph_matcher.OpTypePattern( + 'MatMul', inputs=[input_pattern, weight_pattern]) + matmul_reshape_pattern = graph_matcher.OpTypePattern( + 'Reshape', inputs=[matmul_pattern, + graph_matcher.OpTypePattern('*')]) + + conv_batch_norm_pattern = graph_matcher.OpTypePattern( + 'FusedBatchNorm', + inputs=[ + conv_pattern, gamma_pattern, beta_pattern, mean_pattern, + variance_pattern + ]) + matmul_batch_norm_pattern = graph_matcher.OpTypePattern( + 'FusedBatchNorm', + inputs=[ + matmul_reshape_pattern, gamma_pattern, beta_pattern, mean_pattern, + variance_pattern + ]) + matmul_bn_output_reshape_pattern = graph_matcher.OpTypePattern( + 'Reshape', + inputs=[matmul_batch_norm_pattern, + graph_matcher.OpTypePattern('*')]) + + conv_matcher = graph_matcher.GraphMatcher(conv_batch_norm_pattern) + matmul_matcher = graph_matcher.GraphMatcher(matmul_bn_output_reshape_pattern) + + def _GetCommonTensors(match_result): + """Gets tensors needed for FusedBatchNormMatch from match_result.""" + input_tensor = match_result.get_tensor(input_pattern) + weight_tensor = match_result.get_tensor(weight_pattern) + gamma_tensor = match_result.get_tensor(gamma_pattern) + beta_tensor = match_result.get_tensor(beta_pattern) + # FusedBatchNorm in training is different from that in inference. It takes + # empty 'mean' and empty 'variance', and produces the mean and the variance + # of the batch. Therefore, when is_training is true, mean_tensor and + # variance_tensor point to 1st and 2nd (0-based) output of bn_op, + # respectively; when is_training is false, they point to bn_op's inputs. + is_training = bn_op.get_attr('is_training') + if is_training: + mean_tensor = bn_op.outputs[1] + variance_tensor = bn_op.outputs[2] + else: + mean_tensor = match_result.get_tensor(mean_pattern) + variance_tensor = match_result.get_tensor(variance_pattern) + return (input_tensor, weight_tensor, gamma_tensor, beta_tensor, mean_tensor, + variance_tensor) + + for match_result in conv_matcher.match_graph(graph): + layer_op = match_result.get_op(conv_pattern) + bn_op = match_result.get_op(conv_batch_norm_pattern) + # In the case of convolution the output_tensor is the output of bn_op. + output_tensor = bn_op.outputs[0] + + (input_tensor, weight_tensor, gamma_tensor, beta_tensor, mean_tensor, + variance_tensor) = _GetCommonTensors(match_result) + yield _FusedBatchNormMatch( + layer_op=layer_op, + bn_op=bn_op, + output_tensor=output_tensor, + input_tensor=input_tensor, + weight_tensor=weight_tensor, + gamma_tensor=gamma_tensor, + beta_tensor=beta_tensor, + mean_tensor=mean_tensor, + variance_tensor=variance_tensor) + + for match_result in matmul_matcher.match_graph(graph): + layer_op = match_result.get_op(matmul_pattern) + bn_op = match_result.get_op(matmul_batch_norm_pattern) + # In the MatMul case, the output of batch norm is reshaped back into a + # 2D tensor, so the output_tensor is the output of the Reshape op. + output_reshape_op = match_result.get_op(matmul_bn_output_reshape_pattern) + output_tensor = output_reshape_op.outputs[0] + + (input_tensor, weight_tensor, gamma_tensor, beta_tensor, mean_tensor, + variance_tensor) = _GetCommonTensors(match_result) + yield _FusedBatchNormMatch( + layer_op=layer_op, + bn_op=bn_op, + output_tensor=output_tensor, + input_tensor=input_tensor, + weight_tensor=weight_tensor, + gamma_tensor=gamma_tensor, + beta_tensor=beta_tensor, + mean_tensor=mean_tensor, + variance_tensor=variance_tensor) + + +class _FusedBatchNormMatch(object): + """Contains all information related to a found FusedBatchNorm.""" + + def __init__(self, layer_op, bn_op, output_tensor, input_tensor, + weight_tensor, gamma_tensor, beta_tensor, mean_tensor, + variance_tensor): + self._layer_op = layer_op + self._bn_op = bn_op + self._output_tensor = output_tensor + self._input_tensor = input_tensor + self._weight_tensor = weight_tensor + self._gamma_tensor = gamma_tensor + self._beta_tensor = beta_tensor + self._mean_tensor = mean_tensor + self._variance_tensor = variance_tensor + + @property + def layer_op(self): + return self._layer_op + + @property + def bn_op(self): + return self._bn_op + + @property + def output_tensor(self): + return self._output_tensor + + @property + def input_tensor(self): + return self._input_tensor + + @property + def weight_tensor(self): + return self._weight_tensor + + @property + def gamma_tensor(self): + return self._gamma_tensor + + @property + def beta_tensor(self): + return self._beta_tensor + + @property + def mean_tensor(self): + return self._mean_tensor + + @property + def variance_tensor(self): + return self._variance_tensor + + +def _FoldUnfusedBatchNorms(graph): + """Finds unfused batch norm layers and folds them into preceding layers. + + Folding only affects the following layers: Conv2D, fully connected, depthwise + convolution. + + Args: + graph: Graph to walk and modify. + + Raises: + ValueError: When batch norm folding fails. + """ input_to_ops_map = input_to_ops.InputToOps(graph) for bn in common.BatchNormGroups(graph): diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms_test.py b/tensorflow/contrib/quantize/python/fold_batch_norms_test.py index ddedb0a2c067a27d05dc1aff4c2b4c447dafe93a..5a66b38b155dd140c2d0e5ec405b9494ea0613c9 100644 --- a/tensorflow/contrib/quantize/python/fold_batch_norms_test.py +++ b/tensorflow/contrib/quantize/python/fold_batch_norms_test.py @@ -18,7 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import copy from tensorflow.contrib.layers.python.layers import layers from tensorflow.contrib.quantize.python import fold_batch_norms from tensorflow.python.framework import dtypes @@ -35,57 +34,32 @@ conv2d = layers.conv2d fully_connected = layers.fully_connected separable_conv2d = layers.separable_conv2d -_DEFAULT_BATCH_NORM_PARAMS = { - 'center': True, - 'scale': True, - 'decay': 1.0 - 0.003, - 'fused': False, -} - # TODO(suharshs): Use parameterized test once OSS TF supports it. class FoldBatchNormsTest(test_util.TensorFlowTestCase): def _RunTestOverParameters(self, test_fn): parameters_list = [ - # (relu, relu_op_name, with_bypass) - (nn_ops.relu6, 'Relu6', False), - (nn_ops.relu, 'Relu', False), - (nn_ops.relu6, 'Relu6', True), - (nn_ops.relu, 'Relu', True), + # (relu, relu_op_name, with_bypass, has_scaling, fused_batch_norm) + (nn_ops.relu6, 'Relu6', False, False, False), + (nn_ops.relu, 'Relu', False, False, False), + (nn_ops.relu6, 'Relu6', True, False, False), + (nn_ops.relu, 'Relu', True, False, False), + (nn_ops.relu6, 'Relu6', False, True, False), + (nn_ops.relu, 'Relu', False, True, False), + (nn_ops.relu6, 'Relu6', True, True, False), + (nn_ops.relu, 'Relu', True, True, False), + # Fused batch norm always has scaling enabled. + (nn_ops.relu6, 'Relu6', False, True, True), + (nn_ops.relu, 'Relu', False, True, True), + (nn_ops.relu6, 'Relu6', True, True, True), + (nn_ops.relu, 'Relu', True, True, True), ] - for parameters in parameters_list: - test_fn(parameters[0], parameters[1], parameters[2]) - - def testFailsWithFusedBatchNorm(self): - self._RunTestOverParameters(self._TestFailsWithFusedBatchNorm) + for params in parameters_list: + test_fn(params[0], params[1], params[2], params[3], params[4]) - def _TestFailsWithFusedBatchNorm(self, relu, relu_op_name, with_bypass): - """Tests that batch norm fails when fused batch norm ops are present.""" - g = ops.Graph() - with g.as_default(): - batch_size, height, width = 5, 128, 128 - inputs = array_ops.zeros((batch_size, height, width, 3)) - out_depth = 3 if with_bypass else 32 - stride = 1 if with_bypass else 2 - activation_fn = None if with_bypass else relu - batch_norm_params = _DEFAULT_BATCH_NORM_PARAMS.copy() - batch_norm_params['fused'] = True - scope = 'test/test2' if with_bypass else 'test' - node = conv2d(inputs, out_depth, [5, 5], stride=stride, padding='SAME', - weights_initializer=self._WeightInit(0.09), - activation_fn=activation_fn, - normalizer_fn=batch_norm, - normalizer_params=batch_norm_params, - scope=scope) - if with_bypass: - node = math_ops.add(inputs, node, name='test/Add') - relu(node, name='test/' + relu_op_name) - - with self.assertRaises(ValueError): - fold_batch_norms.FoldBatchNorms(g) - - def _TestFoldConv2d(self, relu, relu_op_name, with_bypass): + def _TestFoldConv2d(self, relu, relu_op_name, with_bypass, has_scaling, + fused_batch_norm): """Tests folding cases: inputs -> Conv2d with batch norm -> Relu*. Args: @@ -93,6 +67,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): relu_op_name: String, name of the Relu* operation. with_bypass: Bool, when true there is an extra connection added from inputs to just before Relu*. + has_scaling: Bool, when true the batch norm has scaling. + fused_batch_norm: Bool, when true the batch norm is fused. """ g = ops.Graph() with g.as_default(): @@ -102,12 +78,17 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): stride = 1 if with_bypass else 2 activation_fn = None if with_bypass else relu scope = 'test/test2' if with_bypass else 'test' - node = conv2d(inputs, out_depth, [5, 5], stride=stride, padding='SAME', - weights_initializer=self._WeightInit(0.09), - activation_fn=activation_fn, - normalizer_fn=batch_norm, - normalizer_params=_DEFAULT_BATCH_NORM_PARAMS, - scope=scope) + node = conv2d( + inputs, + out_depth, [5, 5], + stride=stride, + padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=activation_fn, + normalizer_fn=batch_norm, + normalizer_params=self._BatchNormParams( + scale=has_scaling, fused=fused_batch_norm), + scope=scope) if with_bypass: node = math_ops.add(inputs, node, name='test/Add') relu(node, name='test/' + relu_op_name) @@ -116,9 +97,10 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): folded_mul = g.get_operation_by_name(scope + '/mul_fold') self.assertEqual(folded_mul.type, 'Mul') - self._AssertInputOpsAre(folded_mul, - [scope + '/weights/read', - scope + '/BatchNorm/batchnorm/mul']) + self._AssertInputOpsAre(folded_mul, [ + scope + '/weights/read', + self._BatchNormMultiplierName(scope, has_scaling, fused_batch_norm) + ]) self._AssertOutputGoesToOps(folded_mul, g, [scope + '/convolution_Fold']) folded_conv = g.get_operation_by_name(scope + '/convolution_Fold') @@ -129,16 +111,18 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): folded_add = g.get_operation_by_name(scope + '/add_fold') self.assertEqual(folded_add.type, 'Add') - self._AssertInputOpsAre(folded_add, - [scope + '/convolution_Fold', - scope + '/BatchNorm/batchnorm/sub']) + self._AssertInputOpsAre(folded_add, [ + scope + '/convolution_Fold', + self._BathNormBiasName(scope, fused_batch_norm) + ]) output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] self._AssertOutputGoesToOps(folded_add, g, output_op_names) def testFoldConv2d(self): self._RunTestOverParameters(self._TestFoldConv2d) - def _TestFoldConv2dUnknownShape(self, relu, relu_op_name, with_bypass): + def _TestFoldConv2dUnknownShape(self, relu, relu_op_name, with_bypass, + has_scaling, fused_batch_norm): """Tests folding cases: inputs -> Conv2d with batch norm -> Relu*. Tests that folding works even with an input shape where some dimensions are @@ -149,6 +133,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): relu_op_name: String, name of the Relu* operation. with_bypass: Bool, when true there is an extra connection added from inputs to just before Relu*. + has_scaling: Bool, when true the batch norm has scaling. + fused_batch_norm: Bool, when true the batch norm is fused. """ g = ops.Graph() with g.as_default(): @@ -165,7 +151,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): weights_initializer=self._WeightInit(0.09), activation_fn=activation_fn, normalizer_fn=batch_norm, - normalizer_params=_DEFAULT_BATCH_NORM_PARAMS, + normalizer_params=self._BatchNormParams( + scale=has_scaling, fused=fused_batch_norm), scope=scope) if with_bypass: node = math_ops.add(inputs, node, name='test/Add') @@ -176,7 +163,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): folded_mul = g.get_operation_by_name(scope + '/mul_fold') self.assertEqual(folded_mul.type, 'Mul') self._AssertInputOpsAre(folded_mul, [ - scope + '/weights/read', scope + '/BatchNorm/batchnorm/mul' + scope + '/weights/read', + self._BatchNormMultiplierName(scope, has_scaling, fused_batch_norm) ]) self._AssertOutputGoesToOps(folded_mul, g, [scope + '/convolution_Fold']) @@ -188,7 +176,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): folded_add = g.get_operation_by_name(scope + '/add_fold') self.assertEqual(folded_add.type, 'Add') self._AssertInputOpsAre(folded_add, [ - scope + '/convolution_Fold', scope + '/BatchNorm/batchnorm/sub' + scope + '/convolution_Fold', + self._BathNormBiasName(scope, fused_batch_norm) ]) output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] self._AssertOutputGoesToOps(folded_add, g, output_op_names) @@ -196,62 +185,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): def testFoldConv2dUnknownShape(self): self._RunTestOverParameters(self._TestFoldConv2dUnknownShape) - def _TestFoldConv2dWithoutScale(self, relu, relu_op_name, with_bypass): - """Tests folding cases: inputs -> Conv2d with batch norm -> Relu*. - - Args: - relu: Callable that returns an Operation, a factory method for the Relu*. - relu_op_name: String, name of the Relu* operation. - with_bypass: Bool, when true there is an extra connection added from - inputs to just before Relu*. - """ - g = ops.Graph() - with g.as_default(): - batch_size, height, width = 5, 128, 128 - inputs = array_ops.zeros((batch_size, height, width, 3)) - out_depth = 3 if with_bypass else 32 - stride = 1 if with_bypass else 2 - activation_fn = None if with_bypass else relu - bn_params = copy.copy(_DEFAULT_BATCH_NORM_PARAMS) - bn_params['scale'] = False - scope = 'test/test2' if with_bypass else 'test' - node = conv2d(inputs, out_depth, [5, 5], stride=stride, padding='SAME', - weights_initializer=self._WeightInit(0.09), - activation_fn=activation_fn, - normalizer_fn=batch_norm, - normalizer_params=bn_params, - scope=scope) - if with_bypass: - node = math_ops.add(inputs, node, name='test/Add') - relu(node, name='test/' + relu_op_name) - - fold_batch_norms.FoldBatchNorms(g) - - folded_mul = g.get_operation_by_name(scope + '/mul_fold') - self.assertEqual(folded_mul.type, 'Mul') - self._AssertInputOpsAre(folded_mul, - [scope + '/weights/read', - scope + '/BatchNorm/batchnorm/Rsqrt']) - self._AssertOutputGoesToOps(folded_mul, g, [scope + '/convolution_Fold']) - - folded_conv = g.get_operation_by_name(scope + '/convolution_Fold') - self.assertEqual(folded_conv.type, 'Conv2D') - self._AssertInputOpsAre(folded_conv, - [scope + '/mul_fold', inputs.op.name]) - self._AssertOutputGoesToOps(folded_conv, g, [scope + '/add_fold']) - - folded_add = g.get_operation_by_name(scope + '/add_fold') - self.assertEqual(folded_add.type, 'Add') - self._AssertInputOpsAre(folded_add, - [scope + '/convolution_Fold', - scope + '/BatchNorm/batchnorm/sub']) - output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] - self._AssertOutputGoesToOps(folded_add, g, output_op_names) - - def testFoldConv2dWithoutScale(self): - self._RunTestOverParameters(self._TestFoldConv2dWithoutScale) - - def _TestFoldFullyConnectedLayer(self, relu, relu_op_name, with_bypass): + def _TestFoldFullyConnectedLayer(self, relu, relu_op_name, with_bypass, + has_scaling, fused_batch_norm): """Tests folding cases: inputs -> FC with batch norm -> Relu*. Args: @@ -259,6 +194,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): relu_op_name: String, name of the Relu* operation. with_bypass: Bool, when true there is an extra connection added from inputs to just before Relu*. + has_scaling: Bool, when true the batch norm has scaling. + fused_batch_norm: Bool, when true the batch norm is fused. """ g = ops.Graph() with g.as_default(): @@ -267,12 +204,15 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): out_depth = 256 if with_bypass else 128 activation_fn = None if with_bypass else relu scope = 'test/test2' if with_bypass else 'test' - node = fully_connected(inputs, out_depth, - weights_initializer=self._WeightInit(0.03), - activation_fn=activation_fn, - normalizer_fn=batch_norm, - normalizer_params=_DEFAULT_BATCH_NORM_PARAMS, - scope=scope) + node = fully_connected( + inputs, + out_depth, + weights_initializer=self._WeightInit(0.03), + activation_fn=activation_fn, + normalizer_fn=batch_norm, + normalizer_params=self._BatchNormParams( + scale=has_scaling, fused=fused_batch_norm), + scope=scope) if with_bypass: node = math_ops.add(inputs, node, name='test/Add') relu(node, name='test/' + relu_op_name) @@ -281,9 +221,10 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): folded_mul = g.get_operation_by_name(scope + '/mul_fold') self.assertEqual(folded_mul.type, 'Mul') - self._AssertInputOpsAre(folded_mul, - [scope + '/weights/read', - scope + '/BatchNorm/batchnorm/mul']) + self._AssertInputOpsAre(folded_mul, [ + scope + '/weights/read', + self._BatchNormMultiplierName(scope, has_scaling, fused_batch_norm) + ]) self._AssertOutputGoesToOps(folded_mul, g, [scope + '/MatMul_Fold']) folded_conv = g.get_operation_by_name(scope + '/MatMul_Fold') @@ -294,71 +235,18 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): folded_add = g.get_operation_by_name(scope + '/add_fold') self.assertEqual(folded_add.type, 'Add') - self._AssertInputOpsAre(folded_add, - [scope + '/MatMul_Fold', - scope + '/BatchNorm/batchnorm/sub']) + self._AssertInputOpsAre(folded_add, [ + scope + '/MatMul_Fold', + self._BathNormBiasName(scope, fused_batch_norm) + ]) output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] self._AssertOutputGoesToOps(folded_add, g, output_op_names) def testFoldFullyConnectedLayer(self): self._RunTestOverParameters(self._TestFoldFullyConnectedLayer) - def _TestFoldFullyConnectedLayerWithoutScale(self, relu, relu_op_name, - with_bypass): - """Tests folding cases: inputs -> FC with batch norm -> Relu*. - - Args: - relu: Callable that returns an Operation, a factory method for the Relu*. - relu_op_name: String, name of the Relu* operation. - with_bypass: Bool, when true there is an extra connection added from - inputs to just before Relu*. - """ - g = ops.Graph() - with g.as_default(): - batch_size, depth = 5, 256 - inputs = array_ops.zeros((batch_size, depth)) - out_depth = 256 if with_bypass else 128 - activation_fn = None if with_bypass else relu - bn_params = copy.copy(_DEFAULT_BATCH_NORM_PARAMS) - bn_params['scale'] = False - scope = 'test/test2' if with_bypass else 'test' - node = fully_connected(inputs, out_depth, - weights_initializer=self._WeightInit(0.03), - activation_fn=activation_fn, - normalizer_fn=batch_norm, - normalizer_params=bn_params, - scope=scope) - if with_bypass: - node = math_ops.add(inputs, node, name='test/Add') - relu(node, name='test/' + relu_op_name) - - fold_batch_norms.FoldBatchNorms(g) - - folded_mul = g.get_operation_by_name(scope + '/mul_fold') - self.assertEqual(folded_mul.type, 'Mul') - self._AssertInputOpsAre(folded_mul, - [scope + '/weights/read', - scope + '/BatchNorm/batchnorm/Rsqrt']) - self._AssertOutputGoesToOps(folded_mul, g, [scope + '/MatMul_Fold']) - - folded_conv = g.get_operation_by_name(scope + '/MatMul_Fold') - self.assertEqual(folded_conv.type, 'MatMul') - self._AssertInputOpsAre(folded_conv, - [scope + '/mul_fold', inputs.op.name]) - self._AssertOutputGoesToOps(folded_conv, g, [scope + '/add_fold']) - - folded_add = g.get_operation_by_name(scope + '/add_fold') - self.assertEqual(folded_add.type, 'Add') - self._AssertInputOpsAre(folded_add, - [scope + '/MatMul_Fold', - scope + '/BatchNorm/batchnorm/sub']) - output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] - self._AssertOutputGoesToOps(folded_add, g, output_op_names) - - def testFoldFullyConnectedLayerWithoutScale(self): - self._RunTestOverParameters(self._TestFoldFullyConnectedLayerWithoutScale) - - def _TestFoldDepthwiseConv2d(self, relu, relu_op_name, with_bypass): + def _TestFoldDepthwiseConv2d(self, relu, relu_op_name, with_bypass, + has_scaling, fused_batch_norm): """Tests folding: inputs -> DepthwiseConv2d with batch norm -> Relu*. Args: @@ -366,6 +254,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): relu_op_name: String, name of the Relu* operation. with_bypass: Bool, when true there is an extra connection added from inputs to just before Relu*. + has_scaling: Bool, when true the batch norm has scaling. + fused_batch_norm: Bool, when true the batch norm is fused. """ g = ops.Graph() with g.as_default(): @@ -374,13 +264,18 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): stride = 1 if with_bypass else 2 activation_fn = None if with_bypass else relu scope = 'test/test2' if with_bypass else 'test' - node = separable_conv2d(inputs, None, [5, 5], stride=stride, - depth_multiplier=1.0, padding='SAME', - weights_initializer=self._WeightInit(0.09), - activation_fn=activation_fn, - normalizer_fn=batch_norm, - normalizer_params=_DEFAULT_BATCH_NORM_PARAMS, - scope=scope) + node = separable_conv2d( + inputs, + None, [5, 5], + stride=stride, + depth_multiplier=1.0, + padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=activation_fn, + normalizer_fn=batch_norm, + normalizer_params=self._BatchNormParams( + scale=has_scaling, fused=fused_batch_norm), + scope=scope) if with_bypass: node = math_ops.add(inputs, node, name='test/Add') relu(node, name='test/' + relu_op_name) @@ -396,9 +291,10 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): scale_reshape = g.get_operation_by_name(scope + '/scale_reshape') self.assertEqual(scale_reshape.type, 'Reshape') - self._AssertInputOpsAre(scale_reshape, - [scope + '/BatchNorm/batchnorm/mul', - scope + '/scale_reshape/shape']) + self._AssertInputOpsAre(scale_reshape, [ + self._BatchNormMultiplierName(scope, has_scaling, fused_batch_norm), + scope + '/scale_reshape/shape' + ]) self._AssertOutputGoesToOps(scale_reshape, g, [scope + '/mul_fold']) folded_conv = g.get_operation_by_name(scope + '/depthwise_Fold') @@ -409,77 +305,35 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): folded_add = g.get_operation_by_name(scope + '/add_fold') self.assertEqual(folded_add.type, 'Add') - self._AssertInputOpsAre(folded_add, - [scope + '/depthwise_Fold', - scope + '/BatchNorm/batchnorm/sub']) + self._AssertInputOpsAre(folded_add, [ + scope + '/depthwise_Fold', + self._BathNormBiasName(scope, fused_batch_norm) + ]) output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] self._AssertOutputGoesToOps(folded_add, g, output_op_names) def testFoldDepthwiseConv2d(self): self._RunTestOverParameters(self._TestFoldDepthwiseConv2d) - def _TestFoldDepthwiseConv2dWithoutScale(self, relu, relu_op_name, - with_bypass): - """Tests folding: inputs -> DepthwiseConv2d with batch norm -> Relu*. - - Args: - relu: Callable that returns an Operation, a factory method for the Relu*. - relu_op_name: String, name of the Relu* operation. - with_bypass: Bool, when true there is an extra connection added from - inputs to just before Relu*. - """ - g = ops.Graph() - with g.as_default(): - batch_size, height, width = 5, 128, 128 - inputs = array_ops.zeros((batch_size, height, width, 3)) - stride = 1 if with_bypass else 2 - activation_fn = None if with_bypass else relu - bn_params = copy.copy(_DEFAULT_BATCH_NORM_PARAMS) - bn_params['scale'] = False - scope = 'test/test2' if with_bypass else 'test' - node = separable_conv2d(inputs, None, [5, 5], stride=stride, - depth_multiplier=1.0, padding='SAME', - weights_initializer=self._WeightInit(0.09), - activation_fn=activation_fn, - normalizer_fn=batch_norm, - normalizer_params=bn_params, - scope=scope) - if with_bypass: - node = math_ops.add(inputs, node, name='test/Add') - relu(node, name='test/' + relu_op_name) - - fold_batch_norms.FoldBatchNorms(g) - - folded_mul = g.get_operation_by_name(scope + '/mul_fold') - self.assertEqual(folded_mul.type, 'Mul') - self._AssertInputOpsAre(folded_mul, - [scope + '/depthwise_weights/read', - scope + '/scale_reshape']) - self._AssertOutputGoesToOps(folded_mul, g, [scope + '/depthwise_Fold']) - - scale_reshape = g.get_operation_by_name(scope + '/scale_reshape') - self.assertEqual(scale_reshape.type, 'Reshape') - self._AssertInputOpsAre(scale_reshape, - [scope + '/BatchNorm/batchnorm/Rsqrt', - scope + '/scale_reshape/shape']) - self._AssertOutputGoesToOps(scale_reshape, g, [scope + '/mul_fold']) - - folded_conv = g.get_operation_by_name(scope + '/depthwise_Fold') - self.assertEqual(folded_conv.type, 'DepthwiseConv2dNative') - self._AssertInputOpsAre(folded_conv, - [scope + '/mul_fold', inputs.op.name]) - self._AssertOutputGoesToOps(folded_conv, g, [scope + '/add_fold']) - - folded_add = g.get_operation_by_name(scope + '/add_fold') - self.assertEqual(folded_add.type, 'Add') - self._AssertInputOpsAre(folded_add, - [scope + '/depthwise_Fold', - scope + '/BatchNorm/batchnorm/sub']) - output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] - self._AssertOutputGoesToOps(folded_add, g, output_op_names) - - def testFoldDepthwiseConv2dWithoutScale(self): - self._RunTestOverParameters(self._TestFoldDepthwiseConv2dWithoutScale) + def _BatchNormParams(self, scale=True, fused=False): + return { + 'center': True, + 'scale': scale, + 'decay': 1.0 - 0.003, + 'fused': fused + } + + def _BatchNormMultiplierName(self, scope, has_scaling, fused): + if has_scaling: + if fused: + return scope + '/mul' + return scope + '/BatchNorm/batchnorm/mul' + return scope + '/BatchNorm/batchnorm/Rsqrt' + + def _BathNormBiasName(self, scope, fused): + if fused: + return scope + '/bias' + return scope + '/BatchNorm/batchnorm/sub' def _WeightInit(self, stddev): """Returns a truncated normal variable initializer. diff --git a/tensorflow/contrib/quantize/python/graph_matcher.py b/tensorflow/contrib/quantize/python/graph_matcher.py new file mode 100644 index 0000000000000000000000000000000000000000..e3581cc55905a0af7d0464bc0ec673d3ed7f0363 --- /dev/null +++ b/tensorflow/contrib/quantize/python/graph_matcher.py @@ -0,0 +1,200 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities that match patterns in a tf.Graph.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +class OpTypePattern(object): + """A tree pattern that matches TF expressions with certain op types.""" + + def __init__(self, op_type, name=None, inputs=None): + """Initializes an OpTypePattern. + + Args: + op_type: string that specifies the allowed types of the root. It can be + (1) an op type, e.g. 'Conv2D', + (2) '*', i.e. wildcard, or + (3) multiple op types separated by '|', e.g., 'Relu|Relu6'. + We could use regex strings, which might be worthwhile when we have many + similar TF op types. + name: Optional string. The name of the pattern that can be looked up in + MatchResult. + inputs: Optional list of `OpTypePattern`s or strings that specify the + patterns for the inputs of a matching op. If None, this pattern accepts + any inputs of a matching op. + """ + self._op_type = op_type + self._name = name + if inputs is None: + inputs = [] + self._inputs = [ + input_pattern if isinstance(input_pattern, OpTypePattern) else + OpTypePattern(input_pattern) for input_pattern in inputs + ] + + @property + def op_type(self): + return self._op_type + + @property + def inputs(self): + return self._inputs + + @property + def name(self): + return self._name + + +class MatchResult(object): + r"""Encapsulates the result of a match done by GraphMatcher. + + MatchResult contains a map from OpTypePattern to the matching op and tensor. + When the matching op has multiple output tensors, the matching tensor is the + output tensor used by the matching op of the parent pattern. E.g., when we + match graph + + - + + / \y0 y1/ \ + x split z + | + y (nodes are ops; edges are going up) + + against add_pattern defined as + + y1_pattern = OpTypePattern('*') + z_pattern = OpTypePattern('*') + add_pattern = OpTypePattern('+', inputs=[y1_pattern, z_pattern]) + + the matching op of `y1_pattern` is `split`, and the matching tensor of + `y1_pattern` + is `y1` not `y0`. + """ + + def __init__(self): + self._pattern_to_op_tensor = {} + self._name_to_pattern = {} + + def add(self, pattern, op, tensor): + self._pattern_to_op_tensor[pattern] = op, tensor + if pattern.name is not None: + if pattern.name in self._name_to_pattern: + raise ValueError( + 'Name %s is already bound to another pattern' % pattern.name) + self._name_to_pattern[pattern.name] = pattern + + def _to_pattern(self, pattern_or_name): + if isinstance(pattern_or_name, OpTypePattern): + return pattern_or_name + + if isinstance(pattern_or_name, str): + return self._name_to_pattern[pattern_or_name] + + raise ValueError('pattern_or_name has type %s. Expect OpTypePattern or str.' + % type(pattern_or_name)) + + def get_op(self, pattern_or_name): + return self._pattern_to_op_tensor[self._to_pattern(pattern_or_name)][0] + + def get_tensor(self, pattern_or_name): + return self._pattern_to_op_tensor[self._to_pattern(pattern_or_name)][1] + + +class GraphMatcher(object): + """Checks if a particular subgraph matches a given pattern.""" + + def __init__(self, pattern): + """Initializes a GraphMatcher. + + Args: + pattern: The `OpTypePattern` against which `GraphMatcher` matches + subgraphs. + """ + self._pattern = pattern + + def _match_pattern(self, pattern, op, tensor): + """Returns whether an TF expression rooted at `op` matches `pattern`. + + If there is a match, adds to `self._match_result` the matching op and tensor + with key `pattern`. + + Args: + pattern: An `OpTypePattern`. + op: A `tf.Operation` to match against the pattern. + tensor: the output `tf.Tensor` of `op` that is used by the matching op of + `pattern`'s parent. Can be None if `pattern` is already the root of the + pattern tree. + + Returns: + True if an TF expression rooted at `op` matches `pattern`. + """ + if pattern.op_type != '*': + if op.type not in pattern.op_type.split('|'): + return False + + self._match_result.add(pattern, op, tensor) + + if not pattern.inputs: + # If pattern.inputs is empty, skips the rest and accepts all the inputs. + return True + + return len(op.inputs) == len(pattern.inputs) and all([ + self._match_pattern(input_pattern, input_tensor.op, input_tensor) + for input_tensor, input_pattern in zip(op.inputs, pattern.inputs) + ]) + + def match_op(self, op): + """Matches `op` against `self._pattern`. + + Args: + op: `tf.Operation` to match against the pattern. + + Returns: + Returns a `MatchResult` if `op` matches the pattern; otherwise, returns + None. + """ + self._match_result = MatchResult() + if not self._match_pattern(self._pattern, op, tensor=None): + return None + return self._match_result + + def match_ops(self, ops): + """Matches each operation in `ops` against `self._pattern`. + + Args: + ops: collection of `tf.Operation` to match against the pattern. + + Yields: + `MatchResult` for each `tf.Operation` that matches the pattern. + """ + for op in ops: + match_result = self.match_op(op) + if match_result: + yield match_result + + def match_graph(self, graph): + """Matches each operation in `graph` against `self._pattern`. + + Args: + graph: `tf.Graph` containing operations to match. + + Yields: + `MatchResult` for each `tf.Operation` in `graph` that matches the pattern. + """ + # Python 3.3.2+ implements `yield from`, but for now: + for match_result in self.match_ops(graph.get_operations()): + yield match_result diff --git a/tensorflow/contrib/quantize/python/graph_matcher_test.py b/tensorflow/contrib/quantize/python/graph_matcher_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e1572865e423e569ee3b280036c0e02b71b70648 --- /dev/null +++ b/tensorflow/contrib/quantize/python/graph_matcher_test.py @@ -0,0 +1,130 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for graph_matcher.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.framework.python import ops as contrib_ops +from tensorflow.contrib.layers.python.layers import initializers +from tensorflow.contrib.layers.python.layers import layers +from tensorflow.contrib.quantize.python import graph_matcher +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.platform import googletest + + +class GraphMatcherTest(test_util.TensorFlowTestCase): + + def test_conv_layer(self): + g = ops.Graph() + with g.as_default(): + inputs = array_ops.placeholder(dtypes.float32, shape=[8, 5, 5, 3]) + + with contrib_ops.arg_scope( + [layers.batch_norm], fused=True, is_training=True, trainable=True): + return layers.convolution( + inputs, + num_outputs=16, + kernel_size=3, + stride=1, + padding='VALID', + activation_fn=nn_ops.relu, + normalizer_fn=layers.batch_norm, + normalizer_params={}, + weights_initializer=initializers.xavier_initializer(), + weights_regularizer=None, + biases_initializer=init_ops.zeros_initializer(), + biases_regularizer=None, + reuse=None, + trainable=True, + scope=None) + + inputs_pattern = graph_matcher.OpTypePattern('*', name='inputs') + relu_pattern = graph_matcher.OpTypePattern( + 'Relu', + name='relu', + inputs=[ + graph_matcher.OpTypePattern( + 'FusedBatchNorm', + inputs=[ + graph_matcher.OpTypePattern( + 'Conv2D', inputs=[inputs_pattern, '*']), '*', '*', '*', + '*' + ]) + ]) + matcher = graph_matcher.GraphMatcher(relu_pattern) + match_results = list(matcher.match_graph(g)) + self.assertEqual(1, len(match_results)) + match_result = match_results[0] + self.assertEqual(match_result.get_tensor(inputs_pattern), inputs) + self.assertEqual(match_result.get_tensor('inputs'), inputs) + + def test_multiple_outputs(self): + # - + + # / \y0 y1/ \ + # x split z + # | + # y (nodes are ops; edges are going up) + g = ops.Graph() + with g.as_default(): + x = array_ops.placeholder(dtypes.float32, shape=[1], name='x') + y = array_ops.placeholder(dtypes.float32, shape=[2], name='y') + y0, y1 = array_ops.split(y, num_or_size_splits=2, axis=0) + z = array_ops.placeholder(dtypes.float32, shape=[1], name='z') + math_ops.add(x, y0) + math_ops.subtract(y1, z) + + y1_pattern = graph_matcher.OpTypePattern('*') + minus_pattern = graph_matcher.OpTypePattern('Sub', inputs=[y1_pattern, '*']) + matcher = graph_matcher.GraphMatcher(minus_pattern) + + match_results = list(matcher.match_graph(g)) + self.assertEqual(1, len(match_results)) + match_result = match_results[0] + + self.assertEqual(y0.op, y1.op) + self.assertEqual(match_result.get_op(y1_pattern), y1.op) + self.assertEqual(match_result.get_tensor(y1_pattern), y1) + + def test_oneof_pattern(self): + # - + + # / \ / \ + # x y z + g = ops.Graph() + with g.as_default(): + x = array_ops.placeholder(dtypes.float32, shape=[], name='x') + y = array_ops.placeholder(dtypes.float32, shape=[], name='y') + z = array_ops.placeholder(dtypes.float32, shape=[], name='z') + plus = x + y + minus = y - z + + add_or_sub_pattern = graph_matcher.OpTypePattern( + 'Add|Sub', inputs=['*', '*']) + matcher = graph_matcher.GraphMatcher(add_or_sub_pattern) + self.assertEqual([ + match_result.get_op(add_or_sub_pattern) + for match_result in matcher.match_graph(g) + ], [plus.op, minus.op]) + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/contrib/quantize/python/quantize_parameterized_test.py b/tensorflow/contrib/quantize/python/quantize_parameterized_test.py index b5a32a7266a4c3ddf9a481fd9b292ab0f1812a9a..31fcd66dfb7ab33a8cc1a852e8210a286c4cc63a 100644 --- a/tensorflow/contrib/quantize/python/quantize_parameterized_test.py +++ b/tensorflow/contrib/quantize/python/quantize_parameterized_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.layers.python.layers import layers +from tensorflow.contrib.quantize.python import fold_batch_norms from tensorflow.contrib.quantize.python import quantize from tensorflow.python.framework import ops from tensorflow.python.framework import test_util @@ -35,18 +36,11 @@ conv2d = layers.conv2d fully_connected = layers.fully_connected separable_conv2d = layers.separable_conv2d -_DEFAULT_BATCH_NORM_PARAMS = { - 'center': True, - 'scale': True, - 'decay': 1.0 - 0.003, - 'fused': False, -} - -# TODO(suharshs): Use parameterized test once OSS TF supports it. class QuantizeTest(test_util.TensorFlowTestCase): - def _RunTestOverParameters(self, test_fn): + def _RunWithoutBatchNormTestOverParameters(self, test_fn): + # TODO(suharshs): Use parameterized test once OSS TF supports it. parameters_list = [ # (activation, activation_op_name, with_bypass, delay) (nn_ops.relu6, 'Relu6', False, None), @@ -60,10 +54,10 @@ class QuantizeTest(test_util.TensorFlowTestCase): (array_ops.identity, 'Identity', True, None), (nn_ops.relu6, 'Relu6', True, 5000), (nn_ops.relu, 'Relu', True, 5000), - (array_ops.identity, 'Identity', True, 5000) + (array_ops.identity, 'Identity', True, 5000), ] - for parameters in parameters_list: - test_fn(parameters[0], parameters[1], parameters[2], parameters[3]) + for params in parameters_list: + test_fn(params[0], params[1], params[2], params[3]) def _TestQuantize_Conv2dWithoutBatchNorm(self, activation, activation_op_name, with_bypass, delay): @@ -137,7 +131,8 @@ class QuantizeTest(test_util.TensorFlowTestCase): self._AssertOutputGoesToOps(act_quant, graph, [output_op_name]) def testQuantize_Conv2dWithoutBatchNorm(self): - self._RunTestOverParameters(self._TestQuantize_Conv2dWithoutBatchNorm) + self._RunWithoutBatchNormTestOverParameters( + self._TestQuantize_Conv2dWithoutBatchNorm) def _TestQuantize_FCWithoutBatchNorm(self, activation, activation_op_name, with_bypass, delay): @@ -210,7 +205,8 @@ class QuantizeTest(test_util.TensorFlowTestCase): self._AssertOutputGoesToOps(act_quant, graph, [output_op_name]) def testQuantize_FCWithoutBatchNorm(self): - self._RunTestOverParameters(self._TestQuantize_FCWithoutBatchNorm) + self._RunWithoutBatchNormTestOverParameters( + self._TestQuantize_FCWithoutBatchNorm) def _TestQuantize_DepthwiseConv2dWithoutBatchNorm( self, activation, activation_op_name, with_bypass, delay): @@ -284,11 +280,43 @@ class QuantizeTest(test_util.TensorFlowTestCase): self._AssertOutputGoesToOps(act_quant, graph, [output_op_name]) def testQuantize_DepthwiseConv2dWithoutBatchNorm(self): - self._RunTestOverParameters( + self._RunWithoutBatchNormTestOverParameters( self._TestQuantize_DepthwiseConv2dWithoutBatchNorm) + def _RunBatchNormTestOverParameters(self, test_fn): + # TODO(suharshs): Use parameterized test once OSS TF supports it. + parameters_list = [ + # (activation, activation_op_name, with_bypass, delay, fused_batch_norm) + (nn_ops.relu6, 'Relu6', False, None, False), + (nn_ops.relu, 'Relu', False, None, False), + (array_ops.identity, 'Identity', False, None, False), + (nn_ops.relu6, 'Relu6', False, 5000, False), + (nn_ops.relu, 'Relu', False, 5000, False), + (array_ops.identity, 'Identity', False, 5000, False), + (nn_ops.relu6, 'Relu6', True, None, False), + (nn_ops.relu, 'Relu', True, None, False), + (array_ops.identity, 'Identity', True, None, False), + (nn_ops.relu6, 'Relu6', True, 5000, False), + (nn_ops.relu, 'Relu', True, 5000, False), + (array_ops.identity, 'Identity', True, 5000, False), + (nn_ops.relu6, 'Relu6', False, None, True), + (nn_ops.relu, 'Relu', False, None, True), + (array_ops.identity, 'Identity', False, None, True), + (nn_ops.relu6, 'Relu6', False, 5000, True), + (nn_ops.relu, 'Relu', False, 5000, True), + (array_ops.identity, 'Identity', False, 5000, True), + (nn_ops.relu6, 'Relu6', True, None, True), + (nn_ops.relu, 'Relu', True, None, True), + (array_ops.identity, 'Identity', True, None, True), + (nn_ops.relu6, 'Relu6', True, 5000, True), + (nn_ops.relu, 'Relu', True, 5000, True), + (array_ops.identity, 'Identity', True, 5000, True) + ] + for params in parameters_list: + test_fn(params[0], params[1], params[2], params[3], params[4]) + def _TestQuantize_Conv2dWithBatchNorm(self, activation, activation_op_name, - with_bypass, delay): + with_bypass, delay, fused_batch_norm): """Tests quantization: inputs -> Conv2d with batch norm -> Activation. Args: @@ -298,25 +326,29 @@ class QuantizeTest(test_util.TensorFlowTestCase): with_bypass: Bool, when true there is an extra connection added from inputs to just before Activation. delay: Int (optional), delay in number of steps until quantization starts. + fused_batch_norm: Bool, when true use FusedBatchNorm. """ self._testQuantize_Conv2dWithBatchNorm( activation, activation_op_name, with_bypass, delay, + fused_batch_norm, use_ema=True) self._testQuantize_Conv2dWithBatchNorm( activation, activation_op_name, with_bypass, delay, + fused_batch_norm, use_ema=False) def testQuantize_Conv2dWithBatchNorm(self): - self._RunTestOverParameters(self._TestQuantize_Conv2dWithBatchNorm) + self._RunBatchNormTestOverParameters(self._TestQuantize_Conv2dWithBatchNorm) def _testQuantize_Conv2dWithBatchNorm(self, activation, activation_op_name, - with_bypass, delay, use_ema): + with_bypass, delay, fused_batch_norm, + use_ema): """Tests quantization: inputs -> Conv2d with batch norm -> Activation. Args: @@ -326,6 +358,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): with_bypass: Bool, when true there is an extra connection added from inputs to just before Activation. delay: Int (optional), delay in number of steps until quantization starts. + fused_batch_norm: Bool, when true use FusedBatchNorm. use_ema: Bool, when true uses EMA quantization for BN folded weights. """ graph = ops.Graph() @@ -337,39 +370,29 @@ class QuantizeTest(test_util.TensorFlowTestCase): stride = 1 if with_bypass else 2 out_depth = 3 if with_bypass else 32 scope = 'test/test2' if with_bypass else 'test' - node = conv2d(inputs, out_depth, [5, 5], stride=stride, padding='SAME', - weights_initializer=self._WeightInit(0.09), - activation_fn=None, - normalizer_fn=batch_norm, - normalizer_params=_DEFAULT_BATCH_NORM_PARAMS, - scope=scope) - # Manually fold the batch norm. - weights = graph.get_operation_by_name(scope + '/weights/read').outputs[0] - bn_mult = (graph.get_operation_by_name(scope + '/BatchNorm/batchnorm/mul') - .outputs[0]) - mul_fold = math_ops.multiply(weights, bn_mult, name=scope + '/mul_fold') - stride = [stride, stride] - conv_fold = nn_ops.convolution( - input=inputs, - filter=mul_fold, + node = conv2d( + inputs, + out_depth, [5, 5], + stride=stride, padding='SAME', - strides=stride, - data_format='NHWC', - name=scope + '/convolution_Fold') - bn_bias = (graph.get_operation_by_name(scope + '/BatchNorm/batchnorm/sub') - .outputs[0]) - add_fold = math_ops.add(conv_fold, bn_bias, name=scope + '/add_fold') + weights_initializer=self._WeightInit(0.09), + activation_fn=None, + normalizer_fn=batch_norm, + normalizer_params=self._BatchNormParams(fused_batch_norm), + scope=scope) + # Manually add a bypass (optionaly) and an activation. if with_bypass: - node = math_ops.add(inputs, add_fold, name='test/Add') - else: - node = add_fold + node = math_ops.add(inputs, node, name='test/Add') + node = activation(node, name='test/' + activation_op_name) update_barrier = control_flow_ops.no_op(name='update_barrier') with ops.control_dependencies([update_barrier]): array_ops.identity(node, name='control_dependency') + fold_batch_norms.FoldBatchNorms(graph) + quantize.Quantize( graph, quant_delay=delay, quantize_folded_weights_use_ema=use_ema) @@ -413,7 +436,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): self._AssertOutputGoesToOps(act_quant, graph, [output_op_name]) def _TestQuantize_FCWithBatchNorm(self, activation, activation_op_name, - with_bypass, delay): + with_bypass, delay, fused_batch_norm): """Tests quantization: inputs -> FC with batch norm -> Activation. Args: @@ -423,25 +446,29 @@ class QuantizeTest(test_util.TensorFlowTestCase): with_bypass: Bool, when true there is an extra connection added from inputs to just before Activation. delay: Int (optional), delay in number of steps until quantization starts. + fused_batch_norm: Bool, when true use FusedBatchNorm. """ self._testQuantize_FCWithBatchNorm( activation, activation_op_name, with_bypass, delay, + fused_batch_norm, use_ema=True) self._testQuantize_FCWithBatchNorm( activation, activation_op_name, with_bypass, delay, + fused_batch_norm, use_ema=False) def testQuantize_FCWithBatchNorm(self): - self._RunTestOverParameters(self._TestQuantize_FCWithBatchNorm) + self._RunBatchNormTestOverParameters(self._TestQuantize_FCWithBatchNorm) def _testQuantize_FCWithBatchNorm(self, activation, activation_op_name, - with_bypass, delay, use_ema): + with_bypass, delay, fused_batch_norm, + use_ema): """Tests quantization: inputs -> FC with batch norm -> Activation. Args: @@ -451,6 +478,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): with_bypass: Bool, when true there is an extra connection added from inputs to just before Activation. delay: Int (optional), delay in number of steps until quantization starts. + fused_batch_norm: Bool, when true use FusedBatchNorm. use_ema: Bool, when true uses EMA quantization for BN folded weights. """ graph = ops.Graph() @@ -461,32 +489,27 @@ class QuantizeTest(test_util.TensorFlowTestCase): inputs = array_ops.zeros((batch_size, depth)) out_depth = 256 if with_bypass else 128 scope = 'test/test2' if with_bypass else 'test' - node = fully_connected(inputs, out_depth, - weights_initializer=self._WeightInit(0.03), - activation_fn=None, - normalizer_fn=batch_norm, - normalizer_params=_DEFAULT_BATCH_NORM_PARAMS, - scope=scope) - # Manually fold the batch norm. - weights = graph.get_operation_by_name(scope + '/weights/read').outputs[0] - bn_mult = (graph.get_operation_by_name(scope + '/BatchNorm/batchnorm/mul') - .outputs[0]) - mul_fold = math_ops.multiply(weights, bn_mult, name=scope + '/mul_fold') - fc_fold = math_ops.matmul(inputs, mul_fold, name=scope + '/MatMul_Fold') - bn_bias = (graph.get_operation_by_name(scope + '/BatchNorm/batchnorm/sub') - .outputs[0]) - add_fold = math_ops.add(fc_fold, bn_bias, name=scope + '/add_fold') + node = fully_connected( + inputs, + out_depth, + weights_initializer=self._WeightInit(0.03), + activation_fn=None, + normalizer_fn=batch_norm, + normalizer_params=self._BatchNormParams(fused_batch_norm), + scope=scope) + # Manually add a bypass (optionaly) and an activation. if with_bypass: - node = math_ops.add(inputs, add_fold, name='test/Add') - else: - node = add_fold + node = math_ops.add(inputs, node, name='test/Add') + node = activation(node, name='test/' + activation_op_name) update_barrier = control_flow_ops.no_op(name='update_barrier') with ops.control_dependencies([update_barrier]): array_ops.identity(node, name='control_dependency') + fold_batch_norms.FoldBatchNorms(graph) + quantize.Quantize( graph, quant_delay=delay, quantize_folded_weights_use_ema=use_ema) @@ -530,7 +553,8 @@ class QuantizeTest(test_util.TensorFlowTestCase): self._AssertOutputGoesToOps(act_quant, graph, [output_op_name]) def _TestQuantize_DepthwiseConv2dWithBatchNorm( - self, activation, activation_op_name, with_bypass, delay): + self, activation, activation_op_name, with_bypass, delay, + fused_batch_norm): """Tests quantization: inputs -> DWConv2d with batch norm -> Activation. Args: @@ -540,26 +564,30 @@ class QuantizeTest(test_util.TensorFlowTestCase): with_bypass: Bool, when true there is an extra connection added from inputs to just before Activation. delay: Int (optional), delay in number of steps until quantization starts. + fused_batch_norm: Bool, when true use FusedBatchNorm. """ self._testQuantize_DepthwiseConv2dWithBatchNorm( activation, activation_op_name, with_bypass, delay, + fused_batch_norm, use_ema=True) self._testQuantize_DepthwiseConv2dWithBatchNorm( activation, activation_op_name, with_bypass, delay, + fused_batch_norm, use_ema=False) def testQuantize_DepthwiseConv2dWithBatchNorm(self): - self._RunTestOverParameters( - self._TestQuantize_DepthwiseConv2dWithoutBatchNorm) + self._RunBatchNormTestOverParameters( + self._TestQuantize_DepthwiseConv2dWithBatchNorm) def _testQuantize_DepthwiseConv2dWithBatchNorm( - self, activation, activation_op_name, with_bypass, delay, use_ema): + self, activation, activation_op_name, with_bypass, delay, + fused_batch_norm, use_ema): """Tests quantization: inputs -> DWConv2d with batch norm -> Activation. Args: @@ -569,6 +597,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): with_bypass: Bool, when true there is an extra connection added from inputs to just before Activation. delay: Int (optional), delay in number of steps until quantization starts. + fused_batch_norm: Bool, when true use FusedBatchNorm. use_ema: Bool, when true uses EMA quantization for BN folded weights. """ graph = ops.Graph() @@ -579,46 +608,30 @@ class QuantizeTest(test_util.TensorFlowTestCase): inputs = array_ops.zeros((batch_size, height, width, depth)) stride = 1 if with_bypass else 2 scope = 'test/test2' if with_bypass else 'test' - node = separable_conv2d(inputs, None, [5, 5], stride=stride, - depth_multiplier=1.0, padding='SAME', - weights_initializer=self._WeightInit(0.09), - activation_fn=None, - normalizer_fn=batch_norm, - normalizer_params=_DEFAULT_BATCH_NORM_PARAMS, - scope=scope) - # Manually fold the batch norm. - weights = (graph.get_operation_by_name(scope + '/depthwise_weights/read') - .outputs[0]) - bn_mult = (graph.get_operation_by_name(scope + '/BatchNorm/batchnorm/mul') - .outputs[0]) - new_shape = [ - weights.get_shape().as_list()[2], weights.get_shape().as_list()[3] - ] - bn_mult_reshaped = array_ops.reshape( - bn_mult, new_shape, name=scope + '/gamma_reshape') - mul_fold = math_ops.multiply( - weights, bn_mult_reshaped, name=scope + '/mul_fold') - stride = [1, stride, stride, 1] - conv_fold = nn_ops.depthwise_conv2d( - input=inputs, - filter=mul_fold, + node = separable_conv2d( + inputs, + None, [5, 5], + stride=stride, + depth_multiplier=1.0, padding='SAME', - strides=stride, - name=scope + '/depthwise_Fold') - bn_bias = (graph.get_operation_by_name(scope + '/BatchNorm/batchnorm/sub') - .outputs[0]) - add_fold = math_ops.add(conv_fold, bn_bias, name=scope + '/add_fold') + weights_initializer=self._WeightInit(0.09), + activation_fn=None, + normalizer_fn=batch_norm, + normalizer_params=self._BatchNormParams(fused_batch_norm), + scope=scope) + # Manually add a bypass (optionaly) and an activation. if with_bypass: - node = math_ops.add(inputs, add_fold, name='test/Add') - else: - node = add_fold + node = math_ops.add(inputs, node, name='test/Add') + node = activation(node, name='test/' + activation_op_name) update_barrier = control_flow_ops.no_op(name='update_barrier') with ops.control_dependencies([update_barrier]): array_ops.identity(node, name='control_dependency') + fold_batch_norms.FoldBatchNorms(graph) + quantize.Quantize( graph, quant_delay=delay, quantize_folded_weights_use_ema=use_ema) quantization_node_name = 'FakeQuantWithMinMaxVars' @@ -660,6 +673,9 @@ class QuantizeTest(test_util.TensorFlowTestCase): if delay else 'control_dependency') self._AssertOutputGoesToOps(act_quant, graph, [output_op_name]) + def _BatchNormParams(self, fused=False): + return {'center': True, 'scale': True, 'decay': 1.0 - 0.003, 'fused': fused} + def _WeightInit(self, stddev): """Returns truncated normal variable initializer. diff --git a/tensorflow/contrib/rnn/BUILD b/tensorflow/contrib/rnn/BUILD index 571d299ad9ed8da51013b2106f27b7678fb49c4d..29ba26d75dcce6ac1983f82dc2dfc03323e0ec5f 100644 --- a/tensorflow/contrib/rnn/BUILD +++ b/tensorflow/contrib/rnn/BUILD @@ -156,6 +156,7 @@ cuda_py_tests( "//tensorflow/python:client_testlib", "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", "//tensorflow/python:gradients", "//tensorflow/python:init_ops", "//tensorflow/python:math_ops", @@ -165,6 +166,7 @@ cuda_py_tests( "//tensorflow/python:util", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", + "//tensorflow/python/eager:context", ], shard_count = 10, ) diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py index 2fa033632acb451762c60a68f659302102d6c3ab..12def6dcc8a5bb32329064e765588f99c5eb16fc 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py @@ -25,10 +25,12 @@ from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.contrib import rnn as rnn_lib from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops as ops_lib from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients_impl @@ -881,6 +883,7 @@ class LSTMTest(test.TestCase): # Smoke test, this should not raise an error rnn.dynamic_rnn(cell, inputs, dtype=dtypes.float32) + @test_util.run_in_graph_and_eager_modes() def testDynamicRNNWithTupleStates(self): num_units = 3 input_size = 5 @@ -888,13 +891,20 @@ class LSTMTest(test.TestCase): num_proj = 4 max_length = 8 sequence_length = [4, 6] + in_graph_mode = context.in_graph_mode() with self.test_session(graph=ops_lib.Graph()) as sess: initializer = init_ops.random_uniform_initializer( -0.01, 0.01, seed=self._seed) - inputs = max_length * [ - array_ops.placeholder( - dtypes.float32, shape=(None, input_size)) - ] + if in_graph_mode: + inputs = max_length * [ + array_ops.placeholder( + dtypes.float32, shape=(None, input_size)) + ] + else: + inputs = max_length * [ + constant_op.constant( + np.random.randn(batch_size, input_size).astype(np.float32)) + ] inputs_c = array_ops.stack(inputs) cell = rnn_cell.LSTMCell( num_units, @@ -924,21 +934,34 @@ class LSTMTest(test.TestCase): self.assertEqual(state_dynamic[0], state_dynamic.c) self.assertEqual(state_dynamic[1], state_dynamic.h) - variables_lib.global_variables_initializer().run() - - input_value = np.random.randn(batch_size, input_size) - outputs_static_v = sess.run(outputs_static, - feed_dict={inputs[0]: input_value}) - outputs_dynamic_v = sess.run(outputs_dynamic, - feed_dict={inputs[0]: input_value}) - self.assertAllEqual(outputs_static_v, outputs_dynamic_v) - - state_static_v = sess.run(state_static, - feed_dict={inputs[0]: input_value}) - state_dynamic_v = sess.run(state_dynamic, - feed_dict={inputs[0]: input_value}) - self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_dynamic_v)) + if in_graph_mode: + variables_lib.global_variables_initializer().run() + input_value = np.random.randn(batch_size, input_size) + outputs_static = sess.run( + outputs_static, feed_dict={ + inputs[0]: input_value + }) + outputs_dynamic = sess.run( + outputs_dynamic, feed_dict={ + inputs[0]: input_value + }) + state_static = sess.run( + state_static, feed_dict={ + inputs[0]: input_value + }) + state_dynamic = sess.run( + state_dynamic, feed_dict={ + inputs[0]: input_value + }) + + if in_graph_mode: + self.assertAllEqual(outputs_static, outputs_dynamic) + else: + self.assertAllEqual( + array_ops.stack(outputs_static).numpy(), outputs_dynamic.numpy()) + self.assertAllEqual(np.hstack(state_static), np.hstack(state_dynamic)) + @test_util.run_in_graph_and_eager_modes() def testDynamicRNNWithNestedTupleStates(self): num_units = 3 input_size = 5 @@ -946,13 +969,20 @@ class LSTMTest(test.TestCase): num_proj = 4 max_length = 8 sequence_length = [4, 6] + in_graph_mode = context.in_graph_mode() with self.test_session(graph=ops_lib.Graph()) as sess: initializer = init_ops.random_uniform_initializer( -0.01, 0.01, seed=self._seed) - inputs = max_length * [ - array_ops.placeholder( - dtypes.float32, shape=(None, input_size)) - ] + if in_graph_mode: + inputs = max_length * [ + array_ops.placeholder( + dtypes.float32, shape=(None, input_size)) + ] + else: + inputs = max_length * [ + constant_op.constant( + np.random.randn(batch_size, input_size).astype(np.float32)) + ] inputs_c = array_ops.stack(inputs) def _cell(i): @@ -993,20 +1023,34 @@ class LSTMTest(test.TestCase): sequence_length=sequence_length, scope=scope) - variables_lib.global_variables_initializer().run() - - input_value = np.random.randn(batch_size, input_size) - outputs_static_v = sess.run(outputs_static, - feed_dict={inputs[0]: input_value}) - outputs_dynamic_v = sess.run(outputs_dynamic, - feed_dict={inputs[0]: input_value}) - self.assertAllEqual(outputs_static_v, outputs_dynamic_v) - - state_static_v = sess.run(nest.flatten(state_static), - feed_dict={inputs[0]: input_value}) - state_dynamic_v = sess.run(nest.flatten(state_dynamic), - feed_dict={inputs[0]: input_value}) - self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_dynamic_v)) + if in_graph_mode: + input_value = np.random.randn(batch_size, input_size) + variables_lib.global_variables_initializer().run() + outputs_static = sess.run( + outputs_static, feed_dict={ + inputs[0]: input_value + }) + outputs_dynamic = sess.run( + outputs_dynamic, feed_dict={ + inputs[0]: input_value + }) + state_static = sess.run( + nest.flatten(state_static), feed_dict={ + inputs[0]: input_value + }) + state_dynamic = sess.run( + nest.flatten(state_dynamic), feed_dict={ + inputs[0]: input_value + }) + + if in_graph_mode: + self.assertAllEqual(outputs_static, outputs_dynamic) + else: + self.assertAllEqual( + array_ops.stack(outputs_static).numpy(), outputs_dynamic.numpy()) + state_static = [s.numpy() for s in nest.flatten(state_static)] + state_dynamic = [s.numpy() for s in nest.flatten(state_dynamic)] + self.assertAllEqual(np.hstack(state_static), np.hstack(state_dynamic)) def _testDynamicEquivalentToStaticRNN(self, use_gpu, use_sequence_length): time_steps = 8 @@ -1015,21 +1059,22 @@ class LSTMTest(test.TestCase): input_size = 5 batch_size = 2 - input_values = np.random.randn(time_steps, batch_size, input_size) + input_values = np.random.randn(time_steps, batch_size, input_size).astype( + np.float32) if use_sequence_length: sequence_length = np.random.randint(0, time_steps, size=batch_size) else: sequence_length = None - ########### Step 1: Run static graph and generate readouts - with self.test_session(use_gpu=use_gpu, graph=ops_lib.Graph()) as sess: - concat_inputs = array_ops.placeholder( - dtypes.float32, shape=(time_steps, batch_size, input_size)) - inputs = array_ops.unstack(concat_inputs) + in_graph_mode = context.in_graph_mode() + + # TODO(b/68017812): Eager ignores operation seeds, so we need to create a + # single cell and reuse it across the static and dynamic RNNs. Remove this + # special case once is fixed. + if not in_graph_mode: initializer = init_ops.random_uniform_initializer( -0.01, 0.01, seed=self._seed) - cell = rnn_cell.LSTMCell( num_units, use_peepholes=True, @@ -1037,63 +1082,85 @@ class LSTMTest(test.TestCase): num_proj=num_proj, state_is_tuple=False) + ########### Step 1: Run static graph and generate readouts + with self.test_session(use_gpu=use_gpu, graph=ops_lib.Graph()) as sess: + if in_graph_mode: + concat_inputs = array_ops.placeholder( + dtypes.float32, shape=(time_steps, batch_size, input_size)) + else: + concat_inputs = constant_op.constant(input_values) + inputs = array_ops.unstack(concat_inputs) + initializer = init_ops.random_uniform_initializer( + -0.01, 0.01, seed=self._seed) + + # TODO(akshayka): Remove special case once b/68017812 is fixed. + if in_graph_mode: + cell = rnn_cell.LSTMCell( + num_units, + use_peepholes=True, + initializer=initializer, + num_proj=num_proj, + state_is_tuple=False) + with variable_scope.variable_scope("dynamic_scope"): outputs_static, state_static = rnn.static_rnn( cell, inputs, sequence_length=sequence_length, dtype=dtypes.float32) - feeds = {concat_inputs: input_values} - - # Initialize - variables_lib.global_variables_initializer().run(feed_dict=feeds) - - # Generate gradients of sum of outputs w.r.t. inputs - static_gradients = gradients_impl.gradients( - outputs_static + [state_static], [concat_inputs]) - - # Generate gradients of individual outputs w.r.t. inputs - static_individual_gradients = nest.flatten([ - gradients_impl.gradients(y, [concat_inputs]) - for y in [outputs_static[0], outputs_static[-1], state_static] - ]) - - # Generate gradients of individual variables w.r.t. inputs - trainable_variables = ops_lib.get_collection( - ops_lib.GraphKeys.TRAINABLE_VARIABLES) - assert len(trainable_variables) > 1, ("Count of trainable variables: %d" % - len(trainable_variables)) - # pylint: disable=bad-builtin - static_individual_variable_gradients = nest.flatten([ - gradients_impl.gradients(y, trainable_variables) - for y in [outputs_static[0], outputs_static[-1], state_static] - ]) - - # Test forward pass - values_static = sess.run(outputs_static, feed_dict=feeds) - (state_value_static,) = sess.run((state_static,), feed_dict=feeds) - - # Test gradients to inputs and variables w.r.t. outputs & final state - static_grad_values = sess.run(static_gradients, feed_dict=feeds) - - static_individual_grad_values = sess.run(static_individual_gradients, - feed_dict=feeds) - - static_individual_var_grad_values = sess.run( - static_individual_variable_gradients, feed_dict=feeds) + if in_graph_mode: + # Generate gradients and run sessions to obtain outputs + feeds = {concat_inputs: input_values} + # Initialize + variables_lib.global_variables_initializer().run(feed_dict=feeds) + # Generate gradients of sum of outputs w.r.t. inputs + static_gradients = gradients_impl.gradients( + outputs_static + [state_static], [concat_inputs]) + # Generate gradients of individual outputs w.r.t. inputs + static_individual_gradients = nest.flatten([ + gradients_impl.gradients(y, [concat_inputs]) + for y in [outputs_static[0], outputs_static[-1], state_static] + ]) + # Generate gradients of individual variables w.r.t. inputs + trainable_variables = ops_lib.get_collection( + ops_lib.GraphKeys.TRAINABLE_VARIABLES) + assert len(trainable_variables) > 1, ( + "Count of trainable variables: %d" % len(trainable_variables)) + # pylint: disable=bad-builtin + static_individual_variable_gradients = nest.flatten([ + gradients_impl.gradients(y, trainable_variables) + for y in [outputs_static[0], outputs_static[-1], state_static] + ]) + # Test forward pass + values_static = sess.run(outputs_static, feed_dict=feeds) + (state_value_static,) = sess.run((state_static,), feed_dict=feeds) + + # Test gradients to inputs and variables w.r.t. outputs & final state + static_grad_values = sess.run(static_gradients, feed_dict=feeds) + + static_individual_grad_values = sess.run(static_individual_gradients, + feed_dict=feeds) + + static_individual_var_grad_values = sess.run( + static_individual_variable_gradients, feed_dict=feeds) ########## Step 2: Run dynamic graph and generate readouts with self.test_session(use_gpu=use_gpu, graph=ops_lib.Graph()) as sess: - concat_inputs = array_ops.placeholder( - dtypes.float32, shape=(time_steps, batch_size, input_size)) - inputs = array_ops.unstack(concat_inputs) + if in_graph_mode: + concat_inputs = array_ops.placeholder( + dtypes.float32, shape=(time_steps, batch_size, input_size)) + else: + concat_inputs = constant_op.constant(input_values) initializer = init_ops.random_uniform_initializer( -0.01, 0.01, seed=self._seed) - cell = rnn_cell.LSTMCell( - num_units, - use_peepholes=True, - initializer=initializer, - num_proj=num_proj, - state_is_tuple=False) + # TODO(akshayka): Remove this special case once b/68017812 is + # fixed. + if in_graph_mode: + cell = rnn_cell.LSTMCell( + num_units, + use_peepholes=True, + initializer=initializer, + num_proj=num_proj, + state_is_tuple=False) with variable_scope.variable_scope("dynamic_scope"): outputs_dynamic, state_dynamic = rnn.dynamic_rnn( @@ -1104,72 +1171,83 @@ class LSTMTest(test.TestCase): dtype=dtypes.float32) split_outputs_dynamic = array_ops.unstack(outputs_dynamic, time_steps) - feeds = {concat_inputs: input_values} + if in_graph_mode: + feeds = {concat_inputs: input_values} - # Initialize - variables_lib.global_variables_initializer().run(feed_dict=feeds) + # Initialize + variables_lib.global_variables_initializer().run(feed_dict=feeds) + + # Generate gradients of sum of outputs w.r.t. inputs + dynamic_gradients = gradients_impl.gradients( + split_outputs_dynamic + [state_dynamic], [concat_inputs]) + + # Generate gradients of several individual outputs w.r.t. inputs + dynamic_individual_gradients = nest.flatten([ + gradients_impl.gradients(y, [concat_inputs]) + for y in + [split_outputs_dynamic[0], split_outputs_dynamic[-1], state_dynamic] + ]) + + # Generate gradients of individual variables w.r.t. inputs + trainable_variables = ops_lib.get_collection( + ops_lib.GraphKeys.TRAINABLE_VARIABLES) + assert len(trainable_variables) > 1, ( + "Count of trainable variables: %d" % len(trainable_variables)) + dynamic_individual_variable_gradients = nest.flatten([ + gradients_impl.gradients(y, trainable_variables) + for y in + [split_outputs_dynamic[0], split_outputs_dynamic[-1], state_dynamic] + ]) - # Generate gradients of sum of outputs w.r.t. inputs - dynamic_gradients = gradients_impl.gradients( - split_outputs_dynamic + [state_dynamic], [concat_inputs]) - - # Generate gradients of several individual outputs w.r.t. inputs - dynamic_individual_gradients = nest.flatten([ - gradients_impl.gradients(y, [concat_inputs]) - for y in - [split_outputs_dynamic[0], split_outputs_dynamic[-1], state_dynamic] - ]) - - # Generate gradients of individual variables w.r.t. inputs - trainable_variables = ops_lib.get_collection( - ops_lib.GraphKeys.TRAINABLE_VARIABLES) - assert len(trainable_variables) > 1, ("Count of trainable variables: %d" % - len(trainable_variables)) - dynamic_individual_variable_gradients = nest.flatten([ - gradients_impl.gradients(y, trainable_variables) - for y in - [split_outputs_dynamic[0], split_outputs_dynamic[-1], state_dynamic] - ]) - - # Test forward pass - values_dynamic = sess.run(split_outputs_dynamic, feed_dict=feeds) - (state_value_dynamic,) = sess.run((state_dynamic,), feed_dict=feeds) - - # Test gradients to inputs and variables w.r.t. outputs & final state - dynamic_grad_values = sess.run(dynamic_gradients, feed_dict=feeds) - - dynamic_individual_grad_values = sess.run(dynamic_individual_gradients, - feed_dict=feeds) - - dynamic_individual_var_grad_values = sess.run( - dynamic_individual_variable_gradients, feed_dict=feeds) + # Test forward pass + values_dynamic = sess.run(split_outputs_dynamic, feed_dict=feeds) + (state_value_dynamic,) = sess.run((state_dynamic,), feed_dict=feeds) + + # Test gradients to inputs and variables w.r.t. outputs & final state + dynamic_grad_values = sess.run(dynamic_gradients, feed_dict=feeds) + + dynamic_individual_grad_values = sess.run(dynamic_individual_gradients, + feed_dict=feeds) + + dynamic_individual_var_grad_values = sess.run( + dynamic_individual_variable_gradients, feed_dict=feeds) ######### Step 3: Comparisons + if not in_graph_mode: + values_static = outputs_static + values_dynamic = split_outputs_dynamic + state_value_static = state_static + state_value_dynamic = state_dynamic + self.assertEqual(len(values_static), len(values_dynamic)) for (value_static, value_dynamic) in zip(values_static, values_dynamic): self.assertAllEqual(value_static, value_dynamic) self.assertAllEqual(state_value_static, state_value_dynamic) - self.assertAllEqual(static_grad_values, dynamic_grad_values) + if in_graph_mode: + + self.assertAllEqual(static_grad_values, dynamic_grad_values) - self.assertEqual( - len(static_individual_grad_values), len(dynamic_individual_grad_values)) - self.assertEqual( - len(static_individual_var_grad_values), - len(dynamic_individual_var_grad_values)) + self.assertEqual( + len(static_individual_grad_values), + len(dynamic_individual_grad_values)) + self.assertEqual( + len(static_individual_var_grad_values), + len(dynamic_individual_var_grad_values)) - for i, (a, b) in enumerate( - zip(static_individual_grad_values, dynamic_individual_grad_values)): - tf_logging.info("Comparing individual gradients iteration %d" % i) - self.assertAllEqual(a, b) + for i, (a, b) in enumerate( + zip(static_individual_grad_values, dynamic_individual_grad_values)): + tf_logging.info("Comparing individual gradients iteration %d" % i) + self.assertAllEqual(a, b) - for i, (a, b) in enumerate( - zip(static_individual_var_grad_values, - dynamic_individual_var_grad_values)): - tf_logging.info("Comparing individual variable gradients iteration %d" % - i) - self.assertAllEqual(a, b) + for i, (a, b) in enumerate( + zip(static_individual_var_grad_values, + dynamic_individual_var_grad_values)): + tf_logging.info("Comparing individual variable gradients iteration %d" % + i) + self.assertAllEqual(a, b) + @test_util.run_in_graph_and_eager_modes() def testDynamicEquivalentToStaticRNN(self): self._testDynamicEquivalentToStaticRNN( use_gpu=False, use_sequence_length=False) diff --git a/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc b/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc index 95273e2b33ef88635a35249ed9f4c1b51c3b03d4..64973ccccdc962757a727d7183bd70e94edcfd1b 100644 --- a/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc +++ b/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc @@ -112,7 +112,7 @@ struct GatherTree { const int32 max_time = parent_ids.dimension(0); const int32 batch_size = parent_ids.dimension(1); const int32 beam_width = parent_ids.dimension(2); - beams.setConstant(-1); + beams.setConstant(end_token); auto DoWork = [&, ctx, end_token](int start_batch_beam, int limit_batch_beam) { @@ -138,10 +138,13 @@ struct GatherTree { beams(level, batch, beam) = step_ids(level, batch, parent); parent = parent_ids(level, batch, parent); } + // Not necessary when using a BeamSearchDecoder, but necessary + // when a user feeds in possibly broken trajectory (i.e., non-eos + // entries in a beam following eos entries). bool finished = false; for (int32 time = 0; time < max_seq_len_b; ++time) { if (finished) { - beams(time, batch, beam) = -1; + beams(time, batch, beam) = end_token; } else if (beams(time, batch, beam) == end_token) { finished = true; } diff --git a/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc b/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc index e71efc48cecf1f5721b6806ca00a8c9f25e5fcc2..bc28d492fe1a25afe0d0783539aa9e759e7b703f 100644 --- a/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc +++ b/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc @@ -46,24 +46,31 @@ __global__ void GatherTreeOpKernel(const int32 batch_size, const int32 max_time, const int32 initial_beam_ix = GET_IX(max_seq_len_b - 1, beam); beams[initial_beam_ix] = ldg(step_ids + initial_beam_ix); int32 parent = ldg(parent_ids + initial_beam_ix); + bool found_bad = false; for (int32 level = max_seq_len_b - 2; level >= 0; --level) { const int32 level_beam_ix = GET_IX(level, beam); const int32 level_parent_ix = GET_IX(level, parent); if (parent < 0 || parent > beam_width) { beams[level_beam_ix] = -1; parent = -1; + found_bad = true; } else { beams[level_beam_ix] = ldg(step_ids + level_parent_ix); parent = ldg(parent_ids + level_parent_ix); } } - bool finished = false; - for (int32 time = 0; time < max_seq_len_b; ++time) { - const int32 level_beam_ix = GET_IX(time, beam); - if (finished) { - beams[level_beam_ix] = -1; - } else if (beams[level_beam_ix] == end_token) { - finished = true; + // Not necessary when using a BeamSearchDecoder, but necessary + // when a user feeds in possibly broken trajectory (i.e., non-eos + // entries in a beam following eos entries). + if (!found_bad) { + bool finished = false; + for (int32 time = 0; time < max_seq_len_b; ++time) { + const int32 level_beam_ix = GET_IX(time, beam); + if (finished) { + beams[level_beam_ix] = end_token; + } else if (beams[level_beam_ix] == end_token) { + finished = true; + } } } #undef GET_IX @@ -80,8 +87,8 @@ struct GatherTree { const int32 max_time = parent_ids.dimension(0); const int32 batch_size = parent_ids.dimension(1); const int32 beam_width = parent_ids.dimension(2); - // First kernel launch to zero things out - beams.device(d) = beams.constant(T(-1)); + // First kernel launch to "zero" things out + beams.device(d) = beams.constant(end_token); CudaLaunchConfig config = GetCudaLaunchConfig(batch_size * beam_width, d); // clang-format off diff --git a/tensorflow/contrib/seq2seq/ops/beam_search_ops.cc b/tensorflow/contrib/seq2seq/ops/beam_search_ops.cc index 231504bfbb3fb977fbe96a00e7b0898f481f0968..71539b6f592f0c8e53c4bb3801d1e35f34814966 100644 --- a/tensorflow/contrib/seq2seq/ops/beam_search_ops.cc +++ b/tensorflow/contrib/seq2seq/ops/beam_search_ops.cc @@ -53,11 +53,14 @@ REGISTER_OP("GatherTree") .Doc(R"doc( Calculates the full beams from the per-step ids and parent beam ids. -This op implements the following mathematical equations: +On CPU, if an out of bound parent id is found, an error is returned. +On GPU, if an out of bound parent id is found, a -1 is stored in the +corresponding output value and the execution for that beam returns early. -```python -TODO(ebrevdo): fill in -``` +For a given beam, past the time step containing the first decoded `end_token` +all values are filled in with `end_token`. + +TODO(ebrevdo): fill in the remainder of this docstring. step_ids: `[max_time, batch_size, beam_width]`. parent_ids: `[max_time, batch_size, beam_width]`. diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py index f30131487203069d513d9f4fe815a883dafd4ac8..277c5b6ef76bce8d59e47cf0026c6e2b1d5cf1e2 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py @@ -36,24 +36,26 @@ class GatherTreeTest(test.TestCase): def testGatherTreeOne(self): # (max_time = 4, batch_size = 1, beams = 3) + end_token = 10 step_ids = _transpose_batch_time( [[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]]) parent_ids = _transpose_batch_time( [[[0, 0, 0], [0, 1, 1], [2, 1, 2], [-1, -1, -1]]]) max_sequence_lengths = [3] - expected_result = _transpose_batch_time( - [[[2, 2, 2], [6, 5, 6], [7, 8, 9], [-1, -1, -1]]]) + expected_result = _transpose_batch_time([[[2, 2, 2], [6, 5, 6], [7, 8, 9], + [10, 10, 10]]]) beams = beam_search_ops.gather_tree( step_ids=step_ids, parent_ids=parent_ids, max_sequence_lengths=max_sequence_lengths, - end_token=10) + end_token=end_token) with self.test_session(use_gpu=True): self.assertAllEqual(expected_result, beams.eval()) def testBadParentValuesOnCPU(self): # (batch_size = 1, max_time = 4, beams = 3) # bad parent in beam 1 time 1 + end_token = 10 step_ids = _transpose_batch_time( [[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]]) parent_ids = _transpose_batch_time( @@ -64,7 +66,7 @@ class GatherTreeTest(test.TestCase): step_ids=step_ids, parent_ids=parent_ids, max_sequence_lengths=max_sequence_lengths, - end_token=10) + end_token=end_token) with self.test_session(): with self.assertRaisesOpError( r"parent id -1 at \(batch, time, beam\) == \(0, 0, 1\)"): @@ -77,19 +79,20 @@ class GatherTreeTest(test.TestCase): return # (max_time = 4, batch_size = 1, beams = 3) # bad parent in beam 1 time 1; appears as a negative index at time 0 + end_token = 10 step_ids = _transpose_batch_time( [[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]]) parent_ids = _transpose_batch_time( [[[0, 0, 0], [0, -1, 1], [2, 1, 2], [-1, -1, -1]]]) max_sequence_lengths = [3] - expected_result = _transpose_batch_time( - [[[2, -1, 2], [6, 5, 6], [7, 8, 9], [-1, -1, -1]]]) + expected_result = _transpose_batch_time([[[2, -1, 2], [6, 5, 6], [7, 8, 9], + [10, 10, 10]]]) with ops.device("/device:GPU:0"): beams = beam_search_ops.gather_tree( step_ids=step_ids, parent_ids=parent_ids, max_sequence_lengths=max_sequence_lengths, - end_token=10) + end_token=end_token) with self.test_session(use_gpu=True): self.assertAllEqual(expected_result, beams.eval()) @@ -115,24 +118,24 @@ class GatherTreeTest(test.TestCase): self.assertEqual((max_time, batch_size, beam_width), beams.shape) beams_value = beams.eval() for b in range(batch_size): - # Past max_sequence_lengths[b], we emit all -1s. + # Past max_sequence_lengths[b], we emit all end tokens. b_value = beams_value[max_sequence_lengths[b]:, b, :] - self.assertAllClose(b_value, -1. * np.ones_like(b_value)) + self.assertAllClose(b_value, end_token * np.ones_like(b_value)) for batch, beam in itertools.product( range(batch_size), range(beam_width)): v = np.squeeze(beams_value[:, batch, beam]) if end_token in v: + found_bad = np.where(v == -1)[0] + self.assertEqual(0, len(found_bad)) found = np.where(v == end_token)[0] - # Should be up to 1 instance of end_token per beam. - self.assertEqual(len(found), 1) - found = found[0] + found = found[0] # First occurrence of end_token. # If an end_token is found, everything before it should be a # valid id and everything after it should be -1. if found > 0: self.assertAllEqual( v[:found - 1] >= 0, np.ones_like(v[:found - 1], dtype=bool)) - self.assertAllClose( - v[found + 1:], -1 * np.ones_like(v[found + 1:])) + self.assertAllClose(v[found + 1:], + end_token * np.ones_like(v[found + 1:])) if __name__ == "__main__": diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 04e0719a1be90cb3b094109d737b4f0db5fa0ce2..805de16468a505b7c4eae76cadd32ff4667f0e07 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -20,6 +20,7 @@ from __future__ import division from __future__ import print_function import collections +from contextlib import contextmanager import copy import threading import six @@ -38,6 +39,7 @@ from tensorflow.core.protobuf import config_pb2 from tensorflow.python.estimator import estimator as estimator_lib from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator import util +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -57,12 +59,15 @@ from tensorflow.python.training import training_util _INITIAL_LOSS = 1e7 _ZERO_LOSS = 0. -_DEFAULT_NAME_SCOPE = 'tpu_estimator' +_TPU_ESTIMATOR = 'tpu_estimator' _ITERATIONS_PER_LOOP_VAR = 'iterations_per_loop' _BATCH_SIZE_KEY = 'batch_size' _CROSS_REPLICA_SUM_OP = 'CrossReplicaSum' _RESERVED_PARAMS_KEYS = [_BATCH_SIZE_KEY] +# TODO(b/65703635): Flip the value and remove all dead code. +_WRAP_INPUT_FN_INTO_WHILE_LOOP = False + def _create_global_step(graph): graph = graph or ops.get_default_graph() @@ -81,17 +86,25 @@ def _create_global_step(graph): ops.GraphKeys.GLOBAL_STEP]) -def _create_iterations_per_loop(): - with variable_scope.variable_scope(_DEFAULT_NAME_SCOPE, - reuse=variable_scope.AUTO_REUSE): - return variable_scope.get_variable( - _ITERATIONS_PER_LOOP_VAR, - initializer=init_ops.zeros_initializer(), - shape=[], - dtype=dtypes.int32, - trainable=False, - collections=[], - use_resource=True) +def _create_or_get_iterations_per_loop(): + graph = ops.get_default_graph() + iter_vars = graph.get_collection(_TPU_ESTIMATOR) + if len(iter_vars) == 1: + return iter_vars[0] + elif len(iter_vars) > 1: + raise RuntimeError('Multiple iterations_per_loop_var in collection.') + + with ops.colocate_with(training_util.get_global_step()): + with variable_scope.variable_scope(_TPU_ESTIMATOR, + reuse=variable_scope.AUTO_REUSE): + return variable_scope.get_variable( + _ITERATIONS_PER_LOOP_VAR, + initializer=init_ops.zeros_initializer(), + shape=[], + dtype=dtypes.int32, + trainable=False, + collections=[_TPU_ESTIMATOR], + use_resource=True) def _sync_variables_ops(): @@ -127,64 +140,209 @@ _DEFAULT_COORDINATOR_JOB_NAME = 'coordinator' _LOCAL_MASTERS = ('', 'local') -def _tpu_job(run_config, mode): - """Returns the job name to use to place TPU computations on. - - Args: - run_config: The tpu_config.RunConfig used for this custom estimator. - mode: A model_fn_lib.ModeKeys value. +class _TPUContext(object): + """A context holds immutable states of TPU computation. - Returns: - A string containing the job name, or None if no job should be specified. + This immutable object holds TPUEstimator config, train/eval batch size, and + `TPUEstimator.use_tpu`, which is expected to be passed around. It also + provides utility functions, basded on the current state, to determine other + information commonly required by TPU computation, such as TPU device names, + TPU hosts, shard batch size, etc. - Raises: - ValueError: If the user needs to specify a tpu_job_name, because we are - unable to infer the job name automatically, or if the user-specified job - names are inappropriate. + N.B. As `mode` is not immutable state in Estimator, but essential to + distinguish between TPU training and evaluation, a common usage for + _TPUContext with `mode` is as follows: + ``` + with _ctx.with_mode(mode) as ctx: + if ctx.is_running_on_cpu(): + ... + ``` """ - # If the user specifies the tpu_job_name, use that. - if run_config.tpu_config.tpu_job_name: - return run_config.tpu_config.tpu_job_name - - # The tpu job is determined by the run_config. Right now, this method is - # required as tpu_config is not part of the RunConfig. - master = (run_config.evaluation_master if mode == model_fn_lib.ModeKeys.EVAL - else run_config.master) - if master in _LOCAL_MASTERS: - return None - - if (not run_config.session_config or - not run_config.session_config.cluster_def.job): - return _DEFAULT_JOB_NAME - cluster_def = run_config.session_config.cluster_def - job_names = set([job.name for job in cluster_def.job]) - if _DEFAULT_JOB_NAME in job_names: - # b/37868888 tracks allowing ClusterSpec propagation to reuse job names. - raise ValueError('Currently, tpu_worker is not an allowed job name.') - if len(job_names) == 1: - return cluster_def.job[0].name - if len(job_names) == 2: - if _DEFAULT_COORDINATOR_JOB_NAME in job_names: - job_names.remove(_DEFAULT_COORDINATOR_JOB_NAME) - return job_names.pop() - # TODO(b/67716447): Include more sophisticated heuristics. - raise ValueError( - 'Could not infer TPU job name. Please specify a tpu_job_name as part of ' - 'your TPUConfig.') - - -def _is_running_on_cpu(use_tpu, mode, eval_batch_size): - """Determines whether the input_fn and model_fn should be invoked on CPU.""" - return ((not use_tpu) or mode == model_fn_lib.ModeKeys.PREDICT or - (mode == model_fn_lib.ModeKeys.EVAL and eval_batch_size is None)) - - -def _per_shard_batch_size(global_batch_size, run_config, use_tpu): - """Returns the batch size for each shard.""" - if use_tpu: - return global_batch_size // run_config.tpu_config.num_shards - else: - return global_batch_size + + def __init__(self, config, train_batch_size, eval_batch_size, use_tpu): + self._config = config + self._train_batch_size = train_batch_size + self._eval_batch_size = eval_batch_size + self._use_tpu = use_tpu + self._num_shards_or_none = self._config.tpu_config.num_shards + self._mode = None + + def _assert_mode(self): + if self._mode is None: + raise RuntimeError( + '`mode` needs to be set via contextmanager `with_mode`.') + return self._mode + + @property + def num_of_cores_per_host(self): + num_cores = self.num_cores + return min(num_cores, 8) + + @contextmanager + def with_mode(self, mode): + new_ctx = copy.copy(self) # Shallow copy is enough. + new_ctx._mode = mode # pylint: disable=protected-access + yield new_ctx + + @property + def mode(self): + return self._assert_mode() + + @property + def num_cores(self): + # TODO(xiejw): Adds lazy num_shards initialization. + return self._num_shards_or_none + + @property + def num_hosts(self): + return self.num_cores // self.num_of_cores_per_host + + @property + def config(self): + return self._config + + def is_input_sharded_per_core(self): + """Return true if input_fn is invoked per-core (other than per-host).""" + self._assert_mode() + return (self._mode == model_fn_lib.ModeKeys.TRAIN and + not self._config.tpu_config.per_host_input_for_training) + + def is_running_on_cpu(self): + """Determines whether the input_fn and model_fn should be invoked on CPU.""" + mode = self._assert_mode() + return ((not self._use_tpu) or mode == model_fn_lib.ModeKeys.PREDICT or + (mode == model_fn_lib.ModeKeys.EVAL and + self._eval_batch_size is None)) + + @property + def batch_size_for_input_fn(self): + """Returns the shard batch size for `input_fn`.""" + mode = self._assert_mode() + # Special case for eval. + if mode == model_fn_lib.ModeKeys.EVAL and self._eval_batch_size is None: + return None + if self.is_running_on_cpu(): + if mode == model_fn_lib.ModeKeys.TRAIN: + return self._train_batch_size + if mode == model_fn_lib.ModeKeys.EVAL: + return self._eval_batch_size + return None + + global_batch_size = (self._train_batch_size if + mode == model_fn_lib.ModeKeys.TRAIN + else self._eval_batch_size) + # On TPU + return (global_batch_size // self.num_cores + if self.is_input_sharded_per_core() else global_batch_size) + + @property + def batch_size_for_model_fn(self): + """Returns the shard batch size for `model_fn`.""" + mode = self._assert_mode() + # Special case for eval. + if mode == model_fn_lib.ModeKeys.EVAL and self._eval_batch_size is None: + return None + if self.is_running_on_cpu(): + if mode == model_fn_lib.ModeKeys.TRAIN: + return self._train_batch_size + if mode == model_fn_lib.ModeKeys.EVAL: + return self._eval_batch_size + return None + + # On TPU. always sharded per core. + if mode == model_fn_lib.ModeKeys.TRAIN: + return self._train_batch_size // self.num_cores + else: + return self._eval_batch_size // self.num_cores + + @property + def master_job(self): + """Returns the job name to use to place TPU computations on. + + Returns: + A string containing the job name, or None if no job should be specified. + + Raises: + ValueError: If the user needs to specify a tpu_job_name, because we are + unable to infer the job name automatically, or if the user-specified job + names are inappropriate. + """ + run_config = self._config + # If the user specifies the tpu_job_name, use that. + if run_config.tpu_config.tpu_job_name: + return run_config.tpu_config.tpu_job_name + + # The tpu job is determined by the run_config. Right now, this method is + # required as tpu_config is not part of the RunConfig. + mode = self._assert_mode() + master = (run_config.evaluation_master if mode == model_fn_lib.ModeKeys.EVAL + else run_config.master) + if master in _LOCAL_MASTERS: + return None + + if (not run_config.session_config or + not run_config.session_config.cluster_def.job): + return _DEFAULT_JOB_NAME + cluster_def = run_config.session_config.cluster_def + job_names = set([job.name for job in cluster_def.job]) + if _DEFAULT_JOB_NAME in job_names: + # b/37868888 tracks allowing ClusterSpec propagation to reuse job names. + raise ValueError('Currently, tpu_worker is not an allowed job name.') + if len(job_names) == 1: + return cluster_def.job[0].name + if len(job_names) == 2: + if _DEFAULT_COORDINATOR_JOB_NAME in job_names: + job_names.remove(_DEFAULT_COORDINATOR_JOB_NAME) + return job_names.pop() + # TODO(b/67716447): Include more sophisticated heuristics. + raise ValueError( + 'Could not infer TPU job name. Please specify a tpu_job_name as part ' + 'of your TPUConfig.') + + @property + def tpu_host_placement_function(self): + """Returns the TPU host place function.""" + master = self.master_job + def _placement_function(_sentinal=None, core_id=None, host_id=None): # pylint: disable=invalid-name + assert _sentinal is None + if core_id is not None and host_id is not None: + raise RuntimeError( + 'core_id and host_id can have only one non-None value.') + + if master is None: + return '/replica:0/task:0/device:CPU:0' + else: + # This assumes that if using more than 8 shards, + # the job configuration varies 'task'. + if core_id is not None: + host_id = core_id / 8 + return '/job:%s/task:%d/device:CPU:0' % (master, host_id) + return _placement_function + + @property + def tpu_device_placement_function(self): + master = self.master_job + job_device = '' if master is None else ('/job:%s' % master) + def _placement_function(i): + return '%s/task:%d/device:TPU:%d' % (job_device, i / 8, i % 8) + return _placement_function + + @property + def tpu_ordinal_function(self): + """Returns the TPU ordinal fn.""" + def _tpu_ordinal_function(index): + """Return the TPU ordinal associated with a shard. + + Required because the enqueue ops are placed on CPU. + + Args: + index: the shard index + + Returns: + The ordinal of the TPU device the shard's infeed should be placed on. + """ + return index % 8 + return _tpu_ordinal_function class _SIGNAL(object): @@ -319,11 +477,16 @@ class _InfeedThreadController(_InfeedOutfeedThreadBaseController): logging.info('Stop Infeed input thread.') return - iterations = signal - for i in range(iterations): - logging.debug('Infeed enqueue for iteration (%d, %d)', count, i) + if _WRAP_INPUT_FN_INTO_WHILE_LOOP: + # Enqueue batches for next loop. session.run(enqueue_ops) - count += 1 + else: + iterations = signal + for i in range(iterations): + logging.debug('Infeed enqueue for iteration (%d, %d)', count, i) + session.run(enqueue_ops) + count += 1 + except Exception: # pylint: disable=broad-except logging.error( 'Failed running infeed, closing session.\n' @@ -346,17 +509,16 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): dequeue. """ - def __init__(self, run_config, mode, enqueue_fn, dequeue_ops=None): - self._tpu_job = _tpu_job(run_config, mode) - self._enqueue_fn = enqueue_fn + def __init__(self, ctx, enqueue_ops, dequeue_ops=None): + self._master_job = ctx.master_job + self._enqueue_ops = enqueue_ops self._dequeue_ops = dequeue_ops def begin(self): - self._enqueue_ops = self._enqueue_fn() - self._iterations_per_loop_var = _create_iterations_per_loop() - logging.info('TPU job name %s', self._tpu_job) - self._init_op = [tpu.initialize_system(job=self._tpu_job)] - self._finalize_op = [tpu.shutdown_system(job=self._tpu_job)] + logging.info('TPU job name %s', self._master_job) + self._iterations_per_loop_var = _create_or_get_iterations_per_loop() + self._init_op = [tpu.initialize_system(job=self._master_job)] + self._finalize_op = [tpu.shutdown_system(job=self._master_job)] def after_create_session(self, session, coord): logging.info('Init TPU system') @@ -378,6 +540,7 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): iterations = run_context.session.run(self._iterations_per_loop_var) self._infeed_thd_controller.send_next_batch_signal(iterations) if self._dequeue_ops is not None: + # TODO(xiejw): Refactor the outfeed dequeue into tf.while_loop. logging.info('Dequeue next batch of data from outfeed.') self._outfeed_thd_controller.send_next_batch_signal(iterations) @@ -439,7 +602,7 @@ class _TPUStopAtStepHook(session_run_hook.SessionRunHook): if self._global_step_tensor is None: raise RuntimeError('Global step should be created.') - self._iterations_per_loop_var = _create_iterations_per_loop() + self._iterations_per_loop_var = _create_or_get_iterations_per_loop() def after_create_session(self, session, coord): global_step = session.run(self._global_step_tensor) @@ -474,360 +637,288 @@ class _SetEvalIterationsHook(session_run_hook.SessionRunHook): self._num_steps = num_steps def begin(self): - self._iterations_per_loop_var = _create_iterations_per_loop() + self._iterations_per_loop_var = _create_or_get_iterations_per_loop() def after_create_session(self, session, coord): self._iterations_per_loop_var.load(self._num_steps, session=session) -class _PerShardOutput(object): - """Wraps input_fn's outputs into per-shard outputs. - - Used so that the model_fn can distinguish between sharded input and unsharded - inputs (e.g., for export_savedmodel()). - """ - - def __init__(self, output): - self.output = output - - def as_list(self): - return self.output - +def generate_per_core_enqueue_ops_fn_for_host( + ctx, input_fn, inputs_structure_recorder): + """Generates infeed enqueue ops for per-core input_fn on a single host.""" + infeed_queue_holder = {'instance': None} + + def enqueue_ops_fn(): + """A fn returns enqueue_ops.""" + num_cores_per_host = ctx.num_of_cores_per_host + per_host_sharded_inputs = [] + for core_ordinal in range(num_cores_per_host): + with ops.name_scope('ordinal_%d' % (core_ordinal)): + inputs = input_fn() + if isinstance(inputs, tuple): + features, labels = inputs + else: + features, labels = inputs, None -class _InputsHolder(object): - """A inputs holder holds the `features` and `labels' for TPU system. + inputs_structure_recorder.validate_and_record_structure( + features, labels) + flattened_inputs = ( + inputs_structure_recorder.flatten_features_and_labels( + features, labels)) + per_host_sharded_inputs.append(flattened_inputs) - Model inputs returned by the `input_fn` can have one of the following forms: + infeed_queue = tpu_feed.InfeedQueue( + number_of_tuple_elements=len(per_host_sharded_inputs[0])) + infeed_queue_holder['instance'] = infeed_queue + infeed_queue.set_configuration_from_sharded_input_tensors( + per_host_sharded_inputs) + + per_host_enqueue_ops = infeed_queue.generate_enqueue_ops( + per_host_sharded_inputs, + tpu_ordinal_function=ctx.tpu_ordinal_function) + return per_host_enqueue_ops + return enqueue_ops_fn, (lambda: infeed_queue_holder['instance']) + + +class _InputPipeline(object): + """`_InputPipeline` handles invoking `input_fn` and piping to infeed queue. + + `_InputPipeline` abstracts the per-core/per-host `input_fn` invocation from + call site. To be precise, based on the configuration in `_TPUContext`, it + invokes `input_fn` for all cores (usually multi-host TPU training) or for one + host (usually for single-host TPU evaluation), and sends all `features` and + `labels` returned by `input_fn` to TPU infeed. For per-core invocation, + `features` and `labels` are piped to infeed directly, one tuple for each + core. For per-host invocation, `features` and `labels` are split at host + (with respect to `batch_axis`) and piped to all cores accordingly. + + In addition, flatten/unflatten are handled by `_InputPipeline` also. Model + inputs returned by the `input_fn` can have one of the following forms: 1. features 2. (features, labels) Internally, form 1 is reformed to `(features, None)` as features and labels are passed separatedly to underlying methods. For TPU training, TPUEstimator - expects multiple `features` and `labels` tuples one for each shard. - - In addition, TPUEstimator allows various different structures for inputs - (namely `features` and `labels`). `features` can be `Tensor` or dict of - string name to `Tensor`, and `labels` could be `None`, `Tensor`, or dict of - string name to `Tensor`. TPU infeed/outfeed library expects flattened tensor - list. So, `features` and `labels` need to be flattened, before infeed enqueue, - and the structure of them needs to be recorded, in order to restore them after - infeed dequeue. - - `_InputsHolder` could hold the `features` and `labels` tuple for all shards - (usually multi-host TPU training) or for one host (usually for single-host TPU - evaluation), records the structure details (including presence, dict or single - tensor, dict names), validates the structure consistency cross all shards, and - encapsulates the flatten/unflatten logic. + may expect multiple `features` and `labels` tuples one for each core. + + TPUEstimator allows various different structures for inputs (namely `features` + and `labels`). `features` can be `Tensor` or dict of string name to `Tensor`, + and `labels` could be `None`, `Tensor`, or dict of string name to `Tensor`. + TPU infeed/outfeed library expects flattened tensor list. So, `features` and + `labels` need to be flattened, before infeed enqueue, and the structure of + them needs to be recorded, in order to restore them after infeed dequeue. """ - def __init__(self, features=None, labels=None, num_shards=None): - """Constructor. - - Args: - features: features for one host or a list of features one for each shard - (must be type `_PerShardOutput`). Once provided, the corresponding - `labels` should be set also and this `_InputsHolder` is frozen to - prevent from future modification. If `None`, it is expected to add - features and labels for each shard by calling `append_tuple` later. - labels: labels for one host or a list of labels one for each shard - (must be type `_PerShardOutput`). - num_shards: Number of shards in the TPU system. Must be provided unless it - can be deduced from `features`. - - Raises: - ValueError: If both `sharded_features` and `num_shards` are `None`. - """ - # Holds the features and labels for all shards. - self._feature_list = [] - self._label_list = [] - - # Holds the structure of inputs - self._feature_names = [] - self._label_names = [] - self._has_labels = False - - # Internal state. - self._initialized = False - self._frozen = False - self._sharded = False - - if features is None: - if num_shards is None: - raise ValueError( - '`features` and `num_shards` cannot be both None') - self._num_shards = num_shards - elif isinstance(features, _PerShardOutput): - self._from_sharded_inputs(features, labels, num_shards) - else: - if num_shards is None: - raise ValueError( - '`num_shards` cannot be None for unsharded features.') - self._from_unsharded_inputs(features, labels, num_shards) - - def _from_unsharded_inputs(self, features, labels, num_shards): - """Initializes the inputs with unsharded features and labels.""" - self._num_shards = num_shards - if labels is not None: - self._has_labels = True - self.append_tuple((features, labels)) - else: - self.append_tuple(features) - - self._sharded = False - self._frozen = True - - def _from_sharded_inputs(self, sharded_features, sharded_labels, num_shards): - """Initializes the inputs with sharded features and labels.""" - if not isinstance(sharded_features, _PerShardOutput): - raise ValueError('`sharded_features` must have type `_PerShardOutput`.') - features = sharded_features.as_list() - - if num_shards is not None and num_shards != len(features): - raise ValueError( - '`num_shards` should be same as the length of sharded_features.') + class InputsStructureRecorder(object): + """The recorder to record inputs structure.""" + + def __init__(self): + # Holds the structure of inputs + self._feature_names = [] + self._label_names = [] + self._has_labels = False + + # Internal state. + self._initialized = False + + def has_labels(self): + return self._has_labels + + def validate_and_record_structure(self, features, labels): + """Validates and records the structure of features` and `labels`.""" + def _extract_key_names(tensor_or_dict): + if tensor_or_dict is None: + return [] + return tensor_or_dict.keys() if isinstance(tensor_or_dict, dict) else [] + + # Extract structure. + has_labels = labels is not None + feature_names = _extract_key_names(features) + label_names = _extract_key_names(labels) + + if self._initialized: + # Verify the structure is same. The following should never happen. + assert feature_names == self._feature_names, 'feature keys mismatched' + assert label_names == self._label_names, 'label keys mismatched' + assert has_labels == self._has_labels, 'label presence mismatched' + else: + # Record structure. + self._initialized = True + self._feature_names = feature_names + self._label_names = label_names + self._has_labels = has_labels + + def flatten_features_and_labels(self, features, labels): + """Flattens the `features` and `labels` to a single tensor list.""" + flattened_inputs = [] + if self._feature_names: + # We need a fixed ordering for enqueueing and dequeueing. + flattened_inputs.extend([features[name] + for name in self._feature_names]) + else: + flattened_inputs.append(features) - self._num_shards = len(features) - if not self._num_shards: - raise ValueError('`sharded_features` should not be empty.') + if labels is not None: + if self._label_names: + # We need a fixed ordering for enqueueing and dequeueing. + flattened_inputs.extend([labels[name] for name in self._label_names]) + else: + flattened_inputs.append(labels) + return flattened_inputs + + def unflatten_features_and_labels(self, flattened_inputs): + """Restores the flattened inputs to original features and labels form. + + Args: + flattened_inputs: Flattened inputs for each shard. + + Returns: + A tuple of (`features`, `labels`), where `labels` could be None. + Each one, if present, should have identical structure (single tensor vs + dict) as the one returned by input_fn. + + Raises: + ValueError: If the number of expected tensors from `flattened_inputs` + mismatches the recorded structure. + """ + expected_num_features = (len(self._feature_names) if self._feature_names + else 1) + if self._has_labels: + expected_num_labels = (len(self._label_names) if self._label_names + else 1) + else: + expected_num_labels = 0 - if sharded_labels is not None: - if not isinstance(sharded_labels, _PerShardOutput): - raise ValueError('sharded_labels` must have type `_PerShardOutput`.') + expected_num_tensors = expected_num_features + expected_num_labels - self._has_labels = True - labels = sharded_labels.as_list() - if self._num_shards != len(labels): + if expected_num_tensors != len(flattened_inputs): raise ValueError( - 'Length of `sharded_features` and `sharded_labels` mismatch.') - - if self._has_labels: - for (f, l) in zip(features, labels): - self.append_tuple((f, l)) - else: - for f in features: - self.append_tuple(f) - - self._sharded = True - self._frozen = True - - def _extract_key_names(self, tensor_or_dict): - if tensor_or_dict is None: - return [] - - return tensor_or_dict.keys() if isinstance(tensor_or_dict, dict) else [] - - def _validate(self, features, labels): - has_labels = labels is not None - feature_names = self._extract_key_names(features) - label_names = self._extract_key_names(labels) - - if self._initialized: - self._sharded = True - # The following should never happen. - assert feature_names == self._feature_names, 'feature keys mismatched' - assert label_names == self._label_names, 'label keys mismatched' - assert has_labels == self._has_labels, 'label presence mismatched' - else: - self._initialized = True - self._feature_names = feature_names - self._label_names = label_names - self._has_labels = has_labels - - @property - def sharded(self): - if not self._frozen: - raise RuntimeError('_InputsHolder has not been frozen yet.') - return self._sharded - - @property - def num_shards(self): - if not self._frozen: - raise RuntimeError('_InputsHolder has not been frozen yet.') - return self._num_shards - - def append_tuple(self, inputs): - """Appends `inputs` for one shard into holder. - - Args: - inputs: The return from `input_fn`, which could be features or tuple of - (features, labels). After the first `inputs` appended into - `_InputsHolder`, the structure of `features` and `labels is recorded. - Any future invocation should provide the `inputs` with same structure. - - Raises: - RuntimeError: If the internal data has been frozen already. - """ - if self._frozen: - raise RuntimeError('InputsHolder has frozen, which cannot be mutated.') - - # input_fn may return either features or (features, labels) - if isinstance(inputs, tuple): - features, labels = inputs - else: - features, labels = inputs, None - - self._validate(features, labels) - - self._feature_list.append(features) - if labels is not None: - self._label_list.append(labels) - - def as_features_and_labels_tuple(self): - """Returns features and labels as grouped tuple. - - This is intended to be used to pass features and labels for all shards from - input_fn to model_fn as the parent class `Estimator` does not have the - concept of shards. So, grouped tuple is required. - - Once called, the internal data is frozen and `append_tuple` cannot be - invoked anymore. - - Returns: - A tuple of features and labels. Both have type `_PerShardOutput`, holding - the inputs for all shards. `labels` could be `None`. - - Raises: - RuntimeError: If the internal data has not been initialized. - """ - self._frozen = True - if not self._initialized: - raise RuntimeError('InputsHolder has not been initialized.') - - assert len(self._feature_list) == self._num_shards - if not self._label_list or all(l is None for l in self._label_list): - return _PerShardOutput(self._feature_list), None - - assert len(self._label_list) == self._num_shards - return (_PerShardOutput(self._feature_list), - _PerShardOutput(self._label_list)) - - def as_sharded_flattened_inputs(self): - """Flatten the features and label as tensor lists for all shards. - - Flattened tensor list contains all tensors in `features` (dict) and `labels` - (dict). Conceptually, it has the predicated structure like: - - ```python - flatten_list = [] - for name in features: - flatten_list.append(features[name]) - for name in labels: - flatten_list.append(labels[name]) - ``` - - This method handles the label is None case and single tensor case nicely. - - Once called, the internal data is frozen and `append_tuple` cannot be - invokded anymore. - - Returns: - A list of flattened inputs one for each shard. - - Raises: - RuntimeError: If the internal data has not been initialized. - ValueError: If the inputs are sharded. - """ - self._frozen = True - if not self._initialized: - raise RuntimeError('InputsHolder has not been initialized.') - if not self._sharded: - raise ValueError('Inputs are not sharded.') - - sharded_inputs = [] - - for shard in range(self._num_shards): - flattened_inputs = self._as_flattened_inputs( - self._feature_list[shard], - self._label_list[shard] if self._has_labels else None) - sharded_inputs.append(flattened_inputs) - - return sharded_inputs - - def as_flattened_inputs(self): - """Flatten the features and label as a single tensor list for one host.""" - self._frozen = True - if not self._initialized: - raise RuntimeError('InputsHolder has not been initialized.') - if self._sharded: - raise ValueError('Inputs are sharded.') - - return self._as_flattened_inputs( - self._feature_list[0], - self._label_list[0] if self._has_labels else None) - - def _as_flattened_inputs(self, features, labels): - """Flattens the `features` and `labels` to a single tensor list.""" - flattened_inputs = [] - if self._feature_names: - # We need a fixed ordering for enqueueing and dequeueing. - flattened_inputs.extend([features[name] for name in self._feature_names]) - else: - flattened_inputs.append(features) - - if labels is not None: - if self._label_names: - # We need a fixed ordering for enqueueing and dequeueing. - flattened_inputs.extend([labels[name] for name in self._label_names]) + 'The number of flattened tensors mismatches expected num. ' + 'Expected {}, got {}'.format(expected_num_tensors, + len(flattened_inputs))) + if self._feature_names: + unflattened_features = dict( + zip(self._feature_names, flattened_inputs[:expected_num_features])) + else: + # Single tensor case + unflattened_features = flattened_inputs[0] + + if expected_num_labels == 0: + unflattened_label = None + elif self._label_names: + unflattened_label = dict(zip(self._label_names, + flattened_inputs[expected_num_features:])) else: - flattened_inputs.append(labels) - return flattened_inputs + # Single tensor case. + unflattened_label = flattened_inputs[expected_num_features] - def unflatten_features_and_labels(self, flattened_inputs): - """Restores the flattened inputs to original features and labels form. + return unflattened_features, unflattened_label - Once called, the internal data is frozen and `append_tuple` cannot be - invokded anymore. + def __init__(self, input_fn, batch_axis, ctx): + """Constructor. Args: - flattened_inputs: Flattened inputs for one each, which should be created - by the `as_sharded_flattened_inputs` API. - - Returns: - A tuple of (`features`, `labels`), where `labels` could be None. - Each one, if present, should have identical structure (single tensor vs - dict) as the one returned by input_fn. + input_fn: input fn for train or eval. + batch_axis: A python tuple of int values describing how each tensor + produced by the Estimator `input_fn` should be split across the TPU + compute shards. + ctx: A `_TPUContext` instance with mode. Raises: - RuntimeError: If the internal data has not been initialized. - ValueError: If the number of expected tensors from `flattened_inputs` - mismatches the recorded structure. + ValueError: If both `sharded_features` and `num_cores` are `None`. """ - self._frozen = True - if not self._initialized: - raise RuntimeError('InputsHolder has not been initialized.') - - expected_num_features = (len(self._feature_names) if self._feature_names - else 1) - if self._has_labels: - expected_num_labels = (len(self._label_names) if self._label_names - else 1) - else: - expected_num_labels = 0 - - expected_num_tensors = expected_num_features + expected_num_labels + self._inputs_structure_recorder = _InputPipeline.InputsStructureRecorder() + + self._sharded_per_core = ctx.is_input_sharded_per_core() + self._input_fn = input_fn + self._infeed_queue = None + self._ctx = ctx + self._batch_axis = batch_axis + + def generate_infeed_enqueue_ops_and_dequeue_fn(self): + """Generates infeed enqueue ops and dequeue_fn.""" + # While tf.while_loop is called, the body function, which invokes + # `enqueue_fn` passed in, is called to construct the graph. So, input_fn + # structure is recorded. + enqueue_ops = self._invoke_input_fn_and_record_structure() + + def dequeue_fn(): + """dequeue_fn is used by TPU to retrieve the tensors.""" + values = self._infeed_queue.generate_dequeue_op() + # The unflatten process uses the structure information recorded above. + return self._inputs_structure_recorder.unflatten_features_and_labels( + values) + + return (enqueue_ops, dequeue_fn) + + def _invoke_input_fn_and_record_structure(self): + if self._sharded_per_core: + # Per-Core input pipeline deployment. + tpu_host_placement_fn = self._ctx.tpu_host_placement_function + enqueue_ops = [] + infeed_queues = [] + + # Invoke input pipeline for each core and placed on the corresponding + # host. + num_hosts = self._ctx.num_hosts + for host_id in range(num_hosts): + host_device = tpu_host_placement_fn(host_id=host_id) + with ops.device(host_device): + with ops.name_scope('input_pipeline_task%d' % (host_id)): + enqueue_ops_fn, infeed_queue_getter = ( + generate_per_core_enqueue_ops_fn_for_host( + self._ctx, self._input_fn, self._inputs_structure_recorder)) + + if _WRAP_INPUT_FN_INTO_WHILE_LOOP: + enqueue_ops.append(_wrap_computation_in_while_loop( + device=host_device, op_fn=enqueue_ops_fn)) + else: + enqueue_ops.append(enqueue_ops_fn()) + # Infeed_queue_getter must be called after enqueue_ops_fn is called. + infeed_queues.append(infeed_queue_getter()) + + # infeed_queue is used to generate dequeue ops. The only thing it uses for + # dequeue is dtypes and types. So, any one can be used. Here, grab the + # first one. + self._infeed_queue = infeed_queues[0] + return enqueue_ops - if expected_num_tensors != len(flattened_inputs): - raise ValueError( - 'The number of flattened tensors mismatches expected num. ' - 'Expected {}, got {}'.format(expected_num_tensors, - len(flattened_inputs))) - if self._feature_names: - unflattened_features = dict(zip(self._feature_names, - flattened_inputs[:expected_num_features])) - else: - # Single tensor case - unflattened_features = flattened_inputs[0] - - if expected_num_labels == 0: - unflattened_label = None - elif self._label_names: - unflattened_label = dict(zip(self._label_names, - flattened_inputs[expected_num_features:])) else: - # Single tensor case. - unflattened_label = flattened_inputs[expected_num_features] - - return unflattened_features, unflattened_label + # TODO(b/67051042): Extend this to multi-host support. + host_id = 0 + host_device = self._ctx.tpu_host_placement_function(host_id=host_id) + def enqueue_fn(): + with ops.device(host_device): + with ops.name_scope('input_pipeline_task%d' % (host_id)): + inputs = self._input_fn() + if isinstance(inputs, tuple): + features, labels = inputs + else: + features, labels = inputs, None + self._inputs_structure_recorder.validate_and_record_structure( + features, labels) + unsharded_tensor_list = ( + self._inputs_structure_recorder.flatten_features_and_labels( + features, labels)) + + self._infeed_queue = tpu_feed.InfeedQueue( + tuple_types=[t.dtype for t in unsharded_tensor_list], + tuple_shapes=[t.shape for t in unsharded_tensor_list], + shard_dimensions=self._batch_axis) + self._infeed_queue.set_number_of_shards(self._ctx.num_cores) + + def placement_fn(core_id): + return self._ctx.tpu_host_placement_function(core_id=core_id) + return ( + self._infeed_queue.split_inputs_and_generate_enqueue_ops( + unsharded_tensor_list, + placement_function=placement_fn)) + + if _WRAP_INPUT_FN_INTO_WHILE_LOOP: + return _wrap_computation_in_while_loop(device=host_device, + op_fn=enqueue_fn) + else: + return enqueue_fn() class _ModelFnWrapper(object): @@ -840,20 +931,17 @@ class _ModelFnWrapper(object): train and eval step. """ - def __init__(self, model_fn, config, params, mode, train_batch_size, - eval_batch_size): + def __init__(self, model_fn, config, params, ctx): self._model_fn = model_fn self._config = config self._params = params - self._mode = mode - self._train_batch_size = train_batch_size - self._eval_batch_size = eval_batch_size + self._ctx = ctx def call_without_tpu(self, features, labels): # Let CrossShardOptimizer be called without TPU in model_fn, since it's # common to set the train_op even when running evaluate() or predict(). with tpu_function.tpu_shard_context(1): - return self._call_model_fn(features, labels, use_tpu=False) + return self._call_model_fn(features, labels) def convert_to_single_tpu_train_step(self, dequeue_fn): """Converts user provided model_fn` as a single train step on TPU. @@ -883,7 +971,7 @@ class _ModelFnWrapper(object): features, labels = dequeue_fn() estimator_spec = self._verify_estimator_spec( - self._call_model_fn(features, labels, use_tpu=True)) + self._call_model_fn(features, labels)) loss, train_op = estimator_spec.loss, estimator_spec.train_op with ops.control_dependencies([train_op]): return array_ops.identity(loss) @@ -915,13 +1003,13 @@ class _ModelFnWrapper(object): A tuple of eval_fn and eval_metrics. The eval_fn representing the eval step for TPU. and eval_metrics is an `_EvalMetrics` instance. """ - eval_metrics = _EvalMetrics() + eval_metrics = _EvalMetrics(self._ctx) def eval_step(total_loss): """Evaluation step function for use inside a while loop.""" features, labels = dequeue_fn() - tpu_estimator_spec = self._call_model_fn(features, labels, use_tpu=True) + tpu_estimator_spec = self._call_model_fn(features, labels) if not isinstance(tpu_estimator_spec, TPUEstimatorSpec): raise RuntimeError( 'estimator_spec used by TPU evaluation must have type' @@ -935,11 +1023,7 @@ class _ModelFnWrapper(object): return math_ops.add(total_loss, loss) return eval_step, eval_metrics - @property - def config(self): - return self._config - - def _call_model_fn(self, features, labels, use_tpu): + def _call_model_fn(self, features, labels): """Calls the model_fn with required parameters.""" model_fn_args = util.fn_args(self._model_fn) kwargs = {} @@ -950,12 +1034,11 @@ class _ModelFnWrapper(object): if 'labels' in model_fn_args: kwargs['labels'] = labels - else: - if labels is not None: - raise ValueError( - 'model_fn does not take labels, but input_fn returns labels.') + elif labels is not None: + raise ValueError( + 'model_fn does not take labels, but input_fn returns labels.') if 'mode' in model_fn_args: - kwargs['mode'] = self._mode + kwargs['mode'] = self._ctx.mode if 'config' in model_fn_args: kwargs['config'] = config if 'params' in model_fn_args: @@ -966,16 +1049,16 @@ class _ModelFnWrapper(object): 'model_fn ({}) does not include params argument, ' 'required by TPUEstimator to pass batch size as ' 'params[\'batch_size\']'.format(self._model_fn)) - if self._mode == model_fn_lib.ModeKeys.TRAIN: - params[_BATCH_SIZE_KEY] = _per_shard_batch_size( - self._train_batch_size, config, use_tpu) - elif (self._mode == model_fn_lib.ModeKeys.EVAL and - self._eval_batch_size is not None): - params[_BATCH_SIZE_KEY] = _per_shard_batch_size( - self._eval_batch_size, config, use_tpu) + + batch_size_for_model_fn = self._ctx.batch_size_for_model_fn + if batch_size_for_model_fn is not None: + params[_BATCH_SIZE_KEY] = batch_size_for_model_fn estimator_spec = self._model_fn(features=features, **kwargs) - if (not use_tpu) and isinstance(estimator_spec, TPUEstimatorSpec): + if (self._ctx.is_running_on_cpu() and + isinstance(estimator_spec, TPUEstimatorSpec)): + # The estimator_spec will be passed to `Estimator` directly, which expects + # type `EstimatorSpec`. return estimator_spec.as_estimator_spec() else: return estimator_spec @@ -998,7 +1081,8 @@ class _ModelFnWrapper(object): class _EvalMetrics(object): """Class wraps TPUEstimator.eval_metrics.""" - def __init__(self): + def __init__(self, ctx): + self._ctx = ctx self._metric_fn = None self._is_dict = False self._tensor_keys = [] @@ -1081,7 +1165,7 @@ class _EvalMetrics(object): raise RuntimeError('Eval metrics have not been recorded yet') return self._tensors - def to_metric_metric_ops_for_tpu(self, run_config, dummy_update_op): + def to_metric_metric_ops_for_tpu(self, dummy_update_op): """Creates the eval_metric_ops now based on the TPU outfeed. `eval_metric_ops` is defined in `EstimatorSpec`. From all shards, tensors @@ -1090,7 +1174,6 @@ class _EvalMetrics(object): metric fn. Args: - run_config: A `RunConfig` instance. dummy_update_op: A dummy update op. Returns: @@ -1102,9 +1185,7 @@ class _EvalMetrics(object): RuntimeError: If outfeed tensor is scalar. """ - num_shards = run_config.tpu_config.num_shards - job = _tpu_job(run_config, model_fn_lib.ModeKeys.EVAL) - job_device = '' if job is None else ('/job:%s' % job) + num_cores = self._ctx.num_cores # For each i, dequeue_ops[i] is a list containing the tensors from all # shards. This list is concatenated later. @@ -1113,8 +1194,9 @@ class _EvalMetrics(object): dequeue_ops.append([]) # Outfeed ops execute on each JF node. - for i in xrange(num_shards): - with ops.device('%s/task:%d/device:TPU:%d' % (job_device, i / 8, i % 8)): + tpu_device_placement_fn = self._ctx.tpu_device_placement_function + for i in xrange(num_cores): + with ops.device(tpu_device_placement_fn(i)): outfeed_tensors = tpu_ops.outfeed_dequeue_tuple( dtypes=self._tensor_dtypes, shapes=self._tensor_shapes) for j, item in enumerate(outfeed_tensors): @@ -1122,7 +1204,7 @@ class _EvalMetrics(object): # It is assumed evaluation always happends on single host TPU system. So, # place all ops on tpu host if possible. - with ops.device('{}/device:CPU:0'.format(job_device)): + with ops.device(self._ctx.tpu_host_placement_function(core_id=0)): for i, item in enumerate(dequeue_ops): if dequeue_ops[i][0].shape.ndims == 0: raise RuntimeError( @@ -1167,9 +1249,9 @@ class TPUEstimator(estimator_lib.Estimator): specify `train_batch_size` in constructor, and then get the batch size for each shard in `input_fn` and `model_fn` by `params['batch_size']`. If `TPUConfig.per_host_input_for_training` is `True`, `input_fn` is invoked per - host rather than per shard. In this case, a global batch size is transformed a + host rather than per core. In this case, a global batch size is transformed a per-host batch size in params for `input_fn`, but `model_fn` still gets - per-shard batch size. + per-core batch size. For evaluation, if `eval_batch_size` is None, it is executed on CPU, even if `use_tpu` is `True`. If `eval_batch_size` is not `None`, it is executed on @@ -1327,9 +1409,7 @@ class TPUEstimator(estimator_lib.Estimator): # We cannot store config and params in this constructor as parent # constructor might change them, such as assigning a temp dir for # config.model_dir. - model_function = _augment_model_fn(model_fn, train_batch_size, - eval_batch_size, use_tpu, - batch_axis) + model_function = self._augment_model_fn(model_fn, batch_axis) # Passing non-None params as wrapped model_fn has it. params = params or {} @@ -1338,12 +1418,13 @@ class TPUEstimator(estimator_lib.Estimator): model_dir=model_dir, config=config, params=params) - self._use_tpu = use_tpu - self._train_batch_size = train_batch_size - self._eval_batch_size = eval_batch_size self._iterations_per_training_loop = ( self._config.tpu_config.iterations_per_loop) + # All properties passed to _TPUContext are immutable. + self._ctx = _TPUContext(self._config, train_batch_size, eval_batch_size, + use_tpu) + def _create_global_step(self, graph): """Creates a global step suitable for TPUs. @@ -1359,10 +1440,10 @@ class TPUEstimator(estimator_lib.Estimator): return _create_global_step(graph) def _convert_train_steps_to_hooks(self, steps, max_steps): - if _is_running_on_cpu(self._use_tpu, model_fn_lib.ModeKeys.TRAIN, - self._eval_batch_size): - return super(TPUEstimator, self)._convert_train_steps_to_hooks( - steps, max_steps) + with self._ctx.with_mode(model_fn_lib.ModeKeys.TRAIN) as ctx: + if ctx.is_running_on_cpu(): + return super(TPUEstimator, self)._convert_train_steps_to_hooks( + steps, max_steps) # On TPU. if steps is None and max_steps is None: @@ -1380,9 +1461,9 @@ class TPUEstimator(estimator_lib.Estimator): steps, max_steps)] def _convert_eval_steps_to_hooks(self, steps): - if _is_running_on_cpu(self._use_tpu, model_fn_lib.ModeKeys.EVAL, - self._eval_batch_size): - return super(TPUEstimator, self)._convert_eval_steps_to_hooks(steps) + with self._ctx.with_mode(model_fn_lib.ModeKeys.EVAL) as ctx: + if ctx.is_running_on_cpu(): + return super(TPUEstimator, self)._convert_eval_steps_to_hooks(steps) if steps is None: raise ValueError('Evaluate `steps` must be set on TPU. Cannot be `None`.') @@ -1422,197 +1503,115 @@ class TPUEstimator(estimator_lib.Estimator): if 'config' in input_fn_args: kwargs['config'] = config - # Setting the batch size in params first. This helps user to have same - # input_fn for use_tpu=True/False. - if mode == model_fn_lib.ModeKeys.TRAIN: - kwargs['params'][_BATCH_SIZE_KEY] = ( - _per_shard_batch_size(self._train_batch_size, config, self._use_tpu) - if not config.tpu_config.per_host_input_for_training else - self._train_batch_size) - elif (mode == model_fn_lib.ModeKeys.EVAL and - self._eval_batch_size is not None): - # For TPU evaluation, input_fn is invoked for one host (instead of shard). - kwargs['params'][_BATCH_SIZE_KEY] = self._eval_batch_size - - if _is_running_on_cpu(self._use_tpu, mode, self._eval_batch_size): - with ops.device('/device:CPU:0'): - return input_fn(**kwargs) - - job = _tpu_job(config, mode) - def placement_function(index): - if job is None: - return '/replica:0/task:0/device:CPU:0' - else: - return '/job:%s/task:%d/device:CPU:0' % (job, index / 8) + with self._ctx.with_mode(mode) as ctx: + # Setting the batch size in params first. This helps user to have same + # input_fn for use_tpu=True/False. + batch_size_for_input_fn = ctx.batch_size_for_input_fn + if batch_size_for_input_fn is not None: + kwargs['params'][_BATCH_SIZE_KEY] = batch_size_for_input_fn - if mode == model_fn_lib.ModeKeys.TRAIN: - if not config.tpu_config.per_host_input_for_training: - # Now for TPU training. - num_shards = config.tpu_config.num_shards - inputs = _InputsHolder(num_shards=num_shards) - for i in range(config.tpu_config.num_shards): - with ops.device(placement_function(i)): - inputs.append_tuple(input_fn(**kwargs)) - return inputs.as_features_and_labels_tuple() - else: - # TODO(xiejw): Extend this to multi-host support. - with ops.device(placement_function(0)): + if ctx.is_running_on_cpu(): + with ops.device('/device:CPU:0'): return input_fn(**kwargs) - # Now for TPU evaluation. - with ops.device(placement_function(0)): - return input_fn(**kwargs) - - -# TODO(b/64607814): Ensure batch_axis works with nested structures. -def _create_infeed_enqueue_ops_and_dequeue_fn(inputs_holder, run_config, - batch_axis, mode): - """Utility to convert input_fn to enqueue and dequeue fns for TPU. - - Args: - inputs_holder: An `_InputsHolder` holding features and labels. - run_config: A `RunConfig` instance. - batch_axis: A python list of batch dimensions. - mode: ModeKeys - - Returns: - A tuple of (dequeue_fn, enqueue_fn) - """ - if inputs_holder.sharded: - sharded_inputs = inputs_holder.as_sharded_flattened_inputs() - - infeed_queue = tpu_feed.InfeedQueue( - number_of_tuple_elements=len(sharded_inputs[0])) - infeed_queue.set_configuration_from_sharded_input_tensors(sharded_inputs) - else: - unsharded_inputs = inputs_holder.as_flattened_inputs() - infeed_queue = tpu_feed.InfeedQueue( - tuple_types=[t.dtype for t in unsharded_inputs], - tuple_shapes=[t.shape for t in unsharded_inputs], - shard_dimensions=batch_axis) - infeed_queue.set_number_of_shards(inputs_holder.num_shards) - - def dequeue_fn(): - """dequeue_fn is used by the train_step in TPU to retrieve the tensors.""" - values = infeed_queue.generate_dequeue_op() - return inputs_holder.unflatten_features_and_labels(values) - - def tpu_ordinal_function(index): - """Return the TPU ordinal associated with a shard. - - Required because the enqueue ops are placed on CPU. - - Args: - index: the shard index - - Returns: - The ordinal of the TPU device the shard's infeed should be placed on. - """ - return index % 8 - - def enqueue_fn(): - """enqueue_fn is used to add ops to the graph to send tensors.""" - if inputs_holder.sharded: - return infeed_queue.generate_enqueue_ops( - sharded_inputs, tpu_ordinal_function=tpu_ordinal_function) - else: - job = _tpu_job(run_config, mode) - def placement_function(index): - if job is None: - return '/replica:0/task:0/device:CPU:0' - else: - # This assumes that if using more than 8 shards, - # the job configuration varies 'task'. - return '/job:%s/task:%d/device:CPU:0' % (job, index / 8) - return infeed_queue.split_inputs_and_generate_enqueue_ops( - unsharded_inputs, placement_function=placement_function) - - return (dequeue_fn, enqueue_fn) - - -def _augment_model_fn(model_fn, train_batch_size, eval_batch_size, use_tpu, - batch_axis): - """Returns a new model_fn, which wraps the TPU support.""" - - def _model_fn(features, labels, mode, config, params): - """A Estimator `model_fn` for TPUEstimator.""" - model_fn_wrapper = _ModelFnWrapper(model_fn, config, params, mode, - train_batch_size, eval_batch_size) - - # TODO(jhseu): Move to PREDICT to TPU. - if _is_running_on_cpu(use_tpu, mode, eval_batch_size): - logging.info('Running %s on CPU', mode) - return model_fn_wrapper.call_without_tpu(features, labels) - - inputs = _InputsHolder(features=features, labels=labels, - num_shards=config.tpu_config.num_shards) - - dequeue_fn, enqueue_fn = _create_infeed_enqueue_ops_and_dequeue_fn( - inputs, config, batch_axis, mode) - - if mode == model_fn_lib.ModeKeys.TRAIN: - loss = _train_on_tpu_system(model_fn_wrapper, dequeue_fn) - hooks = [ - TPUInfeedOutfeedSessionHook(config, mode, enqueue_fn), - training.LoggingTensorHook( - {'loss': array_ops.identity(loss), - 'step': training.get_global_step()}, - every_n_secs=30) - ] - summary.scalar(model_fn_lib.LOSS_METRIC_KEY, loss) - with ops.control_dependencies([loss]): - update_ops = _sync_variables_ops() - - # Validate the TPU training graph to catch basic errors - _validate_tpu_training_graph() - - return model_fn_lib.EstimatorSpec( - mode, - loss=loss, - training_hooks=hooks, - train_op=control_flow_ops.group(*update_ops)) - - # Now eval. - total_loss, eval_metric_ops = _eval_on_tpu_system( - model_fn_wrapper, dequeue_fn) - iterations_per_loop_var = _create_iterations_per_loop() - mean_loss = math_ops.div( - total_loss, - math_ops.cast(iterations_per_loop_var, dtype=total_loss.dtype)) - - # Creates a dummy metric update_op for all metrics. Estimator expects all - # metrics in eval_metric_ops have update_op and calls them one by one. The - # real metric update_ops are invoked in a separated thread. So, here give - # Estimator the dummy op for all metrics. - with ops.control_dependencies([mean_loss]): - # After TPU evaluation computation is done (the mean_loss tensor), reads - # all variables back from TPU and updates the eval step counter properly. - internal_ops_to_run = _sync_variables_ops() - internal_ops_to_run.append( - _increase_eval_step_op(iterations_per_loop_var)) - with ops.control_dependencies(internal_ops_to_run): - dummy_update_op = control_flow_ops.no_op() - - eval_metric_ops, eval_update_ops = ( - eval_metric_ops.to_metric_metric_ops_for_tpu( - config, dummy_update_op)) - hooks = [ - TPUInfeedOutfeedSessionHook(config, mode, enqueue_fn, eval_update_ops), - ] - - return model_fn_lib.EstimatorSpec( - mode, - loss=mean_loss, - evaluation_hooks=hooks, - eval_metric_ops=eval_metric_ops) - return _model_fn - - -def _eval_on_tpu_system(model_fn_wrapper, dequeue_fn): + # For TPU computation, input_fn should be invoked in a tf.while_loop for + # performance. While constructing the tf.while_loop, the structure of + # inputs returned by the `input_fn` needs to be recorded. The structure + # includes whether features or labels is dict or single Tensor, dict keys, + # tensor shapes, and dtypes. The recorded structure is used to create the + # infeed dequeue ops, which must be wrapped and passed as a Fn, called + # inside the TPU computation, as the TPU computation is wrapped inside a + # tf.while_loop also. So, we either pass input_fn to model_fn or pass + # dequeue_fn to model_fn. Here, `input_fn` is passed directly as + # `features` in `model_fn` signature. + def _input_fn(): + return input_fn(**kwargs) + return _input_fn + + def _augment_model_fn(self, model_fn, batch_axis): + """Returns a new model_fn, which wraps the TPU support.""" + + def _model_fn(features, labels, mode, config, params): + """A Estimator `model_fn` for TPUEstimator.""" + with self._ctx.with_mode(mode) as ctx: + model_fn_wrapper = _ModelFnWrapper(model_fn, config, params, ctx) + + # TODO(jhseu): Move to PREDICT to TPU. + if ctx.is_running_on_cpu(): + logging.info('Running %s on CPU', mode) + return model_fn_wrapper.call_without_tpu(features, labels) + + assert labels is None, '`labels` passed to `model_fn` must be `None`.' + # TPUEstimator._call_input_fn passes `input_fn` as features to here. + assert callable(features), '`input_fn` is not callable.' + input_fn = features + + input_holders = _InputPipeline(input_fn, batch_axis, ctx) + enqueue_ops, dequeue_fn = ( + input_holders.generate_infeed_enqueue_ops_and_dequeue_fn()) + + if mode == model_fn_lib.ModeKeys.TRAIN: + loss = _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn) + hooks = [ + TPUInfeedOutfeedSessionHook(ctx, enqueue_ops), + training.LoggingTensorHook( + {'loss': array_ops.identity(loss), + 'step': training.get_global_step()}, + every_n_secs=30) + ] + summary.scalar(model_fn_lib.LOSS_METRIC_KEY, loss) + with ops.control_dependencies([loss]): + update_ops = _sync_variables_ops() + + # Validate the TPU training graph to catch basic errors + _validate_tpu_training_graph() + + return model_fn_lib.EstimatorSpec( + mode, + loss=loss, + training_hooks=hooks, + train_op=control_flow_ops.group(*update_ops)) + + # Now eval. + total_loss, eval_metric_ops = _eval_on_tpu_system( + ctx, model_fn_wrapper, dequeue_fn) + iterations_per_loop_var = _create_or_get_iterations_per_loop() + mean_loss = math_ops.div( + total_loss, + math_ops.cast(iterations_per_loop_var, dtype=total_loss.dtype)) + + # Creates a dummy metric update_op for all metrics. Estimator expects + # all metrics in eval_metric_ops have update_op and calls them one by + # one. The real metric update_ops are invoked in a separated thread. So, + # here give Estimator the dummy op for all metrics. + with ops.control_dependencies([mean_loss]): + # After TPU evaluation computation is done (the mean_loss tensor), + # reads all variables back from TPU and updates the eval step counter + # properly + internal_ops_to_run = _sync_variables_ops() + internal_ops_to_run.append( + _increase_eval_step_op(iterations_per_loop_var)) + with ops.control_dependencies(internal_ops_to_run): + dummy_update_op = control_flow_ops.no_op() + + eval_metric_ops, eval_update_ops = ( + eval_metric_ops.to_metric_metric_ops_for_tpu(dummy_update_op)) + hooks = [ + TPUInfeedOutfeedSessionHook(ctx, enqueue_ops, eval_update_ops), + ] + + return model_fn_lib.EstimatorSpec( + mode, + loss=mean_loss, + evaluation_hooks=hooks, + eval_metric_ops=eval_metric_ops) + return _model_fn + + +def _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): """Executes `model_fn_wrapper` multiple times on all TPU shards.""" - config = model_fn_wrapper.config.tpu_config - num_shards = config.num_shards - iterations_per_loop_var = _create_iterations_per_loop() + num_cores = ctx.num_cores + iterations_per_loop_var = _create_or_get_iterations_per_loop() single_tpu_eval_step, eval_metric_ops = ( model_fn_wrapper.convert_to_single_tpu_eval_step(dequeue_fn)) @@ -1625,15 +1624,15 @@ def _eval_on_tpu_system(model_fn_wrapper, dequeue_fn): (loss,) = tpu.shard(multi_tpu_eval_steps_on_single_shard, inputs=[], - num_shards=num_shards, + num_shards=num_cores, outputs_from_all_shards=False) return loss, eval_metric_ops -def _train_on_tpu_system(model_fn_wrapper, dequeue_fn): +def _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): """Executes `model_fn_wrapper` multiple times on all TPU shards.""" - num_shards = model_fn_wrapper.config.tpu_config.num_shards - iterations_per_loop_var = _create_iterations_per_loop() + num_cores = ctx.num_cores + iterations_per_loop_var = _create_or_get_iterations_per_loop() single_tpu_train_step = model_fn_wrapper.convert_to_single_tpu_train_step( dequeue_fn) @@ -1647,11 +1646,27 @@ def _train_on_tpu_system(model_fn_wrapper, dequeue_fn): (loss,) = tpu.shard(multi_tpu_train_steps_on_single_shard, inputs=[], - num_shards=num_shards, + num_shards=num_cores, outputs_from_all_shards=False) return loss +def _wrap_computation_in_while_loop(device, op_fn): + """Wraps the ops generated by `op_fn` in tf.while_loop.""" + def computation(i): + with ops.control_dependencies(op_fn()): + return i + 1 + + iterations_per_loop_var = _create_or_get_iterations_per_loop() + # By setting parallel_iterations=1, the parallel execution in while_loop is + # basically turned off. + with ops.device(device): + iterations = array_ops.identity(iterations_per_loop_var) + return control_flow_ops.while_loop( + lambda i: i < iterations, + computation, [constant_op.constant(0)], parallel_iterations=1) + + def _validate_tpu_training_graph(): """Validate graph before running distributed training. diff --git a/tensorflow/contrib/training/python/training/bucket_ops.py b/tensorflow/contrib/training/python/training/bucket_ops.py index 5523cc375fc20dc167fee0eaa6f1682dc1892c3f..95fbc50cba73b25b748c31ecd443eb19c0b6fc8a 100644 --- a/tensorflow/contrib/training/python/training/bucket_ops.py +++ b/tensorflow/contrib/training/python/training/bucket_ops.py @@ -31,6 +31,7 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util +from tensorflow.python.layers import utils from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import data_flow_ops @@ -47,7 +48,6 @@ _dtypes = input_py._dtypes _store_sparse_tensors = input_py._store_sparse_tensors _validate_keep_input = input_py._validate_keep_input _shapes = input_py._shapes -_smart_cond = input_py._smart_cond _which_queue = input_py._which_queue # pylint: enable=protected-access @@ -239,7 +239,7 @@ def bucket(tensors, ] return control_flow_ops.group(*enqueues, name="group_enqueues") - maybe_enqueue = _smart_cond( + maybe_enqueue = utils.smart_cond( keep_input, enqueue_which, control_flow_ops.no_op) diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 013ed2e8fd49221ff0a3dc0845a254128ac295cc..eb440f66e89c3f59cb831317502774b181c7f897 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1411,7 +1411,7 @@ cc_library( hdrs = LIB_INTERNAL_PUBLIC_HEADERS, copts = tf_copts(), defines = tf_additional_lib_defines() + [ - "SNAPPY", + "TF_USE_SNAPPY", ] + tf_additional_verbs_lib_defines() + tf_additional_mpi_lib_defines() + tf_additional_gdr_lib_defines(), diff --git a/tensorflow/core/framework/api_def.proto b/tensorflow/core/framework/api_def.proto index 987caee25065d0316bde42a9db75fd4d2a171b8d..98c38efc0e9a8e2ca7caf6b666c8930eb7a32733 100644 --- a/tensorflow/core/framework/api_def.proto +++ b/tensorflow/core/framework/api_def.proto @@ -51,7 +51,8 @@ message ApiDef { // endpoints are deprecated). message Endpoint { // Name should be either like "CamelCaseName" or - // "Package.CamelCaseName". + // "Package.CamelCaseName". Client-language-specific ApiDefs may + // use a snake_case convention instead of CamelCase. string name = 1; // First GraphDef version at which the op is disallowed. @@ -74,7 +75,7 @@ message ApiDef { } repeated Arg in_arg = 4; repeated Arg out_arg = 5; - // List of post-rename in_arg names to specify new argument order. + // List of original in_arg names to specify new argument order. // Length of arg_order should be either empty to keep current order // or match size of in_arg. repeated string arg_order = 11; diff --git a/tensorflow/core/framework/op_gen_lib.cc b/tensorflow/core/framework/op_gen_lib.cc index cfaca897ba8c4e11707b26ef8002c1d303ecc6b9..1e93e9be0955c9d62588e009e5a6d899ce33698d 100644 --- a/tensorflow/core/framework/op_gen_lib.cc +++ b/tensorflow/core/framework/op_gen_lib.cc @@ -412,6 +412,8 @@ void InitApiDefFromOpDef(const OpDef& op_def, ApiDef* api_def) { api_in_arg->set_name(op_in_arg.name()); api_in_arg->set_rename_to(op_in_arg.name()); api_in_arg->set_description(op_in_arg.description()); + + *api_def->add_arg_order() = op_in_arg.name(); } for (const auto& op_out_arg : op_def.output_arg()) { auto* api_out_arg = api_def->add_out_arg(); @@ -503,6 +505,22 @@ Status MergeApiDefs(ApiDef* base_api_def, const ApiDef& new_api_def) { } // Merge arg order if (new_api_def.arg_order_size() > 0) { + // Validate that new arg_order is correct. + if (new_api_def.arg_order_size() != base_api_def->arg_order_size()) { + return errors::FailedPrecondition( + "Invalid number of arguments ", new_api_def.arg_order_size(), " for ", + base_api_def->graph_op_name(), + ". Expected: ", base_api_def->arg_order_size()); + } + if (!std::is_permutation(new_api_def.arg_order().begin(), + new_api_def.arg_order().end(), + base_api_def->arg_order().begin())) { + return errors::FailedPrecondition( + "Invalid arg_order: ", str_util::Join(new_api_def.arg_order(), ", "), + " for ", base_api_def->graph_op_name(), + ". All elements in arg_order override must match base arg_order: ", + str_util::Join(base_api_def->arg_order(), ", ")); + } base_api_def->clear_arg_order(); std::copy( new_api_def.arg_order().begin(), new_api_def.arg_order().end(), diff --git a/tensorflow/core/framework/op_gen_lib_test.cc b/tensorflow/core/framework/op_gen_lib_test.cc index b7ee6db9912ee92433328d49e09108c0d5e29d34..da9b4dfbb1738c855c0bfc4752853d5d501d80a8 100644 --- a/tensorflow/core/framework/op_gen_lib_test.cc +++ b/tensorflow/core/framework/op_gen_lib_test.cc @@ -207,6 +207,8 @@ attr { name: "attr_a" rename_to: "attr_a" } +arg_order: "arg_a" +arg_order: "arg_b" )"; OpList op_list; protobuf::TextFormat::ParseFromString(kTestOpList, &op_list); // NOLINT @@ -331,8 +333,8 @@ op { name: "arg_c" rename_to: "arg_cc" } - arg_order: "arg_aa" arg_order: "arg_b" + arg_order: "arg_a" } )"; OpList op_list; @@ -351,8 +353,8 @@ op { EXPECT_EQ("arg_cc", api_def->out_arg(0).rename_to()); ASSERT_EQ(2, api_def->arg_order_size()); - EXPECT_EQ("arg_aa", api_def->arg_order(0)); - EXPECT_EQ("arg_b", api_def->arg_order(1)); + EXPECT_EQ("arg_b", api_def->arg_order(0)); + EXPECT_EQ("arg_a", api_def->arg_order(1)); } TEST(OpGenLibTest, ApiDefOverrideDescriptions) { @@ -411,5 +413,47 @@ op { auto status = api_map.LoadApiDef(api_def1); ASSERT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code()); } + +TEST(OpGenLibTest, ApiDefInvalidArgOrder) { + const string api_def1 = R"( +op { + graph_op_name: "testop" + arg_order: "arg_a" + arg_order: "unexpected_arg" +} +)"; + + const string api_def2 = R"( +op { + graph_op_name: "testop" + arg_order: "arg_a" +} +)"; + + const string api_def3 = R"( +op { + graph_op_name: "testop" + arg_order: "arg_a" + arg_order: "arg_a" +} +)"; + + OpList op_list; + protobuf::TextFormat::ParseFromString(kTestOpList, &op_list); // NOLINT + ApiDefMap api_map(op_list); + TF_CHECK_OK(api_map.LoadApiDef(kTestApiDef)); + + // Loading with incorrect arg name in arg_order should fail. + auto status = api_map.LoadApiDef(api_def1); + ASSERT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code()); + + // Loading with incorrect number of args in arg_order should fail. + status = api_map.LoadApiDef(api_def2); + ASSERT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code()); + + // Loading with the same argument twice in arg_order should fail. + status = api_map.LoadApiDef(api_def3); + ASSERT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code()); +} } // namespace } // namespace tensorflow diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc index 92b4843221054bc9aa11b9a762d7295171fd89fa..b2c193b050b35707703dbd539236d95cb37cd29c 100644 --- a/tensorflow/core/graph/graph_constructor.cc +++ b/tensorflow/core/graph/graph_constructor.cc @@ -1068,10 +1068,16 @@ Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef, refiner->set_graph_def_version( std::min(refiner->graph_def_version(), gdef.versions().producer())); - return GraphConstructor::Construct( - opts, gdef.node(), &gdef.versions(), &gdef.library(), g, refiner, - &results->return_tensors, &results->return_nodes, - &results->unused_input_map_keys); + if (results == nullptr) { + return GraphConstructor::Construct(opts, gdef.node(), &gdef.versions(), + &gdef.library(), g, refiner, nullptr, + nullptr, nullptr); + } else { + return GraphConstructor::Construct( + opts, gdef.node(), &gdef.versions(), &gdef.library(), g, refiner, + &results->return_tensors, &results->return_nodes, + &results->unused_input_map_keys); + } } void CopyGraph(const Graph& src, Graph* dest) { diff --git a/tensorflow/core/grappler/grappler_item_builder.cc b/tensorflow/core/grappler/grappler_item_builder.cc index 54d60cd7aa41354267e23d65e6540d070a4937d1..3f6183b6f1ecb92dcc99abccacda74ceaf72cce0 100644 --- a/tensorflow/core/grappler/grappler_item_builder.cc +++ b/tensorflow/core/grappler/grappler_item_builder.cc @@ -450,12 +450,16 @@ std::unique_ptr GrapplerItemFromMetaGraphDef( } // Optimize the graph (function inlining, l1 optimizations, etc). + VLOG(1) << "Number of nodes in graph before OptimizeGraph: " + << new_item->graph.node_size(); Status optimize_status = OptimizeGraph(new_item->graph, &new_item->graph, cfg); if (!optimize_status.ok()) { LOG(ERROR) << "Graph preprocessing failed: " << optimize_status; return nullptr; } + VLOG(1) << "Number of nodes in graph after OptimizeGraph: " + << new_item->graph.node_size(); if (cfg.prune_graph) { VLOG(1) << "Pruning graph..."; @@ -464,7 +468,8 @@ std::unique_ptr GrapplerItemFromMetaGraphDef( LOG(ERROR) << "Pruning failed: " << status.error_message(); return nullptr; } - VLOG(1) << "Pruning ran succesfully."; + VLOG(1) << "Number of nodes in graph after pruning: " + << new_item->graph.node_size(); } // Validate feed, fetch and init nodes diff --git a/tensorflow/core/kernels/mkl_transpose_op.cc b/tensorflow/core/kernels/mkl_transpose_op.cc index 89a1d5e8a7da50876df74a1b98e8485eadf50655..764d4c9400e5751de29b9651eebc1328fdd09d59 100644 --- a/tensorflow/core/kernels/mkl_transpose_op.cc +++ b/tensorflow/core/kernels/mkl_transpose_op.cc @@ -18,6 +18,9 @@ limitations under the License. #ifdef INTEL_MKL #define EIGEN_USE_THREADS +#include "tensorflow/core/framework/numeric_types.h" +#define MKL_Complex8 tensorflow::complex64 +#define MKL_Complex16 tensorflow::complex128 #include "mkl_trans.h" #include "tensorflow/core/kernels/transpose_functor.h" #include "tensorflow/core/kernels/transpose_op.h" @@ -41,7 +44,7 @@ namespace tensorflow { namespace { template -void MKLTranspose2D(const char trans, const Tensor& in, Tensor* out) {} +Status MKLTranspose2D(const char trans, const Tensor& in, Tensor* out); // Documentation here: https://software.intel.com/en-us/node/520863 // Parameters: (ordering:row-major, operation:transpose, num_rows, num_cols, @@ -54,70 +57,73 @@ void MKLTranspose2D(const char trans, const Tensor& in, Tensor* out) {} mkl_##PREFIX##omatcopy('R', trans, in.dim_size(0), in.dim_size(1), 1, \ in.flat().data(), in.dim_size(1), \ out->flat().data(), in.dim_size(0)); \ - return Status::OK(); + return Status::OK(); \ } - INSTANTIATE(float, s) - INSTANTIATE(double, d) - INSTANTIATE(complex64, c) - INSTANTIATE(complex128, z) +INSTANTIATE(float, s) +INSTANTIATE(double, d) +INSTANTIATE(complex64, c) +INSTANTIATE(complex128, z) #undef INSTANTIATE - static const char kMKLTranspose = 'T'; - static const char kMKLConjugateTranspose = 'C'; - - } // namespace tensorflow - - Status MklTransposeCpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in, - gtl::ArraySlice perm, - Tensor* out) { - if (in.dims() == 2) { - switch (in.dtype()) { - case DT_FLOAT: - return MKLTranspose2D(kMKLTranspose, in, out); - case DT_DOUBLE: - return MKLTranspose2D(kMKLTranspose, in, out); - case DT_COMPLEX64: - return MKLTranspose2D(kMKLTranspose, in, out); - case DT_COMPLEX128: - return MKLTranspose2D(kMKLTranspose, in, out); - default: - break; - } +static const char kMKLTranspose = 'T'; +static const char kMKLConjugateTranspose = 'C'; + +} // namespace + +Status MklTransposeCpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in, + gtl::ArraySlice perm, + Tensor* out) { + if (in.dims() == 2) { + if (perm[0] == 0 && perm[1] == 1) { + return Status::OK(); + } + switch (in.dtype()) { + case DT_FLOAT: + return MKLTranspose2D(kMKLTranspose, in, out); + case DT_DOUBLE: + return MKLTranspose2D(kMKLTranspose, in, out); + case DT_COMPLEX64: + return MKLTranspose2D(kMKLTranspose, in, out); + case DT_COMPLEX128: + return MKLTranspose2D(kMKLTranspose, in, out); + default: + break; } - // Fallback to eigen if transpose parameters not supported by MKL - typedef Eigen::ThreadPoolDevice CPUDevice; - return ::tensorflow::DoTranspose(ctx->eigen_device(), in, perm, - out); } - - Status MklConjugateTransposeCpuOp::DoTranspose(OpKernelContext* ctx, - const Tensor& in, - gtl::ArraySlice perm, - Tensor* out) { - if (in.dims() == 2) { - // TODO(rmlarsen): By setting lda and ldb, we could use the MKL kernels - // for any transpose that can be reduced to swapping the last two - // dimensions in a rank-3 tensor. We can even run each outer dimension in - // a separate thread. - switch (in.dtype()) { - case DT_FLOAT: - return MKLTranspose2D(kMKLTranspose, in, out); - case DT_DOUBLE: - return MKLTranspose2D(kMKLTranspose, in, out); - case DT_COMPLEX64: - return MKLTranspose2D(kMKLConjugateTranspose, in, out); - case DT_COMPLEX128: - return MKLTranspose2D(kMKLConjugateTranspose, in, out); - default: - break; - } + // Fallback to eigen if transpose parameters not supported by MKL + typedef Eigen::ThreadPoolDevice CPUDevice; + return ::tensorflow::DoTranspose(ctx->eigen_device(), in, perm, + out); +} + +Status MklConjugateTransposeCpuOp::DoTranspose(OpKernelContext* ctx, + const Tensor& in, + gtl::ArraySlice perm, + Tensor* out) { + if (in.dims() == 2 && perm[0] == 1 && perm[1] == 0) { + // TODO(rmlarsen): By setting lda and ldb, we could use the MKL kernels + // for any transpose that can be reduced to swapping the last two + // dimensions in a rank-3 tensor. We can even run each outer dimension in + // a separate thread. + switch (in.dtype()) { + case DT_FLOAT: + return MKLTranspose2D(kMKLTranspose, in, out); + case DT_DOUBLE: + return MKLTranspose2D(kMKLTranspose, in, out); + case DT_COMPLEX64: + return MKLTranspose2D(kMKLConjugateTranspose, in, out); + case DT_COMPLEX128: + return MKLTranspose2D(kMKLConjugateTranspose, in, out); + default: + break; } - // Fallback to eigen if transpose parameters not supported by MKL - typedef Eigen::ThreadPoolDevice CPUDevice; - return ::tensorflow::DoConjugateTranspose(ctx->eigen_device(), - in, perm, out); } + // Fallback to eigen if transpose parameters not supported by MKL + typedef Eigen::ThreadPoolDevice CPUDevice; + return ::tensorflow::DoConjugateTranspose(ctx->eigen_device(), in, + perm, out); +} } // namespace tensorflow diff --git a/tensorflow/core/kernels/transpose_functor.h b/tensorflow/core/kernels/transpose_functor.h index 9781fe3b618594d8fd88d635208850c3765a8315..add4635331ee55495f5bc0d79fd040812078f1f8 100644 --- a/tensorflow/core/kernels/transpose_functor.h +++ b/tensorflow/core/kernels/transpose_functor.h @@ -201,17 +201,26 @@ Status DoTransposeImpl(const Device& d, const Tensor& in, case DT_COMPLEX64: if (conjugate) { - Transpose::run(d, in, perm, out); +#if defined(__ANDROID__) and !defined(__clang__) + // Workaround for GCC compiler bug in Android toolchain. + return errors::Unimplemented( + "Conjugate transpose of complex64 not supported for GCC on " + "Android."); +#else + Transpose::run(d, in, perm, out); +#endif } else { - Transpose::run(d, in, perm, out); + Transpose::run(d, in, perm, out); } break; case DT_COMPLEX128: if (conjugate) { - Transpose::run(d, in, perm, out); + Transpose::run(d, in, perm, + out); } else { - Transpose::run(d, in, perm, out); + Transpose::run(d, in, perm, + out); } break; diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl index 2c14ea917c092a3009cd235b0d9d65cc252b3402..e4518a8e2fdfd5a4a23c86a4b287b6f9c7183ef8 100644 --- a/tensorflow/core/platform/default/build_config.bzl +++ b/tensorflow/core/platform/default/build_config.bzl @@ -467,7 +467,7 @@ def tf_additional_core_deps(): "//conditions:default": [], }) + select({ "//tensorflow:with_s3_support": [ - "//tensorflow/contrib/s3:s3_file_system", + "//tensorflow/core/platform/s3:s3_file_system", ], "//conditions:default": [], }) diff --git a/tensorflow/core/platform/posix/port.cc b/tensorflow/core/platform/posix/port.cc index 3b17bac80896c6e042af4314b2947d97e45cbdf3..93a59348c8a5be1d7399f35aad8a4468a03d1f2b 100644 --- a/tensorflow/core/platform/posix/port.cc +++ b/tensorflow/core/platform/posix/port.cc @@ -29,7 +29,7 @@ limitations under the License. #include #include #include -#ifdef SNAPPY +#ifdef TF_USE_SNAPPY #include "snappy.h" #endif #if (defined(__APPLE__) && defined(__MACH__)) || defined(__FreeBSD__) @@ -126,7 +126,7 @@ void AdjustFilenameForLogging(string* filename) { } bool Snappy_Compress(const char* input, size_t length, string* output) { -#ifdef SNAPPY +#ifdef TF_USE_SNAPPY output->resize(snappy::MaxCompressedLength(length)); size_t outlen; snappy::RawCompress(input, length, &(*output)[0], &outlen); @@ -139,7 +139,7 @@ bool Snappy_Compress(const char* input, size_t length, string* output) { bool Snappy_GetUncompressedLength(const char* input, size_t length, size_t* result) { -#ifdef SNAPPY +#ifdef TF_USE_SNAPPY return snappy::GetUncompressedLength(input, length, result); #else return false; @@ -147,7 +147,7 @@ bool Snappy_GetUncompressedLength(const char* input, size_t length, } bool Snappy_Uncompress(const char* input, size_t length, char* output) { -#ifdef SNAPPY +#ifdef TF_USE_SNAPPY return snappy::RawUncompress(input, length, output); #else return false; diff --git a/tensorflow/contrib/s3/BUILD b/tensorflow/core/platform/s3/BUILD similarity index 100% rename from tensorflow/contrib/s3/BUILD rename to tensorflow/core/platform/s3/BUILD diff --git a/tensorflow/contrib/s3/s3_crypto.cc b/tensorflow/core/platform/s3/s3_crypto.cc similarity index 98% rename from tensorflow/contrib/s3/s3_crypto.cc rename to tensorflow/core/platform/s3/s3_crypto.cc index bbd66371e41c5ecf4c6edfcb3a115cae2fb4e933..d7062a59d2c88195b67cdf3c62cb14164e1038f0 100644 --- a/tensorflow/contrib/s3/s3_crypto.cc +++ b/tensorflow/core/platform/s3/s3_crypto.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/s3/s3_crypto.h" +#include "tensorflow/core/platform/s3/s3_crypto.h" #include #include diff --git a/tensorflow/contrib/s3/s3_crypto.h b/tensorflow/core/platform/s3/s3_crypto.h similarity index 100% rename from tensorflow/contrib/s3/s3_crypto.h rename to tensorflow/core/platform/s3/s3_crypto.h diff --git a/tensorflow/contrib/s3/s3_file_system.cc b/tensorflow/core/platform/s3/s3_file_system.cc similarity index 99% rename from tensorflow/contrib/s3/s3_file_system.cc rename to tensorflow/core/platform/s3/s3_file_system.cc index daced83145353c52ae19e2b7e8491b5fcb31cc1f..51c85592bf43bdfb68c4ba90d19d28582560d6d4 100644 --- a/tensorflow/contrib/s3/s3_file_system.cc +++ b/tensorflow/core/platform/s3/s3_file_system.cc @@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/s3/s3_file_system.h" -#include "tensorflow/contrib/s3/s3_crypto.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/s3/s3_file_system.h" +#include "tensorflow/core/platform/s3/s3_crypto.h" #include #include diff --git a/tensorflow/contrib/s3/s3_file_system.h b/tensorflow/core/platform/s3/s3_file_system.h similarity index 100% rename from tensorflow/contrib/s3/s3_file_system.h rename to tensorflow/core/platform/s3/s3_file_system.h diff --git a/tensorflow/contrib/s3/s3_file_system_test.cc b/tensorflow/core/platform/s3/s3_file_system_test.cc similarity index 99% rename from tensorflow/contrib/s3/s3_file_system_test.cc rename to tensorflow/core/platform/s3/s3_file_system_test.cc index 949281fad4a6c6d67f12d4de4e6be0e5e4d025ea..0b42f5fcec0041a01a571b1e38dedaa7ef191c22 100644 --- a/tensorflow/contrib/s3/s3_file_system_test.cc +++ b/tensorflow/core/platform/s3/s3_file_system_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/s3/s3_file_system.h" +#include "tensorflow/core/platform/s3/s3_file_system.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/gtl/stl_util.h" diff --git a/tensorflow/core/platform/windows/port.cc b/tensorflow/core/platform/windows/port.cc index 85b53e07c439e02f63d4600c57c925f3b8d843b9..e327d53949caf7e2d30e6deba0be2848f010afc2 100644 --- a/tensorflow/core/platform/windows/port.cc +++ b/tensorflow/core/platform/windows/port.cc @@ -20,7 +20,7 @@ limitations under the License. #include #include #include -#ifdef SNAPPY +#ifdef TF_USE_SNAPPY #include "snappy.h" #endif @@ -118,7 +118,7 @@ void AdjustFilenameForLogging(string* filename) { } bool Snappy_Compress(const char* input, size_t length, string* output) { -#ifdef SNAPPY +#ifdef TF_USE_SNAPPY output->resize(snappy::MaxCompressedLength(length)); size_t outlen; snappy::RawCompress(input, length, &(*output)[0], &outlen); @@ -131,7 +131,7 @@ bool Snappy_Compress(const char* input, size_t length, string* output) { bool Snappy_GetUncompressedLength(const char* input, size_t length, size_t* result) { -#ifdef SNAPPY +#ifdef TF_USE_SNAPPY return snappy::GetUncompressedLength(input, length, result); #else return false; @@ -139,7 +139,7 @@ bool Snappy_GetUncompressedLength(const char* input, size_t length, } bool Snappy_Uncompress(const char* input, size_t length, char* output) { -#ifdef SNAPPY +#ifdef TF_USE_SNAPPY return snappy::RawUncompress(input, length, output); #else return false; diff --git a/tensorflow/examples/learn/iris.py b/tensorflow/examples/learn/iris.py index 33e8d45801409fa112e27f40b1732c43cda72bc2..0a50b3ba87d70a58794bc35009dc76de2cb71d1e 100644 --- a/tensorflow/examples/learn/iris.py +++ b/tensorflow/examples/learn/iris.py @@ -17,47 +17,94 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np -from sklearn import datasets -from sklearn import metrics -from sklearn import model_selection +import os +import urllib import tensorflow as tf +# Data sets +IRIS_TRAINING = 'iris_training.csv' +IRIS_TRAINING_URL = 'http://download.tensorflow.org/data/iris_training.csv' -X_FEATURE = 'x' # Name of the input feature. +IRIS_TEST = 'iris_test.csv' +IRIS_TEST_URL = 'http://download.tensorflow.org/data/iris_test.csv' + +FEATURE_KEYS = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width'] + + +def maybe_download_iris_data(file_name, download_url): + """Downloads the file and returns the number of data.""" + if not os.path.exists(file_name): + raw = urllib.urlopen(download_url).read() + with open(file_name, 'w') as f: + f.write(raw) + + # The first line is a comma-separated string. The first one is the number of + # total data in the file. + with open(file_name, 'r') as f: + first_line = f.readline() + num_elements = first_line.split(',')[0] + return int(num_elements) + + +def input_fn(file_name, num_data, batch_size, is_training): + """Creates an input_fn required by Estimator train/evaluate.""" + # If the data sets aren't stored locally, download them. + + def _parse_csv(rows_string_tensor): + """Takes the string input tensor and returns tuple of (features, labels).""" + # Last dim is the label. + num_features = len(FEATURE_KEYS) + num_columns = num_features + 1 + columns = tf.decode_csv(rows_string_tensor, + record_defaults=[[]] * num_columns) + features = dict(zip(FEATURE_KEYS, columns[:num_features])) + labels = tf.cast(columns[num_features], tf.int32) + return features, labels + + def _input_fn(): + """The input_fn.""" + dataset = tf.data.TextLineDataset([file_name]) + # Skip the first line (which does not have data). + dataset = dataset.skip(1) + dataset = dataset.map(_parse_csv) + + if is_training: + # For this small dataset, which can fit into memory, to achieve true + # randomness, the shuffle buffer size is set as the total number of + # elements in the dataset. + dataset = dataset.shuffle(num_data) + dataset = dataset.repeat() + + dataset = dataset.batch(batch_size) + iterator = dataset.make_one_shot_iterator() + features, labels = iterator.get_next() + return features, labels + + return _input_fn def main(unused_argv): - # Load dataset. - iris = datasets.load_iris() - x_train, x_test, y_train, y_test = model_selection.train_test_split( - iris.data, iris.target, test_size=0.2, random_state=42) + tf.logging.set_verbosity(tf.logging.INFO) + + num_training_data = maybe_download_iris_data( + IRIS_TRAINING, IRIS_TRAINING_URL) + num_test_data = maybe_download_iris_data(IRIS_TEST, IRIS_TEST_URL) # Build 3 layer DNN with 10, 20, 10 units respectively. feature_columns = [ - tf.feature_column.numeric_column( - X_FEATURE, shape=np.array(x_train).shape[1:])] + tf.feature_column.numeric_column(key, shape=1) for key in FEATURE_KEYS] classifier = tf.estimator.DNNClassifier( feature_columns=feature_columns, hidden_units=[10, 20, 10], n_classes=3) # Train. - train_input_fn = tf.estimator.inputs.numpy_input_fn( - x={X_FEATURE: x_train}, y=y_train, num_epochs=None, shuffle=True) - classifier.train(input_fn=train_input_fn, steps=200) - - # Predict. - test_input_fn = tf.estimator.inputs.numpy_input_fn( - x={X_FEATURE: x_test}, y=y_test, num_epochs=1, shuffle=False) - predictions = classifier.predict(input_fn=test_input_fn) - y_predicted = np.array(list(p['class_ids'] for p in predictions)) - y_predicted = y_predicted.reshape(np.array(y_test).shape) - - # Score with sklearn. - score = metrics.accuracy_score(y_test, y_predicted) - print('Accuracy (sklearn): {0:f}'.format(score)) - - # Score with tensorflow. + train_input_fn = input_fn(IRIS_TRAINING, num_training_data, batch_size=32, + is_training=True) + classifier.train(input_fn=train_input_fn, steps=400) + + # Eval. + test_input_fn = input_fn(IRIS_TEST, num_test_data, batch_size=32, + is_training=False) scores = classifier.evaluate(input_fn=test_input_fn) print('Accuracy (tensorflow): {0:f}'.format(scores['accuracy'])) diff --git a/tensorflow/examples/learn/random_forest_mnist.py b/tensorflow/examples/learn/random_forest_mnist.py index 3c09990ea1eecdf7b5dff95b0fb60197cd0787b7..72c935cdae2196a1309097e4e6f15bd6f22f96a5 100644 --- a/tensorflow/examples/learn/random_forest_mnist.py +++ b/tensorflow/examples/learn/random_forest_mnist.py @@ -1,4 +1,4 @@ - # Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,18 +21,14 @@ import argparse import sys import tempfile -# pylint: disable=g-backslash-continuation -from tensorflow.contrib.learn.python.learn\ - import metric_spec -from tensorflow.contrib.learn.python.learn.estimators\ - import estimator -from tensorflow.contrib.tensor_forest.client\ - import eval_metrics -from tensorflow.contrib.tensor_forest.client\ - import random_forest -from tensorflow.contrib.tensor_forest.python\ - import tensor_forest +import numpy + +from tensorflow.contrib.learn.python.learn import metric_spec +from tensorflow.contrib.tensor_forest.client import eval_metrics +from tensorflow.contrib.tensor_forest.client import random_forest +from tensorflow.contrib.tensor_forest.python import tensor_forest from tensorflow.examples.tutorials.mnist import input_data +from tensorflow.python.estimator.inputs import numpy_io from tensorflow.python.platform import app FLAGS = None @@ -41,16 +37,15 @@ FLAGS = None def build_estimator(model_dir): """Build an estimator.""" params = tensor_forest.ForestHParams( - num_classes=10, num_features=784, - num_trees=FLAGS.num_trees, max_nodes=FLAGS.max_nodes) + num_classes=10, + num_features=784, + num_trees=FLAGS.num_trees, + max_nodes=FLAGS.max_nodes) graph_builder_class = tensor_forest.RandomForestGraphs if FLAGS.use_training_loss: graph_builder_class = tensor_forest.TrainingLossForest - # Use the SKCompat wrapper, which gives us a convenient way to split - # in-memory data like MNIST into batches. - return estimator.SKCompat(random_forest.TensorForestEstimator( - params, graph_builder_class=graph_builder_class, - model_dir=model_dir)) + return random_forest.TensorForestEstimator( + params, graph_builder_class=graph_builder_class, model_dir=model_dir) def train_and_eval(): @@ -62,18 +57,30 @@ def train_and_eval(): mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=False) - est.fit(x=mnist.train.images, y=mnist.train.labels, - batch_size=FLAGS.batch_size) + train_input_fn = numpy_io.numpy_input_fn( + x={'images': mnist.train.images}, + y=mnist.train.labels.astype(numpy.int32), + batch_size=FLAGS.batch_size, + num_epochs=None, + shuffle=True) + est.fit(input_fn=train_input_fn, steps=None) metric_name = 'accuracy' - metric = {metric_name: - metric_spec.MetricSpec( - eval_metrics.get_metric(metric_name), - prediction_key=eval_metrics.get_prediction_key(metric_name))} - - results = est.score(x=mnist.test.images, y=mnist.test.labels, - batch_size=FLAGS.batch_size, - metrics=metric) + metric = { + metric_name: + metric_spec.MetricSpec( + eval_metrics.get_metric(metric_name), + prediction_key=eval_metrics.get_prediction_key(metric_name)) + } + + test_input_fn = numpy_io.numpy_input_fn( + x={'images': mnist.test.images}, + y=mnist.test.labels.astype(numpy.int32), + num_epochs=1, + batch_size=FLAGS.batch_size, + shuffle=False) + + results = est.evaluate(input_fn=test_input_fn, metrics=metric) for key in sorted(results): print('%s: %s' % (key, results[key])) diff --git a/tensorflow/examples/learn/text_classification_character_rnn.py b/tensorflow/examples/learn/text_classification_character_rnn.py index 1fc9388a1a026013ad14f8d1deeccbed817d1c88..86adc056add508c309b3a5b93e58e9c195995642 100644 --- a/tensorflow/examples/learn/text_classification_character_rnn.py +++ b/tensorflow/examples/learn/text_classification_character_rnn.py @@ -30,7 +30,6 @@ import sys import numpy as np import pandas -from sklearn import metrics import tensorflow as tf FLAGS = None @@ -46,8 +45,8 @@ def char_rnn_model(features, labels, mode): byte_vectors = tf.one_hot(features[CHARS_FEATURE], 256, 1., 0.) byte_list = tf.unstack(byte_vectors, axis=1) - cell = tf.contrib.rnn.GRUCell(HIDDEN_SIZE) - _, encoding = tf.contrib.rnn.static_rnn(cell, byte_list, dtype=tf.float32) + cell = tf.nn.rnn_cell.GRUCell(HIDDEN_SIZE) + _, encoding = tf.nn.static_rnn(cell, byte_list, dtype=tf.float32) logits = tf.layers.dense(encoding, MAX_LABEL, activation=None) @@ -98,28 +97,20 @@ def main(unused_argv): train_input_fn = tf.estimator.inputs.numpy_input_fn( x={CHARS_FEATURE: x_train}, y=y_train, - batch_size=len(x_train), + batch_size=128, num_epochs=None, shuffle=True) classifier.train(input_fn=train_input_fn, steps=100) - # Predict. + # Eval. test_input_fn = tf.estimator.inputs.numpy_input_fn( x={CHARS_FEATURE: x_test}, y=y_test, num_epochs=1, shuffle=False) - predictions = classifier.predict(input_fn=test_input_fn) - y_predicted = np.array(list(p['class'] for p in predictions)) - y_predicted = y_predicted.reshape(np.array(y_test).shape) - # Score with sklearn. - score = metrics.accuracy_score(y_test, y_predicted) - print('Accuracy (sklearn): {0:f}'.format(score)) - - # Score with tensorflow. scores = classifier.evaluate(input_fn=test_input_fn) - print('Accuracy (tensorflow): {0:f}'.format(scores['accuracy'])) + print('Accuracy: {0:f}'.format(scores['accuracy'])) if __name__ == '__main__': diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 21cdaec47788b1287157eb1a6782d9f5accfed24..b7aa7bbf6b5d99fd32cadb51d7965ce7c43bb29e 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -1978,6 +1978,7 @@ py_library( ":tensor_array_ops", ":util", ":variable_scope", + "//tensorflow/python/eager:context", ], ) @@ -2638,6 +2639,7 @@ py_library( ":init_ops", ":io_ops", ":io_ops_gen", + ":layers_base", ":lib", ":lookup_ops", ":math_ops", diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index da17be05b7d434f593e4ae7e8560933c5142528b..9580e8484751442426da924db23a818747f3e758 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -396,12 +396,11 @@ def implicit_grad(f): return grad_fn -def _get_arg_spec(f, params): +def _get_arg_spec(f, params, param_args): args = tf_inspect.getargspec(f).args if params is None: if not args: - raise ValueError("When params is None the differentiated function cannot" - " only take arguments by *args and **kwds.") + return range(len(param_args)) return range(len(args)) elif all(isinstance(x, six.string_types) for x in params): return [args.index(n) for n in params] @@ -560,10 +559,9 @@ def val_and_grad_function(f, params=None): ValueError: if the params are not all strings or all integers. """ - parameter_positions = _get_arg_spec(f, params) - def decorated(*args, **kwds): """Computes the value and gradient of the decorated function.""" + parameter_positions = _get_arg_spec(f, params, args) dy = kwds.pop("dy", None) if dy is not None: dy = ops.convert_to_tensor(dy) @@ -616,10 +614,9 @@ def make_vjp(f, params=None): """ - parameter_positions = _get_arg_spec(f, params) - def decorated(*args, **kwds): """Computes the value and gradient of the decorated function.""" + parameter_positions = _get_arg_spec(f, params, args) assert not kwds, "The gradient function can't take keyword arguments." tape.push_new_tape() sources = [] diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py index 95d5f0adcb44384fa99cb51feb21b17c5ef32c4e..9ba5913c65e2e6b0cc28ba015761a173db7fad37 100644 --- a/tensorflow/python/eager/backprop_test.py +++ b/tensorflow/python/eager/backprop_test.py @@ -292,7 +292,7 @@ class BackpropTest(test.TestCase): self.assertEqual(grad.numpy(), 6.0) def testGradientTapeVariable(self): - v = resource_variable_ops.ResourceVariable(1.0) + v = resource_variable_ops.ResourceVariable(1.0, name='v') with backprop.GradientTape() as g: y = v * v grad = g.gradient(y, [v])[0] @@ -381,6 +381,14 @@ class BackpropTest(test.TestCase): [tensor_shape.TensorShape(s).as_proto() for s in shape_list], backprop.make_attr([pywrap_tensorflow.TF_ATTR_SHAPE], shape_list)) + def testArgsGradientFunction(self): + + def f(*args): + return args[0] * args[0] + + grad = backprop.gradients_function(f) + self.assertAllEqual(grad(1.0)[0], 2.0) + def testMultiValueConvertToTensor(self): x = resource_variable_ops.ResourceVariable( initial_value=array_ops.constant([1.0]), name='x') @@ -449,7 +457,8 @@ class BackpropTest(test.TestCase): add_n.append(1) context.context().add_post_execution_callback(callback) - v = resource_variable_ops.ResourceVariable(constant_op.constant(2.0)) + v = resource_variable_ops.ResourceVariable(constant_op.constant(2.0), + name='v') def fn(): outputs = [] for _ in range(20): diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index aa7cba56defa2a7cc858e70c0dff07d07e154517..58581283d27eb3e96b21d86ca04f402129d8a8a9 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -26,7 +26,6 @@ import threading from tensorflow.python import pywrap_tensorflow from tensorflow.python.framework import device as pydev from tensorflow.python.framework import errors -from tensorflow.python.util import compat from tensorflow.python.util import tf_contextlib GRAPH_MODE = 0 @@ -103,11 +102,16 @@ class Context(object): if self._context_handle is not None: return assert self._context_devices is None - opts = pywrap_tensorflow.TF_NewSessionOptions( - target=compat.as_bytes(""), config=self._config) - with errors.raise_exception_on_not_ok_status() as status: - self._context_handle = pywrap_tensorflow.TFE_NewContext(opts, status) - pywrap_tensorflow.TF_DeleteSessionOptions(opts) + opts = pywrap_tensorflow.TFE_NewContextOptions() + try: + with errors.raise_exception_on_not_ok_status() as status: + if self._config is not None: + config_str = self._config.SerializeToString() + pywrap_tensorflow.TFE_ContextOptionsSetConfig( + opts, config_str, len(config_str), status) + self._context_handle = pywrap_tensorflow.TFE_NewContext(opts, status) + finally: + pywrap_tensorflow.TFE_DeleteContextOptions(opts) # Store list of devices self._context_devices = [] with errors.raise_exception_on_not_ok_status() as status: diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index da49517cf9446092a5895af4871a37f9c4d5598e..e675ee8988f73ea9489d2a2ccfa2d4d0ba565b3d 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -79,6 +79,22 @@ def capture_tensors(captures): _scoped_captures.tensors = old +def capture_value(tensor_map, value, dtype, name): + """Capture a value from outside the function, to pass in as an extra arg.""" + captured_value = tensor_map.get(ops.tensor_id(value), None) + if captured_value is None: + captured_value = graph_placeholder( + dtype=dtype or value.dtype, shape=value.shape, name=name) + if captured_value.dtype == dtypes.resource: + captured_value._handle_data = value._handle_data # pylint: disable=protected-access + tensor_map[ops.tensor_id(value)] = (value, captured_value) + else: + captured_value = captured_value[1] + tape.record_operation("captured_value", [captured_value], [value], + lambda x: [x]) + return captured_value + + def _convert_to_graph_tensor(value, dtype=None, name=None, as_ref=False): """Captures a Tensor while building a graph mode function. @@ -100,18 +116,33 @@ def _convert_to_graph_tensor(value, dtype=None, name=None, as_ref=False): if tensor_map is None: # Capturing is not enabled. return constant_op.constant(value.numpy()) - captured_value = tensor_map.get(ops.tensor_id(value), None) - if captured_value is None: - captured_value = graph_placeholder( - dtype=dtype or value.dtype, shape=value.shape, name=name) - if captured_value.dtype == dtypes.resource: - captured_value._handle_data = value._handle_data # pylint: disable=protected-access - tensor_map[ops.tensor_id(value)] = (value, captured_value) - else: - captured_value = captured_value[1] - tape.record_operation("captured_value", [captured_value], [value], - lambda x: [x]) - return captured_value + return capture_value(tensor_map, value, dtype, name) + + +class CapturingGraph(ops.Graph): + + def __init__(self, captures): + super(CapturingGraph, self).__init__() + self._building_function = True + self.captures = captures + + def create_op( + self, + op_type, + inputs, + dtypes, # pylint: disable=redefined-outer-name + input_types=None, + name=None, + attrs=None, + op_def=None, + compute_shapes=True, + compute_device=True): + for i, inp in enumerate(inputs): + if inp.graph is not self: + inputs[i] = capture_value(self.captures, inp, inp.dtype, inp.op.name) + return super(CapturingGraph, self).create_op( + op_type, inputs, dtypes, input_types, name, attrs, op_def, + compute_shapes, compute_device) # TODO(apassos): it'd be really nice if we could scope this registration. @@ -325,6 +356,8 @@ class _GraphModeFunction(object): name="FunctionCall", compute_shapes=False) result = op.outputs + if not result: + return op for i, s in enumerate(self._output_shapes): result[i].set_shape(s) else: @@ -381,7 +414,8 @@ def _get_defun_inputs(args): def _defun_internal(name, func, args, kwds): """Defines and returns graph-mode version of func.""" with context.graph_mode(): - tmp_graph = ops.Graph() + captures = {} + tmp_graph = CapturingGraph(captures) # Copy the graph collections to ensure summaries and other things work. This # lets the function access (but not mutate) collections of the containing # graph, such as the global step and the summary writer collections. @@ -392,7 +426,6 @@ def _defun_internal(name, func, args, kwds): with tmp_graph.as_default(): func_inputs = _get_defun_inputs(args) - captures = {} with capture_tensors(captures): func_outputs = func(*func_inputs, **kwds) ids = list(sorted(captures.keys())) diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index fb647f5c2110c89390b866c5e4ed4aa157b6d070..33bedb59f3a239edf0096d774b75b9e79811554f 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -32,6 +32,7 @@ from tensorflow.python.ops import clip_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables class FunctionTest(test.TestCase): @@ -56,7 +57,7 @@ class FunctionTest(test.TestCase): self.assertAllEqual(out, math_ops.matmul(t, t).numpy()) def testGraphModeWithGradients(self): - v = resource_variable_ops.ResourceVariable(1.0) + v = resource_variable_ops.ResourceVariable(1.0, name='v') @function.defun def step(): @@ -68,6 +69,23 @@ class FunctionTest(test.TestCase): self.assertAllEqual(step(), 2.0) + def testGraphModeCaptureVariable(self): + with context.graph_mode(), self.test_session() as sess: + + class HasAVar(object): + + def __init__(self): + self.v = resource_variable_ops.ResourceVariable(1.0) + + def call(self): + return self.v * 2 + + o = HasAVar() + variables.global_variables_initializer().run() + call = function.defun(o.call) + op = call() + self.assertAllEqual(sess.run(op), 2.0) + def testTensorConversionWithDefun(self): @function.defun @@ -138,7 +156,7 @@ class FunctionTest(test.TestCase): g(constant_op.constant(1.0)) def testGradientTensorConversionWithDefun(self): - three = resource_variable_ops.ResourceVariable(3.0) + three = resource_variable_ops.ResourceVariable(3.0, name='v') @function.defun def f(x): diff --git a/tensorflow/python/eager/graph_callable.py b/tensorflow/python/eager/graph_callable.py index 3aba164630d7968b37c09f9bf69518b615f84f70..0ec83636a0fd086fb725cc206715c5fc40c243e1 100644 --- a/tensorflow/python/eager/graph_callable.py +++ b/tensorflow/python/eager/graph_callable.py @@ -312,11 +312,21 @@ def _graph_callable_internal(func, shape_and_dtypes): Returns: Callable graph object. """ + container = tf_ops.get_default_graph()._container # pylint: disable=protected-access + container_prefix = tf_ops.get_default_graph()._container_prefix # pylint: disable=protected-access with context.graph_mode(): # This graph will store both the initialization and the call version of the # wrapped function. It will later be used by the backprop code to build the # backprop graph, if necessary. tmp_graph = tf_ops.Graph() + # Inherit the container from the original graph to create resources at user + # expected containers. Also inherits the container prefix, since this is + # used for error checking when isolating Eager execution (the container + # prefix at creation must match the container prefix when used, and + # variables returned from the graph callable will be used in the outside + # context). + tmp_graph._container = container # pylint: disable=protected-access + tmp_graph._container_prefix = container_prefix # pylint: disable=protected-access with tmp_graph.as_default(): # Placeholders for the non-variable inputs. func_inputs = _get_graph_callable_inputs(shape_and_dtypes) diff --git a/tensorflow/python/estimator/export/export.py b/tensorflow/python/estimator/export/export.py index e2e20f0d717b7dec7a968222b2d76f315b2b538f..31e9933c6f702393eb21b10c5bdd770739056032 100644 --- a/tensorflow/python/estimator/export/export.py +++ b/tensorflow/python/estimator/export/export.py @@ -33,6 +33,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import signature_def_utils from tensorflow.python.util import compat @@ -47,8 +48,8 @@ class ServingInputReceiver(collections.namedtuple( """A return type for a serving_input_receiver_fn. The expected return values are: - features: A dict of string to `Tensor` or `SparseTensor`, specifying the - features to be passed to the model. + features: A `Tensor`, `SparseTensor`, or dict of string to `Tensor` or + `SparseTensor`, specifying the features to be passed to the model. receiver_tensors: a `Tensor`, or a dict of string to `Tensor`, specifying input nodes where this receiver expects to be fed by default. Typically, this is a single placeholder expecting serialized `tf.Example` protos. @@ -193,13 +194,14 @@ def build_all_signature_defs(receiver_tensors, raise ValueError('export_outputs must be a dict.') signature_def_map = {} + excluded_signatures = {} for output_key, export_output in export_outputs.items(): signature_name = '{}'.format(output_key or 'None') try: signature = export_output.as_signature_def(receiver_tensors) signature_def_map[signature_name] = signature - except ValueError: - pass + except ValueError as e: + excluded_signatures[signature_name] = str(e) if receiver_tensors_alternatives: for receiver_name, receiver_tensors_alt in ( @@ -213,8 +215,10 @@ def build_all_signature_defs(receiver_tensors, try: signature = export_output.as_signature_def(receiver_tensors_alt) signature_def_map[signature_name] = signature - except ValueError: - pass + except ValueError as e: + excluded_signatures[signature_name] = str(e) + + _log_signature_report(signature_def_map, excluded_signatures) # The above calls to export_output.as_signature_def should return only # valid signatures; if there is a validity problem, they raise ValueError, @@ -224,6 +228,46 @@ def build_all_signature_defs(receiver_tensors, if signature_def_utils.is_valid_signature(v)} +_FRIENDLY_METHOD_NAMES = { + signature_constants.CLASSIFY_METHOD_NAME: 'Classify', + signature_constants.REGRESS_METHOD_NAME: 'Regress', + signature_constants.PREDICT_METHOD_NAME: 'Predict', +} + + +def _log_signature_report(signature_def_map, excluded_signatures): + """Log a report of which signatures were produced.""" + sig_names_by_method_name = collections.defaultdict(list) + + # We'll collect whatever method_names are present, but also we want to make + # sure to output a line for each of the three standard methods even if they + # have no signatures. + for method_name in _FRIENDLY_METHOD_NAMES: + sig_names_by_method_name[method_name] = [] + + for signature_name, sig in signature_def_map.items(): + sig_names_by_method_name[sig.method_name].append(signature_name) + + # TODO(b/67733540): consider printing the full signatures, not just names + for method_name, sig_names in sig_names_by_method_name.items(): + if method_name in _FRIENDLY_METHOD_NAMES: + method_name = _FRIENDLY_METHOD_NAMES[method_name] + logging.info('Signatures INCLUDED in export for {}: {}'.format( + method_name, sig_names if sig_names else 'None')) + + if excluded_signatures: + logging.info('Signatures EXCLUDED from export because they cannot be ' + 'be served via TensorFlow Serving APIs:') + for signature_name, message in excluded_signatures.items(): + logging.info('\'{}\' : {}'.format(signature_name, message)) + + if not signature_def_map: + logging.warn('Export includes no signatures!') + elif (signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY + not in signature_def_map): + logging.warn('Export includes no default signature!') + + # When we create a timestamped directory, there is a small chance that the # directory already exists because another worker is also writing exports. # In this case we just wait one second to get a new timestamp and try again. diff --git a/tensorflow/python/estimator/export/export_output.py b/tensorflow/python/estimator/export/export_output.py index 7c7f92872ebb10c9679c07aa3bb15bfbf5021b4d..863af6d41d985043542b03375372fe564c283b82 100644 --- a/tensorflow/python/estimator/export/export_output.py +++ b/tensorflow/python/estimator/export/export_output.py @@ -150,6 +150,9 @@ class RegressionOutput(ExportOutput): return signature_def_utils.regression_signature_def(examples, self.value) +_SINGLE_OUTPUT_DEFAULT_NAME = 'output' + + class PredictOutput(ExportOutput): """Represents the output of a generic prediction head. @@ -162,16 +165,15 @@ class PredictOutput(ExportOutput): """Constructor for PredictOutput. Args: - outputs: A dict of string to `Tensor` representing the predictions. + outputs: A `Tensor` or a dict of string to `Tensor` representing the + predictions. Raises: ValueError: if the outputs is not dict, or any of its keys are not strings, or any of its values are not `Tensor`s. """ if not isinstance(outputs, dict): - raise ValueError( - 'Prediction outputs must be given as a dict of string to Tensor; ' - 'got {}'.format(outputs)) + outputs = {_SINGLE_OUTPUT_DEFAULT_NAME: outputs} for key, value in outputs.items(): if not isinstance(key, six.string_types): raise ValueError( diff --git a/tensorflow/python/estimator/export/export_output_test.py b/tensorflow/python/estimator/export/export_output_test.py index 035a9a143e6ffa18ae78ef2544614f342363b22d..7090e53d807817db7d66ed0ee1307d7e38e9e87e 100644 --- a/tensorflow/python/estimator/export/export_output_test.py +++ b/tensorflow/python/estimator/export/export_output_test.py @@ -199,20 +199,18 @@ class ExportOutputTest(test.TestCase): signature_constants.CLASSIFY_METHOD_NAME) self.assertEqual(actual_signature_def, expected_signature_def) - def test_predict_output_constructor(self): - """Tests that no errors are raised when input is expected.""" + def test_predict_outputs_valid(self): + """Tests that no errors are raised when provided outputs are valid.""" outputs = { "output0": constant_op.constant([0]), - u"output1": constant_op.constant([1]), + u"output1": constant_op.constant(["foo"]), } export_output_lib.PredictOutput(outputs) - def test_predict_output_outputs_invalid(self): - with self.assertRaisesRegexp( - ValueError, - "Prediction outputs must be given as a dict of string to Tensor"): - export_output_lib.PredictOutput(constant_op.constant([0])) + # Single Tensor is OK too + export_output_lib.PredictOutput(constant_op.constant([0])) + def test_predict_outputs_invalid(self): with self.assertRaisesRegexp( ValueError, "Prediction output key must be a string"): diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index c681ffb514cc6bb9a9984cbc2c667644794ab1e5..a01bf02deb4e1b947abade7c14e0ac58b40bd3d3 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -47,6 +47,7 @@ from tensorflow.python import pywrap_tensorflow from tensorflow.python.client import device_lib from tensorflow.python.client import session from tensorflow.python.eager import context +from tensorflow.python.eager import tape from tensorflow.python.framework import device as pydev from tensorflow.python.framework import errors from tensorflow.python.framework import ops @@ -391,6 +392,66 @@ def with_c_api(cls): return cls +class IsolateTest(object): + """A context manager which isolates resources in its block. + + Provides an Eager-agnostic abstraction for preventing the sharing of + variables and other resources. + + In graph mode, resource handle ops are only executed in a particular Session, + isolating them from resources with the same name in other Graphs. In Eager, + separate Sessions do not exist, so resources (particularly ResourceVariables) + would be shared implicitly if a resource of the same name were created + anywhere in a Python process. Multiple handles to the same resource would + cause several issues, and so this type of sharing will raise an exception. + + Using resources with the same name in a single Python process may be useful + (especially for unit tests), so this context manager provides an abstraction + for isolating resources. Using a resource created in one Isolation environment + in another is an error. + + Example usage in Eager mode: + + ```python + import tensorflow as tf + # Import subject to change + from tensorflow.contrib.eager.python import tfe + + tfe.enable_eager_execution() + + for hyperparameter in [1, 2, 3]: + with tfe.IsolateTest(): + v = tfe.Variable(name="v", initial_value=hyperparameter) + # train model, test results ... + ``` + + IsolateTest is currently exposed through contrib.eager, but it creates a new + default Graph and provides equivalent safety in graph mode. + """ + + def __init__(self): + if context.in_eager_mode() and tape.could_possibly_record(): + raise ValueError("Cannot isolate Eager execution with an active tape.") + # In Eager, Graphs set a container which isolates resources, and maintain a + # VariableStore which caches ResourceVariable objects created through + # get_variable. So setting the default Graph has the side effect of + # isolating Eager resources. + with context.eager_mode(): + # Create the graph in Eager mode, as this provides stricter semantics + # (i.e. has a unique container prefix). This prevents implicit sharing + # when a Graph-mode graph is created and then Eager mode is enabled (an + # error through enable_eager_execution, but common with context managers + # in unit tests). + self._graph_as_default_context_manager = ops.Graph().as_default() + + def __enter__(self): + self._graph_as_default_context_manager.__enter__() + + def __exit__(self, type_arg, value_arg, traceback_arg): + return self._graph_as_default_context_manager.__exit__( + type_arg, value_arg, traceback_arg) + + def run_in_graph_and_eager_modes(__unused__=None, graph=None, config=None, use_gpu=False, force_gpu=False, reset_test=True): @@ -440,9 +501,8 @@ def run_in_graph_and_eager_modes(__unused__=None, graph=None, config=None, with context.device("/device:CPU:0"): f(self, **kwargs) - eager_graph = graph or ops.Graph() with context.eager_mode(): - with eager_graph.as_default(): + with IsolateTest(): run_eager_mode() return decorated diff --git a/tensorflow/python/framework/test_util_test.py b/tensorflow/python/framework/test_util_test.py index 6129fa2e0d06e3ac271ace515a0e3ab8fb98ac9d..b2f8d62095f75ba55344a63401525ea998a70b47 100644 --- a/tensorflow/python/framework/test_util_test.py +++ b/tensorflow/python/framework/test_util_test.py @@ -27,12 +27,16 @@ from google.protobuf import text_format from tensorflow.core.framework import graph_pb2 from tensorflow.core.protobuf import meta_graph_pb2 +from tensorflow.python.client import session +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import random_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables from tensorflow.python.platform import googletest @@ -325,5 +329,72 @@ class TestUtilTest(test_util.TensorFlowTestCase): self.assertEqual(a_rand, b_rand) +@test_util.with_c_api +class IsolationTest(test_util.TensorFlowTestCase): + + @test_util.run_in_graph_and_eager_modes() + def test_variable_reuse_exception(self): + with test_util.IsolateTest(), session.Session(): + first_container_variable = resource_variable_ops.ResourceVariable( + name="first_container_variable", + initial_value=1) + if context.in_graph_mode(): + self.evaluate([variables.global_variables_initializer()]) + with test_util.IsolateTest(): + if context.in_graph_mode(): + with self.assertRaises(RuntimeError): + self.evaluate(first_container_variable.read_value()) + else: + with self.assertRaises(ValueError): + first_container_variable.read_value() + + @test_util.run_in_graph_and_eager_modes() + def test_variable_reuse_exception_nested(self): + with test_util.IsolateTest(), session.Session(): + first_container_variable = resource_variable_ops.ResourceVariable( + name="first_container_variable", + initial_value=1) + if context.in_graph_mode(): + self.evaluate([variables.global_variables_initializer()]) + with test_util.IsolateTest(), session.Session(): + if context.in_graph_mode(): + with self.assertRaises(RuntimeError): + self.evaluate(first_container_variable.read_value()) + else: + with self.assertRaises(ValueError): + first_container_variable.read_value() + + @test_util.run_in_graph_and_eager_modes() + def test_no_sharing(self): + with test_util.IsolateTest(), session.Session(): + first_container_variable = resource_variable_ops.ResourceVariable( + name="same_name", + initial_value=1) + if context.in_graph_mode(): + self.evaluate([variables.global_variables_initializer()]) + with test_util.IsolateTest(), session.Session(): + second_container_variable = resource_variable_ops.ResourceVariable( + name="same_name", + initial_value=2) + if context.in_graph_mode(): + self.evaluate([variables.global_variables_initializer()]) + self.assertEqual( + 2, self.evaluate(second_container_variable.read_value())) + self.assertEqual(1, self.evaluate(first_container_variable.read_value())) + + def test_graph_mode_isolation(self): + with context.graph_mode(): + # Even if we've (accidentally) called IsolateTest in Graph mode, it should + # provide Eager isolation. + with test_util.IsolateTest(): + with context.eager_mode(): + first_container_variable = resource_variable_ops.ResourceVariable( + name="first_container_variable", + initial_value=1) + with context.eager_mode(): + with self.assertRaises(ValueError): + first_container_variable.read_value() + + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index 9a164449900eec2c1711dc4ad720565ce7a11ea8..2a6d78e805d259e58d375fb55b899c6f1be77469 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -2312,6 +2312,7 @@ cuda_py_test( "//tensorflow/python:control_flow_ops", "//tensorflow/python:data_flow_grad", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", "//tensorflow/python:gradients", "//tensorflow/python:init_ops", "//tensorflow/python:nn_grad", @@ -2320,6 +2321,7 @@ cuda_py_test( "//tensorflow/python:sparse_grad", "//tensorflow/python:tensor_array_grad", "//tensorflow/python:variables", + "//tensorflow/python/eager:context", ], shard_count = 10, tags = ["no_windows"], diff --git a/tensorflow/python/kernel_tests/qr_op_test.py b/tensorflow/python/kernel_tests/qr_op_test.py index b4fd89bd03784082575be592f19792ca18d0b899..8848c15e765236c2ae2817cce1510c4c1ab04740 100644 --- a/tensorflow/python/kernel_tests/qr_op_test.py +++ b/tensorflow/python/kernel_tests/qr_op_test.py @@ -22,6 +22,7 @@ import numpy as np from tensorflow.python.framework import constant_op from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops @@ -140,11 +141,11 @@ def _GetQrOpTest(dtype_, shape_, full_matrices_, use_static_shape_): x_reshape = np.reshape(x_np, (-1, x_np.shape[-2], x_np.shape[-1])) for i in range(new_first_dim): if full_matrices_: - np_q_reshape[i,:,:], _ = \ - np.linalg.qr(x_reshape[i,:,:], mode="complete") + np_q_reshape[i, :, :], _ = np.linalg.qr( + x_reshape[i, :, :], mode="complete") else: - np_q_reshape[i,:,:], _ = \ - np.linalg.qr(x_reshape[i,:,:], mode="reduced") + np_q_reshape[i, :, :], _ = np.linalg.qr( + x_reshape[i, :, :], mode="reduced") np_q = np.reshape(np_q_reshape, q_dims) CompareOrthogonal(self, np_q, q_tf_val, min(shape_[-2:])) CheckApproximation(self, x_np, q_tf_val, r_tf_val) @@ -153,6 +154,46 @@ def _GetQrOpTest(dtype_, shape_, full_matrices_, use_static_shape_): return Test +class QrGradOpTest(test.TestCase): + pass + + +def _GetQrGradOpTest(dtype_, shape_, full_matrices_): + + def Test(self): + np.random.seed(42) + a = np.random.uniform(low=-1.0, high=1.0, size=shape_).astype(dtype_) + if dtype_ in [np.complex64, np.complex128]: + a += 1j * np.random.uniform( + low=-1.0, high=1.0, size=shape_).astype(dtype_) + # Optimal stepsize for central difference is O(epsilon^{1/3}). + epsilon = np.finfo(dtype_).eps + delta = 0.1 * epsilon**(1.0 / 3.0) + if dtype_ in [np.float32, np.complex64]: + tol = 3e-2 + else: + tol = 1e-6 + with self.test_session(use_gpu=True): + tf_a = constant_op.constant(a) + tf_b = linalg_ops.qr(tf_a, full_matrices=full_matrices_) + for b in tf_b: + x_init = np.random.uniform( + low=-1.0, high=1.0, size=shape_).astype(dtype_) + if dtype_ in [np.complex64, np.complex128]: + x_init += 1j * np.random.uniform( + low=-1.0, high=1.0, size=shape_).astype(dtype_) + theoretical, numerical = gradient_checker.compute_gradient( + tf_a, + tf_a.get_shape().as_list(), + b, + b.get_shape().as_list(), + x_init_value=x_init, + delta=delta) + self.assertAllClose(theoretical, numerical, atol=tol, rtol=tol) + + return Test + + if __name__ == "__main__": for dtype in np.float32, np.float64, np.complex64, np.complex128: for rows in 1, 2, 5, 10, 32, 100: @@ -168,4 +209,21 @@ if __name__ == "__main__": _AddTest(QrOpTest, "Qr", name, _GetQrOpTest(dtype, shape, full_matrices, use_static_shape)) + + # TODO(pfau): Get working with complex types. + # TODO(pfau): Get working with full_matrices when rows != cols + # TODO(pfau): Get working when rows < cols + # TODO(pfau): Get working with shapeholders (dynamic shapes) + for full_matrices in False, True: + for dtype in np.float32, np.float64: + for rows in 1, 2, 5, 10: + for cols in 1, 2, 5, 10: + if rows == cols or (not full_matrices and rows > cols): + for batch_dims in [(), (3,)] + [(3, 2)] * (max(rows, cols) < 10): + shape = batch_dims + (rows, cols) + name = "%s_%s_full_%s" % (dtype.__name__, + "_".join(map(str, shape)), + full_matrices) + _AddTest(QrGradOpTest, "QrGrad", name, + _GetQrGradOpTest(dtype, shape, full_matrices)) test.main() diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py index 23676223dc661fbdfc1141687c8b760a161cbd24..10f9a72c7bbb76ab9b880a997f3134607e3c925b 100644 --- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py +++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py @@ -181,7 +181,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): @test_util.run_in_graph_and_eager_modes() def testInitFnDtype(self): v = resource_variable_ops.ResourceVariable( - initial_value=lambda: 1, dtype=dtypes.float32) + initial_value=lambda: 1, dtype=dtypes.float32, name="var0") self.assertEqual(dtypes.float32, v.value().dtype) @test_util.run_in_graph_and_eager_modes() @@ -192,26 +192,27 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): @test_util.run_in_graph_and_eager_modes() def testInitializeAllVariables(self): - v = resource_variable_ops.ResourceVariable(1, dtype=dtypes.float32) + v = resource_variable_ops.ResourceVariable(1, dtype=dtypes.float32, + name="var0") self.evaluate(variables.global_variables_initializer()) self.assertEqual(1.0, self.evaluate(v.value())) @test_util.run_in_graph_and_eager_modes() def testOperatorOverload(self): - v = resource_variable_ops.ResourceVariable(1.0) + v = resource_variable_ops.ResourceVariable(1.0, name="var0") self.evaluate(variables.global_variables_initializer()) self.assertEqual(2.0, self.evaluate(v + v)) @test_util.run_in_graph_and_eager_modes() def testAssignMethod(self): - v = resource_variable_ops.ResourceVariable(1.0) + v = resource_variable_ops.ResourceVariable(1.0, name="var0") self.evaluate(variables.global_variables_initializer()) self.evaluate(v.assign(2.0)) self.assertEqual(2.0, self.evaluate(v.value())) @test_util.run_in_graph_and_eager_modes() def testLoad(self): - v = resource_variable_ops.ResourceVariable(1.0) + v = resource_variable_ops.ResourceVariable(1.0, name="var0") self.evaluate(variables.global_variables_initializer()) v.load(2.0) self.assertEqual(2.0, self.evaluate(v.value())) @@ -237,21 +238,21 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): @test_util.run_in_graph_and_eager_modes() def testAssignAddMethod(self): - v = resource_variable_ops.ResourceVariable(1.0) + v = resource_variable_ops.ResourceVariable(1.0, name="var0") self.evaluate(variables.global_variables_initializer()) self.evaluate(v.assign_add(1.0)) self.assertEqual(2.0, self.evaluate(v.value())) @test_util.run_in_graph_and_eager_modes() def testAssignSubMethod(self): - v = resource_variable_ops.ResourceVariable(3.0) + v = resource_variable_ops.ResourceVariable(3.0, name="var0") self.evaluate(variables.global_variables_initializer()) self.evaluate(v.assign_sub(1.0)) self.assertEqual(2.0, self.evaluate(v.value())) @test_util.run_in_graph_and_eager_modes() def testDestroyResource(self): - v = resource_variable_ops.ResourceVariable(3.0) + v = resource_variable_ops.ResourceVariable(3.0, name="var0") self.evaluate(variables.global_variables_initializer()) self.assertEqual(3.0, self.evaluate(v.value())) self.evaluate(resource_variable_ops.destroy_resource_op(v.handle)) @@ -309,12 +310,15 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): self.evaluate(variables.global_variables_initializer()) w = resource_variable_ops.var_handle_op( - dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var4") + dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var4", + # Needed in Eager since we get a unique container name by default. + container=ops.get_default_graph()._container) w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype) self.assertEqual(300.0, self.evaluate(w_read)) x = resource_variable_ops.var_handle_op( - dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var5") + dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var5", + container=ops.get_default_graph()._container) with self.assertRaisesOpError("Resource .*/var5/.* does not exist"): x_read = resource_variable_ops.read_variable_op(x, v.dtype.base_dtype) self.evaluate(x_read) @@ -328,7 +332,9 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): self.evaluate(variables.global_variables_initializer()) w = resource_variable_ops.var_handle_op( - dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="foo/var6") + dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="foo/var6", + # Needed in Eager since we get a unique container name by default. + container=ops.get_default_graph()._container) w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype) self.assertEqual(300.0, self.evaluate(w_read)) @@ -438,6 +444,21 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): resource_variable_ops.destroy_resource_op(var._handle, ignore_lookup_error=False) + def testSharingViaResourceVariableObject(self): + with context.eager_mode(): + _ = resource_variable_ops.ResourceVariable(1.0, name="var0") + with self.assertRaisesRegexp(ValueError, + "'var0' already created"): + _ = resource_variable_ops.ResourceVariable(2.0, name="var0") + with ops.Graph().as_default(): + _ = resource_variable_ops.ResourceVariable(2.0, name="var0") + + def testVariableNameMissing(self): + with context.eager_mode(): + with self.assertRaisesRegexp(ValueError, + "Variables need to have explicit names"): + _ = resource_variable_ops.ResourceVariable(1.0) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py index a644e6a44fa7f0e1bec7e3ea664a8a79a202ad05..d8f4b439e37981f3d21181feae9baa8d492ee1d5 100644 --- a/tensorflow/python/kernel_tests/rnn_test.py +++ b/tensorflow/python/kernel_tests/rnn_test.py @@ -26,9 +26,12 @@ import numpy as np from tensorflow.contrib import rnn as contrib_rnn from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops as ops_lib from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients_impl @@ -82,9 +85,13 @@ class RNNTest(test.TestCase): self._seed = 23489 np.random.seed(self._seed) + @test_util.run_in_graph_and_eager_modes() def testInvalidSequenceLengthShape(self): cell = Plus1RNNCell() - inputs = [array_ops.placeholder(dtypes.float32, shape=(3, 4))] + if context.in_graph_mode(): + inputs = [array_ops.placeholder(dtypes.float32, shape=(3, 4))] + else: + inputs = [constant_op.constant(np.ones((3, 4)))] with self.assertRaisesRegexp(ValueError, "must be a vector"): rnn.dynamic_rnn( cell, @@ -92,45 +99,77 @@ class RNNTest(test.TestCase): dtype=dtypes.float32, sequence_length=[[4]]) + @test_util.run_in_graph_and_eager_modes() def testBatchSizeFromInput(self): cell = Plus1RNNCell() + in_graph_mode = context.in_graph_mode() # With static batch size - inputs = array_ops.placeholder(dtypes.float32, shape=(3, 4, 5)) + if in_graph_mode: + inputs = array_ops.placeholder(dtypes.float32, shape=(3, 4, 5)) + initial_state = array_ops.placeholder(dtypes.float32, shape=(3, 5)) + else: + inputs = np.zeros((3, 4, 5), dtype=np.float32) + initial_state = np.zeros((3, 5), dtype=np.float32) + # - Without initial_state outputs, state = rnn.dynamic_rnn(cell, inputs, dtype=dtypes.float32) - self.assertEqual(3, outputs.shape[0].value) - self.assertEqual(3, state.shape[0].value) + if in_graph_mode: + self.assertEqual(3, outputs.shape[0].value) + self.assertEqual(3, state.shape[0].value) + else: + self.assertEqual(3, outputs.shape[0]) + self.assertEqual(3, state.shape[0]) + # - With initial_state outputs, state = rnn.dynamic_rnn( - cell, - inputs, - initial_state=array_ops.placeholder(dtypes.float32, shape=(3, 5))) - self.assertEqual(3, outputs.shape[0].value) - self.assertEqual(3, state.shape[0].value) + cell, inputs, initial_state=initial_state) + if in_graph_mode: + self.assertEqual(3, outputs.shape[0].value) + self.assertEqual(3, state.shape[0].value) + else: + self.assertEqual(3, outputs.shape[0]) + self.assertEqual(3, state.shape[0]) + # Without static batch size - inputs = array_ops.placeholder(dtypes.float32, shape=(None, 4, 5)) - # - Without initial_state - outputs, state = rnn.dynamic_rnn(cell, inputs, dtype=dtypes.float32) - self.assertEqual(None, outputs.shape[0].value) - self.assertEqual(None, state.shape[0].value) - # - With initial_state - outputs, state = rnn.dynamic_rnn( - cell, - inputs, - initial_state=array_ops.placeholder(dtypes.float32, shape=(None, 5))) - self.assertEqual(None, outputs.shape[0].value) - self.assertEqual(None, state.shape[0].value) + # Tensor shapes are fully determined in Eager mode, so only run this + # test in graph mode. + if in_graph_mode: + inputs = array_ops.placeholder(dtypes.float32, shape=(None, 4, 5)) + # - Without initial_state + outputs, state = rnn.dynamic_rnn(cell, inputs, dtype=dtypes.float32) + self.assertEqual(None, outputs.shape[0].value) + self.assertEqual(None, state.shape[0].value) + # - With initial_state + outputs, state = rnn.dynamic_rnn( + cell, + inputs, + initial_state=array_ops.placeholder(dtypes.float32, shape=(None, 5))) + self.assertEqual(None, outputs.shape[0].value) + self.assertEqual(None, state.shape[0].value) + @test_util.run_in_graph_and_eager_modes() def testScalarStateIsAccepted(self): cell = ScalarStateRNNCell() - inputs = array_ops.placeholder(dtypes.float32, shape=(1, 4, 1)) + in_graph_mode = context.in_graph_mode() + + if in_graph_mode: + inputs = array_ops.placeholder(dtypes.float32, shape=(1, 4, 1)) + else: + inputs = np.array([[[1], [2], [3], [4]]], dtype=np.float32) + with self.test_session() as sess: outputs, state = rnn.dynamic_rnn( cell, inputs, dtype=dtypes.float32, sequence_length=[4]) - outputs, state = sess.run( - [outputs, state], feed_dict={inputs: [[[1], [2], [3], [4]]]}) - self.assertAllEqual(outputs, [[[1], [2], [3], [4]]]) - self.assertEqual(state, 4) + if in_graph_mode: + outputs, state = sess.run( + [outputs, state], feed_dict={inputs: [[[1], [2], [3], [4]]]}) + + if in_graph_mode: + self.assertAllEqual(outputs, np.array([[[1], [2], [3], [4]]])) + self.assertEqual(state, 4) + else: + self.assertAllEqual(outputs.numpy(), np.array([[[1], [2], [3], [4]]])) + self.assertEqual(state.numpy(), 4) ######### Benchmarking RNN code diff --git a/tensorflow/python/ops/linalg_grad.py b/tensorflow/python/ops/linalg_grad.py index ec263591e10307abf5a40e21ff6d995c10602dcb..8a76fe3ce55bbdea1677f83fe075ed3bdc8d875d 100644 --- a/tensorflow/python/ops/linalg_grad.py +++ b/tensorflow/python/ops/linalg_grad.py @@ -81,6 +81,36 @@ def _CholeskyGrad(op, grad): return grad_a * 0.5 +@ops.RegisterGradient("Qr") +def _QrGrad(op, dq, dr): + """Gradient for Qr.""" + q, r = op.outputs + if q.dtype.is_complex: + raise NotImplementedError("QrGrad not implemented for dtype: %s" % q.dtype) + if (r.shape.ndims is None or r.shape.as_list()[-2] is None or + r.shape.as_list()[-1] is None): + raise NotImplementedError("QrGrad not implemented with dynamic shapes.") + if r.shape[-2].value != r.shape[-1].value: + raise NotImplementedError("QrGrad not implemented when ncols > nrows " + "or full_matrices is true and ncols != nrows.") + + qdq = math_ops.matmul(q, dq, adjoint_a=True) + qdq_ = qdq - _linalg.adjoint(qdq) + rdr = math_ops.matmul(r, dr, adjoint_b=True) + rdr_ = rdr - _linalg.adjoint(rdr) + tril = array_ops.matrix_band_part(qdq_ + rdr_, -1, 0) + + def _TriangularSolve(x, r): + """Equiv to matmul(x, adjoint(matrix_inverse(r))) if r is upper-tri.""" + return _linalg.adjoint( + linalg_ops.matrix_triangular_solve( + r, _linalg.adjoint(x), lower=False, adjoint=False)) + + grad_a = math_ops.matmul(q, dr + _TriangularSolve(tril, r)) + grad_b = _TriangularSolve(dq - math_ops.matmul(q, qdq), r) + return grad_a + grad_b + + @ops.RegisterGradient("MatrixSolve") def _MatrixSolveGrad(op, grad): """Gradient for MatrixSolve.""" @@ -105,7 +135,7 @@ def _MatrixSolveLsGrad(op, grad): # b) Implement a symmetric rank-k update op instead of computing # x*z + transpose(x*z). This pattern occurs other places in TensorFlow. - def _overdetermined(op, grad): + def _Overdetermined(op, grad): """Gradients for the overdetermined case of MatrixSolveLs. This is the backprop for the solution to the normal equations of the first @@ -130,7 +160,7 @@ def _MatrixSolveLsGrad(op, grad): grad_b = math_ops.matmul(a, z) return (grad_a, grad_b, None) - def _underdetermined(op, grad): + def _Underdetermined(op, grad): """Gradients for the underdetermined case of MatrixSolveLs. This is the backprop for the solution to the normal equations of the second @@ -162,16 +192,16 @@ def _MatrixSolveLsGrad(op, grad): matrix_shape = op.inputs[0].get_shape()[-2:] if matrix_shape.is_fully_defined(): if matrix_shape[-2] >= matrix_shape[-1]: - return _overdetermined(op, grad) + return _Overdetermined(op, grad) else: - return _underdetermined(op, grad) + return _Underdetermined(op, grad) else: # We have to defer determining the shape to runtime and use # conditional execution of the appropriate graph. matrix_shape = array_ops.shape(op.inputs[0])[-2:] return control_flow_ops.cond(matrix_shape[-2] >= matrix_shape[-1], - lambda: _overdetermined(op, grad), - lambda: _underdetermined(op, grad)) + lambda: _Overdetermined(op, grad), + lambda: _Underdetermined(op, grad)) @ops.RegisterGradient("MatrixTriangularSolve") diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index dd3f167145af3d21320ca9d963f132959cfe574f..c94ddb06275d9bfdd5c4e79ff8efe49925a99274 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -49,6 +49,16 @@ def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode): container=container) if graph_mode: return handle + + # We do not want two distinct ResourceVariable objects for the same + # underlying resource in the runtime. + # When in eager mode, explicitly ensure so here. When in graph mode, it's + # ensured by always generating different variable names. + exists = gen_resource_variable_ops.var_is_initialized_op(handle) + if exists: + raise ValueError("variable object with name '%s' already created. Use " + "get_variable() if reuse is desired." % + shared_name) with context.graph_mode(), ops.Graph().as_default(): h = gen_resource_variable_ops.var_handle_op(shape=shape, dtype=dtype, shared_name=shared_name, @@ -270,6 +280,15 @@ class ResourceVariable(variables.Variable): collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES] self._save_slice_info = None self._in_graph_mode = context.in_graph_mode() + # Save the graph's container prefix for error checking. Reading the value of + # the ResourceVariable from another Graph in Eager mode is an error. + self._container_prefix = ops.get_default_graph()._container_prefix # pylint: disable=protected-access + if not self._in_graph_mode and not name: + # TODO(ashankar,josh11b): make this unnecessary using the same + # logic as in layer + raise ValueError("Variables need to have explicit names when eager " + "execution is enabled") + with ops.control_dependencies(None): with ops.name_scope(name, "Variable", [] if init_from_fn else [initial_value]) as name: @@ -577,7 +596,15 @@ class ResourceVariable(variables.Variable): Returns: the read operation. + Raises: + ValueError: if the ResourceVariable was created in another isolation + environment or graph. """ + if (not self._in_graph_mode and + self._container_prefix != ops.get_default_graph()._container_prefix): # pylint: disable=protected-access + raise ValueError( + "Attempted to read a variable from another isolation environment" + " or Graph") with ops.name_scope("Read"): # Ensure we read the variable in the same device as the handle. with ops.device(self._handle_device): diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py index b174956e6041db918aa4b8f5a391bfbc60aa6bda..21c7ed361dc8d613d3332905ded1952dfe34681c 100644 --- a/tensorflow/python/ops/rnn.py +++ b/tensorflow/python/ops/rnn.py @@ -27,6 +27,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -576,8 +577,9 @@ def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None, # determined by the parent scope, or is set to place the cached # Variable using the same placement as for the rest of the RNN. with vs.variable_scope(scope or "rnn") as varscope: - if varscope.caching_device is None: - varscope.set_caching_device(lambda op: op.device) + if context.in_graph_mode(): + if varscope.caching_device is None: + varscope.set_caching_device(lambda op: op.device) batch_size = _best_effort_input_batch_size(flat_input) if initial_state is not None: @@ -595,7 +597,7 @@ def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None, ["Expected shape for Tensor %s is " % x.name, packed_shape, " but saw shape: ", x_shape]) - if sequence_length is not None: + if context.in_graph_mode() and sequence_length is not None: # Perform some shape validation with ops.control_dependencies( [_assert_has_shape(sequence_length, [batch_size])]): @@ -718,14 +720,19 @@ def _dynamic_rnn_loop(cell, size=time_steps, tensor_array_name=base_name + name) - output_ta = tuple(_create_ta("output_%d" % i, - _infer_state_dtype(dtype, state)) - for i in range(len(flat_output_size))) - input_ta = tuple(_create_ta("input_%d" % i, flat_input[i].dtype) - for i in range(len(flat_input))) - - input_ta = tuple(ta.unstack(input_) - for ta, input_ in zip(input_ta, flat_input)) + in_graph_mode = context.in_graph_mode() + if in_graph_mode: + output_ta = tuple(_create_ta("output_%d" % i, + _infer_state_dtype(dtype, state)) + for i in range(len(flat_output_size))) + input_ta = tuple(_create_ta("input_%d" % i, flat_input[i].dtype) + for i in range(len(flat_input))) + input_ta = tuple(ta.unstack(input_) + for ta, input_ in zip(input_ta, flat_input)) + else: + output_ta = tuple([0 for _ in range(time_steps.numpy())] + for i in range(len(flat_output_size))) + input_ta = flat_input def _time_step(time, output_ta_t, state): """Take a time step of the dynamic RNN. @@ -739,10 +746,13 @@ def _dynamic_rnn_loop(cell, The tuple (time + 1, output_ta_t with updated flow, new_state). """ - input_t = tuple(ta.read(time) for ta in input_ta) - # Restore some shape information - for input_, shape in zip(input_t, inputs_got_shape): - input_.set_shape(shape[1:]) + if in_graph_mode: + input_t = tuple(ta.read(time) for ta in input_ta) + # Restore some shape information + for input_, shape in zip(input_t, inputs_got_shape): + input_.set_shape(shape[1:]) + else: + input_t = tuple(ta[time.numpy()] for ta in input_ta) input_t = nest.pack_sequence_as(structure=inputs, flat_sequence=input_t) call_cell = lambda: cell(input_t, state) @@ -764,8 +774,12 @@ def _dynamic_rnn_loop(cell, # Pack state if using state tuples output = nest.flatten(output) - output_ta_t = tuple( - ta.write(time, out) for ta, out in zip(output_ta_t, output)) + if in_graph_mode: + output_ta_t = tuple( + ta.write(time, out) for ta, out in zip(output_ta_t, output)) + else: + for ta, out in zip(output_ta_t, output): + ta[time.numpy()] = out return (time + 1, output_ta_t, new_state) @@ -777,16 +791,20 @@ def _dynamic_rnn_loop(cell, swap_memory=swap_memory) # Unpack final output if not using output tuples. - final_outputs = tuple(ta.stack() for ta in output_final_ta) - - # Restore some shape information - for output, output_size in zip(final_outputs, flat_output_size): - shape = _concat( - [const_time_steps, const_batch_size], output_size, static=True) - output.set_shape(shape) + if in_graph_mode: + final_outputs = tuple(ta.stack() for ta in output_final_ta) + # Restore some shape information + for output, output_size in zip(final_outputs, flat_output_size): + shape = _concat( + [const_time_steps, const_batch_size], output_size, static=True) + output.set_shape(shape) + else: + final_outputs = output_final_ta final_outputs = nest.pack_sequence_as( structure=cell.output_size, flat_sequence=final_outputs) + if not in_graph_mode: + final_outputs = array_ops.stack(final_outputs, axis=0) return (final_outputs, final_state) @@ -967,8 +985,9 @@ def raw_rnn(cell, loop_fn, # determined by the parent scope, or is set to place the cached # Variable using the same placement as for the rest of the RNN. with vs.variable_scope(scope or "rnn") as varscope: - if varscope.caching_device is None: - varscope.set_caching_device(lambda op: op.device) + if context.in_graph_mode(): + if varscope.caching_device is None: + varscope.set_caching_device(lambda op: op.device) time = constant_op.constant(0, dtype=dtypes.int32) (elements_finished, next_input, initial_state, emit_structure, @@ -1166,8 +1185,9 @@ def static_rnn(cell, # determined by the parent scope, or is set to place the cached # Variable using the same placement as for the rest of the RNN. with vs.variable_scope(scope or "rnn") as varscope: - if varscope.caching_device is None: - varscope.set_caching_device(lambda op: op.device) + if context.in_graph_mode(): + if varscope.caching_device is None: + varscope.set_caching_device(lambda op: op.device) # Obtain the first sequence of the input first_input = inputs diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i index 5c624a9c126e6b2cfaccc4da8e7acbd4e325bb64..36c09c20c217bbe16dd33fa4c19c7cb4cad99139 100644 --- a/tensorflow/python/pywrap_tfe.i +++ b/tensorflow/python/pywrap_tfe.i @@ -30,12 +30,25 @@ limitations under the License. %rename("%s") TFE_Py_TapeDeleteTrace; %rename("%s") TFE_Py_TapeRecordOperation; %rename("%s") TFE_Py_TapeExport; - +%rename("%s") TFE_NewContextOptions; +%rename("%s") TFE_ContextOptionsSetConfig; +%rename("%s") TFE_DeleteContextOptions; %{ #include "tensorflow/python/eager/pywrap_tfe.h" %} +%typemap(in) (const void* proto) { + char* c_string; + Py_ssize_t py_size; + // PyBytes_AsStringAndSize() does not copy but simply interprets the input + if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) { + // Python has raised an error (likely TypeError or UnicodeEncodeError). + SWIG_fail; + } + $1 = static_cast(c_string); +} + %typemap(out) TF_DataType { $result = PyInt_FromLong($1); } @@ -165,3 +178,4 @@ limitations under the License. %typemap(in, numinputs=0) TF_Status *out_status; %typemap(freearg) (TF_Status* out_status); %typemap(argout) (TFE_OutputTensorHandles* outputs, TF_Status* out_status); +%typemap(in) (const void* proto); diff --git a/tensorflow/python/saved_model/signature_def_utils_impl.py b/tensorflow/python/saved_model/signature_def_utils_impl.py index 564befeb0b56146fee169cbcd031f0d5ce3e1a82..240ea61aa5f8553852044f84b61d010bfbca69d1 100644 --- a/tensorflow/python/saved_model/signature_def_utils_impl.py +++ b/tensorflow/python/saved_model/signature_def_utils_impl.py @@ -56,9 +56,13 @@ def build_signature_def(inputs=None, outputs=None, method_name=None): def regression_signature_def(examples, predictions): """Creates regression signature from given examples and predictions. + This function produces signatures intended for use with the TensorFlow Serving + Regress API (tensorflow_serving/apis/prediction_service.proto), and so + constrains the input and output types to those allowed by TensorFlow Serving. + Args: - examples: `Tensor`. - predictions: `Tensor`. + examples: A string `Tensor`, expected to accept serialized tf.Examples. + predictions: A float `Tensor`. Returns: A regression-flavored signature_def. @@ -93,10 +97,15 @@ def regression_signature_def(examples, predictions): def classification_signature_def(examples, classes, scores): """Creates classification signature from given examples and predictions. + This function produces signatures intended for use with the TensorFlow Serving + Classify API (tensorflow_serving/apis/prediction_service.proto), and so + constrains the input and output types to those allowed by TensorFlow Serving. + Args: - examples: `Tensor`. - classes: `Tensor`. - scores: `Tensor`. + examples: A string `Tensor`, expected to accept serialized tf.Examples. + classes: A string `Tensor`. Note that the ClassificationResponse message + requires that class labels are strings, not integers or anything else. + scores: a float `Tensor`. Returns: A classification-flavored signature_def. @@ -140,6 +149,10 @@ def classification_signature_def(examples, classes, scores): def predict_signature_def(inputs, outputs): """Creates prediction signature from given inputs and outputs. + This function produces signatures intended for use with the TensorFlow Serving + Predict API (tensorflow_serving/apis/prediction_service.proto). This API + imposes no constraints on the input and output types. + Args: inputs: dict of string to `Tensor`. outputs: dict of string to `Tensor`. diff --git a/tensorflow/python/training/input.py b/tensorflow/python/training/input.py index 704017c244625e171a587789253fdb047cad0599..36f97960ddd3b90872453fb4fc7c9e47b7368e49 100644 --- a/tensorflow/python/training/input.py +++ b/tensorflow/python/training/input.py @@ -32,7 +32,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_util +from tensorflow.python.layers import utils from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import data_flow_ops @@ -413,22 +413,6 @@ def _as_original_type(original_tensors, tensor_list): return tensor_list -def _smart_cond(pred, if_true, if_false): - """A `tf.cond` that does nothing when the condition is static.""" - pred = ops.convert_to_tensor(pred) - static_pred = tensor_util.constant_value(pred) - if static_pred is not None: - if static_pred: - return if_true() - else: - return if_false() - else: - return control_flow_ops.cond( - pred, - if_true, - if_false) - - def _store_sparse_tensors(tensor_list, enqueue_many, keep_input, shared_map_ops=None): """Store SparseTensors for feeding into batch, etc. @@ -480,13 +464,13 @@ def _store_sparse_tensors(tensor_list, enqueue_many, keep_input, map_op_name = shared_map_op.name if shared_map_op else None def _maybe_store_sparse(t, map_op_name, keep_input): """Conditionally store a single sparse Tensor.""" - return _smart_cond( + return utils.smart_cond( keep_input, lambda: _store_sparse(t, shared_name=map_op_name), lambda: constant_op.constant(-1, dtypes.int64)) def _maybe_store_many_sparse(t, map_op_name, keep_input): """Conditionally store multiple sparse Tensors.""" - out_tensor = _smart_cond( + out_tensor = utils.smart_cond( keep_input, lambda: _store_many_sparse(t, shared_name=map_op_name), lambda: -1 * array_ops.ones(array_ops.shape(t)[0:1], dtypes.int64)) @@ -667,7 +651,7 @@ def _enqueue_join(queue, tensor_list_list, enqueue_many, keep_input): enqueue_ops = [enqueue_fn(_select_which_to_enqueue(x, keep_input)) for x in tensor_list_list] else: - enqueue_ops = [_smart_cond( + enqueue_ops = [utils.smart_cond( keep_input, lambda: enqueue_fn(tl), # pylint:disable=cell-var-from-loop control_flow_ops.no_op) for tl in tensor_list_list] @@ -684,7 +668,7 @@ def _enqueue(queue, tensor_list, threads, enqueue_many, keep_input): enqueue_ops = [ enqueue_fn(_select_which_to_enqueue(tensor_list, keep_input))] * threads else: - enqueue_ops = [_smart_cond( + enqueue_ops = [utils.smart_cond( keep_input, lambda: enqueue_fn(tensor_list), control_flow_ops.no_op)] * threads diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index b1926f4eaf69a3e7e83629f962e2f6f6d170137b..c4c1df22eb5b6116c2d5415d4babdfbd8fefef5d 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -557,7 +557,14 @@ class BaseSaverBuilder(object): if not isinstance(var, resource_variable_ops.ResourceVariable): raise ValueError("Can only save/restore ResourceVariable eager " "mode is enabled, type: %s." % type(var)) - names_to_saveables[var._shared_name] = var + set_var = names_to_saveables.setdefault(var._shared_name, var) + if set_var is not var: + raise ValueError( + ("Two different ResourceVariable objects with the same " + "shared_name '%s' were passed to the Saver. This likely means " + "that they were created in different Graphs or isolation " + "contexts, and may not be checkpointed together.") % ( + var._shared_name,)) # pylint: enable=protected-access return names_to_saveables diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py index aeb8eaffe875fd49d573a541f5095b644684ce4d..4abff1d106ade4a38435a58464b6089adcb2c732 100644 --- a/tensorflow/python/training/saver_test.py +++ b/tensorflow/python/training/saver_test.py @@ -233,7 +233,8 @@ class SaverTest(test.TestCase): def testResourceSaveRestoreCachingDevice(self): save_path = os.path.join(self.get_temp_dir(), "resource_cache") with self.test_session(graph=ops_lib.Graph()) as sess: - v = resource_variable_ops.ResourceVariable([1], caching_device="/cpu:0") + v = resource_variable_ops.ResourceVariable([1], caching_device="/cpu:0", + name="v") if context.in_graph_mode(): self.evaluate(variables.global_variables_initializer()) else: diff --git a/tensorflow/tools/ci_build/ci_sanity.sh b/tensorflow/tools/ci_build/ci_sanity.sh index 4e72d025a22cb90428e396c0cfdd1a7c545222eb..1703cae1e5d6d45522660a9fa3b395586d789ece 100755 --- a/tensorflow/tools/ci_build/ci_sanity.sh +++ b/tensorflow/tools/ci_build/ci_sanity.sh @@ -95,6 +95,7 @@ do_pylint() { "^tensorflow/python/platform/default/_googletest\.py.*\[E0102.*function\salready\sdefined "\ "^tensorflow/python/feature_column/feature_column_test\.py.*\[E0110.*abstract-class-instantiated "\ "^tensorflow/contrib/layers/python/layers/feature_column\.py.*\[E0110.*abstract-class-instantiated "\ +"^tensorflow/contrib/eager/python/metrics_impl\.py.*\[E0202.*method-hidden "\ "^tensorflow/python/platform/gfile\.py.*\[E0301.*non-iterator "\ "^tensorflow/python/keras/_impl/keras/callbacks\.py.*\[E1133.*not-an-iterable" diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 0997fffc8a28336534d91802494e53c8e286fb25..15f8cfb72eb1a14732f44bad6dd15f953df8b81d 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -157,7 +157,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): mkl_repository( name = "mkl", urls = [ - "http://mirror.bazel.build/github.com/01org/mkl-dnn/releases/download/v0.9/mklml_lnx_2018.0.20170720.tgz", + "https://mirror.bazel.build/github.com/01org/mkl-dnn/releases/download/v0.9/mklml_lnx_2018.0.20170720.tgz", # "https://github.com/01org/mkl-dnn/releases/download/v0.9/mklml_lnx_2018.0.20170720.tgz", ], sha256 = "57ba56c4c243f403ff78f417ff854ef50b9eddf4a610a917b7c95e7fa8553a4b", @@ -174,7 +174,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "mkl_dnn", urls = [ "https://github.com/01org/mkl-dnn/archive/b01e3a55a07be62172e713bcd2644c5176360212.tar.gz", - "http://mirror.bazel.build/github.com/01org/mkl-dnn/archive/b01e3a55a07be62172e713bcd2644c5176360212.tar.gz", + "https://mirror.bazel.build/github.com/01org/mkl-dnn/archive/b01e3a55a07be62172e713bcd2644c5176360212.tar.gz", ], sha256 = "0d529ad4c49dc799e6df07c2b88b115d0668735da15fb3b3862d28d33fa68165", strip_prefix = "mkl-dnn-b01e3a55a07be62172e713bcd2644c5176360212", @@ -185,7 +185,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "eigen_archive", urls = [ "https://bitbucket.org/eigen/eigen/get/429aa5254200.tar.gz", - "http://mirror.bazel.build/bitbucket.org/eigen/eigen/get/429aa5254200.tar.gz", + "https://mirror.bazel.build/bitbucket.org/eigen/eigen/get/429aa5254200.tar.gz", ], sha256 = "61d8b6fc4279dd1dda986fb1677d15e3d641c07a3ea5abe255790b1f0c0c14e9", strip_prefix = "eigen-eigen-429aa5254200", @@ -198,7 +198,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): sha256 = "970285762565c7890c6c087d262b0a18286e7d0384f13a37786d8521773bc969", strip_prefix = "tools-0e906ebc527eab1cdbf7adabff5b474da9562e9f/arm-bcm2708/arm-rpi-4.9.3-linux-gnueabihf", urls = [ - "http://mirror.bazel.build/github.com/raspberrypi/tools/archive/0e906ebc527eab1cdbf7adabff5b474da9562e9f.tar.gz", + "https://mirror.bazel.build/github.com/raspberrypi/tools/archive/0e906ebc527eab1cdbf7adabff5b474da9562e9f.tar.gz", # "https://github.com/raspberrypi/tools/archive/0e906ebc527eab1cdbf7adabff5b474da9562e9f.tar.gz", ], ) @@ -206,7 +206,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.new_http_archive( name = "libxsmm_archive", urls = [ - "http://mirror.bazel.build/github.com/hfp/libxsmm/archive/1.8.1.tar.gz", + "https://mirror.bazel.build/github.com/hfp/libxsmm/archive/1.8.1.tar.gz", # "https://github.com/hfp/libxsmm/archive/1.8.1.tar.gz", ], sha256 = "2ade869c3f42f23b5263c7d594aa3c7e5e61ac6a3afcaf5d6e42899d2a7986ce", @@ -222,7 +222,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.new_http_archive( name = "ortools_archive", urls = [ - "http://mirror.bazel.build/github.com/google/or-tools/archive/253f7955c6a1fd805408fba2e42ac6d45b312d15.tar.gz", + "https://mirror.bazel.build/github.com/google/or-tools/archive/253f7955c6a1fd805408fba2e42ac6d45b312d15.tar.gz", # "https://github.com/google/or-tools/archive/253f7955c6a1fd805408fba2e42ac6d45b312d15.tar.gz", ], sha256 = "932075525642b04ac6f1b50589f1df5cd72ec2f448b721fd32234cf183f0e755", @@ -233,7 +233,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.http_archive( name = "com_googlesource_code_re2", urls = [ - "http://mirror.bazel.build/github.com/google/re2/archive/b94b7cd42e9f02673cd748c1ac1d16db4052514c.tar.gz", + "https://mirror.bazel.build/github.com/google/re2/archive/b94b7cd42e9f02673cd748c1ac1d16db4052514c.tar.gz", # "https://github.com/google/re2/archive/b94b7cd42e9f02673cd748c1ac1d16db4052514c.tar.gz", ], sha256 = "bd63550101e056427c9e7ff12a408c1c8b74e9803f393ca916b2926fc2c4906f", @@ -243,7 +243,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.http_archive( name = "gemmlowp", urls = [ - "http://mirror.bazel.build/github.com/google/gemmlowp/archive/010bb3e71a26ca1d0884a167081d092b43563996.zip" + "https://mirror.bazel.build/github.com/google/gemmlowp/archive/010bb3e71a26ca1d0884a167081d092b43563996.zip" # "https://github.com/google/gemmlowp/archive/010bb3e71a26ca1d0884a167081d092b43563996.zip", ], sha256 = "dd2557072bde12141419cb8320a9c25e6ec41a8ae53c2ac78c076a347bb46d9d", @@ -253,7 +253,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.new_http_archive( name = "farmhash_archive", urls = [ - "http://mirror.bazel.build/github.com/google/farmhash/archive/816a4ae622e964763ca0862d9dbd19324a1eaf45.tar.gz", + "https://mirror.bazel.build/github.com/google/farmhash/archive/816a4ae622e964763ca0862d9dbd19324a1eaf45.tar.gz", # "https://github.com/google/farmhash/archive/816a4ae622e964763ca0862d9dbd19324a1eaf45.tar.gz", ], sha256 = "6560547c63e4af82b0f202cb710ceabb3f21347a4b996db565a411da5b17aba0", @@ -269,7 +269,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.new_http_archive( name = "highwayhash", urls = [ - "http://mirror.bazel.build/github.com/google/highwayhash/archive/dfcb97ca4fe9277bf9dc1802dd979b071896453b.tar.gz", + "https://mirror.bazel.build/github.com/google/highwayhash/archive/dfcb97ca4fe9277bf9dc1802dd979b071896453b.tar.gz", # "https://github.com/google/highwayhash/archive/dfcb97ca4fe9277bf9dc1802dd979b071896453b.tar.gz", ], sha256 = "0f30a15b1566d93f146c8d149878a06e91d9bb7ec2cfd76906df62a82be4aac9", @@ -280,7 +280,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.new_http_archive( name = "nasm", urls = [ - "http://mirror.bazel.build/www.nasm.us/pub/nasm/releasebuilds/2.12.02/nasm-2.12.02.tar.bz2", + "https://mirror.bazel.build/www.nasm.us/pub/nasm/releasebuilds/2.12.02/nasm-2.12.02.tar.bz2", "http://pkgs.fedoraproject.org/repo/pkgs/nasm/nasm-2.12.02.tar.bz2/d15843c3fb7db39af80571ee27ec6fad/nasm-2.12.02.tar.bz2", ], sha256 = "00b0891c678c065446ca59bcee64719d0096d54d6886e6e472aeee2e170ae324", @@ -291,7 +291,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): temp_workaround_http_archive( name = "jpeg", urls = [ - "http://mirror.bazel.build/github.com/libjpeg-turbo/libjpeg-turbo/archive/1.5.1.tar.gz", + "https://mirror.bazel.build/github.com/libjpeg-turbo/libjpeg-turbo/archive/1.5.1.tar.gz", # "https://github.com/libjpeg-turbo/libjpeg-turbo/archive/1.5.1.tar.gz", ], sha256 = "c15a9607892113946379ccea3ca8b85018301b200754f209453ab21674268e77", @@ -303,7 +303,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.new_http_archive( name = "png_archive", urls = [ - "http://mirror.bazel.build/github.com/glennrp/libpng/archive/v1.2.53.tar.gz", + "https://mirror.bazel.build/github.com/glennrp/libpng/archive/v1.2.53.tar.gz", # "https://github.com/glennrp/libpng/archive/v1.2.53.tar.gz", ], sha256 = "716c59c7dfc808a4c368f8ada526932be72b2fcea11dd85dc9d88b1df1dfe9c2", @@ -314,7 +314,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.new_http_archive( name = "sqlite_archive", urls = [ - "http://mirror.bazel.build/www.sqlite.org/2017/sqlite-amalgamation-3200000.zip", + "https://mirror.bazel.build/www.sqlite.org/2017/sqlite-amalgamation-3200000.zip", "http://www.sqlite.org/2017/sqlite-amalgamation-3200000.zip", ], sha256 = "208780b3616f9de0aeb50822b7a8f5482f6515193859e91ed61637be6ad74fd4", @@ -325,7 +325,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.new_http_archive( name = "gif_archive", urls = [ - "http://mirror.bazel.build/ufpr.dl.sourceforge.net/project/giflib/giflib-5.1.4.tar.gz", + "https://mirror.bazel.build/ufpr.dl.sourceforge.net/project/giflib/giflib-5.1.4.tar.gz", "http://pilotfiber.dl.sourceforge.net/project/giflib/giflib-5.1.4.tar.gz", ], sha256 = "34a7377ba834397db019e8eb122e551a49c98f49df75ec3fcc92b9a794a4f6d1", @@ -336,7 +336,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.new_http_archive( name = "six_archive", urls = [ - "http://mirror.bazel.build/pypi.python.org/packages/source/s/six/six-1.10.0.tar.gz", + "https://mirror.bazel.build/pypi.python.org/packages/source/s/six/six-1.10.0.tar.gz", "https://pypi.python.org/packages/source/s/six/six-1.10.0.tar.gz", ], sha256 = "105f8d68616f8248e24bf0e9372ef04d3cc10104f1980f54d57b2ce73a5ad56a", @@ -347,7 +347,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.new_http_archive( name = "org_python_pypi_backports_weakref", urls = [ - "http://mirror.bazel.build/pypi.python.org/packages/bc/cc/3cdb0a02e7e96f6c70bd971bc8a90b8463fda83e264fa9c5c1c98ceabd81/backports.weakref-1.0rc1.tar.gz", + "https://mirror.bazel.build/pypi.python.org/packages/bc/cc/3cdb0a02e7e96f6c70bd971bc8a90b8463fda83e264fa9c5c1c98ceabd81/backports.weakref-1.0rc1.tar.gz", "https://pypi.python.org/packages/bc/cc/3cdb0a02e7e96f6c70bd971bc8a90b8463fda83e264fa9c5c1c98ceabd81/backports.weakref-1.0rc1.tar.gz", ], sha256 = "8813bf712a66b3d8b85dc289e1104ed220f1878cf981e2fe756dfaabe9a82892", @@ -358,7 +358,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.new_http_archive( name = "com_github_andreif_codegen", urls = [ - "http://mirror.bazel.build/github.com/andreif/codegen/archive/1.0.tar.gz", + "https://mirror.bazel.build/github.com/andreif/codegen/archive/1.0.tar.gz", # "https://github.com/andreif/codegen/archive/1.0.tar.gz", ], sha256 = "2dadd04a2802de27e0fe5a19b76538f6da9d39ff244036afa00c1bba754de5ee", @@ -371,7 +371,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): licenses = ["notice"], # Python 2.0 sha256_urls = { "b5556e921715ddb9242c076cae3963f483aa47266c5e37ea4c187f77cc79501c": [ - "http://mirror.bazel.build/docs.python.org/2.7/_sources/license.txt", + "https://mirror.bazel.build/docs.python.org/2.7/_sources/license.txt", "https://docs.python.org/2.7/_sources/license.txt", ], }, @@ -387,7 +387,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): patched_http_archive( name = "protobuf_archive", urls = [ - "http://mirror.bazel.build/github.com/google/protobuf/archive/b04e5cba356212e4e8c66c61bbe0c3a20537c5b9.tar.gz", + "https://mirror.bazel.build/github.com/google/protobuf/archive/b04e5cba356212e4e8c66c61bbe0c3a20537c5b9.tar.gz", ], sha256 = "e178a25c52efcb6b05988bdbeace4c0d3f2d2fe5b46696d1d9898875c3803d6a", strip_prefix = "protobuf-b04e5cba356212e4e8c66c61bbe0c3a20537c5b9", @@ -410,7 +410,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.http_archive( name = "com_google_protobuf", urls = [ - "http://mirror.bazel.build/github.com/google/protobuf/archive/0b059a3d8a8f8aa40dde7bea55edca4ec5dfea66.tar.gz", + "https://mirror.bazel.build/github.com/google/protobuf/archive/0b059a3d8a8f8aa40dde7bea55edca4ec5dfea66.tar.gz", ], sha256 = "6d43b9d223ce09e5d4ce8b0060cb8a7513577a35a64c7e3dad10f0703bf3ad93", strip_prefix = "protobuf-0b059a3d8a8f8aa40dde7bea55edca4ec5dfea66", @@ -420,7 +420,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.http_archive( name = "com_google_protobuf_cc", urls = [ - "http://mirror.bazel.build/github.com/google/protobuf/archive/0b059a3d8a8f8aa40dde7bea55edca4ec5dfea66.tar.gz", + "https://mirror.bazel.build/github.com/google/protobuf/archive/0b059a3d8a8f8aa40dde7bea55edca4ec5dfea66.tar.gz", ], sha256 = "6d43b9d223ce09e5d4ce8b0060cb8a7513577a35a64c7e3dad10f0703bf3ad93", strip_prefix = "protobuf-0b059a3d8a8f8aa40dde7bea55edca4ec5dfea66", @@ -429,7 +429,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.http_archive( name = "nsync", urls = [ - "http://mirror.bazel.build/github.com/google/nsync/archive/ad722c76c6e6653f66be2e1f69521b7f7517da55.tar.gz", + "https://mirror.bazel.build/github.com/google/nsync/archive/ad722c76c6e6653f66be2e1f69521b7f7517da55.tar.gz", # "https://github.com/google/nsync/archive/ad722c76c6e6653f66be2e1f69521b7f7517da55.tar.gz", ], sha256 = "7dd8ca49319f77e8226cd020a9210a525f88ac26e7041c59c95418223a1cdf55", @@ -439,7 +439,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.http_archive( name = "com_google_googletest", urls = [ - "http://mirror.bazel.build/github.com/google/googletest/archive/9816b96a6ddc0430671693df90192bbee57108b6.zip", + "https://mirror.bazel.build/github.com/google/googletest/archive/9816b96a6ddc0430671693df90192bbee57108b6.zip", # "https://github.com/google/googletest/archive/9816b96a6ddc0430671693df90192bbee57108b6.zip", ], sha256 = "9cbca84c4256bed17df2c8f4d00c912c19d247c11c9ba6647cd6dd5b5c996b8d", @@ -449,7 +449,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.http_archive( name = "com_github_gflags_gflags", urls = [ - "http://mirror.bazel.build/github.com/gflags/gflags/archive/f8a0efe03aa69b3336d8e228b37d4ccb17324b88.tar.gz", + "https://mirror.bazel.build/github.com/gflags/gflags/archive/f8a0efe03aa69b3336d8e228b37d4ccb17324b88.tar.gz", # "https://github.com/gflags/gflags/archive/f8a0efe03aa69b3336d8e228b37d4ccb17324b88.tar.gz", ], sha256 = "4d222fab8f1ede4709cdff417d15a1336f862d7334a81abf76d09c15ecf9acd1", @@ -465,7 +465,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "pcre", sha256 = "ccdf7e788769838f8285b3ee672ed573358202305ee361cfec7a4a4fb005bbc7", urls = [ - "http://mirror.bazel.build/ftp.exim.org/pub/pcre/pcre-8.39.tar.gz", + "https://mirror.bazel.build/ftp.exim.org/pub/pcre/pcre-8.39.tar.gz", "http://ftp.exim.org/pub/pcre/pcre-8.39.tar.gz", ], strip_prefix = "pcre-8.39", @@ -476,7 +476,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "swig", sha256 = "58a475dbbd4a4d7075e5fe86d4e54c9edde39847cdb96a3053d87cb64a23a453", urls = [ - "http://mirror.bazel.build/ufpr.dl.sourceforge.net/project/swig/swig/swig-3.0.8/swig-3.0.8.tar.gz", + "https://mirror.bazel.build/ufpr.dl.sourceforge.net/project/swig/swig/swig-3.0.8/swig-3.0.8.tar.gz", "http://ufpr.dl.sourceforge.net/project/swig/swig/swig-3.0.8/swig-3.0.8.tar.gz", "http://pilotfiber.dl.sourceforge.net/project/swig/swig/swig-3.0.8/swig-3.0.8.tar.gz", ], @@ -488,7 +488,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "curl", sha256 = "ff3e80c1ca6a068428726cd7dd19037a47cc538ce58ef61c59587191039b2ca6", urls = [ - "http://mirror.bazel.build/curl.haxx.se/download/curl-7.49.1.tar.gz", + "https://mirror.bazel.build/curl.haxx.se/download/curl-7.49.1.tar.gz", "https://curl.haxx.se/download/curl-7.49.1.tar.gz", ], strip_prefix = "curl-7.49.1", @@ -518,7 +518,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): patched_http_archive( name = "grpc", urls = [ - "http://mirror.bazel.build/github.com/grpc/grpc/archive/781fd6f6ea03645a520cd5c675da67ab61f87e4b.tar.gz", + "https://mirror.bazel.build/github.com/grpc/grpc/archive/781fd6f6ea03645a520cd5c675da67ab61f87e4b.tar.gz", # "https://github.com/grpc/grpc/archive/781fd6f6ea03645a520cd5c675da67ab61f87e4b.tar.gz", ], sha256 = "2004635e6a078acfac8ffa71738397796be4f8fb72f572cc44ecee5d99511d9f", @@ -542,7 +542,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "linenoise", sha256 = "7f51f45887a3d31b4ce4fa5965210a5e64637ceac12720cfce7954d6a2e812f7", urls = [ - "http://mirror.bazel.build/github.com/antirez/linenoise/archive/c894b9e59f02203dbe4e2be657572cf88c4230c3.tar.gz", + "https://mirror.bazel.build/github.com/antirez/linenoise/archive/c894b9e59f02203dbe4e2be657572cf88c4230c3.tar.gz", # "https://github.com/antirez/linenoise/archive/c894b9e59f02203dbe4e2be657572cf88c4230c3.tar.gz", ], strip_prefix = "linenoise-c894b9e59f02203dbe4e2be657572cf88c4230c3", @@ -554,7 +554,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): temp_workaround_http_archive( name = "llvm", urls = [ - "http://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/bb3c660e87f59abb665570a31b01ab125ec4c10e.tar.gz", + "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/bb3c660e87f59abb665570a31b01ab125ec4c10e.tar.gz", "https://github.com/llvm-mirror/llvm/archive/bb3c660e87f59abb665570a31b01ab125ec4c10e.tar.gz", ], sha256 = "caab6d7978e6771cb4e9b5b89607c5370de8aa642913c6c14e892468194c94e4", @@ -566,7 +566,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.new_http_archive( name = "lmdb", urls = [ - "http://mirror.bazel.build/github.com/LMDB/lmdb/archive/LMDB_0.9.19.tar.gz", + "https://mirror.bazel.build/github.com/LMDB/lmdb/archive/LMDB_0.9.19.tar.gz", # "https://github.com/LMDB/lmdb/archive/LMDB_0.9.19.tar.gz", ], sha256 = "108532fb94c6f227558d45be3f3347b52539f0f58290a7bb31ec06c462d05326", @@ -577,7 +577,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.new_http_archive( name = "jsoncpp_git", urls = [ - "http://mirror.bazel.build/github.com/open-source-parsers/jsoncpp/archive/11086dd6a7eba04289944367ca82cea71299ed70.tar.gz", + "https://mirror.bazel.build/github.com/open-source-parsers/jsoncpp/archive/11086dd6a7eba04289944367ca82cea71299ed70.tar.gz", # "https://github.com/open-source-parsers/jsoncpp/archive/11086dd6a7eba04289944367ca82cea71299ed70.tar.gz", ], sha256 = "07d34db40593d257324ec5fb9debc4dc33f29f8fb44e33a2eeb35503e61d0fe2", @@ -593,7 +593,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.http_archive( name = "boringssl", urls = [ - "https://github.com/google/boringssl/archive/a0fb951d2a26a8ee746b52f3ba81ab011a0af778.tar.gz", + "https://mirror.bazel.build/github.com/google/boringssl/archive/a0fb951d2a26a8ee746b52f3ba81ab011a0af778.tar.gz", ], sha256 = "524ba98a56300149696481b4cb9ddebd0c7b7ac9b9f6edee81da2d2d7e5d2bb3", strip_prefix = "boringssl-a0fb951d2a26a8ee746b52f3ba81ab011a0af778", @@ -602,7 +602,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.new_http_archive( name = "zlib_archive", urls = [ - "http://mirror.bazel.build/zlib.net/zlib-1.2.8.tar.gz", + "https://mirror.bazel.build/zlib.net/zlib-1.2.8.tar.gz", "http://zlib.net/fossils/zlib-1.2.8.tar.gz", ], sha256 = "36658cb768a54c1d4dec43c3116c27ed893e88b02ecfcb44f2166f9c0b7f2a0d", @@ -618,7 +618,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.new_http_archive( name = "fft2d", urls = [ - "http://mirror.bazel.build/www.kurims.kyoto-u.ac.jp/~ooura/fft.tgz", + "https://mirror.bazel.build/www.kurims.kyoto-u.ac.jp/~ooura/fft.tgz", "http://www.kurims.kyoto-u.ac.jp/~ooura/fft.tgz", ], sha256 = "52bb637c70b971958ec79c9c8752b1df5ff0218a4db4510e60826e0cb79b5296", @@ -628,7 +628,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): temp_workaround_http_archive( name = "snappy", urls = [ - "http://mirror.bazel.build/github.com/google/snappy/archive/1.1.4.tar.gz", + "https://mirror.bazel.build/github.com/google/snappy/archive/1.1.4.tar.gz", # "https://github.com/google/snappy/archive/1.1.4.tar.gz", ], sha256 = "2f7504c73d85bac842e893340333be8cb8561710642fc9562fccdd9d2c3fcc94", @@ -640,7 +640,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): temp_workaround_http_archive( name = "nccl_archive", urls = [ - "http://mirror.bazel.build/github.com/nvidia/nccl/archive/03d856977ecbaac87e598c0c4bafca96761b9ac7.tar.gz", + "https://mirror.bazel.build/github.com/nvidia/nccl/archive/03d856977ecbaac87e598c0c4bafca96761b9ac7.tar.gz", # "https://github.com/nvidia/nccl/archive/03d856977ecbaac87e598c0c4bafca96761b9ac7.tar.gz", ], sha256 = "2ca86fb6179ecbff789cc67c836139c1bbc0324ed8c04643405a30bf26325176", @@ -665,7 +665,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "junit", jar_sha256 = "59721f0805e223d84b90677887d9ff567dc534d7c502ca903c0c2b17f05c116a", jar_urls = [ - "http://mirror.bazel.build/repo1.maven.org/maven2/junit/junit/4.12/junit-4.12.jar", + "https://mirror.bazel.build/repo1.maven.org/maven2/junit/junit/4.12/junit-4.12.jar", "http://repo1.maven.org/maven2/junit/junit/4.12/junit-4.12.jar", "http://maven.ibiblio.org/maven2/junit/junit/4.12/junit-4.12.jar", ], @@ -678,7 +678,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "org_hamcrest_core", jar_sha256 = "66fdef91e9739348df7a096aa384a5685f4e875584cce89386a7a47251c4d8e9", jar_urls = [ - "http://mirror.bazel.build/repo1.maven.org/maven2/org/hamcrest/hamcrest-core/1.3/hamcrest-core-1.3.jar", + "https://mirror.bazel.build/repo1.maven.org/maven2/org/hamcrest/hamcrest-core/1.3/hamcrest-core-1.3.jar", "http://repo1.maven.org/maven2/org/hamcrest/hamcrest-core/1.3/hamcrest-core-1.3.jar", "http://maven.ibiblio.org/maven2/org/hamcrest/hamcrest-core/1.3/hamcrest-core-1.3.jar", ], @@ -689,7 +689,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): temp_workaround_http_archive( name = "jemalloc", urls = [ - "http://mirror.bazel.build/github.com/jemalloc/jemalloc/archive/4.4.0.tar.gz", + "https://mirror.bazel.build/github.com/jemalloc/jemalloc/archive/4.4.0.tar.gz", # "https://github.com/jemalloc/jemalloc/archive/4.4.0.tar.gz", ], sha256 = "3c8f25c02e806c3ce0ab5fb7da1817f89fc9732709024e2a81b6b82f7cc792a8", @@ -726,7 +726,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.new_http_archive( name = "com_google_pprof", urls = [ - "http://mirror.bazel.build/github.com/google/pprof/archive/c0fb62ec88c411cc91194465e54db2632845b650.tar.gz", + "https://mirror.bazel.build/github.com/google/pprof/archive/c0fb62ec88c411cc91194465e54db2632845b650.tar.gz", # "https://github.com/google/pprof/archive/c0fb62ec88c411cc91194465e54db2632845b650.tar.gz", ], sha256 = "e0928ca4aa10ea1e0551e2d7ce4d1d7ea2d84b2abbdef082b0da84268791d0c4", @@ -737,8 +737,8 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.new_http_archive( name = "cub_archive", urls = [ - "http://mirror.bazel.build/github.com/NVlabs/cub/archive/1.7.4.zip", - "https://github.com/NVlabs/cub/archive/1.7.4.zip", + "https://mirror.bazel.build/github.com/NVlabs/cub/archive/1.7.4.zip", + # "https://github.com/NVlabs/cub/archive/1.7.4.zip", ], sha256 = "20a1a39fd97e5da7f40f5f2e7fd73fd2ea59f9dc4bb8a6c5f228aa543e727e31", strip_prefix = "cub-1.7.4", @@ -754,7 +754,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "cython", sha256 = "6dcd30b5ceb887b2b965ee7ceb82ea3acb5f0642fe2206c7636b45acea4798e5", urls = [ - "http://mirror.bazel.build/github.com/cython/cython/archive/3732784c45cfb040a5b0936951d196f83a12ea17.tar.gz", + "https://mirror.bazel.build/github.com/cython/cython/archive/3732784c45cfb040a5b0936951d196f83a12ea17.tar.gz", "https://github.com/cython/cython/archive/3732784c45cfb040a5b0936951d196f83a12ea17.tar.gz", ], strip_prefix = "cython-3732784c45cfb040a5b0936951d196f83a12ea17", @@ -764,7 +764,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.http_archive( name = "bazel_toolchains", urls = [ - "http://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/archive/b2b4b38433bf2d1159360855ea4004378308711b.tar.gz", + "https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/archive/b2b4b38433bf2d1159360855ea4004378308711b.tar.gz", # "https://github.com/bazelbuild/bazel-toolchains/archive/b2b4b38433bf2d1159360855ea4004378308711b.tar.gz", ], sha256 = "46187270ca04ff8109980f45c3438fabfe48695e163789096eb82ee097ffe685",