提交 f51d8c75 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Add end-to-end test for RemoteFusedGraphExecuteOp

Change: 149971795
上级 cbc73103
......@@ -3782,6 +3782,7 @@ filegroup(
],
exclude = [
"*test.cc",
"*test_util*",
"*testutil*",
"*testlib*",
"*main.cc",
......@@ -4286,6 +4287,20 @@ cc_library(
],
)
cc_library(
name = "remote_fused_graph_execute_op_test_utils",
srcs = ["remote_fused_graph_execute_op_test_utils.cc"],
hdrs = ["remote_fused_graph_execute_op_test_utils.h"],
deps = [
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:ops",
"//tensorflow/cc:scope",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/kernels:cwise_op",
],
)
tf_cc_test(
name = "remote_fused_graph_execute_utils_test",
size = "small",
......@@ -4293,6 +4308,7 @@ tf_cc_test(
"remote_fused_graph_execute_utils_test.cc",
],
deps = [
":remote_fused_graph_execute_op_test_utils",
":remote_fused_graph_execute_utils",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:scope",
......@@ -4320,7 +4336,13 @@ tf_cc_test(
":ops_testutil",
":ops_util",
":remote_fused_graph_execute_op",
":remote_fused_graph_execute_op_test_utils",
":remote_fused_graph_execute_utils",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:ops",
"//tensorflow/cc:scope",
"//tensorflow/core:core_cpu",
"//tensorflow/core:direct_session",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
......
......@@ -96,6 +96,8 @@ GraphTransferUtils::BuildRemoteFusedGraphExecuteInfo(
CHECK(status.ok());
status = gt->LoadGraphFromProto(ops_definitions, def, inputs, outputs, false,
tensor_shape_map);
const DataType input_data_type =
inputs.empty() ? DT_FLOAT : inputs.at(0).second.dtype();
Scope root = Scope::NewRootScope();
std::vector<Output> output_list;
......@@ -103,6 +105,9 @@ GraphTransferUtils::BuildRemoteFusedGraphExecuteInfo(
const Scope& scope = root.WithOpName(input_node_info.first);
Node* ret;
const auto unique_name = scope.GetUniqueNameForOp("PlaceholderV2");
const DataType dt = input_node_info.second.dtype();
// DataType of input arguments should be same.
CHECK_EQ(input_data_type, dt);
auto builder = NodeBuilder(unique_name, "PlaceholderV2")
.Attr("dtype", input_node_info.second.dtype())
.Attr("shape", input_node_info.second.shape());
......@@ -115,6 +120,13 @@ GraphTransferUtils::BuildRemoteFusedGraphExecuteInfo(
const RemoteFusedGraphExecuteInfo execute_info =
BuildRemoteFusedGraphExecuteInfo(gt->GetGraphTransferInfo());
const DataType output_data_type =
outputs.empty() ? DT_FLOAT : tensor_shape_map.at(outputs.at(0)).first;
for (const string& output_node_name : outputs) {
const DataType dt = tensor_shape_map.at(output_node_name).first;
CHECK_EQ(output_data_type, dt);
}
const Scope& scope = root.WithOpName(remote_graph_execute_name);
CHECK(scope.ok());
auto node_out_list = ops::AsNodeOutList(scope, InputList(output_list));
......@@ -122,7 +134,10 @@ GraphTransferUtils::BuildRemoteFusedGraphExecuteInfo(
const auto unique_name = scope.GetUniqueNameForOp("RemoteFusedGraphExecute");
auto builder = NodeBuilder(unique_name, "RemoteFusedGraphExecute")
.Input(node_out_list)
.Attr("M", static_cast<int64>(output_list.size()))
.Attr("N", static_cast<int64>(outputs.size()))
.Attr("T", input_data_type)
.Attr("U", output_data_type)
.Attr("serialized_graph_transfer_info",
StringPiece(execute_info.SerializeAsString()));
CHECK(scope.ok());
......
......@@ -57,6 +57,7 @@ bool HexagonControlWrapper::Init(const RemoteFusedGraphExecuteInfo& info) {
}
graph_transferer_.SetSerializedGraphTransferInfo(
info.serialized_executor_parameters());
execute_info_ = &info;
return soc_interface_Init();
}
......@@ -274,7 +275,30 @@ bool HexagonControlWrapper::FillInputNode(const string& node_name,
}
bool HexagonControlWrapper::ReadOutputNode(
const string node_name, std::vector<ByteArray> *const outputs) {
const string& node_name, TensorAllocatorFunc tensor_allocator) {
CHECK_NE(execute_info_, nullptr);
TensorShape output_shape;
// TODO(satok): Switch shape corresponding to input shape
for (int i = 0; i < execute_info_->graph_output_node_name_size(); ++i) {
if (execute_info_->graph_output_node_name(i) == node_name) {
for (const TensorShapeProto::Dim& dim :
execute_info_->default_graph_output_tensor_shape(i).shape().dim()) {
output_shape.AddDim(dim.size());
}
break;
}
}
std::vector<IRemoteFusedGraphExecutor::ByteArray> outputs;
ReadOutputNode(node_name, &outputs);
Tensor* output = tensor_allocator(output_shape);
CHECK(output->TotalBytes() >= std::get<1>(outputs[0]));
// TODO(satok): Avoid specifying float
std::memcpy(output->flat<float>().data(), std::get<0>(outputs[0]),
std::get<1>(outputs[0]));
}
bool HexagonControlWrapper::ReadOutputNode(
const string& node_name, std::vector<ByteArray>* const outputs) {
CHECK(outputs != nullptr);
ByteArray output;
soc_interface_ReadOutputNodeFloat(node_name.c_str(), &std::get<0>(output),
......@@ -323,7 +347,11 @@ bool HexagonControlWrapper::FillInputNode(const string&, const ConstByteArray) {
bool HexagonControlWrapper::FillInputNode(const string&, const Tensor&) {
return false;
}
bool HexagonControlWrapper::ReadOutputNode(const string,
bool HexagonControlWrapper::ReadOutputNode(
const string& node_name, TensorAllocatorFunc tensor_allocator) {
return false;
}
bool HexagonControlWrapper::ReadOutputNode(const string&,
std::vector<ByteArray>* const) {
return false;
}
......
......@@ -39,24 +39,28 @@ class HexagonControlWrapper final : public IRemoteFusedGraphExecutor {
bool SetupGraph() final;
bool ExecuteGraph() final;
bool TeardownGraph() final;
bool FillInputNode(const string& node_name, const ConstByteArray bytes) final;
bool FillInputNode(const string& node_name, const Tensor& tensor) final;
bool ReadOutputNode(string node_name, std::vector<ByteArray>* outputs) final;
bool ReadOutputNode(const string& node_name,
TensorAllocatorFunc tensor_allocator) final;
bool ReadOutputNode(const string& node_name, std::vector<ByteArray>* outputs);
private:
bool FillInputNode(const string& node_name, const ConstByteArray bytes);
// CAVEAT: Need offset as HVX library reserves some ids
static constexpr int NODE_ID_OFFSET = 0x10000;
static GraphTransferInfo::NodeInfo* FindNodeInfo(
const string& node_name, GraphTransferInfo* graph_transfer_info);
GraphTransferer graph_transferer_;
const RemoteFusedGraphExecuteInfo* execute_info_{};
GraphTransferer graph_transferer_{};
// Dummy float array for input node.
// TODO(satok): Use actual data passed by FillInputNode and remove
std::vector<float> dummy_input_float_;
std::vector<float> dummy_input_float_{};
// Dummy byte array for cosnt node.
// TODO(satok): Remove
std::unordered_map<int, std::vector<uint8>> dummy_const_data_;
std::unordered_map<int, std::vector<uint8>> dummy_const_data_{};
TF_DISALLOW_COPY_AND_ASSIGN(HexagonControlWrapper);
};
......
......@@ -165,11 +165,20 @@ static void LoadImage(std::vector<float>* img_floats_ptr) {
}
}
static Tensor BuildImageTensor(const std::vector<float>& img_floats) {
LOG(INFO) << "Ioading image finished.";
Tensor img_tensor(DT_FLOAT, {1, WIDTH, HEIGHT, DEPTH});
CHECK_EQ(WIDTH * HEIGHT * DEPTH, img_floats.size());
CHECK_EQ(img_tensor.TotalBytes(), img_floats.size() * sizeof(float));
LOG(INFO) << "Copy data to tensor.";
std::memcpy(img_tensor.flat<float>().data(), img_floats.data(),
img_tensor.TotalBytes());
return img_tensor;
}
static void RunInferenceByHexagonControlWrapper(
const GraphTransferer& gt, const std::vector<float>& img_floats) {
const ConstByteArray ba =
std::make_tuple(reinterpret_cast<const uint8*>(img_floats.data()),
img_floats.size() * sizeof(float), DT_FLOAT);
const Tensor img_tensor = BuildImageTensor(img_floats);
const RemoteFusedGraphExecuteInfo execute_info =
GraphTransferUtils::BuildRemoteFusedGraphExecuteInfo(
......@@ -183,7 +192,7 @@ static void RunInferenceByHexagonControlWrapper(
hexagon_control_wrapper.SetupGraph();
// 3. Fill input node's output
hexagon_control_wrapper.FillInputNode("Mul", ba);
hexagon_control_wrapper.FillInputNode("Mul", img_tensor);
// 4. Execute graph
profile_utils::CpuUtils::EnableClockCycleProfiling(true);
......@@ -216,13 +225,7 @@ static void RunFusedGraph(const GraphDef& fused_graph_def) {
LoadImage(&img_floats);
LOG(INFO) << "Ioading image finished.";
Tensor img_tensor(DT_FLOAT, {1, WIDTH, HEIGHT, DEPTH});
ASSERT_EQ(WIDTH * HEIGHT * DEPTH, img_floats.size());
ASSERT_EQ(img_tensor.TotalBytes(), img_floats.size() * sizeof(float));
LOG(INFO) << "Copy data to tensor.";
std::memcpy(img_tensor.flat<float>().data(), img_floats.data(),
img_tensor.TotalBytes());
const Tensor img_tensor = BuildImageTensor(img_floats);
// Setup session
std::vector<Tensor> output_tensors;
......
......@@ -29,6 +29,7 @@ class IRemoteFusedGraphExecutor {
std::tuple<uint8* /* data */, uint64 /* size */, DataType /* type */>;
using ConstByteArray = std::tuple<const uint8* /* data */, uint64 /* size */,
DataType /* type */>;
using TensorAllocatorFunc = std::function<Tensor*(const TensorShape& shape)>;
IRemoteFusedGraphExecutor() = default;
virtual ~IRemoteFusedGraphExecutor() = default;
......@@ -55,16 +56,12 @@ class IRemoteFusedGraphExecutor {
// Teardown Graph
virtual bool TeardownGraph() = 0;
// Fill input node's output with a ByteArray
virtual bool FillInputNode(const string& node_name,
const ConstByteArray bytes) = 0;
// Fill input node's output with Tensor
virtual bool FillInputNode(const string& node_name, const Tensor& tensor) = 0;
// Read output node's outputs as ByteArrays
virtual bool ReadOutputNode(string node_name,
std::vector<ByteArray>* outputs) = 0;
virtual bool ReadOutputNode(const string& node_name,
TensorAllocatorFunc tensor_allocator) = 0;
private:
TF_DISALLOW_COPY_AND_ASSIGN(IRemoteFusedGraphExecutor);
......
......@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h"
#include "tensorflow/core/kernels/i_remote_fused_graph_executor.h"
#include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
......@@ -92,27 +93,14 @@ class RemoteFusedGraphExecuteOp : public OpKernel {
CHECK(output_count == execute_info_.graph_output_node_name_size());
for (int i = 0; i < output_count; ++i) {
Tensor* output = nullptr;
TensorShape output_shape;
const string& output_node_name = execute_info_.graph_output_node_name(i);
// TODO(satok): Switch shape corresponding to input shape
for (const TensorShapeProto::Dim& dim :
execute_info_.default_graph_output_tensor_shape(i).shape().dim()) {
output_shape.AddDim(dim.size());
}
OP_REQUIRES_OK(ctx, ctx->allocate_output(i, output_shape, &output));
if (remote_fused_graph_executor_) {
std::vector<IRemoteFusedGraphExecutor::ByteArray> outputs;
remote_fused_graph_executor_->ReadOutputNode(output_node_name,
&outputs);
// TODO(satok): Remove this check (<= 1). And support multiple outputs
// for each output node
CHECK(outputs.size() <= 1);
if (!outputs.empty()) {
CHECK(output->TotalBytes() >= std::get<1>(outputs[0]));
// TODO(satok): Avoid specifying float
std::memcpy(output->flat<float>().data(), std::get<0>(outputs[0]),
std::get<1>(outputs[0]));
}
remote_fused_graph_executor_->ReadOutputNode(
output_node_name,
[i, &ctx, &output](const TensorShape& shape) -> Tensor* {
TF_CHECK_OK(ctx->allocate_output(i, shape, &output));
return output;
});
}
}
}
......
......@@ -13,16 +13,25 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/graph/testlib.h"
#include "tensorflow/core/kernels/i_remote_fused_graph_executor.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h"
#include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
......@@ -35,10 +44,257 @@ TEST_F(RemoteFusedGraphExecuteTest, ExecuteAddGraph) {
.Attr("M", 2)
.Attr("N", 1)
.Attr("T", DataTypeToEnum<float>::v())
.Attr("U", DataTypeToEnum<float>::v())
.Attr("serialized_graph_transfer_info", "")
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
// TODO(satok): Add benchmark
}
////////////////////////////
// End-to-end test: Begin //
////////////////////////////
// This test does a end-to-end test for a simple usage of
// RemoteFusedGraphExecuteOp.
constexpr const char* const NAME_A = "a";
constexpr const char* const NAME_B = "b";
constexpr const char* const NAME_A_PLUS_B = "a_plus_b";
constexpr const char* const REMOTE_FUSED_EXECUTE_OP_NODE_NAME =
"remote_fused_execute_op";
constexpr const char* const REMOTE_FUSED_EXECUTOR_NAME =
"build_test_remote_fused_graph_executor";
constexpr float NODE_A_VAL = 2.0f;
constexpr float NODE_A_VAL2 = 10.0f;
constexpr float NODE_B_VAL = 3.0f;
constexpr float FLOAT_VALUE_TOLERANCE = 1e-8f;
// Utility functions //
static Output BuildPlaceHolderOp(const string& name, const DataType dt,
const TensorShape& tensor_shape, Scope* root) {
const Scope& scope = root->WithOpName(name);
Node* ret;
const string unique_name = scope.GetUniqueNameForOp("PlaceholderV2");
NodeBuilder builder = NodeBuilder(unique_name, "PlaceholderV2")
.Attr("dtype", dt)
.Attr("shape", tensor_shape);
scope.UpdateBuilder(&builder);
scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
CHECK(scope.ok());
return Output(ret, 0);
}
static Output BuildRemoteFusedGraphExecuteOp(
const string& name, const std::vector<Output>& output_list,
const int output_node_count,
const RemoteFusedGraphExecuteInfo& execute_info, Scope* root) {
const Scope& scope = root->WithOpName(name);
Node* ret;
CHECK(scope.ok());
auto node_out_list = ops::AsNodeOutList(scope, InputList(output_list));
const auto unique_name = scope.GetUniqueNameForOp("RemoteFusedGraphExecute");
auto builder = NodeBuilder(unique_name, "RemoteFusedGraphExecute")
.Input(node_out_list)
.Attr("M", static_cast<int64>(output_list.size()))
.Attr("N", static_cast<int64>(output_node_count))
.Attr("T", DT_FLOAT)
.Attr("U", DT_FLOAT)
.Attr("serialized_graph_transfer_info",
StringPiece(execute_info.SerializeAsString()));
CHECK(scope.ok());
scope.UpdateBuilder(&builder);
scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
CHECK(scope.ok());
return Output(ret, 0);
}
static RemoteFusedGraphExecuteInfo BuildRemoteFusedGraphExecuteInfo(
const GraphDef& original_graph) {
RemoteFusedGraphExecuteInfo execute_info;
execute_info.set_executor_name(REMOTE_FUSED_EXECUTOR_NAME);
// In this example, simply copy all nodes. Basically, you don't need to add
// unused node for inference.
for (const NodeDef& node : original_graph.node()) {
NodeDef& copied_node = *execute_info.add_node();
copied_node = node;
// Adding tensor shape type to the node
// TODO(satok): Use TensorShapeMap to detime tensor shape type
RemoteFusedGraphExecuteUtils::AddOutputTensorShapeType(
std::vector<DataType>({DT_FLOAT}),
std::vector<TensorShape>({TensorShape()}), &copied_node);
}
// Add node A as input
execute_info.add_graph_input_node_name(NAME_A);
RemoteFusedGraphExecuteInfo::TensorShapeTypeProto& shape_a =
*execute_info.add_default_graph_input_tensor_shape();
shape_a.set_dtype(DT_FLOAT);
// (skip setting shape to shape_a as it's shape is rank = 0.)
// Add node A + B as output
execute_info.add_graph_output_node_name(NAME_A_PLUS_B);
RemoteFusedGraphExecuteInfo::TensorShapeTypeProto& shape_a_plus_b =
*execute_info.add_default_graph_output_tensor_shape();
shape_a_plus_b.set_dtype(DT_FLOAT);
// (skip setting shape to shape_a_plus_b as it's shape is rank = 0.)
return execute_info;
}
// 1. Create TestRemoteFusedGraphExecutor to execute your fused graph
class TestRemoteFusedGraphExecutor final : public IRemoteFusedGraphExecutor {
public:
int GetVersion() final { return 1; }
bool Init(const RemoteFusedGraphExecuteInfo& info) final {
info_ = &info;
for (const NodeDef& node_def : info.node()) {
node_def_map_.emplace(node_def.name(), &node_def);
}
return true;
}
bool Finalize() final { return true; }
bool SetupGraph() final { return true; }
bool ExecuteGraph() final {
CHECK(info_ != nullptr);
// TODO(satok): Add utilities to implement this function more easily.
// CAVEAT: This test only handles add op. You can implement here as you
// like.
CHECK_EQ(1, info_->graph_input_node_name_size());
const string& input_node_name = info_->graph_input_node_name(0);
const Tensor& input_tensor = input_tensor_cache_[input_node_name];
const float input_val = *input_tensor.scalar<float>().data();
// TODO(satok): Read NAME_B from node_a_plus_b
const NodeDef& node_b = *node_def_map_.at(NAME_B);
const TensorProto* proto = nullptr;
GetNodeAttr(node_b, "value", &proto);
Tensor const_tensor;
RemoteFusedGraphExecuteUtils::MakeTensorFromProto(*proto, &const_tensor);
const float b_val = *const_tensor.scalar<float>().data();
Tensor output_a_plus_b(DT_FLOAT, {});
output_a_plus_b.flat<float>().data()[0] = input_val + b_val;
output_tensor_buf_.emplace(info_->graph_output_node_name(0),
output_a_plus_b);
return true;
}
bool TeardownGraph() final { return true; }
bool FillInputNode(const string& node_name, const Tensor& tensor) final {
input_tensor_cache_[node_name] = tensor;
return true;
}
bool ReadOutputNode(const string& node_name,
TensorAllocatorFunc tensor_allocator) final {
// TODO(satok): Specify tensor shape by using default_graph_tensor_shape.
const Tensor& buffered_output_tensor = output_tensor_buf_.at(node_name);
const TensorShape& output_shape = buffered_output_tensor.shape();
Tensor* output_tensor = tensor_allocator(output_shape);
CHECK_EQ(buffered_output_tensor.dtype(), output_tensor->dtype());
CHECK(output_tensor->CopyFrom(buffered_output_tensor, output_shape));
return true;
}
private:
const RemoteFusedGraphExecuteInfo* info_;
std::unordered_map<string, Tensor> input_tensor_cache_;
std::unordered_map<string, const NodeDef*> node_def_map_;
std::unordered_map<string, Tensor> output_tensor_buf_;
};
// 2. Register a builder of your custom executor
namespace remote_fused_graph_execute_op {
Status BuildRemoteFusedGraphExecutor(
std::unique_ptr<IRemoteFusedGraphExecutor>* executor) {
executor->reset(new TestRemoteFusedGraphExecutor());
return Status::OK();
}
// This class instantiation registers executor to the
// RemoteFusedGraphExecuteOp. This architecture makes executors to be
// pluggable in order not to link unnecessary libraries.
static RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar
k_test_remote_fused_graph_executor_build(REMOTE_FUSED_EXECUTOR_NAME,
BuildRemoteFusedGraphExecutor);
} // namespace remote_fused_graph_execute_op
// 3. Create Graph transform function to fuse your graph
static Status RewriteGraphToFusedGraph(const GraphDef& original_graph,
GraphDef* fused_graph) {
Scope root = Scope::NewRootScope();
std::vector<Output> output_list;
const Output op_a = BuildPlaceHolderOp(NAME_A, DT_FLOAT, {}, &root);
output_list.emplace_back(op_a);
const RemoteFusedGraphExecuteInfo execute_info =
BuildRemoteFusedGraphExecuteInfo(original_graph);
BuildRemoteFusedGraphExecuteOp(REMOTE_FUSED_EXECUTE_OP_NODE_NAME, output_list,
1, execute_info, &root);
GraphDef fused_graph_def;
TF_CHECK_OK(root.ToGraphDef(&fused_graph_def));
*fused_graph = fused_graph_def;
return Status::OK();
}
// 4. Register transform function
// You can register transform function by REGISTER_GRAPH_TRANSFORM.
// In this test, we don't use graph transform tool to avoid linking to
// the graph transform library.
// To register transform function, you need to change the interface of
// BuildFusedGraphDefOfAddGraph to
// Status BuildFusedGraphDefOfAddGraph(
// const GraphDef& original_graph, const TransformFuncContext& context,
// GraphDef* output_graph_def);
// Then register the function like:
// REGISTER_GRAPH_TRANSFORM("rewrite_graph", RewriteGraph);
// 5. Fuse the original graph and run the inference the new fused graph
TEST(RemoteFusedExecuteGraphOp, EndToEndTest) {
// 5.1 Load original graph
const GraphDef original_graph =
RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph(
NAME_A, NODE_A_VAL, NAME_B, NODE_B_VAL, NAME_A_PLUS_B);
// 5.2 Fuse graph
GraphDef fused_graph;
RewriteGraphToFusedGraph(original_graph, &fused_graph);
// 5.3 Setup session
std::vector<Tensor> output_tensors;
SessionOptions session_options;
session_options.env = Env::Default();
std::unique_ptr<Session> session =
std::unique_ptr<Session>(NewSession(session_options));
Status status = session->Create(fused_graph);
ASSERT_TRUE(status.ok());
RunOptions run_options;
run_options.set_trace_level(RunOptions::FULL_TRACE);
RunMetadata run_metadata;
// 5.4 Setup input
Tensor input_a(DT_FLOAT, {});
input_a.flat<float>().data()[0] = NODE_A_VAL2;
std::vector<std::pair<string, Tensor>> inputs;
inputs.emplace_back(NAME_A, input_a);
// 5.5 Setup output
const std::vector<string> outputs{REMOTE_FUSED_EXECUTE_OP_NODE_NAME};
// 5.6 Run inference with all node as output
status = session->Run(run_options, inputs, outputs, {}, &output_tensors,
&run_metadata);
ASSERT_TRUE(status.ok());
// 5.7 Check output tensor value
ASSERT_EQ(1, output_tensors.size());
EXPECT_NEAR(NODE_A_VAL2 + NODE_B_VAL,
output_tensors.at(0).flat<float>().data()[0],
FLOAT_VALUE_TOLERANCE);
}
////////////////////////////
// End-to-end test: End //
////////////////////////////
} // namespace tensorflow
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h"
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
/* static */ Output RemoteFusedGraphExecuteOpTestUtils::BuildAddOp(
const Scope& scope, const Input& x, const Input& y) {
CHECK(scope.ok());
auto _x = ops::AsNodeOut(scope, x);
CHECK(scope.ok());
auto _y = ops::AsNodeOut(scope, y);
CHECK(scope.ok());
Node* ret;
const auto unique_name = scope.GetUniqueNameForOp("Add");
auto builder = NodeBuilder(unique_name, "Add").Input(_x).Input(_y);
scope.UpdateBuilder(&builder);
scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
CHECK(scope.ok()) << scope.status();
return Output(ret, 0);
}
/* static */ GraphDef RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph(
const string& name0, const float val0, const string& name1,
const float val1, const string& name_out) {
Scope root = Scope::NewRootScope();
Output node0 = ops::Const(root.WithOpName(name0), val0);
Output node1 = ops::Const(root.WithOpName(name1), val1);
RemoteFusedGraphExecuteOpTestUtils::BuildAddOp(root.WithOpName(name_out),
node0, node1);
GraphDef def;
TF_CHECK_OK(root.ToGraphDef(&def));
return def;
}
} // namespace tensorflow
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_REMOTE_FUSED_GRAPH_EXECUTE_OP_TEST_UTILS_H_
#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_REMOTE_FUSED_GRAPH_EXECUTE_OP_TEST_UTILS_H_
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/platform/macros.h"
namespace tensorflow {
// RemoteFusedGraphExecuteOpTestUtils is a set of utilities in tests for
// RemoteFusedGraphExecuteOp.
class RemoteFusedGraphExecuteOpTestUtils {
public:
static Output BuildAddOp(const Scope& scope, const Input& x, const Input& y);
static GraphDef BuildAddGraph(const string& name0, const float val0,
const string& name1, const float val1,
const string& name_out);
private:
RemoteFusedGraphExecuteOpTestUtils() = delete;
TF_DISALLOW_COPY_AND_ASSIGN(RemoteFusedGraphExecuteOpTestUtils);
};
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_REMOTE_FUSED_GRAPH_EXECUTE_OP_TEST_UTILS_H_
......@@ -17,11 +17,17 @@ limitations under the License.
#include <utility>
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
/* static */ constexpr const char* const
RemoteFusedGraphExecuteUtils::ATTR_OUTPUT_DATA_TYPES;
/* static */ constexpr const char* const
RemoteFusedGraphExecuteUtils::ATTR_OUTPUT_SHAPES;
RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar::ExecutorBuildRegistrar(
const string& name, ExecutorBuildFunc executor_build_func) {
ExecutorBuildRegistry& executor_build_registry = *GetExecutorBuildRegistry();
......@@ -195,4 +201,25 @@ RemoteFusedGraphExecuteUtils::GetExecutorBuildRegistry() {
}
}
/* static */ Status RemoteFusedGraphExecuteUtils::MakeTensorFromProto(
const TensorProto& tensor_proto, Tensor* tensor) {
if (tensor_proto.dtype() > 0 && tensor_proto.dtype() <= DataType_MAX) {
Tensor parsed(tensor_proto.dtype());
if (parsed.FromProto(cpu_allocator(), tensor_proto)) {
*tensor = parsed;
return Status::OK();
}
}
return errors::InvalidArgument("Cannot parse tensor from proto");
}
/* static */ bool RemoteFusedGraphExecuteUtils::AddOutputTensorShapeType(
const std::vector<DataType>& data_types,
const std::vector<TensorShape>& shapes, NodeDef* node_def) {
// const gtl::ArraySlice<DataType> data_types_array(data_types);
AddNodeAttr(ATTR_OUTPUT_DATA_TYPES, data_types, node_def);
AddNodeAttr(ATTR_OUTPUT_SHAPES, shapes, node_def);
return true;
}
} // namespace tensorflow
......@@ -30,6 +30,10 @@ namespace tensorflow {
// functions for IRemoteFusedGraphExecutor.
class RemoteFusedGraphExecuteUtils {
public:
static constexpr const char* const ATTR_OUTPUT_DATA_TYPES =
"_output_data_types";
static constexpr const char* const ATTR_OUTPUT_SHAPES = "_output_shapes";
using ExecutorBuildFunc = std::function<Status(
std::unique_ptr<IRemoteFusedGraphExecutor>* executor)>;
// Registrar class for IRemoteFusedGraphExecutor.
......@@ -78,6 +82,13 @@ class RemoteFusedGraphExecuteUtils {
const std::vector<tensorflow::Tensor>& output_tensors,
TensorShapeMap* tensor_shape_map);
static Status MakeTensorFromProto(const TensorProto& tensor_proto,
Tensor* tensor);
static bool AddOutputTensorShapeType(const std::vector<DataType>& data_types,
const std::vector<TensorShape>& shapes,
NodeDef* node_def);
private:
static ExecutorBuildRegistry* GetExecutorBuildRegistry();
......
......@@ -16,45 +16,22 @@ limitations under the License.
#include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h"
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
const string NAME_A = "a";
const string NAME_B = "b";
const string NAME_A_PLUS_B = "a_plus_b";
constexpr const char* const NAME_A = "a";
constexpr const char* const NAME_B = "b";
constexpr const char* const NAME_A_PLUS_B = "a_plus_b";
constexpr float NODE_A_VAL = 2.0f;
constexpr float NODE_B_VAL = 3.0f;
constexpr float VALUE_TOLERANCE_FLOAT = 1e-8f;
static Output BuildAddOps(const Scope& scope, const Input& x, const Input& y) {
EXPECT_TRUE(scope.ok());
auto _x = ops::AsNodeOut(scope, x);
EXPECT_TRUE(scope.ok());
auto _y = ops::AsNodeOut(scope, y);
EXPECT_TRUE(scope.ok());
Node* ret;
const auto unique_name = scope.GetUniqueNameForOp("Add");
auto builder = NodeBuilder(unique_name, "Add").Input(_x).Input(_y);
scope.UpdateBuilder(&builder);
scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
EXPECT_TRUE(scope.ok());
return Output(ret, 0);
}
static GraphDef CreateAddGraphDef() {
Scope root = Scope::NewRootScope();
Output node_a = ops::Const(root.WithOpName(NAME_A), NODE_A_VAL);
Output node_b = ops::Const(root.WithOpName(NAME_B), NODE_B_VAL);
Output node_add = BuildAddOps(root.WithOpName(NAME_A_PLUS_B), node_a, node_b);
GraphDef def;
TF_CHECK_OK(root.ToGraphDef(&def));
return def;
}
TEST(RemoteFusedGraphExecuteUtils, DryRunAddGraphA) {
GraphDef def = CreateAddGraphDef();
GraphDef def = RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph(
NAME_A, NODE_A_VAL, NAME_B, NODE_B_VAL, NAME_A_PLUS_B);
std::pair<string, Tensor> input_node_info;
input_node_info.first = NAME_A;
input_node_info.second = Tensor(DT_FLOAT, {});
......@@ -73,7 +50,8 @@ TEST(RemoteFusedGraphExecuteUtils, DryRunAddGraphA) {
}
TEST(RemoteFusedGraphExecuteUtils, DryRunAddGraphAUninitialized) {
GraphDef def = CreateAddGraphDef();
GraphDef def = RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph(
NAME_A, NODE_A_VAL, NAME_B, NODE_B_VAL, NAME_A_PLUS_B);
std::pair<string, Tensor> input_node_info;
input_node_info.first = NAME_A;
input_node_info.second = Tensor(DT_FLOAT, {});
......@@ -91,7 +69,8 @@ TEST(RemoteFusedGraphExecuteUtils, DryRunAddGraphAUninitialized) {
}
TEST(RemoteFusedGraphExecuteUtils, DryRunAddGraphAB) {
GraphDef def = CreateAddGraphDef();
GraphDef def = RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph(
NAME_A, NODE_A_VAL, NAME_B, NODE_B_VAL, NAME_A_PLUS_B);
std::pair<string, Tensor> input_node_info_a;
input_node_info_a.first = NAME_A;
input_node_info_a.second = Tensor(DT_FLOAT, {});
......@@ -122,7 +101,9 @@ TEST(RemoteFusedGraphExecuteUtils, DryRunAddGraphForAllNodes) {
// Setup dryrun arguments
const std::vector<std::pair<string, Tensor>> inputs{input_node_info_a};
RemoteFusedGraphExecuteUtils::TensorShapeMap output_tensor_info;
GraphDef def = CreateAddGraphDef();
GraphDef def = RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph(
NAME_A, NODE_A_VAL, NAME_B, NODE_B_VAL, NAME_A_PLUS_B);
// dryrun
const Status status = RemoteFusedGraphExecuteUtils::DryRunInferenceForAllNode(
......
......@@ -22,10 +22,11 @@ namespace tensorflow {
// TODO(satok): Implement shape_inference
REGISTER_OP("RemoteFusedGraphExecute")
.Input("values: M * T")
.Output("output: N * T")
.Output("output: N * U")
.Attr("M: int >= 0")
.Attr("N: int >= 0")
.Attr("T: type")
.Attr("U: type")
.Attr("serialized_graph_transfer_info: string")
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
......
......@@ -35,6 +35,8 @@ TEST(RemoteFusedGraphOpsTest, RemoteFusedGraphExecute_ShapeFn) {
.Input(src_list)
.Attr("M", input_count)
.Attr("N", output_count)
.Attr("T", DT_FLOAT)
.Attr("U", DT_FLOAT)
.Finalize(&op.node_def));
};
set_n(4, 2);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册