未验证 提交 1be6bf45 编写于 作者: Y Yiqun Liu 提交者: GitHub

Add assign to fusion_group and enhance inplace execution in fusion_group. (#26121)

上级 b2034c28
...@@ -68,11 +68,35 @@ static bool HasInput(Node* n, std::string name) { ...@@ -68,11 +68,35 @@ static bool HasInput(Node* n, std::string name) {
return input_names_set.find(name) != input_names_set.end(); return input_names_set.find(name) != input_names_set.end();
} }
static Node* GetInputVar(Node* n, const std::string& name) {
PADDLE_ENFORCE_EQ(n && n->IsOp() && n->Op(), true,
platform::errors::InvalidArgument(
"Expected node %p to be an operator node.", n));
for (auto* in : n->inputs) {
if (in->Name() == name) {
return in;
}
}
return nullptr;
}
static Node* GetOutputVar(Node* n, const std::string& name) {
PADDLE_ENFORCE_EQ(n && n->IsOp() && n->Op(), true,
platform::errors::InvalidArgument(
"Expected node %p to be an operator node.", n));
for (auto* out : n->outputs) {
if (out->Name() == name) {
return out;
}
}
return nullptr;
}
std::vector<OperationExpression> CodeGenerator::ConvertToExpressions( std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
SubGraph* subgraph) { SubGraph* subgraph) {
std::unordered_map<std::string, int> var_ids = EncodeVarNodes(subgraph); std::unordered_map<Node*, int> var_ids = EncodeVarNodes(subgraph);
std::vector<Node*> intermediate_out_nodes = std::unordered_set<Node*> intermediate_out_vars_set =
subgraph->GetIntermediateOutVarNodes(); subgraph->GetIntermediateOutVarNodesSet();
std::vector<OperationExpression> expressions; std::vector<OperationExpression> expressions;
for (auto* node : subgraph->SortedNodes()) { for (auto* node : subgraph->SortedNodes()) {
if (node && node->IsOp() && node->Op()) { if (node && node->IsOp() && node->Op()) {
...@@ -92,11 +116,12 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions( ...@@ -92,11 +116,12 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
// "elementwise_add_grad", where "X", "Y" and "Out" are not used. // "elementwise_add_grad", where "X", "Y" and "Out" are not used.
if ((HasInput(node, name) && op->Input(name).size() >= 1U)) { if ((HasInput(node, name) && op->Input(name).size() >= 1U)) {
for (size_t i = 0; i < op->Input(name).size(); i++) { for (size_t i = 0; i < op->Input(name).size(); i++) {
Node* input_var = GetInputVar(node, op->Input(name)[i]);
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
var_ids.find(op->Input(name)[i]), var_ids.end(), var_ids.find(input_var), var_ids.end(),
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Input(%s) of operation %s is not set.", name, op->Type())); "Input(%s) of operation %s is not set.", name, op->Type()));
input_ids.push_back(var_ids[op->Input(name)[i]]); input_ids.push_back(var_ids[input_var]);
} }
} else { } else {
input_ids.push_back(-1); input_ids.push_back(-1);
...@@ -106,31 +131,29 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions( ...@@ -106,31 +131,29 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
// Output ids should be set in fixed order, like: // Output ids should be set in fixed order, like:
// - dx, dy in backward operations // - dx, dy in backward operations
std::vector<int> output_ids; std::vector<int> output_ids;
std::vector<int> intermediate_output_ids;
std::vector<std::string> output_names = std::vector<std::string> output_names =
OperationMap::Instance().Get(op->Type()).output_names; OperationMap::Instance().Get(op->Type()).output_names;
std::unordered_map<int, bool> intermediate_state;
for (auto& name : output_names) { for (auto& name : output_names) {
Node* output_var = GetOutputVar(node, op->Output(name)[0]);
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
var_ids.find(op->Output(name)[0]), var_ids.end(), var_ids.find(output_var), var_ids.end(),
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Output(%s) of operation %s is not set.", name, op->Type())); "Output(%s) of operation %s is not set.", name, op->Type()));
output_ids.push_back(var_ids[op->Output(name)[0]]); output_ids.push_back(var_ids[output_var]);
bool enable_intermediate = false; if (!subgraph->SaveIntermediateOut() &&
for (auto* n : intermediate_out_nodes) { intermediate_out_vars_set.find(output_var) !=
if (n->Name() == op->Output(name)[0]) { intermediate_out_vars_set.end()) {
enable_intermediate = true; intermediate_output_ids.push_back(var_ids[output_var]);
break;
}
} }
intermediate_state[var_ids[op->Output(name)[0]]] = enable_intermediate;
} }
std::string lhs_type = ExtractDataType(node->outputs); std::string lhs_type = ExtractDataType(node->outputs);
std::string rhs_type = ExtractDataType(node->inputs); std::string rhs_type = ExtractDataType(node->inputs);
auto expression = auto expression =
OperationExpression(node->Name(), input_ids, output_ids, rhs_type, OperationExpression(node->Name(), input_ids, output_ids, rhs_type,
lhs_type, intermediate_state); lhs_type, intermediate_output_ids);
expression.SetAttr(attr); expression.SetAttr(attr);
expressions.push_back(expression); expressions.push_back(expression);
} }
...@@ -146,17 +169,18 @@ std::string CodeGenerator::Generate( ...@@ -146,17 +169,18 @@ std::string CodeGenerator::Generate(
// TODO(liuyiqun): Check whether all expressions are elementwise operations. // TODO(liuyiqun): Check whether all expressions are elementwise operations.
std::set<int> input_ids = std::move(DistilInputIds(expressions)); std::set<int> input_ids = std::move(DistilInputIds(expressions));
std::set<int> output_ids = std::move(DistilOutputIds(expressions)); std::set<int> output_ids = std::move(DistilOutputIds(expressions));
std::set<int> intermediate_ids = std::set<int> intermediate_output_ids =
std::move(DistilIntermediateIds(expressions)); std::move(DistilIntermediateIds(expressions));
std::unordered_map<int, std::string> dtypes = std::unordered_map<int, std::string> dtypes =
std::move(DistilDtypes(expressions)); std::move(DistilDtypes(expressions));
TemplateVariable template_var; TemplateVariable template_var;
template_var.Add("func_name", func_name); template_var.Add("func_name", func_name);
template_var.Add("parameters", EmitParameters(input_ids, output_ids, template_var.Add(
intermediate_ids, dtypes)); "parameters",
EmitParameters(input_ids, output_ids, intermediate_output_ids, dtypes));
template_var.Add("compute_body", template_var.Add("compute_body",
EmitComputeBody(expressions, input_ids, output_ids, EmitComputeBody(expressions, input_ids, output_ids,
intermediate_ids, dtypes)); intermediate_output_ids, dtypes));
std::set<std::string> all_dtype; std::set<std::string> all_dtype;
for (const auto& type : dtypes) { for (const auto& type : dtypes) {
...@@ -204,18 +228,14 @@ std::set<int> CodeGenerator::DistilOutputIds( ...@@ -204,18 +228,14 @@ std::set<int> CodeGenerator::DistilOutputIds(
std::set<int> CodeGenerator::DistilIntermediateIds( std::set<int> CodeGenerator::DistilIntermediateIds(
const std::vector<OperationExpression>& expressions) { const std::vector<OperationExpression>& expressions) {
std::set<int> intermediate_ids; std::set<int> intermediate_output_ids;
// Use std::set to remove the reptead id and get a ordered list. // Use std::set to remove the reptead id and get a ordered list.
for (size_t i = 0; i < expressions.size(); i++) { for (size_t i = 0; i < expressions.size(); i++) {
for (auto id : expressions[i].GetOutputIds()) { for (auto id : expressions[i].GetIntermediateOutputIds()) {
auto intermediate_state = expressions[i].GetIntermediateState(); intermediate_output_ids.insert(id);
if (intermediate_state.find(id) != intermediate_state.end() &&
intermediate_state[id]) {
intermediate_ids.insert(id);
}
} }
} }
return intermediate_ids; return intermediate_output_ids;
} }
std::unordered_map<int, std::string> CodeGenerator::DistilDtypes( std::unordered_map<int, std::string> CodeGenerator::DistilDtypes(
...@@ -316,26 +336,29 @@ std::string CodeGenerator::EmitComputeBody( ...@@ -316,26 +336,29 @@ std::string CodeGenerator::EmitComputeBody(
return load.str() + compute.str() + store.str(); return load.str() + compute.str() + store.str();
} }
std::unordered_map<std::string, int> CodeGenerator::EncodeVarNodes( std::unordered_map<Node*, int> CodeGenerator::EncodeVarNodes(
SubGraph* subgraph) { SubGraph* subgraph) {
const auto& input_var_nodes = subgraph->GetInputVarNodes(); const auto& input_var_nodes = subgraph->GetInputVarNodes();
const auto& output_var_nodes = subgraph->GetOutputVarNodes(); // Encode all var nodes, including intermediate output var nodes.
const auto& output_var_nodes = subgraph->GetOutputVarNodes(true);
int id = 0; int id = 0;
std::unordered_map<std::string, int> var_ids; std::unordered_map<Node*, int> var_ids;
// Numbering input vars. // Numbering input vars.
for (auto* in : input_var_nodes) { for (auto* in : input_var_nodes) {
VLOG(3) << "Encoding input names:" << in->Name() << ", id:" << id; VLOG(3) << "Encoding input names:" << in->Name() << "(" << in
if (var_ids.find(in->Name()) == var_ids.end()) { << "), id:" << id;
var_ids[in->Name()] = id++; if (var_ids.find(in) == var_ids.end()) {
var_ids[in] = id++;
} }
} }
// Encoding output vars. // Encoding output vars.
for (auto* out : output_var_nodes) { for (auto* out : output_var_nodes) {
VLOG(3) << "Ecoding output names:" << out->Name() << ", id:" << id; VLOG(3) << "Ecoding output names:" << out->Name() << "(" << out
if (var_ids.find(out->Name()) == var_ids.end()) { << "), id:" << id;
var_ids[out->Name()] = id++; if (var_ids.find(out) == var_ids.end()) {
var_ids[out] = id++;
} }
} }
return var_ids; return var_ids;
......
...@@ -61,7 +61,7 @@ class CodeGenerator { ...@@ -61,7 +61,7 @@ class CodeGenerator {
const std::unordered_map<int, std::string>& dtypes) const; const std::unordered_map<int, std::string>& dtypes) const;
// Encode all var nodes in the subgraph with an unique number. // Encode all var nodes in the subgraph with an unique number.
std::unordered_map<std::string, int> EncodeVarNodes(SubGraph* subgraph); std::unordered_map<Node*, int> EncodeVarNodes(SubGraph* subgraph);
private: private:
std::vector<CodeTemplate> code_templates_; std::vector<CodeTemplate> code_templates_;
......
...@@ -48,20 +48,20 @@ class OperationExpression { ...@@ -48,20 +48,20 @@ class OperationExpression {
std::string op_type, const std::vector<int>& input_ids, std::string op_type, const std::vector<int>& input_ids,
const std::vector<int>& output_ids, std::string rhs_type, const std::vector<int>& output_ids, std::string rhs_type,
std::string lhs_type, std::string lhs_type,
const std::unordered_map<int, bool>& intermediate_state = {}) const std::vector<int>& intermediate_output_ids = {})
: op_type_(op_type), : op_type_(op_type),
input_ids_(input_ids), input_ids_(input_ids),
output_ids_(output_ids), output_ids_(output_ids),
rhs_type_(rhs_type), rhs_type_(rhs_type),
lhs_type_(lhs_type), lhs_type_(lhs_type),
intermediate_state_(intermediate_state) {} intermediate_output_ids_(intermediate_output_ids) {}
std::string GetOpType() const { return op_type_; } std::string GetOpType() const { return op_type_; }
std::unordered_map<int, bool> GetIntermediateState() const {
return intermediate_state_;
}
std::vector<int> GetInputIds() const { return input_ids_; } std::vector<int> GetInputIds() const { return input_ids_; }
std::vector<int> GetOutputIds() const { return output_ids_; } std::vector<int> GetOutputIds() const { return output_ids_; }
std::vector<int> GetIntermediateOutputIds() const {
return intermediate_output_ids_;
}
std::string GetRHSType() const { return rhs_type_; } std::string GetRHSType() const { return rhs_type_; }
std::string GetLHSType() const { return lhs_type_; } std::string GetLHSType() const { return lhs_type_; }
void SetAttr(AttributeMap attr) { attr_ = attr; } void SetAttr(AttributeMap attr) { attr_ = attr; }
...@@ -84,7 +84,7 @@ class OperationExpression { ...@@ -84,7 +84,7 @@ class OperationExpression {
AttributeMap attr_; AttributeMap attr_;
std::string rhs_type_; std::string rhs_type_;
std::string lhs_type_; std::string lhs_type_;
std::unordered_map<int, bool> intermediate_state_; std::vector<int> intermediate_output_ids_;
}; };
class TemplateVariable { class TemplateVariable {
......
...@@ -144,7 +144,6 @@ void CheckOutput(const std::vector<OperationExpression>& expressions, ...@@ -144,7 +144,6 @@ void CheckOutput(const std::vector<OperationExpression>& expressions,
LOG(INFO) << "Precision check failed from i = " << id LOG(INFO) << "Precision check failed from i = " << id
<< ", expect: " << expect << ", actual: " << actual; << ", expect: " << expect << ", actual: " << actual;
EXPECT_LT(fabs(actual - expect), eps); EXPECT_LT(fabs(actual - expect), eps);
break;
} }
} }
} }
...@@ -465,7 +464,7 @@ TEST(code_generator, subgraph) { ...@@ -465,7 +464,7 @@ TEST(code_generator, subgraph) {
for (std::string dtype : {"float", "__half"}) { for (std::string dtype : {"float", "__half"}) {
std::unique_ptr<paddle::framework::ir::Graph> graph = std::unique_ptr<paddle::framework::ir::Graph> graph =
BuildGraph(false, dtype); BuildGraph(false, dtype);
fusion_group::SubGraph subgraph(0, "elementwise_kernel_1", false, fusion_group::SubGraph subgraph(0, "elementwise_kernel_1", true,
graph->Nodes()); graph->Nodes());
// Expressions generated by code_generator (they may be different): // Expressions generated by code_generator (they may be different):
...@@ -484,7 +483,7 @@ TEST(code_generator, subgraph_grad) { ...@@ -484,7 +483,7 @@ TEST(code_generator, subgraph_grad) {
for (std::string dtype : {"float", "__half"}) { for (std::string dtype : {"float", "__half"}) {
std::unique_ptr<paddle::framework::ir::Graph> graph = std::unique_ptr<paddle::framework::ir::Graph> graph =
BuildGraph(true, dtype); BuildGraph(true, dtype);
fusion_group::SubGraph subgraph(0, "elementwise_grad_kernel_1", false, fusion_group::SubGraph subgraph(0, "elementwise_grad_kernel_1", true,
DistilGradNodes(graph)); DistilGradNodes(graph));
// Expressions generated by code_generator (they may be different): // Expressions generated by code_generator (they may be different):
......
...@@ -63,7 +63,7 @@ static bool IsEqualAndNotEmpty(const std::vector<int64_t>& l, ...@@ -63,7 +63,7 @@ static bool IsEqualAndNotEmpty(const std::vector<int64_t>& l,
bool GroupDetector::CheckPrecondition(const Node* n) { bool GroupDetector::CheckPrecondition(const Node* n) {
auto check_data_type = [&](const std::vector<Node*>& nodes) -> bool { auto check_data_type = [&](const std::vector<Node*>& nodes) -> bool {
bool is_first = true; bool is_first = true;
proto::VarType::Type data_type_0; proto::VarType::Type data_type_0 = proto::VarType::BOOL;
for (auto* n : nodes) { for (auto* n : nodes) {
if (n && n->IsVar() && n->Var()) { if (n && n->IsVar() && n->Var()) {
if (n->Var()->GetType() != proto::VarType::LOD_TENSOR) { if (n->Var()->GetType() != proto::VarType::LOD_TENSOR) {
......
...@@ -63,11 +63,6 @@ int FusionGroupPass::DetectFusionGroup(Graph* graph, int type) const { ...@@ -63,11 +63,6 @@ int FusionGroupPass::DetectFusionGroup(Graph* graph, int type) const {
std::unordered_set<Node*>(vec.begin(), vec.end())); std::unordered_set<Node*>(vec.begin(), vec.end()));
VLOG(3) << "subgraph: {\n" << DebugString(subgraph.SortedNodes()) << "}\n"; VLOG(3) << "subgraph: {\n" << DebugString(subgraph.SortedNodes()) << "}\n";
// In elementwise fused kernel, memory is the bound of execution,
// here we remove the output id to use less memory and less time.
if (subgraph.RemoveIntermediateOut()) {
subgraph.DetectIntermediateOutWithGraph(graph);
}
if (subgraph.IsValid(min_subgraph_size)) { if (subgraph.IsValid(min_subgraph_size)) {
subgraph.SetFuncName("fused_elementwise_" + std::to_string(index++)); subgraph.SetFuncName("fused_elementwise_" + std::to_string(index++));
if (GenerateCode(&subgraph)) { if (GenerateCode(&subgraph)) {
...@@ -115,57 +110,52 @@ static int ExtractOpRole(fusion_group::SubGraph* subgraph) { ...@@ -115,57 +110,52 @@ static int ExtractOpRole(fusion_group::SubGraph* subgraph) {
void FusionGroupPass::InsertFusionGroupOp( void FusionGroupPass::InsertFusionGroupOp(
Graph* graph, fusion_group::SubGraph* subgraph) const { Graph* graph, fusion_group::SubGraph* subgraph) const {
const std::vector<Node*>& input_vars_of_subgraph = const std::vector<Node*>& input_vars = subgraph->GetInputVarNodes();
subgraph->GetInputVarNodes(); const std::vector<Node*>& output_vars =
const std::vector<Node*>& output_vars_of_subgraph = subgraph->GetOutputVarNodes(subgraph->SaveIntermediateOut());
subgraph->GetOutputVarNodes();
const std::vector<Node*> intermediate_vars_of_subgraph =
subgraph->GetIntermediateOutVarNodes();
std::unordered_set<Node*> external_nodes; std::unordered_set<Node*> external_nodes;
OpDesc op_desc; // Prepare inputs.
op_desc.SetType("fusion_group");
std::vector<std::string> input_names; std::vector<std::string> input_names;
std::vector<std::string> inputs_data_types; std::vector<int> input_dtypes;
for (auto* n : input_vars_of_subgraph) { std::unordered_set<Node*> output_vars_set(output_vars.begin(),
output_vars.end());
for (auto* n : input_vars) {
// It is not an output var node.
if (output_vars_set.find(n) == output_vars_set.end()) {
input_names.push_back(n->Name()); input_names.push_back(n->Name());
inputs_data_types.push_back(DataTypeToString(n->Var()->GetDataType())); input_dtypes.push_back(n->Var()->GetDataType());
external_nodes.insert(n); external_nodes.insert(n);
} }
op_desc.SetInput("Inputs", input_names); }
// Prepare outputs.
std::vector<std::string> output_names; std::vector<std::string> output_names;
std::vector<std::string> outs_data_types; std::vector<int> output_dtypes;
std::vector<Node*> output_var_without_intermediate; for (auto* n : output_vars) {
for (auto* n : output_vars_of_subgraph) {
auto it_input =
find(input_vars_of_subgraph.begin(), input_vars_of_subgraph.end(), n);
auto it_intermediate = find(intermediate_vars_of_subgraph.begin(),
intermediate_vars_of_subgraph.end(), n);
if (it_intermediate == intermediate_vars_of_subgraph.end() &&
it_input == input_vars_of_subgraph.end()) {
output_names.push_back(n->Name()); output_names.push_back(n->Name());
outs_data_types.push_back(DataTypeToString(n->Var()->GetDataType())); output_dtypes.push_back(n->Var()->GetDataType());
output_var_without_intermediate.push_back(n);
}
external_nodes.insert(n); external_nodes.insert(n);
} }
OpDesc op_desc;
op_desc.SetType("fusion_group");
op_desc.SetInput("Inputs", input_names);
op_desc.SetOutput("Outs", output_names); op_desc.SetOutput("Outs", output_names);
op_desc.SetAttr("inputs_data_type", inputs_data_types); op_desc.SetAttr("inputs_dtype", input_dtypes);
op_desc.SetAttr("outs_data_type", outs_data_types); op_desc.SetAttr("outs_dtype", output_dtypes);
op_desc.SetAttr("type", subgraph->GetType()); op_desc.SetAttr("type", subgraph->GetType());
op_desc.SetAttr("func_name", subgraph->GetFuncName()); op_desc.SetAttr("func_name", subgraph->GetFuncName());
op_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), op_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
ExtractOpRole(subgraph)); ExtractOpRole(subgraph));
Node* fusion_group_node = graph->CreateOpNode(&op_desc); Node* fusion_group_node = graph->CreateOpNode(&op_desc);
for (auto* in : input_vars_of_subgraph) { for (auto* in : input_vars) {
if (output_vars_set.find(in) == output_vars_set.end()) {
IR_NODE_LINK_TO(in, fusion_group_node); IR_NODE_LINK_TO(in, fusion_group_node);
} }
}
for (auto* out : output_var_without_intermediate) { for (auto* out : output_vars) {
IR_NODE_LINK_TO(fusion_group_node, out); IR_NODE_LINK_TO(fusion_group_node, out);
} }
......
...@@ -105,12 +105,6 @@ void OperationMap::InsertUnaryElementwiseOperations() { ...@@ -105,12 +105,6 @@ void OperationMap::InsertUnaryElementwiseOperations() {
insert_handler("tanh", "%{2.0} / (%{1.0} + Exp(-%{2.0} * ${0})) - %{1.0}", insert_handler("tanh", "%{2.0} / (%{1.0} + Exp(-%{2.0} * ${0})) - %{1.0}",
{"${2} * (%{1.0} - ${1} * ${1})"}); {"${2} * (%{1.0} - ${1} * ${1})"});
// cast:
// out = static_cast<T>(x)
// TODO(wangchaochaohu): This is not the compelete definition of
// cast Op, We need refine it later.
insert_handler("cast", "${0}", {});
// sqrt: // sqrt:
// out = x^(1/2) // out = x^(1/2)
// dx = dout * 0.5 / out // dx = dout * 0.5 / out
...@@ -121,6 +115,16 @@ void OperationMap::InsertUnaryElementwiseOperations() { ...@@ -121,6 +115,16 @@ void OperationMap::InsertUnaryElementwiseOperations() {
// dx = dout * 2.0 * x // dx = dout * 2.0 * x
insert_handler("square", "${0} * ${0}", {"${2} * %{2.0} * ${0}"}); insert_handler("square", "${0} * ${0}", {"${2} * %{2.0} * ${0}"});
// assign:
// out = x
insert_handler("assign", "${0}", {});
// cast:
// out = static_cast<T>(x)
// TODO(wangchaochaohu): This is not the compelete definition of
// cast Op, We need refine it later.
insert_handler("cast", "${0}", {});
// scale // scale
// out = (bias_after_scale) ? scale * X + bias : scale(X + bias) // out = (bias_after_scale) ? scale * X + bias : scale(X + bias)
// here we use '=' operator to seperate th default value // here we use '=' operator to seperate th default value
......
...@@ -66,11 +66,12 @@ class SubGraph { ...@@ -66,11 +66,12 @@ class SubGraph {
} }
int GetType() const { return type_; } int GetType() const { return type_; }
bool RemoveIntermediateOut() { return !save_intermediate_out_; }
void SetFuncName(std::string func_name) { func_name_ = func_name; } void SetFuncName(std::string func_name) { func_name_ = func_name; }
std::string GetFuncName() const { return func_name_; } std::string GetFuncName() const { return func_name_; }
bool SaveIntermediateOut() const { return save_intermediate_out_; }
const std::unordered_set<Node*>& Nodes() const { return nodes_set_; } const std::unordered_set<Node*>& Nodes() const { return nodes_set_; }
const std::vector<Node*>& SortedNodes() { const std::vector<Node*>& SortedNodes() {
if (!is_sorted_) { if (!is_sorted_) {
...@@ -118,66 +119,88 @@ class SubGraph { ...@@ -118,66 +119,88 @@ class SubGraph {
return input_vars; return input_vars;
} }
std::vector<Node*> GetOutputVarNodes() { std::vector<Node*> GetOutputVarNodes(bool with_intermediate_out) {
// The order of output nodes should be consistant anywhere.. // The order of output nodes should be consistant anywhere..
std::vector<Node*> output_vars_all; std::vector<Node*> output_vars;
for (auto* n : SortedNodes()) { for (auto* n : SortedNodes()) {
if (n && n->IsVar() && n->Var()) { if (IsOutputOfInternalOp(n)) {
// If the var_node is the output of some op_node in the subgraph, it // If the var_node is the output of some op_node in the subgraph, it
// is considered the output var node of the subgraph. // is considered the output var node of the subgraph.
bool is_found = false; if (with_intermediate_out) {
for (auto* in : n->inputs) { output_vars.push_back(n);
if (Has(in)) { } else {
is_found = true; if (n->outputs.empty() || IsInputOfExternalOp(n)) {
} output_vars.push_back(n);
} }
if (is_found) {
output_vars_all.push_back(n);
} }
} }
} }
return output_vars_all; return output_vars;
} }
std::vector<Node*> GetIntermediateOutVarNodes() { std::vector<Node*> GetIntermediateOutVarNodes() {
return intermediate_out_nodes_; // Intermediate output var nodes: the output of some op_node in the
// subgraph, but not referenced outside the subgraph.
std::vector<Node*> intermediate_out_vars;
for (auto* n : SortedNodes()) {
if (IsOutputOfInternalOp(n) && IsInputOfInternalOp(n) &&
!IsInputOfExternalOp(n)) {
// When the outputs size is 0, it is also considered a intermidiate
// output. It maybe an unused output or the fetching vars, so that we
// cannot eleiminate it directly here.
intermediate_out_vars.push_back(n);
}
}
return intermediate_out_vars;
} }
void DetectIntermediateOutWithGraph(Graph* graph) { std::unordered_set<Node*> GetIntermediateOutVarNodesSet() {
auto graph_nodes = graph->Nodes(); std::vector<Node*> intermediate_out_vars = GetIntermediateOutVarNodes();
return std::unordered_set<Node*>(intermediate_out_vars.begin(),
for (auto* n : SortedNodes()) { intermediate_out_vars.end());
bool enable_remove = true; }
if (n && n->IsVar() && n->Var()) { private:
bool leaf_graph = true; bool IsInputOfInternalOp(Node* n) {
for (auto* node : graph_nodes) { bool is_input_of_internal_op = false;
if (node->IsOp()) { if (Has(n) && n && n->IsVar() && n->Var()) {
auto inputs = node->inputs; for (auto* out : n->outputs) {
for (auto* in : inputs) { if (Has(out)) {
if (in && in->Name() == n->Name()) { is_input_of_internal_op = true;
if (!Has(node)) enable_remove = false; break;
leaf_graph = false; }
} }
} }
return is_input_of_internal_op;
} }
if (!enable_remove) {
bool IsInputOfExternalOp(Node* n) {
// If n is the input any one node outside the subgraph.
bool is_input_of_external_op = false;
if (Has(n) && n && n->IsVar() && n->Var()) {
for (auto* out : n->outputs) {
if (!Has(out)) {
is_input_of_external_op = true;
break; break;
} }
} }
if (leaf_graph) enable_remove = false; }
return is_input_of_external_op;
} else {
enable_remove = false;
} }
if (enable_remove) { bool IsOutputOfInternalOp(Node* n) {
intermediate_out_nodes_.push_back(n); bool is_output_of_internal_op = false;
if (Has(n) && n && n->IsVar() && n->Var()) {
for (auto* in : n->inputs) {
if (Has(in)) {
is_output_of_internal_op = true;
break;
}
} }
} }
return is_output_of_internal_op;
} }
private:
void TopologicalSort() { void TopologicalSort() {
if (!is_sorted_) { if (!is_sorted_) {
std::unordered_map<Node*, std::vector<Node*>> inputs_map; std::unordered_map<Node*, std::vector<Node*>> inputs_map;
...@@ -236,7 +259,6 @@ class SubGraph { ...@@ -236,7 +259,6 @@ class SubGraph {
bool save_intermediate_out_{true}; bool save_intermediate_out_{true};
std::unordered_set<Node*> nodes_set_; std::unordered_set<Node*> nodes_set_;
std::vector<Node*> intermediate_out_nodes_{};
bool is_sorted_{false}; bool is_sorted_{false};
std::vector<Node*> sorted_nodes_; std::vector<Node*> sorted_nodes_;
}; };
......
...@@ -22,8 +22,14 @@ class FusionGroupOp : public framework::OperatorWithKernel { ...@@ -22,8 +22,14 @@ class FusionGroupOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
const size_t num_ins = ctx->Inputs("Inputs").size(); OP_INOUT_CHECK(ctx->HasInputs("Inputs"), "Input", "Inputs", "FusionGroup");
const size_t num_outs = ctx->Outputs("Outs").size(); OP_INOUT_CHECK(ctx->HasOutputs("Outs"), "Output", "Outs", "FusionGroup");
auto input_names = ctx->Inputs("Inputs");
auto output_names = ctx->Outputs("Outs");
const size_t num_ins = input_names.size();
const size_t num_outs = output_names.size();
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
num_ins, 1UL, num_ins, 1UL,
...@@ -42,9 +48,12 @@ class FusionGroupOp : public framework::OperatorWithKernel { ...@@ -42,9 +48,12 @@ class FusionGroupOp : public framework::OperatorWithKernel {
std::vector<framework::DDim> x_dims = ctx->GetInputsDim("Inputs"); std::vector<framework::DDim> x_dims = ctx->GetInputsDim("Inputs");
if (type == 0) { if (type == 0) {
for (size_t i = 1; i < num_ins; ++i) { for (size_t i = 1; i < num_ins; ++i) {
PADDLE_ENFORCE_EQ(x_dims[0], x_dims[i], PADDLE_ENFORCE_EQ(
x_dims[0], x_dims[i],
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"All the inputs' dims should be the same.")); "All the inputs' dims is expected to be the same. "
"But recieved [%s] (name: %s) vs [%s] (name: %s).",
x_dims[0], input_names[0], x_dims[i], input_names[i]));
} }
std::vector<framework::DDim> out_dims; std::vector<framework::DDim> out_dims;
for (size_t j = 0; j < num_outs; ++j) { for (size_t j = 0; j < num_outs; ++j) {
...@@ -76,11 +85,11 @@ class FusionGroupOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -76,11 +85,11 @@ class FusionGroupOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("Outs", AddOutput("Outs",
"(std::vector<LoDTensor>) The outputs of fusion_group op.") "(std::vector<LoDTensor>) The outputs of fusion_group op.")
.AsDuplicable(); .AsDuplicable();
AddAttr<std::vector<std::string>>( AddAttr<std::vector<int>>("outs_dtype",
"outs_data_type", "The data type of Outputs in fusion_group op.") "The data type of Outputs in fusion_group op.")
.SetDefault({}); .SetDefault({});
AddAttr<std::vector<std::string>>( AddAttr<std::vector<int>>("inputs_dtype",
"inputs_data_type", "The data type of Inputs in fusion_group op.") "The data type of Inputs in fusion_group op.")
.SetDefault({}); .SetDefault({});
AddAttr<int>("type", "Fusion type.").SetDefault(0); AddAttr<int>("type", "Fusion type.").SetDefault(0);
AddAttr<std::string>("func_name", "Name of the generated functions.") AddAttr<std::string>("func_name", "Name of the generated functions.")
......
...@@ -24,14 +24,14 @@ namespace operators { ...@@ -24,14 +24,14 @@ namespace operators {
static void MutableMultiTypeData( static void MutableMultiTypeData(
std::vector<paddle::framework::LoDTensor*>* var, std::vector<paddle::framework::LoDTensor*>* var,
const std::vector<std::string>& data_type, const platform::Place& place) { const std::vector<int>& data_type, const platform::Place& place) {
for (size_t i = 0; i < var->size(); i++) { for (size_t i = 0; i < var->size(); i++) {
if (data_type[i] == "float") { if (data_type[i] == framework::proto::VarType::FP32) {
(*var)[i]->mutable_data<float>(place); (*var)[i]->mutable_data<float>(place);
} else if (data_type[i] == "double") { } else if (data_type[i] == framework::proto::VarType::FP16) {
(*var)[i]->mutable_data<double>(place);
} else if (data_type[i] == "::paddle::platform::float16") {
(*var)[i]->mutable_data<paddle::platform::float16>(place); (*var)[i]->mutable_data<paddle::platform::float16>(place);
} else if (data_type[i] == framework::proto::VarType::FP64) {
(*var)[i]->mutable_data<double>(place);
} }
} }
} }
...@@ -43,15 +43,15 @@ class FusionGroupKernel : public framework::OpKernel<T> { ...@@ -43,15 +43,15 @@ class FusionGroupKernel : public framework::OpKernel<T> {
auto ins = ctx.MultiInput<framework::LoDTensor>("Inputs"); auto ins = ctx.MultiInput<framework::LoDTensor>("Inputs");
auto outs = ctx.MultiOutput<framework::LoDTensor>("Outs"); auto outs = ctx.MultiOutput<framework::LoDTensor>("Outs");
int type = ctx.Attr<int>("type"); int type = ctx.Attr<int>("type");
auto outs_type = ctx.Attr<std::vector<std::string>>("outs_data_type"); const auto& outs_dtype = ctx.Attr<std::vector<int>>("outs_dtype");
auto inputs_type = ctx.Attr<std::vector<std::string>>("inputs_data_type"); const auto& inputs_dtype = ctx.Attr<std::vector<int>>("inputs_dtype");
size_t num_ins = ins.size(); size_t num_ins = ins.size();
size_t num_outs = outs.size(); size_t num_outs = outs.size();
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
MutableMultiTypeData(&outs, outs_type, place); MutableMultiTypeData(&outs, outs_dtype, place);
std::string func_name = ctx.Attr<std::string>("func_name"); std::string func_name = ctx.Attr<std::string>("func_name");
platform::DeviceCode* dev_code = platform::DeviceCode* dev_code =
...@@ -64,22 +64,22 @@ class FusionGroupKernel : public framework::OpKernel<T> { ...@@ -64,22 +64,22 @@ class FusionGroupKernel : public framework::OpKernel<T> {
args.push_back(&n); args.push_back(&n);
std::vector<const void*> ptrs(num_ins + num_outs); std::vector<const void*> ptrs(num_ins + num_outs);
for (size_t i = 0; i < num_ins; ++i) { for (size_t i = 0; i < num_ins; ++i) {
if (inputs_type[i] == "::paddle::platform::float16") { if (inputs_dtype[i] == framework::proto::VarType::FP16) {
ptrs[i] = ins[i]->data<paddle::platform::float16>(); ptrs[i] = ins[i]->data<paddle::platform::float16>();
} else if (inputs_type[i] == "double") { } else if (inputs_dtype[i] == framework::proto::VarType::FP32) {
ptrs[i] = ins[i]->data<double>();
} else if (inputs_type[i] == "float") {
ptrs[i] = ins[i]->data<float>(); ptrs[i] = ins[i]->data<float>();
} else if (inputs_dtype[i] == framework::proto::VarType::FP64) {
ptrs[i] = ins[i]->data<double>();
} }
args.push_back(&ptrs[i]); args.push_back(&ptrs[i]);
} }
for (size_t j = 0; j < num_outs; ++j) { for (size_t j = 0; j < num_outs; ++j) {
if (outs_type[j] == "::paddle::platform::float16") { if (outs_dtype[j] == framework::proto::VarType::FP16) {
ptrs[num_ins + j] = outs[j]->data<paddle::platform::float16>(); ptrs[num_ins + j] = outs[j]->data<paddle::platform::float16>();
} else if (outs_type[j] == "double") { } else if (outs_dtype[j] == framework::proto::VarType::FP32) {
ptrs[num_ins + j] = outs[j]->data<double>();
} else if (outs_type[j] == "float") {
ptrs[num_ins + j] = outs[j]->data<float>(); ptrs[num_ins + j] = outs[j]->data<float>();
} else if (outs_dtype[j] == framework::proto::VarType::FP64) {
ptrs[num_ins + j] = outs[j]->data<double>();
} }
args.push_back(&ptrs[num_ins + j]); args.push_back(&ptrs[num_ins + j]);
} }
......
...@@ -57,10 +57,14 @@ framework::OpDesc* CreateFusionGroupOp( ...@@ -57,10 +57,14 @@ framework::OpDesc* CreateFusionGroupOp(
const std::vector<std::string>& input_names, const std::vector<std::string>& input_names,
const std::vector<std::vector<int64_t>>& input_shapes, const std::vector<std::vector<int64_t>>& input_shapes,
const std::vector<std::string>& output_names, int type, const std::vector<std::string>& output_names, int type,
const std::vector<std::string>& inputs_data_type, std::string func_name) {
const std::vector<std::string>& outs_data_type, std::string func_name) {
EXPECT_EQ(input_names.size(), input_shapes.size()); EXPECT_EQ(input_names.size(), input_shapes.size());
std::vector<int> input_dtypes(input_names.size(),
framework::proto::VarType::FP32);
std::vector<int> output_dtypes(output_names.size(),
framework::proto::VarType::FP32);
for (size_t i = 0; i < input_names.size(); ++i) { for (size_t i = 0; i < input_names.size(); ++i) {
auto* var = program->MutableBlock(0)->Var(input_names[i]); auto* var = program->MutableBlock(0)->Var(input_names[i]);
var->SetType(framework::proto::VarType::LOD_TENSOR); var->SetType(framework::proto::VarType::LOD_TENSOR);
...@@ -77,8 +81,8 @@ framework::OpDesc* CreateFusionGroupOp( ...@@ -77,8 +81,8 @@ framework::OpDesc* CreateFusionGroupOp(
op->SetType("fusion_group"); op->SetType("fusion_group");
op->SetInput("Inputs", input_names); op->SetInput("Inputs", input_names);
op->SetOutput("Outs", output_names); op->SetOutput("Outs", output_names);
op->SetAttr("inputs_data_type", inputs_data_type); op->SetAttr("inputs_dtype", input_dtypes);
op->SetAttr("outs_data_type", outs_data_type); op->SetAttr("outs_dtype", output_dtypes);
op->SetAttr("type", type); op->SetAttr("type", type);
op->SetAttr("func_name", func_name); op->SetAttr("func_name", func_name);
op->SetAttr(framework::OpProtoAndCheckerMaker::OpRoleAttrName(), op->SetAttr(framework::OpProtoAndCheckerMaker::OpRoleAttrName(),
...@@ -133,8 +137,6 @@ void CheckOutputs(framework::Scope* scope, ...@@ -133,8 +137,6 @@ void CheckOutputs(framework::Scope* scope,
void TestMain(const std::vector<std::string>& input_names, void TestMain(const std::vector<std::string>& input_names,
const std::vector<std::vector<int64_t>>& input_shapes, const std::vector<std::vector<int64_t>>& input_shapes,
const std::vector<std::string>& output_names, int type, const std::vector<std::string>& output_names, int type,
const std::vector<std::string>& inputs_data_type,
const std::vector<std::string>& outs_data_type,
std::string func_name, std::string cuda_kernel_str, std::string func_name, std::string cuda_kernel_str,
CPUKernelFunc cpu_kernel_func) { CPUKernelFunc cpu_kernel_func) {
// Compile the device code // Compile the device code
...@@ -144,9 +146,8 @@ void TestMain(const std::vector<std::string>& input_names, ...@@ -144,9 +146,8 @@ void TestMain(const std::vector<std::string>& input_names,
// Create a ProgramDesc that has a fusion_group_op. // Create a ProgramDesc that has a fusion_group_op.
framework::ProgramDesc program; framework::ProgramDesc program;
framework::OpDesc* op_desc = framework::OpDesc* op_desc = CreateFusionGroupOp(
CreateFusionGroupOp(&program, input_names, input_shapes, output_names, &program, input_names, input_shapes, output_names, type, func_name);
type, inputs_data_type, outs_data_type, func_name);
auto fusion_group_op = framework::OpRegistry::CreateOp(*op_desc); auto fusion_group_op = framework::OpRegistry::CreateOp(*op_desc);
framework::Scope scope; framework::Scope scope;
...@@ -216,11 +217,8 @@ void elementwise_cuda_kernel_0(size_t n, float *x, float* y, float* z) { ...@@ -216,11 +217,8 @@ void elementwise_cuda_kernel_0(size_t n, float *x, float* y, float* z) {
} }
}; };
std::vector<std::string> inputs_data_type(input_names.size(), "float"); TestMain(input_names, input_shapes, output_names, 0,
std::vector<std::string> outs_data_type(output_names.size(), "float"); "elementwise_cuda_kernel_0", kernel, elementwise_cpu_kernel_0);
TestMain(input_names, input_shapes, output_names, 0, inputs_data_type,
outs_data_type, "elementwise_cuda_kernel_0", kernel,
elementwise_cpu_kernel_0);
} }
} // namespace operators } // namespace operators
......
...@@ -77,12 +77,13 @@ class FusionGroupPassTest(PassTest): ...@@ -77,12 +77,13 @@ class FusionGroupPassTest(PassTest):
self.check_output_with_place(fluid.CUDAPlace(0)) self.check_output_with_place(fluid.CUDAPlace(0))
class FusionGroupPassTest1(FusionGroupPassTest): class FusionGroupPassComplicatedTest(FusionGroupPassTest):
def build_program(self, dtype): def build_program(self, dtype):
with fluid.program_guard(self.main_program, self.startup_program): with fluid.program_guard(self.main_program, self.startup_program):
self.feed_vars = self._prepare_feed_vars([32, 128], dtype, 5) self.feed_vars = self._prepare_feed_vars([32, 64], dtype, 5)
tmp_0 = layers.assign(self.feed_vars[0]) one = layers.fill_constant(shape=[1], dtype=dtype, value=1.0)
tmp_0 = one * self.feed_vars[0]
# subgraph with 9 op nodes # subgraph with 9 op nodes
tmp_1 = tmp_0 * layers.sigmoid(self.feed_vars[1]) + layers.sigmoid( tmp_1 = tmp_0 * layers.sigmoid(self.feed_vars[1]) + layers.sigmoid(
self.feed_vars[2]) * layers.tanh(self.feed_vars[3]) self.feed_vars[2]) * layers.tanh(self.feed_vars[3])
...@@ -94,7 +95,7 @@ class FusionGroupPassTest1(FusionGroupPassTest): ...@@ -94,7 +95,7 @@ class FusionGroupPassTest1(FusionGroupPassTest):
self.fetch_list = [tmp_2, self.grad(tmp_0)] self.fetch_list = [tmp_2, self.grad(tmp_0)]
class FusionGroupPassTest2(FusionGroupPassTest): class FusionGroupPassInplaceTest(FusionGroupPassTest):
def build_program(self, dtype): def build_program(self, dtype):
with fluid.program_guard(self.main_program, self.startup_program): with fluid.program_guard(self.main_program, self.startup_program):
self.feed_vars = self._prepare_feed_vars([32, 128], dtype, 3) self.feed_vars = self._prepare_feed_vars([32, 128], dtype, 3)
...@@ -103,15 +104,13 @@ class FusionGroupPassTest2(FusionGroupPassTest): ...@@ -103,15 +104,13 @@ class FusionGroupPassTest2(FusionGroupPassTest):
name="data3", shape=[128, 32], dtype=dtype)) name="data3", shape=[128, 32], dtype=dtype))
# subgraph with 3 op node # subgraph with 3 op node
tmp_0 = self.feed_vars[0] + self.feed_vars[1] tmp_0 = self.feed_vars[0] - self.feed_vars[1]
tmp_1 = layers.relu(self.feed_vars[2] * tmp_0) tmp_1 = tmp_0 * self.feed_vars[2]
# subgraph with 2 op nodes tmp_2 = layers.assign(tmp_1, output=tmp_0)
tmp_2 = layers.relu(layers.sigmoid(self.feed_vars[3])) tmp_3 = layers.mul(tmp_2, self.feed_vars[3])
tmp_3 = layers.mul(tmp_1, tmp_2)
self.append_gradients(tmp_3) self.num_fused_ops = 1
self.num_fused_ops = 2 self.fetch_list = [tmp_3]
self.fetch_list = [tmp_3, self.grad(tmp_1)]
class FusionGroupPassTestFP64(FusionGroupPassTest): class FusionGroupPassTestFP64(FusionGroupPassTest):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册