提交 7b8e31c5 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Factor out shape inference propagation to RemoteFusedGraphExecuteUtils

Change: 149984977
上级 6121fe5b
......@@ -57,7 +57,7 @@ static string ToString(T val) {
/**
* graph loading functions
* - LoadGraphFromProto
* - LoadGraphFromProtoFile
* - LoadGraphFromProptoFile
* These functions read a graph definition and store parameters
* of node to transfer the graph to SOC.
*/
......@@ -67,60 +67,19 @@ Status GraphTransferer::LoadGraphFromProto(
const std::vector<std::pair<string, Tensor>>& input_node_info_list,
const std::vector<string>& output_node_names,
const bool shape_inference_for_unknown_shape,
const OutputTensorMap& output_tensor_map) {
const TensorShapeMap& output_tensor_map) {
ImportGraphDefOptions opts;
Graph graph(OpRegistry::Global());
ShapeRefiner shape_refiner(graph.versions().producer(), graph.op_registry());
VLOG(1) << "Start import graph";
Status status = ImportGraphDef(opts, graph_def, &graph, &shape_refiner);
if (!status.ok()) {
VLOG(1) << "Failed to import graph " << status.ToString();
return status;
}
if (shape_inference_for_unknown_shape && !input_node_info_list.empty()) {
auto visit = [&shape_refiner, &input_node_info_list, &status](Node* node) {
if (!status.ok()) {
return;
}
CHECK_NE(node, nullptr);
// If we visit an input node, we use the shape provided and set the
// shape accordingly.
bool is_input_node = false;
for (const std::pair<string, Tensor>& input_node_info :
input_node_info_list) {
if (node->name() == input_node_info.first) {
shape_inference::InferenceContext* context =
shape_refiner.GetContext(node);
shape_inference::ShapeHandle handle;
status = context->MakeShapeFromTensorShape(
input_node_info.second.shape(), &handle);
// TODO(b/32704451): Don't just ignore this status!
shape_refiner.SetShape(node, 0, handle).IgnoreError();
is_input_node = true;
}
if (!status.ok()) {
break;
}
}
// If not an input node call AddNode() that recomputes the shape.
if (!is_input_node && status.ok()) {
status = shape_refiner.AddNode(node);
if (!status.ok()) {
VLOG(1) << "Shape inference failed for node: " << node->name();
}
}
};
// Runs a reverse DFS over the entire graph setting the shape for the input
// nodes provided and then recomputing the shape of all the nodes downstream
// from them. The "visit" function is executed for each node after all its
// parents have been visited.
ReverseDFS(graph, {}, visit);
if (shape_inference_for_unknown_shape) {
status = RemoteFusedGraphExecuteUtils::PropagateShapeInference(
graph_def, input_node_info_list, &graph, &shape_refiner);
if (!status.ok()) {
VLOG(1) << "Failed to run shape inference: " << status.ToString();
return status;
}
}
......@@ -149,6 +108,7 @@ Status GraphTransferer::LoadGraphFromProto(
return status;
}
}
SortParams(output_node_names);
for (const std::pair<string, Tensor>& input_node_info :
......@@ -319,7 +279,7 @@ bool GraphTransferer::AreAllInputsCached(const Node& node) const {
Status GraphTransferer::RegisterNode(
const IGraphTransferOpsDefinitions& ops_definitions,
const ShapeRefiner& shape_refiner, const OutputTensorMap& output_tensor_map,
const ShapeRefiner& shape_refiner, const TensorShapeMap& output_tensor_map,
const Node& node,
const std::vector<std::pair<string, Tensor>>& input_node_info_list,
const std::vector<string>& output_node_names) {
......@@ -352,7 +312,7 @@ Status GraphTransferer::RegisterNode(
void GraphTransferer::RegisterConstantNode(
const ShapeRefiner& shape_refiner, const Node& node,
const OutputTensorMap& output_tensor_map) {
const TensorShapeMap& output_tensor_map) {
VLOG(1) << "Register constant node: " << node.name();
CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1);
const int id = node_name_to_id_cache_map_[node.name()];
......@@ -439,7 +399,7 @@ bool GraphTransferer::HasPaddingAndStrides(const Node& node) {
}
bool GraphTransferer::IsNodeFlattenReshape(
const Node& node, const OutputTensorMap& output_tensor_map,
const Node& node, const TensorShapeMap& output_tensor_map,
const ShapeRefiner& shape_refiner) {
// Check if node is reshape op
if (node.type_string() != RESHAPE_NODE_TYPE_STRING) {
......@@ -477,7 +437,7 @@ bool GraphTransferer::IsNodeFlattenReshape(
void GraphTransferer::RegisterNodeWithPaddingAndStrides(
const IGraphTransferOpsDefinitions& ops_definitions,
const ShapeRefiner& shape_refiner, const OutputTensorMap& output_tensor_map,
const ShapeRefiner& shape_refiner, const TensorShapeMap& output_tensor_map,
const Node& node) {
CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1);
const int id = node_name_to_id_cache_map_[node.name()];
......@@ -512,7 +472,7 @@ void GraphTransferer::RegisterNodeWithPaddingAndStrides(
void GraphTransferer::RegisterInputNode(
const IGraphTransferOpsDefinitions& ops_definitions,
const ShapeRefiner& shape_refiner, const OutputTensorMap& output_tensor_map,
const ShapeRefiner& shape_refiner, const TensorShapeMap& output_tensor_map,
const Node& node) {
VLOG(1) << "Register input node: " << node.name();
CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1);
......@@ -530,7 +490,7 @@ void GraphTransferer::RegisterInputNode(
void GraphTransferer::RegisterFlattenNode(
const IGraphTransferOpsDefinitions& ops_definitions,
const ShapeRefiner& shape_refiner, const OutputTensorMap& output_tensor_map,
const ShapeRefiner& shape_refiner, const TensorShapeMap& output_tensor_map,
const Node& node) {
VLOG(1) << "Register flatten node: " << node.name();
CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1);
......@@ -547,7 +507,7 @@ void GraphTransferer::RegisterFlattenNode(
void GraphTransferer::RegisterGenericNode(
const IGraphTransferOpsDefinitions& ops_definitions,
const ShapeRefiner& shape_refiner, const OutputTensorMap& output_tensor_map,
const ShapeRefiner& shape_refiner, const TensorShapeMap& output_tensor_map,
const Node& node) {
VLOG(1) << "Register generic node: " << node.name();
CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1);
......@@ -569,7 +529,7 @@ Status GraphTransferer::RegisterNodeIfAllInputsAreCached(
const bool only_register_const_node,
const std::vector<std::pair<string, Tensor>>& input_node_info_list,
const std::vector<string>& output_node_names,
const OutputTensorMap& output_tensor_map) {
const TensorShapeMap& output_tensor_map) {
if (only_register_const_node && !node.IsConstant()) {
return Status();
}
......@@ -627,7 +587,7 @@ void GraphTransferer::AppendNodeInputParams(
}
void GraphTransferer::AppendNodeOutputParams(
const ShapeRefiner& shape_refiner, const OutputTensorMap& output_tensor_map,
const ShapeRefiner& shape_refiner, const TensorShapeMap& output_tensor_map,
const int id, const Node& node) {
VLOG(1) << "Append output params: " << node.name() << ", "
<< node.num_outputs();
......@@ -670,7 +630,7 @@ void GraphTransferer::AppendNodeOutputParams(
}
void GraphTransferer::AppendNodeParamsWithIoParams(
const ShapeRefiner& shape_refiner, const OutputTensorMap& output_tensor_map,
const ShapeRefiner& shape_refiner, const TensorShapeMap& output_tensor_map,
const Node& node, const string& name, const int id, const string& type,
const int type_id, const int padding, const int inputs_size,
const std::vector<int>& extra_inputs, const int outputs_size,
......@@ -757,7 +717,7 @@ GraphTransferer::ToTensorShapeArray(const TensorShape& shape) {
}
/* static */ void GraphTransferer::CheckShape(
const OutputTensorMap& output_tensor_map, const string& node_name,
const TensorShapeMap& output_tensor_map, const string& node_name,
const std::array<int64, SHAPE_ARRAY_SIZE>& expected) {
if (output_tensor_map.empty()) {
// As output_tensor_map is empty, skip checking tensor shape.
......
......@@ -45,7 +45,7 @@ class GraphTransferer {
static constexpr int MAX_SUPPORTED_RANK = 4;
// TODO(satok): Remove. Use proto definition instead.
static constexpr int SHAPE_ARRAY_SIZE = MAX_SUPPORTED_RANK;
using OutputTensorMap = RemoteFusedGraphExecuteUtils::TensorShapeMap;
using TensorShapeMap = RemoteFusedGraphExecuteUtils::TensorShapeMap;
GraphTransferer() = default;
......@@ -58,7 +58,7 @@ class GraphTransferer {
const std::vector<std::pair<string, Tensor>>& input_node_info_list,
const std::vector<string>& output_node_names,
const bool shape_inference_for_unkown_shape,
const OutputTensorMap& output_tensor_map);
const TensorShapeMap& output_tensor_map);
// Load graph structure into GraphTransferer from protobuf file
// TODO(satok): Pass a pair of TensorShape and DataType instead of
......@@ -107,12 +107,12 @@ class GraphTransferer {
Status RegisterNode(
const IGraphTransferOpsDefinitions& ops_definitions,
const ShapeRefiner& shape_refiner,
const OutputTensorMap& output_tensor_map, const Node& node,
const TensorShapeMap& output_tensor_map, const Node& node,
const std::vector<std::pair<string, Tensor>>& input_node_info_list,
const std::vector<string>& output_node_names);
void RegisterConstantNode(const ShapeRefiner& shape_refiner, const Node& node,
const OutputTensorMap& output_tensor_map);
const TensorShapeMap& output_tensor_map);
int RegisterConstantShape(const std::vector<int>& shape);
......@@ -122,27 +122,27 @@ class GraphTransferer {
// TODO(satok): Remove this method once generic reshape op is implemented in
// SOC
bool IsNodeFlattenReshape(const Node& node,
const OutputTensorMap& output_tensor_map,
const TensorShapeMap& output_tensor_map,
const ShapeRefiner& shape_refiner);
void RegisterNodeWithPaddingAndStrides(
const IGraphTransferOpsDefinitions& ops_definitions,
const ShapeRefiner& shape_refiner,
const OutputTensorMap& output_tensor_map, const Node& node);
const TensorShapeMap& output_tensor_map, const Node& node);
void RegisterInputNode(const IGraphTransferOpsDefinitions& ops_definitions,
const ShapeRefiner& shape_refiner,
const OutputTensorMap& output_tensor_map,
const TensorShapeMap& output_tensor_map,
const Node& node);
void RegisterFlattenNode(const IGraphTransferOpsDefinitions& ops_definitions,
const ShapeRefiner& shape_refiner,
const OutputTensorMap& output_tensor_map,
const TensorShapeMap& output_tensor_map,
const Node& node);
void RegisterGenericNode(const IGraphTransferOpsDefinitions& ops_definitions,
const ShapeRefiner& shape_refiner,
const OutputTensorMap& output_tensor_map,
const TensorShapeMap& output_tensor_map,
const Node& node);
Status RegisterNodeIfAllInputsAreCached(
......@@ -151,7 +151,7 @@ class GraphTransferer {
const bool only_register_const_node,
const std::vector<std::pair<string, Tensor>>& input_node_info_list,
const std::vector<string>& output_node_names,
const OutputTensorMap& output_tensor_map);
const TensorShapeMap& output_tensor_map);
void AppendNodeParams(const string& name, const int id, const string& type,
const int type_id, const int padding,
......@@ -163,7 +163,7 @@ class GraphTransferer {
const std::vector<int>& extra_inputs);
void AppendNodeOutputParams(const ShapeRefiner& shape_refiner,
const OutputTensorMap& output_tensor_map,
const TensorShapeMap& output_tensor_map,
const int id, const Node& node);
static std::array<int64, SHAPE_ARRAY_SIZE> BuildShapeArray(
......@@ -172,7 +172,7 @@ class GraphTransferer {
void AppendNodeParamsWithIoParams(
const ShapeRefiner& shape_refiner,
const OutputTensorMap& output_tensor_map, const Node& node,
const TensorShapeMap& output_tensor_map, const Node& node,
const string& name, const int id, const string& type, const int type_id,
const int padding, const int inputs_size,
const std::vector<int>& extra_inputs, const int outputs_size,
......@@ -183,7 +183,7 @@ class GraphTransferer {
static string ToPaddingDebugString(int padding);
static void CheckShape(const OutputTensorMap& output_tensor_map,
static void CheckShape(const TensorShapeMap& output_tensor_map,
const string& node_name,
const std::array<int64, SHAPE_ARRAY_SIZE>& actual);
......
......@@ -49,7 +49,7 @@ class GraphTransfererTest : public ::testing::Test {
static const std::vector<string> OP_TYPES{
"INPUT", "OUTPUT", "Conv2D", "MaxPool", "NoOp", "Add", "Const", "Softmax"};
const GraphTransferer::OutputTensorMap EMPTY_OUTPUT_TENSOR_MAP;
const RemoteFusedGraphExecuteUtils::TensorShapeMap EMPTY_OUTPUT_TENSOR_MAP;
class TestGraphTransferOpsDefinitions : public IGraphTransferOpsDefinitions {
public:
......
......@@ -56,6 +56,7 @@ constexpr const char* const FUSED_MODEL_FILENAME =
"tensorflow_inception_v3_stripped_optimized_quantized_fused_hexagon.pb";
constexpr const char* const REMOTE_FUSED_GRAPH_EXECUTE_NODE_NAME =
"remote_fused_graph_execute_node";
constexpr bool USE_SHAPE_INFERENCE = false;
const bool DBG_DUMP_FLOAT_DATA = false;
const int WIDTH = 299;
......@@ -282,11 +283,18 @@ TEST(GraphTransferer,
RemoteFusedGraphExecuteUtils::TensorShapeMap output_tensor_info;
GraphTransferer gt;
gt.EnableStrictCheckMode(false);
profile_utils::CpuUtils::EnableClockCycleProfiling(true);
ClockCycleProfiler prof;
prof.Start();
Status status = gt.LoadGraphFromProtoFile(
*ops_definitions, MODEL_FILENAME, inputs, output_node_names,
false /* is_text_proto */, false /* shape_inference_for_unknown_shape */,
true /* dry_run_for_unknown_shape */, &output_tensor_info);
false, // is_text_proto
USE_SHAPE_INFERENCE, // shape_inference_for_unknown_shape
!USE_SHAPE_INFERENCE, // dry_run_for_unknown_shape
&output_tensor_info);
ASSERT_TRUE(status.ok()) << status;
prof.Stop();
prof.DumpStatistics("LoadGraphFromProtoFile");
std::vector<float> img_floats;
LoadImage(&img_floats);
......
......@@ -15,9 +15,12 @@ limitations under the License.
#include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h"
#include <algorithm>
#include <utility>
#include "tensorflow/core/common_runtime/shape_refiner.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/public/session_options.h"
......@@ -222,4 +225,78 @@ RemoteFusedGraphExecuteUtils::GetExecutorBuildRegistry() {
return true;
}
/* static */ Status RemoteFusedGraphExecuteUtils::PropagateShapeInference(
const GraphDef& graph_def,
const std::vector<std::pair<string, Tensor>>& input_node_info_list,
Graph* graph, ShapeRefiner* shape_refiner) {
Status status;
auto visit = [&shape_refiner, &input_node_info_list, &status](Node* node) {
if (!status.ok()) {
return;
}
CHECK_NE(node, nullptr);
// If we visit an input node, we use the shape provided and set the
// shape accordingly.
bool is_input_node = false;
for (const std::pair<string, Tensor>& input_node_info :
input_node_info_list) {
if (node->name() == input_node_info.first) {
shape_inference::InferenceContext* context =
shape_refiner->GetContext(node);
shape_inference::ShapeHandle handle;
status = context->MakeShapeFromTensorShape(
input_node_info.second.shape(), &handle);
shape_refiner->SetShape(node, 0, handle);
is_input_node = true;
}
if (!status.ok()) {
break;
}
}
// If not an input node call AddNode() that recomputes the shape.
if (!is_input_node && status.ok()) {
status = shape_refiner->AddNode(node);
if (!status.ok()) {
VLOG(1) << "Shape inference failed for node: " << node->name();
}
}
};
ReverseDFS(*graph, {}, visit);
return status;
}
/* static */ Status RemoteFusedGraphExecuteUtils::BuildTensorShapeMapFromGraph(
const Graph& graph, const ShapeRefiner& shape_refiner,
TensorShapeMap* tensor_shape_map) {
for (int i = 0; i < graph.num_node_ids(); ++i) {
const Node* node = graph.FindNodeId(i);
CHECK_NE(node, nullptr);
for (int j = 0; j < node->num_outputs(); ++j) {
const int output_index = j;
const DataType dt = node->output_type(output_index);
shape_inference::InferenceContext* context =
shape_refiner.GetContext(node);
CHECK_NE(context, nullptr);
shape_inference::ShapeHandle shape_handle = context->output(output_index);
if (context->RankKnown(shape_handle)) {
TensorShape ts;
for (int k = 0; k < context->Rank(shape_handle); ++k) {
shape_inference::DimensionHandle dh = context->Dim(shape_handle, k);
CHECK(context->ValueKnown(dh));
ts.AddDim(context->Value(dh));
}
const string& node_name = node->name();
CHECK(tensor_shape_map->count(node_name) == 0);
tensor_shape_map->emplace(node_name, std::make_pair(dt, ts));
} else {
return errors::InvalidArgument("Graph contains unknow shapes");
}
}
}
return Status::OK();
}
} // namespace tensorflow
......@@ -20,6 +20,8 @@ limitations under the License.
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/kernels/i_remote_fused_graph_executor.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/macros.h"
......@@ -89,6 +91,15 @@ class RemoteFusedGraphExecuteUtils {
const std::vector<TensorShape>& shapes,
NodeDef* node_def);
static Status PropagateShapeInference(
const GraphDef& graph_def,
const std::vector<std::pair<string, Tensor>>& input_node_info_list,
Graph* graph, ShapeRefiner* shape_refiner);
static Status BuildTensorShapeMapFromGraph(const Graph& graph,
const ShapeRefiner& shape_refiner,
TensorShapeMap* tensor_shape_map);
private:
static ExecutorBuildRegistry* GetExecutorBuildRegistry();
......
......@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h"
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/core/common_runtime/shape_refiner.h"
#include "tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/test.h"
......@@ -100,27 +101,66 @@ TEST(RemoteFusedGraphExecuteUtils, DryRunAddGraphForAllNodes) {
// Setup dryrun arguments
const std::vector<std::pair<string, Tensor>> inputs{input_node_info_a};
RemoteFusedGraphExecuteUtils::TensorShapeMap output_tensor_info;
RemoteFusedGraphExecuteUtils::TensorShapeMap tensor_shape_map;
GraphDef def = RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph(
NAME_A, NODE_A_VAL, NAME_B, NODE_B_VAL, NAME_A_PLUS_B);
// dryrun
const Status status = RemoteFusedGraphExecuteUtils::DryRunInferenceForAllNode(
def, inputs, false /* initialize_by_zero */, &output_tensor_info);
def, inputs, false /* initialize_by_zero */, &tensor_shape_map);
ASSERT_TRUE(status.ok()) << status;
// Assert output node count
ASSERT_EQ(3, output_tensor_info.size());
ASSERT_EQ(1, output_tensor_info.count(NAME_A));
ASSERT_EQ(1, output_tensor_info.count(NAME_B));
ASSERT_EQ(1, output_tensor_info.count(NAME_A_PLUS_B));
EXPECT_EQ(DT_FLOAT, output_tensor_info.at(NAME_B).first);
EXPECT_EQ(DT_FLOAT, output_tensor_info.at(NAME_A_PLUS_B).first);
const TensorShape& shape_b = output_tensor_info.at(NAME_B).second;
const TensorShape& shape_a_b = output_tensor_info.at(NAME_A_PLUS_B).second;
ASSERT_EQ(3, tensor_shape_map.size());
ASSERT_EQ(1, tensor_shape_map.count(NAME_A));
ASSERT_EQ(1, tensor_shape_map.count(NAME_B));
ASSERT_EQ(1, tensor_shape_map.count(NAME_A_PLUS_B));
EXPECT_EQ(DT_FLOAT, tensor_shape_map.at(NAME_B).first);
EXPECT_EQ(DT_FLOAT, tensor_shape_map.at(NAME_A_PLUS_B).first);
const TensorShape& shape_b = tensor_shape_map.at(NAME_B).second;
const TensorShape& shape_a_b = tensor_shape_map.at(NAME_A_PLUS_B).second;
EXPECT_EQ(0, shape_b.dims());
EXPECT_EQ(0, shape_a_b.dims());
}
TEST(RemoteFusedGraphExecuteUtils, PropagateAndBuildTensorShapeMap) {
std::pair<string, Tensor> input_node_info_a;
input_node_info_a.first = NAME_A;
input_node_info_a.second = Tensor(DT_FLOAT, {});
input_node_info_a.second.scalar<float>()() = NODE_A_VAL;
std::pair<string, Tensor> input_node_info_b;
input_node_info_b.first = NAME_B;
input_node_info_b.second = Tensor(DT_FLOAT, {});
input_node_info_b.second.scalar<float>()() = NODE_B_VAL;
const std::vector<std::pair<string, Tensor>> inputs{input_node_info_a,
input_node_info_b};
RemoteFusedGraphExecuteUtils::TensorShapeMap tensor_shape_map;
GraphDef def = RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph(
NAME_A, NODE_A_VAL, NAME_B, NODE_B_VAL, NAME_A_PLUS_B);
ImportGraphDefOptions opts;
Graph graph(OpRegistry::Global());
ShapeRefiner shape_refiner(graph.versions().producer(), graph.op_registry());
Status status = ImportGraphDef(opts, def, &graph, &shape_refiner);
ASSERT_TRUE(RemoteFusedGraphExecuteUtils::PropagateShapeInference(
def, inputs, &graph, &shape_refiner)
.ok());
ASSERT_TRUE(RemoteFusedGraphExecuteUtils::BuildTensorShapeMapFromGraph(
graph, shape_refiner, &tensor_shape_map)
.ok());
ASSERT_EQ(3, tensor_shape_map.size());
ASSERT_EQ(1, tensor_shape_map.count(NAME_A));
ASSERT_EQ(1, tensor_shape_map.count(NAME_B));
ASSERT_EQ(1, tensor_shape_map.count(NAME_A_PLUS_B));
EXPECT_EQ(DT_FLOAT, tensor_shape_map.at(NAME_B).first);
EXPECT_EQ(DT_FLOAT, tensor_shape_map.at(NAME_A_PLUS_B).first);
const TensorShape& shape_b = tensor_shape_map.at(NAME_B).second;
const TensorShape& shape_a_b = tensor_shape_map.at(NAME_A_PLUS_B).second;
EXPECT_EQ(0, shape_b.dims());
EXPECT_EQ(0, shape_a_b.dims());
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册