提交 13eb3b90 编写于 作者: A Alexandre Passos 提交者: TensorFlower Gardener

Experimental C and Python APIs to invoke TensorFlow kernels on concrete values.

PiperOrigin-RevId: 164902588
上级 7dfabcc0
......@@ -380,6 +380,7 @@ filegroup(
"//tensorflow/java/src/main/native:all_files",
"//tensorflow/python:all_files",
"//tensorflow/python/debug:all_files",
"//tensorflow/python/eager:all_files",
"//tensorflow/python/estimator:all_files",
"//tensorflow/python/feature_column:all_files",
"//tensorflow/python/kernel_tests:all_files",
......
# Experimental extensions to the C API for eager execution of kernels.
licenses(["notice"]) # Apache 2.0
cc_library(
name = "c_api",
srcs = ["c_api.cc"],
hdrs = ["c_api.h"],
visibility = [
"//tensorflow:internal",
"//tensorflow/python/eager:__pkg__",
],
deps = [
":runtime",
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_internal",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
],
)
cc_test(
name = "c_api_test",
srcs = ["c_api_test.cc"],
deps = [
":c_api",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
cc_library(
name = "runtime",
srcs = ["runtime.cc"],
hdrs = ["runtime.h"],
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/c:c_api",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow",
],
)
cc_test(
name = "runtime_test",
srcs = ["runtime_test.cc"],
deps = [
":runtime",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:client_session",
"//tensorflow/cc:ops",
"//tensorflow/cc:scope",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
此差异已折叠。
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_C_API_H_
#define TENSORFLOW_C_EAGER_C_API_H_
// C API extensions to experiment with eager execution of kernels.
#include "tensorflow/c/c_api.h"
#ifdef __cplusplus
extern "C" {
#endif
// "Context" under which operations/functions are executed. It encapsulates
// things like the available devices, resource manager etc.
//
// TODO(ashankar): Merge with TF_Session?
typedef struct TFE_Context TFE_Context;
extern TFE_Context* TFE_NewContext(const TF_SessionOptions* opts,
TF_Status* status);
extern void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status);
extern TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx,
TF_Status* status);
// A handle to a tensor on a device.
//
// Like a TF_Tensor, a TFE_TensorHandle refers to a tensor with a value, shape,
// type etc. Unlike a TF_Tensor, a TFE_TensorHandle may refer to such tensors
// placed in memory of different devices or remote address spaces.
typedef struct TFE_TensorHandle TFE_TensorHandle;
extern TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t);
extern void TFE_DeleteTensorHandle(TFE_TensorHandle* h);
extern TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h);
extern int TFE_TensorHandleNumDims(TFE_TensorHandle* h);
extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index);
extern const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h);
extern TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h,
TF_Status* status);
// Create a new TFE_TensorHandle with the same contents as 'h' but placed
// in the memory of the device name 'device_name'.
//
// Currently requires at least one of the source or destination devices to
// be CPU (i.e., for the source or destination tensor to be placed in
// host memory).
extern TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
TFE_Context* ctx,
const char* device_name,
TF_Status* status);
// Description of the TensorFlow op to execute.
//
// Assumes that the provided 'ctx' outlives the returned TFE_Op, i.e.,
// TFE_DeleteOp() is called before TFE_DeleteContext().
//
// Very similar to TF_OperationDescription with some differences:
// (1) TF_Output or TFE_TensorHandle* as arguments to TF_AddInput,
// TF_AddInputList
// (2) TF_ColocateWith, TF_AddControlInput etc. do not make sense.
// (3) Implementation detail: Avoid use of NodeBuilder/NodeDefBuilder since
// the additional sanity checks there seem unnecessary;
typedef struct TFE_Op TFE_Op;
extern TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
TF_Status* status);
extern void TFE_DeleteOp(TFE_Op* op);
// TODO(ashankar): TFE_OpSetDevice and TFE_Execute should not have a TFE_Context
// parameter. Instead, the TFE_Context should be captured when creating the
// TFE_Op.
extern void TFE_OpSetDevice(TFE_Op* op, TFE_Context* ctx,
const char* device_name, TF_Status* status);
extern void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status);
extern TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
unsigned char* is_list, TF_Status* status);
extern void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name,
const char* value);
extern void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value);
extern void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value);
extern void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name,
unsigned char value);
extern void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name,
TF_DataType value);
// If the number of dimensions is unknown, `num_dims` must be set to
// -1 and `dims` can be null. If a dimension is unknown, the
// corresponding entry in the `dims` array must be -1.
extern void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name,
const int64_t* dims, const int num_dims,
TF_Status* out_status);
extern void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name,
const char** value, int num_values);
extern void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name,
const int64_t* values, int num_values);
extern void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name,
const float* values, int num_values);
extern void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name,
const unsigned char* values, int num_values);
extern void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name,
const TF_DataType* values, int num_values);
extern void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
const int64_t** dims, const int* num_dims,
int num_values, TF_Status* out_status);
// Execute the operation defined by 'op' and return handles to computed
// tensors in 'retvals'.
//
// 'retvals' must point to a pre-allocated array of TFE_TensorHandle*
// and '*num_retvals' should be set to the size of this array.
//
// On return, 'num_retvals' will be set to the actual number of outputs
// returned by the operation.
extern void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals,
int* num_retvals, TF_Status* status);
// Add a function (serialized FunctionDef protocol buffer) to ctx so
// that it can be invoked using TFE_Execute.
extern void TFE_ContextAddFunctionDef(TFE_Context* ctx,
const char* serialized_function_def,
size_t size, TF_Status* status);
#ifdef __cplusplus
} /* end extern "C" */
#endif
#ifdef __cplusplus
// A workaround to ease conversion to and from numpy objects and
// TFE_TensorHandle's.
//
// TODO(ashankar): Figure out an alternative scheme that precludes the need for
// these API-boundary breaking methods.
namespace tensorflow {
class Tensor;
} // namespace tensorflow
const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory(
TFE_TensorHandle* h, TF_Status* status);
TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t);
#endif
#endif // TENSORFLOW_C_EAGER_C_API_H_
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/c_api.h"
#include <string.h>
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
using tensorflow::string;
namespace {
TFE_TensorHandle* TestMatrixTensorHandle() {
int64_t dims[] = {2, 2};
float data[] = {1.0f, 2.0f, 3.0f, 4.0f};
TF_Tensor* t = TF_AllocateTensor(
TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TFE_TensorHandle* th = TFE_NewTensorHandle(t);
TF_DeleteTensor(t);
return th;
}
TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
TF_Status* status = TF_NewStatus();
TFE_Op* op = TFE_NewOp(ctx, "MatMul", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(op, a, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(op, b, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
TFE_OpSetAttrBool(op, "transpose_a", 0);
TFE_OpSetAttrBool(op, "transpose_b", 0);
TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a));
return op;
}
// TODO(apassos) uncomment after rewriting to use the right benchmark API
// void BM_InitOp(benchmark::State& state) {
// TF_Status* status = TF_NewStatus();
// TF_SessionOptions* opts = TF_NewSessionOptions();
// TFE_Context* ctx = TFE_NewContext(opts, status);
// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// TF_DeleteSessionOptions(opts);
// TFE_TensorHandle* m = TestMatrixTensorHandle();
// for (auto _ : state) {
// TFE_Op* matmul = MatMulOp(ctx, m, m);
// TFE_DeleteOp(matmul);
// }
// TFE_DeleteTensorHandle(m);
// TFE_DeleteContext(ctx, status);
// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// TF_DeleteStatus(status);
// }
// BENCHMARK(BM_InitOp);
// void BM_Execute(benchmark::State& state) {
// TF_Status* status = TF_NewStatus();
// TF_SessionOptions* opts = TF_NewSessionOptions();
// TFE_Context* ctx = TFE_NewContext(opts, status);
// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// TF_DeleteSessionOptions(opts);
// TFE_TensorHandle* m = TestMatrixTensorHandle();
// TFE_Op* matmul = MatMulOp(ctx, m, m);
// TFE_TensorHandle* retvals[1];
// int num_retvals = 1;
// for (auto _ : state) {
// TFE_Execute(matmul, &retvals[0], &num_retvals, status);
// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// }
// TFE_DeleteOp(matmul);
// TFE_DeleteTensorHandle(m);
// TFE_DeleteContext(ctx, status);
// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// TF_DeleteStatus(status);
// }
// BENCHMARK(BM_Execute);
TEST(CAPI, Context) {
TF_Status* status = TF_NewStatus();
TF_SessionOptions* opts = TF_NewSessionOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
TF_DeleteSessionOptions(opts);
TF_DeviceList* devices = TFE_ContextListDevices(ctx, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContext(ctx, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
const int num_devices = TF_DeviceListCount(devices);
EXPECT_GE(num_devices, 1) << "At least one CPU device should exist";
for (int i = 0; i < num_devices; ++i) {
EXPECT_NE("", TF_DeviceListName(devices, i, status)) << i;
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
}
TF_DeleteDeviceList(devices);
TF_DeleteStatus(status);
}
TEST(CAPI, TensorHandle) {
TFE_TensorHandle* h = TestMatrixTensorHandle();
EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h));
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_Tensor* t = TFE_TensorHandleResolve(h, status.get());
ASSERT_EQ(16, TF_TensorByteSize(t));
float data[4] = {0};
memcpy(&data[0], TF_TensorData(t), TF_TensorByteSize(t));
EXPECT_EQ(1.0, data[0]);
EXPECT_EQ(2.0, data[1]);
EXPECT_EQ(3.0, data[2]);
EXPECT_EQ(4.0, data[3]);
TF_DeleteTensor(t);
TFE_DeleteTensorHandle(h);
}
TEST(CAPI, TensorHandleCopyBetweenDevices) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_SessionOptions* opts = TF_NewSessionOptions();
TFE_Context* ctx = TFE_NewContext(opts, status.get());
TF_DeleteSessionOptions(opts);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
const int num_devices = TF_DeviceListCount(devices);
const char* kCPUDevice = "CPU:0";
for (int i = 0; i < num_devices; ++i) {
const string name(TF_DeviceListName(devices, i, status.get()));
if (TF_GetCode(status.get()) != TF_OK) {
ADD_FAILURE() << i << " -- " << TF_Message(status.get());
continue;
}
auto tag = tensorflow::strings::StrCat("Device #", i, " (", name, ")");
// Copy to device
TFE_TensorHandle* hdevice =
TFE_TensorHandleCopyToDevice(hcpu, ctx, name.c_str(), status.get());
if (TF_GetCode(status.get()) != TF_OK) {
ADD_FAILURE() << tag << " -- " << TF_Message(status.get());
continue;
}
// Copy back to CPU
TFE_TensorHandle* hcopy =
TFE_TensorHandleCopyToDevice(hdevice, ctx, kCPUDevice, status.get());
if (TF_GetCode(status.get()) != TF_OK) {
ADD_FAILURE() << tag << " -- " << TF_Message(status.get());
continue;
}
TFE_DeleteTensorHandle(hdevice);
// Ensure that the contents are the same!
TF_Tensor* tcopy = TFE_TensorHandleResolve(hcopy, status.get());
TFE_DeleteTensorHandle(hcopy);
if (TF_GetCode(status.get()) != TF_OK) {
ADD_FAILURE() << tag;
continue;
}
EXPECT_EQ(TF_TensorByteSize(t), TF_TensorByteSize(tcopy)) << tag;
EXPECT_EQ(
0, memcmp(TF_TensorData(t), TF_TensorData(tcopy), TF_TensorByteSize(t)))
<< tag;
TF_DeleteTensor(tcopy);
}
TF_DeleteDeviceList(devices);
TF_DeleteTensor(t);
TFE_DeleteTensorHandle(hcpu);
TFE_DeleteContext(ctx, status.get());
EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
}
TEST(CAPI, Execute) {
TF_Status* status = TF_NewStatus();
TF_SessionOptions* opts = TF_NewSessionOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteSessionOptions(opts);
TFE_TensorHandle* m = TestMatrixTensorHandle();
TFE_Op* matmul = MatMulOp(ctx, m, m);
TFE_TensorHandle* retvals[2] = {nullptr};
int num_retvals = 2; // Should be reduced to 1 by the TFE_Execute call.
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteOp(matmul);
TFE_DeleteTensorHandle(m);
TFE_DeleteContext(ctx, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_EQ(1, num_retvals);
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
TFE_DeleteTensorHandle(retvals[0]);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
float product[4] = {0};
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
EXPECT_EQ(7, product[0]);
EXPECT_EQ(10, product[1]);
EXPECT_EQ(15, product[2]);
EXPECT_EQ(22, product[3]);
TF_DeleteStatus(status);
}
string MatMulFunction() {
tensorflow::FunctionDef def;
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
" signature {"
" name: 'MatMulFunction'"
" input_arg {"
" name: 'a'"
" type: DT_FLOAT"
" }"
" output_arg {"
" name: 'm'"
" type: DT_FLOAT"
" }"
" }"
" node_def {"
" name: 'matmul'"
" op: 'MatMul'"
" input: 'a'"
" input: 'a'"
" attr {"
" key: 'T'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" ret {"
" key: 'm'"
" value: 'matmul:product'"
" }",
&def));
return def.SerializeAsString();
}
TEST(CAPI, FunctionDefAndExecute) {
TF_Status* status = TF_NewStatus();
TF_SessionOptions* opts = TF_NewSessionOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteSessionOptions(opts);
string function_def = MatMulFunction();
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* m = TestMatrixTensorHandle();
TFE_TensorHandle* retval[1] = {nullptr};
int num_retvals = 1;
TFE_Op* op = TFE_NewOp(ctx, "MatMulFunction", status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(op, m, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_Execute(op, &retval[0], &num_retvals, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_EQ(1, num_retvals);
TFE_DeleteOp(op);
TFE_DeleteTensorHandle(m);
TF_Tensor* t = TFE_TensorHandleResolve(retval[0], status);
TFE_DeleteTensorHandle(retval[0]);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
float product[4] = {0};
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
EXPECT_EQ(7, product[0]);
EXPECT_EQ(10, product[1]);
EXPECT_EQ(15, product[2]);
EXPECT_EQ(22, product[3]);
TFE_DeleteContext(ctx, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
}
// TODO(apassos) uncomment after rewriting to use the right benchmark API
// void BM_ExecuteFunction(benchmark::State& state) {
// TF_Status* status = TF_NewStatus();
// TF_SessionOptions* opts = TF_NewSessionOptions();
// TFE_Context* ctx = TFE_NewContext(opts, status);
// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// TF_DeleteSessionOptions(opts);
// string function_def = MatMulFunction();
// TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
// status);
// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// TFE_TensorHandle* m = TestMatrixTensorHandle();
// TFE_Op* matmul = TFE_NewOp(ctx, "MatMulFunction", status);
// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// TFE_OpAddInput(matmul, m, status);
// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// TFE_TensorHandle* retval[1] = {nullptr};
// int num_retvals = 1;
// for (auto _ : state) {
// TFE_Execute(matmul, &retval[0], &num_retvals, status);
// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// }
// TFE_DeleteTensorHandle(m);
// TFE_DeleteTensorHandle(retval[0]);
// TFE_DeleteContext(ctx, status);
// EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// TF_DeleteStatus(status);
// }
// BENCHMARK(BM_ExecuteFunction);
// TFE_TensorHandle* CreateVariable(TFE_Context* ctx, float value,
// TF_Status* status) {
// // Create the variable handle.
// TFE_Op* op = TFE_NewOp(ctx, "VarHandleOp", status);
// if (TF_GetCode(status) != TF_OK) return nullptr;
// TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
// TFE_OpSetAttrShape(op, "shape", {}, 0, status);
// TFE_OpSetAttrString(op, "container", "");
// TFE_OpSetAttrString(op, "shared_name", "");
// if (TF_GetCode(status) != TF_OK) return nullptr;
// TFE_TensorHandle* var_handle = nullptr;
// int num_retvals = 1;
// TFE_Execute(op, &var_handle, &num_retvals, status);
// TFE_DeleteOp(op);
// if (TF_GetCode(status) != TF_OK) return nullptr;
// CHECK_EQ(1, num_retvals);
// // Assign 'value' to it.
// op = TFE_NewOp(ctx, "AssignVariableOp", status);
// if (TF_GetCode(status) != TF_OK) return nullptr;
// TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
// TFE_OpAddInput(op, var_handle, status);
// // Convert 'value' to a TF_Tensor then a TFE_TensorHandle.
// std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> t(
// TF_AllocateTensor(TF_FLOAT, nullptr, 0, sizeof(value)),
// TF_DeleteTensor);
// memcpy(TF_TensorData(t.get()), &value, TF_TensorByteSize(t.get()));
// std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
// value_handle(TFE_NewTensorHandle(t.get()), TFE_DeleteTensorHandle);
// TFE_OpAddInput(op, value_handle.get(), status);
// if (TF_GetCode(status) != TF_OK) return nullptr;
// num_retvals = 0;
// TFE_Execute(op, nullptr, &num_retvals, status);
// TFE_DeleteOp(op);
// if (TF_GetCode(status) != TF_OK) return nullptr;
// CHECK_EQ(0, num_retvals);
// return var_handle;
// }
// TEST(CAPI, Variables) {
// // Variables use resource handles, so this is really a test for resource
// // tensor handling.
// TF_Status* status = TF_NewStatus();
// TF_SessionOptions* opts = TF_NewSessionOptions();
// TFE_Context* ctx = TFE_NewContext(opts, status);
// ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// TF_DeleteSessionOptions(opts);
// TFE_TensorHandle* var_handle = CreateVariable(ctx, 12.0, status);
// ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status);
// ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
// TFE_OpAddInput(op, var_handle, status);
// ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// int num_retvals = 1;
// TFE_TensorHandle* value_handle = nullptr;
// TFE_Execute(op, &value_handle, &num_retvals, status);
// TFE_DeleteOp(op);
// ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// ASSERT_EQ(1, num_retvals);
// EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(value_handle));
// EXPECT_EQ(0, TFE_TensorHandleNumDims(value_handle));
// float value = 0.0f;
// TF_Tensor* t = TFE_TensorHandleResolve(value_handle, status);
// ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// ASSERT_EQ(sizeof(float), TF_TensorByteSize(t));
// memcpy(&value, TF_TensorData(t), sizeof(float));
// TF_DeleteTensor(t);
// EXPECT_EQ(12.0, value);
// TFE_DeleteTensorHandle(var_handle);
// TFE_DeleteTensorHandle(value_handle);
// TFE_DeleteContext(ctx, status);
// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// TF_DeleteStatus(status);
// }
// void BM_ReadVariable(benchmark::State& state) {
// TF_Status* status = TF_NewStatus();
// TF_SessionOptions* opts = TF_NewSessionOptions();
// TFE_Context* ctx = TFE_NewContext(opts, status);
// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// TF_DeleteSessionOptions(opts);
// TFE_TensorHandle* var_handle = CreateVariable(ctx, 5.0, status);
// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status);
// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
// TFE_OpAddInput(op, var_handle, status);
// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// int num_retvals = 1;
// TFE_TensorHandle* h = nullptr;
// for (auto _ : state) {
// TFE_Execute(op, &h, &num_retvals, status);
// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// CHECK_EQ(1, num_retvals);
// CHECK(h);
// CHECK_EQ(TF_FLOAT, TFE_TensorHandleDataType(h));
// CHECK_EQ(0, TFE_TensorHandleNumDims(h));
// h = nullptr;
// }
// TFE_DeleteOp(op);
// TFE_DeleteTensorHandle(var_handle);
// TFE_DeleteContext(ctx, status);
// CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// TF_DeleteStatus(status);
// }
// BENCHMARK(BM_ReadVariable);
} // namespace
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/runtime.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/gtl/stl_util.h"
#include "tensorflow/core/platform/fingerprint.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/public/version.h"
#include "tensorflow/core/util/tensor_slice_reader_cache.h"
namespace tensorflow {
namespace {
mutex g_op_name_to_attr_type_map_lock(LINKER_INITIALIZED);
std::unordered_map<string, const AttrTypeMap*>* OpNameToAttrTypeMap() {
static auto* const m = new std::unordered_map<string, const AttrTypeMap*>;
return m;
}
const uint32 kIsList = 1U << 31;
} // namespace
Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out) {
mutex_lock l(g_op_name_to_attr_type_map_lock);
*out = gtl::FindPtrOrNull(*OpNameToAttrTypeMap(), op_name);
if (*out != nullptr) return Status::OK();
const OpRegistrationData* op_reg_data = nullptr;
Status s = OpRegistry::Global()->LookUp(op_name, &op_reg_data);
if (!s.ok()) return s;
std::unique_ptr<AttrTypeMap> m(new AttrTypeMap);
// TODO(agarwal): Avoid having to create this "registry" at runtime,
// perhaps can be done at op registration time?
for (const auto& attr : op_reg_data->op_def.attr()) {
string type = attr.type();
const bool is_list = (type.length() > 6 && type.compare(0, 4, "list") == 0);
if (is_list) {
type = type.substr(5, type.length() - 6);
}
uint32 t = is_list ? kIsList : 0;
if (type == "string") {
t |= TF_ATTR_STRING;
} else if (type == "int") {
t |= TF_ATTR_INT;
} else if (type == "float") {
t |= TF_ATTR_FLOAT;
} else if (type == "bool") {
t |= TF_ATTR_BOOL;
} else if (type == "type") {
t |= TF_ATTR_TYPE;
} else if (type == "shape") {
t |= TF_ATTR_SHAPE;
} else if (type == "tensor") {
t |= TF_ATTR_TENSOR;
} else {
return errors::Unimplemented(
"TODO(agarwal): Enable support for ops with attributes of type '",
type, "'");
}
gtl::InsertIfNotPresent(m.get(), attr.name(), t);
}
*out = m.get();
(*OpNameToAttrTypeMap())[op_name] = m.release();
return Status::OK();
}
Status AttrTypeByName(const AttrTypeMap* m, const string& attr_name,
TF_AttrType* out, unsigned char* is_list) {
CHECK(m);
auto* t = gtl::FindOrNull(*m, attr_name);
if (t == nullptr) {
return errors::InvalidArgument("Attribute '", attr_name,
"' does not exist for this operation");
}
*out = static_cast<TF_AttrType>(*t & ~kIsList);
if (*t & kIsList) {
*is_list = 1;
} else {
*is_list = 0;
}
return Status::OK();
}
#define DEFINE_SET_ATTR(value_type, value_field) \
template <> \
AttrBuilder& AttrBuilder::Set(StringPiece attr_name, value_type&& value) { \
value_field.push_back(std::make_pair(attr_name, value)); \
return *this; \
}
DEFINE_SET_ATTR(StringPiece, string_attrs_);
DEFINE_SET_ATTR(float, float_attrs_);
DEFINE_SET_ATTR(int, int_attrs_);
DEFINE_SET_ATTR(bool, bool_attrs_);
DEFINE_SET_ATTR(tensorflow::DataType, type_attrs_);
#undef DEFINE_SET_ATTR
AttrBuilder& AttrBuilder::NumInputs(int n) {
DCHECK(!node_def_finalized_) << "Calling NumInputs after BuildNodeDef.";
num_inputs_ = n;
return *this;
}
const NodeDef& AttrBuilder::BuildNodeDef() {
if (node_def_finalized_) return *node_def_;
MayBeInitializeNodeDef();
for (int i = 0; i < num_inputs_; ++i) {
node_def_->add_input("dummy_input");
}
for (const auto& p : string_attrs_) {
SetInNodeDef(p.first, p.second);
}
for (const auto& p : int_attrs_) {
SetInNodeDef(p.first, p.second);
}
for (const auto& p : float_attrs_) {
SetInNodeDef(p.first, p.second);
}
for (const auto& p : bool_attrs_) {
SetInNodeDef(p.first, p.second);
}
for (const auto& p : type_attrs_) {
SetInNodeDef(p.first, p.second);
}
node_def_finalized_ = true;
return *node_def_;
}
namespace {
inline tensorflow::Fprint128 FingerprintCat128(const tensorflow::Fprint128& a,
const tensorflow::Fprint128& b) {
return {tensorflow::FingerprintCat64(a.low64, b.low64),
tensorflow::FingerprintCat64(a.low64, b.low64)};
}
void CombineUnordered(const tensorflow::Fprint128& a,
tensorflow::Fprint128* b) {
b->low64 += a.low64;
b->high64 += a.high64;
}
inline tensorflow::Fprint128 CacheKeyHelper(const StringPiece& s,
const tensorflow::Fprint128& b) {
// TODO(agarwal): avoid ToString().
tensorflow::Fprint128 a = tensorflow::Fingerprint128(s.ToString());
return FingerprintCat128(a, b);
}
inline tensorflow::Fprint128 CacheKeyHelper(const StringPiece& s, uint64 b) {
return CacheKeyHelper(s, {b, b});
}
} // namespace
tensorflow::Fprint128 AttrBuilder::CacheKey(const string& device) const {
tensorflow::Fprint128 f = tensorflow::Fingerprint128(op_name_);
f = tensorflow::FingerprintCat128(f, tensorflow::Fingerprint128(device));
if (node_def_ != nullptr) {
// Some attributes are directly written to node_def_ instead of being
// stored explicitly.
string value;
for (const auto& attr : node_def_->attr()) {
attr.second.SerializeToString(&value);
CombineUnordered(
CacheKeyHelper(attr.first, tensorflow::Fingerprint128(value)), &f);
}
// Note that node_def_ may be created but not finalized. This can happen
// when the creation was triggered by a call to Set, but BuildNodeDef has
// not been called.
if (node_def_finalized_) return f;
}
for (const auto& p : string_attrs_) {
// TODO(agarwal): avoid ToString().
CombineUnordered(CacheKeyHelper(p.first, tensorflow::Fingerprint128(
p.second.ToString())),
&f);
}
for (const auto& p : int_attrs_) {
CombineUnordered(CacheKeyHelper(p.first, static_cast<uint64>(p.second)),
&f);
}
static std::hash<float> float_hasher;
for (const auto& p : float_attrs_) {
CombineUnordered(
CacheKeyHelper(p.first, static_cast<uint64>(float_hasher(p.second))),
&f);
}
for (const auto& p : bool_attrs_) {
CombineUnordered(CacheKeyHelper(p.first, p.second ? 1u : 0u), &f);
}
for (const auto& p : type_attrs_) {
CombineUnordered(CacheKeyHelper(p.first, static_cast<uint64>(p.second)),
&f);
}
return f;
}
void AttrBuilder::MayBeInitializeNodeDef() {
if (node_def_ == nullptr) {
node_def_.reset(new NodeDef());
node_def_->set_name(op_name_);
node_def_->set_op(op_name_);
}
}
// static
Status KernelAndDevice::InitOp(Device* device, const NodeDef& ndef,
KernelAndDevice* out) {
OpKernel* k = nullptr;
Status s = CreateOpKernel(device->device_type().c_str(), device,
device->GetAllocator(AllocatorAttributes()),
nullptr, ndef, TF_GRAPH_DEF_VERSION, &k);
out->device_ = device;
out->kernel_.reset(k);
out->flib_ = nullptr;
return s;
}
// static
Status KernelAndDevice::InitFn(const NodeDef& ndef,
FunctionLibraryRuntime* flib,
KernelAndDevice* out) {
OpKernel* k = nullptr;
Status s = flib->CreateKernel(ndef, &k);
out->device_ = flib->device();
out->kernel_.reset(k);
out->flib_ = flib;
return s;
}
Status KernelAndDevice::Run(std::vector<Tensor>* input_tensors,
std::vector<Tensor>* output_tensors) {
gtl::InlinedVector<TensorValue, 4> inputs;
for (Tensor& t : *input_tensors) {
inputs.push_back(TensorValue(&t));
}
std::vector<AllocatorAttributes> out_attrs(kernel_->num_outputs());
for (size_t i = 0; i < out_attrs.size(); ++i) {
out_attrs[i].set_on_host(kernel_->output_memory_types()[i] ==
tensorflow::HOST_MEMORY);
}
OpKernelContext::Params params;
params.device = device_;
params.frame_iter = FrameAndIter(0, 0);
params.inputs = &inputs;
params.op_kernel = kernel_.get();
params.resource_manager = device_->resource_manager();
params.output_attr_array = gtl::vector_as_array(&out_attrs);
params.function_library = flib_;
params.slice_reader_cache = &slice_reader_cache_;
// TODO(apassos): use a thread pool.
std::function<void(std::function<void()>)> runner =
[](std::function<void()> f) { f(); };
params.runner = &runner;
OpKernelContext context(&params);
device_->Compute(kernel_.get(), &context);
if (!context.status().ok()) return context.status();
output_tensors->clear();
for (int i = 0; i < context.num_outputs(); ++i) {
output_tensors->push_back(Tensor(*context.mutable_output(i)));
}
return Status::OK();
}
} // namespace tensorflow
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_RUNTIME_H_
#define TENSORFLOW_C_EAGER_RUNTIME_H_
// Support for eager execution of TensorFlow kernels.
#include <memory>
#include <unordered_map>
#include "tensorflow/c/c_api.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/platform/fingerprint.h"
#include "tensorflow/core/util/tensor_slice_reader_cache.h"
namespace tensorflow {
// Maps attribute name to an encoding of the type of the attribute value.
// If the type is not a list type, the value is the same as the TF_AttrType type
// of the value. Else, the highest order bit is on, and the rest of the bits
// represent the TF_AttrType type of the values in the list.
typedef std::unordered_map<string, uint32> AttrTypeMap;
// Returns the AttrTypeMap for the TensorFlow operation named op_name.
Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out);
// Looks for 'attr_name' in 'm' and sets 'out' and 'is_list'.
Status AttrTypeByName(const AttrTypeMap* m, const string& attr_name,
TF_AttrType* out, unsigned char* is_list);
// KernelAndDevice::Init needs a NodeDef only to pass the attribute map through.
// An AttrBuilder is a convenience class to help with that - providing a smaller
// interface than NodeDefBuilder and avoiding expensive (unnecessary?) sanity
// checks (like number of inputs matching the OpDef - we only care about
// attributes here).
//
// TODO(ashankar): Take a closer look at checks in NodeDefBuilder and see which
// ones make sense to replicate.
// This is a helper class for creating a NodeDef. Additionally, this class
// allows computing a cache key based on fingerprinting the attributes of this
// NodeDef.
//
// Example usage:
// AttrBuilder a;
// a.NumInputs(2);
// a.Set("T", TF_FLOAT);
// uint64 cache_key = a.CacheKey("cpu:0");
// const NodeDef& n = a.BuildNodeDef();
//
// Note that all calls to Set and NumInputs should happen before calling
// BuildNodeDef. Also, calls to NumInputs or Set between multiple invocations
// to CacheKey may cause different values to be returned by CacheKey.
//
// For performance reasons, the class internally delays the actual construction
// of the NodeDef till BuildNodeDef is called, or Set is called with certain
// uncommon types (see template specializations of Set to see which types
// trigger a NodeDef creation).
class AttrBuilder {
public:
explicit AttrBuilder(const char* op)
: op_name_(op),
num_inputs_(0),
node_def_(nullptr),
node_def_finalized_(false) {}
// Needed to work around call to ValidateNodeDef in CreateOpKernel.
AttrBuilder& NumInputs(int n);
template <class T>
AttrBuilder& Set(StringPiece attr_name, T&& value) {
MayBeInitializeNodeDef();
return SetInNodeDef(attr_name, value);
}
tensorflow::Fprint128 CacheKey(const string& device) const;
const NodeDef& BuildNodeDef();
private:
template <class T>
using AttrVec = tensorflow::gtl::InlinedVector<std::pair<StringPiece, T>, 2>;
void MayBeInitializeNodeDef();
template <class T>
AttrBuilder& SetInNodeDef(StringPiece attr_name, T&& value) {
DCHECK(!node_def_finalized_) << "Calling SetInNodeDef after BuildNodeDef.";
// Copied from NodeDefBuilder::Attr
const AttrValue* found = AttrSlice(*node_def_).Find(attr_name);
if (found == nullptr) {
AddNodeAttr(attr_name, std::forward<T>(value), node_def_.get());
} else {
AttrValue attr_value;
SetAttrValue(std::forward<T>(value), &attr_value);
// TODO(ashankar): Do what is done in
// NodeDefBuilder::CheckInconsistency(attr_name, *found, attr_value);
}
return *this;
}
AttrVec<StringPiece> string_attrs_;
AttrVec<int> int_attrs_;
AttrVec<float> float_attrs_;
AttrVec<bool> bool_attrs_;
AttrVec<tensorflow::DataType> type_attrs_;
string op_name_;
int num_inputs_;
std::unique_ptr<NodeDef> node_def_;
bool node_def_finalized_;
}; // namespace tensorflow
template <>
AttrBuilder& AttrBuilder::Set(StringPiece attr_name, StringPiece&& value);
template <>
AttrBuilder& AttrBuilder::Set(StringPiece attr_name, int&& value);
template <>
AttrBuilder& AttrBuilder::Set(StringPiece attr_name, float&& value);
template <>
AttrBuilder& AttrBuilder::Set(StringPiece attr_name, bool&& value);
template <>
AttrBuilder& AttrBuilder::Set(StringPiece attr_name,
tensorflow::DataType&& value);
// KernelAndDevice encapsulates an instantiated kernel and the device it is on.
//
// Also see:
// https://www.tensorflow.org/code/tensorflow/core/common_runtime/kernel_benchmark_testlib.h
// and
// https://www.tensorflow.org/code/tensorflow/core/kernels/ops_testutil.h
class KernelAndDevice {
public:
// Populates 'out' with a kernel appropriate for 'ndef'.
//
// Assumes that 'ndef' refers to a primitive op (as opposed to a function).
static Status InitOp(Device* device, const NodeDef& ndef,
KernelAndDevice* out);
// Like InitOp but for functions defined in flib (i.e., ndef.op() refers to a
// TensorFlow function in the FunctionLibraryRuntime).
//
// The provided FunctionLibraryRuntime MUST outlive all calls to
// Run() on the returned KernelAndDevice.
//
// TODO(ashankar): There shouldn't be a need for a separate InitOp and InitFn.
// The implementation of InitFn should work for both because
// FunctionLibraryRuntime::CreateKernel will create a primitive op kernel if
// appropriate. However, for now we keep them separate because I haven't
// figured out thread-safety concerns around FunctionLibraryRuntime (in
// particular, how the underlying FunctionLibraryDefinition might be mutated
// by another thread as new functions are registered with it).
// Conservatively, thread-safe usage of the FunctionLibraryRuntime is pushed
// on to the caller (see locking in c_api.cc) for now. But I really should
// dig into this so that both InitOp and InitFn can be collapsed to
// FunctionLibraryRuntime::CreateKernel.
static Status InitFn(const NodeDef& ndef, FunctionLibraryRuntime* flib,
KernelAndDevice* out);
KernelAndDevice() : device_(nullptr), flib_(nullptr) {}
// TODO(ashankar): Handle list-valued inputs.
Status Run(std::vector<Tensor>* inputs, std::vector<Tensor>* outputs);
const OpKernel* kernel() const { return kernel_.get(); }
private:
std::unique_ptr<OpKernel> kernel_;
tensorflow::Device* device_;
tensorflow::FunctionLibraryRuntime* flib_;
tensorflow::checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_RUNTIME_H_
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/runtime.h"
#include <memory>
#include <vector>
#include "tensorflow/cc/client/client_session.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
namespace tensorflow {
namespace {
Device* CPUDevice() {
return DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0");
}
TEST(AttrTypeMap, Lookup) {
const AttrTypeMap* m = nullptr;
Status s = AttrTypeMapForOp("ThisOpCannotPossiblyExist", &m);
EXPECT_FALSE(s.ok());
s = AttrTypeMapForOp("MatMul", &m);
ASSERT_TRUE(s.ok()) << s;
TF_AttrType t;
unsigned char is_list = 1;
s = AttrTypeByName(m, "ThisAttribyteCannotPossiblyExist", &t, &is_list);
EXPECT_FALSE(s.ok());
EXPECT_NE(is_list, 0);
s = AttrTypeByName(m, "transpose_a", &t, &is_list);
ASSERT_TRUE(s.ok()) << s;
EXPECT_EQ(TF_ATTR_BOOL, t);
EXPECT_EQ(is_list, 0);
s = AttrTypeMapForOp("Squeeze", &m);
ASSERT_TRUE(s.ok()) << s;
s = AttrTypeByName(m, "squeeze_dims", &t, &is_list);
ASSERT_TRUE(s.ok()) << s;
EXPECT_EQ(TF_ATTR_INT, t);
EXPECT_NE(is_list, 0);
}
TEST(KernelAndDevice, Run) {
Tensor t(Input({{1.0f, 2.0f}, {3.0f, 4.0f}}).tensor());
std::vector<Tensor> inputs;
inputs.push_back(t);
inputs.push_back(t);
NodeDef ndef(AttrBuilder("MatMul")
.Set("T", DT_FLOAT)
.Set("transpose_a", false)
.Set("transpose_b", false)
.NumInputs(inputs.size())
.BuildNodeDef());
std::unique_ptr<Device> device(CPUDevice());
KernelAndDevice kernel;
Status s = KernelAndDevice::InitOp(device.get(), ndef, &kernel);
ASSERT_TRUE(s.ok()) << s;
std::vector<Tensor> outputs;
s = kernel.Run(&inputs, &outputs);
ASSERT_TRUE(s.ok()) << s;
ASSERT_EQ(1, outputs.size());
const Tensor& out = outputs[0];
EXPECT_EQ(7, out.matrix<float>()(0, 0));
EXPECT_EQ(10, out.matrix<float>()(0, 1));
EXPECT_EQ(15, out.matrix<float>()(1, 0));
EXPECT_EQ(22, out.matrix<float>()(1, 1));
}
// TODO(apassos) uncomment after rewriting to use the right benchmark API
// void BM_CreateGraph(benchmark::State& state) {
// for (auto _ : state) {
// Scope root = Scope::NewRootScope();
// auto C = ops::Const(root, {{1.0, 2.0}, {3.0, 4.0}});
// auto M = ops::MatMul(root, C, C);
// TF_CHECK_OK(root.status());
// }
// }
// BENCHMARK(BM_CreateGraph);
// void BM_RunGraph(benchmark::State& state) {
// Scope root = Scope::NewRootScope();
// auto C = ops::Const(root, {{1.0, 2.0}, {3.0, 4.0}});
// auto M = ops::MatMul(root, C, C);
// SessionOptions opts;
// opts.config.set_inter_op_parallelism_threads(1);
// opts.config.set_intra_op_parallelism_threads(1);
// ClientSession sess(root, opts);
// std::vector<Tensor> outputs;
// for (auto _ : state) {
// outputs.clear();
// TF_CHECK_OK(sess.Run({M}, &outputs));
// }
// }
// BENCHMARK(BM_RunGraph);
// void BM_CreateAndDestroySession(benchmark::State& state) {
// Scope root = Scope::NewRootScope();
// auto C = ops::Const(root, {{1.0, 2.0}, {3.0, 4.0}});
// auto M = ops::MatMul(root, C, C);
// for (auto _ : state) {
// ClientSession sess(root);
// }
// }
// BENCHMARK(BM_CreateAndDestroySession);
// void BM_KernelAndDeviceInit(benchmark::State& state) {
// NodeDef ndef(AttrBuilder("MatMul")
// .Set("T", DT_FLOAT)
// .Set("transpose_a", false)
// .Set("transpose_b", false)
// .NumInputs(2)
// .BuildNodeDef());
// std::unique_ptr<Device> device(CPUDevice());
// KernelAndDevice k;
// for (auto _ : state) {
// TF_CHECK_OK(KernelAndDevice::InitOp(device.get(), ndef, &k));
// }
// }
// BENCHMARK(BM_KernelAndDeviceInit);
// void BM_KernelAndDeviceRun(benchmark::State& state) {
// Tensor t(Input({{1.0f, 2.0f}, {3.0f, 4.0f}}).tensor());
// std::vector<Tensor> inputs;
// inputs.push_back(t);
// inputs.push_back(t);
// std::vector<Tensor> outputs;
// NodeDef ndef(AttrBuilder("MatMul")
// .Set("T", DT_FLOAT)
// .Set("transpose_a", false)
// .Set("transpose_b", false)
// .NumInputs(inputs.size())
// .BuildNodeDef());
// std::unique_ptr<Device> device(CPUDevice());
// KernelAndDevice kernel;
// TF_CHECK_OK(KernelAndDevice::InitOp(device.get(), ndef, &kernel));
// for (auto _ : state) {
// TF_CHECK_OK(kernel.Run(&inputs, &outputs));
// }
// }
// BENCHMARK(BM_KernelAndDeviceRun);
} // namespace
} // namespace tensorflow
......@@ -18,6 +18,10 @@
set(tf_c_srcs
"${tensorflow_source_dir}/tensorflow/c/c_api.cc"
"${tensorflow_source_dir}/tensorflow/c/c_api.h"
"${tensorflow_source_dir}/tensorflow/c/eager/c_api.cc"
"${tensorflow_source_dir}/tensorflow/c/eager/c_api.h"
"${tensorflow_source_dir}/tensorflow/c/eager/runtime.cc"
"${tensorflow_source_dir}/tensorflow/c/eager/runtime.h"
"${tensorflow_source_dir}/tensorflow/c/checkpoint_reader.cc"
"${tensorflow_source_dir}/tensorflow/c/checkpoint_reader.h"
"${tensorflow_source_dir}/tensorflow/c/tf_status_helper.cc"
......
......@@ -755,6 +755,8 @@ add_custom_command(
set (pywrap_tensorflow_internal_src
"${tensorflow_source_dir}/tensorflow/core/profiler/internal/print_model_analysis.h"
"${tensorflow_source_dir}/tensorflow/core/profiler/internal/print_model_analysis.cc"
"${tensorflow_source_dir}/tensorflow/python/eager/pywrap_tfe.h"
"${tensorflow_source_dir}/tensorflow/python/eager/pywrap_tfe_src.h"
"${tensorflow_source_dir}/tensorflow/python/client/tf_session_helper.h"
"${tensorflow_source_dir}/tensorflow/python/client/tf_session_helper.cc"
"${tensorflow_source_dir}/tensorflow/python/framework/cpp_shape_inference.h"
......
......@@ -60,6 +60,7 @@ INCLUDEPRE_RE = re.compile(r"google::protobuf::internal::ExplicitlyConstructed|"
# Include if matched after exclude
INCLUDE_RE = re.compile(r"^(TF_\w*)$|"
r"^(TFE_\w*)$|"
r"tensorflow::|"
r"functor::|"
r"perftools::gputools")
......
......@@ -2814,6 +2814,7 @@ tf_py_wrap_cc(
"lib/io/py_record_reader.i",
"lib/io/py_record_writer.i",
"platform/base.i",
"pywrap_tfe.i",
"training/quantize_training.i",
"training/server_lib.i",
"util/kernel_registry.i",
......@@ -2838,6 +2839,7 @@ tf_py_wrap_cc(
"//tensorflow/c:checkpoint_reader",
"//tensorflow/c:python_api",
"//tensorflow/c:tf_status_helper",
"//tensorflow/c/eager:c_api",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"//tensorflow/core/distributed_runtime/rpc:grpc_session",
"//tensorflow/core/grappler:grappler_item",
......@@ -2850,6 +2852,7 @@ tf_py_wrap_cc(
"//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/profiler/internal:print_model_analysis",
"//tensorflow/tools/graph_transforms:transform_graph_lib",
"//tensorflow/python/eager:pywrap_tfe_lib",
"//util/python:python_headers",
] + (tf_additional_lib_deps() +
tf_additional_plugin_deps() +
......
......@@ -56,7 +56,7 @@ tensorflow::ImportNumpy();
// const char*.
%typemap(in) (const char* target) {
$1 = PyBytes_AsString($input);
if (!$1) {
if (!$1) {
// Python has raised an error.
SWIG_fail;
}
......
licenses(["notice"]) # Apache 2.0
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
cc_library(
name = "pywrap_tfe_lib",
srcs = ["pywrap_tfe_src.cc"],
hdrs = ["pywrap_tfe.h"],
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/c:c_api",
"//tensorflow/c/eager:c_api",
"//tensorflow/core:lib",
"//tensorflow/python:numpy_lib",
"//tensorflow/python:py_func_lib",
"//util/python:python_headers",
],
)
py_library(
name = "core",
srcs = ["core.py"],
srcs_version = "PY2AND3",
visibility = ["//tensorflow:internal"],
deps = [
":context",
":memory_trace",
":tape",
"//tensorflow/python:errors",
"//tensorflow/python:pywrap_tensorflow",
],
)
py_library(
name = "tensor",
srcs = ["tensor.py"],
srcs_version = "PY2AND3",
visibility = ["//learning/brain/contrib/eager:__subpackages__"],
deps = [
":context",
":core",
":tape",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:pywrap_tensorflow",
"//tensorflow/python:tensor_shape",
"//third_party/py/numpy",
],
)
py_library(
name = "context",
srcs = ["context.py"],
srcs_version = "PY2AND3",
visibility = ["//learning/brain/contrib/eager:__subpackages__"],
deps = [
"//tensorflow/python:device",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
"//tensorflow/python:platform",
"//tensorflow/python:pywrap_tensorflow",
"//tensorflow/python:util",
],
)
py_library(
name = "tape",
srcs = ["tape.py"],
srcs_version = "PY2AND3",
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/python:dtypes",
"//tensorflow/python:util",
],
)
py_library(
name = "memory_trace",
srcs = ["memory_trace.py"],
srcs_version = "PY2AND3",
)
cuda_py_test(
name = "core_test",
srcs = ["core_test.py"],
additional_deps = [
":context",
":core",
":execute",
"//tensorflow/python:pywrap_tensorflow",
":tensor",
":test",
"//third_party/py/numpy",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_test_lib",
],
)
py_library(
name = "test",
srcs = ["test.py"],
srcs_version = "PY2AND3",
visibility = ["//tensorflow:internal"],
deps = [
":context",
"//tensorflow/python:client_testlib",
],
)
py_library(
name = "execute",
srcs = ["execute.py"],
srcs_version = "PY2AND3",
visibility = ["//tensorflow:internal"],
deps = [
":context",
":core",
":tape",
":tensor",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:dtypes",
"//tensorflow/python:lib",
"//tensorflow/python:pywrap_tensorflow",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:util",
"@six_archive//:six",
],
)
cc_library(
name = "python_eager_op_gen",
srcs = ["python_eager_op_gen.cc"],
hdrs = ["python_eager_op_gen.h"],
visibility = ["//visibility:public"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:op_gen_lib",
"//tensorflow/core:proto_text",
"//tensorflow/core:protos_all_cc",
"//tensorflow/python:python_op_gen",
],
)
cc_library(
name = "python_eager_op_gen_main",
srcs = [
"python_eager_op_gen_main.cc",
],
visibility = ["//visibility:public"],
deps = [
":python_eager_op_gen",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
)
cc_binary(
name = "python_eager_op_gen_demo",
deps = [
":python_eager_op_gen_main",
"//tensorflow/core:ops",
],
)
py_library(
name = "custom_gradient",
srcs = ["custom_gradient.py"],
srcs_version = "PY2AND3",
visibility = ["//tensorflow:internal"],
deps = [
":core",
":tape",
"//tensorflow/python:framework_ops",
"//tensorflow/python:util",
],
)
py_library(
name = "graph_only_ops",
srcs = ["graph_only_ops.py"],
srcs_version = "PY2AND3",
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/core:protos_all_py",
"//tensorflow/python:framework_ops",
],
)
py_library(
name = "framework_for_generated_wrappers",
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:tensor_shape",
"//tensorflow/python/eager:execute",
],
)
py_library(
name = "function",
srcs = ["function.py"],
srcs_version = "PY2AND3",
visibility = ["//tensorflow:internal"],
deps = [
":graph_only_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
"//tensorflow/python:gradients",
"//tensorflow/python:graph_to_function_def",
"//tensorflow/python:pywrap_tensorflow",
"//tensorflow/python:util",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:core",
"//tensorflow/python/eager:execute",
"//tensorflow/python/eager:tape",
"//tensorflow/python/eager:tensor",
"//third_party/py/numpy",
],
)
py_library(
name = "pip_dependencies",
visibility = ["//tensorflow:internal"],
deps = [
":context",
":core",
":execute",
":tensor",
":test",
"//tensorflow/python:pywrap_tensorflow",
],
)
# -----------------------------------------------------------------------------
# Google-internal targets.
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
visibility = ["//tensorflow:__subpackages__"],
)
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Experimental API for TensorFlow's "Eager" mode of execution."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
import threading
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops as tf_ops
from tensorflow.python.platform import app
from tensorflow.python.util import compat
from tensorflow.python.util import tf_contextlib
GRAPH_MODE = 0
EAGER_MODE = 1
# Default execution mode.
_default_mode = GRAPH_MODE
# TODO(agarwal): better name ?
class _EagerContext(threading.local):
"""Thread local eager context."""
def __init__(self):
super(_EagerContext, self).__init__()
self.device_index = -1
self.mode = _default_mode
self.scope_name = ""
self.recording_summaries = False
# TODO(agarwal): rename to EagerContext / EagerRuntime ?
class Context(object):
"""Environment in which eager operations execute."""
def __init__(self, graph=None):
self._eager_context = _EagerContext()
if not self.in_eager_mode():
raise ValueError("Trying to create a Context in GRAPH_MODE")
# Create a handle
opts = pywrap_tensorflow.TF_NewSessionOptions(target=compat.as_bytes(""),
config=None)
with errors.raise_exception_on_not_ok_status() as status:
self._handle = pywrap_tensorflow.TFE_NewContext(opts, status)
pywrap_tensorflow.TF_DeleteSessionOptions(opts)
# Store list of devices
self._devices = []
with errors.raise_exception_on_not_ok_status() as status:
device_list = pywrap_tensorflow.TFE_ContextListDevices(
self._handle, status)
try:
for i in range(pywrap_tensorflow.TF_DeviceListCount(device_list)):
with errors.raise_exception_on_not_ok_status() as status:
dev_name = pywrap_tensorflow.TF_DeviceListName(device_list, i, status)
self._devices.append(pydev.canonical_name(dev_name))
finally:
pywrap_tensorflow.TF_DeleteDeviceList(device_list)
self._summary_writer_resource = None
self._graph = graph or tf_ops.get_default_graph()
def __del__(self):
if self._handle is not None:
with errors.raise_exception_on_not_ok_status() as status:
pywrap_tensorflow.TFE_DeleteContext(self._handle, status)
def __str__(self):
lines = [
"Eager TensorFlow environment with %d devices" % (len(self._devices))
]
for i, d in enumerate(self._devices):
lines.append(" Device %d: %s" % (i, d))
return "\n".join(lines)
@tf_contextlib.contextmanager
def _mode(self, mode):
ctx = self._eager_context
old_mode = ctx.mode
ctx.mode = mode
try:
yield
finally:
ctx.mode = old_mode
def in_graph_mode(self):
"""Returns True if current thread is in GRAPH mode."""
return self._eager_context.mode == GRAPH_MODE
def in_eager_mode(self):
"""Returns True if current thread is in EAGER mode."""
return self._eager_context.mode == EAGER_MODE
@property
def scope_name(self):
"""Returns scope name for the current thread."""
return self._eager_context.scope_name
@scope_name.setter
def scope_name(self, s):
"""Sets scope name for the current thread."""
self._eager_context.scope_name = s
@property
def summary_writer_resource(self):
"""Returns summary writer resource."""
return self._summary_writer_resource
@summary_writer_resource.setter
def summary_writer_resource(self, resource):
"""Sets summary writer resource."""
self._summary_writer_resource = resource
@property
def recording_summaries(self):
"""Returns True if recording summaries is enabled in current thread.."""
return self._eager_context.recording_summaries
@recording_summaries.setter
def recording_summaries(self, val):
"""Enables recording summaries is enabled in current thread.."""
self._eager_context.recording_summaries = val
# TODO(agarwal): remove?
@property
def _device_index(self):
return self._eager_context.device_index
# TODO(agarwal): remove?
@_device_index.setter
def _device_index(self, val):
self._eager_context.device_index = val
@property
def device_name(self):
"""Returns the device name for the current thread."""
index = self._device_index
return None if index < 0 else self._devices[index]
def devices(self):
"""List of the names of devices available to execute operations."""
return self._devices
def num_gpus(self):
"""The number of GPUs available to execute operations."""
# TODO(ashankar): Use TF_DeviceListType to count GPU devices.
return len(self._devices) - 1
def as_default(self):
"""Returns a context manager to make self the default for this thread."""
return _default_context_stack.get_controller(self)
class _DefaultContextStack(tf_ops._DefaultStack): # pylint: disable=protected-access
"""A thread-local stack of Context objects."""
def __init__(self):
super(_DefaultContextStack, self).__init__()
self._global_default_context = None
def get_default(self):
"""Returns a thread local object if present, else a global default."""
return (super(_DefaultContextStack, self).get_default() or
self.global_default_context)
@property
def global_default_context(self):
if self._global_default_context is None:
self._global_default_context = Context()
return self._global_default_context
def reset(self):
super(_DefaultContextStack, self).reset()
self._global_default_context = None
_default_context_stack = _DefaultContextStack()
def get_default_context():
"""Returns a default Context object."""
return _default_context_stack.get_default()
# TODO(agarwal): switch users to get_default_context and get rid of this
# function.
def context():
return get_default_context()
def in_graph_mode():
"""Returns True if current thread is in GRAPH mode for default context."""
return get_default_context().in_graph_mode()
def in_eager_mode():
"""Returns True if current thread is in EAGER mode for default context."""
return get_default_context().in_eager_mode()
def graph_mode():
"""Context-manager to enable GRAPH mode for current thread."""
return get_default_context()._mode(GRAPH_MODE) # pylint: disable=protected-access
def eager_mode():
"""Context-manager to enable EAGER mode for current thread."""
return get_default_context()._mode(EAGER_MODE) # pylint: disable=protected-access
@contextlib.contextmanager
def namescope(name):
"""ContextManager for creating hierarchical name scopes."""
ctx = get_default_context()
old_name = ctx.scope_name
ctx.scope_name = "%s/%s" % (old_name, name) if old_name else name
try:
yield
finally:
ctx.scope_name = old_name
def scope_name():
"""Name of the current scope."""
return get_default_context().scope_name
@tf_contextlib.contextmanager
def device(name):
"""Context-manager to force placement of operations and Tensors on a device.
For example:
```python
with tfe.device('gpu:0'):
with tfe.device('cpu:0'):
shape = tfe.Tensor([], dtype=tf.int32)
x = ops.truncated_normal(shape, tf.float32)
```
will ensure that the `shape` Tensor is on CPU but the `truncated_normal`
operation
runs on GPU 0.
Args:
name: Name of the device (see get_default_context().devices()), or None to
enable automatic placement.
Yields:
Nothing.
Raises:
ValueError: If name does not correspond to a valid device.
"""
device_index = -1
ctx = get_default_context()
if name is not None:
name = pydev.canonical_name(name)
all_devices = ctx.devices()
for i, d in enumerate(all_devices):
# TODO(ashankar): This will change when we have distributed support.
# At that point, should not look for a string suffix but be able to
# do a full string comparison.
if d.endswith(name):
device_index = i
break
if device_index < 0:
raise ValueError("device {} does not match the available devices ({})".
format(name, all_devices))
old_device_index = ctx._device_index # pylint: disable=protected-access
try:
ctx._device_index = device_index # pylint: disable=protected-access
yield
finally:
ctx._device_index = old_device_index # pylint: disable=protected-access
@contextlib.contextmanager
def record_summaries():
"""Context-manager to enable recording of summaries."""
ctx = get_default_context()
old = ctx.recording_summaries
ctx.recording_summaries = True
try:
yield
finally:
ctx.recording_summaries = old
def should_record_summary():
"""True if a summary should be recorded now."""
c = get_default_context()
return c.recording_summaries and c.summary_writer_resource is not None
def run(main=None, argv=None):
"""Runs the program with an optional 'main' function and 'argv' list.
The program will run with eager execution enabled.
Args:
main: the main function to run
argv: the arguments to pass to it
"""
enable_eager_execution()
app.run(main, argv)
# TODO(apassos): This should not be a part of the public API.
def enable_eager_execution():
"""Enables, for the rest of the lifetime of this program, eager execution.
If not called immediately on startup risks creating breakage and bugs.
"""
global _default_mode
assert _default_mode == GRAPH_MODE
_default_mode = EAGER_MODE
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Experimental API for TensorFlow's "Eager" mode of execution."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import threading
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
from tensorflow.python.eager import memory_trace
from tensorflow.python.framework import errors
# Trace of execution and memory usage.
_active_trace = None
_uid_counter = 0
_uid_lock = threading.Lock()
def uid():
"""A unique (within this program execution) integer."""
with _uid_lock:
global _uid_counter
_uid_counter += 1
return _uid_counter
def _status_to_exception(code, message):
try:
error_class = errors.exception_type_from_error_code(code)
return error_class(None, None, message)
except KeyError:
return errors.UnknownError(None, None, message, code)
class _NotOkStatusException(Exception):
"""Exception class to handle not ok Status."""
def __init__(self, message, code):
super(_NotOkStatusException, self).__init__()
self.message = message
self.code = code
def __str__(self):
e = _status_to_exception(self.code, self.message)
return "%s: %s" % (e.__class__.__name__, e)
pywrap_tensorflow.TFE_Py_RegisterExceptionClass(_NotOkStatusException)
def enable_tracing():
"""Enables tracing of execution and memory usage.
WARNING: tracing is not thread-safe.
"""
global _active_trace
_active_trace = memory_trace.MemoryTrace(
len(context.get_default_context().devices()))
def flush_trace():
"""Flushes the active trace, if it exists.
WARNING: tracing is not thread-safe.
"""
if _active_trace is not None:
_active_trace.flush_trace()
def active_trace():
"""Returns the current global active trace of execution and memory usage."""
return _active_trace
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for core."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import threading
import numpy as np
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
from tensorflow.python.eager import core
from tensorflow.python.eager import execute
from tensorflow.python.eager import tensor
from tensorflow.python.eager import test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
def truncated_normal(shape):
return execute.execute(
'TruncatedNormal',
1,
inputs=[shape],
attrs=('dtype', dtypes.float32.as_datatype_enum, 'T',
shape.dtype.as_datatype_enum, 'seed', 0, 'seed2', 0))[0]
class TFETest(test_util.TensorFlowTestCase):
def testContext(self):
ctx = context.Context()
self.assertFalse(ctx.in_graph_mode())
self.assertTrue(ctx.in_eager_mode())
self.assertEqual('', ctx.scope_name)
self.assertEqual(-1, ctx._device_index) # pylint: disable=protected-access
self.assertFalse(ctx.recording_summaries)
self.assertIsNone(ctx.summary_writer_resource)
del ctx
def testDefaultContext(self):
orig = context.get_default_context()
self.assertIs(context.get_default_context(), orig)
c0 = context.Context()
self.assertIs(context.get_default_context(), orig)
context_manager_0 = c0.as_default()
self.assertIs(context.get_default_context(), orig)
with context_manager_0 as c0:
self.assertIs(context.get_default_context(), c0)
with context.Context().as_default() as c1:
self.assertIs(context.get_default_context(), c1)
self.assertIs(context.get_default_context(), c0)
self.assertIs(context.get_default_context(), orig)
def testContextWithThreads(self):
def run_fn(ctx1):
ctx2 = context.get_default_context()
# Default context created in different threads are different.
self.assertIsNot(ctx1, ctx2)
# Check that default values of the context created in a different thread
# are set correctly.
self.assertFalse(ctx2.in_graph_mode())
self.assertTrue(ctx2.in_eager_mode())
self.assertEqual('', ctx2.scope_name)
self.assertEqual(-1, ctx2._device_index) # pylint: disable=protected-access
self.assertFalse(ctx2.recording_summaries)
self.assertIsNone(ctx2.summary_writer_resource)
ctx1 = context.get_default_context()
t = threading.Thread(target=run_fn, args=(ctx1,))
t.start()
t.join()
def testScalarTensor(self):
t = tensor.Tensor(3)
self.assertEqual(t.numpy(), tensor.Tensor(np.array(3)).numpy())
self.assertEqual(dtypes.int32, t.dtype)
self.assertEqual(0, t.shape.ndims)
self.assertAllEqual([], t.shape.as_list())
def testTensorAndNumpyMatrix(self):
expected = np.array([[1.0, 2.0], [3.0, 4.0]], np.float32)
actual = tensor.Tensor([[1.0, 2.0], [3.0, 4.0]])
self.assertAllEqual(expected, actual.numpy())
self.assertEqual(np.float32, actual.numpy().dtype)
self.assertEqual(dtypes.float32, actual.dtype)
self.assertAllEqual([2, 2], actual.shape.as_list())
def testFloatDowncast(self):
# Unless explicitly specified, float64->float32
t = tensor.Tensor(3.0)
self.assertEqual(dtypes.float32, t.dtype)
t = tensor.Tensor(3.0, dtype=dtypes.float64)
self.assertEqual(dtypes.float64, t.dtype)
def testBool(self):
t = tensor.Tensor(False)
if t:
self.assertFalse(True)
def testIntDowncast(self):
t = tensor.Tensor(3)
self.assertEqual(dtypes.int32, t.dtype)
t = tensor.Tensor(3, dtype=dtypes.int64)
self.assertEqual(dtypes.int64, t.dtype)
t = tensor.Tensor(2**33)
self.assertEqual(dtypes.int64, t.dtype)
def testTensorCreationFailure(self):
with self.assertRaises(Exception):
# Should fail because the each row of the Python object has a different
# number of columns.
self.assertEqual(None, tensor.Tensor([[1], [1, 2]]))
def testTensorPlacement(self):
if not context.context().num_gpus():
self.skipTest('No GPUs found')
x = tensor.Tensor(1.).as_gpu_tensor()
with context.device('gpu:0'):
y = tensor.Tensor(2.)
# Add would fail if t2 were not on GPU
result = execute.execute(
'Add', 1, inputs=[x, y],
attrs=('T', x.dtype.as_datatype_enum))[0].as_cpu_tensor().numpy()
self.assertEqual(3, result)
def testNumpyOrderHandling(self):
n = np.array([[1, 2], [3, 4]], order='F')
t = tensor.Tensor(n)
self.assertAllEqual([[1, 2], [3, 4]], t.numpy())
def testCopyBetweenDevices(self):
if not context.context().num_gpus():
self.skipTest('No GPUs found')
cpu = tensor.Tensor([[1., 2.], [3., 4.]])
c2g = cpu.as_gpu_tensor()
# Exercise a copy from GPU to CPU, even though we ignore the value.
_ = c2g.as_cpu_tensor()
with self.assertRaises(errors.InvalidArgumentError):
# c2g is on GPU. Copying between GPU devices fails
# (must redirect through CPU for now).
# TODO(ashankar): Perhaps the function should not fail and instead
# faciliate the copy through host memory?
c2g.as_gpu_tensor()
# Invalid device
with self.assertRaises(errors.InvalidArgumentError):
cpu.as_gpu_tensor(context.context().num_gpus() + 1)
def testNumpyForceCPU(self):
if not context.context().num_gpus():
self.skipTest('No GPUs found')
cpu = tensor.Tensor([[1., 2.], [3., 4.]])
c2g = cpu.as_gpu_tensor()
self.assertAllEqual(c2g.numpy(), cpu.numpy())
def testCopyFromCPUToCPU(self):
ta = tensor.Tensor([[1, 2], [3, 4]])
tb = ta.as_cpu_tensor()
self.assertNotEqual(ta._handle, tb._handle)
self.assertAllEqual(ta.numpy(), tb.numpy())
def testRegisterExceptionClass(self):
with self.assertRaises(TypeError):
pywrap_tensorflow.TFE_Py_RegisterExceptionClass(str)
pywrap_tensorflow.TFE_Py_RegisterExceptionClass(core._NotOkStatusException) # pylint: disable=protected-access
# TODO(agarwal): add tests passing incorrect typed values to attrs.
def testExecuteBasic(self):
three = tensor.Tensor(3)
five = tensor.Tensor(5)
product = execute.execute(
'Mul',
num_outputs=1,
inputs=[three, five],
attrs=('T', three.dtype.as_datatype_enum))[0]
self.assertEqual(15, product.numpy())
def testExecuteTooManyNumOutputs(self):
# num_outputs provided is 50, but only one output is produced.
# That should be okay.
product = execute.execute(
'Mul',
num_outputs=50,
inputs=[tensor.Tensor(3), tensor.Tensor(5)],
attrs=('T', dtypes.int32.as_datatype_enum))[0]
self.assertEqual(15, product.numpy())
def testMatMulGPU(self):
if not context.context().num_gpus():
self.skipTest('No GPUs found')
three = tensor.Tensor([[3.]]).as_gpu_tensor()
five = tensor.Tensor([[5.]]).as_gpu_tensor()
product = execute.execute(
'MatMul',
num_outputs=1,
inputs=[three, five],
attrs=('transpose_a', False, 'transpose_b', False, 'T',
three.dtype.as_datatype_enum))[0]
self.assertEqual([[15.0]], product.numpy())
def testExecuteStringAttr(self):
checked_three = execute.execute(
'CheckNumerics',
num_outputs=1,
inputs=[tensor.Tensor(3.)],
attrs=('message', 'just checking', 'T',
dtypes.float32.as_datatype_enum))[0]
self.assertEqual([[3]], checked_three.numpy())
def testExecuteStringAttrBadValue(self):
with self.assertRaises(errors.InvalidArgumentError):
_ = execute.execute(
'CheckNumerics',
num_outputs=1,
inputs=[tensor.Tensor(3.)],
attrs=('message', 1, 'T', dtypes.float32.as_datatype_enum))
def testExecuteFloatAttr(self):
almost_equal = execute.execute(
'ApproximateEqual',
num_outputs=1,
inputs=[tensor.Tensor(3.0), tensor.Tensor(2.9)],
attrs=('tolerance', 0.3, 'T', dtypes.float32.as_datatype_enum))[0]
self.assertTrue(almost_equal.numpy())
def testExecuteFloatAttrBadValue(self):
with self.assertRaises(errors.InvalidArgumentError):
_ = execute.execute(
'ApproximateEqual',
num_outputs=1,
inputs=[tensor.Tensor(3.0), tensor.Tensor(2.9)],
attrs=('tolerance', '0.3', 'T', dtypes.float32.as_datatype_enum))
def testExecuteIntAttr(self):
total = execute.execute(
'AddN',
num_outputs=1,
inputs=[tensor.Tensor(3), tensor.Tensor(4)],
attrs=('T', dtypes.int32.as_datatype_enum, 'N', 2))[0]
self.assertEqual(7, total.numpy())
def testExecuteIntAttrBadValue(self):
with self.assertRaises(errors.InvalidArgumentError):
_ = execute.execute(
'AddN',
num_outputs=1,
inputs=[tensor.Tensor(3), tensor.Tensor(4)],
attrs=('T', dtypes.int32.as_datatype_enum, 'N', '2'))
# Looks like we don't have an existing op with list(bool) attrs.
def testExecuteBoolAttr(self):
product = execute.execute(
'MatMul',
num_outputs=1,
inputs=[tensor.Tensor([[3]]), tensor.Tensor([[5]])],
attrs=('transpose_a', True, 'transpose_b', False, 'T',
dtypes.int32.as_datatype_enum))[0]
self.assertEqual([[15]], product.numpy())
def testExecuteShapeAttr(self):
execute.execute(
'VarHandleOp',
num_outputs=1,
inputs=[],
attrs=('shape', [1, 2], 'dtype', dtypes.int32.as_datatype_enum,
'container', '', 'shared_name', ''))
def testExecuteShapeAttrBadValue(self):
with self.assertRaises(errors.InvalidArgumentError):
execute.execute(
'VarHandleOp',
num_outputs=1,
inputs=[],
attrs=('shape', 1, 'dtype', dtypes.int32.as_datatype_enum,
'container', '', 'shared_name', ''))
def testExecuteListStringAttr(self):
execute.execute(
'TensorSummary',
num_outputs=1,
inputs=[tensor.Tensor(3.0)],
attrs=('T', dtypes.float32.as_datatype_enum, 'description',
'tensor_summary', 'labels', ['3',
'summary'], 'display_name', 'test'))
def testExecuteListStringAttrBadValue(self):
with self.assertRaises(errors.InvalidArgumentError):
execute.execute(
'TensorSummary',
num_outputs=1,
inputs=[tensor.Tensor(3.0)],
attrs=('T', dtypes.float32.as_datatype_enum, 'description', '',
'labels', 3, 'display_name', 'test'))
def testExecuteListStringAttrBadListValue(self):
with self.assertRaises(errors.InvalidArgumentError):
execute.execute(
'TensorSummary',
num_outputs=1,
inputs=[tensor.Tensor(3.0)],
attrs=('T', dtypes.float32.as_datatype_enum, 'description', '',
'labels', [3], 'display_name', 'test'))
def testExecuteListFloatAttr(self):
b = execute.execute(
'Bucketize',
num_outputs=1,
inputs=[tensor.Tensor([3.0, 5.0, 7.0])],
attrs=('T', dtypes.float32.as_datatype_enum, 'boundaries', [4.0,
6.0]))[0]
self.assertAllEqual([0, 1, 2], b.numpy())
def testExecuteListFloatAttrBadValue(self):
with self.assertRaises(errors.InvalidArgumentError):
execute.execute(
'Bucketize',
num_outputs=1,
inputs=[tensor.Tensor([3.0, 5.0, 7.0])],
attrs=('T', dtypes.float32.as_datatype_enum, 'boundaries', 4.0))
def testExecuteListFloatAttrBadListValue(self):
with self.assertRaises(errors.InvalidArgumentError):
execute.execute(
'Bucketize',
num_outputs=1,
inputs=[tensor.Tensor([3.0, 5.0, 7.0])],
attrs=('T', dtypes.float32.as_datatype_enum, 'boundaries',
['4.0', '6.0']))
def testExecuteListIntAttr(self):
b = execute.execute(
'Squeeze',
num_outputs=1,
inputs=[tensor.Tensor([[[3.0]]])],
attrs=('T', dtypes.float32.as_datatype_enum, 'squeeze_dims', [0, 2]))[0]
self.assertAllEqual([3], b.numpy())
def testExecuteListIntAttrBadValue(self):
with self.assertRaises(errors.InvalidArgumentError):
execute.execute(
'Squeeze',
num_outputs=1,
inputs=[tensor.Tensor([[[3.0]]])],
attrs=('T', dtypes.float32.as_datatype_enum, 'squeeze_dims', 0))
def testExecuteListIntAttrBadListValue(self):
with self.assertRaises(errors.InvalidArgumentError):
execute.execute(
'Squeeze',
num_outputs=1,
inputs=[tensor.Tensor([[[3.0]]])],
attrs=('T', dtypes.float32.as_datatype_enum, 'squeeze_dims',
['0', '2']))
def testExecuteListTypeListShapeAttr(self):
execute.execute(
'Barrier',
num_outputs=1,
inputs=[],
attrs=('component_types', [dtypes.float64.as_datatype_enum], 'shapes',
[[1, 2]], 'capacity', -1, 'container', '', 'shared_name', ''))
def testExecuteListTypeAttrBadValue(self):
with self.assertRaises(errors.InvalidArgumentError):
execute.execute(
'Barrier',
num_outputs=1,
inputs=[],
attrs=('component_types', dtypes.float64.as_datatype_enum, 'shapes',
[[1, 2]], 'capacity', -1, 'container', '', 'shared_name', ''))
def testExecuteListTypeAttrBadListValue(self):
with self.assertRaises(errors.InvalidArgumentError):
execute.execute(
'Barrier',
num_outputs=1,
inputs=[],
attrs=('component_types', '1', 'shapes', [[1, 2]], 'capacity', -1,
'container', '', 'shared_name', ''))
def testExecuteListShapeAttrBadValue(self):
with self.assertRaises(errors.InvalidArgumentError):
execute.execute(
'Barrier',
num_outputs=1,
inputs=[],
attrs=('component_types', [dtypes.float64.as_datatype_enum], 'shapes',
[1, 2], 'capacity', -1, 'container', '', 'shared_name', ''))
def testExecuteListShapeAttrBadListValue(self):
with self.assertRaises(errors.InvalidArgumentError):
execute.execute(
'Barrier',
num_outputs=1,
inputs=[],
attrs=('component_types', [dtypes.float64.as_datatype_enum], 'shapes',
[1], 'capacity', -1, 'container', '', 'shared_name', ''))
def testExecuteMultipleOutputs(self):
split_dim = 1
value = [[0, 1, 2], [3, 4, 5]]
x1, x2, x3 = execute.execute(
'Split',
num_outputs=3,
inputs=[tensor.Tensor(split_dim),
tensor.Tensor(value)],
attrs=('num_split', 3, 'T', dtypes.int32.as_datatype_enum))
self.assertAllEqual([[0], [3]], x1.numpy())
self.assertAllEqual([[1], [4]], x2.numpy())
self.assertAllEqual([[2], [5]], x3.numpy())
def testExecuteBadNumOutputsArgument(self):
with self.assertRaises(TypeError):
execute.execute(
'Relu', [],
inputs=[tensor.Tensor(3.0)],
attrs=('T', dtypes.float32.as_datatype_enum))
def testExecuteUnknownOp(self):
with self.assertRaises(errors.NotFoundError):
execute.execute('BlahBlahBlah', num_outputs=1, inputs=[], attrs=None)
def testExecuteUnknownAttr(self):
with self.assertRaises(errors.InvalidArgumentError):
execute.execute(
'Identity',
num_outputs=1,
inputs=[tensor.Tensor(3)],
attrs=('T', dtypes.int32.as_datatype_enum, 'unknown_attr', 'blah'))
def testComposition(self):
def add(x, y):
return execute.execute(
'Add',
num_outputs=1,
inputs=[x, y],
attrs=('T', dtypes.int32.as_datatype_enum))[0]
x = tensor.Tensor(1)
three_x = add(add(x, x), x)
self.assertEquals(dtypes.int32, three_x.dtype)
self.assertEquals(3, three_x.numpy())
def testOperationWithNoInputsRunsOnDevice(self):
if not context.context().num_gpus():
self.skipTest('No GPUs found')
shape = tensor.Tensor([], dtype=dtypes.int32)
# x: Run the "TruncatedNormal" op CPU and copy result to GPU.
x = truncated_normal(shape).as_gpu_tensor()
# y: Explicitly run the "TruncatedNormal" op on GPU.
with context.device('gpu:0'):
y = truncated_normal(shape)
# Add would fail if x and y were not on the same device.
execute.execute('Add', 1, inputs=[x, y],
attrs=('T', x.dtype.as_datatype_enum))
def testInvalidDevice(self):
with self.assertRaises(ValueError):
with context.device('pu:0'):
_ = tensor.Tensor(1)
if __name__ == '__main__':
test.main()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Decorator to overrides the gradient for a function."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from autograd import core as ag_core
from tensorflow.python.eager import tape
from tensorflow.python.eager import tensor as _tensor
from tensorflow.python.framework import ops as tf_ops
from tensorflow.python.util import nest
def _watch_value_from_tape(tensor):
for t in tape._tape_stack.stack: # pylint: disable=protected-access
w = t.value.tensors.get(tape.tensor_id(tensor), None)
if w is not None:
return w
return tensor
def custom_gradient(f):
"""Decorator to define a function with a custom gradient.
The input function is expected to return the tuple
(results, gradient_function)
The output function will return results while possibly recording the
gradient_function and inputs in the tape.
Args:
f: function to be decorated.
Returns:
decorated function.
"""
def decorated(*args, **kwargs):
"""Decorated function with custom gradient."""
input_tensors = [_watch_value_from_tape(x) for x in args
if isinstance(x, (_tensor.Tensor, tf_ops.Tensor))
or ag_core.isnode(x)]
result, grad_fn = f(*args, **kwargs)
flat_result = nest.flatten(result)
flat_result = [ag_core.getval(x) for x in flat_result]
flat_result = tape.record_operation(
flat_result,
input_tensors,
[],
grad_fn)
flat_result = list(flat_result)
return nest.pack_sequence_as(structure=result, flat_sequence=flat_result)
return decorated
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions called by the generated code to execute an eager-mode op."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from autograd import core as ag_core
import six
from google.protobuf import text_format
from tensorflow.core.framework import tensor_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
from tensorflow.python.eager import core
from tensorflow.python.eager import tape
from tensorflow.python.eager import tensor
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.util import compat
def execute(op_name, num_outputs, inputs, attrs=None, name=None):
"""Execute a TensorFlow operation.
Args:
op_name: Name of the TensorFlow operation (see REGISTER_OP in C++ code) to
execute.
num_outputs: The number of outputs of the operation to fetch.
(Explicitly provided instead of being inferred for performance
reasons).
inputs: A list of inputs to the operation. Each entry should be a Tensor, or
a value which can be passed to the Tensor constructor to create one.
attrs: A tuple with alternating string attr names and attr values for this
operation.
name: Customized name for the operation.
Returns:
None if there are no outputs, a single Tensor object if there is one output
and a list of Tensor objects if there are multiple outputs.
Raises:
An exception on error.
"""
ctx = context.get_default_context()
# TODO(apassos) move this to convert_to_tensor
inputs = [ag_core.getval(x) for x in inputs]
# pylint: disable=protected-access
input_handles = [c._handle for c in inputs]
device_name = ctx.device_name
try:
outh = pywrap_tensorflow.TFE_Py_Execute(ctx._handle, device_name,
op_name, input_handles, attrs,
num_outputs)
# pylint: enable=protected-access
except core._NotOkStatusException as e: # pylint: disable=protected-access
raise core._status_to_exception(e.code, e.message) # pylint: disable=protected-access
# pylint: enable=protected-access
tensors = [tensor._tensor_from_handle(x) for x in outh] # pylint: disable=protected-access
if core.active_trace() is not None:
trace_name = name if name else op_name
for t in tensors:
# pylint: disable=protected-access
core.active_trace().record_tensor(trace_name,
tape.tensor_id(t),
t._device_name(),
t.shape.num_elements())
# pylint: enable=protected-access
return tensors
def record_gradient(unused_op_name, unused_inputs, unused_attrs, results,
unused_name):
"""Import backprop if you want gradients recorded."""
return results
def make_float(v, arg_name):
if not isinstance(v, compat.real_types):
raise TypeError("Expected float for argument '%s' not %s." %
(arg_name, repr(v)))
return float(v)
def make_int(v, arg_name):
if isinstance(v, six.string_types):
raise TypeError("Expected int for argument '%s' not %s." %
(arg_name, repr(v)))
try:
return int(v)
except (ValueError, TypeError):
raise TypeError("Expected int for argument '%s' not %s." %
(arg_name, repr(v)))
def make_str(v, arg_name):
if not isinstance(v, compat.bytes_or_text_types):
raise TypeError("Expected string for argument '%s' not %s." %
(arg_name, repr(v)))
return compat.as_bytes(v) # Convert unicode strings to bytes.
def make_bool(v, arg_name):
if not isinstance(v, bool):
raise TypeError("Expected bool for argument '%s' not %s." %
(arg_name, repr(v)))
return v
def make_type(v, arg_name):
try:
v = dtypes.as_dtype(v).base_dtype
except TypeError:
raise TypeError("Expected DataType for argument '%s' not %s." %
(arg_name, repr(v)))
i = v.as_datatype_enum
return i
def make_shape(v, arg_name):
"""Convert v into a list."""
# Args:
# v: A TensorShapeProto, a list of ints, or a tensor_shape.TensorShape.
# arg_name: String, for error messages.
# Returns:
# None if the rank is unknown, otherwise a list of ints (or Nones in the
# position where the dimension is unknown).
try:
shape = tensor_shape.as_shape(v)
except TypeError as e:
raise TypeError("Error converting %s to a TensorShape: %s" % (arg_name, e))
except ValueError as e:
raise ValueError("Error converting %s to a TensorShape: %s" % (arg_name, e))
if shape.ndims is None:
return None
else:
return shape.as_list()
def make_tensor(v, arg_name):
"""Ensure v is a TensorProto."""
if isinstance(v, tensor_pb2.TensorProto):
return v
elif isinstance(v, six.string_types):
pb = tensor_pb2.TensorProto()
text_format.Merge(v, pb)
return pb
raise TypeError(
"Don't know how to convert %s to a TensorProto for argument '%s'" %
(repr(v), arg_name))
def args_to_matching_eager(l, default_dtype=None):
"""Convert sequence `l` to eager same-type Tensors."""
# TODO(josh11b): Could we do a better job if we also passed in the
# allowed dtypes when that was known?
# Is some input already a Tensor with a dtype?
dtype = None
for t in l:
if isinstance(ag_core.getval(t), tensor.Tensor):
dtype = t.dtype
break
if dtype is None:
# TODO(josh11b): At the moment, I don't think this can fail, but at some
# point we likely should have some logic to prevent bad conversions.
dtype = default_dtype
if dtype is None:
# Infer a dtype based on the first value, and use that dtype for the
# remaining values.
ret = []
for t in l:
ret.append(tensor.convert_to_eager_tensor(t, dtype))
if dtype is None:
dtype = ret[-1].dtype
else:
ret = [tensor.convert_to_eager_tensor(t, dtype) for t in l]
return dtype, ret
def convert_to_mixed_eager_tensors(values):
v = [t if isinstance(ag_core.getval(t), tensor.Tensor) else tensor.Tensor(t)
for t in values]
types = [t.dtype for t in v]
return types, v
def args_to_mixed_eager_tensors(lists):
"""Converts a list of same-length lists of values to eager tensors."""
assert len(lists) > 1
# Generate an error if len(lists[i]) is not the same for all i.
lists_ret = []
for l in lists[1:]:
if len(l) != len(lists[0]):
raise ValueError(
"Expected list arguments to be the same length: %d != %d (%r vs. %r)"
% (len(lists[0]), len(l), lists[0], l))
lists_ret.append([])
# Convert the first element of each list first, then the second element, etc.
types = []
for i in range(len(lists[0])):
dtype = None
# If any list has a Tensor, use that dtype
for l in lists:
if isinstance(ag_core.getval(l[i]), tensor.Tensor):
dtype = l[i].dtype
break
if dtype is None:
# Convert the first one and use its dtype.
lists_ret[0].append(tensor.convert_to_eager_tensor(lists[0][i]))
dtype = lists_ret[0][i].dtype
for j in range(1, len(lists)):
lists_ret[j].append(
tensor.convert_to_eager_tensor(lists[j][i], dtype=dtype))
else:
# Convert everything to the found dtype.
for j in range(len(lists)):
lists_ret[j].append(
tensor.convert_to_eager_tensor(lists[j][i], dtype=dtype))
types.append(dtype)
return types, lists_ret
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=unidiomatic-typecheck
"""Defun decorator for defining graph-mode functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import contextlib
import threading
from autograd import core as ag_core
import numpy as np
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
from tensorflow.python.eager import core
from tensorflow.python.eager import execute
from tensorflow.python.eager import tape
from tensorflow.python.eager import tensor
from tensorflow.python.eager.graph_only_ops import graph_placeholder
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import graph_to_function_def
from tensorflow.python.framework import ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.util import nest
# Thread-local storage for tfe Tensors which are referenced while evaluating a
# graph-mode function.
_scoped_captures = threading.local()
# _scoped_captures.tensors is either None or a map from tfe.Tensor id to a pair
# of a tfe tensor and its corresponding placeholder to pass as a function
# argument. The value should be None unless we're in function definition
# context.
_scoped_captures.tensors = None
@contextlib.contextmanager
def capture_tensors(captures):
old = _scoped_captures.__dict__.get("tensors", None)
try:
_scoped_captures.tensors = captures
yield
finally:
_scoped_captures.tensors = old
def _convert_to_graph_constant(value, dtype=None, name=None, as_ref=False):
"""Captures a tfe Tensor while building a graph mode function.
Creates a placeholder to pass the tensor as an argument.
Arguments:
value: A tfe.Tensor object
dtype: The datatype of the value produced by the node in the graph.
name: Name of the node in the graph.
as_ref: Ignored (required by register_tensor_conversion_function).
Returns:
A placeholder which will, at runtime, have the value of this tensor.
Raises:
ValueError: if called outside a defun context.
"""
_ = as_ref
tensor_map = _scoped_captures.tensors
if tensor_map is None:
raise ValueError(
"Trying to use tfe.Tensor objects in a graph outside graph mode. "
"To build a graph use tfe.defun or tfe.func_to_object.")
captured_value = tensor_map.get(tape.tensor_id(value), None)
if captured_value is None:
captured_value = graph_placeholder(dtype=dtype or value.dtype,
shape=value.shape,
name=name)
if captured_value.dtype == dtypes.resource:
captured_value._handle_data = value._handle_data # pylint: disable=protected-access
tensor_map[tape.tensor_id(value)] = (value, captured_value)
else:
captured_value = captured_value[1]
return captured_value
# TODO(apassos): it'd be really nice if we could scope this registration.
ops.register_tensor_conversion_function(tensor.Tensor,
_convert_to_graph_constant)
class _CapturingContext(object):
"""Tracks references to Tensors outside this context while it is active."""
def __init__(self):
# known_ops are ops which are created while this context is active
self.known_ops = set()
# captured_tensors are all tensors referenced to by ops in this context but
# not produced in it
self.captured_tensors = set()
def AddOp(self, op): # pylint: disable=invalid-name
if op.type in ["Variable", "VariableV2", "VarHandleOp"]:
raise ValueError("tfe.defun cannot capture variables created without "
"using tf.get_variable. Op: %s" % op)
self.known_ops.add(op)
for i in op.inputs:
if i.op not in self.known_ops:
self.captured_tensors.add(i)
def __enter__(self):
self._g = ops.get_default_graph()
self._old = self._g._get_control_flow_context() # pylint: disable=protected-access
self._g._set_control_flow_context(self) # pylint: disable=protected-access
def __exit__(self, _, __, ___): # pylint: disable=invalid-name
self._g._set_control_flow_context(self._old) # pylint: disable=protected-access
def _forward_name(n):
"""The name of a generated forward defun named n."""
return "__forward_%s_%s" % (n, core.uid())
def _backward_name(n):
"""The name of a generated backward defun named n."""
return "__backward_%s_%s" % (n, core.uid())
def _inference_name(n):
"""The name of a forward-but-no-gradient defun named n."""
return "__inference_%s_%s" % (n, core.uid())
class _DefinedFunction(object):
"""Mocks the interface of tf _DefinedFunction."""
def __init__(self, fdef):
self.definition = fdef
self.name = fdef.signature.name
self.grad_func_name = None
self.python_grad_func = None
def _map_sequence_obj_to_idx(sequence):
"""Maps objs in the sequence from id(obj) to sequence index."""
return {id(x): i for i, x in enumerate(sequence)}
class _GraphModeFunction(object):
"""Callable object representing a graph-mode function.
Args:
input_placeholders: list of placeholder values to feed when calling
the wrapped function.
extra_inputs: Tensor inputs this function definition closed over which
are passed as arguments. Need to track so gradients are supported
correctly.
fdef: the function definition we want to call.
graph: the graph from which the fdef operations were pulled. Used as
a context when computing gradients.
operations: the subset of operations in the graph used in the function
definition.
func_outputs: the python outputs of the graph-mode function, with
tensorflow.Tensor objects to be replaced by tfe values when called.
func_outputs_to_fdef_outputs: Maps id(obj) in func_outputs to index of
fdef's outputs. It allows mapping fdef output tensors to nested
func_outputs structure.
output_shapes: List of shapes of all tensors which are output by the
internal function.
"""
def __init__(self,
input_placeholders,
extra_inputs,
fdef,
graph,
operations,
func_outputs,
func_outputs_to_fdef_outputs,
output_shapes):
assert len(input_placeholders) == len(fdef.signature.input_arg), "%s %s" % (
len(input_placeholders), len(fdef.signature.input_arg))
self._input_placeholders = input_placeholders
self._extra_inputs = list(extra_inputs)
self._graph = graph
self._has_backprop = False
self._func_name = fdef.signature.name
self._fdef = _DefinedFunction(fdef)
self._num_outputs = len(fdef.signature.output_arg)
self._ops = operations
self._func_outputs = func_outputs
if (isinstance(func_outputs, (ops.Tensor, type(None)))
or ag_core.isnode(func_outputs)):
self._returns = [func_outputs]
else:
self._returns = list(func_outputs)
self._returns_to_fedf_outputs = func_outputs_to_fdef_outputs
self._output_shapes = output_shapes
def _compute_backprop(self):
"""Computes the backprop function object for this function."""
self._has_backprop = True
with self._graph.as_default(), context.graph_mode():
c = _CapturingContext()
with c:
filtered_outputs = [ag_core.getval(x)
for x in self._returns if x is not None]
self._out_grad_placeholders = [
graph_placeholder(x.dtype, x.shape)
for x in filtered_outputs
]
in_gradients = gradients_impl.gradients(
filtered_outputs,
self._input_placeholders,
grad_ys=self._out_grad_placeholders)
shapes = [x.shape for x in in_gradients if x is not None]
captures = list(sorted(c.captured_tensors, key=lambda x: x.name))
forward_function_def = graph_to_function_def.graph_to_function_def(
self._graph, self._ops,
self._input_placeholders,
filtered_outputs + captures)
self._forward_fdef = _DefinedFunction(forward_function_def)
_register_with_name(_forward_name(self._func_name),
forward_function_def)
backward_outputs = [x for x in in_gradients if x is not None]
all_inputs = self._out_grad_placeholders + captures
backward_function_def = graph_to_function_def.graph_to_function_def(
self._graph,
[x.op for x in self._out_grad_placeholders] +
list(sorted(c.known_ops, key=lambda x: x.name)),
all_inputs,
backward_outputs)
_register_with_name(_backward_name(self._func_name), backward_function_def)
self._backward_function = _GraphModeFunction(
all_inputs, [], backward_function_def, self._graph, c.known_ops,
in_gradients, _map_sequence_obj_to_idx(backward_outputs), shapes)
def _backprop_call(self, args):
"""Calls the wrapped function and records the result on a tape."""
all_args = args + self._extra_inputs
signature = self._forward_fdef.definition.signature
if context.in_graph_mode():
g = ops.get_default_graph()
g._add_function(self._forward_fdef) # pylint: disable=protected-access
unwrapped_args = [ag_core.getval(x) for x in all_args]
op = g.create_op(signature.name,
[ops.convert_to_tensor(x) for x in unwrapped_args],
[dtypes.DType(x.type) for x in signature.output_arg],
op_def=signature,
name="FunctionCall",
compute_shapes=False)
outputs = op.outputs
outputs = [outputs] if isinstance(
outputs, (tensor.Tensor, ops.Tensor, type(None))) else list(outputs)
for i, s in enumerate(self._output_shapes):
outputs[i].set_shape(s)
else:
outputs = execute.execute(
signature.name,
num_outputs=len(signature.output_arg),
inputs=all_args)
real_outputs = outputs[:len(self._returns)]
side_outputs = outputs[len(self._returns):]
watched_extra_inputs = []
for t in self._extra_inputs:
tid = tape.tensor_id(t)
for t in tape._tape_stack.stack: # pylint: disable=protected-access
w = t.value.tensors.get(tid, None)
if w is not None:
watched_extra_inputs.append(w)
break
else: # Note: for-else here done on purpose
watched_extra_inputs.append(t)
real_outputs = tape.record_operation(real_outputs,
(args + watched_extra_inputs),
side_outputs,
self._backward_function)
return self._build_call_outputs(self._returns, real_outputs)
def __call__(self, *args):
"""Executes the passed function in eager mode."""
tensor_inputs = [x for x in nest.flatten(args)
if isinstance(x, (tensor.Tensor, ops.Tensor,
tensor.LazyZero))
or ag_core.isnode(x)]
if tape.should_record(tensor_inputs) or any(
tape.any_tape_has(t) for t in self._extra_inputs):
if not self._has_backprop:
self._compute_backprop()
return self._backprop_call(tensor_inputs)
if context.in_graph_mode():
g = ops.get_default_graph()
g._add_function(self._fdef) # pylint: disable=protected-access
signature = self._fdef.definition.signature
args = list(tensor_inputs) + self._extra_inputs
op = g.create_op(signature.name,
[ops.convert_to_tensor(x) for x in args],
[dtypes.DType(x.type) for x in signature.output_arg],
op_def=signature,
name="FunctionCall",
compute_shapes=False)
result = op.outputs
for i, s in enumerate(self._output_shapes):
result[i].set_shape(s)
else:
tensor_inputs = [x.tensor() if isinstance(x, tensor.LazyZero) else x
for x in tensor_inputs]
result = execute.execute(
self._func_name,
num_outputs=self._num_outputs,
inputs=tensor_inputs + self._extra_inputs)
return self._build_call_outputs(self._returns, result)
def _build_call_outputs(self, func_outputs, result):
"""Maps the fdef output list to actual output structure.
Args:
func_outputs: The outputs originally defined by the graph function. It
could potentially be a nested structure.
result: Output lists defined by FunctionDef.
Returns:
The actual call output.
"""
if self._func_outputs is None:
return None
if isinstance(ag_core.getval(self._func_outputs), ops.Tensor):
return result[0]
outputs = []
for o in func_outputs:
vo = ag_core.getval(o)
if isinstance(vo, ops.Tensor):
outputs.append(result[self._returns_to_fedf_outputs[id(vo)]])
elif type(vo) in (tuple, list):
outputs.append(self._build_call_outputs(o, result))
else:
outputs.append(o)
return tuple(outputs) if type(func_outputs) is tuple else outputs
def _get_defun_inputs(args):
"""Maps the inputs args to graph inputs."""
ret = []
for a in args:
a = ag_core.getval(a)
if isinstance(a, (tensor.LazyZero, ops.Tensor, tensor.Tensor)):
ret.append(graph_placeholder(a.dtype, a.shape))
elif type(a) in (tuple, list):
ret.append(_get_defun_inputs(a))
else:
ret.append(a)
return tuple(ret) if type(args) is tuple else ret
def _defun_internal(name, func, args, kwds):
"""Defines and returns graph-mode version of func."""
with context.graph_mode():
tmp_graph = ops.Graph()
with tmp_graph.as_default():
func_inputs = _get_defun_inputs(args)
captures = {}
with capture_tensors(captures):
func_outputs = func(*func_inputs, **kwds)
ids = list(sorted(captures.keys()))
if ids:
extra_inputs, extra_placeholders = zip(*[captures[x] for x in ids])
else:
extra_inputs = []
extra_placeholders = []
outputs_list = nest.flatten(func_outputs)
output_shapes = [x.shape for x in outputs_list if x is not None]
flat_inputs = [x for x in nest.flatten(func_inputs)
if isinstance(x, ops.Tensor)]
all_inputs = flat_inputs + list(extra_placeholders)
func_def_outputs = [ag_core.getval(x) for x in outputs_list if x is not None]
inference_function_def = graph_to_function_def.graph_to_function_def(
tmp_graph, tmp_graph.get_operations(),
all_inputs,
func_def_outputs)
# Register any other functions defined in the graph
# TODO(ashankar): Oh lord, forgive me for this lint travesty.
for f in tmp_graph._functions.values(): # pylint: disable=protected-access
# TODO(ashankar): What about the gradient registry?
_register_with_name(f.name, f.definition)
_register_with_name(_inference_name(name), inference_function_def)
return _GraphModeFunction(
all_inputs,
extra_inputs,
inference_function_def,
tmp_graph,
tmp_graph.get_operations(),
func_outputs,
_map_sequence_obj_to_idx(func_def_outputs),
output_shapes)
# Defun uses this instead of Tensor as a cache key. Using dtype because
# TensorFlow graphs are not parametric wrt dtypes, and using shapes for
# performance reasons, as much TensorFlow code specializes on known shapes to
# produce slimmer graphs.
_TensorDtype = collections.namedtuple("_TensorDtype", ["dtype", "shape"])
_ZeroDtype = collections.namedtuple("_ZeroDtype", ["dtype", "shape"])
def _cache_key(x):
"""Cache key for tfe functions."""
x = ag_core.getval(x)
if isinstance(x, tensor.Tensor):
return _TensorDtype(x.dtype, x._shape_tuple()) # pylint: disable=protected-access
if isinstance(x, tensor.LazyZero):
return _TensorDtype(x.dtype, tuple(x.shape.as_list())) # pylint: disable=protected-access
if isinstance(x, np.ndarray):
return ("array", x.shape, tuple(x.reshape(-1)))
if type(x) in (list, tuple):
return tuple([_cache_key(a) for a in x])
return x
def register_function_def(fdef):
fdef_string = fdef.SerializeToString()
with errors.raise_exception_on_not_ok_status() as status:
pywrap_tensorflow.TFE_ContextAddFunctionDef(
context.get_default_context()._handle, # pylint: disable=protected-access
fdef_string,
len(fdef_string),
status)
def _register_with_name(name, fdef):
"""Registers the function `fdef` with the name `name`."""
fdef.signature.name = name
register_function_def(fdef)
# TODO(apassos): better error messages for non-hashable arguments.
def named_defun(func, name):
"""Defines a function with a given name.
See the documentation for `defun` for more information on the semantics of the
function.
Args:
func: the function to be wrapped.
name: the name given to it.
Returns:
the wrapped function.
"""
arguments_to_functions = {}
def decorated(*args, **kwds):
"""Decorated version of func."""
# Macroexpand on non-Tensor arguments
cache_key = tuple(_cache_key(x) for x in args)
assert all(not isinstance(x, tensor.Tensor) for x in kwds.values())
cache_key = (cache_key, tuple(kwds.items()))
if cache_key not in arguments_to_functions:
arguments_to_functions[cache_key] = _defun_internal(
name, func, args, kwds)
return arguments_to_functions[cache_key](*args)
return decorated
def defun(func):
"""Decorator to compile func into graph_mode.
defun converts a function that constructs a TensorFlow graph into a function
that executes the graph. TensorFlow graphs typically execute faster and with a
lower memory-footprint than executing each of the operations that make up the
function individually as the TensorFlow runtime can optimize the graph and
execute sub-operations in parallel.
func must be a Python function that constructs a TensorFlow graph,
typically using functions in the tensorflow module.
Arguments to func can be either tfe.Tensor objects or Python
objects. Non-Tensor python objects are treated as constants, and new function
definitions are created internally based on their values.
func must return a tf.Tensor (NOT a tfe.Tensor) or a list of tf.Tensor (NOT a
tfe.Tensor). TODO(apassos) make the wrapped tfe ops return tf.Tensors when in
graph mode.
TODO(apassos): deal with captured global state. Deal with control flow.
Args:
func: function to be compiled.
Returns:
A callable that will execute the compiled function (and return zero
or more tfe.Tensor objects)
"""
return named_defun(func, func.__name__)
"""For eager-mode Python."""
load("//tensorflow:tensorflow.bzl", "clean_dep", "tf_copts")
def tfe_gen_op_wrapper_py(name,
out=None,
visibility=None,
deps=[],
generated_target_name=None):
"""Generate an eager-mode Python op wrapper for an op library."""
# Construct a cc_binary containing the specified ops.
tool_name = "gen_" + name + "_py_wrappers_cc"
if not deps:
deps = [str(Label("//tensorflow/core:" + name + "_op_lib"))]
native.cc_binary(
name=tool_name,
linkopts=["-lm"],
copts=tf_copts(),
linkstatic=1,
deps=([
clean_dep("//tensorflow/python/eager:python_eager_op_gen_main")
] + deps),
visibility=[clean_dep("//visibility:public")],)
# Invoke the previous cc_binary to generate a python file.
if not out:
out = "gen_" + name + ".py"
native.genrule(
name=name + "_pygenrule",
outs=[out],
tools=[tool_name],
cmd=("$(location " + tool_name + ") > $@"))
# Make a py_library out of the generated python file.
if not generated_target_name:
generated_target_name = name
native.py_library(
name=generated_target_name,
srcs=[out],
srcs_version="PY2AND3",
visibility=visibility,
deps=[
clean_dep("//tensorflow/python/eager:framework_for_generated_wrappers"),
],)
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Graph-only versions of a few op functions, for internal use only."""
# Must be separate from array_ops to avoid a cyclic dependency.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.framework import ops
def graph_zeros_like(tensor):
"""Graph-only version of tf.zeros_like(), for internal use only."""
g = ops._get_graph_from_inputs([tensor]) # pylint: disable=protected-access
with g.as_default(), ops.name_scope(None, "zeros_like", [tensor]) as name:
tensor = ops.convert_to_tensor(tensor, name="tensor")
dtype = tensor.dtype.base_dtype.as_datatype_enum
dtype_value = attr_value_pb2.AttrValue(type=dtype)
op = g.create_op("ZerosLike", [tensor], [dtype], input_types=[dtype],
attrs={"T": dtype_value}, name=name)
result, = op.outputs
return result
def graph_placeholder(dtype, shape, name=None):
"""Graph-only version of tf.placeholder(), for internal use only."""
dtype = dtype.base_dtype.as_datatype_enum
dtype_value = attr_value_pb2.AttrValue(type=dtype)
shape = attr_value_pb2.AttrValue(shape=shape.as_proto())
g = ops.get_default_graph()
with ops.name_scope(name, "placeholder", []) as name:
op = g.create_op("Placeholder", [], [dtype], input_types=[],
attrs={"dtype": dtype_value, "shape": shape}, name=name)
result, = op.outputs
return result
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility to trace per-device memory consumption across time over execution."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
TraceEntry = collections.namedtuple(
"TraceEntry", ["op_name", "tensor_id", "mem_usage", "device", "size"])
TensorData = collections.namedtuple(
"TensorData", ["op_name", "tensor_size", "device"])
class MemoryTrace(object):
"""Records a trace of memory usage over operation execution."""
def __init__(self, n_devices):
self.trace = []
self.tensor_to_data = {}
self.current_device_mem_usage = [0] * n_devices
def record_tensor(self, op_name, tensor_id, device, size):
self.current_device_mem_usage[device] += size
self.tensor_to_data[tensor_id] = TensorData(op_name, size, device)
self.trace.append(TraceEntry(op_name,
tensor_id,
self.current_device_mem_usage[:],
device,
size))
def delete_tensor(self, tensor_id):
if tensor_id not in self.tensor_to_data:
return
data = self.tensor_to_data.pop(tensor_id)
self.current_device_mem_usage[data.device] -= data.tensor_size
self.trace.append(TraceEntry(data.op_name,
tensor_id,
self.current_device_mem_usage[:],
data.device,
-data.tensor_size))
def flush_trace(self):
"""Prints the formatted trace recorded so far."""
longest_op_name = max(len(t.op_name) for t in self.trace)
longest_op_name = max(longest_op_name, len("op_name"))
longest_heap_size = max(max(len(str(d)) for d in t.mem_usage)
for t in self.trace)
longest_heap_size = max(longest_heap_size, len("d0"))
longest_id_len = max(len(str(t.tensor_id)) for t in self.trace)
longest_id_len = max(longest_id_len, 2)
first_line = []
first_line.append("+/-")
first_line.append("op_name".ljust(longest_op_name))
first_line.append("id".ljust(longest_id_len))
for i in range(len(self.current_device_mem_usage)):
first_line.append(("d"+str(i)).ljust(longest_heap_size))
first_line.append("size")
print(" | ".join(first_line))
for t in self.trace:
line = []
if t.size > 0:
line.append("+ ")
else:
line.append("- ")
line.append(t.op_name.ljust(longest_op_name))
line.append(str(t.tensor_id).ljust(longest_id_len))
for d in t.mem_usage:
line.append(str(d).ljust(longest_heap_size))
line.append(str(t.size))
print(" | ".join(line))
self.trace = []
print()
此差异已折叠。
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_PYTHON_EAGER_PYTHON_EAGER_OP_GEN_H_
#define THIRD_PARTY_TENSORFLOW_PYTHON_EAGER_PYTHON_EAGER_OP_GEN_H_
#include <string>
#include <vector>
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
// hidden_ops should be a list of Op names that should get a leading _
// in the output. Prints the output to stdout.
void PrintEagerPythonOps(const OpList& ops,
const std::vector<string>& hidden_ops,
bool require_shapes);
// Get the python wrappers for a list of ops in a OpList.
// `op_list_buf` should be a pointer to a buffer containing
// the binary encoded OpList proto, and `op_list_len` should be the
// length of that buffer.
string GetEagerPythonWrappers(const char* op_list_buf, size_t op_list_len);
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_PYTHON_EAGER_PYTHON_EAGER_OP_GEN_H_
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/python/eager/python_eager_op_gen.h"
#include <memory>
#include <string>
#include <vector>
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/platform/init_main.h"
namespace tensorflow {
namespace {
void PrintAllPythonOps(const std::vector<string>& hidden_ops) {
OpList ops;
OpRegistry::Global()->Export(false, &ops);
PrintEagerPythonOps(ops, hidden_ops, true /* require_shapes */);
}
} // namespace
} // namespace tensorflow
int main(int argc, char* argv[]) {
tensorflow::port::InitMain(argv[0], &argc, &argv);
if (argc == 1) {
tensorflow::PrintAllPythonOps({});
} else {
return -1;
}
return 0;
}
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
......@@ -17,6 +17,8 @@ limitations under the License.
* The includes are intentionally not alphabetically sorted, as the order of
* includes follows dependency order */
%include "tensorflow/python/pywrap_tfe.i"
%include "tensorflow/python/util/port.i"
%include "tensorflow/python/util/py_checkpoint_reader.i"
%include "tensorflow/python/util/stat_summarizer.i"
......@@ -45,3 +47,4 @@ limitations under the License.
%include "tensorflow/python/grappler/tf_optimizer.i"
%include "tensorflow/python/grappler/cost_analyzer.i"
%include "tensorflow/python/grappler/model_analyzer.i"
......@@ -175,6 +175,7 @@ sh_binary(
"//tensorflow/python/debug:debug_pip",
"//tensorflow/python/saved_model:saved_model",
"//tensorflow/python:spectral_ops_test_util",
"//tensorflow/python/eager:pip_dependencies",
"//tensorflow/python/tools:tools_pip",
"//tensorflow/tools/dist_test/server:grpc_tensorflow_server",
],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册