From f0d193a23c9d79c2dc792e2c2d16da217b35b17e Mon Sep 17 00:00:00 2001 From: wangchaochaohu Date: Thu, 12 Mar 2020 11:39:54 +0800 Subject: [PATCH] Cast fusion for fusion group (#22876) * add support for expression type convert and add cast Op support in fusion group --- .../ir/fusion_group/code_generator.cc | 115 ++++++++++++------ .../ir/fusion_group/code_generator.h | 8 +- .../ir/fusion_group/code_generator_helper.cc | 51 +++++++- .../ir/fusion_group/code_generator_helper.h | 22 +++- .../ir/fusion_group/code_generator_tester.cc | 61 +++++----- .../elementwise_group_detector.cc | 48 +++++++- .../fusion_group/elementwise_group_detector.h | 7 +- .../ir/fusion_group/fusion_group_pass.cc | 8 ++ .../framework/ir/fusion_group/operation.cc | 11 +- .../framework/ir/fusion_group/subgraph.h | 35 +----- .../fluid/operators/fused/fusion_group_op.cc | 15 ++- .../fluid/operators/fused/fusion_group_op.h | 39 +++++- .../operators/fused/fusion_group_op_test.cc | 19 ++- .../fluid/tests/unittests/ir/pass_test.py | 4 +- .../unittests/ir/test_ir_fusion_group_pass.py | 29 ++++- 15 files changed, 337 insertions(+), 135 deletions(-) diff --git a/paddle/fluid/framework/ir/fusion_group/code_generator.cc b/paddle/fluid/framework/ir/fusion_group/code_generator.cc index 71a23141c37..2d21d15f455 100644 --- a/paddle/fluid/framework/ir/fusion_group/code_generator.cc +++ b/paddle/fluid/framework/ir/fusion_group/code_generator.cc @@ -24,6 +24,21 @@ namespace framework { namespace ir { namespace fusion_group { +std::string ExtractDataType(const std::vector nodes) { + std::string dtype_str = "float"; + auto data_type = nodes.back()->Var()->GetDataType(); + + if (data_type == proto::VarType::FP32) { + dtype_str = "float"; + } else if (data_type == proto::VarType::FP64) { + dtype_str = "double"; + } else if (data_type == proto::VarType::FP16) { + dtype_str = "float16"; + } + + return dtype_str; +} + CodeGenerator::CodeGenerator() { // Only support elementwise operations now. code_templates_.resize(1); @@ -34,8 +49,7 @@ CodeGenerator::CodeGenerator() { std::string CodeGenerator::Generate(SubGraph* subgraph) { std::vector expressions = ConvertToExpressions(subgraph); - return Generate(subgraph->GetFuncName(), subgraph->GetDataType(), - expressions); + return Generate(subgraph->GetFuncName(), expressions); } static bool HasInput(Node* n, std::string name) { @@ -95,8 +109,11 @@ std::vector CodeGenerator::ConvertToExpressions( "Output(%s) of operation %s is not set.", name, op->Type())); output_ids.push_back(var_ids[op->Output(name)[0]]); } - expressions.push_back( - OperationExpression(node->Name(), input_ids, output_ids)); + + std::string lhs_type = ExtractDataType(node->outputs); + std::string rhs_type = ExtractDataType(node->inputs); + expressions.emplace_back(OperationExpression( + node->Name(), input_ids, output_ids, rhs_type, lhs_type)); } } return expressions; @@ -105,25 +122,32 @@ std::vector CodeGenerator::ConvertToExpressions( // In order to get the right result of expression, we need to calculate and // store the expression as suffix Expressions using vector. std::string CodeGenerator::Generate( - std::string func_name, std::string dtype, + std::string func_name, const std::vector& expressions) { // TODO(liuyiqun): Check whether all expressions are elementwise operations. std::set input_ids = DistilInputIds(expressions); std::set output_ids = DistilOutputIds(expressions); - + std::unordered_map dtypes = DistilDtypes(expressions); TemplateVariable template_var; template_var.Add("func_name", func_name); - template_var.Add("parameters", EmitParameters(input_ids, output_ids, dtype)); + template_var.Add("parameters", EmitParameters(input_ids, output_ids, dtypes)); template_var.Add("compute_body", - EmitComputeBody(expressions, input_ids, output_ids, dtype)); - - std::string predefined_cuda_functions; - if (dtype == "float") { - predefined_cuda_functions = predefined_cuda_functions_fp32; - } else if (dtype == "double") { - predefined_cuda_functions = predefined_cuda_functions_fp64; - } else if (dtype == "float16") { - predefined_cuda_functions = predefined_cuda_functions_fp16; + EmitComputeBody(expressions, input_ids, output_ids, dtypes)); + + std::set all_dtype; + for (const auto& type : dtypes) { + all_dtype.insert(type.second); + } + std::string predefined_cuda_functions = ""; + if (all_dtype.find("float") != all_dtype.end() && + all_dtype.find("float16") == all_dtype.end()) { + predefined_cuda_functions += predefined_cuda_functions_fp32; + } + if (all_dtype.find("double") != all_dtype.end()) { + predefined_cuda_functions += predefined_cuda_functions_fp64; + } + if (all_dtype.find("float16") != all_dtype.end()) { + predefined_cuda_functions += predefined_cuda_functions_fp16; } return predefined_cuda_functions + code_templates_[0].Format(template_var); } @@ -154,10 +178,40 @@ std::set CodeGenerator::DistilOutputIds( return output_ids; } +std::unordered_map CodeGenerator::DistilDtypes( + const std::vector& expressions) { + std::unordered_map dtypes; + for (const auto& expression : expressions) { + for (auto id : expression.GetInputIds()) { + auto dtype = expression.GetRHSType(); + if (dtypes.find(id) == dtypes.end()) { + dtypes[id] = dtype; + } else { + PADDLE_ENFORCE_EQ( + dtypes[id], dtype, + platform::errors::PreconditionNotMet( + "In fusion group, Same Node id must have same date type")); + } + } + for (auto id : expression.GetOutputIds()) { + auto dtype = expression.GetLHSType(); + if (dtypes.find(id) == dtypes.end()) { + dtypes[id] = dtype; + } else { + PADDLE_ENFORCE_EQ( + dtypes[id], dtype, + platform::errors::PreconditionNotMet( + "In fusion group, Same Node id must have same date type")); + } + } + } + return dtypes; +} + // we get the parameter list code for the expression information -std::string CodeGenerator::EmitParameters(const std::set& input_ids, - const std::set& output_ids, - std::string dtype) { +std::string CodeGenerator::EmitParameters( + const std::set& input_ids, const std::set& output_ids, + std::unordered_map dtypes) { std::stringstream ret; ret << "int N, "; @@ -165,13 +219,13 @@ std::string CodeGenerator::EmitParameters(const std::set& input_ids, // from the input list. for (auto id : input_ids) { if (output_ids.find(id) == output_ids.end()) { - ret << dtype << "* " << ArgName(id) << ", "; + ret << dtypes[id] << "* " << ArgName(id) << ", "; } } size_t index = 0; for (auto id : output_ids) { - ret << dtype << "* " << ArgName(id); + ret << dtypes[id] << "* " << ArgName(id); if (index != output_ids.size() - 1) { ret << ", "; } @@ -184,13 +238,12 @@ std::string CodeGenerator::EmitParameters(const std::set& input_ids, std::string CodeGenerator::EmitComputeBody( const std::vector& expressions, const std::set& input_ids, const std::set& output_ids, - std::string dtype) { + std::unordered_map dtypes) { std::ostringstream compute; std::unordered_set used; - std::string compute_dtype = (dtype == "float16") ? "float" : dtype; for (size_t i = 0; i < expressions.size(); i++) { VLOG(3) << DebugString(expressions[i]); - compute << expressions[i].GetExpression(compute_dtype, &used); + compute << expressions[i].GetExpression(&used); } // Load input to temporal variables. @@ -198,23 +251,13 @@ std::string CodeGenerator::EmitComputeBody( for (auto id : input_ids) { if (output_ids.find(id) == output_ids.end() && used.find(id) != used.end()) { - if (dtype == "float16") { - load << "float " << TmpName(id) << " = __half2float(" << ArgName(id) - << "[idx]);"; - } else { - load << dtype << " " << TmpName(id) << " = " << ArgName(id) << "[idx];"; - } + load << dtypes[id] << " " << TmpName(id) << " = " << VarName(id) << ";"; } } - // Store temporal variables to memory. std::ostringstream store; for (auto id : output_ids) { - if (dtype == "float16") { - store << ArgName(id) << "[idx] = __float2half(" << TmpName(id) << ");"; - } else { - store << ArgName(id) << "[idx] = " << TmpName(id) << ";"; - } + store << VarName(id) << " = " << TmpName(id) << ";"; } return load.str() + compute.str() + store.str(); diff --git a/paddle/fluid/framework/ir/fusion_group/code_generator.h b/paddle/fluid/framework/ir/fusion_group/code_generator.h index ce1bcc48e65..ee908b2f575 100644 --- a/paddle/fluid/framework/ir/fusion_group/code_generator.h +++ b/paddle/fluid/framework/ir/fusion_group/code_generator.h @@ -30,7 +30,7 @@ class CodeGenerator { public: CodeGenerator(); - std::string Generate(std::string func_name, std::string dtype, + std::string Generate(std::string func_name, const std::vector& expressions); std::string Generate(SubGraph* subgraph); @@ -42,16 +42,18 @@ class CodeGenerator { const std::vector& expressions); std::set DistilOutputIds( const std::vector& expressions); + std::unordered_map DistilDtypes( + const std::vector& expressions); // we get the parameter list code for the expression information std::string EmitParameters(const std::set& input_ids, const std::set& output_ids, - std::string dtype); + std::unordered_map dtypes); std::string EmitComputeBody( const std::vector& expressions, const std::set& input_ids, const std::set& output_ids, - std::string dtype); + std::unordered_map dtypes); // Encode all var nodes in the subgraph with an unique number. std::unordered_map EncodeVarNodes(SubGraph* subgraph); diff --git a/paddle/fluid/framework/ir/fusion_group/code_generator_helper.cc b/paddle/fluid/framework/ir/fusion_group/code_generator_helper.cc index fe395e5f992..e9ed38cac4d 100644 --- a/paddle/fluid/framework/ir/fusion_group/code_generator_helper.cc +++ b/paddle/fluid/framework/ir/fusion_group/code_generator_helper.cc @@ -50,10 +50,26 @@ static std::string ExpandMultivariateTemplate(const std::string rhs, return sum_rhs; } +// In order to avoid multiple __half2float function calls, we do this +// optimization +static std::string OptimzeFP16RHS(std::unordered_set* used, + const int index, + const std::vector& input_ids) { + std::stringstream ret; + if (used->find(input_ids[index]) == used->end()) { + ret << "float half2fp32_" + TmpName(input_ids[index]) + " = __half2float(" + + TmpName(input_ids[index]) + ");"; + } + + return ret.str(); +} + std::string OperationExpression::GetRHS(std::unordered_set* used, + std::string* half2fp32_statement, size_t exprs_index) const { auto rhs = OperationMap::Instance().Get(op_type_).exprs[exprs_index]; auto num_operands = OperationMap::Instance().Get(op_type_).num_operands; + if (num_operands == -1) { size_t input_size = input_ids_.size(); rhs = ExpandMultivariateTemplate(rhs, input_size); @@ -78,7 +94,16 @@ std::string OperationExpression::GetRHS(std::unordered_set* used, platform::errors::InvalidArgument( "Expected %d-th input id > 0 for operation < %s >. Received %d.", index, op_type_, input_ids_[index])); - rhs.replace(pos, length + 3, TmpName(input_ids_[index])); + // TODO(wangchaochaohu): Here fp16 convert to float to do comupte, we need + // to add general fp16 compute later. + std::string var_name; + if (rhs_type_ == "float16") { + half2fp32_statement->append(OptimzeFP16RHS(used, index, input_ids_)); + var_name = "half2fp32_" + TmpName(input_ids_[index]); + } else { + var_name = TmpName(input_ids_[index]); + } + rhs.replace(pos, length + 3, var_name); used->insert(input_ids_[index]); } } @@ -87,7 +112,7 @@ std::string OperationExpression::GetRHS(std::unordered_set* used, std::string OperationExpression::GetLHS(size_t i) const { std::stringstream ret; - ret << TmpName(output_ids_[i]); + ret << lhs_type_ << " " << TmpName(output_ids_[i]); return ret.str(); } @@ -98,15 +123,29 @@ bool OperationExpression::IsSupport() const { // we Traverse the graph and get the group , all input id and output id is // unique for the node which belong the group std::string OperationExpression::GetExpression( - std::string dtype, std::unordered_set* used) const { + std::unordered_set* used) const { + std::string half2fp32_statement; std::stringstream ret; if (IsSupport()) { for (size_t i = 0; i < output_ids_.size(); ++i) { - ret << dtype << " " << GetLHS(i) << " = " << GetRHS(used, i) << ";"; + std::string cast_str = ""; + if ((lhs_type_ == rhs_type_ && rhs_type_ != "float16") || + (lhs_type_ != rhs_type_ && rhs_type_ == "float16")) { + ret << GetLHS(i) << " = " << GetRHS(used, &half2fp32_statement, i) + << ";"; + } else { + if ((lhs_type_ == rhs_type_ && rhs_type_ == "float16") || + lhs_type_ == "float16") { + cast_str = "__float2half"; + } else { + cast_str = "static_cast<" + lhs_type_ + ">"; + } + ret << GetLHS(i) << " = " << cast_str << "(" + << GetRHS(used, &half2fp32_statement, i) << ");"; + } } } - - return ret.str(); + return half2fp32_statement + ret.str(); } } // namespace fusion_group diff --git a/paddle/fluid/framework/ir/fusion_group/code_generator_helper.h b/paddle/fluid/framework/ir/fusion_group/code_generator_helper.h index 18652b53051..10bd2b20a2f 100644 --- a/paddle/fluid/framework/ir/fusion_group/code_generator_helper.h +++ b/paddle/fluid/framework/ir/fusion_group/code_generator_helper.h @@ -30,29 +30,41 @@ namespace fusion_group { static inline std::string ArgName(int index) { return "arg" + std::to_string(index); } + static inline std::string TmpName(int index) { return "tmp" + std::to_string(index); } +static inline std::string VarName(int index) { + return "arg" + std::to_string(index) + "[idx]"; +} + class OperationExpression { public: explicit OperationExpression(std::string op_type, std::vector input_ids, - std::vector output_ids) - : op_type_(op_type), input_ids_(input_ids), output_ids_(output_ids) {} + std::vector output_ids, + std::string rhs_type, std::string lhs_type) + : op_type_(op_type), + input_ids_(input_ids), + output_ids_(output_ids), + rhs_type_(rhs_type), + lhs_type_(lhs_type) {} std::string GetOpType() const { return op_type_; } std::vector GetInputIds() const { return input_ids_; } std::vector GetOutputIds() const { return output_ids_; } + std::string GetRHSType() const { return rhs_type_; } + std::string GetLHSType() const { return lhs_type_; } // Check whether this operation type is supported in OperationMap. bool IsSupport() const; - std::string GetExpression(std::string dtype, - std::unordered_set* used) const; + std::string GetExpression(std::unordered_set* used) const; private: // TODO(wangchao): make offset more flexible we add stride and basic offset std::string GetRHS(std::unordered_set* used, + std::string* half2fp32_statement, size_t exprs_index = 0) const; std::string GetLHS(size_t i = 0) const; @@ -60,6 +72,8 @@ class OperationExpression { std::string op_type_; std::vector input_ids_; std::vector output_ids_; + std::string rhs_type_; + std::string lhs_type_; }; class TemplateVariable { diff --git a/paddle/fluid/framework/ir/fusion_group/code_generator_tester.cc b/paddle/fluid/framework/ir/fusion_group/code_generator_tester.cc index 8f4eb7443ff..92dec555cab 100644 --- a/paddle/fluid/framework/ir/fusion_group/code_generator_tester.cc +++ b/paddle/fluid/framework/ir/fusion_group/code_generator_tester.cc @@ -288,7 +288,7 @@ void TestMain(std::string func_name, std::string dtype) { fusion_group::OperationMap::Init(); fusion_group::CodeGenerator code_generator; - std::string code_str = code_generator.Generate(func_name, dtype, expressions); + std::string code_str = code_generator.Generate(func_name, expressions); VLOG(3) << code_str; LOG(INFO) << "dtype: " << dtype; @@ -297,7 +297,7 @@ void TestMain(std::string func_name, } void TestMain(fusion_group::SubGraph* subgraph, std::vector input_ids, - std::vector output_ids) { + std::vector output_ids, std::string dtype) { fusion_group::OperationMap::Init(); fusion_group::CodeGenerator code_generator; std::string code_str = code_generator.Generate(subgraph); @@ -307,26 +307,28 @@ void TestMain(fusion_group::SubGraph* subgraph, std::vector input_ids, std::vector expressions = code_generator.ConvertToExpressions(subgraph); - LOG(INFO) << "dtype: " << subgraph->GetDataType(); TestElementwiseMain(subgraph->GetFuncName(), code_str, expressions, input_ids, - output_ids, subgraph->GetDataType()); + output_ids, dtype); } TEST(code_generator, elementwise) { - // t2 = t0 * t1 - // t4 = t2 + t3 - // t6 = t4 - t5 - // t7 = relu(t6) - // t8 = sigmoid(t7) - fusion_group::OperationExpression exp1("elementwise_mul", {0, 1}, {2}); - fusion_group::OperationExpression exp2("elementwise_add", {2, 3}, {4}); - fusion_group::OperationExpression exp3("elementwise_sub", {4, 5}, {6}); - fusion_group::OperationExpression exp4("relu", {6}, {7}); - fusion_group::OperationExpression exp5("sigmoid", {7}, {8}); - std::vector expressions = { - exp1, exp2, exp3, exp4, exp5}; - for (std::string dtype : {"float", "float16"}) { + // t2 = t0 * t1 + // t4 = t2 + t3 + // t6 = t4 - t5 + // t7 = relu(t6) + // t8 = sigmoid(t7) + fusion_group::OperationExpression exp1("elementwise_mul", {0, 1}, {2}, + dtype, dtype); + fusion_group::OperationExpression exp2("elementwise_add", {2, 3}, {4}, + dtype, dtype); + fusion_group::OperationExpression exp3("elementwise_sub", {4, 5}, {6}, + dtype, dtype); + fusion_group::OperationExpression exp4("relu", {6}, {7}, dtype, dtype); + fusion_group::OperationExpression exp5("sigmoid", {7}, {8}, dtype, dtype); + std::vector expressions = { + exp1, exp2, exp3, exp4, exp5}; + // Expressions: // Op(elementwise_mul), inputs:{0,1}, outputs:{2} // Op(elementwise_add), inputs:{2,3}, outputs:{4} @@ -340,17 +342,18 @@ TEST(code_generator, elementwise) { } TEST(code_generator, elementwise_grad) { - // The var order: t0, t1, t2, t3, t0', t1', t2', t3' - // t2 = t0 * t1 - // t3 = relu(t2) - // t2' = relu_grad(t2, t3, t3') - // t0', t1' = elementwise_mul_grad(t0, t1, t2, t2') - fusion_group::OperationExpression exp1("relu_grad", {-1, 3, 7}, {6}); - fusion_group::OperationExpression exp2("elementwise_mul_grad", {0, 1, 2, 6}, - {4, 5}); - std::vector expressions = {exp1, exp2}; - for (std::string dtype : {"float", "float16"}) { + // The var order: t0, t1, t2, t3, t0', t1', t2', t3' + // t2 = t0 * t1 + // t3 = relu(t2) + // t2' = relu_grad(t2, t3, t3') + // t0', t1' = elementwise_mul_grad(t0, t1, t2, t2') + fusion_group::OperationExpression exp1("relu_grad", {-1, 3, 7}, {6}, dtype, + dtype); + fusion_group::OperationExpression exp2("elementwise_mul_grad", {0, 1, 2, 6}, + {4, 5}, dtype, dtype); + std::vector expressions = {exp1, exp2}; + // Expressions: // Op(relu_grad), inputs:{2,3,7}, outputs:{6} // Op(elementwise_mul_grad), inputs:{0,1,2,6}, outputs:{4,5} @@ -474,7 +477,7 @@ TEST(code_generator, subgraph) { // Op(elementwise_add), inputs:{7,6}, outputs:{8} std::vector input_ids = {0, 1, 2, 3}; std::vector output_ids = {4, 5, 6, 7, 8}; - TestMain(&subgraph, input_ids, output_ids); + TestMain(&subgraph, input_ids, output_ids, dtype); } } @@ -493,7 +496,7 @@ TEST(code_generator, subgraph_grad) { // Op(tanh_grad), inputs:{9,4,13}, outputs:{14} std::vector input_ids = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; std::vector output_ids = {10, 11, 12, 13, 14, 15, 16, 17}; - TestMain(&subgraph, input_ids, output_ids); + TestMain(&subgraph, input_ids, output_ids, dtype); } } #endif diff --git a/paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.cc b/paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.cc index 66b4130b0b5..93986945273 100644 --- a/paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.cc +++ b/paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.cc @@ -60,6 +60,50 @@ static bool IsEqualAndNotEmpty(const std::vector& l, return l.size() != 0U && r.size() != 0U && l == r; } +bool GroupDetector::IsFusionGroupOp(const Node* n) { + if (!(n && n->IsOp() && n->Op())) return false; + bool is_first = true; + proto::VarType::Type i_data_type = proto::VarType::FP32; + proto::VarType::Type o_data_type = proto::VarType::FP32; + + for (auto* i_node : n->inputs) { + if (!i_node->Var()) return false; + if (i_node->Var()->GetType() != proto::VarType::LOD_TENSOR) { + return false; + } + if (is_first) { + i_data_type = i_node->Var()->GetDataType(); + is_first = false; + } else { + if (i_data_type != i_node->Var()->GetDataType()) return false; + } + } + + is_first = true; + for (auto* o_node : n->outputs) { + if (!o_node->Var()) return false; + if (o_node->Var()->GetType() != proto::VarType::LOD_TENSOR) { + return false; + } + if (is_first) { + o_data_type = o_node->Var()->GetDataType(); + is_first = false; + } else { + if (o_data_type != o_node->Var()->GetDataType()) return false; + } + } + + if (!(i_data_type == proto::VarType::FP32 || + i_data_type == proto::VarType::FP64 || + i_data_type == proto::VarType::FP16) || + !(o_data_type == proto::VarType::FP32 || + o_data_type == proto::VarType::FP64 || + o_data_type == proto::VarType::FP16)) + return false; + + return true; +} + bool ElementwiseGroupDetector::IsElementwiseOp(const Node* n) { if (IsSpecifiedOp(GetElementwiseOpTypes(), n)) { std::vector shape_0; @@ -85,7 +129,9 @@ bool ElementwiseGroupDetector::IsElementwiseOp(const Node* n) { std::vector> ElementwiseGroupDetector::operator()( Graph* graph) { - auto teller = [&](const Node* n) -> bool { return IsElementwiseOp(n); }; + auto teller = [&](const Node* n) -> bool { + return IsFusionGroupOp(n) && IsElementwiseOp(n); + }; return SubgraphDetector(graph, teller)(); } diff --git a/paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.h b/paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.h index ff4db720f5d..58601c6ad77 100644 --- a/paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.h +++ b/paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.h @@ -23,7 +23,12 @@ namespace framework { namespace ir { namespace fusion_group { -class ElementwiseGroupDetector { +class GroupDetector { + protected: + bool IsFusionGroupOp(const Node* n); +}; + +class ElementwiseGroupDetector : GroupDetector { public: std::vector> operator()(Graph* graph); diff --git a/paddle/fluid/framework/ir/fusion_group/fusion_group_pass.cc b/paddle/fluid/framework/ir/fusion_group/fusion_group_pass.cc index 787bfe58987..b672a80662e 100644 --- a/paddle/fluid/framework/ir/fusion_group/fusion_group_pass.cc +++ b/paddle/fluid/framework/ir/fusion_group/fusion_group_pass.cc @@ -110,18 +110,25 @@ void FusionGroupPass::InsertFusionGroupOp( op_desc.SetType("fusion_group"); std::vector input_names; + std::vector inputs_data_types; for (auto* n : input_vars_of_subgraph) { input_names.push_back(n->Name()); + inputs_data_types.push_back(DataTypeToString(n->Var()->GetDataType())); external_nodes.insert(n); } op_desc.SetInput("Inputs", input_names); std::vector output_names; + std::vector outs_data_types; for (auto* n : output_vars_of_subgraph) { output_names.push_back(n->Name()); + outs_data_types.push_back(DataTypeToString(n->Var()->GetDataType())); external_nodes.insert(n); } + op_desc.SetOutput("Outs", output_names); + op_desc.SetAttr("inputs_data_type", inputs_data_types); + op_desc.SetAttr("outs_data_type", outs_data_types); op_desc.SetAttr("type", subgraph->GetType()); op_desc.SetAttr("func_name", subgraph->GetFuncName()); op_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), @@ -131,6 +138,7 @@ void FusionGroupPass::InsertFusionGroupOp( for (auto* in : input_vars_of_subgraph) { IR_NODE_LINK_TO(in, fusion_group_node); } + for (auto* out : output_vars_of_subgraph) { IR_NODE_LINK_TO(fusion_group_node, out); } diff --git a/paddle/fluid/framework/ir/fusion_group/operation.cc b/paddle/fluid/framework/ir/fusion_group/operation.cc index 57572af2c2e..f6846676c72 100644 --- a/paddle/fluid/framework/ir/fusion_group/operation.cc +++ b/paddle/fluid/framework/ir/fusion_group/operation.cc @@ -102,6 +102,13 @@ void OperationMap::InsertUnaryElementwiseOperations() { // dx = dout * (1 - out * out) insert_handler("tanh", "2.0 / (1.0 + real_exp(-2.0 * ${0})) - 1.0", {"${2} * (1.0 - ${1} * ${1})"}); + + // cast + // out = static_cast(d) + // dx = static_cast(d_out) + // TODO(wangchaochaohu): This is not the compelete definition of + // cast Op, We need refine it later. + insert_handler("cast", "${0}", {"${0}"}); } void OperationMap::InsertBinaryElementwiseOperations() { @@ -158,10 +165,12 @@ void OperationMap::InsertMultivariateElementwiseOperations() { std::vector grad_exprs) { int type = 0; int num_oprands = -1; - // here ... represent the number of input is changed Insert(type, num_oprands, op_type, expr, grad_exprs, {"X"}, {"Out"}); }; + // here [] represent the number of input is positive(>=0). + // if input list size of Sum Op is 3, It will expand as + // ${0} + ${1} + ${2} insert_handler("sum", "${0}[ + ${?}]", {}); } diff --git a/paddle/fluid/framework/ir/fusion_group/subgraph.h b/paddle/fluid/framework/ir/fusion_group/subgraph.h index 35247ece490..4cf2bf48d5d 100644 --- a/paddle/fluid/framework/ir/fusion_group/subgraph.h +++ b/paddle/fluid/framework/ir/fusion_group/subgraph.h @@ -49,7 +49,6 @@ class SubGraph { } } } - ExtractDataType(); } bool IsValid(int min_subgraph_size) { @@ -61,11 +60,10 @@ class SubGraph { return false; } - return ExtractDataType(); + return true; } int GetType() const { return type_; } - std::string GetDataType() const { return data_type_; } void SetFuncName(std::string func_name) { func_name_ = func_name; } std::string GetFuncName() const { return func_name_; } @@ -162,37 +160,6 @@ class SubGraph { } private: - bool ExtractDataType() { - bool is_first = true; - proto::VarType::Type data_type = proto::VarType::FP32; - for (auto* n : nodes_set_) { - if (n && n->IsVar() && n->Var()) { - if (n->Var()->GetType() != proto::VarType::LOD_TENSOR) { - // All var node in a subgraph should hold a LoDTensor. - return false; - } - if (is_first) { - data_type = n->Var()->GetDataType(); - is_first = false; - } else if (n->Var()->GetDataType() != data_type) { - // DataType of VarDesc in a subgraph is not the same. - return false; - } - } - } - if (data_type == proto::VarType::FP32) { - data_type_ = "float"; - } else if (data_type == proto::VarType::FP64) { - data_type_ = "double"; - } else if (data_type == proto::VarType::FP16) { - data_type_ = "float16"; - } else { - VLOG(2) << "Only support fp32, fp64 and fp16 in fusion_group."; - return false; - } - return true; - } - void TopologicalSort() { if (!is_sorted_) { std::unordered_map> inputs_map; diff --git a/paddle/fluid/operators/fused/fusion_group_op.cc b/paddle/fluid/operators/fused/fusion_group_op.cc index 503c0355855..c9e8af6153b 100644 --- a/paddle/fluid/operators/fused/fusion_group_op.cc +++ b/paddle/fluid/operators/fused/fusion_group_op.cc @@ -21,7 +21,7 @@ class FusionGroupOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { const size_t num_ins = ctx->Inputs("Inputs").size(); const size_t num_outs = ctx->Outputs("Outs").size(); @@ -58,6 +58,13 @@ class FusionGroupOp : public framework::OperatorWithKernel { ctx->ShareLoD("Inputs", /*->*/ "Outs", 0, j); } } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(framework::proto::VarType::FP32, + platform::CUDAPlace(0)); + }; }; class FusionGroupOpMaker : public framework::OpProtoAndCheckerMaker { @@ -69,6 +76,12 @@ class FusionGroupOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("Outs", "(std::vector) The outputs of fusion_group op.") .AsDuplicable(); + AddAttr>( + "outs_data_type", "The data type of Outputs in fusion_group op.") + .SetDefault({}); + AddAttr>( + "inputs_data_type", "The data type of Inputs in fusion_group op.") + .SetDefault({}); AddAttr("type", "Fusion type.").SetDefault(0); AddAttr("func_name", "Name of the generated functions.") .SetDefault(""); diff --git a/paddle/fluid/operators/fused/fusion_group_op.h b/paddle/fluid/operators/fused/fusion_group_op.h index cc8af48792f..ed879cfae55 100644 --- a/paddle/fluid/operators/fused/fusion_group_op.h +++ b/paddle/fluid/operators/fused/fusion_group_op.h @@ -22,6 +22,20 @@ limitations under the License. */ namespace paddle { namespace operators { +static void MutableMultiTypeData( + std::vector* var, + const std::vector& data_type, const platform::Place& place) { + for (size_t i = 0; i < (*var).size(); i++) { + if (data_type[i] == "float") { + (*var)[i]->mutable_data(place); + } else if (data_type[i] == "double") { + (*var)[i]->mutable_data(place); + } else if (data_type[i] == "::paddle::platform::float16") { + (*var)[i]->mutable_data(place); + } + } +} + template class FusionGroupKernel : public framework::OpKernel { public: @@ -29,14 +43,15 @@ class FusionGroupKernel : public framework::OpKernel { auto ins = ctx.MultiInput("Inputs"); auto outs = ctx.MultiOutput("Outs"); int type = ctx.Attr("type"); + auto outs_type = ctx.Attr>("outs_data_type"); + auto inputs_type = ctx.Attr>("inputs_data_type"); size_t num_ins = ins.size(); size_t num_outs = outs.size(); auto place = ctx.GetPlace(); - for (size_t i = 0; i < num_outs; ++i) { - outs[i]->mutable_data(place); - } + + MutableMultiTypeData(&outs, outs_type, place); std::string func_name = ctx.Attr("func_name"); platform::DeviceCode* dev_code = @@ -47,13 +62,25 @@ class FusionGroupKernel : public framework::OpKernel { size_t n = ins[0]->numel(); std::vector args; args.push_back(&n); - std::vector ptrs(num_ins + num_outs); + std::vector ptrs(num_ins + num_outs); for (size_t i = 0; i < num_ins; ++i) { - ptrs[i] = ins[i]->data(); + if (inputs_type[i] == "::paddle::platform::float16") { + ptrs[i] = ins[i]->data(); + } else if (inputs_type[i] == "double") { + ptrs[i] = ins[i]->data(); + } else if (inputs_type[i] == "float") { + ptrs[i] = ins[i]->data(); + } args.push_back(&ptrs[i]); } for (size_t j = 0; j < num_outs; ++j) { - ptrs[num_ins + j] = outs[j]->data(); + if (outs_type[j] == "::paddle::platform::float16") { + ptrs[num_ins + j] = outs[j]->data(); + } else if (outs_type[j] == "double") { + ptrs[num_ins + j] = outs[j]->data(); + } else if (outs_type[j] == "float") { + ptrs[num_ins + j] = outs[j]->data(); + } args.push_back(&ptrs[num_ins + j]); } dev_code->Launch(n, &args); diff --git a/paddle/fluid/operators/fused/fusion_group_op_test.cc b/paddle/fluid/operators/fused/fusion_group_op_test.cc index 81acb0791c6..48e7d6af397 100644 --- a/paddle/fluid/operators/fused/fusion_group_op_test.cc +++ b/paddle/fluid/operators/fused/fusion_group_op_test.cc @@ -57,7 +57,8 @@ framework::OpDesc* CreateFusionGroupOp( const std::vector& input_names, const std::vector>& input_shapes, const std::vector& output_names, int type, - std::string func_name) { + const std::vector& inputs_data_type, + const std::vector& outs_data_type, std::string func_name) { EXPECT_EQ(input_names.size(), input_shapes.size()); for (size_t i = 0; i < input_names.size(); ++i) { @@ -76,6 +77,8 @@ framework::OpDesc* CreateFusionGroupOp( op->SetType("fusion_group"); op->SetInput("Inputs", input_names); op->SetOutput("Outs", output_names); + op->SetAttr("inputs_data_type", inputs_data_type); + op->SetAttr("outs_data_type", outs_data_type); op->SetAttr("type", type); op->SetAttr("func_name", func_name); op->SetAttr(framework::OpProtoAndCheckerMaker::OpRoleAttrName(), @@ -130,6 +133,8 @@ void CheckOutputs(framework::Scope* scope, void TestMain(const std::vector& input_names, const std::vector>& input_shapes, const std::vector& output_names, int type, + const std::vector& inputs_data_type, + const std::vector& outs_data_type, std::string func_name, std::string cuda_kernel_str, CPUKernelFunc cpu_kernel_func) { // Compile the device code @@ -139,8 +144,9 @@ void TestMain(const std::vector& input_names, // Create a ProgramDesc that has a fusion_group_op. framework::ProgramDesc program; - framework::OpDesc* op_desc = CreateFusionGroupOp( - &program, input_names, input_shapes, output_names, type, func_name); + framework::OpDesc* op_desc = + CreateFusionGroupOp(&program, input_names, input_shapes, output_names, + type, inputs_data_type, outs_data_type, func_name); auto fusion_group_op = framework::OpRegistry::CreateOp(*op_desc); framework::Scope scope; @@ -210,8 +216,11 @@ void elementwise_cuda_kernel_0(size_t n, float *x, float* y, float* z) { } }; - TestMain(input_names, input_shapes, output_names, 0, - "elementwise_cuda_kernel_0", kernel, elementwise_cpu_kernel_0); + std::vector inputs_data_type(input_names.size(), "float"); + std::vector outs_data_type(output_names.size(), "float"); + TestMain(input_names, input_shapes, output_names, 0, inputs_data_type, + outs_data_type, "elementwise_cuda_kernel_0", kernel, + elementwise_cpu_kernel_0); } } // namespace operators diff --git a/python/paddle/fluid/tests/unittests/ir/pass_test.py b/python/paddle/fluid/tests/unittests/ir/pass_test.py index 73953bd2db4..67c569f8174 100644 --- a/python/paddle/fluid/tests/unittests/ir/pass_test.py +++ b/python/paddle/fluid/tests/unittests/ir/pass_test.py @@ -142,8 +142,8 @@ class PassTest(unittest.TestCase): self.assertTrue( np.allclose( outs_opt[i], outs[i], atol=atol), - "Output < {} > has diff at {}".format(self.fetch_list[i].name, - str(place))) + "Output < {} > has diff at {}, expected {} but got {}".format( + self.fetch_list[i].name, str(place), outs_opt[i], outs[i])) def _check_fused_ops(self, program): ''' diff --git a/python/paddle/fluid/tests/unittests/ir/test_ir_fusion_group_pass.py b/python/paddle/fluid/tests/unittests/ir/test_ir_fusion_group_pass.py index c2069a6a855..0ec4a852986 100644 --- a/python/paddle/fluid/tests/unittests/ir/test_ir_fusion_group_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/test_ir_fusion_group_pass.py @@ -125,17 +125,15 @@ class FusionGroupPassTestFP16(FusionGroupPassTest): fluid.data( name="data2", shape=[128, 128], dtype=dtype)) - # subgraph with only 1 op node tmp_0 = self.feed_vars[0] * self.feed_vars[1] tmp_1 = layers.mul(tmp_0, self.feed_vars[2]) - tmp_2 = layers.cast(tmp_0, dtype="float16") tmp_3 = layers.cast(tmp_1, dtype="float16") - # subgraph with 2 op nodes + tmp_2 = layers.cast(tmp_0, dtype="float16") tmp_4 = layers.relu(tmp_2 + tmp_3) tmp_5 = layers.cast(tmp_4, dtype=dtype) - self.fetch_list = [tmp_5] - self.num_fused_ops = 1 + self.fetch_list = [tmp_0, tmp_1, tmp_2, tmp_3, tmp_4, tmp_5] + self.num_fused_ops = 2 class FusionGroupPassSumTest(FusionGroupPassTest): @@ -147,9 +145,28 @@ class FusionGroupPassSumTest(FusionGroupPassTest): tmp_1 = layers.sum([tmp_0, self.feed_vars[2], self.feed_vars[3]]) tmp_2 = layers.sum([tmp_1, self.feed_vars[4]]) - self.fetch_list = [tmp_0, tmp_1] + self.fetch_list = [tmp_0, tmp_1, tmp_2] + self.num_fused_ops = 1 + + +class FusionGroupPassCastTest(FusionGroupPassTest): + def build_program(self, dtype): + with fluid.program_guard(self.main_program, self.startup_program): + self.feed_vars = self._prepare_feed_vars([2, 2], dtype, 2) + + tmp_0 = layers.elementwise_add(self.feed_vars[0], self.feed_vars[1]) + tmp_1 = layers.cast(tmp_0, dtype="double") + tmp_2 = layers.cast(tmp_1, dtype="float32") + + self.fetch_list = [tmp_0, tmp_1, tmp_2] self.num_fused_ops = 1 + def setUp(self): + self.build_program("float64") + self.feeds = self._feed_random_data(self.feed_vars) + self.pass_names = "fusion_group_pass" + self.fused_op_type = "fusion_group" + if __name__ == "__main__": unittest.main() -- GitLab