提交 cfc6338a 编写于 作者: Q qiaolongfei

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into add-none-layers-api-doc

...@@ -14,4 +14,3 @@ ...@@ -14,4 +14,3 @@
# #
add_subdirectory(inference) add_subdirectory(inference)
add_subdirectory(tape)
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
if(APPLE)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=pessimizing-move")
endif(APPLE)
cc_library(tape_variable SRCS variable.cc DEPS ${FLUID_CORE_MODULES} device_context framework_proto proto_desc operator)
cc_library(tape SRCS tape.cc DEPS ${FLUID_CORE_MODULES} ${GLOB_OP_LIB} tape_variable)
cc_test(test_tape
SRCS test_tape.cc
DEPS tape tape_variable)
# Dynamic Graph on Fluid
PaddlePaddle Fluid is targeting the autodiff without tape, which, however, is very
challenging and we are still way from there. DyNet and PyTorch provide a good design
idea, the *tape*, that significantly eases the challenge. Also, DyNet provides
a C++ API that is as convenient as Python but with higher efficiency and could
conveniently integrate with industrial/production systems. This package, `tape`,
combines the good of
1. tape from PyTorch and DyNet
2. C++ API and core from DyNet
3. rich set of operators from PaddlePaddle
## Overview
We can implement Dynet-like Tape(See this [survey](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/survey/dynamic_graph.md))
by wrapping Paddle Fluid's `Operator` and `Variable`.
The user API is straight forward since
1. it is imperative. And it uses host language's control flow logic.
1. it avoids extra concepts such as `Scope` and `Executor`.
All of these benefits come at the cost of just adding one line `reset_global_tape`
at every iteration.
## Code Structure
In short, the `Tape` contains a vector of `OpHandle`s. And an `OpHandle` contains its
`type`, the pointers to the `Variable`s, and necessary attributes.
```c++
class Variable {
public:
VriableHandle Grad(); // returns its gradient variable
private:
framework::VarDesc desc_; // compile time infershape, necessary for lazy execution
framework::Variable var_; // run time variable, holds data memory
};
using VariableHandle = shared_ptr<Variable>;
struct OpHandle {
string type_;
map<string, vector<VariableHandle>> inputs_;
map<string, vector<VariableHandle>> outputs_;
AttributeMap attrs_;
};
class Tape {
public:
void AddOp(OpHandle); // add op
void Forward(); // execute the tape_
void Backward(); // execute the backward of the tape_
private:
vector<OpHandle> tape_;
};
```
We uses `Function` to indicate layers. It takes care of parameter
initialization and `AddOp` to the Tape when it is called.
```c++
class Linear {
public:
Linear(int in_dim, int out_dim, const std::string &act)
: w_(new Variable("LinearWeight")),
b_(new Variable("LinearBias")),
act_(act) {
Tape init_tape;
std::string initializer = "fill_constant";
framework::AttributeMap attrs;
attrs["dtype"] = paddle::framework::proto::VarType::Type::VarType_Type_FP32;
attrs["shape"] = std::vector<int>{in_dim, out_dim};
attrs["value"] = 1.0f;
init_tape.AddOp(initializer, {}, {{"Out", {w_}}}, attrs);
attrs["dtype"] = paddle::framework::proto::VarType::Type::VarType_Type_FP32;
attrs["shape"] = std::vector<int>{out_dim};
attrs["value"] = 1.0f;
init_tape.AddOp(initializer, {}, {{"Out", {b_}}}, attrs);
init_tape.Forward();
}
VariableHandle operator()(VariableHandle input) {
VariableHandle pre_bias(new Variable("linear"));
get_global_tape().AddOp("mul",
{{"X", {input}}, {"Y", {w_}}},
{{"Out", {pre_bias}}},
{{"x_num_col_dims", 1}, {"y_num_col_dims", 1}});
VariableHandle pre_act(new Variable("linear"));
get_global_tape().AddOp("elementwise_add",
{{"X", {pre_bias}}, {"Y", {b_}}},
{{"Out", {pre_act}}},
{{"axis", 1}});
VariableHandle post_act(new Variable("linear"));
get_global_tape().AddOp(act_,
{{"X", {pre_act}}},
{{"Out", {post_act}}},
{});
return post_act;
}
std::vector<VariableHandle> Params() { return {w_, b_}; }
private:
VariableHandle w_;
VariableHandle b_;
std::string act_;
};
```
## User API
```c++
// Model function
paddle::tape::Linear linear1(3, 3, "relu"); // init weight and bias
paddle::tape::Linear linear2(3, 3, "relu"); // init weight and bias
paddle::tape::Mean mean;
// Optimizer
paddle::tape::SGD sgd(0.001);
// Data Feeder
paddle::tape::Fill data_feeder(...);
VariableHandle input(new paddle::tape::Variable("input"));
VariableHandle label(new paddle::tape::Variable("label"));
for (int i = 0; i < 2; ++i) {
reset_global_tape();
data_feeder(input, label);
auto loss = softmax(linear2(linear1(input)), label); // compile time InferShape & InferVarType
LOG(INFO) << loss.value(); // Run forward up to loss
// Run backward, store gradient of w at w->Grad()
get_global_tape.Backward(loss);
// Update w
sgd(linear1.Params());
sgd(linear2.Params());
}
```
<details>
<summary></summary>
digraph G {
subgraph cluster_0 {
node [shape=record,style=filled];
style=filled;
color=lightgrey;
linear1 [label="{type: mul | {input | {<before_mul1>X: before_mul1 |<weight1> Y: weight1}} | {output |<before_bias1> Out: before_bias1}}"];
elementwise_add1 [label="{type: elementwise_add | {input | {<before_bias1>X: before_bias1 |<bias1> Y: bias1}} | {output |<before_act1> Out: before_act1}}"];
relu1 [label="{type: relu | {input | {<before_act1>X: before_act1 }} | {output |<after_act1> Out: after_act1}}"];
linear1 -> elementwise_add1->relu1;
label = "forward tape";
}
linear1:before_mul1->before_mul1
linear1:weight1->weight1
linear1:before_bias1->before_bias1
elementwise_add1:bias1->bias1
elementwise_add1:before_bias1->before_bias1
elementwise_add1:before_act1->before_act1
relu1:before_act1->before_act1
relu1:after_act1->after_act1
subgraph cluster_1 {
node [shape=record,style=filled];
style=filled;
color=lightgrey;
linear1_grad [label="{type: mul_grad | {input | {<before_mul1>X: before_mul1 |<weight1> Y: weight1|<before_bias1_grad> Out_grad: before_bias1_grad}} | {output |{<before_mul1_grad>X_grad: before_mul1_grad |<weight1_grad> Y_grad: weight1_grad}}}"];
elementwise_add1_grad [label="{type: elementwise_add_grad | {input | <before_act1_grad> Out_grad: before_act1_grad} | {output |{<before_bias1_grad>X_grad: before_bias1_grad |<bias1_grad> Y_grad: bias1_grad}}}"];
relu1_grad [label="{type: relu_grad | {input |<after_act1_grad> Out_grad: after_act1_grad} | {ouput | {<before_act1_grad>X_grad: before_act1_grad }}}"];
linear1_grad -> elementwise_add1_grad ->relu1_grad [dir=back];
label = "backward tape";
}
relu1_grad:after_act1_grad->after_act1_grad
relu1_grad:before_act1_grad->before_act1_grad
elementwise_add1_grad:before_act1_grad->before_act1_grad
elementwise_add1_grad:before_bias1_grad->before_bias1_grad
elementwise_add1_grad:bias1_grad->bias1_grad
linear1_grad:before_mul1->before_mul1
linear1_grad:weight1->weight1
linear1_grad:before_bias1_grad->before_bias1_grad
linear1_grad:before_mul1_grad->before_mul1_grad
linear1_grad:weight1_grad->weight1_grad
subgraph cluster_2 {
node [shape=record];
label = "Linear1";
weight1
bias1
}
weight1 -> weight1_grad [ label="Grad()", style="dashed" ];
bias1 -> bias1_grad [ label="Grad()", style="dashed"];
}
</details>
![Image](https://github.com/tonyyang-svail/Paddle/blob/cpp_tap/paddle/contrib/tape/computation_graph.png)
## Code Reuse
We want to stay close to Paddle Fluid as much as possible.
### Reuse All Operators
As all Ops are registered at `OpInfoMap`, the effort of adding a new `Function`
is about 10 lines of code, similar to expose an operator to Python.
### Reuse Compile Time InferShape and InferVarType
Note that all the symbolic information is stored at `tape::Varaible::desc_`, instead
of `ProgramDesc.block.vars`, we create a temporary `BlockDesc` to do `InferShape` and
`InferVarType` every time we `AddOp` to the tape.
### Reuse Operator::Run
We use smart pointer, instead of `Scope`, to manage memory. So we create a temporary
`Scope` for every `Operator::Run()`.
## Possible Feature
### Release Memory on Backward
We can release memory aggressively. During backward, we can delete the OpHandle once
we have finished its backward. Since all the variable is managed by smart pointer, the
memory is automatically released when its `ref_count` goes to 0.
### Kernel Fusion
As a symbolic representation of the Tape is constructed first before the actual
execution, it would be possible to perform graph optimization. One use case is kernel
fusion.
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/contrib/tape/tape.h"
#include "paddle/contrib/tape/variable.h"
#include "paddle/fluid/framework/type_defs.h"
namespace paddle {
namespace tape {
class Function {};
class Fill {
public:
Fill(const std::string &initializer, const framework::AttributeMap &attrs)
: initializer_(initializer), attrs_(attrs) {}
void operator()(VariableHandle var) {
get_global_tape().AddOp(initializer_, {}, {{"Out", {var}}}, attrs_);
}
private:
const std::string initializer_;
const framework::AttributeMap attrs_;
};
class Mean {
public:
VariableHandle operator()(VariableHandle var) {
VariableHandle out(new Variable("mean"));
get_global_tape().AddOp("mean", {{"X", {var}}}, {{"Out", {out}}}, {});
return out;
}
};
class Linear {
public:
Linear(int in_dim, int out_dim, const std::string &act)
: w_(new Variable("LinearWeight")),
b_(new Variable("LinearBias")),
act_(act) {
Tape init_tape;
std::string initializer = "fill_constant";
framework::AttributeMap attrs;
attrs["dtype"] = paddle::framework::proto::VarType::Type::VarType_Type_FP32;
attrs["shape"] = std::vector<int>{in_dim, out_dim};
attrs["value"] = 1.0f;
init_tape.AddOp(initializer, {}, {{"Out", {w_}}}, attrs);
attrs["dtype"] = paddle::framework::proto::VarType::Type::VarType_Type_FP32;
attrs["shape"] = std::vector<int>{out_dim};
attrs["value"] = 1.0f;
init_tape.AddOp(initializer, {}, {{"Out", {b_}}}, attrs);
init_tape.Forward();
}
VariableHandle operator()(VariableHandle input) {
VariableHandle pre_bias(new Variable("linear"));
get_global_tape().AddOp("mul",
{{"X", {input}}, {"Y", {w_}}},
{{"Out", {pre_bias}}},
{{"x_num_col_dims", 1}, {"y_num_col_dims", 1}});
VariableHandle pre_act(new Variable("linear"));
get_global_tape().AddOp("elementwise_add",
{{"X", {pre_bias}}, {"Y", {b_}}},
{{"Out", {pre_act}}},
{{"axis", 1}});
VariableHandle post_act(new Variable("linear"));
get_global_tape().AddOp(
act_, {{"X", {pre_act}}}, {{"Out", {post_act}}}, {});
return post_act;
}
std::vector<VariableHandle> Params() { return {w_, b_}; }
private:
VariableHandle w_;
VariableHandle b_;
std::string act_;
};
class SGD {
public:
SGD(float learning_rate) : learning_rate_(new Variable("sgd")) {
Tape init_tape;
std::string initializer = "fill_constant";
framework::AttributeMap attrs;
attrs["dtype"] = paddle::framework::proto::VarType::Type::VarType_Type_FP32;
attrs["shape"] = std::vector<int>{1};
attrs["value"] = learning_rate;
init_tape.AddOp(initializer, {}, {{"Out", {learning_rate_}}}, attrs);
init_tape.Forward();
}
void operator()(VariableHandle input) {
PADDLE_ENFORCE(get_global_tape().HasBeenBackwarded(),
"optimization must happen after the backward");
Tape temp_tape;
temp_tape.AddOp("sgd",
{{"Param", {input}},
{"LearningRate", {learning_rate_}},
{"Grad", {input->Grad()}}},
{{"ParamOut", {input}}},
{});
temp_tape.Forward();
}
private:
VariableHandle learning_rate_;
};
}
}
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/contrib/tape/tape.h"
#include <list>
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/dim.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/pybind/pybind.h"
namespace paddle {
namespace tape {
// borrowed from
// https://stackoverflow.com/questions/874134/find-if-string-ends-with-another-string-in-c
inline bool ends_with(std::string const &value, std::string const &ending) {
if (ending.size() > value.size()) return false;
return std::equal(ending.rbegin(), ending.rend(), value.rbegin());
}
std::ostream &operator<<(std::ostream &os, const framework::VarDesc &var_desc) {
os << var_desc.Name();
os << "[" << var_desc.GetType() << "]";
os << "[" << var_desc.GetDataType() << "]";
os << "{";
for (auto &i : var_desc.GetShape()) {
os << i << ",";
}
os << "}";
return os;
}
std::string to_string(const std::string &type,
const VariableHandleMap &in_vars,
const VariableHandleMap &out_vars,
const framework::AttributeMap &attrs) {
std::stringstream ss;
ss << type << " ";
for (auto &param_name : in_vars) {
for (auto &var : param_name.second) {
ss << param_name.first << ":(" << var->Desc() << ") ";
}
}
for (auto &param_name : out_vars) {
for (auto &var : param_name.second) {
ss << param_name.first << ":(" << var->Desc() << ") ";
}
}
return ss.str();
}
framework::OpDesc CreateOpDesc(const std::string &type,
const VariableHandleMap &in_vars,
const VariableHandleMap &out_vars,
const framework::AttributeMap &attrs) {
framework::VariableNameMap inputs;
for (auto &param_name : in_vars) {
for (auto &var : param_name.second) {
inputs[param_name.first].emplace_back(var->Name());
}
}
framework::VariableNameMap outputs;
for (auto &param_name : out_vars) {
for (auto &var : param_name.second) {
outputs[param_name.first].emplace_back(var->Name());
}
}
return framework::OpDesc(type, inputs, outputs, attrs);
}
void InferShapeAndVarType(const std::string &type,
const VariableHandleMap &in_vars,
VariableHandleMap *out_vars,
const framework::AttributeMap &attrs) {
framework::OpDesc op_desc = CreateOpDesc(type, in_vars, *out_vars, attrs);
// Create a temporary block for compile-time
framework::ProgramDesc program_desc;
framework::BlockDesc *block_desc = program_desc.MutableBlock(0);
PADDLE_ENFORCE(block_desc);
for (auto &param_name : in_vars) {
for (auto &var : param_name.second) {
*block_desc->Var(var->Name())->Proto() = *var->MutableDesc()->Proto();
}
}
for (auto &param_name : *out_vars) {
for (auto &var : param_name.second) {
*block_desc->Var(var->Name())->Proto() = *var->MutableDesc()->Proto();
}
}
LOG(INFO) << "- " << to_string(type, in_vars, *out_vars, attrs);
op_desc.InferShape(*block_desc);
op_desc.InferVarType(block_desc);
for (auto &param_name : *out_vars) {
for (auto &var : param_name.second) {
*var->MutableDesc()->Proto() = *block_desc->Var(var->Name())->Proto();
}
}
LOG(INFO) << "+ " << to_string(type, in_vars, *out_vars, attrs);
}
void Tape::AddOp(const std::string &type,
const VariableHandleMap &in_vars,
VariableHandleMap out_vars,
const framework::AttributeMap &attrs) {
InferShapeAndVarType(type, in_vars, &out_vars, attrs);
tape_.emplace_back(type, in_vars, out_vars, attrs);
}
// Temporary Scope for Operator::Run()
class ScopeWrapper : public framework::Scope {
public:
ScopeWrapper(const VariableHandleMap &in_vars,
const VariableHandleMap &out_vars) {
for (auto &v : in_vars) {
for (auto &vv : v.second) {
if (!vars_.count(vv->Name())) {
vars_[vv->Name()].reset(vv->Var());
}
}
}
for (auto &v : out_vars) {
for (auto &vv : v.second) {
if (!vars_.count(vv->Name())) {
vars_[vv->Name()].reset(vv->Var());
}
}
}
}
~ScopeWrapper() {
for (auto &pair : vars_) {
pair.second.release();
}
}
};
void Tape::Forward() {
LOG(INFO) << "Starting forward -------------------------";
PADDLE_ENFORCE(!has_been_backwarded_);
while (current_position_ < tape_.size()) {
OpHandle &op = tape_[current_position_];
// Create Output Tensor, this is only necessary for OpWithKernel
for (auto &param2var : op.outputs_) {
for (auto &var : param2var.second) {
var->InitializeVariable();
}
}
framework::OpDesc op_desc =
CreateOpDesc(op.type_, op.inputs_, op.outputs_, op.attrs_);
ScopeWrapper scope(op.inputs_, op.outputs_);
framework::OpRegistry::CreateOp(op_desc)->Run(scope, platform::CPUPlace());
current_position_++;
}
LOG(INFO) << "Finishing forward -------------------------";
}
void Tape::Backward(VariableHandle target) {
PADDLE_ENFORCE(!has_been_backwarded_);
Forward();
// TODO(tonyyang-svail): check output of last op is target
backward_tape_.reset(new Tape());
framework::AttributeMap attrs;
// FIXME(tonyyang-svail): Need to infer_data_type
attrs["dtype"] = framework::proto::VarType::Type::VarType_Type_FP32;
attrs["shape"] = std::vector<int>{1};
attrs["value"] = 1.0f;
backward_tape_->AddOp(
"fill_constant", {}, {{"Out", {target->Grad()}}}, attrs);
for (auto it = tape_.rbegin(); it != tape_.rend(); ++it) {
framework::OpDesc op_desc =
CreateOpDesc(it->type_, it->inputs_, it->outputs_, it->attrs_);
std::unordered_map<std::string, std::string> grad_to_var;
std::vector<std::unique_ptr<framework::OpDesc>> grad_op_descs =
framework::OpInfoMap::Instance()
.Get(op_desc.Type())
.GradOpMaker()(op_desc, {}, &grad_to_var, {});
for (auto &op_desc : grad_op_descs) {
std::unordered_map<std::string, VariableHandle> name2var;
for (auto &param2vars : it->inputs_) {
for (auto &a : param2vars.second) {
name2var[a->Name()] = a;
}
}
for (auto &param2vars : it->outputs_) {
for (auto &a : param2vars.second) {
name2var[a->Name()] = a;
}
}
VariableHandleMap in_vars;
VariableHandleMap out_vars;
std::map<const framework::VariableNameMap *, VariableHandleMap *>
loop_over{{&op_desc->Inputs(), &in_vars},
{&op_desc->Outputs(), &out_vars}};
for (auto &each : loop_over) {
auto &vmp = *each.first;
auto &vhm = *each.second;
for (auto &p2a : vmp) {
for (auto &argu : p2a.second) {
if (name2var.count(argu)) {
vhm[p2a.first].push_back(name2var[argu]);
} else {
PADDLE_ENFORCE(ends_with(argu, framework::kGradVarSuffix),
argu.c_str());
std::string name = argu.substr(
0, argu.size() - std::strlen(framework::kGradVarSuffix));
PADDLE_ENFORCE(name2var.count(name), name.c_str());
vhm[p2a.first].push_back(name2var[name]->Grad());
}
}
}
}
backward_tape_->AddOp(
op_desc->Type(), in_vars, out_vars, op_desc->GetAttrMap());
}
// TODO(tonyyang-svail): how to fill empty grad?
// TODO(tonyyang-svail): Sum var grad is necessary
}
backward_tape_->Forward();
has_been_backwarded_ = true;
}
Tape &get_global_tape() {
static Tape T;
return T;
}
void reset_global_tape() { get_global_tape() = Tape(); }
}
}
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "paddle/fluid/framework/operator.h" // framework::kGradVarSuffix
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/variable.h"
namespace paddle {
namespace tape {
class Variable;
using VariableHandle = std::shared_ptr<Variable>;
/*
* Combination of
* framework::VarDesc desc_;
* framework::Variable var_;
*/
class Variable {
public:
Variable(const std::string pre_fix)
: desc_(pre_fix + std::to_string(count())) {}
Variable(const std::string pre_fix, bool is_grad)
: desc_(pre_fix + (is_grad ? framework::kGradVarSuffix
: std::to_string(count()))) {}
~Variable() { LOG(INFO) << "Deleting " << Name(); }
// Instantiate LoDTensor/SelectedRow
void InitializeVariable();
VariableHandle Grad() {
if (grad_.expired()) {
VariableHandle new_grad(new Variable(desc_.Name(), true));
grad_ = new_grad;
return new_grad;
} else {
return VariableHandle(grad_);
}
}
// Stochastic Gradient Descent with Momentum
// VariableHandle Momentum ();
// void init(const std::string& initializer,
// const framework::AttributeMap& attrs);
// void value() {};
const framework::VarDesc& Desc() const { return desc_; }
framework::VarDesc* MutableDesc() { return &desc_; }
// TODO(tonyyang-svail): No need to expose name
std::string Name() const { return desc_.Name(); }
framework::Variable* Var() { return &var_; }
private:
int count() {
static int counter = 0;
return counter++;
}
framework::VarDesc desc_;
framework::Variable var_;
std::weak_ptr<Variable> grad_;
};
}
}
set(FLUID_CORE_MODULES proto_desc memory lod_tensor executor init) set(FLUID_CORE_MODULES proto_desc memory lod_tensor executor init)
cc_library(analysis SRCS dot.cc node.cc data_flow_graph.cc graph_traits.cc subgraph_splitter.cc fluid_to_data_flow_graph_pass.cc cc_library(analysis SRCS pass_manager.cc dot.cc node.cc data_flow_graph.cc graph_traits.cc subgraph_splitter.cc
DEPS paddle_fluid) fluid_to_data_flow_graph_pass.cc
data_flow_graph_to_fluid_pass.cc
tensorrt_subgraph_pass.cc
dfg_graphviz_draw_pass.cc
DEPS framework_proto)
cc_test(test_node SRCS node_tester.cc DEPS analysis) cc_test(test_node SRCS node_tester.cc DEPS analysis)
cc_test(test_dot SRCS dot_tester.cc DEPS analysis) cc_test(test_dot SRCS dot_tester.cc DEPS analysis)
set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests) set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests)
cc_test(test_data_flow_graph SRCS data_flow_graph_tester.cc DEPS analysis ${FLUID_CORE_MODULES} paddle_fluid function (inference_analysis_test TARGET)
ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model) set(options "")
set_tests_properties(test_data_flow_graph PROPERTIES DEPENDS test_word2vec) set(oneValueArgs "")
set(multiValueArgs SRCS)
cmake_parse_arguments(analysis_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
cc_test(test_subgraph_splitter cc_test(${TARGET}
SRCS subgraph_splitter_tester.cc SRCS "${analysis_test_SRCS}"
DEPS analysis paddle_fluid tensor DEPS analysis
ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model) ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model --fraction_of_gpu_memory_to_use=0.5)
set_tests_properties(test_subgraph_splitter PROPERTIES DEPENDS test_word2vec) set_tests_properties(${TARGET} PROPERTIES DEPENDS test_word2vec)
endfunction(inference_analysis_test)
cc_test(test_dfg_graphviz_draw_pass inference_analysis_test(test_data_flow_graph SRCS data_flow_graph_tester.cc)
SRCS dfg_graphviz_draw_pass_tester.cc inference_analysis_test(test_data_flow_graph_to_fluid_pass SRCS data_flow_graph_to_fluid_pass_tester.cc)
DEPS analysis inference_analysis_test(test_fluid_to_data_flow_graph_pass SRCS fluid_to_data_flow_graph_pass_tester.cc)
ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model) inference_analysis_test(test_subgraph_splitter SRCS subgraph_splitter_tester.cc)
set_tests_properties(test_dfg_graphviz_draw_pass PROPERTIES DEPENDS test_word2vec) inference_analysis_test(test_dfg_graphviz_draw_pass SRCS dfg_graphviz_draw_pass_tester.cc)
#inference_analysis_test(test_tensorrt_subgraph_pass SRCS tensorrt_subgraph_pass_tester.cc)
inference_analysis_test(test_pass_manager SRCS pass_manager_tester.cc)
...@@ -12,22 +12,4 @@ ...@@ -12,22 +12,4 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/contrib/tape/variable.h" #include "paddle/fluid/inference/analysis/argument.h"
namespace paddle {
namespace tape {
void Variable::InitializeVariable() {
LOG(INFO) << "Initialzing " << desc_.Name() << " as " << desc_.GetType();
framework::proto::VarType::Type var_type = desc_.GetType();
if (var_type == framework::proto::VarType::LOD_TENSOR) {
var_.GetMutable<framework::LoDTensor>();
} else if (var_type == framework::proto::VarType::SELECTED_ROWS) {
var_.GetMutable<framework::SelectedRows>();
} else {
PADDLE_THROW("Variable type %d is not in [LOD_TENSOR, SELECTED_ROWS]",
var_type);
}
}
}
}
...@@ -11,54 +11,45 @@ ...@@ -11,54 +11,45 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#pragma once
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "paddle/contrib/tape/variable.h"
namespace paddle {
namespace tape {
using VariableHandleMap = std::map<std::string, std::vector<VariableHandle>>;
struct OpHandle { /*
OpHandle(const std::string &type, * This file defines the class Argument, which is the input and output of the
const VariableHandleMap &in_vars, * analysis module. All the fields that needed either by Passes or PassManagers
const VariableHandleMap &out_vars, * are contained in Argument.
const framework::AttributeMap &attrs) *
: type_(type), inputs_(in_vars), outputs_(out_vars), attrs_(attrs) {} * TODO(Superjomn) Find some way better to contain the fields when it grow too
* big.
*/
std::string type_; #pragma once
VariableHandleMap inputs_;
VariableHandleMap outputs_;
framework::AttributeMap attrs_;
};
class Tape {
public:
void AddOp(const std::string &type,
const VariableHandleMap &in_vars,
VariableHandleMap out_vars,
const framework::AttributeMap &attrs);
void Forward();
void Backward(VariableHandle target);
bool HasBeenBackwarded() { return has_been_backwarded_; }
private: #include "paddle/fluid/framework/program_desc.h"
bool has_been_backwarded_ = false; #include "paddle/fluid/inference/analysis/data_flow_graph.h"
size_t current_position_ = 0;
std::vector<OpHandle> tape_; namespace paddle {
std::shared_ptr<Tape> backward_tape_; namespace inference {
namespace analysis {
/*
* The argument definition of both Pass and PassManagers.
*
* All the fields should be registered here for clearness.
*/
struct Argument {
// The graph that process by the Passes or PassManagers.
std::unique_ptr<DataFlowGraph> main_dfg;
// The original program desc.
std::unique_ptr<framework::proto::ProgramDesc> origin_program_desc;
}; };
Tape &get_global_tape(); #define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)
#define ANALYSIS_ARGUMENT_CHECK_FIELD(field__) \
if (UNLIKELY(!(field__))) { \
LOG(ERROR) << "field " << #field__ << " should be set."; \
return false; \
}
void reset_global_tape(); } // namespace analysis
} } // namespace inference
} } // namespace paddle
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/inference/analysis/data_flow_graph.h" #include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/inference/analysis/dot.h" #include "paddle/fluid/inference/analysis/dot.h"
#include "paddle/fluid/inference/analysis/node.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
...@@ -57,19 +58,7 @@ std::string DataFlowGraph::DotString() const { ...@@ -57,19 +58,7 @@ std::string DataFlowGraph::DotString() const {
// Add nodes // Add nodes
for (size_t i = 0; i < nodes.size(); i++) { for (size_t i = 0; i < nodes.size(); i++) {
const Node &node = nodes.Get(i); const Node &node = nodes.Get(i);
switch (node.type()) { dot.AddNode(node.repr(), node.dot_attrs());
case Node::Type::kValue:
dot.AddNode(node.repr(), node.dot_attrs());
break;
case Node::Type::kFunction:
dot.AddNode(node.repr(), node.dot_attrs());
break;
case Node::Type::kFunctionBlock:
dot.AddNode(node.repr(), node.dot_attrs());
break;
default:
PADDLE_THROW("unsupported Node type %d", static_cast<int>(node.type()));
}
} }
// Add edges // Add edges
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h"
#include "paddle/fluid/framework/proto_desc.h"
namespace paddle {
namespace inference {
namespace analysis {
bool DataFlowGraphToFluidPass::Initialize(Argument* argument) {
ANALYSIS_ARGUMENT_CHECK_FIELD(argument)
ANALYSIS_ARGUMENT_CHECK_FIELD(argument->origin_program_desc)
desc_ = argument->origin_program_desc.get();
// Here some logic from program_desc.cc and will not add new interfaces into
// framework::ProgramDesc class, use some UT to assure the correctness.
auto* block = desc_->mutable_blocks()->Add();
block->set_idx(framework::kRootBlockIndex);
block->set_parent_idx(framework::kNoneBlockIndex);
return true;
}
bool DataFlowGraphToFluidPass::Finalize() { return true; }
void DataFlowGraphToFluidPass::Run(DataFlowGraph* graph) {
auto traits = GraphTraits<DataFlowGraph>(graph);
for (auto it = traits.nodes().begin(); it != traits.nodes().end(); ++it) {
if (it->deleted()) continue;
switch (it->type()) {
case Node::Type::kFunction:
LOG(INFO) << "add function " << it->name();
AddFluidOp(&(*it));
break;
case Node::Type::kFunctionBlock:
AddEngineOp(&(*it));
break;
default:
continue;
}
}
}
void DataFlowGraphToFluidPass::AddFluidOp(Node* node) {
LOG(INFO) << "processing func " << node->name();
auto* ori_op = static_cast<framework::proto::OpDesc*>(node->pb_desc());
// currently only the main block is analyzed.
auto* main_block = desc_->mutable_blocks(framework::kRootBlockIndex);
auto* op = main_block->add_ops();
LOG(INFO) << "to copy the op";
*op = *ori_op; // copy the attributes, by default, these will not be changed
// by analysis phrase.
// The inputs and outputs of the existing ops are not changed by tensorrt
// subgraph pass.
// NOTE It might be changed by other passes in the long run.
}
void DataFlowGraphToFluidPass::AddEngineOp(Node* node) {
// auto* ori_op = static_cast<framework::proto::OpDesc*>(node->extra_info());
// auto* main_block = desc_->mutable_blocks(framework::kRootBlockIndex);
// auto* op = main_block->add_ops();
// TODO(Superjomn) Here need to expose some arguments for default setting.
}
} // namespace analysis
} // namespace inference
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
/*
* This file implements the transformation from fluid ProgramDesc to data flow
* graph.
*/
#pragma once
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/inference/analysis/pass.h"
namespace paddle {
namespace inference {
namespace analysis {
class DataFlowGraphToFluidPass final : public DataFlowGraphPass {
public:
DataFlowGraphToFluidPass() = default;
bool Initialize(Argument *argument) override;
bool Finalize() override;
void Run(DataFlowGraph *graph) override;
std::string repr() const override { return "DFG to fluid"; }
std::string description() const override {
return "Transform a DFG to a Fluid ProgramDesc";
}
Pass *CreatePrinterPass(std::ostream &os,
const std::string &banner) const override {
return nullptr;
}
protected:
// Add a Fluid Op into the ProgramDesc.
void AddFluidOp(Node *node);
// Add a EngineOp into the ProgramDesc.
void AddEngineOp(Node *node);
private:
framework::proto::ProgramDesc *desc_;
};
} // namespace analysis
} // namespace inference
} // namespace paddle
...@@ -27,13 +27,12 @@ namespace inference { ...@@ -27,13 +27,12 @@ namespace inference {
namespace analysis { namespace analysis {
TEST_F(DFG_Tester, Test) { TEST_F(DFG_Tester, Test) {
framework::proto::ProgramDesc new_desc;
DataFlowGraph graph; DataFlowGraph graph;
FluidToDataFlowGraphPass pass0; FluidToDataFlowGraphPass pass0;
DataFlowGraphToFluidPass pass1; DataFlowGraphToFluidPass pass1;
pass0.Initialize(desc); ASSERT_TRUE(pass0.Initialize(&argument));
pass1.Initialize(&new_desc); ASSERT_TRUE(pass1.Initialize(&argument));
pass0.Run(&graph); pass0.Run(&graph);
pass1.Run(&graph); pass1.Run(&graph);
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h"
namespace paddle {
namespace inference {
namespace analysis {
void DFG_GraphvizDrawPass::Run(DataFlowGraph *graph) {
auto content = Draw(graph);
std::ofstream file(GenDotPath());
file.write(content.c_str(), content.size());
file.close();
LOG(INFO) << "draw dot to " << GenDotPath();
}
std::string DFG_GraphvizDrawPass::Draw(DataFlowGraph *graph) {
Dot dot;
// Add nodes
for (size_t i = 0; i < graph->nodes.size(); i++) {
const Node &node = graph->nodes.Get(i);
if (config_.display_deleted_node || !node.deleted()) {
dot.AddNode(node.repr(), node.dot_attrs());
}
}
// Add edges
for (size_t i = 0; i < graph->nodes.size(); i++) {
const Node &node = graph->nodes.Get(i);
if (!config_.display_deleted_node && node.deleted()) continue;
for (auto &in : node.inlinks) {
if (!config_.display_deleted_node && in->deleted()) continue;
for (auto &in : node.inlinks) {
dot.AddEdge(in->repr(), node.repr(), {});
}
}
}
return dot.Build();
}
} // namespace analysis
} // namespace inference
} // namespace paddle
...@@ -21,6 +21,7 @@ limitations under the License. */ ...@@ -21,6 +21,7 @@ limitations under the License. */
#include <fstream> #include <fstream>
#include <string> #include <string>
#include "paddle/fluid/inference/analysis/dot.h"
#include "paddle/fluid/inference/analysis/pass.h" #include "paddle/fluid/inference/analysis/pass.h"
namespace paddle { namespace paddle {
...@@ -32,35 +33,39 @@ namespace analysis { ...@@ -32,35 +33,39 @@ namespace analysis {
*/ */
class DFG_GraphvizDrawPass : public DataFlowGraphPass { class DFG_GraphvizDrawPass : public DataFlowGraphPass {
public: public:
DFG_GraphvizDrawPass(const std::string& dir, const std::string& id) struct Config {
: dir_(dir), id_(id) {} Config(const std::string &dir, const std::string &id,
bool display_deleted_node = false)
bool Initialize() override { return Pass::Initialize(); } : dir(dir), id(id), display_deleted_node(display_deleted_node) {}
void Run(DataFlowGraph* graph) override {
auto content = Draw(graph); // The directory to store the .dot or .png files.
std::ofstream file(GenDotPath()); const std::string dir;
file.write(content.c_str(), content.size()); // The identifier for this dot file.
file.close(); const std::string id;
LOG(INFO) << "draw dot to " << GenDotPath(); // Whether to display deleted nodes, default false.
} const bool display_deleted_node;
};
DFG_GraphvizDrawPass(const Config &config) : config_(config) {}
bool Initialize(Argument *argument) override { return true; }
void Run(DataFlowGraph *graph) override;
bool Finalize() override { return Pass::Finalize(); } bool Finalize() override { return Pass::Finalize(); }
Pass* CreatePrinterPass(std::ostream& os, std::string repr() const override { return "DFG graphviz drawer"; }
const std::string& banner) const override { std::string description() const override {
return nullptr; return "Debug a DFG by draw with graphviz";
} }
private: private:
// Path of the dot file to output. // Path of the dot file to output.
std::string GenDotPath() const { std::string GenDotPath() const {
return dir_ + "/" + "graph_" + id_ + ".dot"; return config_.dir + "/" + "graph_" + config_.id + ".dot";
} }
std::string Draw(DataFlowGraph* graph) { return graph->DotString(); } std::string Draw(DataFlowGraph *graph);
std::string dir_; Config config_;
std::string id_;
}; };
} // namespace analysis } // namespace analysis
......
...@@ -24,9 +24,10 @@ namespace inference { ...@@ -24,9 +24,10 @@ namespace inference {
namespace analysis { namespace analysis {
TEST_F(DFG_Tester, dfg_graphviz_draw_pass_tester) { TEST_F(DFG_Tester, dfg_graphviz_draw_pass_tester) {
auto dfg = ProgramDescToDFG(desc); auto dfg = ProgramDescToDFG(*argument.origin_program_desc);
DFG_GraphvizDrawPass pass("./", "test"); DFG_GraphvizDrawPass::Config config("./", "test");
pass.Initialize(); DFG_GraphvizDrawPass pass(config);
pass.Initialize(&argument);
pass.Run(&dfg); pass.Run(&dfg);
// test content // test content
...@@ -38,7 +39,8 @@ TEST_F(DFG_Tester, dfg_graphviz_draw_pass_tester) { ...@@ -38,7 +39,8 @@ TEST_F(DFG_Tester, dfg_graphviz_draw_pass_tester) {
while (std::getline(file, line)) { while (std::getline(file, line)) {
no++; no++;
} }
ASSERT_EQ(no, 82); // DFG is sensitive to ProgramDesc, be careful to change the existing models.
ASSERT_EQ(no, 112);
} }
} // namespace analysis } // namespace analysis
......
...@@ -21,19 +21,23 @@ namespace paddle { ...@@ -21,19 +21,23 @@ namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
FluidToDataFlowGraphPass::FluidToDataFlowGraphPass() {} bool FluidToDataFlowGraphPass::Initialize(Argument *argument) {
ANALYSIS_ARGUMENT_CHECK_FIELD(argument);
bool FluidToDataFlowGraphPass::Initialize() { return Pass::Initialize(); } ANALYSIS_ARGUMENT_CHECK_FIELD(argument->origin_program_desc);
PADDLE_ENFORCE(argument);
bool FluidToDataFlowGraphPass::Initialize( if (!argument->main_dfg) {
const framework::proto::ProgramDesc &desc) { LOG(INFO) << "Init DFG";
desc_ = &desc; argument->main_dfg.reset(new DataFlowGraph);
}
desc_ = argument->origin_program_desc.get();
return true; return true;
} }
bool FluidToDataFlowGraphPass::Finalize() { return Pass::Finalize(); } bool FluidToDataFlowGraphPass::Finalize() { return Pass::Finalize(); }
void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) { void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) {
PADDLE_ENFORCE(graph);
PADDLE_ENFORCE(desc_);
// insert vars // insert vars
std::unordered_map<std::string, size_t> var2id; std::unordered_map<std::string, size_t> var2id;
auto &main_block = desc_->blocks(framework::kRootBlockIndex); auto &main_block = desc_->blocks(framework::kRootBlockIndex);
...@@ -41,7 +45,7 @@ void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) { ...@@ -41,7 +45,7 @@ void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) {
const auto &var = main_block.vars(i); const auto &var = main_block.vars(i);
auto *v = graph->nodes.Create(Node::Type::kValue); auto *v = graph->nodes.Create(Node::Type::kValue);
v->SetName(var.name()); v->SetName(var.name());
v->SetExtraInfo(const_cast<void *>(static_cast<const void *>(&var))); v->SetPbDesc(const_cast<void *>(static_cast<const void *>(&var)));
var2id[var.name()] = v->id(); var2id[var.name()] = v->id();
} }
for (int i = 0; i < main_block.ops_size(); i++) { for (int i = 0; i < main_block.ops_size(); i++) {
...@@ -51,7 +55,7 @@ void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) { ...@@ -51,7 +55,7 @@ void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) {
static_cast<Function *>(o)->SetFuncType(op.type()); static_cast<Function *>(o)->SetFuncType(op.type());
// Link to the original protobuf message's memory, make it easier to // Link to the original protobuf message's memory, make it easier to
// generate from a data flow graph to fluid ProgramDesc. // generate from a data flow graph to fluid ProgramDesc.
o->SetExtraInfo(const_cast<void *>(static_cast<const void *>(&op))); o->SetPbDesc(const_cast<void *>(static_cast<const void *>(&op)));
// set inputs and outputs // set inputs and outputs
// TODO(Superjomn) make sure the InputNames is the real variable name. // TODO(Superjomn) make sure the InputNames is the real variable name.
for (int j = 0; j < op.inputs_size(); j++) { for (int j = 0; j < op.inputs_size(); j++) {
......
...@@ -34,13 +34,18 @@ namespace analysis { ...@@ -34,13 +34,18 @@ namespace analysis {
*/ */
class FluidToDataFlowGraphPass final : public DataFlowGraphPass { class FluidToDataFlowGraphPass final : public DataFlowGraphPass {
public: public:
FluidToDataFlowGraphPass(); FluidToDataFlowGraphPass() = default;
bool Initialize() override;
bool Initialize(const framework::proto::ProgramDesc &desc) override; bool Initialize(Argument *argument) override;
bool Finalize() override; bool Finalize() override;
void Run(DataFlowGraph *graph) override; void Run(DataFlowGraph *graph) override;
std::string repr() const override { return "fluid-to-data-flow-graph"; }
std::string description() const override {
return "transform a fluid ProgramDesc to a data flow graph.";
}
Pass *CreatePrinterPass(std::ostream &os, Pass *CreatePrinterPass(std::ostream &os,
const std::string &banner) const override; const std::string &banner) const override;
......
...@@ -23,11 +23,11 @@ namespace analysis { ...@@ -23,11 +23,11 @@ namespace analysis {
TEST_F(DFG_Tester, Init) { TEST_F(DFG_Tester, Init) {
FluidToDataFlowGraphPass pass; FluidToDataFlowGraphPass pass;
pass.Initialize(); pass.Initialize(&argument);
pass.Initialize(desc);
DataFlowGraph graph; DataFlowGraph graph;
pass.Run(&graph); pass.Run(&graph);
ASSERT_GT(graph.nodes.size(), 0); // Analysis is sensitive to ProgramDesc, careful to change the original model.
ASSERT_EQ(graph.nodes.size(), 37);
pass.Finalize(); pass.Finalize();
LOG(INFO) << '\n' << graph.DotString(); LOG(INFO) << '\n' << graph.DotString();
} }
......
...@@ -62,6 +62,7 @@ struct DataTypeNamer { ...@@ -62,6 +62,7 @@ struct DataTypeNamer {
SET_TYPE(int); SET_TYPE(int);
SET_TYPE(bool); SET_TYPE(bool);
SET_TYPE(float); SET_TYPE(float);
SET_TYPE(void *);
} }
std::unordered_map<decltype(typeid(int).hash_code()), // NOLINT std::unordered_map<decltype(typeid(int).hash_code()), // NOLINT
......
...@@ -40,6 +40,9 @@ Node *NodeMap::Create(Node::Type type) { ...@@ -40,6 +40,9 @@ Node *NodeMap::Create(Node::Type type) {
case Node::Type::kValue: case Node::Type::kValue:
nodes_.emplace_back(new Value); nodes_.emplace_back(new Value);
break; break;
case Node::Type::kFunctionBlock:
nodes_.emplace_back(new FunctionBlock);
break;
default: default:
PADDLE_THROW("Not supported node type."); PADDLE_THROW("Not supported node type.");
} }
......
...@@ -71,12 +71,17 @@ class Node { ...@@ -71,12 +71,17 @@ class Node {
// Get an additional attribute and convert it to T data type. NOTE this will // Get an additional attribute and convert it to T data type. NOTE this will
// silently create a new attribute if not exists. // silently create a new attribute if not exists.
Attr &attr(const std::string &name) { return attrs_[name]; } Attr &attr(const std::string &name) const { return attrs_[name]; }
int id() const { return id_; } int id() const { return id_; }
bool deleted() const { return deleted_; } // The Protobuf description is set/get with a void* to decouple Node interface
// from a specific kind of Protobuf message.
void SetPbDesc(void *pb) { attr("pb_desc").Pointer() = pb; }
void *pb_desc() const { return attr("pb_desc").Pointer(); }
void SetDeleted() { deleted_ = true; } void SetDeleted() { deleted_ = true; }
bool deleted() const { return deleted_; }
void SetName(const std::string &name) { name_ = name; } void SetName(const std::string &name) { name_ = name; }
const std::string &name() const { return name_; } const std::string &name() const { return name_; }
...@@ -84,29 +89,25 @@ class Node { ...@@ -84,29 +89,25 @@ class Node {
void SetType(Type type) { type_ = type; } void SetType(Type type) { type_ = type; }
Type type() const { return type_; } Type type() const { return type_; }
void *extra_info() const { return extra_info_; }
void SetExtraInfo(void *extra_info) { extra_info_ = extra_info; }
// Input links. // Input links.
std::vector<Node *> inlinks; std::vector<Node *> inlinks;
// Output links. // Output links.
std::vector<Node *> outlinks; std::vector<Node *> outlinks;
// A helper class to maintain the status from Pass. // A helper class to maintain the status from Pass.
// TODO(superjomn) add a checker here to ensure the T is primary.
struct Attr { struct Attr {
// NOTE T should be a primary type or a struct combined by several primary // NOTE T should be a primary type or a struct combined by several primary
// types. // types.
// NOTE the STL containers should not use here. // NOTE the STL containers should not use here.
// Some usages // Some usages
// Attr attr; // Attr attr;
// T data; // attr.Bool() = true;
// attr.data.assign((char*)data, sizeof(data));
bool &Bool() { return As<bool>(); } bool &Bool() { return As<bool>(); }
float &Float() { return As<float>(); } float &Float() { return As<float>(); }
int32_t &Int32() { return As<int32_t>(); } int32_t &Int32() { return As<int32_t>(); }
int64_t &Int64() { return As<int64_t>(); } int64_t &Int64() { return As<int64_t>(); }
void *&Pointer() { return As<void *>(); }
private: private:
template <typename T> template <typename T>
...@@ -130,6 +131,7 @@ class Node { ...@@ -130,6 +131,7 @@ class Node {
size_t type_hash_{std::numeric_limits<size_t>::max()}; size_t type_hash_{std::numeric_limits<size_t>::max()};
}; };
// Type checks.
bool IsFunction() const { return type_ == Node::Type::kFunction; } bool IsFunction() const { return type_ == Node::Type::kFunction; }
bool IsValue() const { return type_ == Node::Type::kValue; } bool IsValue() const { return type_ == Node::Type::kValue; }
bool IsFunctionBlock() const { return type_ == Node::Type::kFunctionBlock; } bool IsFunctionBlock() const { return type_ == Node::Type::kFunctionBlock; }
...@@ -148,9 +150,6 @@ class Node { ...@@ -148,9 +150,6 @@ class Node {
Type type_{Type::kNone}; Type type_{Type::kNone};
// Mark this node is deleted by some pass. // Mark this node is deleted by some pass.
bool deleted_{false}; bool deleted_{false};
void *extra_info_;
mutable std::unordered_map<std::string, Attr> attrs_; mutable std::unordered_map<std::string, Attr> attrs_;
}; };
......
...@@ -19,6 +19,7 @@ limitations under the License. */ ...@@ -19,6 +19,7 @@ limitations under the License. */
#include <string> #include <string>
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/inference/analysis/argument.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h" #include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/inference/analysis/helper.h" #include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/analysis/node.h" #include "paddle/fluid/inference/analysis/node.h"
...@@ -30,19 +31,24 @@ namespace analysis { ...@@ -30,19 +31,24 @@ namespace analysis {
class Pass { class Pass {
public: public:
Pass() = default; Pass() = default;
virtual ~Pass() {} virtual ~Pass() = default;
// Virtual method overridden by subclasses to do only necessary initialization // Virtual method overridden by subclasses to do only necessary initialization
// before any pass is run. // before any pass is run.
virtual bool Initialize() { return false; } // virtual bool Initialize() { return false; }
// There is some passes such as FlowToDataFlowGraphPass that needs a // There is some passes such as FlowToDataFlowGraphPass that needs a
// ProgramDesc. Here use the native ProgramDesc ProtoBuf message, so that it // ProgramDesc. Here use the native ProgramDesc ProtoBuf message, so that it
// only couple with the proto file. // only couple with the proto file.
virtual bool Initialize(const framework::proto::ProgramDesc &desc) { // virtual bool Initialize(const framework::proto::ProgramDesc &desc) { return
return false; // false; }
}
// There are some Passes such as DataFlowGraphToFluidPass that will output a // There are some Passes such as DataFlowGraphToFluidPass that will output a
// ProgramDesc. // ProgramDesc.
virtual bool Initialize(framework::proto::ProgramDesc *desc) { return false; } // virtual bool Initialize(framework::proto::ProgramDesc *desc) { return
// false; }
// Mutable Pass.
virtual bool Initialize(Argument *argument) { return false; }
// Readonly Pass.
virtual bool Initialize(const Argument &argument) { return false; }
// Virtual method overriden by subclasses to do any necessary clean up after // Virtual method overriden by subclasses to do any necessary clean up after
// all passes have run. // all passes have run.
...@@ -50,7 +56,9 @@ class Pass { ...@@ -50,7 +56,9 @@ class Pass {
// Get a Pass appropriate to print the Node this pass operates on. // Get a Pass appropriate to print the Node this pass operates on.
virtual Pass *CreatePrinterPass(std::ostream &os, virtual Pass *CreatePrinterPass(std::ostream &os,
const std::string &banner) const = 0; const std::string &banner) const {
return nullptr;
}
// Run on a single Node. // Run on a single Node.
virtual void Run(Node *x) { LOG(FATAL) << "not valid"; } virtual void Run(Node *x) { LOG(FATAL) << "not valid"; }
...@@ -60,6 +68,11 @@ class Pass { ...@@ -60,6 +68,11 @@ class Pass {
virtual void Run(FunctionBlock *x) { LOG(FATAL) << "not valid"; } virtual void Run(FunctionBlock *x) { LOG(FATAL) << "not valid"; }
// Run on a single DataFlowGraph. // Run on a single DataFlowGraph.
virtual void Run(DataFlowGraph *x) { LOG(FATAL) << "not valid"; } virtual void Run(DataFlowGraph *x) { LOG(FATAL) << "not valid"; }
// Human-readable short representation.
virtual std::string repr() const = 0;
// Human-readable long description.
virtual std::string description() const = 0;
}; };
// NodePass process on any Node types. // NodePass process on any Node types.
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/analysis/pass_manager.h"
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
namespace paddle {
namespace inference {
namespace analysis {
void DfgPassManager::RunAll() {
PADDLE_ENFORCE(argument_);
for (auto& pass : data_) {
VLOG(4) << "Running pass [" << pass->repr() << "]";
pass->Run(argument_->main_dfg.get());
}
}
void NodePassManager::RunAll() {
PADDLE_ENFORCE(argument_);
PADDLE_ENFORCE(argument_->main_dfg.get());
auto trait =
GraphTraits<DataFlowGraph>(argument_->main_dfg.get()).nodes_in_DFS();
for (auto& node : trait) {
for (auto& pass : data_) {
pass->Run(&node);
}
}
}
} // namespace analysis
} // namespace inference
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
/*
* This file defines the logic of pass management. The analysis for inference is
* a pipeline of Passes, a PassManager is a agency that helps to manage the
* executation of the Passes.
*
* There are two modes of Passes, the first one is called NodePass and takes
* an Node as input and output; the second one is called DFGPass and takes a
* DFG(Data Flow Graph) as input and output. It is hard to put all the passes in
* the same pipeline, there are two kinds of PassManagers, both takes a DFG as
* input and output a DFG, but the Passes inside are different:
*
* 1. NodePassManager: the passes inside are all NodePasses, it can have
* different graph trivial algorithm, for example, DFS_NodePassManager will
* trigger the passes in depth first order;
* 2. DfgPassManager: the passes inside are all DfgPasses.
*/
#pragma once
#include <string>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/inference/analysis/pass.h"
namespace paddle {
namespace inference {
namespace analysis {
/*
* PassManager is the base class for all pass managers, a pass manager has
* several Pass-es registered, and execute them in the linear order.
*/
class PassManager : public OrderedRegistry<Pass> {
public:
PassManager() = default;
// Call all the passes' Initialize methods. The desc and data_flow_graph are
// globally shared, so pass them as the arguemnts for all the pass managers.
virtual bool Initialize(const Argument& argument) { return false; }
virtual bool Initialize(Argument* argument) {
argument_ = argument;
for (auto& pass : data_) {
LOG(INFO) << "Initializing pass " << pass->repr();
if (!pass->Initialize(argument)) {
LOG(ERROR) << "Failed to initialize pass [" << pass->repr() << "]";
return false;
}
}
return true;
}
// Call all the passes' Finalize methods.
virtual bool Finalize() {
for (auto& pass : data_) {
if (!pass->Finalize()) {
LOG(ERROR) << "Failed to finalize pass [" << pass->repr() << "]";
return false;
}
}
return true;
}
// Run all the passes.
virtual void RunAll() = 0;
// Short identifier.
virtual std::string repr() const = 0;
// Long description.
virtual std::string description() const = 0;
virtual ~PassManager() = default;
protected:
Argument* argument_{nullptr};
};
/*
* A pass manager that process a DFG.
*/
class DfgPassManager : public PassManager {
public:
DfgPassManager() = default;
void RunAll() override;
virtual ~DfgPassManager() = default;
};
/*
* A pass manager that process a Node each time.
*/
class NodePassManager : public PassManager {
public:
NodePassManager() = default;
void RunAll() override;
virtual ~NodePassManager() = default;
};
} // namespace analysis
} // namespace inference
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/analysis/pass_manager.h"
#include "paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h"
#include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h"
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
#include "paddle/fluid/inference/analysis/ut_helper.h"
#include <gtest/gtest.h>
namespace paddle {
namespace inference {
namespace analysis {
class TestDfgPassManager final : public DfgPassManager {
public:
TestDfgPassManager() = default;
virtual ~TestDfgPassManager() = default;
// Short identifier.
std::string repr() const override { return "test-pass-manager"; }
// Long description.
std::string description() const override { return "test doc"; }
};
class TestNodePassManager final : public NodePassManager {
public:
virtual ~TestNodePassManager() = default;
std::string repr() const override { return "test-node-pass-manager"; }
std::string description() const override { return "test doc"; }
};
class TestNodePass final : public NodePass {
public:
virtual ~TestNodePass() = default;
bool Initialize(Argument* argument) override { return true; }
void Run(Node* node) override {
LOG(INFO) << "- Processing node " << node->repr();
}
std::string repr() const override { return "test-node"; }
std::string description() const override { return "some doc"; }
};
TEST_F(DFG_Tester, DFG_pass_manager) {
TestDfgPassManager manager;
DFG_GraphvizDrawPass::Config config("./", "dfg.dot");
manager.Register("fluid-to-flow-graph", new FluidToDataFlowGraphPass);
manager.Register("graphviz", new DFG_GraphvizDrawPass(config));
manager.Register("dfg-to-fluid", new DataFlowGraphToFluidPass);
ASSERT_TRUE(manager.Initialize(&argument));
manager.RunAll();
}
TEST_F(DFG_Tester, Node_pass_manager) {
// Pre-process: initialize the DFG with the ProgramDesc first.
FluidToDataFlowGraphPass pass0;
pass0.Initialize(&argument);
pass0.Run(argument.main_dfg.get());
TestNodePassManager manager;
manager.Register("test-node-pass", new TestNodePass);
ASSERT_TRUE(manager.Initialize(&argument));
manager.RunAll();
}
} // namespace analysis
} // namespace inference
} // namespace paddle
...@@ -19,22 +19,23 @@ namespace paddle { ...@@ -19,22 +19,23 @@ namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
SubGraphSplitter::NodeInsideSubgraphTeller teller = [](const Node* node) {
if (node->type() != Node::Type::kFunction) return false;
const auto* func = static_cast<const Function*>(node);
if (func->func_type() == "elementwise_add" || func->func_type() == "relu" ||
func->func_type() == "conv2d" || func->func_type() == "mul" ||
func->func_type() == "sigmoid" || func->func_type() == "softmax") {
LOG(INFO) << "sub-graph marked " << node->repr();
return true;
}
return false;
};
TEST_F(DFG_Tester, Split) { TEST_F(DFG_Tester, Split) {
auto desc = LoadProgramDesc(); auto desc = LoadProgramDesc();
auto dfg = ProgramDescToDFG(desc); auto dfg = ProgramDescToDFG(desc);
LOG(INFO) << "spliter\n" << dfg.DotString(); LOG(INFO) << "spliter\n" << dfg.DotString();
SubGraphSplitter::NodeInsideSubgraphTeller teller = [](const Node* node) {
if (node->type() != Node::Type::kFunction) return false;
const auto* func = static_cast<const Function*>(node);
if (func->func_type() == "elementwise_add" || func->func_type() == "relu" ||
func->func_type() == "conv2d" || func->func_type() == "mul" ||
func->func_type() == "sigmoid" || func->func_type() == "softmax") {
LOG(INFO) << "sub-graph marked " << node->repr();
return true;
}
return false;
};
ASSERT_GT(dfg.nodes.size(), 5UL); ASSERT_GT(dfg.nodes.size(), 5UL);
auto subgraphs = SubGraphSplitter(&dfg, teller)(); auto subgraphs = SubGraphSplitter(&dfg, teller)();
...@@ -62,6 +63,28 @@ TEST_F(DFG_Tester, Split) { ...@@ -62,6 +63,28 @@ TEST_F(DFG_Tester, Split) {
ASSERT_EQ(subgraphs.back().size(), 6UL); ASSERT_EQ(subgraphs.back().size(), 6UL);
} }
TEST_F(DFG_Tester, Fuse) {
auto desc = LoadProgramDesc();
auto dfg = ProgramDescToDFG(desc);
size_t count0 = dfg.nodes.size();
SubGraphFuse fuse(&dfg, teller);
fuse();
int count1 = 0;
for (auto& node : dfg.nodes.nodes()) {
if (node->deleted()) {
LOG(INFO) << "deleted " << node->repr();
}
count1 += node->deleted();
}
// At least one nodes should be deleted.
ASSERT_EQ(dfg.nodes.size(), count0 + 1); // added a new FunctionBlock
ASSERT_EQ(6UL, count1);
}
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -12,50 +12,22 @@ ...@@ -12,50 +12,22 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "gtest/gtest.h" #include "paddle/fluid/inference/analysis/tensorrt_subgraph_pass.h"
#include "paddle/contrib/tape/function.h" #include "paddle/fluid/inference/analysis/subgraph_splitter.h"
using namespace paddle::tape; namespace paddle {
namespace inference {
namespace analysis {
TEST(Tape, TestMLP) { TensorRTSubGraphPass::TensorRTSubGraphPass(
LOG(INFO) << "TestMLP"; const TensorRTSubGraphPass::NodeInsideSubgraphTeller &teller)
Linear linear1(3, 3, "relu"); : node_inside_subgraph_teller_(teller) {}
Linear linear2(3, 3, "relu");
Mean mean;
SGD sgd(0.001); void TensorRTSubGraphPass::Run(DataFlowGraph *graph) {
SubGraphFuse(graph, node_inside_subgraph_teller_);
std::string initializer = "fill_constant";
paddle::framework::AttributeMap attrs;
attrs["dtype"] = paddle::framework::proto::VarType::Type::VarType_Type_FP32;
attrs["shape"] = std::vector<int>{3, 3};
attrs["value"] = 1.0f;
Fill filler(initializer, attrs);
for (int i = 0; i < 2; ++i) {
reset_global_tape();
VariableHandle input(new Variable("input"));
filler(input);
auto loss = mean(linear2(linear1(input)));
get_global_tape().Backward(loss);
for (auto w : linear1.Params()) {
sgd(w);
}
for (auto w : linear2.Params()) {
sgd(w);
}
}
} }
int main(int argc, char** argv) { } // analysis
std::vector<paddle::platform::Place> places; } // inference
places.emplace_back(paddle::platform::CPUPlace());
paddle::platform::DeviceContextPool::Init(places);
testing::InitGoogleTest(&argc, argv); } // paddle
return RUN_ALL_TESTS();
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/inference/analysis/node.h"
#include "paddle/fluid/inference/analysis/pass.h"
#include "paddle/fluid/inference/analysis/subgraph_splitter.h"
namespace paddle {
namespace inference {
namespace analysis {
/*
* Parse the graph and replace TensorRT supported nodes with SubGraphNode
*/
class TensorRTSubGraphPass : public DataFlowGraphPass {
public:
// Tell whether to transform a sub-graph into TensorRT.
using NodeInsideSubgraphTeller = SubGraphFuse::NodeInsideSubgraphTeller;
TensorRTSubGraphPass(const NodeInsideSubgraphTeller& teller);
bool Initialize(Argument* argument) override { return true; }
// This class get a sub-graph as input and determine whether to transform this
// sub-graph into TensorRT.
void Run(DataFlowGraph* graph) override;
private:
NodeInsideSubgraphTeller node_inside_subgraph_teller_;
};
} // namespace analysis
} // namespace inference
} // paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/analysis/tensorrt_subgraph_pass.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h"
#include "paddle/fluid/inference/analysis/ut_helper.h"
namespace paddle {
namespace inference {
namespace analysis {
DEFINE_string(model_dir, "", "inference test model dir");
TEST(TensorRTSubGraph, single_pass) {
auto desc = LoadProgramDesc();
auto dfg = ProgramDescToDFG(desc);
SubGraphSplitter::NodeInsideSubgraphTeller teller = [](const Node* node) {
if (node->type() != Node::Type::kFunction) return false;
const auto* func = static_cast<const Function*>(node);
if (func->func_type() == "elementwise_add" || func->func_type() == "relu" ||
func->func_type() == "conv2d" || func->func_type() == "mul" ||
func->func_type() == "sigmoid" || func->func_type() == "softmax") {
LOG(INFO) << "sub-graph marked " << node->repr();
return true;
}
return false;
};
DFG_GraphvizDrawPass::Config config{"./", "test"};
DFG_GraphvizDrawPass dfg_pass(config);
dfg_pass.Initialize();
DFG_GraphvizDrawPass dfg_pass1(config);
dfg_pass1.Initialize();
dfg_pass.Run(&dfg);
TensorRTSubGraphPass trt_pass(std::move(teller));
trt_pass.Initialize();
trt_pass.Run(&dfg);
dfg_pass1.Run(&dfg);
// Check the TRT op's block desc
for (auto node : dfg.nodes.nodes()) {
if (node->IsFunctionBlock()) {
}
}
}
TEST(TensorRTSubGraph, pass_manager) {}
} // namespace analysis
} // namespace inference
} // namespace paddle
...@@ -15,33 +15,46 @@ limitations under the License. */ ...@@ -15,33 +15,46 @@ limitations under the License. */
#pragma once #pragma once
#include <gflags/gflags.h> #include <gflags/gflags.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <fstream>
#include <string> #include <string>
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h" #include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h" #include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
#include "paddle/fluid/inference/analysis/ut_helper.h" #include "paddle/fluid/inference/analysis/ut_helper.h"
#include "paddle/fluid/inference/io.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
// Read ProgramDesc from a __model__ file, defined in io.cc
extern void ReadBinaryFile(const std::string& filename, std::string* contents);
namespace analysis { namespace analysis {
DEFINE_string(inference_model_dir, "", "inference test model dir"); DEFINE_string(inference_model_dir, "", "inference test model dir");
static framework::proto::ProgramDesc LoadProgramDesc( static framework::proto::ProgramDesc LoadProgramDesc(
const std::string& model_dir = FLAGS_inference_model_dir) { const std::string& model_dir = FLAGS_inference_model_dir) {
paddle::platform::CPUPlace place; std::string msg;
paddle::framework::Executor executor(place); std::string net_file = FLAGS_inference_model_dir + "/__model__";
paddle::framework::Scope scope; std::ifstream fin(net_file, std::ios::in | std::ios::binary);
auto program = Load(&executor, &scope, model_dir); PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s", net_file);
return *program->Proto(); fin.seekg(0, std::ios::end);
msg.resize(fin.tellg());
fin.seekg(0, std::ios::beg);
fin.read(&(msg.at(0)), msg.size());
fin.close();
framework::proto::ProgramDesc program_desc;
program_desc.ParseFromString(msg);
return program_desc;
} }
static DataFlowGraph ProgramDescToDFG( static DataFlowGraph ProgramDescToDFG(
const framework::proto::ProgramDesc& desc) { const framework::proto::ProgramDesc& desc) {
DataFlowGraph graph; DataFlowGraph graph;
FluidToDataFlowGraphPass pass; FluidToDataFlowGraphPass pass;
pass.Initialize(desc); Argument argument;
argument.origin_program_desc.reset(new framework::proto::ProgramDesc(desc));
pass.Initialize(&argument);
pass.Run(&graph); pass.Run(&graph);
pass.Finalize(); pass.Finalize();
return graph; return graph;
...@@ -49,9 +62,12 @@ static DataFlowGraph ProgramDescToDFG( ...@@ -49,9 +62,12 @@ static DataFlowGraph ProgramDescToDFG(
class DFG_Tester : public ::testing::Test { class DFG_Tester : public ::testing::Test {
protected: protected:
void SetUp() override { desc = LoadProgramDesc(FLAGS_inference_model_dir); } void SetUp() override {
auto desc = LoadProgramDesc(FLAGS_inference_model_dir);
argument.origin_program_desc.reset(new framework::proto::ProgramDesc(desc));
}
framework::proto::ProgramDesc desc; Argument argument;
}; };
} // namespace analysis } // namespace analysis
......
...@@ -240,7 +240,7 @@ void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) { ...@@ -240,7 +240,7 @@ void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) {
} }
// Test with a larger FC layer. // Test with a larger FC layer.
TEST(TensorRTEngineOp, fc) { Execute(40, 256, 256); } TEST(TensorRTEngineOp, fc) { Execute(40, 28, 28); }
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册