提交 b3d5ec90 编写于 作者: V Vijay Vasudevan 提交者: GitHub

Merge pull request #13866 from vrv/branch_172924803

Branch 172924803
......@@ -5,7 +5,7 @@ http_archive(
sha256 = "110fe68753413777944b473c25eed6368c4a0487cee23a7bac1b13cc49d3e257",
strip_prefix = "rules_closure-4af89ef1db659eb41f110df189b67d4cf14073e1",
urls = [
"http://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/4af89ef1db659eb41f110df189b67d4cf14073e1.tar.gz",
"https://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/4af89ef1db659eb41f110df189b67d4cf14073e1.tar.gz",
"https://github.com/bazelbuild/rules_closure/archive/4af89ef1db659eb41f110df189b67d4cf14073e1.tar.gz", # 2017-08-28
],
)
......
......@@ -348,6 +348,7 @@ filegroup(
"//tensorflow/compiler/xla/service/llvm_ir:all_files",
"//tensorflow/compiler/xla/tests:all_files",
"//tensorflow/compiler/xla/tools:all_files",
"//tensorflow/compiler/xla/tools/parser:all_files",
"//tensorflow/contrib:all_files",
"//tensorflow/contrib/all_reduce:all_files",
"//tensorflow/contrib/android:all_files",
......@@ -421,7 +422,6 @@ filegroup(
"//tensorflow/contrib/remote_fused_graph/pylib:all_files",
"//tensorflow/contrib/resampler:all_files",
"//tensorflow/contrib/rnn:all_files",
"//tensorflow/contrib/s3:all_files",
"//tensorflow/contrib/saved_model:all_files",
"//tensorflow/contrib/saved_model/cc/saved_model:all_files",
"//tensorflow/contrib/seq2seq:all_files",
......@@ -475,6 +475,7 @@ filegroup(
"//tensorflow/core/platform/cloud:all_files",
"//tensorflow/core/platform/default/build_config:all_files",
"//tensorflow/core/platform/hadoop:all_files",
"//tensorflow/core/platform/s3:all_files",
"//tensorflow/core/profiler:all_files",
"//tensorflow/core/profiler/internal:all_files",
"//tensorflow/core/profiler/internal/advisor:all_files",
......
......@@ -3,6 +3,7 @@ licenses(["notice"]) # Apache 2.0
load(
"//tensorflow:tensorflow.bzl",
"tf_cuda_cc_test",
"tf_cc_test",
"tf_copts",
"tf_cuda_library",
......@@ -50,7 +51,7 @@ tf_cuda_library(
],
)
tf_cc_test(
tf_cuda_cc_test(
name = "c_api_test",
srcs = ["c_api_test.cc"],
deps = [
......
......@@ -54,9 +54,23 @@ string DeviceName(tensorflow::Device* d) {
extern "C" {
TFE_Context* TFE_NewContext(const TF_SessionOptions* opts, TF_Status* status) {
TFE_ContextOptions* TFE_NewContextOptions() { return new TFE_ContextOptions; }
void TFE_ContextOptionsSetConfig(TFE_ContextOptions* options, const void* proto,
size_t proto_len, TF_Status* status) {
TF_SetConfig(&options->session_options, proto, proto_len, status);
}
void TFE_ContextOptionsSetDevicePlacementPolicy(
TFE_ContextOptions* options, TFE_ContextDevicePlacementPolicy policy) {
options->policy = policy;
}
void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
TF_Graph* graph = TF_NewGraph();
TF_Session* session = TF_NewSession(graph, opts, status);
TF_Session* session = TF_NewSession(graph, &opts->session_options, status);
if (status->status.ok()) {
if (session->device_mgr == nullptr || session->devices.empty()) {
status->status = tensorflow::errors::InvalidArgument(
......@@ -71,9 +85,10 @@ TFE_Context* TFE_NewContext(const TF_SessionOptions* opts, TF_Status* status) {
}
TFE_Context* ret = new TFE_Context(session);
ret->policy = opts->policy;
ret->pflr.reset(new tensorflow::ProcessFunctionLibraryRuntime(
ret->session->device_mgr, opts->options.env, TF_GRAPH_DEF_VERSION,
&ret->func_lib_def, {}));
ret->session->device_mgr, opts->session_options.options.env,
TF_GRAPH_DEF_VERSION, &ret->func_lib_def, {}));
ret->rendezvous =
new tensorflow::IntraProcessRendezvous(ret->session->device_mgr);
......@@ -408,8 +423,10 @@ void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
namespace {
tensorflow::Status ValidateInputTypeAndPlacement(
tensorflow::Device* host_device, tensorflow::Device* op_device, TFE_Op* op,
const tensorflow::OpKernel* kernel) {
TFE_Context* ctx, tensorflow::Device* host_device,
tensorflow::Device* op_device, TFE_Op* op,
const tensorflow::OpKernel* kernel,
std::vector<TFE_TensorHandle*>* copied_tensors) {
const tensorflow::MemoryTypeVector& memtypes = kernel->input_memory_types();
if (memtypes.size() != op->inputs.size()) {
return tensorflow::errors::InvalidArgument(
......@@ -421,11 +438,42 @@ tensorflow::Status ValidateInputTypeAndPlacement(
const tensorflow::Device* actual_device =
op->input_devices[i] == nullptr ? host_device : op->input_devices[i];
if (expected_device != actual_device) {
return tensorflow::errors::InvalidArgument(
"cannot compute ", op->name, " as input #", i,
" was expected to be on ", expected_device->name(),
" but is actually on ", actual_device->name(),
" (operation running on ", op_device->name(), ")");
switch (ctx->policy) {
case TFE_DEVICE_PLACEMENT_EXPLICIT:
return tensorflow::errors::InvalidArgument(
"cannot compute ", op->name, " as input #", i,
" was expected to be on ", expected_device->name(),
" but is actually on ", actual_device->name(),
" (operation running on ", op_device->name(), ")");
case TFE_DEVICE_PLACEMENT_WARN:
LOG(WARNING) << "before computing " << op->name << " input #" << i
<< " was expected to be on " << expected_device->name()
<< " but is actually on " << actual_device->name()
<< " (operation running on " << op_device->name()
<< "). This triggers a copy which can be a performance "
"bottleneck.";
break;
case TFE_DEVICE_PLACEMENT_SILENT: // Do nothing.
break;
}
// We are only here if the policy is warn or silent copies, so we should
// trigger a copy.
TFE_TensorHandle original{op->inputs[i], op->input_devices[i]};
TF_Status* s = TF_NewStatus();
TFE_TensorHandle* copied_tensor = TFE_TensorHandleCopyToDevice(
&original, ctx, expected_device->name().c_str(), s);
if (!s->status.ok()) {
tensorflow::Status status = s->status;
delete s;
return tensorflow::errors::Internal(
"Failed copying input tensor from ", actual_device->name(), " to ",
expected_device->name(), " in order to run ", op->name, ": ",
status.error_message());
}
op->inputs[i] = copied_tensor->t;
copied_tensors->push_back(copied_tensor);
op->input_devices[i] = copied_tensor->d;
delete s;
}
if (op->inputs[i].dtype() != kernel->input_type(i)) {
return tensorflow::errors::InvalidArgument(
......@@ -468,10 +516,14 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
}
tensorflow::gtl::InsertOrUpdate(&(ctx->kernel_cache), cache_key, kernel);
}
status->status = ValidateInputTypeAndPlacement(ctx->devices()[0], device, op,
kernel->kernel());
std::vector<TFE_TensorHandle*> copied_tensors;
status->status = ValidateInputTypeAndPlacement(
ctx, ctx->devices()[0], device, op, kernel->kernel(), &copied_tensors);
output_memory_types = &kernel->kernel()->output_memory_types();
if (!status->status.ok()) {
for (auto* t : copied_tensors) {
TFE_DeleteTensorHandle(t);
}
return;
}
// WARNING: kernel->Run utilizes the FunctionLibraryRuntime
......@@ -483,6 +535,9 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
// sense for FunctionLibraryRuntime to ensure thread-safe access to
// FunctionLibraryDefinition?).
status->status = kernel->Run(&op->inputs, &outputs);
for (auto* t : copied_tensors) {
TFE_DeleteTensorHandle(t);
}
if (!status->status.ok()) return;
*num_retvals = std::min<int>(*num_retvals, outputs.size());
for (int i = 0; i < *num_retvals; ++i) {
......
......@@ -43,14 +43,46 @@ limitations under the License.
extern "C" {
#endif
typedef struct TFE_ContextOptions TFE_ContextOptions;
// Return a new options object.
TF_CAPI_EXPORT extern TFE_ContextOptions* TFE_NewContextOptions();
// Set the config in TF_ContextOptions.options.
// config should be a serialized tensorflow.ConfigProto proto.
// If config was not parsed successfully as a ConfigProto, record the
// error information in *status.
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetConfig(
TFE_ContextOptions* options, const void* proto, size_t proto_len,
TF_Status* status);
// Controls how to act when we try to run an operation on a given device but
// some input tensors are not on that device.
typedef enum TFE_ContextDevicePlacementPolicy {
// The default: running operations with input tensors on the wrong device will
// fail.
TFE_DEVICE_PLACEMENT_EXPLICIT = 0,
// Copy the tensor to the right device but log a warning.
TFE_DEVICE_PLACEMENT_WARN = 1,
// Silently copy the tensor, which has a performance cost since the
// operation will be blocked till the copy completes.
TFE_DEVICE_PLACEMENT_SILENT = 2,
} TFE_ContextDevicePlacementPolicy;
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetDevicePlacementPolicy(
TFE_ContextOptions*, TFE_ContextDevicePlacementPolicy);
// Destroy an options object.
TF_CAPI_EXPORT extern void TFE_DeleteContextOptions(TFE_ContextOptions*);
// "Context" under which operations/functions are executed. It encapsulates
// things like the available devices, resource manager etc.
//
// TODO(ashankar): Merge with TF_Session?
typedef struct TFE_Context TFE_Context;
TF_CAPI_EXPORT extern TFE_Context* TFE_NewContext(const TF_SessionOptions* opts,
TF_Status* status);
TF_CAPI_EXPORT extern TFE_Context* TFE_NewContext(
const TFE_ContextOptions* opts, TF_Status* status);
TF_CAPI_EXPORT extern void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status);
TF_CAPI_EXPORT extern TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx,
TF_Status* status);
......
......@@ -35,9 +35,16 @@ limitations under the License.
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
struct TFE_ContextOptions {
TF_SessionOptions session_options;
TFE_ContextDevicePlacementPolicy policy{TFE_DEVICE_PLACEMENT_EXPLICIT};
};
struct TFE_Context {
explicit TFE_Context(TF_Session* s) : session(s) {}
TFE_ContextDevicePlacementPolicy policy;
// TFE_Context is an extension of TF_Session. And TF_Session needs a TF_Graph.
TF_Session* session;
tensorflow::Rendezvous* rendezvous;
......
......@@ -62,10 +62,10 @@ TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
void BM_InitOp(int iters) {
tensorflow::testing::StopTiming();
TF_Status* status = TF_NewStatus();
TF_SessionOptions* opts = TF_NewSessionOptions();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteSessionOptions(opts);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* m = TestMatrixTensorHandle();
tensorflow::testing::StartTiming();
......@@ -84,10 +84,10 @@ BENCHMARK(BM_InitOp);
void BM_Execute(int iters) {
tensorflow::testing::StopTiming();
TF_Status* status = TF_NewStatus();
TF_SessionOptions* opts = TF_NewSessionOptions();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteSessionOptions(opts);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* m = TestMatrixTensorHandle();
TFE_Op* matmul = MatMulOp(ctx, m, m);
......@@ -109,9 +109,9 @@ BENCHMARK(BM_Execute);
TEST(CAPI, Context) {
TF_Status* status = TF_NewStatus();
TF_SessionOptions* opts = TF_NewSessionOptions();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
TF_DeleteSessionOptions(opts);
TFE_DeleteContextOptions(opts);
TF_DeviceList* devices = TFE_ContextListDevices(ctx, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
......@@ -150,9 +150,9 @@ TEST(CAPI, TensorHandle) {
TEST(CAPI, TensorHandleCopyBetweenDevices) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_SessionOptions* opts = TF_NewSessionOptions();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status.get());
TF_DeleteSessionOptions(opts);
TFE_DeleteContextOptions(opts);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
......@@ -216,12 +216,58 @@ TEST(CAPI, TensorHandleCopyBetweenDevices) {
EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
}
TEST(CAPI, TensorHandleSilentCopy) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status.get());
TFE_DeleteContextOptions(opts);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
const int num_devices = TF_DeviceListCount(devices);
// Disable the test if no GPU is present.
if (num_devices > 1) {
const int device_to_use = 1;
const string name(TF_DeviceListName(devices, device_to_use, status.get()));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* hgpu =
TFE_TensorHandleCopyToDevice(hcpu, ctx, name.c_str(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu);
TFE_OpSetDevice(matmul, name.c_str(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* retvals[1];
int num_retvals = 1;
TFE_Execute(matmul, &retvals[0], &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_DeleteOp(matmul);
TFE_DeleteTensorHandle(retvals[0]);
TFE_DeleteTensorHandle(hgpu);
}
TF_DeleteDeviceList(devices);
TF_DeleteTensor(t);
TFE_DeleteTensorHandle(hcpu);
TFE_DeleteContext(ctx, status.get());
EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
}
TEST(CAPI, Execute) {
TF_Status* status = TF_NewStatus();
TF_SessionOptions* opts = TF_NewSessionOptions();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteSessionOptions(opts);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* m = TestMatrixTensorHandle();
TFE_Op* matmul = MatMulOp(ctx, m, m);
......@@ -285,10 +331,10 @@ string MatMulFunction() {
TEST(CAPI, FunctionDefAndExecute) {
TF_Status* status = TF_NewStatus();
TF_SessionOptions* opts = TF_NewSessionOptions();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteSessionOptions(opts);
TFE_DeleteContextOptions(opts);
string function_def = MatMulFunction();
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
......@@ -326,10 +372,10 @@ TEST(CAPI, FunctionDefAndExecute) {
void BM_ExecuteFunction(int iters) {
tensorflow::testing::StopTiming();
TF_Status* status = TF_NewStatus();
TF_SessionOptions* opts = TF_NewSessionOptions();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteSessionOptions(opts);
TFE_DeleteContextOptions(opts);
string function_def = MatMulFunction();
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
......@@ -406,10 +452,10 @@ TEST(CAPI, Variables) {
// Variables use resource handles, so this is really a test for resource
// tensor handling.
TF_Status* status = TF_NewStatus();
TF_SessionOptions* opts = TF_NewSessionOptions();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteSessionOptions(opts);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* var_handle = CreateVariable(ctx, 12.0, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
......@@ -446,10 +492,10 @@ TEST(CAPI, Variables) {
void BM_ReadVariable(int iters) {
tensorflow::testing::StopTiming();
TF_Status* status = TF_NewStatus();
TF_SessionOptions* opts = TF_NewSessionOptions();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteSessionOptions(opts);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* var_handle = CreateVariable(ctx, 5.0, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
......
......@@ -138,6 +138,11 @@ class ComputationBuilder {
ComputationDataHandle ConstantR2(
std::initializer_list<std::initializer_list<NativeT>> values);
template <typename NativeT>
ComputationDataHandle ConstantFromArrayWithLayout(
const Array<NativeT>& values, const Layout& layout);
template <typename NativeT>
ComputationDataHandle ConstantFromArray(const Array<NativeT>& values);
template <typename NativeT>
ComputationDataHandle ConstantR2FromArray2DWithLayout(
const Array2D<NativeT>& values, const Layout& layout);
template <typename NativeT>
......@@ -910,48 +915,54 @@ ComputationDataHandle ComputationBuilder::ConstantR2(
}
template <typename NativeT>
ComputationDataHandle ComputationBuilder::ConstantR2FromArray2DWithLayout(
const Array2D<NativeT>& values, const Layout& layout) {
ComputationDataHandle ComputationBuilder::ConstantFromArrayWithLayout(
const Array<NativeT>& values, const Layout& layout) {
return ConstantOp([&values, &layout](Literal* literal) {
literal->PopulateR2FromArray2DWithLayout(values, layout);
literal->PopulateFromArrayWithLayout(values, layout);
});
}
template <typename NativeT>
ComputationDataHandle ComputationBuilder::ConstantFromArray(
const Array<NativeT>& values) {
return ConstantOp(
[&values](Literal* literal) { literal->PopulateFromArray(values); });
}
template <typename NativeT>
ComputationDataHandle ComputationBuilder::ConstantR2FromArray2DWithLayout(
const Array2D<NativeT>& values, const Layout& layout) {
return ConstantFromArrayWithLayout(values, layout);
}
template <typename NativeT>
ComputationDataHandle ComputationBuilder::ConstantR2FromArray2D(
const Array2D<NativeT>& values) {
return ConstantOp(
[&values](Literal* literal) { literal->PopulateR2FromArray2D(values); });
return ConstantFromArray(values);
}
template <typename NativeT>
ComputationDataHandle ComputationBuilder::ConstantR3FromArray3DWithLayout(
const Array3D<NativeT>& values, const Layout& layout) {
return ConstantOp([&values, &layout](Literal* literal) {
literal->PopulateR3FromArray3DWithLayout(values, layout);
});
return ConstantFromArrayWithLayout(values, layout);
}
template <typename NativeT>
ComputationDataHandle ComputationBuilder::ConstantR3FromArray3D(
const Array3D<NativeT>& values) {
return ConstantOp(
[&values](Literal* literal) { literal->PopulateR3FromArray3D(values); });
return ConstantFromArray(values);
}
template <typename NativeT>
ComputationDataHandle ComputationBuilder::ConstantR4FromArray4DWithLayout(
const Array4D<NativeT>& values, const Layout& layout) {
return ConstantOp([&values, &layout](Literal* literal) {
literal->PopulateR4FromArray4DWithLayout(values, layout);
});
return ConstantFromArrayWithLayout(values, layout);
}
template <typename NativeT>
ComputationDataHandle ComputationBuilder::ConstantR4FromArray4D(
const Array4D<NativeT>& values) {
return ConstantOp(
[&values](Literal* literal) { literal->PopulateR4FromArray4D(values); });
return ConstantFromArray(values);
}
} // namespace xla
......
......@@ -83,6 +83,10 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
return CreateDefaultLayoutForRank(shape.dimensions_size());
}
/* static */ Layout LayoutUtil::GetDefaultLayoutForRank(int64 rank) {
return CreateDefaultLayoutForRank(rank);
}
/* static */ Layout LayoutUtil::GetDefaultLayoutForR2() {
return CreateDefaultLayoutForRank(2);
}
......
......@@ -40,6 +40,7 @@ class LayoutUtil {
static Layout GetDefaultLayoutForShape(const Shape& shape);
// Helper functions that create default layouts for various ranks.
static Layout GetDefaultLayoutForRank(int64 rank);
static Layout GetDefaultLayoutForR2();
static Layout GetDefaultLayoutForR3();
static Layout GetDefaultLayoutForR4();
......
......@@ -206,9 +206,9 @@ void AllocateFlags() {
flag_values->xla_gpu_disable_multi_streaming(),
"If true, multi-streaming in the GPU backend is disabled."),
tensorflow::Flag(
"xla_dump_debug_json_to",
flag_values->mutable_xla_dump_debug_json_to(),
"Dump compilation artifacts as JSON into this directory."),
"xla_dump_hlo_proto_to",
flag_values->mutable_xla_dump_hlo_proto_to(),
"Dump compilation artifacts as proto binary into this directory."),
tensorflow::Flag(
"xla_test_all_output_layouts",
bool_setter_for(&DebugOptions::set_xla_test_all_output_layouts),
......
......@@ -334,6 +334,11 @@ class Literal {
// WithLayout use the default XLA layout for the literal's linear
// representation in memory.
template <typename NativeT>
static std::unique_ptr<Literal> CreateFromArray(const Array<NativeT>& values);
template <typename NativeT>
static std::unique_ptr<Literal> CreateFromArrayWithLayout(
const Array<NativeT>& values, const Layout& layout);
template <typename NativeT>
static std::unique_ptr<Literal> CreateR2FromArray2D(
const Array2D<NativeT>& values);
template <typename NativeT>
......@@ -481,6 +486,11 @@ class Literal {
std::initializer_list<std::initializer_list<NativeT>> values,
const Layout& layout);
template <typename NativeT>
void PopulateFromArray(const Array<NativeT>& values);
template <typename NativeT>
void PopulateFromArrayWithLayout(const Array<NativeT>& values,
const Layout& layout);
template <typename NativeT>
void PopulateR2FromArray2D(const Array2D<NativeT>& values);
template <typename NativeT>
void PopulateR2FromArray2DWithLayout(const Array2D<NativeT>& values,
......@@ -816,33 +826,42 @@ template <typename NativeT>
}
template <typename NativeT>
/* static */ std::unique_ptr<Literal> Literal::CreateR2FromArray2DWithLayout(
const Array2D<NativeT>& values, const Layout& layout) {
/* static */ std::unique_ptr<Literal> Literal::CreateFromArrayWithLayout(
const Array<NativeT>& values, const Layout& layout) {
auto literal = MakeUnique<Literal>();
literal->PopulateR2FromArray2DWithLayout(values, layout);
literal->PopulateFromArrayWithLayout(values, layout);
return literal;
}
template <typename NativeT>
/* static */ std::unique_ptr<Literal> Literal::CreateFromArray(
const Array<NativeT>& values) {
return CreateFromArrayWithLayout(
values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions()));
}
template <typename NativeT>
/* static */ std::unique_ptr<Literal> Literal::CreateR2FromArray2DWithLayout(
const Array2D<NativeT>& values, const Layout& layout) {
return CreateFromArrayWithLayout(values, layout);
}
template <typename NativeT>
/* static */ std::unique_ptr<Literal> Literal::CreateR2FromArray2D(
const Array2D<NativeT>& values) {
return CreateR2FromArray2DWithLayout(values,
LayoutUtil::GetDefaultLayoutForR2());
return CreateFromArray(values);
}
template <typename NativeT>
/* static */ std::unique_ptr<Literal> Literal::CreateR3FromArray3DWithLayout(
const Array3D<NativeT>& values, const Layout& layout) {
auto literal = MakeUnique<Literal>();
literal->PopulateR3FromArray3DWithLayout(values, layout);
return literal;
return CreateFromArrayWithLayout(values, layout);
}
template <typename NativeT>
/* static */ std::unique_ptr<Literal> Literal::CreateR3FromArray3D(
const Array3D<NativeT>& values) {
return CreateR3FromArray3DWithLayout(values,
LayoutUtil::GetDefaultLayoutForR3());
return CreateFromArray(values);
}
template <typename NativeT>
......@@ -901,16 +920,13 @@ template <typename NativeT>
template <typename NativeT>
/* static */ std::unique_ptr<Literal> Literal::CreateR4FromArray4D(
const Array4D<NativeT>& values) {
return CreateR4FromArray4DWithLayout(values,
LayoutUtil::GetDefaultLayoutForR4());
return CreateFromArray(values);
}
template <typename NativeT>
/* static */ std::unique_ptr<Literal> Literal::CreateR4FromArray4DWithLayout(
const Array4D<NativeT>& values, const Layout& layout) {
auto literal = MakeUnique<Literal>();
literal->PopulateR4FromArray4DWithLayout(values, layout);
return literal;
return CreateFromArrayWithLayout(values, layout);
}
template <typename NativeT>
......@@ -1070,82 +1086,53 @@ void Literal::PopulateR2(
}
template <typename NativeT>
void Literal::PopulateR2FromArray2DWithLayout(const Array2D<NativeT>& values,
const Layout& layout) {
void Literal::PopulateFromArrayWithLayout(const Array<NativeT>& values,
const Layout& layout) {
*mutable_shape() = ShapeUtil::MakeShapeWithLayout(
primitive_util::NativeToPrimitiveType<NativeT>(),
{values.height(), values.width()}, AsInt64Slice(layout.minor_to_major()));
primitive_util::NativeToPrimitiveType<NativeT>(), values.dimensions(),
AsInt64Slice(layout.minor_to_major()));
Reserve(values.num_elements());
values.Each([this](tensorflow::gtl::ArraySlice<int64> indices,
NativeT value) { this->Set(indices, value); });
}
const int64 dim1_size = values.width();
const int64 dim0_size = values.height();
CHECK_EQ(dim0_size, shape().dimensions(0));
CHECK_EQ(dim1_size, shape().dimensions(1));
Reserve(dim1_size * dim0_size);
for (int64 dim0 = 0; dim0 < dim0_size; ++dim0) {
for (int64 dim1 = 0; dim1 < dim1_size; ++dim1) {
Set({dim0, dim1}, values(dim0, dim1));
}
}
template <typename NativeT>
void Literal::PopulateFromArray(const Array<NativeT>& values) {
PopulateFromArrayWithLayout(
values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions()));
}
template <typename NativeT>
void Literal::PopulateR2FromArray2DWithLayout(const Array2D<NativeT>& values,
const Layout& layout) {
PopulateFromArrayWithLayout(values, layout);
}
template <typename NativeT>
void Literal::PopulateR2FromArray2D(const Array2D<NativeT>& values) {
PopulateR2FromArray2DWithLayout(values, LayoutUtil::GetDefaultLayoutForR2());
PopulateFromArray(values);
}
template <typename NativeT>
void Literal::PopulateR3FromArray3DWithLayout(const Array3D<NativeT>& values,
const Layout& layout) {
*mutable_shape() = ShapeUtil::MakeShapeWithLayout(
primitive_util::NativeToPrimitiveType<NativeT>(),
{values.n1(), values.n2(), values.n3()},
AsInt64Slice(layout.minor_to_major()));
CHECK_EQ(values.n1(), shape().dimensions(0));
CHECK_EQ(values.n2(), shape().dimensions(1));
CHECK_EQ(values.n3(), shape().dimensions(2));
Reserve(values.n1() * values.n2() * values.n3());
for (int64 dim0 = 0; dim0 < values.n1(); ++dim0) {
for (int64 dim1 = 0; dim1 < values.n2(); ++dim1) {
for (int64 dim2 = 0; dim2 < values.n3(); ++dim2) {
Set({dim0, dim1, dim2}, values(dim0, dim1, dim2));
}
}
}
PopulateFromArrayWithLayout(values, layout);
}
template <typename NativeT>
void Literal::PopulateR3FromArray3D(const Array3D<NativeT>& values) {
PopulateR3FromArray3DWithLayout(values, LayoutUtil::GetDefaultLayoutForR3());
PopulateFromArray(values);
}
template <typename NativeT>
void Literal::PopulateR4FromArray4DWithLayout(const Array4D<NativeT>& values,
const Layout& layout) {
*mutable_shape() = ShapeUtil::MakeShapeWithLayout(
primitive_util::NativeToPrimitiveType<NativeT>(),
{values.planes(), values.depth(), values.height(), values.width()},
AsInt64Slice(layout.minor_to_major()));
CHECK_EQ(values.n1(), shape().dimensions(0));
CHECK_EQ(values.n2(), shape().dimensions(1));
CHECK_EQ(values.n3(), shape().dimensions(2));
CHECK_EQ(values.n4(), shape().dimensions(3));
Reserve(values.n1() * values.n2() * values.n3() * values.n4());
for (int64 dim0 = 0; dim0 < values.n1(); ++dim0) {
for (int64 dim1 = 0; dim1 < values.n2(); ++dim1) {
for (int64 dim2 = 0; dim2 < values.n3(); ++dim2) {
for (int64 dim3 = 0; dim3 < values.n4(); ++dim3) {
Set({dim0, dim1, dim2, dim3}, values(dim0, dim1, dim2, dim3));
}
}
}
}
PopulateFromArrayWithLayout(values, layout);
}
template <typename NativeT>
void Literal::PopulateR4FromArray4D(const Array4D<NativeT>& values) {
PopulateR4FromArray4DWithLayout(values, LayoutUtil::GetDefaultLayoutForR4());
PopulateFromArray(values);
}
template <typename NativeT, typename FnType>
......
......@@ -37,20 +37,6 @@ bool ProtobufEquals(const tensorflow::protobuf::Message& m1,
return (serialized1 == serialized2);
}
StatusOr<string> ToJson(const tensorflow::protobuf::Message& message) {
string json_output;
tensorflow::protobuf::util::JsonPrintOptions json_options;
json_options.add_whitespace = true;
json_options.always_print_primitive_fields = true;
auto status = tensorflow::protobuf::util::MessageToJsonString(
message, &json_output, json_options);
if (!status.ok()) {
return InternalError("MessageToJsonString failed: %s",
status.error_message().data());
}
return json_output;
}
namespace {
string SanitizeFilename(const string& file_name) {
......@@ -65,17 +51,6 @@ string SanitizeFilename(const string& file_name) {
} // namespace
Status DumpJsonToDirectory(const tensorflow::protobuf::Message& message,
const string& directory, const string& file_name) {
TF_ASSIGN_OR_RETURN(const string json_output, ToJson(message));
tensorflow::Env* env = tensorflow::Env::Default();
TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(directory));
string safe_file_name = SanitizeFileName(file_name) + ".json";
const string path = tensorflow::io::JoinPath(directory, safe_file_name);
return tensorflow::WriteStringToFile(env, path, json_output);
}
Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message,
const string& directory, const string& file_name) {
tensorflow::Env* env = tensorflow::Env::Default();
......
......@@ -32,17 +32,12 @@ namespace protobuf_util {
extern bool ProtobufEquals(const tensorflow::protobuf::Message& m1,
const tensorflow::protobuf::Message& m2);
// Returns 'message' as a JSON string.
StatusOr<string> ToJson(const tensorflow::protobuf::Message& message);
// Writes the given message in binary proto or JSON format to the path formed by
// joining 'directory/file_name.pb' (or file_name.json). The 'directory' is
// recursively created if it doesn't already exist, and the 'file_name' is
// sanitized by replacing illegal characters with underscore '_'.
// Writes the given message in binary proto to the path formed by joining
// 'directory/file_name.pb'. The 'directory' is recursively created if it
// doesn't already exist, and the 'file_name' is sanitized by replacing
// illegal characters with underscore '_'.
Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message,
const string& directory, const string& file_name);
Status DumpJsonToDirectory(const tensorflow::protobuf::Message& message,
const string& directory, const string& file_name);
} // namespace protobuf_util
} // namespace xla
......
......@@ -2064,6 +2064,29 @@ tf_cc_test(
],
)
cc_library(
name = "hlo_runner",
srcs = ["hlo_runner.cc"],
hdrs = ["hlo_runner.h"],
deps = [
":executable",
":hlo",
":transfer_manager",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:backend",
"//tensorflow/compiler/xla/service:compiler",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//third_party/eigen3",
],
)
# -----------------------------------------------------------------------------
filegroup(
......
......@@ -475,8 +475,8 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile(
// ownership is std::moved.
const bool embed_ir_in_executable =
module->config().debug_options().xla_embed_ir_in_executable();
const string dump_debug_json_to =
module->config().debug_options().xla_dump_debug_json_to();
const string xla_dump_hlo_proto_to =
module->config().debug_options().xla_dump_hlo_proto_to();
if (options::CpuParallelBackendRequested(module->config())) {
VLOG(1) << "Using parallel cpu backend";
......@@ -496,10 +496,10 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile(
// print one ourselves.
XLA_VLOG_LINES(2, assignment->ToString());
if (!dump_debug_json_to.empty()) {
if (!xla_dump_hlo_proto_to.empty()) {
HloProto proto = MakeHloProto(*module, *assignment);
TF_RETURN_IF_ERROR(protobuf_util::DumpJsonToDirectory(
proto, dump_debug_json_to, module->name()));
TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory(
proto, xla_dump_hlo_proto_to, module->name()));
}
// If we are using the parallel CPU backend, we need to create map from
......@@ -603,12 +603,11 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile(
// print one ourselves.
XLA_VLOG_LINES(2, assignment->ToString());
if (!dump_debug_json_to.empty()) {
if (!xla_dump_hlo_proto_to.empty()) {
HloProto proto = MakeHloProto(*module, *assignment);
TF_RETURN_IF_ERROR(protobuf_util::DumpJsonToDirectory(
proto, dump_debug_json_to, module->name()));
TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory(
proto, xla_dump_hlo_proto_to, module->name()));
}
// Each computation is a single function. Emit all embedded computations
// before the entry computation. The order of computations returned from
// GetEmbeddedComputations guarantees that a called computation occurs
......@@ -775,12 +774,12 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
// print one ourselves.
XLA_VLOG_LINES(2, assignment->ToString());
const string dump_debug_json_to =
module->config().debug_options().xla_dump_debug_json_to();
if (!dump_debug_json_to.empty()) {
const string xla_dump_hlo_proto_to =
module->config().debug_options().xla_dump_hlo_proto_to();
if (!xla_dump_hlo_proto_to.empty()) {
HloProto proto = MakeHloProto(*module, *assignment);
TF_RETURN_IF_ERROR(protobuf_util::DumpJsonToDirectory(
proto, dump_debug_json_to, module->name()));
TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory(
proto, xla_dump_hlo_proto_to, module->name()));
}
IrEmitter ir_emitter(*module, *assignment, &llvm_module,
......
......@@ -136,6 +136,8 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount(
instruction->opcode() == HloOpcode::kCall ||
instruction->opcode() == HloOpcode::kCustomCall ||
instruction->opcode() == HloOpcode::kSelectAndScatter ||
instruction->opcode() == HloOpcode::kGetTupleElement ||
instruction->opcode() == HloOpcode::kBitcast ||
(instruction->opcode() == HloOpcode::kConvolution &&
PotentiallyImplementedAsEigenConvolution(*instruction)) ||
PotentiallyImplementedAsEigenDot(*instruction) ||
......
......@@ -318,12 +318,12 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::Compile(
// print one ourselves.
XLA_VLOG_LINES(2, buffer_assignment->ToString());
const string dump_debug_json_to =
module->config().debug_options().xla_dump_debug_json_to();
if (!dump_debug_json_to.empty()) {
const string xla_dump_hlo_proto_to =
module->config().debug_options().xla_dump_hlo_proto_to();
if (!xla_dump_hlo_proto_to.empty()) {
HloProto proto = MakeHloProto(*module, *buffer_assignment);
TF_RETURN_IF_ERROR(protobuf_util::DumpJsonToDirectory(
proto, dump_debug_json_to, module->name()));
TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory(
proto, xla_dump_hlo_proto_to, module->name()));
}
IrEmitterContext ir_emitter_context(module.get(), buffer_assignment.get(),
......
......@@ -373,8 +373,8 @@ string HloComputation::ToString(int nested_level) const {
for (int i = 0; i < nested_level; i++) {
s << " ";
}
s << name() << " " << ShapeUtil::HumanString(ComputeProgramShape())
<< " { \n";
s << "%" << name() << " " << ShapeUtil::HumanString(ComputeProgramShape())
<< " {\n";
for (const HloInstruction* instruction : MakeInstructionPostOrder()) {
for (int i = 0; i < nested_level; i++) {
s << " ";
......
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_runner.h"
#include <set>
#include <string>
#include <utility>
#define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
namespace se = ::perftools::gputools;
namespace xla {
/*static*/ StatusOr<std::unique_ptr<HloModule>>
HloRunner::ReadModuleFromHloProtoFile(const char* filename,
const DebugOptions& debug_options) {
HloProto proto;
TF_RETURN_IF_ERROR(tensorflow::ReadBinaryProto(tensorflow::Env::Default(),
filename, &proto));
HloModuleConfig config;
config.set_debug_options(debug_options);
TF_ASSIGN_OR_RETURN(auto module, HloModule::CreateFromProto(
proto.hlo_module(),
VersionedComputationHandle(), config));
return std::move(module);
}
// Define this in .cc file to avoid having to include eigen or forward declare
// these types in the header.
struct HloRunner::EigenThreadPoolWrapper {
std::unique_ptr<EigenThreadPoolWrapper> pool;
std::unique_ptr<Eigen::ThreadPoolDevice> device;
};
HloRunner::HloRunner() {}
HloRunner::HloRunner(se::Platform* platform) {
BackendOptions backend_options;
backend_options.set_platform(platform);
backend_ = Backend::CreateBackend(backend_options).ConsumeValueOrDie();
VLOG(1) << "Created HloRunner for platform: " << platform->Name();
}
HloRunner::~HloRunner() {
// Deallocate all the memory allocated during the tests.
for (auto& allocation : allocations_) {
backend().default_stream_executor()->Deallocate(&allocation);
}
}
StatusOr<se::DeviceMemoryBase> HloRunner::Execute(
std::unique_ptr<HloModule> module,
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments,
Shape* result_shape) {
TF_ASSIGN_OR_RETURN(
std::unique_ptr<Executable> executable,
backend().compiler()->Compile(std::move(module),
backend().default_stream_executor()));
se::Stream stream(backend().default_stream_executor());
stream.Init();
ExecutableRunOptions run_options;
run_options.set_stream(&stream);
run_options.set_allocator(backend().memory_allocator());
run_options.set_inter_op_thread_pool(backend().inter_op_thread_pool());
run_options.set_intra_op_thread_pool(
backend().eigen_intra_op_thread_pool_device());
HloExecutionProfile hlo_execution_profile;
ServiceExecutableRunOptions service_run_options(
run_options, backend().StreamBorrower(),
backend().inter_op_thread_pool());
TF_ASSIGN_OR_RETURN(
se::DeviceMemoryBase result,
executable->ExecuteOnStream(&service_run_options, arguments,
&hlo_execution_profile));
TF_RET_CHECK(stream.BlockHostUntilDone());
allocations_.push_back(result);
*result_shape = executable->result_shape();
if (ShapeUtil::IsTuple(*result_shape)) {
// We must record element buffers of tuples as well to avoid leaks.
DCHECK(!ShapeUtil::IsNestedTuple(*result_shape));
TF_ASSIGN_OR_RETURN(
std::vector<se::DeviceMemoryBase> element_buffers,
backend().transfer_manager()->ShallowCopyTupleFromDevice(
backend().default_stream_executor(), result, *result_shape));
// A tuple may contain the same buffer in more than one element. Keep track
// of the buffers already added to avoid duplicates in allocations_.
std::set<void*> added_opaques;
for (auto element_buffer : element_buffers) {
if (added_opaques.count(element_buffer.opaque()) == 0) {
CHECK(element_buffer.opaque() != nullptr);
added_opaques.insert(element_buffer.opaque());
allocations_.push_back(element_buffer);
}
}
}
return result;
}
se::DeviceMemoryBase HloRunner::TransferToDevice(const Literal& literal) {
// Allocate memory on the device using the stream executor.
int64 allocation_size =
backend().transfer_manager()->GetByteSizeRequirement(literal.shape());
se::DeviceMemoryBase allocation =
backend().default_stream_executor()->AllocateArray<uint8>(
allocation_size);
allocations_.push_back(allocation);
TF_CHECK_OK(backend().transfer_manager()->TransferLiteralToDevice(
backend().default_stream_executor(), literal, &allocation));
return allocation;
}
std::unique_ptr<Literal> HloRunner::TransferFromDevice(
const Shape& shape, se::DeviceMemoryBase device_base) {
auto literal = MakeUnique<Literal>();
TF_CHECK_OK(backend().transfer_manager()->TransferLiteralFromDevice(
backend().default_stream_executor(), device_base, shape, shape,
literal.get()));
return literal;
}
std::unique_ptr<Literal> HloRunner::ExecuteAndTransfer(
std::unique_ptr<HloModule> module,
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments) {
Shape result_shape;
se::DeviceMemoryBase device_base =
Execute(std::move(module), arguments, &result_shape).ValueOrDie();
return TransferFromDevice(result_shape, device_base);
}
template <>
std::unique_ptr<Literal> HloRunner::Execute(
std::unique_ptr<HloModule> module,
const tensorflow::gtl::ArraySlice<std::unique_ptr<Literal>>& literals) {
std::vector<se::DeviceMemoryBase> arguments;
for (const auto& literal : literals) {
arguments.push_back(TransferToDevice(*literal));
}
return ExecuteAndTransfer(std::move(module), arguments);
}
template <>
std::unique_ptr<Literal> HloRunner::Execute(
std::unique_ptr<HloModule> module,
const tensorflow::gtl::ArraySlice<Literal*>& literals) {
std::vector<se::DeviceMemoryBase> arguments;
for (const auto& literal : literals) {
arguments.push_back(TransferToDevice(*literal));
}
return ExecuteAndTransfer(std::move(module), arguments);
}
Backend& HloRunner::backend() {
if (!backend_) {
backend_ = Backend::CreateDefaultBackend().ConsumeValueOrDie();
VLOG(1) << "executing on platform " << backend().platform()->Name();
}
return *backend_;
}
} // namespace xla
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_
#include <memory>
#include <string>
#include <vector>
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
namespace xla {
// A base class for running an HloModule. This executes the given HloModule on a
// certain backend directly without using the client interface. HloModule can be
// explicitly built, or loaded from a serialization file (e.g., hlo proto file).
class HloRunner {
public:
HloRunner();
HloRunner(::perftools::gputools::Platform* platform);
~HloRunner();
// Reads the binary proto file in xla.HloProto format, creates and returns the
// HloModule.
static StatusOr<std::unique_ptr<HloModule>> ReadModuleFromHloProtoFile(
const char* filename, const DebugOptions& debug_options);
// Executes the given module with given literals as input and returns the
// result as a Literal. The LiteralPtr type accepts Literal* or
// std::unique_ptr<Literal>.
template <typename LiteralPtr>
std::unique_ptr<Literal> Execute(
std::unique_ptr<HloModule> module,
const tensorflow::gtl::ArraySlice<LiteralPtr>& literals);
// Executes the given module and returns a global data handle.
StatusOr<perftools::gputools::DeviceMemoryBase> Execute(
std::unique_ptr<HloModule> module,
tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
arguments,
Shape* result_shape);
// Transfers the given literal to the device and returns the data handle.
perftools::gputools::DeviceMemoryBase TransferToDevice(
const Literal& literal);
// Transfers the array referred to by the given handle from the device and
// returns as a Literal.
std::unique_ptr<Literal> TransferFromDevice(
const Shape& shape, perftools::gputools::DeviceMemoryBase device_base);
// Executes the given module and return the result as a Literal.
std::unique_ptr<Literal> ExecuteAndTransfer(
std::unique_ptr<HloModule> module,
tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
arguments);
// If backend is not created in the constructor, creates and returns the
// default backend. If creation fails, crashes the program.
//
// This creates the backend lazily so it's possible to instantiate an
// HloRunner in a program without any backends linked in.
Backend& backend();
private:
struct EigenThreadPoolWrapper;
std::vector<perftools::gputools::DeviceMemoryBase> allocations_;
std::unique_ptr<EigenThreadPoolWrapper> thread_pool_wrapper_;
std::unique_ptr<Backend> backend_;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_
......@@ -58,14 +58,32 @@ TransposeFolding::OperandIndices CanFoldOperandsIntoConvolution(
return {};
}
// We only support folding the RHS.
const int64 kRhsOperandIndex = 1;
auto& operand = *convolution.operand(kRhsOperandIndex);
if (operand.opcode() == HloOpcode::kTranspose && operand.user_count() == 1) {
return transposable_conv_operands(convolution, {kRhsOperandIndex});
const ConvolutionDimensionNumbers& dnums =
convolution.convolution_dimension_numbers();
TransposeFolding::OperandIndices operand_set;
for (int64 i = 0; i < convolution.operand_count(); ++i) {
auto& operand = *convolution.operand(i);
if (operand.opcode() == HloOpcode::kTranspose &&
operand.user_count() == 1) {
const auto& transpose_dimensions = operand.dimensions();
// We can transpose the LHS so long as it doesn't move around spatial
// dimensions because ConvolutionDimensionNumbers doesn't have different
// fields for input and output spatial dimensions.
if (i == 0 &&
std::any_of(dnums.spatial_dimensions().begin(),
dnums.spatial_dimensions().end(),
[&](const int64 spatial_dimension) {
return transpose_dimensions[spatial_dimension] !=
spatial_dimension;
})) {
continue;
}
operand_set.push_back(i);
}
}
return {};
return transposable_conv_operands(convolution, operand_set);
}
using InstructionOperandsPair =
......@@ -98,40 +116,61 @@ bool FoldTransposeIntoDot(InstructionOperandsPair pair) {
// Returns whether the module is changed.
bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) {
auto& convolution = *pair.first;
// We only support fusing the RHS transpose into convolution.
//
// ConvolutionDimensionNumbers doesn't make enough of a distinction between
// the output and the activations.
//
// TODO(b/37125184): Support transposing the LHS too.
if (pair.second.size() != 1 || pair.second.front() != 1) {
return false;
}
auto& operand_indices = pair.second;
const ConvolutionDimensionNumbers& dnums =
convolution.convolution_dimension_numbers();
HloInstruction& transpose = *convolution.mutable_operand(1);
CHECK_EQ(transpose.opcode(), HloOpcode::kTranspose);
const auto& transpose_dimensions = transpose.dimensions();
HloInstruction& transpose_operand = *transpose.mutable_operand(0);
// Everything remains the same except for the kernel dimension numbers. We
// need to apply the transpose permutation to the original shape to figure out
// what the new logical dimensions are.
ConvolutionDimensionNumbers new_dnums = dnums;
new_dnums.set_kernel_input_feature_dimension(
transpose_dimensions[dnums.kernel_input_feature_dimension()]);
new_dnums.set_kernel_output_feature_dimension(
transpose_dimensions[dnums.kernel_output_feature_dimension()]);
for (auto& kernel_spatial_dimension :
*new_dnums.mutable_kernel_spatial_dimensions()) {
kernel_spatial_dimension = transpose_dimensions[kernel_spatial_dimension];
HloInstruction* new_lhs;
const int64 kLhsIdx = 0;
if (std::find(operand_indices.begin(), operand_indices.end(), kLhsIdx) !=
operand_indices.end()) {
HloInstruction& transpose = *convolution.mutable_operand(kLhsIdx);
const auto& transpose_dimensions = transpose.dimensions();
HloInstruction& transpose_operand = *transpose.mutable_operand(0);
// Everything remains the same except for the input/output dimension
// numbers. We need to apply the transpose permutation to the original shape
// to figure out what the new logical dimensions are.
new_dnums.set_input_batch_dimension(
transpose_dimensions[dnums.input_batch_dimension()]);
new_dnums.set_input_feature_dimension(
transpose_dimensions[dnums.input_feature_dimension()]);
for (const auto& spatial_dimension : dnums.spatial_dimensions()) {
CHECK_EQ(spatial_dimension, transpose_dimensions[spatial_dimension]);
}
new_lhs = &transpose_operand;
} else {
new_lhs = convolution.mutable_operand(kLhsIdx);
}
HloInstruction* new_rhs;
const int64 kRhsIdx = 1;
if (std::find(operand_indices.begin(), operand_indices.end(), kRhsIdx) !=
operand_indices.end()) {
HloInstruction& transpose = *convolution.mutable_operand(kRhsIdx);
const auto& transpose_dimensions = transpose.dimensions();
HloInstruction& transpose_operand = *transpose.mutable_operand(0);
// Everything remains the same except for the kernel dimension numbers. We
// need to apply the transpose permutation to the original shape to figure
// out what the new logical dimensions are.
new_dnums.set_kernel_input_feature_dimension(
transpose_dimensions[dnums.kernel_input_feature_dimension()]);
new_dnums.set_kernel_output_feature_dimension(
transpose_dimensions[dnums.kernel_output_feature_dimension()]);
for (auto& kernel_spatial_dimension :
*new_dnums.mutable_kernel_spatial_dimensions()) {
kernel_spatial_dimension = transpose_dimensions[kernel_spatial_dimension];
}
new_rhs = &transpose_operand;
} else {
new_rhs = convolution.mutable_operand(kRhsIdx);
}
auto new_conv = HloInstruction::CreateConvolve(
convolution.shape(), convolution.mutable_operand(0), &transpose_operand,
convolution.window(), new_dnums);
convolution.shape(), new_lhs, new_rhs, convolution.window(), new_dnums);
TF_CHECK_OK(convolution.parent()->ReplaceWithNewInstruction(
&convolution, std::move(new_conv)));
......
......@@ -313,8 +313,7 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeRhs) {
new_conv->convolution_dimension_numbers().kernel_spatial_dimensions(1));
}
// Test that a transpose of the activations does not get folded into
// convolution.
// Test that a transpose of the activations gets folded into convolution.
TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) {
auto builder = HloComputation::Builder("entry_computation");
HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
......@@ -348,18 +347,25 @@ TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) {
module.AddEntryComputation(builder.Build(conv));
FoldTranspose(&module);
// Instructions after folding: transpose_x, y, and the convolution.
// Instructions after folding: x, y, and the convolution.
std::unordered_set<HloInstruction*> instruction_set(
entry_computation->instructions().begin(),
entry_computation->instructions().end());
CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation.";
CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation.";
CHECK_EQ(1, instruction_set.erase(transpose_x))
<< "transpose_x is not in entry_computation.";
CHECK_EQ(1, instruction_set.erase(conv))
<< "transpose_x is not in entry_computation.";
CHECK_EQ(0, instruction_set.size())
<< "entry_computation should contain exactly 4 instructions.";
EXPECT_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation.";
EXPECT_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation.";
EXPECT_EQ(1, instruction_set.size())
<< "entry_computation should contain exactly 3 instructions.";
HloInstruction* new_conv = *instruction_set.begin();
EXPECT_EQ(HloOpcode::kConvolution, new_conv->opcode());
EXPECT_EQ(dnums.input_feature_dimension(),
new_conv->convolution_dimension_numbers().input_batch_dimension());
EXPECT_EQ(
dnums.input_batch_dimension(),
new_conv->convolution_dimension_numbers().input_feature_dimension());
EXPECT_EQ(dnums.spatial_dimensions(0),
new_conv->convolution_dimension_numbers().spatial_dimensions(0));
EXPECT_EQ(dnums.spatial_dimensions(1),
new_conv->convolution_dimension_numbers().spatial_dimensions(1));
}
} // namespace
......
......@@ -20,6 +20,7 @@ limitations under the License.
#include <stack>
#include <unordered_map>
#include <utility>
#include <vector>
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal_util.h"
......@@ -1843,10 +1844,17 @@ UserComputation::GetEmbeddedComputations(
XLA_VLOG_LINES(3, session_computation_.DebugString());
std::vector<VersionedComputationHandle> computations;
std::vector<int64> sorted_handles;
for (const auto& handle_request : session_computation_.requests()) {
int64 handle_value = handle_request.first;
sorted_handles.push_back(handle_request.first);
}
std::sort(sorted_handles.begin(), sorted_handles.end());
for (int64 handle : sorted_handles) {
const auto& handle_request = session_computation_.requests().find(handle);
CHECK(handle_request != session_computation_.requests().end());
int64 handle_value = handle_request->first;
if (handle_value <= version) {
const OperationRequest& request = handle_request.second;
const OperationRequest& request = handle_request->second;
switch (request.request().op_case()) {
case OpRequest::kCallRequest: {
CHECK_EQ(1, request.embedded_computation_versions_size());
......
......@@ -102,6 +102,32 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) {
return true;
}
// Constructs and returns the new shape with the given minor_to_major order in
// its Layout.
StatusOr<Shape> MakeShapeWithLayoutInternal(
PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions,
tensorflow::gtl::ArraySlice<int64> minor_to_major) {
if (dimensions.size() != minor_to_major.size()) {
return InvalidArgument("Dimensions size is %ld, but layout size is %ld.",
dimensions.size(), minor_to_major.size());
}
if (element_type == OPAQUE || element_type == TUPLE) {
return InvalidArgument("Unsupported element type: %s",
PrimitiveType_Name(element_type).c_str());
}
Shape shape = ShapeUtil::MakeShape(element_type, dimensions);
auto min2maj = shape.mutable_layout()->mutable_minor_to_major();
min2maj->Clear();
for (int64 value : minor_to_major) {
min2maj->Add(value);
}
if (!shape.has_layout()) {
return InvalidArgument("Shape has no layout.");
}
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(shape));
return shape;
}
} // namespace
/* static */ bool ShapeUtil::Equal(const Shape& lhs, const Shape& rhs) {
......@@ -152,16 +178,8 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) {
/* static */ Shape ShapeUtil::MakeShapeWithLayout(
PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions,
tensorflow::gtl::ArraySlice<int64> minor_to_major) {
CHECK_EQ(dimensions.size(), minor_to_major.size());
Shape shape = MakeShape(element_type, dimensions);
auto min2maj = shape.mutable_layout()->mutable_minor_to_major();
min2maj->Clear();
for (int64 value : minor_to_major) {
min2maj->Add(value);
}
DCHECK(shape.has_layout());
TF_DCHECK_OK(ValidateShape(shape));
return shape;
return MakeShapeWithLayoutInternal(element_type, dimensions, minor_to_major)
.ValueOrDie();
}
/* static */ Shape ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(
......@@ -499,11 +517,10 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
// Extract the layout minor-to-major and set it.
TF_ASSIGN_OR_RETURN(std::vector<int64> min2maj,
comma_list_to_int64s(layout_string));
TF_RET_CHECK(dimensions.size() == min2maj.size());
result =
ShapeUtil::MakeShapeWithLayout(primitive_type, dimensions, min2maj);
TF_ASSIGN_OR_RETURN(result, MakeShapeWithLayoutInternal(
primitive_type, dimensions, min2maj));
}
TF_DCHECK_OK(ShapeUtil::ValidateShape(result));
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(result));
return std::move(result);
}
......
......@@ -102,28 +102,18 @@ cc_library(
deps = [
":literal_test_util",
"//tensorflow/compiler/xla:shape_layout",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/service",
"//tensorflow/compiler/xla/service:backend",
"//tensorflow/compiler/xla/service:compiler",
"//tensorflow/compiler/xla/service:computation_layout",
"//tensorflow/compiler/xla/service:computation_placer",
"//tensorflow/compiler/xla/service:executable",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_execution_profile",
"//tensorflow/compiler/xla/service:hlo_graph_dumper",
"//tensorflow/compiler/xla/service:transfer_manager",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/compiler/xla/service:hlo_runner",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core:test",
"//third_party/eigen3",
],
)
......
......@@ -19,24 +19,9 @@ limitations under the License.
#include <string>
#include <utility>
#define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
#include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/compiler/xla/shape_layout.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
......@@ -45,22 +30,6 @@ namespace se = ::perftools::gputools;
namespace xla {
// Define this in .cc file to avoid having to include eigen or forward declare
// these types in the header.
struct HloTestBase::EigenThreadPoolWrapper {
std::unique_ptr<EigenThreadPoolWrapper> pool;
std::unique_ptr<Eigen::ThreadPoolDevice> device;
};
HloTestBase::HloTestBase() {}
HloTestBase::~HloTestBase() {
// Deallocate all the memory allocated during the tests.
for (auto& allocation : allocations_) {
backend().default_stream_executor()->Deallocate(&allocation);
}
}
/* static */
std::unique_ptr<HloModule> HloTestBase::CreateNewModule() {
HloModuleConfig config;
......@@ -80,98 +49,25 @@ StatusOr<perftools::gputools::DeviceMemoryBase> HloTestBase::Execute(
tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
arguments,
Shape* result_shape) {
TF_ASSIGN_OR_RETURN(
std::unique_ptr<Executable> executable,
backend().compiler()->Compile(std::move(module),
backend().default_stream_executor()));
se::Stream stream(backend().default_stream_executor());
stream.Init();
ExecutableRunOptions run_options;
run_options.set_stream(&stream);
run_options.set_allocator(backend().memory_allocator());
run_options.set_inter_op_thread_pool(backend().inter_op_thread_pool());
run_options.set_intra_op_thread_pool(
backend().eigen_intra_op_thread_pool_device());
HloExecutionProfile hlo_execution_profile;
ServiceExecutableRunOptions service_run_options(
run_options, backend().StreamBorrower(),
backend().inter_op_thread_pool());
TF_ASSIGN_OR_RETURN(
se::DeviceMemoryBase result,
executable->ExecuteOnStream(&service_run_options, arguments,
&hlo_execution_profile));
TF_RET_CHECK(stream.BlockHostUntilDone());
allocations_.push_back(result);
*result_shape = executable->result_shape();
if (ShapeUtil::IsTuple(*result_shape)) {
// We must record element buffers of tuples as well to avoid leaks.
DCHECK(!ShapeUtil::IsNestedTuple(*result_shape));
TF_ASSIGN_OR_RETURN(
std::vector<se::DeviceMemoryBase> element_buffers,
backend().transfer_manager()->ShallowCopyTupleFromDevice(
backend().default_stream_executor(), result, *result_shape));
// A tuple may contain the same buffer in more than one element. Keep track
// of the buffers already added to avoid duplicates in allocations_.
std::set<void*> added_opaques;
for (auto element_buffer : element_buffers) {
if (added_opaques.count(element_buffer.opaque()) == 0) {
CHECK(element_buffer.opaque() != nullptr);
added_opaques.insert(element_buffer.opaque());
allocations_.push_back(element_buffer);
}
}
}
return result;
return runner_.Execute(std::move(module), arguments, result_shape);
}
se::DeviceMemoryBase HloTestBase::TransferToDevice(const Literal& literal) {
// Allocate memory on the device using the stream executor.
int64 allocation_size =
backend().transfer_manager()->GetByteSizeRequirement(literal.shape());
se::DeviceMemoryBase allocation =
backend().default_stream_executor()->AllocateArray<uint8>(
allocation_size);
allocations_.push_back(allocation);
TF_CHECK_OK(backend().transfer_manager()->TransferLiteralToDevice(
backend().default_stream_executor(), literal, &allocation));
return allocation;
return runner_.TransferToDevice(literal);
}
std::unique_ptr<Literal> HloTestBase::TransferFromDevice(
const Shape& shape, se::DeviceMemoryBase device_base) {
auto literal = MakeUnique<Literal>();
TF_CHECK_OK(backend().transfer_manager()->TransferLiteralFromDevice(
backend().default_stream_executor(), device_base, shape, shape,
literal.get()));
return literal;
return runner_.TransferFromDevice(shape, device_base);
}
std::unique_ptr<Literal> HloTestBase::ExecuteAndTransfer(
std::unique_ptr<HloModule> module,
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments) {
Shape result_shape;
se::DeviceMemoryBase device_base =
Execute(std::move(module), arguments, &result_shape).ValueOrDie();
return TransferFromDevice(result_shape, device_base);
return runner_.ExecuteAndTransfer(std::move(module), arguments);
}
Backend& HloTestBase::backend() {
if (!backend_) {
backend_ = Backend::CreateDefaultBackend().ConsumeValueOrDie();
VLOG(1) << "executing on platform " << backend().platform()->Name();
}
return *backend_;
}
Backend& HloTestBase::backend() { return runner_.backend(); }
/* static */
string HloTestBase::TestName() {
......
......@@ -21,12 +21,12 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_runner.h"
#include "tensorflow/compiler/xla/shape_layout.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
......@@ -39,10 +39,9 @@ namespace xla {
// building a graph of HLO instructions to run.
class HloTestBase : public ::testing::Test {
protected:
struct EigenThreadPoolWrapper;
HloTestBase();
HloTestBase() {}
~HloTestBase() override;
~HloTestBase() override {}
// Creates a new HLO module for a test. The module created will have
// TestName() for its name; it will also automatically populate its debug
......@@ -102,23 +101,12 @@ class HloTestBase : public ::testing::Test {
static string TestName();
// Creates (if necessary) and returns the default backend. If creation fails,
// crashes the program.
//
// This creates the backend lazily so it's possible to instantiate an
// HloTestBase in a program without any backends linked in.
// Returns the backend owned by the HloRunner.
Backend& backend();
// This vector contains handles of all the device memory allocations performed
// by the test. These are deallocated on destruction of the test object.
std::vector<perftools::gputools::DeviceMemoryBase> allocations_;
HloRunner runner_;
ErrorSpec error_spec_{0.0001};
std::unique_ptr<EigenThreadPoolWrapper> thread_pool_wrapper_;
private:
std::unique_ptr<Backend> backend_; // Lazily populated. Access via backend().
};
} // namespace xla
......
......@@ -210,6 +210,18 @@ tf_cc_binary(
],
)
tf_cc_binary(
name = "hlo_proto_to_json",
srcs = ["hlo_proto_to_json.cc"],
deps = [
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
],
)
# -----------------------------------------------------------------------------
filegroup(
......
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Usage:
// hlo_proto_to_json --input_file=some_binary_proto
// --output_file=path_to_dump_output
//
// Reads one serilized Hlo module, convert it into JSON format and dump into
// some output directory. some_binaray_proto is obtained by serializing Hlo
// module to disk using --xla_dump_hlo_proto_to debug optoin.
#include <stdio.h>
#include <string>
#include <vector>
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/command_line_flags.h"
using tensorflow::Env;
using xla::string;
namespace xla {
namespace tools {
StatusOr<string> ToJson(const tensorflow::protobuf::Message& message) {
string json_output;
tensorflow::protobuf::util::JsonPrintOptions json_options;
json_options.add_whitespace = true;
json_options.always_print_primitive_fields = true;
auto status = tensorflow::protobuf::util::MessageToJsonString(
message, &json_output, json_options);
if (!status.ok()) {
return InternalError("MessageToJsonString failed: %s",
status.error_message().data());
}
return json_output;
}
void RealMain(const string& input, const string& output) {
HloProto hlo_proto;
TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), input,
&hlo_proto))
<< "Can't open, read, or parse input file " << input;
auto statusor = ToJson(hlo_proto);
QCHECK(statusor.ok()) << "Error converting " << input << " to JSON."
<< statusor.status();
TF_CHECK_OK(tensorflow::WriteStringToFile(tensorflow::Env::Default(), output,
statusor.ValueOrDie()));
}
} // namespace tools
} // namespace xla
int main(int argc, char** argv) {
string input_file, output_file;
const std::vector<tensorflow::Flag> flag_list = {
tensorflow::Flag("input_file", &input_file, "file to convert."),
tensorflow::Flag("output_file", &output_file, "converted file"),
};
const string usage = tensorflow::Flags::Usage(argv[0], flag_list);
bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
tensorflow::port::InitMain(usage.c_str(), &argc, &argv);
QCHECK(parse_ok && argc == 1) << "\n" << usage;
QCHECK(!input_file.empty()) << "--input_file is required";
QCHECK(!output_file.empty()) << "--output_file is required";
xla::tools::RealMain(input_file, output_file);
return 0;
}
# Build file for the Hlo parser.
licenses(["notice"]) # Apache 2.0
package(
default_visibility = [":friends"],
)
package_group(
name = "friends",
includes = [
"//tensorflow/compiler/xla:friends",
],
)
# Filegroup used to collect source files for dependency checking.
filegroup(
name = "c_srcs",
data = glob([
"**/*.cc",
"**/*.h",
]),
)
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
cc_library(
name = "hlo_lexer",
srcs = ["hlo_lexer.cc"],
hdrs = [
"hlo_lexer.h",
"hlo_token.h",
],
deps = [
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:lib",
"//tensorflow/core:regexp_internal",
],
)
cc_library(
name = "hlo_parser",
srcs = ["hlo_parser.cc"],
hdrs = ["hlo_parser.h"],
deps = [
":hlo_lexer",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
)
tf_cc_test(
name = "hlo_parser_test",
size = "small",
srcs = ["hlo_parser_test.cc"],
deps = [
":hlo_parser",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
# -----------------------------------------------------------------------------
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
visibility = ["//tensorflow:__subpackages__"],
)
# HloModule string syntax
TODO: Support subcomputations (for fusion, reduce, while, ...).
TODO: Support ops that require extra attributes, e.g. dimensions, strides.
```yacc
hlo_module
: 'HloModule' name computation
;
computation
: 'ENTRY' name param_list '->' shape instruction_list
;
instruction_list
: '{' instruction_list1 '}'
;
instruction_list1
: instruction
| instruction_list1 instruction
;
instruction
: name '=' shape opcode operands
;
operands
: '(' operands1 ')'
;
operands1
: /*empty*/
| operand
| operands1 ',' operand
;
operand
: shape name
;
param_list
: '(' param_list1 ')'
;
param_list1
: /*empty*/
| param
| param_list1 ',' param
;
param
: name shape
;
shape
: shape_val_
| '(' tuple_elements ')'
;
tuple_elements
: /*empty*/
| shape (',' shape)*
;
name
: identifier ':'
| '%' identifier
;
identifier
: [a-zA-Z_][a-zA-Z0-9_.-]*
;
```
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/tools/parser/hlo_lexer.h"
#include <unordered_map>
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/platform/regexp.h"
namespace xla {
namespace tools {
using tensorflow::StringPiece;
namespace {
constexpr int kEOF = -1;
constexpr int kError = -2;
// [a-zA-Z0-9_.-]
bool IsIdentifierChar(char c) {
return isalnum(static_cast<unsigned char>(c)) || c == '-' || c == '.' ||
c == '_';
}
} // namespace
int HloLexer::GetNextChar() {
int current_char = PeekCurrentChar();
if (current_char != kEOF && current_char != kError) {
current_ptr_++;
}
return current_char;
}
int HloLexer::PeekCurrentChar() const {
if (current_ptr_ == buf_.end()) {
return kEOF;
}
char current_char = *current_ptr_;
if (current_char == 0) {
// '\0' should not appear in the middle of the string.
return kError;
}
return static_cast<unsigned char>(current_char);
}
bool HloLexer::CanDereference(const char* ptr) const {
return ptr < buf_.end() && ptr >= buf_.begin();
}
StringPiece HloLexer::StringPieceFromPointers(const char* begin,
const char* end) const {
CHECK(begin <= end);
CHECK(begin == buf_.end() || CanDereference(begin));
CHECK(end == buf_.end() || CanDereference(end));
return StringPiece(begin, end - begin);
}
tensorflow::RegexpStringPiece HloLexer::RegexpStringPieceFromPointers(
const char* begin, const char* end) const {
CHECK(begin <= end);
CHECK(begin == buf_.end() || CanDereference(begin));
CHECK(end == buf_.end() || CanDereference(end));
return tensorflow::RegexpStringPiece(begin, end - begin);
}
TokKind HloLexer::LexToken() {
while (true) {
token_start_ = current_ptr_;
int current_char = GetNextChar();
switch (current_char) {
default:
// [a-zA-Z_]
if (isalpha(static_cast<unsigned char>(current_char)) ||
current_char == '_') {
return LexIdentifier();
}
return TokKind::kError;
case kEOF:
// Hit the end of the input buffer.
return TokKind::kEof;
case kError:
// Hit an invalid character in the input buffer.
return TokKind::kError;
case ' ':
case '\t':
case '\n':
case '\r':
// Ignore whitespace.
continue;
case '0':
case '1':
case '2':
case '3':
case '4':
case '5':
case '6':
case '7':
case '8':
case '9':
case '-':
if (current_char == '-' && PeekCurrentChar() == '>') {
current_ptr_++;
return TokKind::kArrow;
}
return LexDigitOrNegative();
case '=':
return TokKind::kEqual;
case ',':
return TokKind::kComma;
case '%':
return LexPercent();
case ':':
return TokKind::kColon;
case '[':
return TokKind::kLsquare;
case ']':
return TokKind::kRsquare;
case '{':
return TokKind::kLbrace;
case '}':
return TokKind::kRbrace;
case '(':
return TokKind::kLparen;
case ')':
return TokKind::kRparen;
}
}
}
// Lex a shape, name, keyword, or opcode.
// shape ::= ([a-zA-Z0-9_]*[0-9]*)\[([0-9,]*)\](?:\s*{([0-9,]*)})?
// name ::= [a-zA-Z_][a-zA-Z0-9_.-]*:
// keyword ::= HloModule, ENTRY, ...
// opcode ::= add, greater-than, ...
TokKind HloLexer::LexIdentifier() {
{
auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end());
// 'consumable' will be advanced iff its prefix matches the pattern.
static LazyRE2 shape_pattern = {
R"(^(\w*\d*)\[([\d,]*)\](?:\s*{([\d,]*)})?)"};
if (RE2::Consume(&consumable, *shape_pattern)) {
auto status_or_shape = ShapeUtil::ParseShapeString(
StringPieceFromPointers(token_start_, consumable.begin()));
if (status_or_shape.ok()) {
// This is a shape string.
shape_val_ = status_or_shape.ValueOrDie();
current_ptr_ = consumable.begin();
return TokKind::kShape;
}
}
}
while (IsIdentifierChar(PeekCurrentChar())) {
current_ptr_++;
}
// If followed by ':', it's a name.
if (PeekCurrentChar() == ':') {
str_val_.assign(token_start_, current_ptr_);
current_ptr_++; // skip ':'
return TokKind::kName;
}
StringPiece identifier = StringPieceFromPointers(token_start_, current_ptr_);
// See if this is a keyword.
#define KEYWORD(STR) \
do { \
if (identifier == #STR) { \
return TokKind::kw_##STR; \
} \
} while (false)
KEYWORD(true);
KEYWORD(false);
KEYWORD(HloModule);
KEYWORD(ENTRY);
#undef KEYWORD
// See if this is an opcode.
auto opcode = StringToHloOpcode(identifier.ToString());
if (opcode.ok()) {
opcode_val_ = opcode.ValueOrDie();
return TokKind::kOpcode;
}
current_ptr_ = token_start_ + 1;
return TokKind::kError;
}
// Lex names after a % character.
// name ::= [a-zA-Z_][a-zA-Z0-9_.-]*
TokKind HloLexer::LexPercent() {
const char* name_start = current_ptr_;
if (isalpha(static_cast<unsigned char>(PeekCurrentChar())) ||
PeekCurrentChar() == '_') {
current_ptr_++;
while (IsIdentifierChar(PeekCurrentChar())) {
current_ptr_++;
}
str_val_.assign(name_start, current_ptr_);
return TokKind::kName;
}
return TokKind::kError;
}
// Lex integer and floating-point values.
// int [-]?[0-9]+
// fp with exp [-]?([0-9]+|[0-9]+[.][0-9]*|[0-9]*[.][0-9]+)([eE][+-]?[0-9]+)
// fp without exp [-]?([0-9]+[.][0-9]*|[0-9]*[.][0-9]+)
TokKind HloLexer::LexDigitOrNegative() {
auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end());
static LazyRE2 float_pattern = {
R"([-]?((\d+|\d+[.]\d*|\d*[.]\d+)([eE][+-]?\d+))|(\d+[.]\d*|\d*[.]\d+))"};
if (RE2::Consume(&consumable, *float_pattern)) {
current_ptr_ = consumable.begin();
tensorflow::strings::safe_strtod(string(token_start_, current_ptr_).c_str(),
&decimal_val_);
return TokKind::kDecimal;
}
static LazyRE2 int_pattern = {R"([-]?\d+)"};
if (RE2::Consume(&consumable, *int_pattern)) {
current_ptr_ = consumable.begin();
tensorflow::strings::safe_strto64(
StringPieceFromPointers(token_start_, current_ptr_), &int64_val_);
return TokKind::kInt;
}
return TokKind::kError;
}
StringPiece HloLexer::GetCurrentLine() const {
const char* start = token_start_;
const char* end = current_ptr_;
if (!CanDereference(start) || !CanDereference(end)) {
return "LINE OUT OF RANGE";
}
while (start > buf_.begin() && *start != '\n') {
start--;
}
while (end < buf_.end() && *end != '\n') {
end++;
}
return StringPieceFromPointers(start, end);
}
} // namespace tools
} // namespace xla
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_LEXER_H_
#define TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_LEXER_H_
#include <string>
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/tools/parser/hlo_token.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/regexp.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
namespace tools {
// Lexer for the HloModule::ToString() format text.
class HloLexer {
public:
explicit HloLexer(tensorflow::StringPiece buf) : buf_(buf) {
current_ptr_ = buf_.begin();
}
TokKind Lex() { return current_kind_ = LexToken(); }
TokKind GetKind() const { return current_kind_; }
string GetStrVal() const {
CHECK(GetKind() == TokKind::kName);
return str_val_;
}
Shape GetShapeVal() const {
CHECK(GetKind() == TokKind::kShape);
return shape_val_;
}
HloOpcode GetOpcodeVal() const {
CHECK(GetKind() == TokKind::kOpcode);
return opcode_val_;
}
int64 GetInt64Val() const {
CHECK(GetKind() == TokKind::kInt);
return int64_val_;
}
double GetDecimalVal() const {
CHECK(GetKind() == TokKind::kDecimal);
return decimal_val_;
}
// Returns the line of text that is currently being lexed.
tensorflow::StringPiece GetCurrentLine() const;
private:
// Returns the current character. If it's neither the end of input buffer nor
// an invalid character, moves the pointer forward.
int GetNextChar();
// Returns the current character.
int PeekCurrentChar() const;
// Creates StringPiece with the given begin and end. Exits if the begin > end,
// or it's out of the range of the current buffer.
tensorflow::StringPiece StringPieceFromPointers(const char* begin,
const char* end) const;
tensorflow::RegexpStringPiece RegexpStringPieceFromPointers(
const char* begin, const char* end) const;
// Returns true if the given ptr is dereferenceable within the range of the
// current buffer.
bool CanDereference(const char* ptr) const;
TokKind LexToken();
TokKind LexIdentifier();
TokKind LexPercent();
TokKind LexShape();
TokKind LexConstant();
TokKind LexDigitOrNegative();
const tensorflow::StringPiece buf_;
const char* current_ptr_;
// Information about the current token.
const char* token_start_;
TokKind current_kind_;
string str_val_;
Shape shape_val_;
HloOpcode opcode_val_;
int64 int64_val_;
double decimal_val_;
};
} // namespace tools
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_LEXER_H_
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace xla {
namespace tools {
namespace {
using tensorflow::StringPiece;
using tensorflow::strings::StrCat;
// Parser for the HloModule::ToString() format text.
class HloParser {
public:
explicit HloParser(StringPiece str) : lexer_(str) {}
// Runs the parser. Returns false if an error occurred.
bool Run();
// Returns the parsed HloModule.
std::unique_ptr<HloModule> ConsumeHloModule() { return std::move(module_); }
// Returns the error information.
string GetError() const { return tensorflow::str_util::Join(error_, "\n"); }
private:
// ParseXXX returns false if an error occurred.
bool ParseHloModule();
bool ParseComputation();
bool ParseInstructionList(HloComputation::Builder* builder);
bool ParseInstruction(HloComputation::Builder* builder);
bool ParseLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
bool ParseOperands(std::vector<HloInstruction*>* operands,
const int expected_size);
bool ParseParamList();
bool ParseName(string* result);
bool ParseShape(Shape* result);
bool ParseOpcode(HloOpcode* result);
bool ParseInt64(int64* result);
bool ParseDecimal(double* result);
bool ParseBool(bool* result);
bool ParseToken(TokKind kind, const string& msg);
// Logs the current parsing line and the given message. Always returns false.
bool TokenError(StringPiece msg);
// If the current token is 'kind', eats it (i.e. lexes the next token) and
// returns true.
bool EatIfPresent(TokKind kind);
// Adds the instruction to the pool. Returns false and emits an error if the
// instruction already exists.
bool AddInstruction(const string& name, HloInstruction* instruction);
// The map from the instruction name to the instruction. This does not own the
// instructions.
std::unordered_map<string, HloInstruction*> instruction_pool_;
HloLexer lexer_;
std::unique_ptr<HloModule> module_;
std::vector<string> error_;
};
bool HloParser::TokenError(StringPiece msg) {
error_.push_back(
StrCat("was parsing \"", lexer_.GetCurrentLine(), "\"; ", msg));
return false;
}
bool HloParser::Run() {
lexer_.Lex();
return ParseHloModule();
}
// ::= 'HloModule' name computation
bool HloParser::ParseHloModule() {
if (lexer_.GetKind() != TokKind::kw_HloModule) {
return TokenError("expects HloModule");
}
// Eat 'HloModule'
lexer_.Lex();
string name;
if (!ParseName(&name)) {
return false;
}
module_ = MakeUnique<HloModule>(name);
return ParseComputation();
}
// computation ::= 'ENTRY' name param_list '->' shape instruction_list
bool HloParser::ParseComputation() {
string name;
if (!ParseToken(TokKind::kw_ENTRY, "expects 'ENTRY'") || !ParseName(&name)) {
return false;
}
auto builder = MakeUnique<HloComputation::Builder>(name);
Shape shape;
if (!ParseParamList() || !ParseToken(TokKind::kArrow, "expects '->'") ||
!ParseShape(&shape) || !ParseInstructionList(builder.get())) {
return false;
}
module_->AddEntryComputation(builder->Build());
return true;
}
// instruction_list ::= '{' instruction_list1 '}'
// instruction_list1 ::= (instruction)+
bool HloParser::ParseInstructionList(HloComputation::Builder* builder) {
if (!ParseToken(TokKind::kLbrace,
"expects '{' at the beginning of instruction list.")) {
return false;
}
do {
if (!ParseInstruction(builder)) {
return false;
}
} while (lexer_.GetKind() != TokKind::kRbrace);
return ParseToken(TokKind::kRbrace,
"expects '}' at the end of instruction list.");
}
// instruction ::= name '=' shape opcode operands
bool HloParser::ParseInstruction(HloComputation::Builder* builder) {
string name;
Shape shape;
HloOpcode opcode;
std::vector<HloInstruction*> operands;
if (!ParseName(&name) ||
!ParseToken(TokKind::kEqual, "expects '=' in instruction") ||
!ParseShape(&shape) || !ParseOpcode(&opcode)) {
return false;
}
switch (opcode) {
case HloOpcode::kParameter: {
int64 parameter_number;
return ParseToken(TokKind::kLparen,
"expects '(' before parameter number") &&
ParseInt64(&parameter_number) &&
ParseToken(TokKind::kRparen,
"expects ')' after parameter number") &&
AddInstruction(
name, builder->AddInstruction(HloInstruction::CreateParameter(
parameter_number, shape, name)));
}
case HloOpcode::kConstant: {
std::unique_ptr<Literal> literal;
return ParseToken(TokKind::kLparen,
"expects '(' before parameter number") &&
ParseLiteral(&literal, shape) &&
ParseToken(TokKind::kRparen,
"expects ')' after parameter number") &&
AddInstruction(
name, builder->AddInstruction(
HloInstruction::CreateConstant(std::move(literal))));
}
// Unary ops.
case HloOpcode::kAbs:
case HloOpcode::kRoundNearestAfz:
case HloOpcode::kBitcast:
case HloOpcode::kCeil:
case HloOpcode::kCopy:
case HloOpcode::kCos:
case HloOpcode::kExp:
case HloOpcode::kIsFinite:
case HloOpcode::kFloor:
case HloOpcode::kLog:
case HloOpcode::kNot:
case HloOpcode::kNegate:
case HloOpcode::kSign:
case HloOpcode::kSin:
case HloOpcode::kSort:
case HloOpcode::kTanh: {
return ParseOperands(&operands, /*expected_size=*/1) &&
AddInstruction(name,
builder->AddInstruction(HloInstruction::CreateUnary(
shape, opcode, operands[0])));
}
// Binary ops.
case HloOpcode::kAdd:
case HloOpcode::kDivide:
case HloOpcode::kMultiply:
case HloOpcode::kSubtract:
case HloOpcode::kEq:
case HloOpcode::kGe:
case HloOpcode::kGt:
case HloOpcode::kLe:
case HloOpcode::kLt:
case HloOpcode::kNe:
case HloOpcode::kDot:
case HloOpcode::kMaximum:
case HloOpcode::kMinimum:
case HloOpcode::kPower:
case HloOpcode::kRemainder:
case HloOpcode::kAnd:
case HloOpcode::kOr:
case HloOpcode::kShiftLeft:
case HloOpcode::kShiftRightArithmetic:
case HloOpcode::kShiftRightLogical: {
return ParseOperands(&operands, /*expected_size=*/2) &&
AddInstruction(
name, builder->AddInstruction(HloInstruction::CreateBinary(
shape, opcode, operands[0], operands[1])));
}
// Ternary ops.
case HloOpcode::kClamp:
case HloOpcode::kSelect: {
return ParseOperands(&operands, /*expected_size=*/3) &&
AddInstruction(
name,
builder->AddInstruction(HloInstruction::CreateTernary(
shape, opcode, operands[0], operands[1], operands[2])));
}
// Other supported ops.
case HloOpcode::kConvert: {
return ParseOperands(&operands, /*expected_size=*/1) &&
AddInstruction(
name, builder->AddInstruction(
HloInstruction::CreateConvert(shape, operands[0])));
}
case HloOpcode::kCrossReplicaSum: {
return ParseOperands(&operands, /*expected_size=*/1) &&
AddInstruction(name, builder->AddInstruction(
HloInstruction::CreateCrossReplicaSum(
shape, operands[0])));
}
case HloOpcode::kReshape: {
return ParseOperands(&operands, /*expected_size=*/1) &&
AddInstruction(
name, builder->AddInstruction(
HloInstruction::CreateReshape(shape, operands[0])));
}
case HloOpcode::kBroadcast:
case HloOpcode::kCall:
case HloOpcode::kCustomCall:
case HloOpcode::kConcatenate:
case HloOpcode::kReducePrecision:
case HloOpcode::kConvolution:
case HloOpcode::kGetTupleElement:
case HloOpcode::kMap:
case HloOpcode::kPad:
case HloOpcode::kReduce:
case HloOpcode::kReduceWindow:
case HloOpcode::kSelectAndScatter:
case HloOpcode::kReverse:
case HloOpcode::kRng:
case HloOpcode::kSlice:
case HloOpcode::kDynamicSlice:
case HloOpcode::kDynamicUpdateSlice:
case HloOpcode::kTranspose:
case HloOpcode::kTuple:
case HloOpcode::kWhile:
case HloOpcode::kFusion:
case HloOpcode::kBatchNormTraining:
case HloOpcode::kBatchNormInference:
case HloOpcode::kInfeed:
case HloOpcode::kOutfeed:
case HloOpcode::kBatchNormGrad:
case HloOpcode::kRecv:
case HloOpcode::kSend:
case HloOpcode::kUpdate:
case HloOpcode::kIndex:
case HloOpcode::kTrace:
return TokenError(StrCat("parsing not yet implemented for op: ",
HloOpcodeString(opcode)));
}
}
bool HloParser::ParseLiteral(std::unique_ptr<Literal>* literal,
const Shape& shape) {
switch (shape.element_type()) {
case PRED:
bool b;
if (!ParseBool(&b)) {
return false;
}
*literal = Literal::CreateR0<bool>(b);
return true;
case S32:
int64 i;
if (!ParseInt64(&i)) {
return false;
}
*literal = Literal::CreateR0<int32>(i);
return true;
case F32:
double d;
if (!ParseDecimal(&d)) {
return false;
}
*literal = Literal::CreateR0<float>(d);
return true;
default:
return TokenError(StrCat("unsupported constant in shape: ",
ShapeUtil::HumanString(shape)));
}
}
// operands ::= '(' operands1 ')'
// operands1
// ::= /*empty*/
// ::= operand (, operand)*
// operand ::= shape name
bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands,
const int expected_size) {
if (!ParseToken(TokKind::kLparen,
"expects '(' at the beginning of operands")) {
return false;
}
if (lexer_.GetKind() == TokKind::kRparen) {
// empty
} else {
do {
Shape shape;
string name;
if (!ParseShape(&shape) || !ParseName(&name)) {
return false;
}
HloInstruction* instruction =
tensorflow::gtl::FindPtrOrNull(instruction_pool_, name);
if (!instruction) {
return TokenError(StrCat("instruction does not exist: ", name));
}
operands->push_back(instruction);
} while (EatIfPresent(TokKind::kComma));
}
if (expected_size != operands->size()) {
return TokenError(StrCat("expects ", expected_size, " operands, but has ",
operands->size(), " operands"));
}
return ParseToken(TokKind::kRparen, "expects ')' at the end of operands");
}
// param_list ::= '(' param_list1 ')'
// param_list1
// ::= /*empty*/
// ::= param (',' param)*
// param ::= name shape
bool HloParser::ParseParamList() {
if (!ParseToken(TokKind::kLparen,
"expects '(' at the beginning of param list")) {
return false;
}
if (lexer_.GetKind() == TokKind::kRparen) {
// empty
} else {
do {
Shape shape;
if (!ParseToken(TokKind::kName, "expects name in parameter") ||
!ParseShape(&shape)) {
return false;
}
} while (EatIfPresent(TokKind::kComma));
}
return ParseToken(TokKind::kRparen, "expects ')' at the end of param list");
}
// shape ::= shape_val_
// shape ::= '(' tuple_elements ')'
// tuple_elements
// ::= /*empty*/
// ::= shape (',' shape)*
bool HloParser::ParseShape(Shape* result) {
if (EatIfPresent(TokKind::kLparen)) { // Tuple
std::vector<Shape> shapes;
if (lexer_.GetKind() == TokKind::kRparen) {
/*empty*/
} else {
// shape (',' shape)*
do {
shapes.emplace_back();
if (!ParseShape(&shapes.back())) {
return false;
}
} while (EatIfPresent(TokKind::kComma));
}
*result = ShapeUtil::MakeTupleShape(shapes);
return ParseToken(TokKind::kRparen, "expects ')' at the end of tuple.");
}
if (lexer_.GetKind() != TokKind::kShape) {
return TokenError("expects shape");
}
*result = lexer_.GetShapeVal();
lexer_.Lex();
return true;
}
bool HloParser::ParseName(string* result) {
VLOG(1) << "ParseName";
if (lexer_.GetKind() != TokKind::kName) {
return TokenError("expects name");
}
*result = lexer_.GetStrVal();
lexer_.Lex();
return true;
}
bool HloParser::ParseOpcode(HloOpcode* result) {
VLOG(1) << "ParseOpcode";
if (lexer_.GetKind() != TokKind::kOpcode) {
return TokenError("expects opcode");
}
*result = lexer_.GetOpcodeVal();
lexer_.Lex();
return true;
}
bool HloParser::ParseInt64(int64* result) {
VLOG(1) << "ParseInt64";
if (lexer_.GetKind() != TokKind::kInt) {
return TokenError("expects integer");
}
*result = lexer_.GetInt64Val();
lexer_.Lex();
return true;
}
bool HloParser::ParseDecimal(double* result) {
switch (lexer_.GetKind()) {
case TokKind::kDecimal:
*result = lexer_.GetDecimalVal();
break;
case TokKind::kInt:
*result = static_cast<double>(lexer_.GetInt64Val());
break;
default:
return TokenError("expects decimal or integer");
}
lexer_.Lex();
return true;
}
bool HloParser::ParseBool(bool* result) {
if (lexer_.GetKind() != TokKind::kw_true &&
lexer_.GetKind() != TokKind::kw_false) {
return TokenError("expects true or false");
}
*result = lexer_.GetKind() == TokKind::kw_true;
lexer_.Lex();
return true;
}
bool HloParser::ParseToken(TokKind kind, const string& msg) {
if (lexer_.GetKind() != kind) {
return TokenError(msg);
}
lexer_.Lex();
return true;
}
bool HloParser::EatIfPresent(TokKind kind) {
if (lexer_.GetKind() != kind) {
return false;
}
lexer_.Lex();
return true;
}
bool HloParser::AddInstruction(const string& name,
HloInstruction* instruction) {
auto result = instruction_pool_.insert({name, instruction});
if (!result.second) {
return TokenError(StrCat("instruction already exists: ", name));
}
return true;
}
} // namespace
StatusOr<std::unique_ptr<HloModule>> Parse(StringPiece str) {
HloParser parser(str);
if (!parser.Run()) {
return InvalidArgument("Syntax error: %s", parser.GetError().c_str());
}
return parser.ConsumeHloModule();
}
} // namespace tools
} // namespace xla
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_PARSER_H_
#define TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_PARSER_H_
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tools/parser/hlo_lexer.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
namespace tools {
// The api of the hlo parser. Given a string in the HloModule::ToString()
// format, returns the parsed HloModule.
StatusOr<std::unique_ptr<HloModule>> Parse(tensorflow::StringPiece str);
} // namespace tools
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_PARSER_H_
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include <string>
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace xla {
namespace tools {
namespace {
struct TestData {
string test_name;
string module_string;
};
string TestDataToString(const ::testing::TestParamInfo<TestData>& data) {
return data.param.test_name;
}
std::vector<TestData> CreateTestCases() {
// clang-format off
return std::vector<TestData>({
// ax + y
{
"AxpyParam",
R"(HloModule axpy_module:
ENTRY %axpy.v5 (alpha: f32[2,4], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
%alpha = f32[2,4]{1,0} parameter(0)
%x = f32[2,4]{1,0} parameter(1)
%multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %alpha, f32[2,4]{1,0} %x)
%y = f32[2,4]{1,0} parameter(2)
%add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
}
)"
},
// pred constant
{
"ConstantPred",
R"(HloModule constant_pred_module:
ENTRY %constant_pred () -> pred[] {
%constant = pred[] constant(true)
}
)"
},
// s32 constant
{
"ConstantS32",
R"(HloModule constant_s32_module:
ENTRY %constant_s32 () -> s32[] {
%constant = s32[] constant(-42)
}
)"
},
// f32 constant, but the value is not a decimal
{
"ConstantF32", R"(HloModule ConstantF32_module:
ENTRY %ConstantF32.v4 () -> f32[] {
%constant = f32[] constant(42)
}
)"
},
// constant + constant
{
"AddConstants",
R"(HloModule add_constants_module:
ENTRY %add_constants () -> f32[] {
%constant = f32[] constant(3.14)
%add = f32[] add(f32[] %constant, f32[] %constant)
}
)"
},
// v1 > v2 ? v1 : v2
{
"SelectR1F32",
R"(HloModule SelectR1F32WithCmpR1F32sFromParamsSmall_module:
ENTRY %SelectR1F32WithCmpR1F32sFromParamsSmall.v4 (v1: f32[4], v2: f32[4]) -> f32[4] {
%v1 = f32[4]{0} parameter(0)
%v2 = f32[4]{0} parameter(1)
%greater-than = pred[4]{0} greater-than(f32[4]{0} %v1, f32[4]{0} %v2)
%select = f32[4]{0} select(pred[4]{0} %greater-than, f32[4]{0} %v1, f32[4]{0} %v2)
}
)"
}
});
// clang-format on
}
class HloParserTest : public ::testing::Test,
public ::testing::WithParamInterface<TestData> {
protected:
void ExpectSuccess() {
const string& original = GetParam().module_string;
auto result = Parse(original);
TF_EXPECT_OK(result.status());
EXPECT_EQ(original, result.ValueOrDie()->ToString());
}
};
TEST_P(HloParserTest, Run) { ExpectSuccess(); }
INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserTest,
::testing::ValuesIn(CreateTestCases()),
TestDataToString);
TEST_F(HloParserTest, Empty) {
const string original = "";
auto result = Parse(original);
EXPECT_NE(tensorflow::Status::OK(), result.status());
}
TEST_F(HloParserTest, Garbage) {
const string original = "HloModule thi$ str1ng makes# N0 sen$e @all!*&^%$";
auto result = Parse(original);
EXPECT_NE(tensorflow::Status::OK(), result.status());
}
TEST_F(HloParserTest, WrongOpcode) {
const string original = R"(HloModule wrong_opcode:
ENTRY %blabla (x: f32[], y: f32[]) -> f32[] {
%x = f32[]{} parameter(0)
%y = f32[]{} parameter(1)
%le = pred[]{} le(f32[]{} %x, f32[]{} %y)
}
)";
auto result = Parse(original);
EXPECT_NE(tensorflow::Status::OK(), result.status());
}
TEST_F(HloParserTest, WrongShape) {
const string original = R"(HloModule wrong_opcode:
ENTRY %blabla (x: g32[]) -> g32[] {
%x = g32[]{} parameter(0)
}
)";
auto result = Parse(original);
EXPECT_NE(tensorflow::Status::OK(), result.status());
}
TEST_F(HloParserTest, WrongOperandsSize) {
const string original = R"(HloModule wrong_opcode:
ENTRY %blabla (x: f32[]) -> pred[] {
%x = f32[]{} parameter(0)
%eq = pred[]{} equal-to(f32[]{} %x)
}
)";
auto result = Parse(original);
EXPECT_NE(tensorflow::Status::OK(), result.status());
}
TEST_F(HloParserTest, OperandNotFound) {
const string original = R"(HloModule operand_not_found:
ENTRY %blabla (x: f32[]) -> pred[] {
%x = f32[]{} parameter(0)
%eq = pred[]{} equal-to(f32[]{} %x, f32[]{} %y)
}
)";
auto result = Parse(original);
EXPECT_NE(tensorflow::Status::OK(), result.status());
}
TEST_F(HloParserTest, MoreConstants) {
const string original = R"(HloModule SelectScalarS32True_module:
ENTRY %SelectScalarS32True.v4 () -> s32[] {
%constant.2 = pred[] constant(true)
%constant.1 = s32[] constant(-42)
%constant = s32[] constant(42)
%select = s32[] select(pred[] %constant.2, s32[] %constant.1, s32[] %constant)
}
)";
auto result = Parse(original);
TF_EXPECT_OK(result.status());
// Constant instructions have no name. The string will be parsed successfully
// but the constant names will not be exactly the same.
}
TEST_F(HloParserTest, ConstantWithExp) {
const string original = R"(HloModule ConstantWithExp_module:
ENTRY %ConstantWithExp.v4 () -> f32[] {
%constant.1 = f32[] constant(3e+2)
}
)";
auto result = Parse(original);
TF_EXPECT_OK(result.status());
// The string will be parsed successfully but the output strings are not
// exactly the same, because "3e2" is parsed into value 300 and will be
// printed as "300".
}
TEST_F(HloParserTest, Tuple) {
const string original = R"(HloModule EmptyTupleCreate_module:
ENTRY %EmptyTupleCreate.v1 () -> () {
%tuple = () tuple()
}
)";
auto result = Parse(original);
EXPECT_NE(tensorflow::Status::OK(), result.status());
}
} // namespace
} // namespace tools
} // namespace xla
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_TOKEN_H_
#define TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_TOKEN_H_
namespace xla {
namespace tools {
// Defines different kinds of tokens in a hlo module string.
enum class TokKind {
// Markers
kEof,
kError,
// Tokens with no info.
kEqual, // =
kComma, // ,
kColon, // :
kLsquare,
kRsquare, // [ ]
kLbrace,
kRbrace, // { }
kLparen,
kRparen, // ( )
kArrow, // ->
// Keywords
kw_HloModule,
kw_ENTRY,
kw_true,
kw_false,
// Typed tokens.
kName, // %foo
kShape, // f32[2,3]{1,0}
kOpcode, // add
kInt, // 42
kDecimal, // 4.2
};
} // namespace tools
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_TOKEN_H_
......@@ -82,8 +82,8 @@ message DebugOptions {
// Dump all HLO modules as text into the provided directory path.
string xla_generate_hlo_text_to = 7;
// Dump compilation artifacts as JSON into this directory.
string xla_dump_debug_json_to = 8;
// Dump compilation artifacts in binary proto into this directory.
string xla_dump_hlo_proto_to = 8;
// Instrument the computation to collect per-HLO cycle counts.
bool xla_hlo_profile = 9;
......
......@@ -69,6 +69,28 @@ tf_cc_test(
],
)
cc_library(
name = "adaptive_shared_batch_scheduler",
hdrs = ["adaptive_shared_batch_scheduler.h"],
deps = [
":batch_scheduler",
"//tensorflow/contrib/batching/util:periodic_function_dynamic",
"//tensorflow/core:lib",
],
)
tf_cc_test(
name = "adaptive_shared_batch_scheduler_test",
srcs = ["adaptive_shared_batch_scheduler_test.cc"],
deps = [
":adaptive_shared_batch_scheduler",
"//tensorflow/contrib/batching/test_util:fake_clock_env",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
cc_library(
name = "basic_batch_scheduler",
hdrs = ["basic_batch_scheduler.h"],
......
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_
#define THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_
#include <functional>
#include <memory>
#include <queue>
#include <unordered_map>
#include <vector>
#include "tensorflow/contrib/batching/batch_scheduler.h"
#include "tensorflow/contrib/batching/util/periodic_function.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace serving {
namespace internal {
template <typename TaskType>
class ASBSBatch;
template <typename TaskType>
class ASBSQueue;
} // namespace internal
// Shared batch scheduler designed to minimize latency. The scheduler keeps
// track of a number of queues (one per model or model version) which are
// continuously enqueuing requests. The scheduler groups the requests into
// batches which it periodically sends off for processing (see
// shared_batch_scheduler.h for more details). The AdaptiveSharedBatchScheduler
// prioritizes batches by age (i.e. the batch's oldest request) irrespective of
// queue. The scheduler will process the oldest batch at an adjustable rate,
// regardless of batch size. The user can provide feedback to help set this rate
// to achieve some goal (i.e. minimize overall latency, limit cpu usage, etc).
//
// The rate (or rather, the corresponding period) is adjusted each time a batch
// is processed, using an exponentially weighted moving average to smooth
// potentially noisy feedback:
// ewma_feedback = ((N - 1) * ewma_feedback + feedback()) / N
// period *= (1 + K * emwa_feedback)
//
// Some potential use cases:
// Hardware Accelerators (GPUs & TPUs) - If some phase of batch processing
// involves serial processing by a device, from a latency perspective it is
// desirable to keep the device evenly loaded, avoiding the need to wait for
// the device to process prior batches.
// feedback = num_pending_on_device() - desired_pending.
// CPU utilization - If the batch processing is cpu dominated, you can reap
// latency gains when underutilized by increasing the processing rate, but
// back the rate off when the load increases to avoid overload.
// feedback = cpu_rate() - desired_cpu_rate.
template <typename TaskType>
class AdaptiveSharedBatchScheduler
: public std::enable_shared_from_this<
AdaptiveSharedBatchScheduler<TaskType>> {
public:
struct Options {
// The name to use for the pool of batch threads.
string thread_pool_name = {"batch_threads"};
// Number of batch processing threads; equivalently the maximum number of
// concurrently running batches.
int64 num_batch_threads = port::NumSchedulableCPUs();
// The environment to use (typically only overridden by test code).
Env* env = Env::Default();
// Initial batch scheduling period in microseconds. Will be altered for
// non-zero rate_feedback.
double initial_scheduling_period_micros = 500;
// Minimum batch scheduling period in microseconds. Recommend setting this
// value greater than 0, otherwise it may take a while to recover from a
// sustained time of negative scheduling_period_feedback (which may occur
// under low load).
double min_scheduling_period_micros = 100;
// Maximum batch scheduling period in microseconds.
double max_scheduling_period_micros = 10000;
// Feedback function used to modify the scheduling period each time a batch
// is scheduled. Should return values roughly O(1), with positive values
// resulting in an increased period.
std::function<double()> scheduling_period_feedback = [] { return 0.; };
// To handle potentially noisy scheduling_period_feedback, the period is
// adjusted using an exponentially weighted moving average over the previous
// feedback_smoothing_batches batches. Must be greater than 0.
int64 feedback_smoothing_batches = 10;
};
// Ownership is shared between the caller of Create() and any queues created
// via AddQueue().
static Status Create(
const Options& options,
std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>>* scheduler);
struct QueueOptions {
// Maximum size of each batch.
int max_batch_size = 1000;
// Maximum number of enqueued (i.e. non-scheduled) batches.
int max_enqueued_batches = 10;
};
using BatchProcessor = std::function<void(std::unique_ptr<Batch<TaskType>>)>;
// Adds queue (and its callback) to be managed by this scheduler.
Status AddQueue(const QueueOptions& options,
BatchProcessor process_batch_callback,
std::unique_ptr<BatchScheduler<TaskType>>* queue);
private:
// access to AddBatch, RemoveQueue, GetEnv.
friend class internal::ASBSQueue<TaskType>;
explicit AdaptiveSharedBatchScheduler(const Options& options);
// Batch scheduling function which runs every scheduling_period_ microseconds.
void ProcessOneBatch();
// Notifies scheduler of non-empty batch which is eligible for processing.
void AddBatch(internal::ASBSBatch<TaskType>*);
// Removes queue from scheduler.
void RemoveQueue(const internal::ASBSQueue<TaskType>* queue);
Env* GetEnv() const { return options_.env; }
const Options options_;
struct BatchCompare {
bool operator()(const internal::ASBSBatch<TaskType>* a,
const internal::ASBSBatch<TaskType>* b);
};
// Collection of batches added by AddBatch, ordered by age. Owned by scheduler
// until they are released for processing.
std::priority_queue<const internal::ASBSBatch<TaskType>*,
std::vector<internal::ASBSBatch<TaskType>*>, BatchCompare>
batches_ GUARDED_BY(mu_);
// Unowned queues and callbacks added by AddQueue.
std::unordered_map<const internal::ASBSQueue<TaskType>*, BatchProcessor>
queues_and_callbacks_ GUARDED_BY(mu_);
mutex mu_;
// Responsible for running ProcessOneBatch. PeriodicFunction was used in order
// to check for deletion so that the thread can be shut down.
std::unique_ptr<PeriodicFunction> scheduling_thread_;
// Responsible for running the batch processing callbacks.
std::unique_ptr<thread::ThreadPool> batch_thread_pool_;
// Time interval in microseconds between successive ProcessOneBatch calls.
double scheduling_period_;
// Exponentially weighted moving average of
// options_.scheduling_period_feedback() evaluated in each ProcessOneBatch
// call.
double ewma_feedback_ = 0;
TF_DISALLOW_COPY_AND_ASSIGN(AdaptiveSharedBatchScheduler);
};
//////////////////////////////////////////////////////////
// Implementation details follow. API users need not read.
namespace internal {
// Consolidates tasks into batches, passing them off to the
// AdaptiveSharedBatchScheduler for processing.
template <typename TaskType>
class ASBSQueue : public BatchScheduler<TaskType> {
public:
using QueueOptions =
typename AdaptiveSharedBatchScheduler<TaskType>::QueueOptions;
ASBSQueue(std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>> scheduler,
const QueueOptions& options);
~ASBSQueue() override;
// Adds task to current batch. Fails if the task size is larger than the batch
// size or if the current batch is full and this queue's number of outstanding
// batches is at its maximum.
Status Schedule(std::unique_ptr<TaskType>* task) override;
// Number of tasks waiting to be scheduled.
size_t NumEnqueuedTasks() const override;
// Number of size 1 tasks which could currently be scheduled without failing.
size_t SchedulingCapacity() const override;
// Notifies queue that a batch is about to be scheduled; the queue should not
// place any more tasks in this batch.
void ReleaseBatch(const ASBSBatch<TaskType>* batch);
private:
std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>> scheduler_;
const QueueOptions options_;
// Owned by scheduler_.
ASBSBatch<TaskType>* current_batch_ GUARDED_BY(mu_) = nullptr;
int64 num_enqueued_batches_ GUARDED_BY(mu_) = 0;
int64 num_enqueued_tasks_ GUARDED_BY(mu_) = 0;
mutable mutex mu_;
TF_DISALLOW_COPY_AND_ASSIGN(ASBSQueue);
};
// Batch which remembers when and by whom it was created.
template <typename TaskType>
class ASBSBatch : public Batch<TaskType> {
public:
ASBSBatch(ASBSQueue<TaskType>* queue, int64 creation_time_micros)
: queue_(queue), creation_time_micros_(creation_time_micros) {}
~ASBSBatch() override {}
ASBSQueue<TaskType>* queue() const { return queue_; }
int64 creation_time_micros() const { return creation_time_micros_; }
private:
ASBSQueue<TaskType>* queue_;
const int64 creation_time_micros_;
TF_DISALLOW_COPY_AND_ASSIGN(ASBSBatch);
};
} // namespace internal
// ---------------- AdaptiveSharedBatchScheduler ----------------
template <typename TaskType>
Status AdaptiveSharedBatchScheduler<TaskType>::Create(
const Options& options,
std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>>* scheduler) {
if (options.num_batch_threads < 1) {
return errors::InvalidArgument("num_batch_threads must be positive; was ",
options.num_batch_threads);
}
if (options.min_scheduling_period_micros < 0) {
return errors::InvalidArgument(
"min_scheduling_period_micros must be >= 0; was ",
options.min_scheduling_period_micros);
}
if (options.min_scheduling_period_micros >
options.initial_scheduling_period_micros) {
return errors::InvalidArgument(
"initial_scheduling_period_micros (",
options.initial_scheduling_period_micros,
") must be >= min_scheduling_period_micros (",
options.min_scheduling_period_micros, ")");
}
if (options.initial_scheduling_period_micros >
options.max_scheduling_period_micros) {
return errors::InvalidArgument(
"initial_scheduling_period_micros (",
options.initial_scheduling_period_micros,
") must be <= max_scheduling_period_micros (",
options.max_scheduling_period_micros, ")");
}
if (options.feedback_smoothing_batches < 1) {
return errors::InvalidArgument(
"feedback_smoothing_batches must be positive; was ",
options.feedback_smoothing_batches);
}
scheduler->reset(new AdaptiveSharedBatchScheduler<TaskType>(options));
return Status::OK();
}
template <typename TaskType>
AdaptiveSharedBatchScheduler<TaskType>::AdaptiveSharedBatchScheduler(
const Options& options)
: options_(options),
scheduling_period_(options.initial_scheduling_period_micros) {
PeriodicFunction::Options opts;
opts.thread_name_prefix = "scheduling_thread";
opts.env = GetEnv();
scheduling_thread_.reset(
new PeriodicFunction([this] { ProcessOneBatch(); }, 0, opts));
batch_thread_pool_.reset(new thread::ThreadPool(
GetEnv(), options.thread_pool_name, options.num_batch_threads));
}
template <typename TaskType>
Status AdaptiveSharedBatchScheduler<TaskType>::AddQueue(
const QueueOptions& options, BatchProcessor process_batch_callback,
std::unique_ptr<BatchScheduler<TaskType>>* queue) {
if (options.max_batch_size <= 0) {
return errors::InvalidArgument("max_batch_size must be positive; was ",
options.max_batch_size);
}
if (options.max_enqueued_batches <= 0) {
return errors::InvalidArgument(
"max_enqueued_batches must be positive; was ",
options.max_enqueued_batches);
}
internal::ASBSQueue<TaskType>* asbs_queue_raw;
queue->reset(asbs_queue_raw = new internal::ASBSQueue<TaskType>(
this->shared_from_this(), options));
mutex_lock l(mu_);
queues_and_callbacks_[asbs_queue_raw] = process_batch_callback;
return Status::OK();
}
template <typename TaskType>
void AdaptiveSharedBatchScheduler<TaskType>::AddBatch(
internal::ASBSBatch<TaskType>* batch) {
mutex_lock l(mu_);
batches_.push(batch);
}
template <typename TaskType>
void AdaptiveSharedBatchScheduler<TaskType>::RemoveQueue(
const internal::ASBSQueue<TaskType>* queue) {
mutex_lock l(mu_);
queues_and_callbacks_.erase(queue);
}
template <typename TaskType>
void AdaptiveSharedBatchScheduler<TaskType>::ProcessOneBatch() {
static const double kFeedbackMultiplier = .001;
internal::ASBSBatch<TaskType>* batch = nullptr;
BatchProcessor callback;
const int64 start_time_micros = GetEnv()->NowMicros();
{
mutex_lock l(mu_);
if (!batches_.empty()) {
batch = batches_.top();
batches_.pop();
callback = queues_and_callbacks_[batch->queue()];
}
}
if (batch != nullptr) {
double feedback = options_.scheduling_period_feedback();
const int64 N = options_.feedback_smoothing_batches;
ewma_feedback_ = ((N - 1) * ewma_feedback_ + feedback) / N;
scheduling_period_ *= (1 + kFeedbackMultiplier * ewma_feedback_);
if (scheduling_period_ < options_.min_scheduling_period_micros) {
scheduling_period_ = options_.min_scheduling_period_micros;
} else if (scheduling_period_ > options_.max_scheduling_period_micros) {
scheduling_period_ = options_.max_scheduling_period_micros;
}
// Queue may destroy itself after ReleaseBatch is called.
batch->queue()->ReleaseBatch(batch);
batch_thread_pool_->Schedule([callback, batch] {
callback(std::unique_ptr<Batch<TaskType>>(batch));
});
}
const int64 sleep_time =
scheduling_period_ - (GetEnv()->NowMicros() - start_time_micros);
if (sleep_time > 0) {
GetEnv()->SleepForMicroseconds(sleep_time);
}
}
template <typename TaskType>
bool AdaptiveSharedBatchScheduler<TaskType>::BatchCompare::operator()(
const internal::ASBSBatch<TaskType>* a,
const internal::ASBSBatch<TaskType>* b) {
return a->creation_time_micros() > b->creation_time_micros();
}
// ---------------- ASBSQueue ----------------
namespace internal {
template <typename TaskType>
ASBSQueue<TaskType>::ASBSQueue(
std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>> scheduler,
const QueueOptions& options)
: scheduler_(scheduler), options_(options) {}
template <typename TaskType>
ASBSQueue<TaskType>::~ASBSQueue() {
// Wait until last batch has been scheduled.
const int kSleepMicros = 1000;
for (;;) {
{
mutex_lock l(mu_);
if (num_enqueued_batches_ == 0) {
break;
}
}
scheduler_->GetEnv()->SleepForMicroseconds(kSleepMicros);
}
scheduler_->RemoveQueue(this);
}
template <typename TaskType>
Status ASBSQueue<TaskType>::Schedule(std::unique_ptr<TaskType>* task) {
bool added_new_batch = false;
size_t size = (*task)->size();
if (size > options_.max_batch_size) {
return errors::InvalidArgument("Task size ", size,
" is larger than maximum batch size ",
options_.max_batch_size);
}
{
mutex_lock l(mu_);
// Current batch is full, create another if allowed.
if (current_batch_ &&
current_batch_->size() + size > options_.max_batch_size) {
if (num_enqueued_batches_ >= options_.max_enqueued_batches) {
return errors::Unavailable("The batch scheduling queue is full");
}
current_batch_->Close();
current_batch_ = nullptr;
}
if (!current_batch_) {
added_new_batch = true;
num_enqueued_batches_++;
current_batch_ =
new ASBSBatch<TaskType>(this, scheduler_->GetEnv()->NowMicros());
}
current_batch_->AddTask(std::move(*task));
num_enqueued_tasks_++;
}
if (added_new_batch) scheduler_->AddBatch(current_batch_);
return Status::OK();
}
template <typename TaskType>
void ASBSQueue<TaskType>::ReleaseBatch(const ASBSBatch<TaskType>* batch) {
mutex_lock l(mu_);
num_enqueued_batches_--;
num_enqueued_tasks_ -= batch->num_tasks();
if (batch == current_batch_) {
current_batch_->Close();
current_batch_ = nullptr;
}
}
template <typename TaskType>
size_t ASBSQueue<TaskType>::NumEnqueuedTasks() const {
mutex_lock l(mu_);
return num_enqueued_tasks_;
}
template <typename TaskType>
size_t ASBSQueue<TaskType>::SchedulingCapacity() const {
mutex_lock l(mu_);
const int current_batch_capacity =
current_batch_ ? options_.max_batch_size - current_batch_->size() : 0;
const int spare_batches =
options_.max_enqueued_batches - num_enqueued_batches_;
return spare_batches * options_.max_batch_size + current_batch_capacity;
}
} // namespace internal
} // namespace serving
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h"
#include "tensorflow/contrib/batching/test_util/fake_clock_env.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace serving {
namespace anonymous {
class FakeTask : public BatchTask {
public:
explicit FakeTask(size_t size) : size_(size) {}
~FakeTask() override = default;
size_t size() const override { return size_; }
private:
const size_t size_;
TF_DISALLOW_COPY_AND_ASSIGN(FakeTask);
};
// Creates a FakeTask of size 'task_size', and calls 'scheduler->Schedule()' on
// that task. Returns the resulting status.
Status ScheduleTask(size_t task_size, BatchScheduler<FakeTask>* scheduler) {
std::unique_ptr<FakeTask> task(new FakeTask(task_size));
Status status = scheduler->Schedule(&task);
// Schedule() should have consumed 'task' iff it returned Status::OK.
CHECK_EQ(status.ok(), task == nullptr);
return status;
}
// Creates a thread that waits on 'start' and then advances the fake clock in
// 'env' in a loop until 'stop' is notified. Useful for allowing objects that
// use the clock to be destroyed.
std::unique_ptr<Thread> CreateFakeClockAdvancerThread(
test_util::FakeClockEnv* env, Notification* start, Notification* stop) {
return std::unique_ptr<Thread>(Env::Default()->StartThread(
{}, "FakeClockAdvancerThread", [env, start, stop] {
start->WaitForNotification();
while (!stop->HasBeenNotified()) {
env->AdvanceByMicroseconds(10);
Env::Default()->SleepForMicroseconds(10);
}
}));
}
TEST(AdaptiveSharedBatchSchedulerTest, Basic) {
for (const bool delete_scheduler_early : {false, true}) {
for (const bool delete_queue_1_early : {false, true}) {
int queue_0_tasks = 0;
auto queue_0_callback =
[&queue_0_tasks](std::unique_ptr<Batch<FakeTask>> batch) {
ASSERT_TRUE(batch->IsClosed());
EXPECT_GT(batch->num_tasks(), 0);
for (int i = 0; i < batch->num_tasks(); i++) {
queue_0_tasks += batch->task(i).size();
}
};
int queue_1_tasks = 0;
auto queue_1_callback =
[&queue_1_tasks](std::unique_ptr<Batch<FakeTask>> batch) {
ASSERT_TRUE(batch->IsClosed());
EXPECT_GT(batch->num_tasks(), 0);
for (int i = 0; i < batch->num_tasks(); i++) {
queue_1_tasks += batch->task(i).size();
}
};
{
std::shared_ptr<AdaptiveSharedBatchScheduler<FakeTask>> scheduler;
TF_ASSERT_OK(
AdaptiveSharedBatchScheduler<FakeTask>::Create({}, &scheduler));
// Create two queues.
std::unique_ptr<BatchScheduler<FakeTask>> queue_0;
TF_ASSERT_OK(scheduler->AddQueue({}, queue_0_callback, &queue_0));
std::unique_ptr<BatchScheduler<FakeTask>> queue_1;
TF_ASSERT_OK(scheduler->AddQueue({}, queue_1_callback, &queue_1));
if (delete_scheduler_early) {
// Delete our copy of the scheduler. The queues should keep it alive
// under the covers.
scheduler = nullptr;
}
// Submit tasks to the two queues, and (optionally) remove the queues.
TF_ASSERT_OK(ScheduleTask(1, queue_0.get()));
TF_ASSERT_OK(ScheduleTask(2, queue_1.get()));
TF_ASSERT_OK(ScheduleTask(3, queue_0.get()));
TF_ASSERT_OK(ScheduleTask(4, queue_1.get()));
if (delete_queue_1_early) {
queue_1 = nullptr;
}
TF_ASSERT_OK(ScheduleTask(5, queue_0.get()));
}
EXPECT_EQ(queue_0_tasks, 9);
EXPECT_EQ(queue_1_tasks, 6);
}
}
}
TEST(AdaptiveSharedBatchSchedulerTest, BadOptions) {
using Scheduler = AdaptiveSharedBatchScheduler<FakeTask>;
std::shared_ptr<Scheduler> scheduler;
Scheduler::Options options;
options.num_batch_threads = 0;
EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok());
options = Scheduler::Options();
options.min_scheduling_period_micros = 50;
options.max_scheduling_period_micros = 100;
options.initial_scheduling_period_micros = 1;
EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok());
options = Scheduler::Options();
options.min_scheduling_period_micros = 50;
options.max_scheduling_period_micros = 100;
options.initial_scheduling_period_micros = 1000;
EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok());
options = Scheduler::Options();
options.min_scheduling_period_micros = 100;
options.max_scheduling_period_micros = 50;
options.initial_scheduling_period_micros = 75;
EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok());
options = Scheduler::Options();
options.feedback_smoothing_batches = 0;
EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok());
}
TEST(AdaptiveSharedBatchSchedulerTest, ObeysQueueOptions) {
test_util::FakeClockEnv env(Env::Default());
Notification start_teardown, stop_teardown;
std::unique_ptr<Thread> teardown_thread =
CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown);
{
AdaptiveSharedBatchScheduler<FakeTask>::Options options;
options.initial_scheduling_period_micros = 1000;
options.env = &env;
std::shared_ptr<AdaptiveSharedBatchScheduler<FakeTask>> scheduler;
TF_ASSERT_OK(
AdaptiveSharedBatchScheduler<FakeTask>::Create(options, &scheduler));
std::unique_ptr<BatchScheduler<FakeTask>> queue_0;
std::unique_ptr<BatchScheduler<FakeTask>> queue_1;
int queue_0_tasks = 0;
int queue_1_tasks = 0;
auto queue_0_callback = [&queue_0_tasks,
&env](std::unique_ptr<Batch<FakeTask>> batch) {
ASSERT_TRUE(batch->IsClosed());
EXPECT_GT(batch->num_tasks(), 0);
for (int i = 0; i < batch->num_tasks(); i++) {
queue_0_tasks += batch->task(i).size();
}
env.SleepForMicroseconds(1);
};
auto queue_1_callback = [&queue_1_tasks,
&env](std::unique_ptr<Batch<FakeTask>> batch) {
ASSERT_TRUE(batch->IsClosed());
EXPECT_GT(batch->num_tasks(), 0);
for (int i = 0; i < batch->num_tasks(); i++) {
queue_1_tasks += batch->task(i).size();
}
env.SleepForMicroseconds(1);
};
AdaptiveSharedBatchScheduler<FakeTask>::QueueOptions queue_options;
queue_options.max_batch_size = 10;
queue_options.max_enqueued_batches = 0;
// Queue must have max_enqueued_batchs > 1.
EXPECT_FALSE(
scheduler->AddQueue(queue_options, queue_0_callback, &queue_0).ok());
queue_options.max_enqueued_batches = 2;
TF_ASSERT_OK(
scheduler->AddQueue(queue_options, queue_0_callback, &queue_0));
queue_options.max_batch_size = 0;
// Queue must have max_batch_size > 0.
EXPECT_FALSE(
scheduler->AddQueue(queue_options, queue_1_callback, &queue_1).ok());
queue_options.max_batch_size = 2;
queue_options.max_enqueued_batches = 1;
TF_ASSERT_OK(
scheduler->AddQueue(queue_options, queue_1_callback, &queue_1));
// Wait for scheduling_thread to sleep.
env.BlockUntilThreadsAsleep(1);
// Task larger than max_batch_size shouldn't schedule.
EXPECT_FALSE(ScheduleTask(15, queue_0.get()).ok());
TF_ASSERT_OK(ScheduleTask(5, queue_0.get()));
TF_ASSERT_OK(ScheduleTask(5, queue_0.get()));
env.AdvanceByMicroseconds(1);
// Task larger than max_batch_size shouldn't schedule.
EXPECT_FALSE(ScheduleTask(3, queue_1.get()).ok());
TF_ASSERT_OK(ScheduleTask(1, queue_1.get()));
TF_ASSERT_OK(ScheduleTask(1, queue_1.get()));
env.AdvanceByMicroseconds(1);
// Exceeds max_enqueued_batches, shouldn't schedule.
EXPECT_FALSE(ScheduleTask(1, queue_1.get()).ok());
TF_ASSERT_OK(ScheduleTask(5, queue_0.get()));
// Exceeds max_enqueued_batches, shouldn't schedule.
EXPECT_FALSE(ScheduleTask(6, queue_0.get()).ok());
TF_ASSERT_OK(ScheduleTask(4, queue_0.get()));
// Batches should be processed in order from oldest to newest.
env.AdvanceByMicroseconds(1000);
env.BlockUntilThreadsAsleep(2);
EXPECT_EQ(queue_0_tasks, 10);
EXPECT_EQ(queue_1_tasks, 0);
env.AdvanceByMicroseconds(1000);
env.BlockUntilThreadsAsleep(2);
EXPECT_EQ(queue_0_tasks, 10);
EXPECT_EQ(queue_1_tasks, 2);
env.AdvanceByMicroseconds(1000);
env.BlockUntilThreadsAsleep(2);
EXPECT_EQ(queue_0_tasks, 19);
EXPECT_EQ(queue_1_tasks, 2);
start_teardown.Notify();
}
stop_teardown.Notify();
}
TEST(AdaptiveSharedBatchSchedulerTest, RateFeedback) {
test_util::FakeClockEnv env(Env::Default());
Notification start_teardown, stop_teardown;
std::unique_ptr<Thread> teardown_thread =
CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown);
{
double feedback = 0;
AdaptiveSharedBatchScheduler<FakeTask>::Options options;
options.initial_scheduling_period_micros = 1000;
options.min_scheduling_period_micros = 200;
options.max_scheduling_period_micros = 2000;
options.env = &env;
options.scheduling_period_feedback = [&feedback] { return feedback; };
options.feedback_smoothing_batches = 1;
std::shared_ptr<AdaptiveSharedBatchScheduler<FakeTask>> scheduler;
TF_ASSERT_OK(
AdaptiveSharedBatchScheduler<FakeTask>::Create(options, &scheduler));
std::unique_ptr<BatchScheduler<FakeTask>> queue;
int scheduled_items = 0;
auto queue_callback = [&scheduled_items,
&env](std::unique_ptr<Batch<FakeTask>> batch) {
ASSERT_TRUE(batch->IsClosed());
EXPECT_GT(batch->num_tasks(), 0);
scheduled_items = 0;
for (int i = 0; i < batch->num_tasks(); i++) {
scheduled_items += batch->task(i).size();
}
env.SleepForMicroseconds(1);
};
TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue));
// Wait for scheduling_thread to sleep.
env.BlockUntilThreadsAsleep(1);
// Enqueue 6 batches.
for (int i = 0; i < 6; i++) {
TF_ASSERT_OK(ScheduleTask(900 + i, queue.get()));
env.AdvanceByMicroseconds(1);
}
feedback = -500;
env.AdvanceByMicroseconds(994);
env.BlockUntilThreadsAsleep(2); // scheduling period = 500 usec.
EXPECT_EQ(scheduled_items, 900);
env.AdvanceByMicroseconds(500);
env.BlockUntilThreadsAsleep(2); // scheduling period = 250 usec.
EXPECT_EQ(scheduled_items, 901);
feedback = 0;
env.AdvanceByMicroseconds(250);
env.BlockUntilThreadsAsleep(2); // scheduling period = 250 usec.
EXPECT_EQ(scheduled_items, 902);
feedback = 10000; // large feedback should hit max_scheduling_period.
env.AdvanceByMicroseconds(250);
env.BlockUntilThreadsAsleep(2); // scheduling period = 2000 usec.
EXPECT_EQ(scheduled_items, 903);
feedback = -10000; // large feedback should hit min_scheduling_period.
env.AdvanceByMicroseconds(1999);
// No callback scheduled, only scheduling thread sleeping.
env.BlockUntilThreadsAsleep(1);
EXPECT_EQ(scheduled_items, 903);
env.AdvanceByMicroseconds(1);
env.BlockUntilThreadsAsleep(2); // scheduling period = 200 usec.
EXPECT_EQ(scheduled_items, 904);
env.AdvanceByMicroseconds(200);
env.BlockUntilThreadsAsleep(2);
EXPECT_EQ(scheduled_items, 905);
start_teardown.Notify();
}
stop_teardown.Notify();
}
TEST(AdaptiveSharedBatchSchedulerTest, FeedbackSmoothing) {
test_util::FakeClockEnv env(Env::Default());
Notification start_teardown, stop_teardown;
std::unique_ptr<Thread> teardown_thread =
CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown);
{
double feedback = 0;
AdaptiveSharedBatchScheduler<FakeTask>::Options options;
options.initial_scheduling_period_micros = 1000;
options.env = &env;
options.scheduling_period_feedback = [&feedback] { return feedback; };
options.feedback_smoothing_batches = 3;
std::shared_ptr<AdaptiveSharedBatchScheduler<FakeTask>> scheduler;
TF_ASSERT_OK(
AdaptiveSharedBatchScheduler<FakeTask>::Create(options, &scheduler));
std::unique_ptr<BatchScheduler<FakeTask>> queue;
int scheduled_items = 0;
auto queue_callback = [&scheduled_items,
&env](std::unique_ptr<Batch<FakeTask>> batch) {
ASSERT_TRUE(batch->IsClosed());
EXPECT_GT(batch->num_tasks(), 0);
scheduled_items = 0;
for (int i = 0; i < batch->num_tasks(); i++) {
scheduled_items += batch->task(i).size();
}
env.SleepForMicroseconds(1);
};
TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue));
// Wait for scheduling_thread to sleep.
env.BlockUntilThreadsAsleep(1);
// Enqueue 4 batches.
for (int i = 0; i < 4; i++) {
TF_ASSERT_OK(ScheduleTask(900 + i, queue.get()));
env.AdvanceByMicroseconds(1);
}
feedback = -300;
env.AdvanceByMicroseconds(996);
env.BlockUntilThreadsAsleep(2);
// ewma_feedback = 100, scheduling_period = 900.
EXPECT_EQ(scheduled_items, 900);
env.AdvanceByMicroseconds(899);
// No callback scheduled, only scheduling thread sleeping.
env.BlockUntilThreadsAsleep(1);
EXPECT_EQ(scheduled_items, 900);
env.AdvanceByMicroseconds(1);
env.BlockUntilThreadsAsleep(2);
// ewma_feedback = 167, scheduling_period = 750.
EXPECT_EQ(scheduled_items, 901);
env.AdvanceByMicroseconds(749);
// No callback scheduled, only scheduling thread sleeping.
env.BlockUntilThreadsAsleep(1);
EXPECT_EQ(scheduled_items, 901);
feedback = 1000 / 3.;
env.AdvanceByMicroseconds(1);
env.BlockUntilThreadsAsleep(2);
// emwa_feedback = 0, scheduling_period = 750.
EXPECT_EQ(scheduled_items, 902);
env.AdvanceByMicroseconds(749);
// No callback scheduled, only scheduling thread sleeping.
env.BlockUntilThreadsAsleep(1);
EXPECT_EQ(scheduled_items, 902);
env.AdvanceByMicroseconds(1);
env.BlockUntilThreadsAsleep(2);
EXPECT_EQ(scheduled_items, 903);
start_teardown.Notify();
}
stop_teardown.Notify();
}
TEST(AdaptiveSharedBatchSchedulerTest, QueueCapacityInfo) {
test_util::FakeClockEnv env(Env::Default());
Notification start_teardown, stop_teardown;
std::unique_ptr<Thread> teardown_thread =
CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown);
{
AdaptiveSharedBatchScheduler<FakeTask>::Options options;
options.initial_scheduling_period_micros = 1000;
options.env = &env;
std::shared_ptr<AdaptiveSharedBatchScheduler<FakeTask>> scheduler;
TF_ASSERT_OK(
AdaptiveSharedBatchScheduler<FakeTask>::Create(options, &scheduler));
std::unique_ptr<BatchScheduler<FakeTask>> queue;
int scheduled_items = 0;
auto queue_callback = [&scheduled_items,
&env](std::unique_ptr<Batch<FakeTask>> batch) {
ASSERT_TRUE(batch->IsClosed());
EXPECT_GT(batch->num_tasks(), 0);
scheduled_items = 0;
for (int i = 0; i < batch->num_tasks(); i++) {
scheduled_items += batch->task(i).size();
}
env.SleepForMicroseconds(1);
};
AdaptiveSharedBatchScheduler<FakeTask>::QueueOptions queue_options;
queue_options.max_batch_size = 10;
queue_options.max_enqueued_batches = 10;
TF_ASSERT_OK(scheduler->AddQueue(queue_options, queue_callback, &queue));
// Wait for scheduling_thread to sleep.
env.BlockUntilThreadsAsleep(1);
// Enqueue 3 tasks.
EXPECT_EQ(queue->NumEnqueuedTasks(), 0);
EXPECT_EQ(queue->SchedulingCapacity(), 100);
TF_ASSERT_OK(ScheduleTask(5, queue.get()));
EXPECT_EQ(queue->NumEnqueuedTasks(), 1);
EXPECT_EQ(queue->SchedulingCapacity(), 95);
env.AdvanceByMicroseconds(1);
TF_ASSERT_OK(ScheduleTask(6, queue.get()));
EXPECT_EQ(queue->NumEnqueuedTasks(), 2);
EXPECT_EQ(queue->SchedulingCapacity(), 84);
env.AdvanceByMicroseconds(1);
TF_ASSERT_OK(ScheduleTask(1, queue.get()));
EXPECT_EQ(queue->NumEnqueuedTasks(), 3);
EXPECT_EQ(queue->SchedulingCapacity(), 83);
env.AdvanceByMicroseconds(998);
env.BlockUntilThreadsAsleep(2);
EXPECT_EQ(scheduled_items, 5);
env.AdvanceByMicroseconds(1000);
env.BlockUntilThreadsAsleep(2);
EXPECT_EQ(scheduled_items, 7);
start_teardown.Notify();
}
stop_teardown.Notify();
}
} // namespace anonymous
} // namespace serving
} // namespace tensorflow
......@@ -78,7 +78,7 @@ template <typename TaskType>
class Batch {
public:
Batch() = default;
~Batch(); // Blocks until the batch is closed.
virtual ~Batch(); // Blocks until the batch is closed.
// Appends 'task' to the batch. After calling AddTask(), the newly-added task
// can be accessed via task(num_tasks()-1) or mutable_task(num_tasks()-1).
......
......@@ -14,7 +14,7 @@
# ==============================================================================
include (ExternalProject)
set(cub_URL https://github.com/NVlabs/cub/archive/1.7.4.zip)
set(cub_URL https://mirror.bazel.build/github.com/NVlabs/cub/archive/1.7.4.zip)
set(cub_HASH SHA256=20a1a39fd97e5da7f40f5f2e7fd73fd2ea59f9dc4bb8a6c5f228aa543e727e31)
set(cub_BUILD ${CMAKE_CURRENT_BINARY_DIR}/cub/src/cub)
set(cub_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/cub/src/cub)
......
......@@ -15,7 +15,7 @@
include (ExternalProject)
set(gif_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/gif_archive/giflib-5.1.4/)
set(gif_URL http://mirror.bazel.build/ufpr.dl.sourceforge.net/project/giflib/giflib-5.1.4.tar.gz)
set(gif_URL https://mirror.bazel.build/ufpr.dl.sourceforge.net/project/giflib/giflib-5.1.4.tar.gz)
set(gif_HASH SHA256=34a7377ba834397db019e8eb122e551a49c98f49df75ec3fcc92b9a794a4f6d1)
set(gif_INSTALL ${CMAKE_BINARY_DIR}/gif/install)
set(gif_BUILD ${CMAKE_BINARY_DIR}/gif/src/gif)
......
......@@ -15,7 +15,7 @@
include (ExternalProject)
set(jpeg_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/jpeg_archive)
set(jpeg_URL http://mirror.bazel.build/www.ijg.org/files/jpegsrc.v9a.tar.gz)
set(jpeg_URL https://mirror.bazel.build/www.ijg.org/files/jpegsrc.v9a.tar.gz)
set(jpeg_HASH SHA256=3a753ea48d917945dd54a2d97de388aa06ca2eb1066cbfdc6652036349fe05a7)
set(jpeg_BUILD ${CMAKE_CURRENT_BINARY_DIR}/jpeg/src/jpeg)
set(jpeg_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/jpeg/install)
......
......@@ -15,7 +15,7 @@
include (ExternalProject)
set(lmdb_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/lmdb)
set(lmdb_URL http://mirror.bazel.build/github.com/LMDB/lmdb/archive/LMDB_0.9.19.tar.gz)
set(lmdb_URL https://mirror.bazel.build/github.com/LMDB/lmdb/archive/LMDB_0.9.19.tar.gz)
set(lmdb_HASH SHA256=108532fb94c6f227558d45be3f3347b52539f0f58290a7bb31ec06c462d05326)
set(lmdb_BUILD ${CMAKE_BINARY_DIR}/lmdb/src/lmdb)
set(lmdb_INSTALL ${CMAKE_BINARY_DIR}/lmdb/install)
......
......@@ -47,4 +47,4 @@ ExternalProject_Add(snappy
)
# actually enables snappy in the source code
add_definitions(-DSNAPPY)
\ No newline at end of file
add_definitions(-DTF_USE_SNAPPY)
......@@ -86,7 +86,7 @@ cuda_py_test(
"//tensorflow/python:client",
"//tensorflow/python:client_testlib",
"//tensorflow/python/eager:graph_callable",
"//tensorflow/python:platform_test",
"//tensorflow/python/eager:test",
"//tensorflow/python:variables",
],
)
......@@ -132,11 +132,12 @@ py_library(
"//tensorflow/python:array_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:init_ops",
"//tensorflow/python:layers_base",
"//tensorflow/python:math_ops",
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:function",
],
)
......@@ -146,6 +147,10 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":metrics",
"//tensorflow/python:array_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:variables",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:test",
],
)
......@@ -160,6 +165,8 @@ py_library(
deps = [
":datasets",
":metrics",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:function",
],
)
......
......@@ -86,7 +86,7 @@ class EvaluatorTest(test.TestCase):
for v in e.metric_variables:
p = v.name.split("/")[0]
prefix_count[p] = prefix_count.get(p, 0) + 1
self.assertEqual({"outer-mean": 2, "mean": 2}, prefix_count)
self.assertEqual({"outer_mean": 2, "mean": 2}, prefix_count)
def testDataset(self):
e = SimpleEvaluator(IdentityModel())
......
......@@ -18,6 +18,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import re
from tensorflow.python.eager import context
from tensorflow.python.eager import function
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
......@@ -25,55 +29,69 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
_to_replace = re.compile("[^A-Za-z0-9.]")
class Metric(object):
"""A metric holds state for aggregating statistics over an evaluation run.
Users will use Evaluator.add_metric() to add Metric objects to their
evaluation, call them in each step, and then use
Evaluator.all_metric_results() at the end.
evaluation, call them in each step (treating the object as a callable),
and then use Evaluator.all_metric_results() at the end.
Descendants will implement:
* call(): Should follow this pattern:
if not self.built:
self.var = self.add_variable(...)
self.add_update(self.var.assign_add(...))
* aggregate(): Adds in the state from a list of metrics of the same type
as `self`. (Default of summing all the variables will be fine for most
descendants.)
* result(): Computes and returns a final value for the metric
* `build()`: All variables should be created in this method, by calling
`self.add_variable()` as in: `self.var = self.add_variable(...)`
build() will be called in the first invocation of `__call__()`, with
the same arguments passed `call()`.
* `call()`: Has all updates to variables, as in:
self.var.assign_add(...)
* `result()`: Computes and returns a final value for the metric
from the variables in `self`.
Decendants may override, but usually won't need to:
* `aggregate()`: Adds in the state from a list of metrics of the same type
as `self`. (Default is to sum all the variables.)
* `reset()`: Reset all variables to their initial state. (Default is to
zero all the variables.)
Note that users should not call `aggregate()` or `reset()`, they are for
use by TensorFlow infrastructure.
"""
def __init__(self, name=None):
self.built = False
self._built = False
self._vars = []
self._updates = []
self._name = name or self.__class__.__name__
# TODO(josh11b): Need some way to make sure two Metrics in the same
# Network have distinct names. Maybe we can get a unique name from
# a name/variable scope?
# TODO(josh11b): self._in_graph_mode = context.in_graph_mode()
name = name or self.__class__.__name__
# Replace things like spaces in name to create a valid scope name.
scope_name = _to_replace.sub("_", name)
# We create the variable scope now to get the unique name that will
# be used as a variable prefix when build() calls add_variable().
with variable_scope.variable_scope(
None, default_name=scope_name, use_resource=True, reuse=False) as scope:
pos = scope.name.rfind(scope_name)
self._name = name + scope.name[pos + len(scope_name):]
self._scope = scope
if context.in_graph_mode():
# We make self.call() into a graph callable here, so that we can
# return a single op that performs all of the variable updates.
self.call = function.defun(self.call)
# ---- API for users ----
def __call__(self, *args, **kwargs):
# TODO(josh11b): If self._in_graph_mode is true, make self.call() into a
# graph callable here, so that variable updates happen without requiring
# a separate fetch.
# TODO(josh11b): Do we need a separate build() method to separate
# initialization from each update? If so, how do we get the arguments
# to it? We *could* just pass in *args and **kwargs...
if not self.built:
# TODO(ashankar): Set up container isolation so there is no chance
# distinct metrics objects accidentally share variables.
# TODO(josh11b): Replace things like spaces in self._name to create
# a valid scope name.
with variable_scope.variable_scope(
self._name, use_resource=True, reuse=False):
ret = self.call(*args, **kwargs)
self.built = True
else:
ret = self.call(*args, **kwargs)
return ret
"""Returns op to execute to update this metric for these inputs.
Returns None if eager execution is enabled.
Args:
*args:
**kwargs: A mini-batch of inputs to the Metric, passed on to `call()`.
"""
if not self._built:
with variable_scope.variable_scope(self._scope):
self.build(*args, **kwargs)
self._built = True
return self.call(*args, **kwargs)
@property
def name(self):
......@@ -84,10 +102,43 @@ class Metric(object):
return self._vars
# ---- To be implemented by descendants ---
def build(self, *args, **kwargs):
"""Method to create variables.
Called by `__call__()` before `call()` for the first time.
Args:
*args:
**kwargs: The arguments to the first invocation of `__call__()`.
`build()` may use the shape and/or dtype of these arguments
when deciding how to create variables.
"""
raise NotImplementedError("Metrics must define a build() member function")
def call(self, *args, **kwargs):
"""Accumulates statistics for the metric."""
"""Accumulates statistics for the metric. Users should use __call__ instead.
Note: This function is executed as a graph function in graph mode.
This means:
a) Operations on the same resource are executed in textual order.
This should make it easier to do things like add the updated
value of a variable to another, for example.
b) You don't need to worry about collecting the update ops to execute.
All update ops added to the graph by this function will be executed.
As a result, code should generally work the same way with graph or
eager execution.
Args:
*args:
**kwargs: A mini-batch of inputs to the Metric, as passed to
`__call__()`.
"""
raise NotImplementedError("Metrics must define a call() member function")
def result(self): # TODO(josh11b): Add an optional summary_writer parameter.
"""Computes and returns a final value for the metric."""
raise NotImplementedError("Metrics must define a result() member function")
# We can support two different strategies of for doing data-parallel
# distributed metric computations:
# * Put metric variables on the first device and rely on small
......@@ -123,16 +174,19 @@ class Metric(object):
self._vars[i].assign_add(math_ops.add_n([m._vars[i] for m in metrics]))
# pylint: enable=protected-access
def result(self): # TODO(josh11b): Add an optional summary_writer parameter.
"""Computes and returns a final value for the metric."""
raise NotImplementedError("Metrics must define a result() member function")
def reset(self):
"""Reset this metric to a freshly initialized state.
Default implementation zeros all the metric variables.
"""
for v in self._vars:
v.assign(math_ops.zeros_like(v))
# ---- For use by descendants ---
def add_variable(self, name, shape=None, dtype=None, initializer=None):
"""***Only for use by descendants of Metric***."""
if self.built:
raise RuntimeError("Can't call add_variable() after a Metric has been "
"built in the first call().")
if self._built:
raise RuntimeError("Can't call add_variable() except in build().")
v = variable_scope.get_variable(name, shape, dtype, initializer,
trainable=False, use_resource=True)
self._vars.append(v)
......@@ -144,6 +198,15 @@ class Mean(Metric):
# TODO(josh11b): Maybe have a dtype argument that defaults to tf.float64?
# Or defaults to type of the input if it is tf.float32, else tf.float64?
def build(self, values, weights=None):
del values, weights # build() does not use call's arguments
self.numer = self.add_variable(name="numer", shape=(),
dtype=dtypes.float64,
initializer=init_ops.zeros_initializer)
self.denom = self.add_variable(name="denom", shape=(),
dtype=dtypes.float64,
initializer=init_ops.zeros_initializer)
def call(self, values, weights=None):
"""Accumulate statistics for computing the mean.
......@@ -154,13 +217,6 @@ class Mean(Metric):
values: Tensor with the per-example value.
weights: Optional weighting of each example. Defaults to 1.
"""
if not self.built: # False only in the first call().
self.numer = self.add_variable(name="numer", shape=(),
dtype=dtypes.float64,
initializer=init_ops.zeros_initializer)
self.denom = self.add_variable(name="denom", shape=(),
dtype=dtypes.float64,
initializer=init_ops.zeros_initializer)
if weights is None:
self.denom.assign_add(
math_ops.cast(array_ops.size(values), dtypes.float64))
......@@ -179,6 +235,10 @@ class Mean(Metric):
class Accuracy(Mean):
"""Calculates how often `predictions` matches `labels`."""
def build(self, labels, predictions, weights=None):
del labels, predictions, weights
super(Accuracy, self).build(None) # Arguments are unused
def call(self, labels, predictions, weights=None):
"""Accumulate accuracy statistics.
......
......@@ -19,7 +19,11 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.eager.python import metrics
from tensorflow.python.eager import context
from tensorflow.python.eager import test
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variables
class MetricsTest(test.TestCase):
......@@ -56,6 +60,53 @@ class MetricsTest(test.TestCase):
m([7], [2]) # 0 correct, weight 1
self.assertEqual(2.5/5, m.result().numpy())
def testTwoMeans(self):
# Verify two metrics with the same class and name don't
# accidentally share state.
m1 = metrics.Mean()
m2 = metrics.Mean()
m1(0)
m2(2)
self.assertEqual(0, m1.result().numpy())
self.assertEqual(2, m2.result().numpy())
self.assertNotEqual(m1.name, m2.name)
def testNamesWithSpaces(self):
# Verify two metrics with the same class and name don't
# accidentally share state.
m1 = metrics.Mean("has space")
m2 = metrics.Mean("has space")
m2(2)
m1(0)
self.assertEqual(m1.name, "has space")
self.assertEqual(m1.numer.name, "has_space/numer:0")
self.assertEqual(m2.name, "has space_1")
self.assertEqual(m2.numer.name, "has_space_1/numer:0")
def testGraph(self):
with context.graph_mode(), self.test_session() as sess:
m = metrics.Mean()
p = array_ops.placeholder(dtypes.float32)
accumulate = m(p)
variables.global_variables_initializer().run()
sess.run(accumulate, feed_dict={p: [1, 10, 100]})
sess.run(accumulate, feed_dict={p: 1000})
sess.run(accumulate, feed_dict={p: [10000, 100000]})
self.assertAllEqual(m.result().eval(), 111111.0/6)
def testTwoMeansGraph(self):
# Verify two metrics with the same class and name don't
# accidentally share state.
with context.graph_mode(), self.test_session() as sess:
m1 = metrics.Mean()
m2 = metrics.Mean()
accumulate1 = m1(0)
accumulate2 = m2(2)
variables.global_variables_initializer().run()
sess.run([accumulate1, accumulate2])
self.assertEqual(0, m1.result().eval())
self.assertEqual(2, m2.result().eval())
if __name__ == "__main__":
test.main()
......@@ -22,6 +22,7 @@ import os
from tensorflow.contrib.eager.python import saver as _saver
from tensorflow.python.eager import context
from tensorflow.python.eager import graph_callable
from tensorflow.python.eager import test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
......@@ -29,7 +30,6 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
class SaverTest(test.TestCase):
......@@ -38,7 +38,7 @@ class SaverTest(test.TestCase):
return '/device:GPU:0' if context.num_gpus() else '/device:CPU:0'
def testBasics(self):
with context.eager_mode(), ops.device(self._dev()):
with ops.device(self._dev()):
v1 = resource_variable_ops.ResourceVariable(1.0, name='v1')
def model():
return array_ops.constant(2.0) * v1
......@@ -54,8 +54,42 @@ class SaverTest(test.TestCase):
saver.restore(ckpt_prefix)
self.assertEqual(v1.read_value().numpy(), 1.0)
def testRestoreOnCreate(self):
def testSameNameNoClobbering(self):
with context.eager_mode(), ops.device(self._dev()):
# Note that this test purposefully uses Graphs rather than
# IsolateTest. Users are more likely to accidentally create the same
# variable name this way.
first_graph = ops.Graph()
with first_graph.as_default():
v1_first_graph = resource_variable_ops.ResourceVariable(1.0, name='v1')
with ops.Graph().as_default():
v1_second_graph = resource_variable_ops.ResourceVariable(2.0, name='v1')
saver = _saver.Saver([v1_first_graph, v1_second_graph])
ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt')
with self.assertRaisesRegexp(ValueError, 'v1'):
saver.save(ckpt_prefix)
def testDifferentGraphError(self):
with context.eager_mode(), ops.device(self._dev()):
with ops.Graph().as_default():
v1 = resource_variable_ops.ResourceVariable(1.0, name='v1')
with ops.Graph().as_default():
saver = _saver.Saver([v1])
ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt')
with self.assertRaisesRegexp(ValueError, 'Graph'):
saver.save(ckpt_prefix)
def testSameObjectOK(self):
with context.eager_mode(), ops.device(self._dev()):
v1 = resource_variable_ops.ResourceVariable(1.0, name='v1')
# While different objects with the same shared_name are not good, passing
# in the same object multiple times is fine.
saver = _saver.Saver([v1, v1])
ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt')
saver.save(ckpt_prefix)
def testRestoreOnCreate(self):
with ops.device(self._dev()):
def model(init_val):
v1 = resource_variable_ops.ResourceVariable(init_val, name='v1')
return array_ops.constant(1.0) * v1, v1
......@@ -71,12 +105,9 @@ class SaverTest(test.TestCase):
# Value is from checkpoint, but not from argument.
ret, _ = model(2.0)
self.assertEqual(ret.numpy(), 1.0)
# Create it a second time won't re-assign the checkpoint value.
v1_2 = resource_variable_ops.ResourceVariable(3.0, name='v1')
self.assertEqual(v1_2.read_value().numpy(), 3.0)
def testRestoreNotFound(self):
with context.eager_mode(), ops.device(self._dev()):
with ops.device(self._dev()):
def model(v):
return array_ops.constant(1.0) * v
......@@ -92,7 +123,7 @@ class SaverTest(test.TestCase):
_ = model(resource_variable_ops.ResourceVariable(1.0, name='v2'))
def testSaveRestoreGraphCallable(self):
with context.eager_mode(), ops.device(self._dev()):
with ops.device(self._dev()):
@graph_callable.graph_callable(
[graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)])
def model(x):
......
......@@ -53,6 +53,7 @@ To use, at program startup, call `tfe.enable_eager_execution()`.
@@in_eager_mode
@@in_graph_mode
@@IsolateTest
@@run_test_in_graph_and_eager_modes
"""
......@@ -84,6 +85,7 @@ from tensorflow.python.eager.execution_callbacks import nan_callback
from tensorflow.python.eager.execution_callbacks import seterr
from tensorflow.python.framework.ops import enable_eager_execution
from tensorflow.python.framework.ops import eager_run as run
from tensorflow.python.framework.test_util import IsolateTest
from tensorflow.python.framework.test_util import run_in_graph_and_eager_modes as run_test_in_graph_and_eager_modes
from tensorflow.python.ops.resource_variable_ops import ResourceVariable as Variable
from tensorflow.python.util.all_util import remove_undocumented
......
......@@ -24,7 +24,11 @@ the full-batch version.
approach for computing the initial cluster assignments that is expensive but is
typically less prone to getting stuck in bad local minima.
We provide distributed implementations of both full-batch and mini-batch
K-Means algorithm. Both K-Means++ and random initialization are supported.
The user can also choose between **Cosine** and **Squared Euclidean** distance
metrics.
**[k-MC2](https://www.aaai.org/ocs/index.php/AAAI/AAAI16/paper/view/12147/11759)**
provides a very fast seeding method that provides high quality centers
comparable to K-Means++ seeding. k-MC2 works particularly well if it is combined
with Mini-batch K-Means.
We provide distributed implementations of both full-batch and mini-batch K-Means
algorithm. K-Means++, k-MC2 and random initialization are supported. The user
can also choose between **Cosine** and **Squared Euclidean** distance metrics.
......@@ -224,6 +224,58 @@ class KmeansPlusPlusInitializationOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("KmeansPlusPlusInitialization").Device(DEVICE_CPU),
KmeansPlusPlusInitializationOp);
// Implementation of one single Markov Chain for the k-MC^2 algorithm
class KMC2ChainInitializationOp : public OpKernel {
public:
explicit KMC2ChainInitializationOp(OpKernelConstruction* context)
: OpKernel(context) {
OP_REQUIRES_OK(context,
context->MatchSignature({DT_FLOAT, DT_INT64}, {DT_INT64}));
}
void Compute(OpKernelContext* context) override {
const Tensor& distances_tensor = context->input(0);
const Tensor& seed_tensor = context->input(1);
OP_REQUIRES(context, TensorShapeUtils::IsVector(distances_tensor.shape()),
InvalidArgument("Input distances should be a vector."));
OP_REQUIRES(context, TensorShapeUtils::IsScalar(seed_tensor.shape()),
InvalidArgument("Input seed should be a scalar."));
const int64 num_points = distances_tensor.dim_size(0);
const int64 seed = seed_tensor.scalar<int64>()();
OP_REQUIRES(context, num_points > 0,
InvalidArgument("Expected distances_tensor.size() > 0."));
random::PhiloxRandom random(seed);
random::SimplePhilox rng(&random);
auto distances = distances_tensor.flat<float>();
// Set the initial state of the Markov chain to be the first candidate.
int64 selected_index = 0;
float selected_distance = distances(selected_index);
// Build a Markov chain of length num_points.
for (int64 i = 1; i < num_points; ++i) {
const float candidate_distance = distances(i);
// Set the next state of the Markov chain to be the candidate with
// probability min(1, candidate_distance/selected_distance).
if (candidate_distance > rng.RandFloat() * selected_distance) {
selected_index = i;
selected_distance = candidate_distance;
}
}
Tensor* output_sampled_index_tensor;
OP_REQUIRES_OK(context,
context->allocate_output(0, TensorShape({}),
&output_sampled_index_tensor));
auto output = output_sampled_index_tensor->scalar<int64>();
// Return the last state of the Markov chain as the new center.
output() = selected_index;
}
};
REGISTER_KERNEL_BUILDER(Name("KMC2ChainInitialization").Device(DEVICE_CPU),
KMC2ChainInitializationOp);
// Operator for computing the nearest neighbors for a set of points.
class NearestNeighborsOp : public OpKernel {
public:
......
......@@ -116,6 +116,62 @@ RUN_BM_KmeansPlusPlusInitialization(k3RetriesPerSample);
#undef RUN_BM_KmeansPlusPlusInitialization
#undef BENCHMARK_KMEANS_PLUS_PLUS
Graph* SetUpKMC2Initialization(int num_points) {
Graph* g = new Graph(OpRegistry::Global());
Tensor distances(DT_FLOAT, TensorShape({num_points}));
Tensor seed(DT_INT64, TensorShape({}));
distances.flat<float>().setRandom();
seed.flat<int64>().setConstant(12345);
TF_CHECK_OK(
NodeBuilder("KMC2ChainInitializationOp", "KMC2ChainInitialization")
.Input(test::graph::Constant(g, distances))
.Input(test::graph::Constant(g, seed))
.Finalize(g, nullptr /* node */));
return g;
}
template <int num_points, int num_to_sample, int num_dims>
void BM_KMC2Initialization(int iters) {
testing::StopTiming();
testing::ItemsProcessed(static_cast<int64>(iters) * num_points * num_dims *
num_to_sample);
testing::UseRealTime();
Graph* g = SetUpKMC2Initialization(num_points);
testing::StartTiming();
test::Benchmark("cpu", g).Run(iters);
}
#define BENCHMARK_KMC2(p, c, d) \
void BM_KMC2Initialization_##p##_##c##_##d(int iters) { \
BM_KMC2Initialization<p, c, d>(iters); \
} \
BENCHMARK(BM_KMC2Initialization_##p##_##c##_##d);
#define RUN_BM_KMC2Initialization \
BENCHMARK_KMC2(k10Points, k2Centers, k100Dim); \
BENCHMARK_KMC2(k10Points, k5Centers, k100Dim); \
BENCHMARK_KMC2(k10Points, k10Centers, k100Dim); \
BENCHMARK_KMC2(k100Points, k10Centers, k100Dim); \
BENCHMARK_KMC2(k100Points, k20Centers, k100Dim); \
BENCHMARK_KMC2(k100Points, k50Centers, k100Dim); \
BENCHMARK_KMC2(k100Points, k100Centers, k100Dim); \
BENCHMARK_KMC2(k1kPoints, k100Centers, k100Dim); \
BENCHMARK_KMC2(k1kPoints, k200Centers, k100Dim); \
BENCHMARK_KMC2(k1kPoints, k500Centers, k100Dim); \
BENCHMARK_KMC2(k1kPoints, k1kCenters, k100Dim); \
BENCHMARK_KMC2(k10kPoints, k100Centers, k100Dim); \
BENCHMARK_KMC2(k10kPoints, k200Centers, k100Dim); \
BENCHMARK_KMC2(k10kPoints, k500Centers, k100Dim); \
BENCHMARK_KMC2(k10kPoints, k1kCenters, k100Dim); \
BENCHMARK_KMC2(k1MPoints, k100Centers, k100Dim); \
BENCHMARK_KMC2(k1MPoints, k200Centers, k100Dim); \
BENCHMARK_KMC2(k1MPoints, k500Centers, k100Dim); \
BENCHMARK_KMC2(k1MPoints, k1kCenters, k100Dim)
RUN_BM_KMC2Initialization;
#undef RUN_BM_KMC2Initialization
#undef BENCHMARK_KMC2
Graph* SetUpNearestNeighbors(int num_dims, int num_points, int num_centers,
int k) {
Graph* g = new Graph(OpRegistry::Global());
......
......@@ -44,6 +44,25 @@ num_retries_per_sample: Scalar. For each row that is sampled, this parameter
samples: Matrix of shape (num_to_sample, d). The sampled rows.
)");
REGISTER_OP("KMC2ChainInitialization")
.Input("distances: float32")
.Input("seed: int64")
.Output("index: int64")
.SetShapeFn(shape_inference::ScalarShape)
.Doc(R"(
Returns the index of a data point that should be added to the seed set.
Entries in distances are assumed to be squared distances of candidate points to
the already sampled centers in the seed set. The op constructs one Markov chain
of the k-MC^2 algorithm and returns the index of one candidate point to be added
as an additional cluster center.
distances: Vector with squared distances to the closest previously sampled
cluster center for each candidate point.
seed: Scalar. Seed for initializing the random number generator.
index: Scalar with the index of the sampled point.
)");
REGISTER_OP("NearestNeighbors")
.Input("points: float32")
.Input("centers: float32")
......
......@@ -55,6 +55,63 @@ class KmeansPlusPlusInitializationTest(test.TestCase):
self.runTestWithSeed(seed)
class KMC2InitializationTest(test.TestCase):
def runTestWithSeed(self, seed):
with self.test_session():
distances = np.zeros(1000).astype(np.float32)
distances[6] = 10e7
distances[4] = 10e3
sampled_point = clustering_ops.kmc2_chain_initialization(distances, seed)
self.assertEquals(sampled_point.eval(), 6)
distances[6] = 0.0
sampled_point = clustering_ops.kmc2_chain_initialization(distances, seed)
self.assertEquals(sampled_point.eval(), 4)
def testBasic(self):
for seed in range(100):
self.runTestWithSeed(seed)
class KMC2InitializationLargeTest(test.TestCase):
def setUp(self):
self._distances = np.zeros(1001)
self._distances[500] = 100.0
self._distances[1000] = 50.0
def testBasic(self):
with self.test_session():
counts = {}
seed = 0
for i in range(50):
sample = clustering_ops.kmc2_chain_initialization(
self._distances, seed + i).eval()
counts[sample] = counts.get(sample, 0) + 1
self.assertEquals(len(counts), 2)
self.assertTrue(500 in counts)
self.assertTrue(1000 in counts)
self.assertGreaterEqual(counts[500], 5)
self.assertGreaterEqual(counts[1000], 5)
class KMC2InitializationCornercaseTest(test.TestCase):
def setUp(self):
self._distances = np.zeros(10)
def runTestWithSeed(self, seed):
with self.test_session():
sampled_point = clustering_ops.kmc2_chain_initialization(
self._distances, seed)
self.assertEquals(sampled_point.eval(), 0)
def testBasic(self):
for seed in range(100):
self.runTestWithSeed(seed)
# A simple test that can be verified by hand.
class NearestCentersTest(test.TestCase):
......
......@@ -50,6 +50,7 @@ COSINE_DISTANCE = 'cosine'
RANDOM_INIT = 'random'
KMEANS_PLUS_PLUS_INIT = 'kmeans_plus_plus'
KMC2_INIT = 'kmc2'
# The name of the variable holding the cluster centers. Used by the Estimator.
CLUSTERS_VAR_NAME = 'clusters'
......@@ -66,7 +67,8 @@ class KMeans(object):
use_mini_batch=False,
mini_batch_steps_per_iteration=1,
random_seed=0,
kmeans_plus_plus_num_retries=2):
kmeans_plus_plus_num_retries=2,
kmc2_chain_length=200):
"""Creates an object for generating KMeans clustering graph.
This class implements the following variants of K-means algorithm:
......@@ -95,7 +97,8 @@ class KMeans(object):
exactly like a full-batch version.
Args:
inputs: An input tensor or list of input tensors
inputs: An input tensor or list of input tensors. It is assumed that the
data points have been previously randomly permuted.
num_clusters: An integer tensor specifying the number of clusters. This
argument is ignored if initial_clusters is a tensor or numpy array.
initial_clusters: Specifies the clusters used during initialization. One
......@@ -104,6 +107,7 @@ class KMeans(object):
- a function f(inputs, k) that returns up to k centers from `inputs`.
- "random": Choose centers randomly from `inputs`.
- "kmeans_plus_plus": Use kmeans++ to choose centers from `inputs`.
- "kmc2": Use the fast k-MC2 algorithm to choose centers from `inputs`.
In the last three cases, one batch of `inputs` may not yield
`num_clusters` centers, in which case initialization will require
multiple batches until enough centers are chosen. In the case of
......@@ -121,13 +125,17 @@ class KMeans(object):
additional points to draw from the current distribution before selecting
the best. If a negative value is specified, a heuristic is used to
sample O(log(num_to_sample)) additional points.
kmc2_chain_length: Determines how many candidate points are used by the
k-MC2 algorithm to produce one new cluster centers. If a (mini-)batch
contains less points, one new cluster center is generated from the
(mini-)batch.
Raises:
ValueError: An invalid argument was passed to initial_clusters or
distance_metric.
"""
if isinstance(initial_clusters, str) and initial_clusters not in [
RANDOM_INIT, KMEANS_PLUS_PLUS_INIT
RANDOM_INIT, KMEANS_PLUS_PLUS_INIT, KMC2_INIT
]:
raise ValueError(
"Unsupported initialization algorithm '%s'" % initial_clusters)
......@@ -141,6 +149,7 @@ class KMeans(object):
self._mini_batch_steps_per_iteration = int(mini_batch_steps_per_iteration)
self._random_seed = random_seed
self._kmeans_plus_plus_num_retries = kmeans_plus_plus_num_retries
self._kmc2_chain_length = kmc2_chain_length
@classmethod
def _distance_graph(cls, inputs, clusters, distance_metric):
......@@ -302,9 +311,10 @@ class KMeans(object):
else:
cluster_centers_updated = cluster_centers
update_in_steps = None
cluster_counts = (variable_scope.variable(
array_ops.ones([num_clusters], dtype=dtypes.int64))
if self._use_mini_batch else None)
cluster_counts = (
variable_scope.variable(
array_ops.ones([num_clusters], dtype=dtypes.int64))
if self._use_mini_batch else None)
return (cluster_centers, cluster_centers_initialized, cluster_counts,
cluster_centers_updated, update_in_steps)
......@@ -359,7 +369,7 @@ class KMeans(object):
init_op = _InitializeClustersOpFactory(
self._inputs, num_clusters, initial_clusters, self._distance_metric,
self._random_seed, self._kmeans_plus_plus_num_retries,
cluster_centers_var, cluster_centers_updated,
self._kmc2_chain_length, cluster_centers_var, cluster_centers_updated,
cluster_centers_initialized).op()
cluster_centers = cluster_centers_var
......@@ -520,8 +530,9 @@ class KMeans(object):
array_ops.reshape(array_ops.shape(inp)[0], [-1])),
[-1, 1]), cluster_idx, num_clusters))
with ops.colocate_with(cluster_centers, ignore_existing=True):
new_clusters_centers = math_ops.add_n(cluster_sums) / (math_ops.cast(
math_ops.add_n(cluster_counts), cluster_sums[0].dtype) + epsilon)
new_clusters_centers = math_ops.add_n(cluster_sums) / (
math_ops.cast(math_ops.add_n(cluster_counts), cluster_sums[0].dtype) +
epsilon)
if self._clusters_l2_normalized():
new_clusters_centers = nn_impl.l2_normalize(new_clusters_centers, dim=1)
return state_ops.assign(cluster_centers, new_clusters_centers)
......@@ -548,9 +559,12 @@ class _InitializeClustersOpFactory(object):
cluster_centers_initialized := true
"""
# TODO(ccolby): Refactor this class so that kmc2 isn't so much a special case.
def __init__(self, inputs, num_clusters, initial_clusters, distance_metric,
random_seed, kmeans_plus_plus_num_retries, cluster_centers,
cluster_centers_updated, cluster_centers_initialized):
random_seed, kmeans_plus_plus_num_retries, kmc2_chain_length,
cluster_centers, cluster_centers_updated,
cluster_centers_initialized):
"""Creates an op factory.
Args:
......@@ -560,6 +574,7 @@ class _InitializeClustersOpFactory(object):
distance_metric: See KMeans constructor.
random_seed: See KMeans constructor.
kmeans_plus_plus_num_retries: See KMeans constructor.
kmc2_chain_length: See KMeans constructor.
cluster_centers: The TF variable holding the initial centers. It may
already contain some centers when the op is executed.
cluster_centers_updated: A second TF variable to hold a copy of the
......@@ -575,6 +590,7 @@ class _InitializeClustersOpFactory(object):
self._distance_metric = distance_metric
self._random_seed = random_seed
self._kmeans_plus_plus_num_retries = kmeans_plus_plus_num_retries
self._kmc2_chain_length = kmc2_chain_length
self._cluster_centers = cluster_centers
self._cluster_centers_updated = cluster_centers_updated
self._cluster_centers_initialized = cluster_centers_initialized
......@@ -604,6 +620,90 @@ class _InitializeClustersOpFactory(object):
math_ops.to_int64(self._num_remaining), self._random_seed,
self._kmeans_plus_plus_num_retries)
def _kmc2_multiple_centers(self):
"""Adds new initial cluster centers using the k-MC2 algorithm.
In each call to the op, the provided batch is split into subsets based on
the specified `kmc2_chain_length`. On each subset, a single Markov chain of
the k-MC2 algorithm is used to add *one* new center cluster center. If there
are less than `kmc2_chain_length` points in the subset, a single center is
added using one Markov chain on the full input. It is assumed that the
provided batch has previously been randomly permuted. Otherwise, k-MC2 may
return suboptimal centers.
Returns:
An op that adds new cluster centers.
"""
# The op only operates on the first shard of data.
first_shard = self._inputs[0]
# Number of points in the input that can be used.
batch_size = array_ops.shape(first_shard)[0]
# Maximum number of subsets such that the size of each subset is at least
# `kmc2_chain_length`. Final subsets may be larger.
max_to_sample = math_ops.cast(
batch_size / self._kmc2_chain_length, dtype=dtypes.int32)
# We sample at least one new center and at most all remaining centers.
num_to_sample = math_ops.maximum(
math_ops.minimum(self._num_remaining, max_to_sample), 1)
def _cond(i, _):
"""Stopping condition for the while loop."""
return math_ops.less(i, num_to_sample)
def _body(i, _):
"""Body that adds a single new center based on a subset."""
def _sample_random():
"""Returns a random point as a cluster center."""
# By assumption the batch is reshuffled and _sample_random is always
# called for i=0. Hence, we simply return the first point.
new_center = array_ops.reshape(first_shard[0], [1, -1])
if self._distance_metric == COSINE_DISTANCE:
new_center = nn_impl.l2_normalize(new_center, dim=1)
return new_center
def _sample_kmc2_chain():
"""Returns previous centers as well as a new center sampled using k-MC2.
"""
# Extract the subset from the underlying batch.
start = i * self._kmc2_chain_length
end = start + self._kmc2_chain_length
subset = first_shard[start:end]
# Compute the distances from points in the subset to previous centers.
_, distances = gen_clustering_ops.nearest_neighbors(
subset, self._cluster_centers, 1)
# Sample index of new center using k-MC2 Markov chain.
new_center_index = gen_clustering_ops.kmc2_chain_initialization(
array_ops.squeeze(distances), self._random_seed)
# Extract actual new center.
newly_sampled_center = array_ops.reshape(subset[new_center_index],
[1, -1])
# Return concatenation with previously sampled centers.
if self._distance_metric == COSINE_DISTANCE:
newly_sampled_center = nn_impl.l2_normalize(
newly_sampled_center, dim=1)
return array_ops.concat([self._cluster_centers, newly_sampled_center],
0)
# Obtain a random point if there are no previously sampled centers.
# Otherwise, construct a k-MC2 Markov chain.
new_centers = control_flow_ops.cond(
math_ops.equal(self._num_selected, 0), _sample_random,
_sample_kmc2_chain)
# Assign new cluster centers to underlying variable.
assigned_centers = state_ops.assign(
self._cluster_centers, new_centers, validate_shape=False)
if self._cluster_centers_updated is not self._cluster_centers:
assigned_centers = state_ops.assign(
self._cluster_centers_updated,
assigned_centers,
validate_shape=False)
return i + 1, self._num_clusters - array_ops.shape(assigned_centers)[0]
# Add num_to_sample new data points.
_, num_remaining = control_flow_ops.while_loop(_cond, _body, [0, 0])
return num_remaining
def _greedy_batch_sampler(self, sampler):
# If the input dataset size is smaller than the number of centers
# remaining, choose the entire input dataset as centers. This can happen
......@@ -657,7 +757,10 @@ class _InitializeClustersOpFactory(object):
with ops.control_dependencies([
check_ops.assert_positive(self._num_remaining),
]):
num_now_remaining = self._add_new_centers()
if self._initial_clusters == KMC2_INIT:
num_now_remaining = self._kmc2_multiple_centers()
else:
num_now_remaining = self._add_new_centers()
return control_flow_ops.cond(
math_ops.equal(num_now_remaining, 0),
lambda: state_ops.assign(self._cluster_centers_initialized, True),
......
......@@ -37,6 +37,7 @@ See the @{$python/contrib.framework} guide.
@@arg_scope
@@add_arg_scope
@@current_arg_scope
@@has_arg_scope
@@arg_scoped_arguments
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册