未验证 提交 6b1e1f0d 编写于 作者: Y Yiqun Liu 提交者: GitHub

Enable generating code for a given subgraph. (#21126)

* Enable generating code for a given subgraph.

* Support sorting the subgraph.

* Remove the rearange of expressions because we use the sorted subgraph directly.

* Enable generating code for a subgraph which is composed of grad ops.

* Use expression information to check the accuracy in unittest.

* Separate load and store from computation expressions.
test=develop

* Improve the loading statements in generated codes.
test=develop

* Remove unused arguments from formal list.
test=develop
上级 3ff5cc2d
cc_library(code_generator SRCS operation.cc code_generator.cc code_generator_helper.cc DEPS graph) cc_library(code_generator SRCS operation.cc code_generator.cc code_generator_helper.cc DEPS graph)
if(NOT APPLE AND NOT WIN32) if(NOT APPLE AND NOT WIN32)
if(WITH_GPU) if(WITH_GPU)
cc_test(test_code_generator SRCS code_generator_tester.cc DEPS code_generator device_code lod_tensor) cc_test(test_code_generator SRCS code_generator_tester.cc DEPS code_generator device_code lod_tensor graph_viz_pass)
endif() endif()
endif() endif()
......
...@@ -13,9 +13,10 @@ See the License for the specific language governing permissions and ...@@ -13,9 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/ir/fusion_group/code_generator.h" #include "paddle/fluid/framework/ir/fusion_group/code_generator.h"
#include <set>
#include <sstream> #include <sstream>
#include <unordered_set>
#include "paddle/fluid/framework/ir/fusion_group/code_generator_helper.h" #include "paddle/fluid/framework/ir/fusion_group/code_generator_helper.h"
#include "paddle/fluid/framework/ir/fusion_group/operation.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -30,69 +31,205 @@ CodeGenerator::CodeGenerator() { ...@@ -30,69 +31,205 @@ CodeGenerator::CodeGenerator() {
code_templates_[0] = elementwise_t; code_templates_[0] = elementwise_t;
} }
std::string CodeGenerator::Generate(SubGraph* subgraph) {
std::vector<OperationExpression> expressions = ConvertToExpressions(subgraph);
return Generate(subgraph->func_name, expressions);
}
std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
SubGraph* subgraph) {
std::unordered_map<std::string, int> var_ids = EncodeVarNodes(subgraph);
std::vector<OperationExpression> expressions;
for (auto* node : subgraph->SortedNodes()) {
if (node && node->IsOp() && node->Op()) {
auto* op = node->Op();
// Input ids should be set in fixed order, like:
// - x, y in forward operations
// - x, y, out, out@GRAD in backward operations
std::vector<int> input_ids;
std::vector<std::string> input_names =
OperationMap::Instance().Get(op->Type()).input_names;
for (auto& name : input_names) {
// TODO(liuyiqun): support duplicated input.
if (op->Input(name).size() >= 1U) {
// Some input vars are not used in grad ops, such as
// "elementwise_add_grad", where "X", "Y" and "Out" are not used.
PADDLE_ENFORCE_NE(var_ids.find(op->Input(name)[0]), var_ids.end(),
"Input(%s) of operation %s should be set.", name,
op->Type());
input_ids.push_back(var_ids[op->Input(name)[0]]);
} else {
input_ids.push_back(-1);
}
}
// Output ids should be set in fixed order, like:
// - dx, dy in backward operations
std::vector<int> output_ids;
std::vector<std::string> output_names =
OperationMap::Instance().Get(op->Type()).output_names;
for (auto& name : output_names) {
PADDLE_ENFORCE_EQ(op->Output(name).size(), 1U,
"Output(%s) of operation %s should be set.", name,
op->Type());
PADDLE_ENFORCE_NE(var_ids.find(op->Output(name)[0]), var_ids.end(),
"Output(%s) of operation %s should be set.", name,
op->Type());
output_ids.push_back(var_ids[op->Output(name)[0]]);
}
expressions.push_back(
OperationExpression(node->Name(), input_ids, output_ids));
}
}
return expressions;
}
// In order to get the right result of expression, we need to calculate and // In order to get the right result of expression, we need to calculate and
// store the expression as suffix Expressions using vector. // store the expression as suffix Expressions using vector.
std::string CodeGenerator::GenerateCode( std::string CodeGenerator::Generate(
std::string func_name, std::vector<OperationExpression> expressions) { std::string func_name, std::vector<OperationExpression> expressions) {
// Check whether all expressions are elementwise operations. // TODO(liuyiqun): Check whether all expressions are elementwise operations.
std::string dtype = "float";
std::set<int> input_ids = DistilInputIds(expressions);
std::set<int> output_ids = DistilOutputIds(expressions);
TemplateVariable template_var; TemplateVariable template_var;
template_var.Add("func_name", func_name); template_var.Add("func_name", func_name);
template_var.Add("parameters", EmitParameters(expressions, "float")); template_var.Add("parameters", EmitParameters(input_ids, output_ids, dtype));
template_var.Add("compute_body", EmitComputeBody(expressions)); template_var.Add("compute_body",
EmitComputeBody(expressions, input_ids, output_ids, dtype));
return predefined_cuda_functions + code_templates_[0].Format(template_var); return predefined_cuda_functions + code_templates_[0].Format(template_var);
} }
// we get the parameter list code for the expression information std::set<int> CodeGenerator::DistilInputIds(
std::string CodeGenerator::EmitParameters( const std::vector<OperationExpression>& expressions) {
std::vector<OperationExpression> expressions, std::string dtype) {
std::set<int> input_ids; std::set<int> input_ids;
std::set<int> output_ids; // Use std::set to remove the reptead id and get a ordered list.
// Remove the reptead id and get a ordered list.
for (size_t i = 0; i < expressions.size(); i++) { for (size_t i = 0; i < expressions.size(); i++) {
for (auto id : expressions[i].GetInputIds()) { for (auto id : expressions[i].GetInputIds()) {
input_ids.insert(id); if (id >= 0) {
input_ids.insert(id);
}
} }
}
return input_ids;
}
std::set<int> CodeGenerator::DistilOutputIds(
const std::vector<OperationExpression>& expressions) {
std::set<int> output_ids;
// Use std::set to remove the reptead id and get a ordered list.
for (size_t i = 0; i < expressions.size(); i++) {
for (auto id : expressions[i].GetOutputIds()) { for (auto id : expressions[i].GetOutputIds()) {
output_ids.insert(id); output_ids.insert(id);
} }
} }
return output_ids;
}
// we get the parameter list code for the expression information
std::string CodeGenerator::EmitParameters(const std::set<int>& input_ids,
const std::set<int>& output_ids,
std::string dtype) {
std::stringstream ret;
ret << "int N, ";
// If a id is in the input and output list at the same time, then remove it // If a id is in the input and output list at the same time, then remove it
// from the input list. // from the input list.
for (auto iter = input_ids.begin(); iter != input_ids.end();) { for (auto id : input_ids) {
if (output_ids.find(*iter) != output_ids.end()) { if (output_ids.find(id) == output_ids.end()) {
input_ids.erase(iter++); ret << dtype << "* " << ArgName(id) << ", ";
} else {
iter++;
} }
} }
std::stringstream ret; size_t index = 0;
ret << "int N, "; for (auto id : output_ids) {
for (auto iter = input_ids.begin(); iter != input_ids.end(); iter++) { ret << dtype << "* " << ArgName(id);
ret << dtype << "* " << VarName(*iter) << ", "; if (index != output_ids.size() - 1) {
}
size_t count_index = 0;
for (auto iter = output_ids.begin(); iter != output_ids.end(); iter++) {
ret << dtype << "* " << VarName(*iter);
if (count_index != output_ids.size() - 1) {
ret << ", "; ret << ", ";
} }
count_index++; index++;
} }
return ret.str(); return ret.str();
} }
std::string CodeGenerator::EmitComputeBody( std::string CodeGenerator::EmitComputeBody(
std::vector<OperationExpression> expressions) { const std::vector<OperationExpression>& expressions,
// get the right experssion code using suffix expression const std::set<int>& input_ids, const std::set<int>& output_ids,
std::stringstream ret; std::string dtype) {
std::ostringstream compute;
std::unordered_set<int> used;
for (size_t i = 0; i < expressions.size(); i++) { for (size_t i = 0; i < expressions.size(); i++) {
ret << expressions[i].GetExpression(); VLOG(3) << DebugString(expressions[i]);
compute << expressions[i].GetExpression(dtype, &used);
} }
return ret.str();
// Load input to temporal variables.
std::ostringstream load;
for (auto id : input_ids) {
if (output_ids.find(id) == output_ids.end() &&
used.find(id) != used.end()) {
load << dtype << " " << TmpName(id) << " = " << ArgName(id) << "[idx];";
}
}
// Store temporal variables to memory.
std::ostringstream store;
for (auto id : output_ids) {
store << ArgName(id) << "[idx] = " << TmpName(id) << ";";
}
return load.str() + compute.str() + store.str();
}
std::unordered_map<std::string, int> CodeGenerator::EncodeVarNodes(
SubGraph* subgraph) {
const auto& input_var_nodes = subgraph->GetInputVarNodes();
const auto& output_var_nodes = subgraph->GetOutputVarNodes();
int id = 0;
std::unordered_map<std::string, int> var_ids;
// Numbering input vars.
for (auto* in : input_var_nodes) {
VLOG(3) << "Encoding input names:" << in->Name() << ", id:" << id;
if (var_ids.find(in->Name()) == var_ids.end()) {
var_ids[in->Name()] = id++;
}
}
// Numbering internal vars.
for (auto* node : subgraph->SortedNodes()) {
if (node && node->IsVar() && node->Var()) {
bool is_found = false;
for (auto* in : input_var_nodes) {
if (node == in) {
is_found = true;
break;
}
}
if (is_found) {
continue;
}
for (auto* out : output_var_nodes) {
if (node == out) {
is_found = true;
break;
}
}
PADDLE_ENFORCE_EQ(
is_found, true,
"Subgraph with internal var nodes (%s) is not supported yet.",
node->Name());
}
}
// Encoding output vars.
for (auto* out : output_var_nodes) {
VLOG(3) << "Ecoding output names:" << out->Name() << ", id:" << id;
if (var_ids.find(out->Name()) == var_ids.end()) {
var_ids[out->Name()] = id++;
}
}
return var_ids;
} }
} // namespace fusion_group } // namespace fusion_group
......
...@@ -14,9 +14,12 @@ limitations under the License. */ ...@@ -14,9 +14,12 @@ limitations under the License. */
#pragma once #pragma once
#include <set>
#include <string> #include <string>
#include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/fusion_group/code_generator_helper.h" #include "paddle/fluid/framework/ir/fusion_group/code_generator_helper.h"
#include "paddle/fluid/framework/ir/fusion_group/subgraph.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -27,18 +30,31 @@ class CodeGenerator { ...@@ -27,18 +30,31 @@ class CodeGenerator {
public: public:
CodeGenerator(); CodeGenerator();
std::string GenerateCode(std::string func_name, std::string Generate(std::string func_name,
std::vector<OperationExpression> expressions); std::vector<OperationExpression> expressions);
// TODO(wangchao): add a more general interface std::string Generate(SubGraph* subgraph);
// std::string Generate(const std::string name, const SubGraph& subgraph);
std::vector<OperationExpression> ConvertToExpressions(SubGraph* subgraph);
private: private:
std::set<int> DistilInputIds(
const std::vector<OperationExpression>& expressions);
std::set<int> DistilOutputIds(
const std::vector<OperationExpression>& expressions);
// we get the parameter list code for the expression information // we get the parameter list code for the expression information
std::string EmitParameters(std::vector<OperationExpression> expressions, std::string EmitParameters(const std::set<int>& input_ids,
const std::set<int>& output_ids,
std::string dtype); std::string dtype);
std::string EmitComputeBody(std::vector<OperationExpression> expressions); std::string EmitComputeBody(
const std::vector<OperationExpression>& expressions,
const std::set<int>& input_ids, const std::set<int>& output_ids,
std::string dtype);
// Encode all var nodes in the subgraph with an unique number.
std::unordered_map<std::string, int> EncodeVarNodes(SubGraph* subgraph);
private: private:
std::vector<CodeTemplate> code_templates_; std::vector<CodeTemplate> code_templates_;
......
...@@ -33,8 +33,9 @@ static T StringTo(const std::string& str) { ...@@ -33,8 +33,9 @@ static T StringTo(const std::string& str) {
return value; return value;
} }
std::string OperationExpression::GetRHS(size_t i) { std::string OperationExpression::GetRHS(std::unordered_set<int>* used,
auto rhs = OperationMap::Instance().Get(op_).exprs[i]; size_t i) const {
auto rhs = OperationMap::Instance().Get(op_type_).exprs[i];
for (size_t i = 0; i < rhs.size(); i++) { for (size_t i = 0; i < rhs.size(); i++) {
size_t pos = i; size_t pos = i;
if (rhs[pos] == '$' && rhs[pos + 1] == '{') { if (rhs[pos] == '$' && rhs[pos + 1] == '{') {
...@@ -47,29 +48,33 @@ std::string OperationExpression::GetRHS(size_t i) { ...@@ -47,29 +48,33 @@ std::string OperationExpression::GetRHS(size_t i) {
PADDLE_ENFORCE_LT(index, input_ids_.size(), PADDLE_ENFORCE_LT(index, input_ids_.size(),
"Only %d inputs are provided, but need %d.", "Only %d inputs are provided, but need %d.",
input_ids_.size(), index + 1); input_ids_.size(), index + 1);
rhs.replace(pos, length + 3, VarName(input_ids_[index]) + R"([idx])"); PADDLE_ENFORCE_GE(input_ids_[index], 0,
"Input id should be no less than 0.");
rhs.replace(pos, length + 3, TmpName(input_ids_[index]));
used->insert(input_ids_[index]);
} }
} }
return rhs; return rhs;
} }
std::string OperationExpression::GetLHS(size_t i) { std::string OperationExpression::GetLHS(size_t i) const {
std::stringstream ret; std::stringstream ret;
ret << VarName(output_ids_[i]) << R"([idx])"; ret << TmpName(output_ids_[i]);
return ret.str(); return ret.str();
} }
bool OperationExpression::IsSupport() { bool OperationExpression::IsSupport() const {
return OperationMap::Instance().Has(op_); return OperationMap::Instance().Has(op_type_);
} }
// we Traverse the graph and get the group , all input id and output id is // we Traverse the graph and get the group , all input id and output id is
// unique for the node which belong the group // unique for the node which belong the group
std::string OperationExpression::GetExpression() { std::string OperationExpression::GetExpression(
std::string dtype, std::unordered_set<int>* used) const {
std::stringstream ret; std::stringstream ret;
if (IsSupport()) { if (IsSupport()) {
for (size_t i = 0; i < output_ids_.size(); ++i) { for (size_t i = 0; i < output_ids_.size(); ++i) {
ret << GetLHS(i) << " = " << GetRHS(i) << ";"; ret << dtype << " " << GetLHS(i) << " = " << GetRHS(used, i) << ";";
} }
} }
......
...@@ -14,10 +14,10 @@ limitations under the License. */ ...@@ -14,10 +14,10 @@ limitations under the License. */
#pragma once #pragma once
#include <set>
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -27,28 +27,36 @@ namespace framework { ...@@ -27,28 +27,36 @@ namespace framework {
namespace ir { namespace ir {
namespace fusion_group { namespace fusion_group {
static std::string VarName(int index) { return "var" + std::to_string(index); } static inline std::string ArgName(int index) {
return "arg" + std::to_string(index);
}
static inline std::string TmpName(int index) {
return "tmp" + std::to_string(index);
}
class OperationExpression { class OperationExpression {
public: public:
explicit OperationExpression(std::string op, std::vector<int> input_ids, explicit OperationExpression(std::string op_type, std::vector<int> input_ids,
std::vector<int> output_ids) std::vector<int> output_ids)
: op_(op), input_ids_(input_ids), output_ids_(output_ids) {} : op_type_(op_type), input_ids_(input_ids), output_ids_(output_ids) {}
std::vector<int> GetInputIds() { return input_ids_; } std::string GetOpType() const { return op_type_; }
std::vector<int> GetOutputIds() { return output_ids_; } std::vector<int> GetInputIds() const { return input_ids_; }
std::vector<int> GetOutputIds() const { return output_ids_; }
// Check whether this operation type is supported in OperationMap. // Check whether this operation type is supported in OperationMap.
bool IsSupport(); bool IsSupport() const;
std::string GetExpression(); std::string GetExpression(std::string dtype,
std::unordered_set<int>* used) const;
private:
// TODO(wangchao): make offset more flexible we add stride and basic offset // TODO(wangchao): make offset more flexible we add stride and basic offset
std::string GetRHS(size_t i = 0); std::string GetRHS(std::unordered_set<int>* used, size_t i = 0) const;
std::string GetLHS(size_t i = 0); std::string GetLHS(size_t i = 0) const;
private: private:
std::string op_; std::string op_type_;
std::vector<int> input_ids_; std::vector<int> input_ids_;
std::vector<int> output_ids_; std::vector<int> output_ids_;
}; };
...@@ -58,6 +66,7 @@ class TemplateVariable { ...@@ -58,6 +66,7 @@ class TemplateVariable {
void Add(std::string identifier, std::string expression) { void Add(std::string identifier, std::string expression) {
strings_[identifier] = expression; strings_[identifier] = expression;
} }
void Remove(std::string identifier, std::string expression) { void Remove(std::string identifier, std::string expression) {
for (auto it = strings_.begin(); it != strings_.end();) { for (auto it = strings_.begin(); it != strings_.end();) {
if (it->first == identifier) { if (it->first == identifier) {
...@@ -155,7 +164,6 @@ __device__ double real_max(double x, double y) { return ::fmax(x, y); } ...@@ -155,7 +164,6 @@ __device__ double real_max(double x, double y) { return ::fmax(x, y); }
)"; )";
static const char elementwise_cuda_template[] = R"( static const char elementwise_cuda_template[] = R"(
extern "C" __global__ void $func_name($parameters) { extern "C" __global__ void $func_name($parameters) {
for(int idx = blockIdx.x * blockDim.x + threadIdx.x; for(int idx = blockIdx.x * blockDim.x + threadIdx.x;
idx < N; idx < N;
...@@ -165,6 +173,28 @@ extern "C" __global__ void $func_name($parameters) { ...@@ -165,6 +173,28 @@ extern "C" __global__ void $func_name($parameters) {
} }
)"; )";
static std::string DebugString(const OperationExpression& expr) {
std::stringstream ret;
ret << "Op(" << expr.GetOpType() << "), inputs:{";
auto input_ids = expr.GetInputIds();
for (size_t i = 0; i < input_ids.size(); ++i) {
if (i != 0) {
ret << ",";
}
ret << expr.GetInputIds()[i];
}
ret << "}, outputs:{";
auto output_ids = expr.GetOutputIds();
for (size_t i = 0; i < output_ids.size(); ++i) {
if (i != 0) {
ret << ",";
}
ret << expr.GetOutputIds()[i];
}
ret << "}";
return ret.str();
}
} // namespace fusion_group } // namespace fusion_group
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
......
...@@ -18,16 +18,133 @@ limitations under the License. */ ...@@ -18,16 +18,133 @@ limitations under the License. */
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/fusion_group/operation.h" #include "paddle/fluid/framework/ir/fusion_group/operation.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/operators/math.h" #include "paddle/fluid/operators/math.h"
#include "paddle/fluid/platform/device_code.h" #include "paddle/fluid/platform/device_code.h"
#include "paddle/fluid/platform/init.h" #include "paddle/fluid/platform/init.h"
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
namespace fusion_group = paddle::framework::ir::fusion_group;
namespace paddle {
namespace framework {
namespace ir {
namespace fusion_group {
// relu
inline float relu(float x) { return x > 0 ? x : 0.; }
inline float relu_grad_dx(float x, float out, float dout) {
return x > 0 ? dout : 0;
}
// sigmoid
inline float sigmoid(float x) { return 1.0 / (1.0 + std::exp(-x)); }
inline float sigmoid_grad_dx(float x, float out, float dout) {
return dout * out * (1 - out);
}
// tanh
inline float tanh(float x) { return 2.0 / (1.0 + std::exp(-2 * x)) - 1.0; }
inline float tanh_grad_dx(float x, float out, float dout) {
return dout * (1.0 - out * out);
}
// elementwise_add
inline float elementwise_add(float x, float y) { return x + y; }
inline float elementwise_add_grad_dx(float x, float y, float out, float dout) {
return dout;
}
inline float elementwise_add_grad_dy(float x, float y, float out, float dout) {
return dout;
}
// elementwise_sub
inline float elementwise_sub(float x, float y) { return x - y; }
inline float elementwise_sub_grad_dx(float x, float y, float out, float dout) {
return dout;
}
inline float elementwise_sub_grad_dy(float x, float y, float out, float dout) {
return -dout;
}
// elementwise_mul
inline float elementwise_mul(float x, float y) { return x * y; }
inline float elementwise_mul_grad_dx(float x, float y, float out, float dout) {
return dout * y;
}
inline float elementwise_mul_grad_dy(float x, float y, float out, float dout) {
return dout * x;
}
void CheckOutput(const std::vector<OperationExpression>& expressions,
const std::vector<LoDTensor> cpu_tensors,
const std::vector<int> input_ids_of_subgraph,
const std::vector<int> output_ids_of_subgraph, int i) {
std::vector<float> var(cpu_tensors.size());
for (auto id : input_ids_of_subgraph) {
if (id >= 0) {
var[id] = cpu_tensors[id].data<float>()[i];
}
}
for (auto expression : expressions) {
std::string op_type = expression.GetOpType();
auto input_ids = expression.GetInputIds();
auto output_ids = expression.GetOutputIds();
if (op_type == "relu") {
var[output_ids[0]] = relu(var[input_ids[0]]);
} else if (op_type == "sigmoid") {
var[output_ids[0]] = sigmoid(var[input_ids[0]]);
} else if (op_type == "tanh") {
var[output_ids[0]] = tanh(var[input_ids[0]]);
} else if (op_type == "elementwise_add") {
var[output_ids[0]] =
elementwise_add(var[input_ids[0]], var[input_ids[1]]);
} else if (op_type == "elementwise_sub") {
var[output_ids[0]] =
elementwise_sub(var[input_ids[0]], var[input_ids[1]]);
} else if (op_type == "elementwise_mul") {
var[output_ids[0]] =
elementwise_mul(var[input_ids[0]], var[input_ids[1]]);
} else if (op_type == "relu_grad") {
var[output_ids[0]] =
relu_grad_dx(var[input_ids[0]], 0, var[input_ids[2]]);
} else if (op_type == "sigmoid_grad") {
var[output_ids[0]] =
sigmoid_grad_dx(0, var[input_ids[1]], var[input_ids[2]]);
} else if (op_type == "tanh_grad") {
var[output_ids[0]] =
tanh_grad_dx(0, var[input_ids[1]], var[input_ids[2]]);
} else if (op_type == "elementwise_add_grad") {
var[output_ids[0]] = elementwise_add_grad_dx(0, 0, 0, var[input_ids[3]]);
var[output_ids[1]] = elementwise_add_grad_dy(0, 0, 0, var[input_ids[3]]);
} else if (op_type == "elementwise_mul_grad") {
var[output_ids[0]] =
elementwise_mul_grad_dx(0, var[input_ids[1]], 0, var[input_ids[3]]);
var[output_ids[1]] =
elementwise_mul_grad_dy(var[input_ids[0]], 0, 0, var[input_ids[3]]);
}
}
for (auto id : output_ids_of_subgraph) {
float actual = cpu_tensors[id].data<float>()[i];
float expect = var[id];
PADDLE_ENFORCE_LT(fabs(actual - expect), 1.E-05,
"Get %f vs %f (actual vs expect).", actual, expect);
}
}
template <typename T> template <typename T>
void SetupRandomCPUTensor(paddle::framework::LoDTensor* tensor) { void SetupRandomCPUTensor(LoDTensor* tensor) {
static unsigned int seed = 100; static unsigned int seed = 100;
std::mt19937 rng(seed++); std::mt19937 rng(seed++);
std::uniform_real_distribution<double> uniform_dist(0, 1); std::uniform_real_distribution<double> uniform_dist(0, 1);
...@@ -40,15 +157,16 @@ void SetupRandomCPUTensor(paddle::framework::LoDTensor* tensor) { ...@@ -40,15 +157,16 @@ void SetupRandomCPUTensor(paddle::framework::LoDTensor* tensor) {
} }
} }
void TestMain(std::string func_name, } // namespace fusion_group
std::vector<fusion_group::OperationExpression> expressions, } // namespace ir
std::vector<paddle::framework::LoDTensor> cpu_tensors, int n, } // namespace framework
std::vector<int> input_ids, std::vector<int> output_ids) { } // namespace paddle
fusion_group::OperationMap::Init();
fusion_group::CodeGenerator code_generator; namespace fusion_group = paddle::framework::ir::fusion_group;
std::string code_str = code_generator.GenerateCode(func_name, expressions);
VLOG(3) << code_str;
void TestMainImpl(std::string func_name, std::string code_str,
std::vector<paddle::framework::LoDTensor> cpu_tensors, int n,
std::vector<int> input_ids, std::vector<int> output_ids) {
paddle::framework::InitDevices(false, {0}); paddle::framework::InitDevices(false, {0});
paddle::platform::CUDAPlace place = paddle::platform::CUDAPlace(0); paddle::platform::CUDAPlace place = paddle::platform::CUDAPlace(0);
paddle::platform::CUDADeviceCode device_code(place, func_name, code_str); paddle::platform::CUDADeviceCode device_code(place, func_name, code_str);
...@@ -60,20 +178,20 @@ void TestMain(std::string func_name, ...@@ -60,20 +178,20 @@ void TestMain(std::string func_name,
std::vector<void*> args; std::vector<void*> args;
args.push_back(&n); args.push_back(&n);
for (size_t i = 0; i < input_ids.size(); ++i) { for (auto id : input_ids) {
gpu_ptrs[input_ids[i]] = gpu_tensors[input_ids[i]].mutable_data<float>( if (id >= 0) {
cpu_tensors[input_ids[i]].dims(), place); gpu_ptrs[id] =
args.push_back(&gpu_ptrs[input_ids[i]]); gpu_tensors[id].mutable_data<float>(cpu_tensors[id].dims(), place);
fusion_group::SetupRandomCPUTensor<float>(&cpu_tensors[id]);
SetupRandomCPUTensor<float>(&cpu_tensors[input_ids[i]]); TensorCopySync(cpu_tensors[id], place, &gpu_tensors[id]);
TensorCopySync(cpu_tensors[input_ids[i]], place, args.push_back(&gpu_ptrs[id]);
&gpu_tensors[input_ids[i]]); }
} }
for (size_t i = 0; i < output_ids.size(); ++i) { for (auto id : output_ids) {
gpu_ptrs[output_ids[i]] = gpu_tensors[output_ids[i]].mutable_data<float>( gpu_ptrs[id] =
cpu_tensors[output_ids[i]].dims(), place); gpu_tensors[id].mutable_data<float>(cpu_tensors[id].dims(), place);
args.push_back(&gpu_ptrs[output_ids[i]]); args.push_back(&gpu_ptrs[id]);
} }
device_code.SetNumThreads(1024); device_code.SetNumThreads(1024);
...@@ -84,12 +202,40 @@ void TestMain(std::string func_name, ...@@ -84,12 +202,40 @@ void TestMain(std::string func_name,
paddle::platform::DeviceContextPool::Instance().Get(place)); paddle::platform::DeviceContextPool::Instance().Get(place));
dev_ctx->Wait(); dev_ctx->Wait();
for (size_t i = 0; i < output_ids.size(); ++i) { for (auto id : output_ids) {
TensorCopySync(gpu_tensors[output_ids[i]], paddle::platform::CPUPlace(), TensorCopySync(gpu_tensors[id], paddle::platform::CPUPlace(),
&cpu_tensors[output_ids[i]]); &cpu_tensors[id]);
} }
} }
void TestMain(std::string func_name,
std::vector<fusion_group::OperationExpression> expressions,
std::vector<paddle::framework::LoDTensor> cpu_tensors, int n,
std::vector<int> input_ids, std::vector<int> output_ids) {
fusion_group::OperationMap::Init();
fusion_group::CodeGenerator code_generator;
std::string code_str = code_generator.Generate(func_name, expressions);
VLOG(3) << code_str;
TestMainImpl(func_name, code_str, cpu_tensors, n, input_ids, output_ids);
}
std::vector<fusion_group::OperationExpression> TestMain(
fusion_group::SubGraph* subgraph,
std::vector<paddle::framework::LoDTensor> cpu_tensors, int n,
std::vector<int> input_ids, std::vector<int> output_ids) {
fusion_group::OperationMap::Init();
fusion_group::CodeGenerator code_generator;
std::string code_str = code_generator.Generate(subgraph);
VLOG(3) << code_str;
TestMainImpl(subgraph->func_name, code_str, cpu_tensors, n, input_ids,
output_ids);
// Need to check the accuracy according to expressions.
return code_generator.ConvertToExpressions(subgraph);
}
TEST(code_generator, elementwise) { TEST(code_generator, elementwise) {
// t2 = t0 * t1 // t2 = t0 * t1
// t4 = t2 + t3 // t4 = t2 + t3
...@@ -101,41 +247,33 @@ TEST(code_generator, elementwise) { ...@@ -101,41 +247,33 @@ TEST(code_generator, elementwise) {
fusion_group::OperationExpression exp3("elementwise_sub", {4, 5}, {6}); fusion_group::OperationExpression exp3("elementwise_sub", {4, 5}, {6});
fusion_group::OperationExpression exp4("relu", {6}, {7}); fusion_group::OperationExpression exp4("relu", {6}, {7});
fusion_group::OperationExpression exp5("sigmoid", {7}, {8}); fusion_group::OperationExpression exp5("sigmoid", {7}, {8});
std::vector<fusion_group::OperationExpression> expressions = { std::vector<fusion_group::OperationExpression> expressions = {
exp1, exp2, exp3, exp4, exp5}; exp1, exp2, exp3, exp4, exp5};
// Prepare CPU tensors // Prepare CPU tensors
std::vector<paddle::framework::LoDTensor> cpu_tensors(9); std::vector<paddle::framework::LoDTensor> cpu_tensors(9);
std::vector<int> input_ids = {0, 1, 3, 5};
std::vector<int> output_ids = {2, 4, 6, 7, 8};
auto dims = paddle::framework::make_ddim( auto dims = paddle::framework::make_ddim(
{static_cast<int64_t>(256), static_cast<int64_t>(1024)}); {static_cast<int64_t>(256), static_cast<int64_t>(1024)});
for (size_t i = 0; i < cpu_tensors.size(); ++i) { for (size_t i = 0; i < cpu_tensors.size(); ++i) {
cpu_tensors[i].mutable_data<float>(dims, paddle::platform::CPUPlace()); cpu_tensors[i].mutable_data<float>(dims, paddle::platform::CPUPlace());
} }
// Expressions:
// Op(elementwise_mul), inputs:{0,1}, outputs:{2}
// Op(elementwise_add), inputs:{2,3}, outputs:{4}
// Op(elementwise_sub), inputs:{4,5}, outputs:{6}
// Op(relu), inputs:{6}, outputs:{7}
// Op(sigmoid), inputs:{7}, outputs:{8}
int n = cpu_tensors[0].numel(); int n = cpu_tensors[0].numel();
TestMain("fused_elementwise_0", expressions, cpu_tensors, n, input_ids, std::vector<int> input_ids = {0, 1, 3, 5};
std::vector<int> output_ids = {2, 4, 6, 7, 8};
TestMain("elementwise_kernel_0", expressions, cpu_tensors, n, input_ids,
output_ids); output_ids);
auto cpu_kernel_handler = [&](float* var0, float* var1, float* var3,
float* var5, int i) -> float {
float var2_i = var0[i] * var1[i];
float var4_i = var2_i + var3[i];
float var6_i = var4_i - var5[i];
float var7_i = var6_i > 0.0 ? var6_i : 0.0;
float var8_i = 1.0 / (1.0 + std::exp(-var7_i));
return var8_i;
};
// Check the results // Check the results
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
float result = cpu_kernel_handler( fusion_group::CheckOutput(expressions, cpu_tensors, input_ids, output_ids,
cpu_tensors[0].data<float>(), cpu_tensors[1].data<float>(), i);
cpu_tensors[3].data<float>(), cpu_tensors[5].data<float>(), i);
PADDLE_ENFORCE_LT(fabs(cpu_tensors[8].data<float>()[i] - result), 1.E-05);
} }
} }
...@@ -145,48 +283,183 @@ TEST(code_generator, elementwise_grad) { ...@@ -145,48 +283,183 @@ TEST(code_generator, elementwise_grad) {
// t3 = relu(t2) // t3 = relu(t2)
// t2' = relu_grad(t2, t3, t3') // t2' = relu_grad(t2, t3, t3')
// t0', t1' = elementwise_mul_grad(t0, t1, t2, t2') // t0', t1' = elementwise_mul_grad(t0, t1, t2, t2')
fusion_group::OperationExpression exp1("relu_grad", {2, 3, 7}, {6}); fusion_group::OperationExpression exp1("relu_grad", {2, -1, 7}, {6});
fusion_group::OperationExpression exp2("elementwise_mul_grad", {0, 1, 2, 6}, fusion_group::OperationExpression exp2("elementwise_mul_grad", {0, 1, 2, 6},
{4, 5}); {4, 5});
std::vector<fusion_group::OperationExpression> expressions = {exp1, exp2}; std::vector<fusion_group::OperationExpression> expressions = {exp1, exp2};
// Prepare CPU tensors // Prepare CPU tensors
std::vector<paddle::framework::LoDTensor> cpu_tensors(8); std::vector<paddle::framework::LoDTensor> cpu_tensors(8);
std::vector<int> input_ids = {0, 1, 2, 3, 7};
std::vector<int> output_ids = {4, 5, 6};
auto dims = paddle::framework::make_ddim( auto dims = paddle::framework::make_ddim(
{static_cast<int64_t>(256), static_cast<int64_t>(1024)}); {static_cast<int64_t>(256), static_cast<int64_t>(1024)});
for (size_t i = 0; i < cpu_tensors.size(); ++i) { for (size_t i = 0; i < cpu_tensors.size(); ++i) {
cpu_tensors[i].mutable_data<float>(dims, paddle::platform::CPUPlace()); cpu_tensors[i].mutable_data<float>(dims, paddle::platform::CPUPlace());
} }
// Expressions:
// Op(relu_grad), inputs:{2,3,7}, outputs:{6}
// Op(elementwise_mul_grad), inputs:{0,1,2,6}, outputs:{4,5}
int n = cpu_tensors[0].numel(); int n = cpu_tensors[0].numel();
TestMain("fused_elementwise_grad_0", expressions, cpu_tensors, n, input_ids, std::vector<int> input_ids = {0, 1, 2, -1, 7};
std::vector<int> output_ids = {4, 5, 6};
TestMain("elementwise_grad_kernel_0", expressions, cpu_tensors, n, input_ids,
output_ids); output_ids);
auto cpu_kernel_handler = [&](float* var0, float* var1, float* var2, // Check the results
float* var3, float* var7, for (int i = 0; i < n; i++) {
int i) -> std::vector<float> { fusion_group::CheckOutput(expressions, cpu_tensors, input_ids, output_ids,
float var6_i = var2[i] > 0 ? var7[i] : 0; i);
float var4_i = var6_i * var1[i]; }
float var5_i = var6_i * var0[i]; }
return std::vector<float>{var4_i, var5_i, var6_i};
std::unique_ptr<paddle::framework::ir::Graph> BuildGraph(
bool backward = false) {
// inputs operator output
// --------------------------------------------------------
// x0 sigmoid -> tmp_0
// (tmp_0, x1) elementwise_mul -> tmp_1
// x2 tanh -> tmp_2
// (x3, tmp_2) elementwise_mul -> tmp_3
// (tmp_1, tmp_3) elementwise_add -> tmp_4
//
// Expression: tmp_4 = sigmoid(x0) * x1 + tanh(x2) * x3
// The var order (their ids may be different):
// backward is false - x0(0), x1(1), x2(2), x3(3);
// - tmp_0(4), tmp_2(5), tmp_3(6), tmp_1(7), tmp_4(8)
// backward is true - tmp_1(0), tmp_4@GRAD(1), tmp_3(2), tmp_4(3),
// tmp_2(4), x3(5), x1(6), tmp_0(7), x0(8), x2(9)
// - tmp_3@GRAD(10), tmp_1@GRAD(11), tmp_0@GRAD(12),
// tmp_2@GRAD(13), x2@GRAD(14), x0@GRAD(15),
// x3@GRAD(16), x1@GRAD(17)
paddle::framework::ir::Layers layers;
auto* x0 = layers.data("x0", {16, 32});
auto* tmp_0 = layers.sigmoid(x0);
tmp_0->SetShape({16, 32});
auto* x1 = layers.data("x1", {16, 32});
auto* tmp_1 = layers.elementwise_mul(tmp_0, x1);
tmp_1->SetShape({16, 32});
auto* x2 = layers.data("x2", {16, 32});
auto* tmp_2 = layers.tanh(x2);
tmp_2->SetShape({16, 32});
auto* x3 = layers.data("x3", {16, 32});
auto* tmp_3 = layers.elementwise_mul(x3, tmp_2);
tmp_3->SetShape({16, 32});
layers.elementwise_add(tmp_1, tmp_3);
if (backward) {
layers.backward();
}
std::unique_ptr<paddle::framework::ir::Graph> graph(
new paddle::framework::ir::Graph(layers.main_program()));
#ifdef __clang__
return graph;
#else
return std::move(graph);
#endif
}
std::unordered_set<paddle::framework::ir::Node*> DistilGradNodes(
const std::unique_ptr<paddle::framework::ir::Graph>& graph) {
auto is_grad_op = [&](paddle::framework::ir::Node* n) -> bool {
if (n && n->IsOp() && n->Op()) {
std::string suffix = "_grad";
std::string op_type = n->Op()->Type();
size_t pos = op_type.rfind(suffix);
return pos != std::string::npos &&
pos == (op_type.length() - suffix.length());
}
return false;
}; };
std::unordered_set<paddle::framework::ir::Node*> grad_nodes;
for (auto* n : graph->Nodes()) {
if (is_grad_op(n)) {
grad_nodes.insert(n);
} else if (n && n->IsVar() && n->Var()) {
// Remove forward op nodes from inputs
std::vector<paddle::framework::ir::Node*> inputs;
for (auto* in : n->inputs) {
if (in && in->IsOp() && in->Op() && is_grad_op(in)) {
inputs.push_back(in);
}
}
n->inputs = inputs;
// Remove forward op nodes from outputs
std::vector<paddle::framework::ir::Node*> outputs;
for (auto* out : n->outputs) {
if (out && out->IsOp() && out->Op() && is_grad_op(out)) {
outputs.push_back(out);
}
}
n->outputs = outputs;
grad_nodes.insert(n);
}
}
return grad_nodes;
}
TEST(code_generator, subgraph) {
std::unique_ptr<paddle::framework::ir::Graph> graph = BuildGraph(false);
fusion_group::SubGraph subgraph(0, "elementwise_kernel_1", true,
graph->Nodes());
// Prepare CPU tensors
std::vector<paddle::framework::LoDTensor> cpu_tensors(9);
auto dims = paddle::framework::make_ddim(
{static_cast<int64_t>(256), static_cast<int64_t>(1024)});
for (size_t i = 0; i < cpu_tensors.size(); ++i) {
cpu_tensors[i].mutable_data<float>(dims, paddle::platform::CPUPlace());
}
// Expressions generated by code_generator (they may be different):
// Op(sigmoid), inputs:{0}, outputs:{4}
// Op(elementwise_mul), inputs:{4,1}, outputs:{7}
// Op(tanh), inputs:{2}, outputs:{5}
// Op(elementwise_mul), inputs:{3,5}, outputs:{6}
// Op(elementwise_add), inputs:{7,6}, outputs:{8}
int n = cpu_tensors[0].numel();
std::vector<int> input_ids = {0, 1, 2, 3};
std::vector<int> output_ids = {4, 5, 6, 7, 8};
std::vector<fusion_group::OperationExpression> expressions =
TestMain(&subgraph, cpu_tensors, n, input_ids, output_ids);
// Check the results
for (int i = 0; i < n; i++) {
fusion_group::CheckOutput(expressions, cpu_tensors, input_ids, output_ids,
i);
}
}
TEST(code_generator, subgraph_grad) {
std::unique_ptr<paddle::framework::ir::Graph> graph = BuildGraph(true);
fusion_group::SubGraph subgraph(0, "elementwise_grad_kernel_1", true,
DistilGradNodes(graph));
// Prepare CPU tensors
std::vector<paddle::framework::LoDTensor> cpu_tensors(18);
auto dims = paddle::framework::make_ddim(
{static_cast<int64_t>(256), static_cast<int64_t>(1024)});
for (size_t i = 0; i < cpu_tensors.size(); ++i) {
cpu_tensors[i].mutable_data<float>(dims, paddle::platform::CPUPlace());
}
// Expressions generated by code_generator (they may be different):
// Op(elementwise_add_grad), inputs:{1,2,3,0}, outputs:{11,10}
// Op(elementwise_mul_grad), inputs:{5,4,2,10}, outputs:{17,13}
// Op(elementwise_mul_grad), inputs:{7,6,1,11}, outputs:{12,15}
// Op(sigmoid_grad), inputs:{8,7,12}, outputs:{16}
// Op(tanh_grad), inputs:{9,4,13}, outputs:{14}
int n = cpu_tensors[0].numel();
std::vector<int> input_ids = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
std::vector<int> output_ids = {10, 11, 12, 13, 14, 15, 16, 17};
std::vector<fusion_group::OperationExpression> expressions =
TestMain(&subgraph, cpu_tensors, n, input_ids, output_ids);
// Check the results // Check the results
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
std::vector<float> results = cpu_kernel_handler( fusion_group::CheckOutput(expressions, cpu_tensors, input_ids, output_ids,
cpu_tensors[0].data<float>(), cpu_tensors[1].data<float>(), i);
cpu_tensors[2].data<float>(), cpu_tensors[3].data<float>(),
cpu_tensors[7].data<float>(), i);
PADDLE_ENFORCE_LT(fabs(cpu_tensors[4].data<float>()[i] - results[0]),
1.E-05);
PADDLE_ENFORCE_LT(fabs(cpu_tensors[5].data<float>()[i] - results[1]),
1.E-05);
PADDLE_ENFORCE_LT(fabs(cpu_tensors[6].data<float>()[i] - results[2]),
1.E-05);
} }
} }
#endif #endif
...@@ -108,13 +108,6 @@ bool ElementwiseGroupDetector::IsOutputOfElementwiseOp(Node* n) { ...@@ -108,13 +108,6 @@ bool ElementwiseGroupDetector::IsOutputOfElementwiseOp(Node* n) {
return false; return false;
} }
void ElementwiseGroupDetector::Insert(Node* n) {
if (subgraph_.nodes_set.find(n) == subgraph_.nodes_set.end()) {
VLOG(5) << "Insert " << n->Name() << " to subgraph " << name_;
subgraph_.nodes_set.insert(n);
}
}
int ElementwiseGroupDetector::Search(Node* n, std::vector<Node*> except_nodes) { int ElementwiseGroupDetector::Search(Node* n, std::vector<Node*> except_nodes) {
std::unordered_set<Node*> except_nodes_set; std::unordered_set<Node*> except_nodes_set;
for (size_t i = 0; i < except_nodes.size(); ++i) { for (size_t i = 0; i < except_nodes.size(); ++i) {
...@@ -123,16 +116,16 @@ int ElementwiseGroupDetector::Search(Node* n, std::vector<Node*> except_nodes) { ...@@ -123,16 +116,16 @@ int ElementwiseGroupDetector::Search(Node* n, std::vector<Node*> except_nodes) {
int num_operations = 0; int num_operations = 0;
if (IsElementwiseOp(n)) { if (IsElementwiseOp(n)) {
Insert(n); subgraph_.Insert(n);
num_operations += 1; num_operations += 1;
for (auto* var : n->inputs) { for (auto* var : n->inputs) {
Insert(var); subgraph_.Insert(var);
if (except_nodes_set.find(var) == except_nodes_set.end()) { if (except_nodes_set.find(var) == except_nodes_set.end()) {
num_operations += Search(var, {n}); num_operations += Search(var, {n});
} }
} }
for (auto* var : n->outputs) { for (auto* var : n->outputs) {
Insert(var); subgraph_.Insert(var);
if (except_nodes_set.find(var) == except_nodes_set.end()) { if (except_nodes_set.find(var) == except_nodes_set.end()) {
num_operations += Search(var, {n}); num_operations += Search(var, {n});
} }
...@@ -157,7 +150,7 @@ int ElementwiseGroupDetector::Search(Node* n, std::vector<Node*> except_nodes) { ...@@ -157,7 +150,7 @@ int ElementwiseGroupDetector::Search(Node* n, std::vector<Node*> except_nodes) {
int ElementwiseGroupDetector::operator()(Node* n) { int ElementwiseGroupDetector::operator()(Node* n) {
if (!IsOutputOfElementwiseOp(n) && IsInputOfElementwiseOp(n, "X")) { if (!IsOutputOfElementwiseOp(n) && IsInputOfElementwiseOp(n, "X")) {
name_ = n->Name(); name_ = n->Name();
Insert(n); subgraph_.Insert(n);
num_operations_ = Search(n, n->inputs); num_operations_ = Search(n, n->inputs);
VLOG(4) << "Detect elementwise subgraph begin with " << name_ << ", " VLOG(4) << "Detect elementwise subgraph begin with " << name_ << ", "
<< num_operations_ << " operations, " << GetSubgraph().GetNumNodes() << num_operations_ << " operations, " << GetSubgraph().GetNumNodes()
......
...@@ -36,7 +36,6 @@ class ElementwiseGroupDetector { ...@@ -36,7 +36,6 @@ class ElementwiseGroupDetector {
bool IsInputOfElementwiseOp(Node* n, std::string name = ""); bool IsInputOfElementwiseOp(Node* n, std::string name = "");
bool IsOutputOfElementwiseOp(Node* n); bool IsOutputOfElementwiseOp(Node* n);
void Insert(Node* n);
int Search(Node* n, std::vector<Node*> except_nodes = {}); int Search(Node* n, std::vector<Node*> except_nodes = {});
private: private:
......
...@@ -36,7 +36,7 @@ int FusionGroupPass::DetectFusionGroup(Graph* graph, int type) const { ...@@ -36,7 +36,7 @@ int FusionGroupPass::DetectFusionGroup(Graph* graph, int type) const {
for (Node* n : all_nodes) { for (Node* n : all_nodes) {
bool is_found = false; bool is_found = false;
for (auto& subgraph : subgraphs) { for (auto& subgraph : subgraphs) {
if (subgraph.nodes_set.find(n) != subgraph.nodes_set.end()) { if (subgraph.Has(n)) {
is_found = true; is_found = true;
break; break;
} }
...@@ -61,15 +61,17 @@ int FusionGroupPass::DetectFusionGroup(Graph* graph, int type) const { ...@@ -61,15 +61,17 @@ int FusionGroupPass::DetectFusionGroup(Graph* graph, int type) const {
// TODO(liuyiqun): check whether there are intersection between subgraphs // TODO(liuyiqun): check whether there are intersection between subgraphs
for (size_t i = 0; i < subgraphs.size(); ++i) { for (size_t i = 0; i < subgraphs.size(); ++i) {
InsertFusionGroupOp(graph, subgraphs[i]); InsertFusionGroupOp(graph, &subgraphs[i]);
} }
return subgraphs.size(); return subgraphs.size();
} }
void FusionGroupPass::InsertFusionGroupOp( void FusionGroupPass::InsertFusionGroupOp(
Graph* graph, const fusion_group::SubGraph& subgraph) const { Graph* graph, fusion_group::SubGraph* subgraph) const {
std::vector<Node*> input_vars_of_subgraph = subgraph.GetInputVarNodes(); const std::vector<Node*>& input_vars_of_subgraph =
std::vector<Node*> output_vars_of_subgraph = subgraph.GetOutputVarNodes(); subgraph->GetInputVarNodes();
const std::vector<Node*>& output_vars_of_subgraph =
subgraph->GetOutputVarNodes();
std::unordered_set<Node*> external_nodes; std::unordered_set<Node*> external_nodes;
OpDesc op_desc; OpDesc op_desc;
...@@ -88,8 +90,8 @@ void FusionGroupPass::InsertFusionGroupOp( ...@@ -88,8 +90,8 @@ void FusionGroupPass::InsertFusionGroupOp(
external_nodes.insert(n); external_nodes.insert(n);
} }
op_desc.SetOutput("Outs", output_names); op_desc.SetOutput("Outs", output_names);
op_desc.SetAttr("type", subgraph.type); op_desc.SetAttr("type", subgraph->type);
op_desc.SetAttr("func_name", subgraph.func_name); op_desc.SetAttr("func_name", subgraph->func_name);
auto fusion_group_node = graph->CreateOpNode(&op_desc); auto fusion_group_node = graph->CreateOpNode(&op_desc);
for (auto* in : input_vars_of_subgraph) { for (auto* in : input_vars_of_subgraph) {
...@@ -100,7 +102,7 @@ void FusionGroupPass::InsertFusionGroupOp( ...@@ -100,7 +102,7 @@ void FusionGroupPass::InsertFusionGroupOp(
} }
std::unordered_set<const Node*> internal_nodes; std::unordered_set<const Node*> internal_nodes;
for (auto* n : subgraph.nodes_set) { for (auto* n : subgraph->Nodes()) {
if (external_nodes.find(n) == external_nodes.end()) { if (external_nodes.find(n) == external_nodes.end()) {
internal_nodes.insert(n); internal_nodes.insert(n);
} }
......
...@@ -30,7 +30,7 @@ class FusionGroupPass : public Pass { ...@@ -30,7 +30,7 @@ class FusionGroupPass : public Pass {
private: private:
int DetectFusionGroup(Graph* graph, int type = 0) const; int DetectFusionGroup(Graph* graph, int type = 0) const;
void InsertFusionGroupOp(Graph* graph, void InsertFusionGroupOp(Graph* graph,
const fusion_group::SubGraph& subgraph) const; fusion_group::SubGraph* subgraph) const;
const std::string name_scope_{"fusion_group"}; const std::string name_scope_{"fusion_group"};
}; };
......
...@@ -22,6 +22,14 @@ namespace paddle { ...@@ -22,6 +22,14 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
void VisualizeGraph(std::unique_ptr<Graph> graph, std::string graph_viz_path) {
// Insert a graph_viz_pass to transform the graph to a .dot file.
// It can be used for debug.
auto graph_viz_pass = PassRegistry::Instance().Get("graph_viz_pass");
graph_viz_pass->Set("graph_viz_path", new std::string(graph_viz_path));
graph.reset(graph_viz_pass->Apply(graph.release()));
}
TEST(FusionGroupPass, elementwise_list) { TEST(FusionGroupPass, elementwise_list) {
fusion_group::OperationMap::Init(); fusion_group::OperationMap::Init();
...@@ -46,29 +54,17 @@ TEST(FusionGroupPass, elementwise_list) { ...@@ -46,29 +54,17 @@ TEST(FusionGroupPass, elementwise_list) {
layers.elementwise_add(tmp_2, w); layers.elementwise_add(tmp_2, w);
std::unique_ptr<Graph> graph(new Graph(layers.main_program())); std::unique_ptr<Graph> graph(new Graph(layers.main_program()));
// VisualizeGraph(graph, "00_elementwise_list.dot");
// The following codes is to insert a graph_viz_pass to transform the graph to
// a .dot file. It is used for debug.
// auto graph_viz_pass = PassRegistry::Instance().Get("graph_viz_pass");
// graph_viz_pass->Set("graph_viz_path", new
// std::string("00_elementwise_list.dot"));
// graph.reset(graph_viz_pass->Apply(graph.release()));
auto fusion_group_pass = PassRegistry::Instance().Get("fusion_group_pass"); auto fusion_group_pass = PassRegistry::Instance().Get("fusion_group_pass");
VLOG(3) << DebugString(graph); VLOG(3) << DebugString(graph);
graph.reset(fusion_group_pass->Apply(graph.release())); graph.reset(fusion_group_pass->Apply(graph.release()));
// VisualizeGraph(graph, "01_elementwise_list.fusion_group.dot");
int num_fusion_group_ops = GetNumOpNodes(graph, "fusion_group"); int num_fusion_group_ops = GetNumOpNodes(graph, "fusion_group");
VLOG(3) << DebugString(graph); VLOG(3) << DebugString(graph);
PADDLE_ENFORCE_EQ(num_fusion_group_ops, 1); PADDLE_ENFORCE_EQ(num_fusion_group_ops, 1);
// The following codes is to insert a graph_viz_pass to transform the graph to
// a .dot file. It is used for debug.
// auto graph_viz_pass = PassRegistry::Instance().Get("graph_viz_pass");
// graph_viz_pass->Set("graph_viz_path", new
// std::string("01_elementwise_list.fusion_group.dot"));
// graph.reset(graph_viz_pass->Apply(graph.release()));
} }
TEST(FusionGroupPass, elementwise_tree) { TEST(FusionGroupPass, elementwise_tree) {
...@@ -128,29 +124,17 @@ TEST(FusionGroupPass, elementwise_tree) { ...@@ -128,29 +124,17 @@ TEST(FusionGroupPass, elementwise_tree) {
layers.mul(tmp_6, tmp_9); layers.mul(tmp_6, tmp_9);
std::unique_ptr<Graph> graph(new Graph(layers.main_program())); std::unique_ptr<Graph> graph(new Graph(layers.main_program()));
// VisualizeGraph(graph, "00_elementwise_tree.dot");
// The following codes is to insert a graph_viz_pass to transform the graph to
// a .dot file. It is used for debug.
// auto graph_viz_pass = PassRegistry::Instance().Get("graph_viz_pass");
// graph_viz_pass->Set("graph_viz_path", new
// std::string("00_elementwise_tree.dot"));
// graph.reset(graph_viz_pass->Apply(graph.release()));
auto fusion_group_pass = PassRegistry::Instance().Get("fusion_group_pass"); auto fusion_group_pass = PassRegistry::Instance().Get("fusion_group_pass");
LOG(INFO) << DebugString(graph); VLOG(3) << DebugString(graph);
graph.reset(fusion_group_pass->Apply(graph.release())); graph.reset(fusion_group_pass->Apply(graph.release()));
// VisualizeGraph(graph, "01_elementwise_tree.fusion_group.dot");
int num_fusion_group_ops = GetNumOpNodes(graph, "fusion_group"); int num_fusion_group_ops = GetNumOpNodes(graph, "fusion_group");
LOG(INFO) << DebugString(graph); VLOG(3) << DebugString(graph);
PADDLE_ENFORCE_EQ(num_fusion_group_ops, 2); PADDLE_ENFORCE_EQ(num_fusion_group_ops, 2);
// The following codes is to insert a graph_viz_pass to transform the graph to
// a .dot file. It is used for debug.
// auto graph_viz_pass = PassRegistry::Instance().Get("graph_viz_pass");
// graph_viz_pass->Set("graph_viz_path", new
// std::string("01_elementwise_tree.fusion_group.dot"));
// graph.reset(graph_viz_pass->Apply(graph.release()));
} }
} // namespace ir } // namespace ir
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/ir/fusion_group/operation.h" #include "paddle/fluid/framework/ir/fusion_group/operation.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -38,15 +39,30 @@ std::unordered_set<std::string> OperationMap::Find(int type, int num_operands) { ...@@ -38,15 +39,30 @@ std::unordered_set<std::string> OperationMap::Find(int type, int num_operands) {
} }
void OperationMap::Insert(int type, int num_operands, std::string op_type, void OperationMap::Insert(int type, int num_operands, std::string op_type,
std::string expr, std::string expr, std::vector<std::string> grad_exprs,
std::vector<std::string> grad_exprs) { std::vector<std::string> input_names,
Operation op(type, num_operands, op_type, {expr}); std::vector<std::string> output_names) {
Operation op(type, num_operands, op_type, {expr}, input_names, output_names);
PADDLE_ENFORCE_EQ(op.IsValid(), true, "Operation %s is invalid.", op_type); PADDLE_ENFORCE_EQ(op.IsValid(), true, "Operation %s is invalid.", op_type);
operations_[op_type] = op; operations_[op_type] = op;
if (grad_exprs.size() > 0U) { if (grad_exprs.size() > 0U) {
std::string grad_op_type = op_type + "_grad"; std::string grad_op_type = op_type + "_grad";
Operation grad_op(type, num_operands, grad_op_type, grad_exprs); // grad_inputs = inputs + outputs + grad of outputs
std::vector<std::string> grad_input_names = input_names;
for (auto name : output_names) {
grad_input_names.push_back(name);
}
for (auto name : output_names) {
grad_input_names.push_back(GradVarName(name));
}
// grad_output = grad of inputs
std::vector<std::string> grad_output_names;
for (auto name : input_names) {
grad_output_names.push_back(GradVarName(name));
}
Operation grad_op(type, num_operands, grad_op_type, grad_exprs,
grad_input_names, grad_output_names);
PADDLE_ENFORCE_EQ(grad_op.IsValid(), true, "Operation %s is invalid.", PADDLE_ENFORCE_EQ(grad_op.IsValid(), true, "Operation %s is invalid.",
grad_op_type); grad_op_type);
operations_[grad_op_type] = grad_op; operations_[grad_op_type] = grad_op;
...@@ -54,59 +70,65 @@ void OperationMap::Insert(int type, int num_operands, std::string op_type, ...@@ -54,59 +70,65 @@ void OperationMap::Insert(int type, int num_operands, std::string op_type,
} }
void OperationMap::InsertUnaryElementwiseOperations() { void OperationMap::InsertUnaryElementwiseOperations() {
int type = 0;
int num_oprands = 1;
// For unary elementwise operations: // For unary elementwise operations:
// ${0} - x // ${0} - x
// ${1} - out // ${1} - out
// ${2} - dout // ${2} - dout
auto insert_handler = [&](std::string op_type, std::string expr,
std::vector<std::string> grad_exprs) {
int type = 0;
int num_oprands = 1;
Insert(type, num_oprands, op_type, expr, grad_exprs, {"X"}, {"Out"});
};
// relu: // relu:
// out = f(x) = x > 0 ? x : 0 // out = f(x) = x > 0 ? x : 0
// dx = dout * (out > 0 ? 1 : 0) = dout * (x > 0 ? 1 : 0) // dx = dout * (out > 0 ? 1 : 0) = dout * (x > 0 ? 1 : 0)
Insert(type, num_oprands, "relu", "real_max(${0}, 0)", insert_handler("relu", "real_max(${0}, 0)", {"${0} > 0 ? ${2} : 0"});
{"${0} > 0 ? ${2} : 0"});
// sigmoid: // sigmoid:
// out = f(x) = 1.0 / (1.0 + exp(-x)) // out = f(x) = 1.0 / (1.0 + exp(-x))
// dx = dout * out * (1 - out) // dx = dout * out * (1 - out)
Insert(type, num_oprands, "sigmoid", "1.0 / (1.0 + real_exp(- ${0}))", insert_handler("sigmoid", "1.0 / (1.0 + real_exp(- ${0}))",
{"${2} * ${1} * (1.0 - ${1})"}); {"${2} * ${1} * (1.0 - ${1})"});
// tanh: // tanh:
// out = f(x) = 2.0 / (1.0 + exp(-2.0 * x)) - 1.0; // out = f(x) = 2.0 / (1.0 + exp(-2.0 * x)) - 1.0;
// dx = dout * (1 - out * out) // dx = dout * (1 - out * out)
Insert(type, num_oprands, "tanh", "2.0 / (1.0 + real_exp(-2.0 * ${0})) - 1.0", insert_handler("tanh", "2.0 / (1.0 + real_exp(-2.0 * ${0})) - 1.0",
{"${2} * (1.0 - ${1} * ${1})"}); {"${2} * (1.0 - ${1} * ${1})"});
} }
void OperationMap::InsertBinaryElementwiseOperations() { void OperationMap::InsertBinaryElementwiseOperations() {
int type = 0;
int num_oprands = 2;
// For binary elementwise oprations: // For binary elementwise oprations:
// ${0} - x // ${0} - x
// ${1} - y // ${1} - y
// ${2} - out // ${2} - out
// ${3} - dout // ${3} - dout
auto insert_handler = [&](std::string op_type, std::string expr,
std::vector<std::string> grad_exprs) {
int type = 0;
int num_oprands = 2;
Insert(type, num_oprands, op_type, expr, grad_exprs, {"X", "Y"}, {"Out"});
};
// elementwise_add: // elementwise_add:
// out = x + y // out = x + y
// dx = dout * 1 // dx = dout * 1
// dy = dout * 1 // dy = dout * 1
Insert(type, num_oprands, "elementwise_add", "${0} + ${1}", {"${3}", "${3}"}); insert_handler("elementwise_add", "${0} + ${1}", {"${3}", "${3}"});
// elementwise_sub: // elementwise_sub:
// out = x - y // out = x - y
// dx = dout * 1 // dx = dout * 1
// dy = dout * (-1) // dy = dout * (-1)
Insert(type, num_oprands, "elementwise_sub", "${0} - ${1}", insert_handler("elementwise_sub", "${0} - ${1}", {"${3}", "- ${3}"});
{"${3}", "- ${3}"});
// elementwise_mul: // elementwise_mul:
// out = x * y // out = x * y
// dx = dout * y // dx = dout * y
// dy = dout * x // dy = dout * x
Insert(type, num_oprands, "elementwise_mul", "${0} * ${1}", insert_handler("elementwise_mul", "${0} * ${1}",
{"${3} * ${1}", "${3} * ${0}"}); {"${3} * ${1}", "${3} * ${0}"});
Insert(type, num_oprands, "elementwise_div", "${0} / ${1}", {}); insert_handler("elementwise_div", "${0} / ${1}", {});
Insert(type, num_oprands, "elementwise_min", "real_min(${0}, ${1})", {}); insert_handler("elementwise_min", "real_min(${0}, ${1})", {});
Insert(type, num_oprands, "elementwise_max", "real_max(${0}, ${1})", {}); insert_handler("elementwise_max", "real_max(${0}, ${1})", {});
} }
} // namespace fusion_group } // namespace fusion_group
......
...@@ -26,20 +26,32 @@ namespace ir { ...@@ -26,20 +26,32 @@ namespace ir {
namespace fusion_group { namespace fusion_group {
struct Operation { struct Operation {
Operation() {} Operation() = default;
Operation(int t, int n, std::string o, std::vector<std::string> e) Operation(int t, int n, std::string o, std::vector<std::string> e,
: type(t), num_operands(n), op_type(o), exprs(e) {} std::vector<std::string> i_n, std::vector<std::string> o_n)
: type(t),
num_operands(n),
op_type(o),
exprs(e),
input_names(i_n),
output_names(o_n) {}
bool IsGradOp() { bool IsGradOp() {
std::string suffix = "_grad"; std::string suffix = "_grad";
return op_type.rfind(suffix) == (op_type.length() - suffix.length()); size_t pos = op_type.rfind(suffix);
return pos != std::string::npos &&
pos == (op_type.length() - suffix.length());
} }
bool IsValid() { bool IsValid() {
if (!IsGradOp() && exprs.size() != 1U) { if (!IsGradOp() && exprs.size() != 1U) {
// When it is a forward operation, it should hold only one expression (for
// only one output).
return false; return false;
} }
if (IsGradOp() && exprs.size() != static_cast<size_t>(num_operands)) { if (IsGradOp() && exprs.size() != static_cast<size_t>(num_operands)) {
// When it is a backward opertion, it should hold a expression for each
// operand.
return false; return false;
} }
return true; return true;
...@@ -49,6 +61,8 @@ struct Operation { ...@@ -49,6 +61,8 @@ struct Operation {
int num_operands; int num_operands;
std::string op_type; std::string op_type;
std::vector<std::string> exprs; std::vector<std::string> exprs;
std::vector<std::string> input_names;
std::vector<std::string> output_names;
}; };
class OperationMap { class OperationMap {
...@@ -83,7 +97,9 @@ class OperationMap { ...@@ -83,7 +97,9 @@ class OperationMap {
private: private:
void Insert(int type, int num_operands, std::string op_type, std::string expr, void Insert(int type, int num_operands, std::string op_type, std::string expr,
std::vector<std::string> grad_exprs); std::vector<std::string> grad_exprs,
std::vector<std::string> input_names,
std::vector<std::string> output_names);
void InsertUnaryElementwiseOperations(); void InsertUnaryElementwiseOperations();
void InsertBinaryElementwiseOperations(); void InsertBinaryElementwiseOperations();
......
...@@ -15,8 +15,10 @@ limitations under the License. */ ...@@ -15,8 +15,10 @@ limitations under the License. */
#pragma once #pragma once
#include <string> #include <string>
#include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/fusion_group/operation.h"
#include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/ir/node.h"
namespace paddle { namespace paddle {
...@@ -27,12 +29,35 @@ namespace fusion_group { ...@@ -27,12 +29,35 @@ namespace fusion_group {
struct SubGraph { struct SubGraph {
int type{-1}; int type{-1};
std::string func_name; std::string func_name;
std::unordered_set<Node*> nodes_set; bool save_intermediate_out{false};
SubGraph() = default;
SubGraph(int t, std::string f, bool s, const std::unordered_set<Node*>& n)
: type(t), func_name(f), save_intermediate_out(s), nodes_set(n) {}
bool IsEmpty() { return nodes_set.empty(); } bool IsEmpty() { return nodes_set.empty(); }
const std::unordered_set<Node*>& Nodes() const { return nodes_set; }
const std::vector<Node*>& SortedNodes() {
if (!is_sorted) {
Sort();
}
return sorted_nodes;
}
size_t GetNumNodes() { return nodes_set.size(); } size_t GetNumNodes() { return nodes_set.size(); }
bool Has(Node* n) { return nodes_set.find(n) != nodes_set.end(); }
void Insert(Node* n) {
if (nodes_set.find(n) == nodes_set.end()) {
VLOG(5) << "Insert " << n->Name() << " to subgraph " << this;
nodes_set.insert(n);
is_sorted = false;
}
}
int GetNumOperations() { int GetNumOperations() {
int num_operations = 0; int num_operations = 0;
for (auto* n : nodes_set) { for (auto* n : nodes_set) {
...@@ -43,11 +68,10 @@ struct SubGraph { ...@@ -43,11 +68,10 @@ struct SubGraph {
return num_operations; return num_operations;
} }
std::vector<Node*> GetInputVarNodes() const { std::vector<Node*> GetInputVarNodes() {
// The order of input nodes should be consistent with that of the generated // The order of input nodes should be consistent anywhere.
// code.
std::vector<Node*> input_vars; std::vector<Node*> input_vars;
for (auto* n : nodes_set) { for (auto* n : SortedNodes()) {
if (n && n->IsVar() && n->Var()) { if (n && n->IsVar() && n->Var()) {
bool is_found = true; bool is_found = true;
// When the inputs size is 0, it is also considered the input var of // When the inputs size is 0, it is also considered the input var of
...@@ -57,7 +81,7 @@ struct SubGraph { ...@@ -57,7 +81,7 @@ struct SubGraph {
} }
// Normally a var node has only one input op node. // Normally a var node has only one input op node.
for (auto* in : n->inputs) { for (auto* in : n->inputs) {
if (nodes_set.find(in) == nodes_set.end()) { if (!Has(in)) {
is_found = false; is_found = false;
} }
} }
...@@ -69,28 +93,197 @@ struct SubGraph { ...@@ -69,28 +93,197 @@ struct SubGraph {
return input_vars; return input_vars;
} }
std::vector<Node*> GetOutputVarNodes() const { std::vector<Node*> GetOutputVarNodes() {
// The order of output nodes should be consistant with that of the generated // The order of output nodes should be consistant anywhere..
// code.
std::vector<Node*> output_vars; std::vector<Node*> output_vars;
for (auto* n : SortedNodes()) {
if (n && n->IsVar() && n->Var()) {
if (save_intermediate_out) {
// If the var_node is the output of some op_node in the subgraph, it
// is considered the output var node of the subgraph.
bool is_found = false;
for (auto* in : n->inputs) {
if (Has(in)) {
is_found = true;
}
}
if (is_found) {
output_vars.push_back(n);
}
} else {
// If one of the var_node's outputs is the input of some operator
// outside the subgraph, it is considered the output var node of the
// subgraph.
bool is_found = true;
if (n->outputs.size() == 0U) {
is_found = false;
}
for (auto* out : n->outputs) {
if (!Has(out)) {
is_found = false;
}
}
if (!is_found) {
output_vars.push_back(n);
}
}
}
}
return output_vars;
}
private:
int FindIndexInSortedNodes(Node* n) {
for (size_t i = 0; i < sorted_nodes.size(); ++i) {
if (n == sorted_nodes[i]) {
return static_cast<int>(i);
}
}
return -1;
}
void SortVarsBasedOnSortedOps() {
// Insert var nodes to sorted_nodes.
std::unordered_map<std::string, Node*> sorted_vars;
for (auto* n : nodes_set) { for (auto* n : nodes_set) {
if (n && n->IsVar() && n->Var()) { if (n && n->IsVar() && n->Var()) {
bool is_found = true; int from = 0;
if (n->outputs.size() == 0U) { int to = sorted_nodes.size();
is_found = false;
for (auto* in : n->inputs) {
if (in && in->IsOp() && in->Op()) {
int index = FindIndexInSortedNodes(in);
// Insert after input op node
if (index >= 0) {
from = index + 1 > from ? index + 1 : from;
}
}
} }
for (auto* out : n->outputs) { for (auto* out : n->outputs) {
if (nodes_set.find(out) == nodes_set.end()) { if (out && out->IsOp() && out->Op()) {
is_found = false; int index = FindIndexInSortedNodes(out);
// Insert before output op node
if (index >= 0) {
to = index < to ? index : to;
}
} }
} }
if (!is_found) {
output_vars.push_back(n); PADDLE_ENFORCE_LE(from, to, "Range [%d, %d] is invalid.", from, to);
sorted_nodes.insert(sorted_nodes.begin() + to, n);
sorted_vars[n->Name()] = n;
}
}
}
std::vector<Node*> SortedOps() {
Node* start_op_n = nullptr;
std::unordered_set<Node*> ops;
for (auto* op_n : nodes_set) {
if (op_n && op_n->IsOp() && op_n->Op()) {
// Initialize ops to all ops in the subgraph.
ops.insert(op_n);
if (!start_op_n) {
// Find start op node whose inputs are produced outside the subgraph.
bool is_found = false;
for (auto* prev_op_n : GetPrevOpNodes(op_n)) {
if (Has(prev_op_n)) {
is_found = true;
break;
}
}
if (!is_found) {
start_op_n = op_n;
}
} }
} }
} }
return output_vars;
std::vector<Node*> sorted_ops;
sorted_ops.push_back(start_op_n);
ops.erase(start_op_n);
while (ops.size() > 0U) {
std::unordered_set<Node*> erased_ops;
for (auto* op_n : ops) {
bool found_connected_ops = false;
int from = 1;
int to = sorted_ops.size();
std::unordered_set<Node*> prev_op_nodes = GetPrevOpNodes(op_n);
std::unordered_set<Node*> next_op_nodes = GetNextOpNodes(op_n);
for (int i = sorted_ops.size(); i >= 0; --i) {
if (prev_op_nodes.find(sorted_ops[i]) != prev_op_nodes.end()) {
// Insert after i (i + 1)
found_connected_ops = true;
from = (i + 1 > from) ? i + 1 : from;
}
if (next_op_nodes.find(sorted_ops[i]) != next_op_nodes.end()) {
// Insert before i
found_connected_ops = true;
to = (i < to) ? i : to;
}
}
if (found_connected_ops) {
PADDLE_ENFORCE_LE(from, to, "Range [%d, %d] is invalid.", from, to);
sorted_ops.insert(sorted_ops.begin() + to, op_n);
erased_ops.insert(op_n);
}
}
PADDLE_ENFORCE_GT(erased_ops.size(), 0U);
for (auto* op_n : erased_ops) {
ops.erase(op_n);
}
}
return sorted_ops;
}
std::unordered_set<Node*> GetPrevOpNodes(Node* op_n) {
PADDLE_ENFORCE_EQ(op_n && op_n->IsOp() && op_n->Op(), true,
"Node %p is not a op node.", op_n);
std::unordered_set<Node*> prev_op_nodes;
for (auto* in_var : op_n->inputs) {
if (in_var && in_var->IsVar() && in_var->Var()) {
for (auto* prev_op_n : in_var->inputs) {
if (prev_op_n && prev_op_n->IsOp() && prev_op_n->Op()) {
prev_op_nodes.insert(prev_op_n);
}
}
}
}
return prev_op_nodes;
}
std::unordered_set<Node*> GetNextOpNodes(Node* op_n) {
PADDLE_ENFORCE_EQ(op_n && op_n->IsOp() && op_n->Op(), true,
"Node %p is not a op node.", op_n);
std::unordered_set<Node*> next_op_nodes;
for (auto* out_var : op_n->outputs) {
if (out_var && out_var->IsVar() && out_var->Var()) {
for (auto* next_op_n : out_var->outputs) {
if (next_op_n && next_op_n->IsOp() && next_op_n->Op()) {
next_op_nodes.insert(next_op_n);
}
}
}
}
return next_op_nodes;
} }
void Sort() {
if (!is_sorted) {
sorted_nodes = SortedOps();
SortVarsBasedOnSortedOps();
}
is_sorted = true;
}
private:
std::unordered_set<Node*> nodes_set;
bool is_sorted{false};
std::vector<Node*> sorted_nodes;
}; };
} // namespace fusion_group } // namespace fusion_group
......
...@@ -19,7 +19,10 @@ limitations under the License. */ ...@@ -19,7 +19,10 @@ limitations under the License. */
#include <string> #include <string>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -267,6 +270,47 @@ struct Layers { ...@@ -267,6 +270,47 @@ struct Layers {
return outs; return outs;
} }
void backward() {
BlockDesc* block = program_.MutableBlock(0);
std::vector<OpDesc*> forward_ops = block->AllOps();
for (int i = forward_ops.size() - 1; i >= 0; --i) {
OpDesc* op = forward_ops[i];
OpDesc* grad_op = block->AppendOp();
grad_op->SetType(op->Type() + "_grad");
// All op's inputs are grad_op's input.
for (auto name : op->InputNames()) {
grad_op->SetInput(name, op->Input(name));
}
// All op's outputs are grad_op's input.
for (auto name : op->OutputNames()) {
grad_op->SetInput(name, op->Output(name));
}
// All op's outputs grad are grad_op's input.
for (auto name : op->OutputNames()) {
std::vector<std::string> grad_var_names;
for (auto var_name : op->Output(name)) {
VarDesc* var = block->FindVar(var_name);
VarDesc* grad_var =
lod_tensor(GradVarName(var_name), var->GetShape(), false);
grad_var_names.push_back(grad_var->Name());
}
grad_op->SetInput(GradVarName(name), grad_var_names);
}
// All op's inputs grad are grad_op's output.
for (auto name : op->InputNames()) {
std::vector<std::string> grad_var_names;
for (auto var_name : op->Input(name)) {
VarDesc* var = block->FindVar(var_name);
VarDesc* grad_var =
lod_tensor(GradVarName(var_name), var->GetShape(), false);
grad_var_names.push_back(grad_var->Name());
}
grad_op->SetOutput(GradVarName(name), grad_var_names);
}
// TODO(liuyiqun): attrs
}
}
private: private:
VarDesc* lod_tensor(std::string name, std::vector<int64_t> shape = {}, VarDesc* lod_tensor(std::string name, std::vector<int64_t> shape = {},
bool is_persistable = false) { bool is_persistable = false) {
...@@ -412,7 +456,7 @@ static std::string DebugString(Node* node) { ...@@ -412,7 +456,7 @@ static std::string DebugString(Node* node) {
return os.str(); return os.str();
} }
static std::string DebugString(const std::unordered_set<Node*>& nodes) { static std::string DebugString(const std::vector<Node*>& nodes) {
std::ostringstream os; std::ostringstream os;
for (auto* node : nodes) { for (auto* node : nodes) {
if (node->IsOp() && node->Op()) { if (node->IsOp() && node->Op()) {
...@@ -425,6 +469,14 @@ static std::string DebugString(const std::unordered_set<Node*>& nodes) { ...@@ -425,6 +469,14 @@ static std::string DebugString(const std::unordered_set<Node*>& nodes) {
return os.str(); return os.str();
} }
static std::string DebugString(const std::unordered_set<Node*>& nodes) {
std::vector<Node*> vec;
for (auto* node : nodes) {
vec.push_back(node);
}
return DebugString(vec);
}
static std::string DebugString(const std::unique_ptr<Graph>& graph) { static std::string DebugString(const std::unique_ptr<Graph>& graph) {
std::ostringstream os; std::ostringstream os;
os << "Graph: {\n" << DebugString(graph->Nodes()) << "}\n"; os << "Graph: {\n" << DebugString(graph->Nodes()) << "}\n";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册