diff --git a/paddle/fluid/framework/ir/fusion_group/code_generator.cc b/paddle/fluid/framework/ir/fusion_group/code_generator.cc index 431d3c05f6dd4b44074729716555744773f950e7..55449856d189065388facf3e3ce736f505e976fb 100644 --- a/paddle/fluid/framework/ir/fusion_group/code_generator.cc +++ b/paddle/fluid/framework/ir/fusion_group/code_generator.cc @@ -68,11 +68,35 @@ static bool HasInput(Node* n, std::string name) { return input_names_set.find(name) != input_names_set.end(); } +static Node* GetInputVar(Node* n, const std::string& name) { + PADDLE_ENFORCE_EQ(n && n->IsOp() && n->Op(), true, + platform::errors::InvalidArgument( + "Expected node %p to be an operator node.", n)); + for (auto* in : n->inputs) { + if (in->Name() == name) { + return in; + } + } + return nullptr; +} + +static Node* GetOutputVar(Node* n, const std::string& name) { + PADDLE_ENFORCE_EQ(n && n->IsOp() && n->Op(), true, + platform::errors::InvalidArgument( + "Expected node %p to be an operator node.", n)); + for (auto* out : n->outputs) { + if (out->Name() == name) { + return out; + } + } + return nullptr; +} + std::vector CodeGenerator::ConvertToExpressions( SubGraph* subgraph) { - std::unordered_map var_ids = EncodeVarNodes(subgraph); - std::vector intermediate_out_nodes = - subgraph->GetIntermediateOutVarNodes(); + std::unordered_map var_ids = EncodeVarNodes(subgraph); + std::unordered_set intermediate_out_vars_set = + subgraph->GetIntermediateOutVarNodesSet(); std::vector expressions; for (auto* node : subgraph->SortedNodes()) { if (node && node->IsOp() && node->Op()) { @@ -92,11 +116,12 @@ std::vector CodeGenerator::ConvertToExpressions( // "elementwise_add_grad", where "X", "Y" and "Out" are not used. if ((HasInput(node, name) && op->Input(name).size() >= 1U)) { for (size_t i = 0; i < op->Input(name).size(); i++) { + Node* input_var = GetInputVar(node, op->Input(name)[i]); PADDLE_ENFORCE_NE( - var_ids.find(op->Input(name)[i]), var_ids.end(), + var_ids.find(input_var), var_ids.end(), platform::errors::InvalidArgument( "Input(%s) of operation %s is not set.", name, op->Type())); - input_ids.push_back(var_ids[op->Input(name)[i]]); + input_ids.push_back(var_ids[input_var]); } } else { input_ids.push_back(-1); @@ -106,31 +131,29 @@ std::vector CodeGenerator::ConvertToExpressions( // Output ids should be set in fixed order, like: // - dx, dy in backward operations std::vector output_ids; + std::vector intermediate_output_ids; std::vector output_names = OperationMap::Instance().Get(op->Type()).output_names; - std::unordered_map intermediate_state; for (auto& name : output_names) { + Node* output_var = GetOutputVar(node, op->Output(name)[0]); PADDLE_ENFORCE_NE( - var_ids.find(op->Output(name)[0]), var_ids.end(), + var_ids.find(output_var), var_ids.end(), platform::errors::InvalidArgument( "Output(%s) of operation %s is not set.", name, op->Type())); - output_ids.push_back(var_ids[op->Output(name)[0]]); - bool enable_intermediate = false; - for (auto* n : intermediate_out_nodes) { - if (n->Name() == op->Output(name)[0]) { - enable_intermediate = true; - break; - } + output_ids.push_back(var_ids[output_var]); + if (!subgraph->SaveIntermediateOut() && + intermediate_out_vars_set.find(output_var) != + intermediate_out_vars_set.end()) { + intermediate_output_ids.push_back(var_ids[output_var]); } - intermediate_state[var_ids[op->Output(name)[0]]] = enable_intermediate; } std::string lhs_type = ExtractDataType(node->outputs); std::string rhs_type = ExtractDataType(node->inputs); auto expression = OperationExpression(node->Name(), input_ids, output_ids, rhs_type, - lhs_type, intermediate_state); + lhs_type, intermediate_output_ids); expression.SetAttr(attr); expressions.push_back(expression); } @@ -146,17 +169,18 @@ std::string CodeGenerator::Generate( // TODO(liuyiqun): Check whether all expressions are elementwise operations. std::set input_ids = std::move(DistilInputIds(expressions)); std::set output_ids = std::move(DistilOutputIds(expressions)); - std::set intermediate_ids = + std::set intermediate_output_ids = std::move(DistilIntermediateIds(expressions)); std::unordered_map dtypes = std::move(DistilDtypes(expressions)); TemplateVariable template_var; template_var.Add("func_name", func_name); - template_var.Add("parameters", EmitParameters(input_ids, output_ids, - intermediate_ids, dtypes)); + template_var.Add( + "parameters", + EmitParameters(input_ids, output_ids, intermediate_output_ids, dtypes)); template_var.Add("compute_body", EmitComputeBody(expressions, input_ids, output_ids, - intermediate_ids, dtypes)); + intermediate_output_ids, dtypes)); std::set all_dtype; for (const auto& type : dtypes) { @@ -204,18 +228,14 @@ std::set CodeGenerator::DistilOutputIds( std::set CodeGenerator::DistilIntermediateIds( const std::vector& expressions) { - std::set intermediate_ids; + std::set intermediate_output_ids; // Use std::set to remove the reptead id and get a ordered list. for (size_t i = 0; i < expressions.size(); i++) { - for (auto id : expressions[i].GetOutputIds()) { - auto intermediate_state = expressions[i].GetIntermediateState(); - if (intermediate_state.find(id) != intermediate_state.end() && - intermediate_state[id]) { - intermediate_ids.insert(id); - } + for (auto id : expressions[i].GetIntermediateOutputIds()) { + intermediate_output_ids.insert(id); } } - return intermediate_ids; + return intermediate_output_ids; } std::unordered_map CodeGenerator::DistilDtypes( @@ -316,26 +336,29 @@ std::string CodeGenerator::EmitComputeBody( return load.str() + compute.str() + store.str(); } -std::unordered_map CodeGenerator::EncodeVarNodes( +std::unordered_map CodeGenerator::EncodeVarNodes( SubGraph* subgraph) { const auto& input_var_nodes = subgraph->GetInputVarNodes(); - const auto& output_var_nodes = subgraph->GetOutputVarNodes(); + // Encode all var nodes, including intermediate output var nodes. + const auto& output_var_nodes = subgraph->GetOutputVarNodes(true); int id = 0; - std::unordered_map var_ids; + std::unordered_map var_ids; // Numbering input vars. for (auto* in : input_var_nodes) { - VLOG(3) << "Encoding input names:" << in->Name() << ", id:" << id; - if (var_ids.find(in->Name()) == var_ids.end()) { - var_ids[in->Name()] = id++; + VLOG(3) << "Encoding input names:" << in->Name() << "(" << in + << "), id:" << id; + if (var_ids.find(in) == var_ids.end()) { + var_ids[in] = id++; } } // Encoding output vars. for (auto* out : output_var_nodes) { - VLOG(3) << "Ecoding output names:" << out->Name() << ", id:" << id; - if (var_ids.find(out->Name()) == var_ids.end()) { - var_ids[out->Name()] = id++; + VLOG(3) << "Ecoding output names:" << out->Name() << "(" << out + << "), id:" << id; + if (var_ids.find(out) == var_ids.end()) { + var_ids[out] = id++; } } return var_ids; diff --git a/paddle/fluid/framework/ir/fusion_group/code_generator.h b/paddle/fluid/framework/ir/fusion_group/code_generator.h index 2b18657bbcfbe81d4504306a3753d5c0b82092fd..21773f239b9f6e5208aea45f481bf6f92745033f 100644 --- a/paddle/fluid/framework/ir/fusion_group/code_generator.h +++ b/paddle/fluid/framework/ir/fusion_group/code_generator.h @@ -61,7 +61,7 @@ class CodeGenerator { const std::unordered_map& dtypes) const; // Encode all var nodes in the subgraph with an unique number. - std::unordered_map EncodeVarNodes(SubGraph* subgraph); + std::unordered_map EncodeVarNodes(SubGraph* subgraph); private: std::vector code_templates_; diff --git a/paddle/fluid/framework/ir/fusion_group/code_generator_helper.h b/paddle/fluid/framework/ir/fusion_group/code_generator_helper.h index 03d28277afbbb7467f638d521414d702bc8e8179..910f71e65bed10a515f9401a5b09a27ba0929fcf 100644 --- a/paddle/fluid/framework/ir/fusion_group/code_generator_helper.h +++ b/paddle/fluid/framework/ir/fusion_group/code_generator_helper.h @@ -48,20 +48,20 @@ class OperationExpression { std::string op_type, const std::vector& input_ids, const std::vector& output_ids, std::string rhs_type, std::string lhs_type, - const std::unordered_map& intermediate_state = {}) + const std::vector& intermediate_output_ids = {}) : op_type_(op_type), input_ids_(input_ids), output_ids_(output_ids), rhs_type_(rhs_type), lhs_type_(lhs_type), - intermediate_state_(intermediate_state) {} + intermediate_output_ids_(intermediate_output_ids) {} std::string GetOpType() const { return op_type_; } - std::unordered_map GetIntermediateState() const { - return intermediate_state_; - } std::vector GetInputIds() const { return input_ids_; } std::vector GetOutputIds() const { return output_ids_; } + std::vector GetIntermediateOutputIds() const { + return intermediate_output_ids_; + } std::string GetRHSType() const { return rhs_type_; } std::string GetLHSType() const { return lhs_type_; } void SetAttr(AttributeMap attr) { attr_ = attr; } @@ -84,7 +84,7 @@ class OperationExpression { AttributeMap attr_; std::string rhs_type_; std::string lhs_type_; - std::unordered_map intermediate_state_; + std::vector intermediate_output_ids_; }; class TemplateVariable { diff --git a/paddle/fluid/framework/ir/fusion_group/code_generator_tester.cc b/paddle/fluid/framework/ir/fusion_group/code_generator_tester.cc index 89b05fc577bb46606ff5c43d0dd697bd7b8aed38..ebc89b14c265d3491f0f9bc64a36f52c6c9f2a18 100644 --- a/paddle/fluid/framework/ir/fusion_group/code_generator_tester.cc +++ b/paddle/fluid/framework/ir/fusion_group/code_generator_tester.cc @@ -144,7 +144,6 @@ void CheckOutput(const std::vector& expressions, LOG(INFO) << "Precision check failed from i = " << id << ", expect: " << expect << ", actual: " << actual; EXPECT_LT(fabs(actual - expect), eps); - break; } } } @@ -465,7 +464,7 @@ TEST(code_generator, subgraph) { for (std::string dtype : {"float", "__half"}) { std::unique_ptr graph = BuildGraph(false, dtype); - fusion_group::SubGraph subgraph(0, "elementwise_kernel_1", false, + fusion_group::SubGraph subgraph(0, "elementwise_kernel_1", true, graph->Nodes()); // Expressions generated by code_generator (they may be different): @@ -484,7 +483,7 @@ TEST(code_generator, subgraph_grad) { for (std::string dtype : {"float", "__half"}) { std::unique_ptr graph = BuildGraph(true, dtype); - fusion_group::SubGraph subgraph(0, "elementwise_grad_kernel_1", false, + fusion_group::SubGraph subgraph(0, "elementwise_grad_kernel_1", true, DistilGradNodes(graph)); // Expressions generated by code_generator (they may be different): diff --git a/paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.cc b/paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.cc index 5de253bb96743dc18b6394f04d7818a090a114c2..f6262762a2af6e1abec47fca2bce85a74116b5fd 100644 --- a/paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.cc +++ b/paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.cc @@ -63,7 +63,7 @@ static bool IsEqualAndNotEmpty(const std::vector& l, bool GroupDetector::CheckPrecondition(const Node* n) { auto check_data_type = [&](const std::vector& nodes) -> bool { bool is_first = true; - proto::VarType::Type data_type_0; + proto::VarType::Type data_type_0 = proto::VarType::BOOL; for (auto* n : nodes) { if (n && n->IsVar() && n->Var()) { if (n->Var()->GetType() != proto::VarType::LOD_TENSOR) { diff --git a/paddle/fluid/framework/ir/fusion_group/fusion_group_pass.cc b/paddle/fluid/framework/ir/fusion_group/fusion_group_pass.cc index 883347085926f08adb877d6a7fbe8e5c5e8e1c50..2cf71cdcefcd595c85da63ecb0782d16de5dddb8 100644 --- a/paddle/fluid/framework/ir/fusion_group/fusion_group_pass.cc +++ b/paddle/fluid/framework/ir/fusion_group/fusion_group_pass.cc @@ -63,11 +63,6 @@ int FusionGroupPass::DetectFusionGroup(Graph* graph, int type) const { std::unordered_set(vec.begin(), vec.end())); VLOG(3) << "subgraph: {\n" << DebugString(subgraph.SortedNodes()) << "}\n"; - // In elementwise fused kernel, memory is the bound of execution, - // here we remove the output id to use less memory and less time. - if (subgraph.RemoveIntermediateOut()) { - subgraph.DetectIntermediateOutWithGraph(graph); - } if (subgraph.IsValid(min_subgraph_size)) { subgraph.SetFuncName("fused_elementwise_" + std::to_string(index++)); if (GenerateCode(&subgraph)) { @@ -115,57 +110,52 @@ static int ExtractOpRole(fusion_group::SubGraph* subgraph) { void FusionGroupPass::InsertFusionGroupOp( Graph* graph, fusion_group::SubGraph* subgraph) const { - const std::vector& input_vars_of_subgraph = - subgraph->GetInputVarNodes(); - const std::vector& output_vars_of_subgraph = - subgraph->GetOutputVarNodes(); - const std::vector intermediate_vars_of_subgraph = - subgraph->GetIntermediateOutVarNodes(); + const std::vector& input_vars = subgraph->GetInputVarNodes(); + const std::vector& output_vars = + subgraph->GetOutputVarNodes(subgraph->SaveIntermediateOut()); std::unordered_set external_nodes; - OpDesc op_desc; - op_desc.SetType("fusion_group"); - + // Prepare inputs. std::vector input_names; - std::vector inputs_data_types; - for (auto* n : input_vars_of_subgraph) { - input_names.push_back(n->Name()); - inputs_data_types.push_back(DataTypeToString(n->Var()->GetDataType())); - external_nodes.insert(n); + std::vector input_dtypes; + std::unordered_set output_vars_set(output_vars.begin(), + output_vars.end()); + for (auto* n : input_vars) { + // It is not an output var node. + if (output_vars_set.find(n) == output_vars_set.end()) { + input_names.push_back(n->Name()); + input_dtypes.push_back(n->Var()->GetDataType()); + external_nodes.insert(n); + } } - op_desc.SetInput("Inputs", input_names); + // Prepare outputs. std::vector output_names; - std::vector outs_data_types; - std::vector output_var_without_intermediate; - for (auto* n : output_vars_of_subgraph) { - auto it_input = - find(input_vars_of_subgraph.begin(), input_vars_of_subgraph.end(), n); - auto it_intermediate = find(intermediate_vars_of_subgraph.begin(), - intermediate_vars_of_subgraph.end(), n); - if (it_intermediate == intermediate_vars_of_subgraph.end() && - it_input == input_vars_of_subgraph.end()) { - output_names.push_back(n->Name()); - outs_data_types.push_back(DataTypeToString(n->Var()->GetDataType())); - output_var_without_intermediate.push_back(n); - } + std::vector output_dtypes; + for (auto* n : output_vars) { + output_names.push_back(n->Name()); + output_dtypes.push_back(n->Var()->GetDataType()); external_nodes.insert(n); } + OpDesc op_desc; + op_desc.SetType("fusion_group"); + op_desc.SetInput("Inputs", input_names); op_desc.SetOutput("Outs", output_names); - op_desc.SetAttr("inputs_data_type", inputs_data_types); - op_desc.SetAttr("outs_data_type", outs_data_types); + op_desc.SetAttr("inputs_dtype", input_dtypes); + op_desc.SetAttr("outs_dtype", output_dtypes); op_desc.SetAttr("type", subgraph->GetType()); op_desc.SetAttr("func_name", subgraph->GetFuncName()); op_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), ExtractOpRole(subgraph)); Node* fusion_group_node = graph->CreateOpNode(&op_desc); - for (auto* in : input_vars_of_subgraph) { - IR_NODE_LINK_TO(in, fusion_group_node); + for (auto* in : input_vars) { + if (output_vars_set.find(in) == output_vars_set.end()) { + IR_NODE_LINK_TO(in, fusion_group_node); + } } - - for (auto* out : output_var_without_intermediate) { + for (auto* out : output_vars) { IR_NODE_LINK_TO(fusion_group_node, out); } diff --git a/paddle/fluid/framework/ir/fusion_group/operation.cc b/paddle/fluid/framework/ir/fusion_group/operation.cc index b127d132bafb00f32e7e6c5d2681d6f0e78b4c34..921cf0904f632936862b18b2f083f18a33c760be 100644 --- a/paddle/fluid/framework/ir/fusion_group/operation.cc +++ b/paddle/fluid/framework/ir/fusion_group/operation.cc @@ -105,12 +105,6 @@ void OperationMap::InsertUnaryElementwiseOperations() { insert_handler("tanh", "%{2.0} / (%{1.0} + Exp(-%{2.0} * ${0})) - %{1.0}", {"${2} * (%{1.0} - ${1} * ${1})"}); - // cast: - // out = static_cast(x) - // TODO(wangchaochaohu): This is not the compelete definition of - // cast Op, We need refine it later. - insert_handler("cast", "${0}", {}); - // sqrt: // out = x^(1/2) // dx = dout * 0.5 / out @@ -121,11 +115,21 @@ void OperationMap::InsertUnaryElementwiseOperations() { // dx = dout * 2.0 * x insert_handler("square", "${0} * ${0}", {"${2} * %{2.0} * ${0}"}); + // assign: + // out = x + insert_handler("assign", "${0}", {}); + + // cast: + // out = static_cast(x) + // TODO(wangchaochaohu): This is not the compelete definition of + // cast Op, We need refine it later. + insert_handler("cast", "${0}", {}); + // scale - // out = (bias_after_scale) ? scale * X + bias : scale(X + bias) - // here we use '=' operator to seperate th default value + // out = (bias_after_scale) ? scale * X + bias : scale(X + bias) + // here we use '=' operator to seperate th default value // TODO(wangchaochaohu): Later we need to support Tensor input for scale and - // bias. + // bias. insert_handler( "scale", "${bias_after_scale=true} ? (${scale=%{1.0}} * ${0} + " diff --git a/paddle/fluid/framework/ir/fusion_group/subgraph.h b/paddle/fluid/framework/ir/fusion_group/subgraph.h index 66b17e9f6fe95519b7687bc1c5725684c5c98610..5a29e875aea615c36711aa7dc044e4e1f563c297 100644 --- a/paddle/fluid/framework/ir/fusion_group/subgraph.h +++ b/paddle/fluid/framework/ir/fusion_group/subgraph.h @@ -66,11 +66,12 @@ class SubGraph { } int GetType() const { return type_; } - bool RemoveIntermediateOut() { return !save_intermediate_out_; } void SetFuncName(std::string func_name) { func_name_ = func_name; } std::string GetFuncName() const { return func_name_; } + bool SaveIntermediateOut() const { return save_intermediate_out_; } + const std::unordered_set& Nodes() const { return nodes_set_; } const std::vector& SortedNodes() { if (!is_sorted_) { @@ -118,66 +119,88 @@ class SubGraph { return input_vars; } - std::vector GetOutputVarNodes() { + std::vector GetOutputVarNodes(bool with_intermediate_out) { // The order of output nodes should be consistant anywhere.. - std::vector output_vars_all; + std::vector output_vars; for (auto* n : SortedNodes()) { - if (n && n->IsVar() && n->Var()) { + if (IsOutputOfInternalOp(n)) { // If the var_node is the output of some op_node in the subgraph, it // is considered the output var node of the subgraph. - bool is_found = false; - for (auto* in : n->inputs) { - if (Has(in)) { - is_found = true; + if (with_intermediate_out) { + output_vars.push_back(n); + } else { + if (n->outputs.empty() || IsInputOfExternalOp(n)) { + output_vars.push_back(n); } } - if (is_found) { - output_vars_all.push_back(n); - } } } - return output_vars_all; + return output_vars; } std::vector GetIntermediateOutVarNodes() { - return intermediate_out_nodes_; + // Intermediate output var nodes: the output of some op_node in the + // subgraph, but not referenced outside the subgraph. + std::vector intermediate_out_vars; + for (auto* n : SortedNodes()) { + if (IsOutputOfInternalOp(n) && IsInputOfInternalOp(n) && + !IsInputOfExternalOp(n)) { + // When the outputs size is 0, it is also considered a intermidiate + // output. It maybe an unused output or the fetching vars, so that we + // cannot eleiminate it directly here. + intermediate_out_vars.push_back(n); + } + } + return intermediate_out_vars; } - void DetectIntermediateOutWithGraph(Graph* graph) { - auto graph_nodes = graph->Nodes(); - - for (auto* n : SortedNodes()) { - bool enable_remove = true; + std::unordered_set GetIntermediateOutVarNodesSet() { + std::vector intermediate_out_vars = GetIntermediateOutVarNodes(); + return std::unordered_set(intermediate_out_vars.begin(), + intermediate_out_vars.end()); + } - if (n && n->IsVar() && n->Var()) { - bool leaf_graph = true; - for (auto* node : graph_nodes) { - if (node->IsOp()) { - auto inputs = node->inputs; - for (auto* in : inputs) { - if (in && in->Name() == n->Name()) { - if (!Has(node)) enable_remove = false; - leaf_graph = false; - } - } - } - if (!enable_remove) { - break; - } + private: + bool IsInputOfInternalOp(Node* n) { + bool is_input_of_internal_op = false; + if (Has(n) && n && n->IsVar() && n->Var()) { + for (auto* out : n->outputs) { + if (Has(out)) { + is_input_of_internal_op = true; + break; } - if (leaf_graph) enable_remove = false; + } + } + return is_input_of_internal_op; + } - } else { - enable_remove = false; + bool IsInputOfExternalOp(Node* n) { + // If n is the input any one node outside the subgraph. + bool is_input_of_external_op = false; + if (Has(n) && n && n->IsVar() && n->Var()) { + for (auto* out : n->outputs) { + if (!Has(out)) { + is_input_of_external_op = true; + break; + } } + } + return is_input_of_external_op; + } - if (enable_remove) { - intermediate_out_nodes_.push_back(n); + bool IsOutputOfInternalOp(Node* n) { + bool is_output_of_internal_op = false; + if (Has(n) && n && n->IsVar() && n->Var()) { + for (auto* in : n->inputs) { + if (Has(in)) { + is_output_of_internal_op = true; + break; + } } } + return is_output_of_internal_op; } - private: void TopologicalSort() { if (!is_sorted_) { std::unordered_map> inputs_map; @@ -236,7 +259,6 @@ class SubGraph { bool save_intermediate_out_{true}; std::unordered_set nodes_set_; - std::vector intermediate_out_nodes_{}; bool is_sorted_{false}; std::vector sorted_nodes_; }; diff --git a/paddle/fluid/operators/fused/fusion_group_op.cc b/paddle/fluid/operators/fused/fusion_group_op.cc index c9e8af6153b672b821355096078e1d186508034c..738e069081511ed2e6df56633971f0db21211ac1 100644 --- a/paddle/fluid/operators/fused/fusion_group_op.cc +++ b/paddle/fluid/operators/fused/fusion_group_op.cc @@ -22,8 +22,14 @@ class FusionGroupOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - const size_t num_ins = ctx->Inputs("Inputs").size(); - const size_t num_outs = ctx->Outputs("Outs").size(); + OP_INOUT_CHECK(ctx->HasInputs("Inputs"), "Input", "Inputs", "FusionGroup"); + OP_INOUT_CHECK(ctx->HasOutputs("Outs"), "Output", "Outs", "FusionGroup"); + + auto input_names = ctx->Inputs("Inputs"); + auto output_names = ctx->Outputs("Outs"); + + const size_t num_ins = input_names.size(); + const size_t num_outs = output_names.size(); PADDLE_ENFORCE_GE( num_ins, 1UL, @@ -42,9 +48,12 @@ class FusionGroupOp : public framework::OperatorWithKernel { std::vector x_dims = ctx->GetInputsDim("Inputs"); if (type == 0) { for (size_t i = 1; i < num_ins; ++i) { - PADDLE_ENFORCE_EQ(x_dims[0], x_dims[i], - platform::errors::InvalidArgument( - "All the inputs' dims should be the same.")); + PADDLE_ENFORCE_EQ( + x_dims[0], x_dims[i], + platform::errors::InvalidArgument( + "All the inputs' dims is expected to be the same. " + "But recieved [%s] (name: %s) vs [%s] (name: %s).", + x_dims[0], input_names[0], x_dims[i], input_names[i])); } std::vector out_dims; for (size_t j = 0; j < num_outs; ++j) { @@ -76,11 +85,11 @@ class FusionGroupOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("Outs", "(std::vector) The outputs of fusion_group op.") .AsDuplicable(); - AddAttr>( - "outs_data_type", "The data type of Outputs in fusion_group op.") + AddAttr>("outs_dtype", + "The data type of Outputs in fusion_group op.") .SetDefault({}); - AddAttr>( - "inputs_data_type", "The data type of Inputs in fusion_group op.") + AddAttr>("inputs_dtype", + "The data type of Inputs in fusion_group op.") .SetDefault({}); AddAttr("type", "Fusion type.").SetDefault(0); AddAttr("func_name", "Name of the generated functions.") diff --git a/paddle/fluid/operators/fused/fusion_group_op.h b/paddle/fluid/operators/fused/fusion_group_op.h index 8449c6b63b1a176071c2197de063b90ec2a535eb..5e5f2c60ffbd48d801aa4cff1b074170c44ed88a 100644 --- a/paddle/fluid/operators/fused/fusion_group_op.h +++ b/paddle/fluid/operators/fused/fusion_group_op.h @@ -24,14 +24,14 @@ namespace operators { static void MutableMultiTypeData( std::vector* var, - const std::vector& data_type, const platform::Place& place) { + const std::vector& data_type, const platform::Place& place) { for (size_t i = 0; i < var->size(); i++) { - if (data_type[i] == "float") { + if (data_type[i] == framework::proto::VarType::FP32) { (*var)[i]->mutable_data(place); - } else if (data_type[i] == "double") { - (*var)[i]->mutable_data(place); - } else if (data_type[i] == "::paddle::platform::float16") { + } else if (data_type[i] == framework::proto::VarType::FP16) { (*var)[i]->mutable_data(place); + } else if (data_type[i] == framework::proto::VarType::FP64) { + (*var)[i]->mutable_data(place); } } } @@ -43,15 +43,15 @@ class FusionGroupKernel : public framework::OpKernel { auto ins = ctx.MultiInput("Inputs"); auto outs = ctx.MultiOutput("Outs"); int type = ctx.Attr("type"); - auto outs_type = ctx.Attr>("outs_data_type"); - auto inputs_type = ctx.Attr>("inputs_data_type"); + const auto& outs_dtype = ctx.Attr>("outs_dtype"); + const auto& inputs_dtype = ctx.Attr>("inputs_dtype"); size_t num_ins = ins.size(); size_t num_outs = outs.size(); auto place = ctx.GetPlace(); - MutableMultiTypeData(&outs, outs_type, place); + MutableMultiTypeData(&outs, outs_dtype, place); std::string func_name = ctx.Attr("func_name"); platform::DeviceCode* dev_code = @@ -64,22 +64,22 @@ class FusionGroupKernel : public framework::OpKernel { args.push_back(&n); std::vector ptrs(num_ins + num_outs); for (size_t i = 0; i < num_ins; ++i) { - if (inputs_type[i] == "::paddle::platform::float16") { + if (inputs_dtype[i] == framework::proto::VarType::FP16) { ptrs[i] = ins[i]->data(); - } else if (inputs_type[i] == "double") { - ptrs[i] = ins[i]->data(); - } else if (inputs_type[i] == "float") { + } else if (inputs_dtype[i] == framework::proto::VarType::FP32) { ptrs[i] = ins[i]->data(); + } else if (inputs_dtype[i] == framework::proto::VarType::FP64) { + ptrs[i] = ins[i]->data(); } args.push_back(&ptrs[i]); } for (size_t j = 0; j < num_outs; ++j) { - if (outs_type[j] == "::paddle::platform::float16") { + if (outs_dtype[j] == framework::proto::VarType::FP16) { ptrs[num_ins + j] = outs[j]->data(); - } else if (outs_type[j] == "double") { - ptrs[num_ins + j] = outs[j]->data(); - } else if (outs_type[j] == "float") { + } else if (outs_dtype[j] == framework::proto::VarType::FP32) { ptrs[num_ins + j] = outs[j]->data(); + } else if (outs_dtype[j] == framework::proto::VarType::FP64) { + ptrs[num_ins + j] = outs[j]->data(); } args.push_back(&ptrs[num_ins + j]); } diff --git a/paddle/fluid/operators/fused/fusion_group_op_test.cc b/paddle/fluid/operators/fused/fusion_group_op_test.cc index 48e7d6af397849491c1afeb65c02878b88ccd6cf..d50c829b475752cfad5a41500c6d66d1ecc4c8bf 100644 --- a/paddle/fluid/operators/fused/fusion_group_op_test.cc +++ b/paddle/fluid/operators/fused/fusion_group_op_test.cc @@ -57,10 +57,14 @@ framework::OpDesc* CreateFusionGroupOp( const std::vector& input_names, const std::vector>& input_shapes, const std::vector& output_names, int type, - const std::vector& inputs_data_type, - const std::vector& outs_data_type, std::string func_name) { + std::string func_name) { EXPECT_EQ(input_names.size(), input_shapes.size()); + std::vector input_dtypes(input_names.size(), + framework::proto::VarType::FP32); + std::vector output_dtypes(output_names.size(), + framework::proto::VarType::FP32); + for (size_t i = 0; i < input_names.size(); ++i) { auto* var = program->MutableBlock(0)->Var(input_names[i]); var->SetType(framework::proto::VarType::LOD_TENSOR); @@ -77,8 +81,8 @@ framework::OpDesc* CreateFusionGroupOp( op->SetType("fusion_group"); op->SetInput("Inputs", input_names); op->SetOutput("Outs", output_names); - op->SetAttr("inputs_data_type", inputs_data_type); - op->SetAttr("outs_data_type", outs_data_type); + op->SetAttr("inputs_dtype", input_dtypes); + op->SetAttr("outs_dtype", output_dtypes); op->SetAttr("type", type); op->SetAttr("func_name", func_name); op->SetAttr(framework::OpProtoAndCheckerMaker::OpRoleAttrName(), @@ -133,8 +137,6 @@ void CheckOutputs(framework::Scope* scope, void TestMain(const std::vector& input_names, const std::vector>& input_shapes, const std::vector& output_names, int type, - const std::vector& inputs_data_type, - const std::vector& outs_data_type, std::string func_name, std::string cuda_kernel_str, CPUKernelFunc cpu_kernel_func) { // Compile the device code @@ -144,9 +146,8 @@ void TestMain(const std::vector& input_names, // Create a ProgramDesc that has a fusion_group_op. framework::ProgramDesc program; - framework::OpDesc* op_desc = - CreateFusionGroupOp(&program, input_names, input_shapes, output_names, - type, inputs_data_type, outs_data_type, func_name); + framework::OpDesc* op_desc = CreateFusionGroupOp( + &program, input_names, input_shapes, output_names, type, func_name); auto fusion_group_op = framework::OpRegistry::CreateOp(*op_desc); framework::Scope scope; @@ -216,11 +217,8 @@ void elementwise_cuda_kernel_0(size_t n, float *x, float* y, float* z) { } }; - std::vector inputs_data_type(input_names.size(), "float"); - std::vector outs_data_type(output_names.size(), "float"); - TestMain(input_names, input_shapes, output_names, 0, inputs_data_type, - outs_data_type, "elementwise_cuda_kernel_0", kernel, - elementwise_cpu_kernel_0); + TestMain(input_names, input_shapes, output_names, 0, + "elementwise_cuda_kernel_0", kernel, elementwise_cpu_kernel_0); } } // namespace operators diff --git a/python/paddle/fluid/tests/unittests/ir/test_ir_fusion_group_pass.py b/python/paddle/fluid/tests/unittests/ir/test_ir_fusion_group_pass.py index 7edca281fff9df02436b2cc1af5409db0ea1981d..46d574dad0d0ae1f72617c6aaf3369b16195f76b 100644 --- a/python/paddle/fluid/tests/unittests/ir/test_ir_fusion_group_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/test_ir_fusion_group_pass.py @@ -77,12 +77,13 @@ class FusionGroupPassTest(PassTest): self.check_output_with_place(fluid.CUDAPlace(0)) -class FusionGroupPassTest1(FusionGroupPassTest): +class FusionGroupPassComplicatedTest(FusionGroupPassTest): def build_program(self, dtype): with fluid.program_guard(self.main_program, self.startup_program): - self.feed_vars = self._prepare_feed_vars([32, 128], dtype, 5) + self.feed_vars = self._prepare_feed_vars([32, 64], dtype, 5) - tmp_0 = layers.assign(self.feed_vars[0]) + one = layers.fill_constant(shape=[1], dtype=dtype, value=1.0) + tmp_0 = one * self.feed_vars[0] # subgraph with 9 op nodes tmp_1 = tmp_0 * layers.sigmoid(self.feed_vars[1]) + layers.sigmoid( self.feed_vars[2]) * layers.tanh(self.feed_vars[3]) @@ -94,7 +95,7 @@ class FusionGroupPassTest1(FusionGroupPassTest): self.fetch_list = [tmp_2, self.grad(tmp_0)] -class FusionGroupPassTest2(FusionGroupPassTest): +class FusionGroupPassInplaceTest(FusionGroupPassTest): def build_program(self, dtype): with fluid.program_guard(self.main_program, self.startup_program): self.feed_vars = self._prepare_feed_vars([32, 128], dtype, 3) @@ -103,15 +104,13 @@ class FusionGroupPassTest2(FusionGroupPassTest): name="data3", shape=[128, 32], dtype=dtype)) # subgraph with 3 op node - tmp_0 = self.feed_vars[0] + self.feed_vars[1] - tmp_1 = layers.relu(self.feed_vars[2] * tmp_0) - # subgraph with 2 op nodes - tmp_2 = layers.relu(layers.sigmoid(self.feed_vars[3])) - tmp_3 = layers.mul(tmp_1, tmp_2) + tmp_0 = self.feed_vars[0] - self.feed_vars[1] + tmp_1 = tmp_0 * self.feed_vars[2] + tmp_2 = layers.assign(tmp_1, output=tmp_0) + tmp_3 = layers.mul(tmp_2, self.feed_vars[3]) - self.append_gradients(tmp_3) - self.num_fused_ops = 2 - self.fetch_list = [tmp_3, self.grad(tmp_1)] + self.num_fused_ops = 1 + self.fetch_list = [tmp_3] class FusionGroupPassTestFP64(FusionGroupPassTest):