未验证 提交 d7345959 编写于 作者: Y Yan Chunwei 提交者: GitHub

Feature/pass manager (#11440)

上级 16a0f746
set(FLUID_CORE_MODULES proto_desc memory lod_tensor executor init) set(FLUID_CORE_MODULES proto_desc memory lod_tensor executor init)
cc_library(analysis SRCS dot.cc node.cc data_flow_graph.cc graph_traits.cc subgraph_splitter.cc fluid_to_data_flow_graph_pass.cc cc_library(analysis SRCS pass_manager.cc dot.cc node.cc data_flow_graph.cc graph_traits.cc subgraph_splitter.cc
DEPS paddle_fluid) fluid_to_data_flow_graph_pass.cc
data_flow_graph_to_fluid_pass.cc
tensorrt_subgraph_pass.cc
dfg_graphviz_draw_pass.cc
DEPS framework_proto)
cc_test(test_node SRCS node_tester.cc DEPS analysis) cc_test(test_node SRCS node_tester.cc DEPS analysis)
cc_test(test_dot SRCS dot_tester.cc DEPS analysis) cc_test(test_dot SRCS dot_tester.cc DEPS analysis)
set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests) set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests)
cc_test(test_data_flow_graph SRCS data_flow_graph_tester.cc DEPS analysis ${FLUID_CORE_MODULES} paddle_fluid function (inference_analysis_test TARGET)
ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model) set(options "")
set_tests_properties(test_data_flow_graph PROPERTIES DEPENDS test_word2vec) set(oneValueArgs "")
set(multiValueArgs SRCS)
cmake_parse_arguments(analysis_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
cc_test(test_subgraph_splitter cc_test(${TARGET}
SRCS subgraph_splitter_tester.cc SRCS "${analysis_test_SRCS}"
DEPS analysis paddle_fluid tensor
ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model)
set_tests_properties(test_subgraph_splitter PROPERTIES DEPENDS test_word2vec)
cc_test(test_dfg_graphviz_draw_pass
SRCS dfg_graphviz_draw_pass_tester.cc
DEPS analysis DEPS analysis
ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model) ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model --fraction_of_gpu_memory_to_use=0.5)
set_tests_properties(test_dfg_graphviz_draw_pass PROPERTIES DEPENDS test_word2vec) set_tests_properties(${TARGET} PROPERTIES DEPENDS test_word2vec)
endfunction(inference_analysis_test)
inference_analysis_test(test_data_flow_graph SRCS data_flow_graph_tester.cc)
inference_analysis_test(test_data_flow_graph_to_fluid_pass SRCS data_flow_graph_to_fluid_pass_tester.cc)
inference_analysis_test(test_fluid_to_data_flow_graph_pass SRCS fluid_to_data_flow_graph_pass_tester.cc)
inference_analysis_test(test_subgraph_splitter SRCS subgraph_splitter_tester.cc)
inference_analysis_test(test_dfg_graphviz_draw_pass SRCS dfg_graphviz_draw_pass_tester.cc)
#inference_analysis_test(test_tensorrt_subgraph_pass SRCS tensorrt_subgraph_pass_tester.cc)
inference_analysis_test(test_pass_manager SRCS pass_manager_tester.cc)
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/analysis/argument.h"
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* This file defines the class Argument, which is the input and output of the
* analysis module. All the fields that needed either by Passes or PassManagers
* are contained in Argument.
*
* TODO(Superjomn) Find some way better to contain the fields when it grow too
* big.
*/
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
namespace paddle {
namespace inference {
namespace analysis {
/*
* The argument definition of both Pass and PassManagers.
*
* All the fields should be registered here for clearness.
*/
struct Argument {
// The graph that process by the Passes or PassManagers.
std::unique_ptr<DataFlowGraph> main_dfg;
// The original program desc.
std::unique_ptr<framework::proto::ProgramDesc> origin_program_desc;
};
#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)
#define ANALYSIS_ARGUMENT_CHECK_FIELD(field__) \
if (!UNLIKELY(field__)) { \
LOG(ERROR) << "field " << #field__ << " should be set."; \
return false; \
}
} // namespace analysis
} // namespace inference
} // namespace paddle
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/inference/analysis/data_flow_graph.h" #include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/inference/analysis/dot.h" #include "paddle/fluid/inference/analysis/dot.h"
#include "paddle/fluid/inference/analysis/node.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
...@@ -57,19 +58,7 @@ std::string DataFlowGraph::DotString() const { ...@@ -57,19 +58,7 @@ std::string DataFlowGraph::DotString() const {
// Add nodes // Add nodes
for (size_t i = 0; i < nodes.size(); i++) { for (size_t i = 0; i < nodes.size(); i++) {
const Node &node = nodes.Get(i); const Node &node = nodes.Get(i);
switch (node.type()) {
case Node::Type::kValue:
dot.AddNode(node.repr(), node.dot_attrs()); dot.AddNode(node.repr(), node.dot_attrs());
break;
case Node::Type::kFunction:
dot.AddNode(node.repr(), node.dot_attrs());
break;
case Node::Type::kFunctionBlock:
dot.AddNode(node.repr(), node.dot_attrs());
break;
default:
PADDLE_THROW("unsupported Node type %d", static_cast<int>(node.type()));
}
} }
// Add edges // Add edges
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h"
#include "paddle/fluid/framework/proto_desc.h"
namespace paddle {
namespace inference {
namespace analysis {
bool DataFlowGraphToFluidPass::Initialize(Argument* argument) {
ANALYSIS_ARGUMENT_CHECK_FIELD(argument)
ANALYSIS_ARGUMENT_CHECK_FIELD(argument->origin_program_desc)
desc_ = argument->origin_program_desc.get();
// Here some logic from program_desc.cc and will not add new interfaces into
// framework::ProgramDesc class, use some UT to assure the correctness.
auto* block = desc_->mutable_blocks()->Add();
block->set_idx(framework::kRootBlockIndex);
block->set_parent_idx(framework::kNoneBlockIndex);
return true;
}
bool DataFlowGraphToFluidPass::Finalize() { return true; }
void DataFlowGraphToFluidPass::Run(DataFlowGraph* graph) {
auto traits = GraphTraits<DataFlowGraph>(graph);
for (auto it = traits.nodes().begin(); it != traits.nodes().end(); ++it) {
if (it->deleted()) continue;
switch (it->type()) {
case Node::Type::kFunction:
LOG(INFO) << "add function " << it->name();
AddFluidOp(&(*it));
break;
case Node::Type::kFunctionBlock:
AddEngineOp(&(*it));
break;
default:
continue;
}
}
}
void DataFlowGraphToFluidPass::AddFluidOp(Node* node) {
LOG(INFO) << "processing func " << node->name();
auto* ori_op = static_cast<framework::proto::OpDesc*>(node->pb_desc());
// currently only the main block is analyzed.
auto* main_block = desc_->mutable_blocks(framework::kRootBlockIndex);
auto* op = main_block->add_ops();
LOG(INFO) << "to copy the op";
*op = *ori_op; // copy the attributes, by default, these will not be changed
// by analysis phrase.
// The inputs and outputs of the existing ops are not changed by tensorrt
// subgraph pass.
// NOTE It might be changed by other passes in the long run.
}
void DataFlowGraphToFluidPass::AddEngineOp(Node* node) {
// auto* ori_op = static_cast<framework::proto::OpDesc*>(node->extra_info());
// auto* main_block = desc_->mutable_blocks(framework::kRootBlockIndex);
// auto* op = main_block->add_ops();
// TODO(Superjomn) Here need to expose some arguments for default setting.
}
} // namespace analysis
} // namespace inference
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
/*
* This file implements the transformation from fluid ProgramDesc to data flow
* graph.
*/
#pragma once
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/inference/analysis/pass.h"
namespace paddle {
namespace inference {
namespace analysis {
class DataFlowGraphToFluidPass final : public DataFlowGraphPass {
public:
DataFlowGraphToFluidPass() = default;
bool Initialize(Argument *argument) override;
bool Finalize() override;
void Run(DataFlowGraph *graph) override;
std::string repr() const override { return "DFG to fluid"; }
std::string description() const override {
return "Transform a DFG to a Fluid ProgramDesc";
}
Pass *CreatePrinterPass(std::ostream &os,
const std::string &banner) const override {
return nullptr;
}
protected:
// Add a Fluid Op into the ProgramDesc.
void AddFluidOp(Node *node);
// Add a EngineOp into the ProgramDesc.
void AddEngineOp(Node *node);
private:
framework::proto::ProgramDesc *desc_;
};
} // namespace analysis
} // namespace inference
} // namespace paddle
...@@ -27,13 +27,12 @@ namespace inference { ...@@ -27,13 +27,12 @@ namespace inference {
namespace analysis { namespace analysis {
TEST_F(DFG_Tester, Test) { TEST_F(DFG_Tester, Test) {
framework::proto::ProgramDesc new_desc;
DataFlowGraph graph; DataFlowGraph graph;
FluidToDataFlowGraphPass pass0; FluidToDataFlowGraphPass pass0;
DataFlowGraphToFluidPass pass1; DataFlowGraphToFluidPass pass1;
pass0.Initialize(desc); ASSERT_TRUE(pass0.Initialize(&argument));
pass1.Initialize(&new_desc); ASSERT_TRUE(pass1.Initialize(&argument));
pass0.Run(&graph); pass0.Run(&graph);
pass1.Run(&graph); pass1.Run(&graph);
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h"
namespace paddle {
namespace inference {
namespace analysis {
void DFG_GraphvizDrawPass::Run(DataFlowGraph *graph) {
auto content = Draw(graph);
std::ofstream file(GenDotPath());
file.write(content.c_str(), content.size());
file.close();
LOG(INFO) << "draw dot to " << GenDotPath();
}
std::string DFG_GraphvizDrawPass::Draw(DataFlowGraph *graph) {
Dot dot;
// Add nodes
for (size_t i = 0; i < graph->nodes.size(); i++) {
const Node &node = graph->nodes.Get(i);
if (config_.display_deleted_node || !node.deleted()) {
dot.AddNode(node.repr(), node.dot_attrs());
}
}
// Add edges
for (size_t i = 0; i < graph->nodes.size(); i++) {
const Node &node = graph->nodes.Get(i);
if (!config_.display_deleted_node && node.deleted()) continue;
for (auto &in : node.inlinks) {
if (!config_.display_deleted_node && in->deleted()) continue;
for (auto &in : node.inlinks) {
dot.AddEdge(in->repr(), node.repr(), {});
}
}
}
return dot.Build();
}
} // namespace analysis
} // namespace inference
} // namespace paddle
...@@ -21,6 +21,7 @@ limitations under the License. */ ...@@ -21,6 +21,7 @@ limitations under the License. */
#include <fstream> #include <fstream>
#include <string> #include <string>
#include "paddle/fluid/inference/analysis/dot.h"
#include "paddle/fluid/inference/analysis/pass.h" #include "paddle/fluid/inference/analysis/pass.h"
namespace paddle { namespace paddle {
...@@ -32,35 +33,39 @@ namespace analysis { ...@@ -32,35 +33,39 @@ namespace analysis {
*/ */
class DFG_GraphvizDrawPass : public DataFlowGraphPass { class DFG_GraphvizDrawPass : public DataFlowGraphPass {
public: public:
DFG_GraphvizDrawPass(const std::string& dir, const std::string& id) struct Config {
: dir_(dir), id_(id) {} Config(const std::string &dir, const std::string &id,
bool display_deleted_node = false)
bool Initialize() override { return Pass::Initialize(); } : dir(dir), id(id), display_deleted_node(display_deleted_node) {}
void Run(DataFlowGraph* graph) override {
auto content = Draw(graph); // The directory to store the .dot or .png files.
std::ofstream file(GenDotPath()); const std::string dir;
file.write(content.c_str(), content.size()); // The identifier for this dot file.
file.close(); const std::string id;
LOG(INFO) << "draw dot to " << GenDotPath(); // Whether to display deleted nodes, default false.
} const bool display_deleted_node;
};
DFG_GraphvizDrawPass(const Config &config) : config_(config) {}
bool Initialize(Argument *argument) override { return true; }
void Run(DataFlowGraph *graph) override;
bool Finalize() override { return Pass::Finalize(); } bool Finalize() override { return Pass::Finalize(); }
Pass* CreatePrinterPass(std::ostream& os, std::string repr() const override { return "DFG graphviz drawer"; }
const std::string& banner) const override { std::string description() const override {
return nullptr; return "Debug a DFG by draw with graphviz";
} }
private: private:
// Path of the dot file to output. // Path of the dot file to output.
std::string GenDotPath() const { std::string GenDotPath() const {
return dir_ + "/" + "graph_" + id_ + ".dot"; return config_.dir + "/" + "graph_" + config_.id + ".dot";
} }
std::string Draw(DataFlowGraph* graph) { return graph->DotString(); } std::string Draw(DataFlowGraph *graph);
std::string dir_; Config config_;
std::string id_;
}; };
} // namespace analysis } // namespace analysis
......
...@@ -24,9 +24,10 @@ namespace inference { ...@@ -24,9 +24,10 @@ namespace inference {
namespace analysis { namespace analysis {
TEST_F(DFG_Tester, dfg_graphviz_draw_pass_tester) { TEST_F(DFG_Tester, dfg_graphviz_draw_pass_tester) {
auto dfg = ProgramDescToDFG(desc); auto dfg = ProgramDescToDFG(*argument.origin_program_desc);
DFG_GraphvizDrawPass pass("./", "test"); DFG_GraphvizDrawPass::Config config("./", "test");
pass.Initialize(); DFG_GraphvizDrawPass pass(config);
pass.Initialize(&argument);
pass.Run(&dfg); pass.Run(&dfg);
// test content // test content
...@@ -38,7 +39,8 @@ TEST_F(DFG_Tester, dfg_graphviz_draw_pass_tester) { ...@@ -38,7 +39,8 @@ TEST_F(DFG_Tester, dfg_graphviz_draw_pass_tester) {
while (std::getline(file, line)) { while (std::getline(file, line)) {
no++; no++;
} }
ASSERT_EQ(no, 82); // DFG is sensitive to ProgramDesc, be careful to change the existing models.
ASSERT_EQ(no, 112);
} }
} // namespace analysis } // namespace analysis
......
...@@ -21,19 +21,23 @@ namespace paddle { ...@@ -21,19 +21,23 @@ namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
FluidToDataFlowGraphPass::FluidToDataFlowGraphPass() {} bool FluidToDataFlowGraphPass::Initialize(Argument *argument) {
ANALYSIS_ARGUMENT_CHECK_FIELD(argument);
bool FluidToDataFlowGraphPass::Initialize() { return Pass::Initialize(); } ANALYSIS_ARGUMENT_CHECK_FIELD(argument->origin_program_desc);
PADDLE_ENFORCE(argument);
bool FluidToDataFlowGraphPass::Initialize( if (!argument->main_dfg) {
const framework::proto::ProgramDesc &desc) { LOG(INFO) << "Init DFG";
desc_ = &desc; argument->main_dfg.reset(new DataFlowGraph);
}
desc_ = argument->origin_program_desc.get();
return true; return true;
} }
bool FluidToDataFlowGraphPass::Finalize() { return Pass::Finalize(); } bool FluidToDataFlowGraphPass::Finalize() { return Pass::Finalize(); }
void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) { void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) {
PADDLE_ENFORCE(graph);
PADDLE_ENFORCE(desc_);
// insert vars // insert vars
std::unordered_map<std::string, size_t> var2id; std::unordered_map<std::string, size_t> var2id;
auto &main_block = desc_->blocks(framework::kRootBlockIndex); auto &main_block = desc_->blocks(framework::kRootBlockIndex);
...@@ -41,7 +45,7 @@ void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) { ...@@ -41,7 +45,7 @@ void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) {
const auto &var = main_block.vars(i); const auto &var = main_block.vars(i);
auto *v = graph->nodes.Create(Node::Type::kValue); auto *v = graph->nodes.Create(Node::Type::kValue);
v->SetName(var.name()); v->SetName(var.name());
v->SetExtraInfo(const_cast<void *>(static_cast<const void *>(&var))); v->SetPbDesc(const_cast<void *>(static_cast<const void *>(&var)));
var2id[var.name()] = v->id(); var2id[var.name()] = v->id();
} }
for (int i = 0; i < main_block.ops_size(); i++) { for (int i = 0; i < main_block.ops_size(); i++) {
...@@ -51,7 +55,7 @@ void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) { ...@@ -51,7 +55,7 @@ void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) {
static_cast<Function *>(o)->SetFuncType(op.type()); static_cast<Function *>(o)->SetFuncType(op.type());
// Link to the original protobuf message's memory, make it easier to // Link to the original protobuf message's memory, make it easier to
// generate from a data flow graph to fluid ProgramDesc. // generate from a data flow graph to fluid ProgramDesc.
o->SetExtraInfo(const_cast<void *>(static_cast<const void *>(&op))); o->SetPbDesc(const_cast<void *>(static_cast<const void *>(&op)));
// set inputs and outputs // set inputs and outputs
// TODO(Superjomn) make sure the InputNames is the real variable name. // TODO(Superjomn) make sure the InputNames is the real variable name.
for (int j = 0; j < op.inputs_size(); j++) { for (int j = 0; j < op.inputs_size(); j++) {
......
...@@ -34,13 +34,18 @@ namespace analysis { ...@@ -34,13 +34,18 @@ namespace analysis {
*/ */
class FluidToDataFlowGraphPass final : public DataFlowGraphPass { class FluidToDataFlowGraphPass final : public DataFlowGraphPass {
public: public:
FluidToDataFlowGraphPass(); FluidToDataFlowGraphPass() = default;
bool Initialize() override;
bool Initialize(const framework::proto::ProgramDesc &desc) override; bool Initialize(Argument *argument) override;
bool Finalize() override; bool Finalize() override;
void Run(DataFlowGraph *graph) override; void Run(DataFlowGraph *graph) override;
std::string repr() const override { return "fluid-to-data-flow-graph"; }
std::string description() const override {
return "transform a fluid ProgramDesc to a data flow graph.";
}
Pass *CreatePrinterPass(std::ostream &os, Pass *CreatePrinterPass(std::ostream &os,
const std::string &banner) const override; const std::string &banner) const override;
......
...@@ -23,11 +23,11 @@ namespace analysis { ...@@ -23,11 +23,11 @@ namespace analysis {
TEST_F(DFG_Tester, Init) { TEST_F(DFG_Tester, Init) {
FluidToDataFlowGraphPass pass; FluidToDataFlowGraphPass pass;
pass.Initialize(); pass.Initialize(&argument);
pass.Initialize(desc);
DataFlowGraph graph; DataFlowGraph graph;
pass.Run(&graph); pass.Run(&graph);
ASSERT_GT(graph.nodes.size(), 0); // Analysis is sensitive to ProgramDesc, careful to change the original model.
ASSERT_EQ(graph.nodes.size(), 37);
pass.Finalize(); pass.Finalize();
LOG(INFO) << '\n' << graph.DotString(); LOG(INFO) << '\n' << graph.DotString();
} }
......
...@@ -62,6 +62,7 @@ struct DataTypeNamer { ...@@ -62,6 +62,7 @@ struct DataTypeNamer {
SET_TYPE(int); SET_TYPE(int);
SET_TYPE(bool); SET_TYPE(bool);
SET_TYPE(float); SET_TYPE(float);
SET_TYPE(void *);
} }
std::unordered_map<decltype(typeid(int).hash_code()), // NOLINT std::unordered_map<decltype(typeid(int).hash_code()), // NOLINT
......
...@@ -40,6 +40,9 @@ Node *NodeMap::Create(Node::Type type) { ...@@ -40,6 +40,9 @@ Node *NodeMap::Create(Node::Type type) {
case Node::Type::kValue: case Node::Type::kValue:
nodes_.emplace_back(new Value); nodes_.emplace_back(new Value);
break; break;
case Node::Type::kFunctionBlock:
nodes_.emplace_back(new FunctionBlock);
break;
default: default:
PADDLE_THROW("Not supported node type."); PADDLE_THROW("Not supported node type.");
} }
......
...@@ -71,12 +71,17 @@ class Node { ...@@ -71,12 +71,17 @@ class Node {
// Get an additional attribute and convert it to T data type. NOTE this will // Get an additional attribute and convert it to T data type. NOTE this will
// silently create a new attribute if not exists. // silently create a new attribute if not exists.
Attr &attr(const std::string &name) { return attrs_[name]; } Attr &attr(const std::string &name) const { return attrs_[name]; }
int id() const { return id_; } int id() const { return id_; }
bool deleted() const { return deleted_; } // The Protobuf description is set/get with a void* to decouple Node interface
// from a specific kind of Protobuf message.
void SetPbDesc(void *pb) { attr("pb_desc").Pointer() = pb; }
void *pb_desc() const { return attr("pb_desc").Pointer(); }
void SetDeleted() { deleted_ = true; } void SetDeleted() { deleted_ = true; }
bool deleted() const { return deleted_; }
void SetName(const std::string &name) { name_ = name; } void SetName(const std::string &name) { name_ = name; }
const std::string &name() const { return name_; } const std::string &name() const { return name_; }
...@@ -84,29 +89,25 @@ class Node { ...@@ -84,29 +89,25 @@ class Node {
void SetType(Type type) { type_ = type; } void SetType(Type type) { type_ = type; }
Type type() const { return type_; } Type type() const { return type_; }
void *extra_info() const { return extra_info_; }
void SetExtraInfo(void *extra_info) { extra_info_ = extra_info; }
// Input links. // Input links.
std::vector<Node *> inlinks; std::vector<Node *> inlinks;
// Output links. // Output links.
std::vector<Node *> outlinks; std::vector<Node *> outlinks;
// A helper class to maintain the status from Pass. // A helper class to maintain the status from Pass.
// TODO(superjomn) add a checker here to ensure the T is primary.
struct Attr { struct Attr {
// NOTE T should be a primary type or a struct combined by several primary // NOTE T should be a primary type or a struct combined by several primary
// types. // types.
// NOTE the STL containers should not use here. // NOTE the STL containers should not use here.
// Some usages // Some usages
// Attr attr; // Attr attr;
// T data; // attr.Bool() = true;
// attr.data.assign((char*)data, sizeof(data));
bool &Bool() { return As<bool>(); } bool &Bool() { return As<bool>(); }
float &Float() { return As<float>(); } float &Float() { return As<float>(); }
int32_t &Int32() { return As<int32_t>(); } int32_t &Int32() { return As<int32_t>(); }
int64_t &Int64() { return As<int64_t>(); } int64_t &Int64() { return As<int64_t>(); }
void *&Pointer() { return As<void *>(); }
private: private:
template <typename T> template <typename T>
...@@ -130,6 +131,7 @@ class Node { ...@@ -130,6 +131,7 @@ class Node {
size_t type_hash_{std::numeric_limits<size_t>::max()}; size_t type_hash_{std::numeric_limits<size_t>::max()};
}; };
// Type checks.
bool IsFunction() const { return type_ == Node::Type::kFunction; } bool IsFunction() const { return type_ == Node::Type::kFunction; }
bool IsValue() const { return type_ == Node::Type::kValue; } bool IsValue() const { return type_ == Node::Type::kValue; }
bool IsFunctionBlock() const { return type_ == Node::Type::kFunctionBlock; } bool IsFunctionBlock() const { return type_ == Node::Type::kFunctionBlock; }
...@@ -148,9 +150,6 @@ class Node { ...@@ -148,9 +150,6 @@ class Node {
Type type_{Type::kNone}; Type type_{Type::kNone};
// Mark this node is deleted by some pass. // Mark this node is deleted by some pass.
bool deleted_{false}; bool deleted_{false};
void *extra_info_;
mutable std::unordered_map<std::string, Attr> attrs_; mutable std::unordered_map<std::string, Attr> attrs_;
}; };
......
...@@ -19,6 +19,7 @@ limitations under the License. */ ...@@ -19,6 +19,7 @@ limitations under the License. */
#include <string> #include <string>
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/inference/analysis/argument.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h" #include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/inference/analysis/helper.h" #include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/analysis/node.h" #include "paddle/fluid/inference/analysis/node.h"
...@@ -30,19 +31,24 @@ namespace analysis { ...@@ -30,19 +31,24 @@ namespace analysis {
class Pass { class Pass {
public: public:
Pass() = default; Pass() = default;
virtual ~Pass() {} virtual ~Pass() = default;
// Virtual method overridden by subclasses to do only necessary initialization // Virtual method overridden by subclasses to do only necessary initialization
// before any pass is run. // before any pass is run.
virtual bool Initialize() { return false; } // virtual bool Initialize() { return false; }
// There is some passes such as FlowToDataFlowGraphPass that needs a // There is some passes such as FlowToDataFlowGraphPass that needs a
// ProgramDesc. Here use the native ProgramDesc ProtoBuf message, so that it // ProgramDesc. Here use the native ProgramDesc ProtoBuf message, so that it
// only couple with the proto file. // only couple with the proto file.
virtual bool Initialize(const framework::proto::ProgramDesc &desc) { // virtual bool Initialize(const framework::proto::ProgramDesc &desc) { return
return false; // false; }
}
// There are some Passes such as DataFlowGraphToFluidPass that will output a // There are some Passes such as DataFlowGraphToFluidPass that will output a
// ProgramDesc. // ProgramDesc.
virtual bool Initialize(framework::proto::ProgramDesc *desc) { return false; } // virtual bool Initialize(framework::proto::ProgramDesc *desc) { return
// false; }
// Mutable Pass.
virtual bool Initialize(Argument *argument) { return false; }
// Readonly Pass.
virtual bool Initialize(const Argument &argument) { return false; }
// Virtual method overriden by subclasses to do any necessary clean up after // Virtual method overriden by subclasses to do any necessary clean up after
// all passes have run. // all passes have run.
...@@ -50,7 +56,9 @@ class Pass { ...@@ -50,7 +56,9 @@ class Pass {
// Get a Pass appropriate to print the Node this pass operates on. // Get a Pass appropriate to print the Node this pass operates on.
virtual Pass *CreatePrinterPass(std::ostream &os, virtual Pass *CreatePrinterPass(std::ostream &os,
const std::string &banner) const = 0; const std::string &banner) const {
return nullptr;
}
// Run on a single Node. // Run on a single Node.
virtual void Run(Node *x) { LOG(FATAL) << "not valid"; } virtual void Run(Node *x) { LOG(FATAL) << "not valid"; }
...@@ -60,6 +68,11 @@ class Pass { ...@@ -60,6 +68,11 @@ class Pass {
virtual void Run(FunctionBlock *x) { LOG(FATAL) << "not valid"; } virtual void Run(FunctionBlock *x) { LOG(FATAL) << "not valid"; }
// Run on a single DataFlowGraph. // Run on a single DataFlowGraph.
virtual void Run(DataFlowGraph *x) { LOG(FATAL) << "not valid"; } virtual void Run(DataFlowGraph *x) { LOG(FATAL) << "not valid"; }
// Human-readable short representation.
virtual std::string repr() const = 0;
// Human-readable long description.
virtual std::string description() const = 0;
}; };
// NodePass process on any Node types. // NodePass process on any Node types.
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/analysis/pass_manager.h"
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
namespace paddle {
namespace inference {
namespace analysis {
void DfgPassManager::RunAll() {
PADDLE_ENFORCE(argument_);
for (auto& pass : data_) {
VLOG(4) << "Running pass [" << pass->repr() << "]";
pass->Run(argument_->main_dfg.get());
}
}
void NodePassManager::RunAll() {
PADDLE_ENFORCE(argument_);
PADDLE_ENFORCE(argument_->main_dfg.get());
auto trait =
GraphTraits<DataFlowGraph>(argument_->main_dfg.get()).nodes_in_DFS();
for (auto& node : trait) {
for (auto& pass : data_) {
pass->Run(&node);
}
}
}
} // namespace analysis
} // namespace inference
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
/*
* This file defines the logic of pass management. The analysis for inference is
* a pipeline of Passes, a PassManager is a agency that helps to manage the
* executation of the Passes.
*
* There are two modes of Passes, the first one is called NodePass and takes
* an Node as input and output; the second one is called DFGPass and takes a
* DFG(Data Flow Graph) as input and output. It is hard to put all the passes in
* the same pipeline, there are two kinds of PassManagers, both takes a DFG as
* input and output a DFG, but the Passes inside are different:
*
* 1. NodePassManager: the passes inside are all NodePasses, it can have
* different graph trivial algorithm, for example, DFS_NodePassManager will
* trigger the passes in depth first order;
* 2. DfgPassManager: the passes inside are all DfgPasses.
*/
#pragma once
#include <string>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/inference/analysis/pass.h"
namespace paddle {
namespace inference {
namespace analysis {
/*
* PassManager is the base class for all pass managers, a pass manager has
* several Pass-es registered, and execute them in the linear order.
*/
class PassManager : public OrderedRegistry<Pass> {
public:
PassManager() = default;
// Call all the passes' Initialize methods. The desc and data_flow_graph are
// globally shared, so pass them as the arguemnts for all the pass managers.
virtual bool Initialize(const Argument& argument) { return false; }
virtual bool Initialize(Argument* argument) {
argument_ = argument;
for (auto& pass : data_) {
LOG(INFO) << "Initializing pass " << pass->repr();
if (!pass->Initialize(argument)) {
LOG(ERROR) << "Failed to initialize pass [" << pass->repr() << "]";
return false;
}
}
return true;
}
// Call all the passes' Finalize methods.
virtual bool Finalize() {
for (auto& pass : data_) {
if (!pass->Finalize()) {
LOG(ERROR) << "Failed to finalize pass [" << pass->repr() << "]";
return false;
}
}
return true;
}
// Run all the passes.
virtual void RunAll() = 0;
// Short identifier.
virtual std::string repr() const = 0;
// Long description.
virtual std::string description() const = 0;
virtual ~PassManager() = default;
protected:
Argument* argument_{nullptr};
};
/*
* A pass manager that process a DFG.
*/
class DfgPassManager : public PassManager {
public:
DfgPassManager() = default;
void RunAll() override;
virtual ~DfgPassManager() = default;
};
/*
* A pass manager that process a Node each time.
*/
class NodePassManager : public PassManager {
public:
NodePassManager() = default;
void RunAll() override;
virtual ~NodePassManager() = default;
};
} // namespace analysis
} // namespace inference
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/analysis/pass_manager.h"
#include "paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h"
#include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h"
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
#include "paddle/fluid/inference/analysis/ut_helper.h"
#include <gtest/gtest.h>
namespace paddle {
namespace inference {
namespace analysis {
class TestDfgPassManager final : public DfgPassManager {
public:
TestDfgPassManager() = default;
virtual ~TestDfgPassManager() = default;
// Short identifier.
std::string repr() const override { return "test-pass-manager"; }
// Long description.
std::string description() const override { return "test doc"; }
};
class TestNodePassManager final : public NodePassManager {
public:
virtual ~TestNodePassManager() = default;
std::string repr() const override { return "test-node-pass-manager"; }
std::string description() const override { return "test doc"; }
};
class TestNodePass final : public NodePass {
public:
virtual ~TestNodePass() = default;
bool Initialize(Argument* argument) override { return true; }
void Run(Node* node) override {
LOG(INFO) << "- Processing node " << node->repr();
}
std::string repr() const override { return "test-node"; }
std::string description() const override { return "some doc"; }
};
TEST_F(DFG_Tester, DFG_pass_manager) {
TestDfgPassManager manager;
DFG_GraphvizDrawPass::Config config("./", "dfg.dot");
manager.Register("fluid-to-flow-graph", new FluidToDataFlowGraphPass);
manager.Register("graphviz", new DFG_GraphvizDrawPass(config));
manager.Register("dfg-to-fluid", new DataFlowGraphToFluidPass);
ASSERT_TRUE(manager.Initialize(&argument));
manager.RunAll();
}
TEST_F(DFG_Tester, Node_pass_manager) {
// Pre-process: initialize the DFG with the ProgramDesc first.
FluidToDataFlowGraphPass pass0;
pass0.Initialize(&argument);
pass0.Run(argument.main_dfg.get());
TestNodePassManager manager;
manager.Register("test-node-pass", new TestNodePass);
ASSERT_TRUE(manager.Initialize(&argument));
manager.RunAll();
}
} // namespace analysis
} // namespace inference
} // namespace paddle
...@@ -19,12 +19,7 @@ namespace paddle { ...@@ -19,12 +19,7 @@ namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
TEST_F(DFG_Tester, Split) { SubGraphSplitter::NodeInsideSubgraphTeller teller = [](const Node* node) {
auto desc = LoadProgramDesc();
auto dfg = ProgramDescToDFG(desc);
LOG(INFO) << "spliter\n" << dfg.DotString();
SubGraphSplitter::NodeInsideSubgraphTeller teller = [](const Node* node) {
if (node->type() != Node::Type::kFunction) return false; if (node->type() != Node::Type::kFunction) return false;
const auto* func = static_cast<const Function*>(node); const auto* func = static_cast<const Function*>(node);
if (func->func_type() == "elementwise_add" || func->func_type() == "relu" || if (func->func_type() == "elementwise_add" || func->func_type() == "relu" ||
...@@ -34,7 +29,13 @@ TEST_F(DFG_Tester, Split) { ...@@ -34,7 +29,13 @@ TEST_F(DFG_Tester, Split) {
return true; return true;
} }
return false; return false;
}; };
TEST_F(DFG_Tester, Split) {
auto desc = LoadProgramDesc();
auto dfg = ProgramDescToDFG(desc);
LOG(INFO) << "spliter\n" << dfg.DotString();
ASSERT_GT(dfg.nodes.size(), 5UL); ASSERT_GT(dfg.nodes.size(), 5UL);
auto subgraphs = SubGraphSplitter(&dfg, teller)(); auto subgraphs = SubGraphSplitter(&dfg, teller)();
...@@ -62,6 +63,28 @@ TEST_F(DFG_Tester, Split) { ...@@ -62,6 +63,28 @@ TEST_F(DFG_Tester, Split) {
ASSERT_EQ(subgraphs.back().size(), 6UL); ASSERT_EQ(subgraphs.back().size(), 6UL);
} }
TEST_F(DFG_Tester, Fuse) {
auto desc = LoadProgramDesc();
auto dfg = ProgramDescToDFG(desc);
size_t count0 = dfg.nodes.size();
SubGraphFuse fuse(&dfg, teller);
fuse();
int count1 = 0;
for (auto& node : dfg.nodes.nodes()) {
if (node->deleted()) {
LOG(INFO) << "deleted " << node->repr();
}
count1 += node->deleted();
}
// At least one nodes should be deleted.
ASSERT_EQ(dfg.nodes.size(), count0 + 1); // added a new FunctionBlock
ASSERT_EQ(6UL, count1);
}
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/analysis/tensorrt_subgraph_pass.h"
#include "paddle/fluid/inference/analysis/subgraph_splitter.h"
namespace paddle {
namespace inference {
namespace analysis {
TensorRTSubGraphPass::TensorRTSubGraphPass(
const TensorRTSubGraphPass::NodeInsideSubgraphTeller &teller)
: node_inside_subgraph_teller_(teller) {}
void TensorRTSubGraphPass::Run(DataFlowGraph *graph) {
SubGraphFuse(graph, node_inside_subgraph_teller_);
}
} // analysis
} // inference
} // paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/inference/analysis/node.h"
#include "paddle/fluid/inference/analysis/pass.h"
#include "paddle/fluid/inference/analysis/subgraph_splitter.h"
namespace paddle {
namespace inference {
namespace analysis {
/*
* Parse the graph and replace TensorRT supported nodes with SubGraphNode
*/
class TensorRTSubGraphPass : public DataFlowGraphPass {
public:
// Tell whether to transform a sub-graph into TensorRT.
using NodeInsideSubgraphTeller = SubGraphFuse::NodeInsideSubgraphTeller;
TensorRTSubGraphPass(const NodeInsideSubgraphTeller& teller);
bool Initialize(Argument* argument) override { return true; }
// This class get a sub-graph as input and determine whether to transform this
// sub-graph into TensorRT.
void Run(DataFlowGraph* graph) override;
private:
NodeInsideSubgraphTeller node_inside_subgraph_teller_;
};
} // namespace analysis
} // namespace inference
} // paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/analysis/tensorrt_subgraph_pass.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h"
#include "paddle/fluid/inference/analysis/ut_helper.h"
namespace paddle {
namespace inference {
namespace analysis {
DEFINE_string(model_dir, "", "inference test model dir");
TEST(TensorRTSubGraph, single_pass) {
auto desc = LoadProgramDesc();
auto dfg = ProgramDescToDFG(desc);
SubGraphSplitter::NodeInsideSubgraphTeller teller = [](const Node* node) {
if (node->type() != Node::Type::kFunction) return false;
const auto* func = static_cast<const Function*>(node);
if (func->func_type() == "elementwise_add" || func->func_type() == "relu" ||
func->func_type() == "conv2d" || func->func_type() == "mul" ||
func->func_type() == "sigmoid" || func->func_type() == "softmax") {
LOG(INFO) << "sub-graph marked " << node->repr();
return true;
}
return false;
};
DFG_GraphvizDrawPass::Config config{"./", "test"};
DFG_GraphvizDrawPass dfg_pass(config);
dfg_pass.Initialize();
DFG_GraphvizDrawPass dfg_pass1(config);
dfg_pass1.Initialize();
dfg_pass.Run(&dfg);
TensorRTSubGraphPass trt_pass(std::move(teller));
trt_pass.Initialize();
trt_pass.Run(&dfg);
dfg_pass1.Run(&dfg);
// Check the TRT op's block desc
for (auto node : dfg.nodes.nodes()) {
if (node->IsFunctionBlock()) {
}
}
}
TEST(TensorRTSubGraph, pass_manager) {}
} // namespace analysis
} // namespace inference
} // namespace paddle
...@@ -15,33 +15,46 @@ limitations under the License. */ ...@@ -15,33 +15,46 @@ limitations under the License. */
#pragma once #pragma once
#include <gflags/gflags.h> #include <gflags/gflags.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <fstream>
#include <string> #include <string>
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h" #include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h" #include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
#include "paddle/fluid/inference/analysis/ut_helper.h" #include "paddle/fluid/inference/analysis/ut_helper.h"
#include "paddle/fluid/inference/io.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
// Read ProgramDesc from a __model__ file, defined in io.cc
extern void ReadBinaryFile(const std::string& filename, std::string* contents);
namespace analysis { namespace analysis {
DEFINE_string(inference_model_dir, "", "inference test model dir"); DEFINE_string(inference_model_dir, "", "inference test model dir");
static framework::proto::ProgramDesc LoadProgramDesc( static framework::proto::ProgramDesc LoadProgramDesc(
const std::string& model_dir = FLAGS_inference_model_dir) { const std::string& model_dir = FLAGS_inference_model_dir) {
paddle::platform::CPUPlace place; std::string msg;
paddle::framework::Executor executor(place); std::string net_file = FLAGS_inference_model_dir + "/__model__";
paddle::framework::Scope scope; std::ifstream fin(net_file, std::ios::in | std::ios::binary);
auto program = Load(&executor, &scope, model_dir); PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s", net_file);
return *program->Proto(); fin.seekg(0, std::ios::end);
msg.resize(fin.tellg());
fin.seekg(0, std::ios::beg);
fin.read(&(msg.at(0)), msg.size());
fin.close();
framework::proto::ProgramDesc program_desc;
program_desc.ParseFromString(msg);
return program_desc;
} }
static DataFlowGraph ProgramDescToDFG( static DataFlowGraph ProgramDescToDFG(
const framework::proto::ProgramDesc& desc) { const framework::proto::ProgramDesc& desc) {
DataFlowGraph graph; DataFlowGraph graph;
FluidToDataFlowGraphPass pass; FluidToDataFlowGraphPass pass;
pass.Initialize(desc); Argument argument;
argument.origin_program_desc.reset(new framework::proto::ProgramDesc(desc));
pass.Initialize(&argument);
pass.Run(&graph); pass.Run(&graph);
pass.Finalize(); pass.Finalize();
return graph; return graph;
...@@ -49,9 +62,12 @@ static DataFlowGraph ProgramDescToDFG( ...@@ -49,9 +62,12 @@ static DataFlowGraph ProgramDescToDFG(
class DFG_Tester : public ::testing::Test { class DFG_Tester : public ::testing::Test {
protected: protected:
void SetUp() override { desc = LoadProgramDesc(FLAGS_inference_model_dir); } void SetUp() override {
auto desc = LoadProgramDesc(FLAGS_inference_model_dir);
argument.origin_program_desc.reset(new framework::proto::ProgramDesc(desc));
}
framework::proto::ProgramDesc desc; Argument argument;
}; };
} // namespace analysis } // namespace analysis
......
...@@ -240,7 +240,7 @@ void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) { ...@@ -240,7 +240,7 @@ void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) {
} }
// Test with a larger FC layer. // Test with a larger FC layer.
TEST(TensorRTEngineOp, fc) { Execute(40, 256, 256); } TEST(TensorRTEngineOp, fc) { Execute(40, 28, 28); }
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册