From ca9e77a8d497be1ff1f3e40931442a0ff6c5b740 Mon Sep 17 00:00:00 2001 From: wangchaochaohu Date: Sun, 1 Mar 2020 22:59:53 +0800 Subject: [PATCH] add sum op support for fusion group (#22771) * Add the codegen and auto fusion for sum Op in fusion group --- .../ir/fusion_group/code_generator.cc | 21 ++++++----- .../ir/fusion_group/code_generator_helper.cc | 27 +++++++++++++- .../ir/fusion_group/code_generator_helper.h | 3 +- .../elementwise_group_detector.cc | 37 ++++--------------- .../framework/ir/fusion_group/operation.cc | 18 +++++++-- .../framework/ir/fusion_group/operation.h | 3 +- .../unittests/ir/test_ir_fusion_group_pass.py | 13 +++++++ 7 files changed, 76 insertions(+), 46 deletions(-) diff --git a/paddle/fluid/framework/ir/fusion_group/code_generator.cc b/paddle/fluid/framework/ir/fusion_group/code_generator.cc index b7a75d376a5..71a23141c37 100644 --- a/paddle/fluid/framework/ir/fusion_group/code_generator.cc +++ b/paddle/fluid/framework/ir/fusion_group/code_generator.cc @@ -60,18 +60,21 @@ std::vector CodeGenerator::ConvertToExpressions( // - X, Y in forward operations // - X, Y, Out, out@GRAD in backward operations std::vector input_ids; - std::vector input_names = - OperationMap::Instance().Get(op->Type()).input_names; + auto operation = OperationMap::Instance().Get(op->Type()); + std::vector input_names = operation.input_names; + for (auto& name : input_names) { // Some input vars are not used in grad ops, such as // "elementwise_add_grad", where "X", "Y" and "Out" are not used. - if (HasInput(node, name) && op->Input(name).size() >= 1U) { - // TODO(liuyiqun): support duplicated input. - PADDLE_ENFORCE_NE( - var_ids.find(op->Input(name)[0]), var_ids.end(), - platform::errors::InvalidArgument( - "Input(%s) of operation %s is not set.", name, op->Type())); - input_ids.push_back(var_ids[op->Input(name)[0]]); + + if ((HasInput(node, name) && op->Input(name).size() >= 1U)) { + for (size_t i = 0; i < op->Input(name).size(); i++) { + PADDLE_ENFORCE_NE( + var_ids.find(op->Input(name)[i]), var_ids.end(), + platform::errors::InvalidArgument( + "Input(%s) of operation %s is not set.", name, op->Type())); + input_ids.push_back(var_ids[op->Input(name)[i]]); + } } else { input_ids.push_back(-1); } diff --git a/paddle/fluid/framework/ir/fusion_group/code_generator_helper.cc b/paddle/fluid/framework/ir/fusion_group/code_generator_helper.cc index be06a620f78..fe395e5f992 100644 --- a/paddle/fluid/framework/ir/fusion_group/code_generator_helper.cc +++ b/paddle/fluid/framework/ir/fusion_group/code_generator_helper.cc @@ -33,9 +33,32 @@ static T StringTo(const std::string& str) { return value; } +static std::string ExpandMultivariateTemplate(const std::string rhs, + const size_t input_size) { + int start_pos = rhs.find("[", 0); + int end_pos = rhs.find("]", 0); + std::string sum_rhs = rhs.substr(0, start_pos); + std::string sum_rhs_component = + rhs.substr(start_pos + 1, (end_pos - start_pos - 1)); + int replace_pos = sum_rhs_component.find("?", 0); + + for (size_t i = 1; i < input_size; i++) { + std::string append_str = + sum_rhs_component.replace(replace_pos, 1, std::to_string(i)); + sum_rhs = sum_rhs + append_str; + } + return sum_rhs; +} + std::string OperationExpression::GetRHS(std::unordered_set* used, - size_t i) const { - auto rhs = OperationMap::Instance().Get(op_type_).exprs[i]; + size_t exprs_index) const { + auto rhs = OperationMap::Instance().Get(op_type_).exprs[exprs_index]; + auto num_operands = OperationMap::Instance().Get(op_type_).num_operands; + if (num_operands == -1) { + size_t input_size = input_ids_.size(); + rhs = ExpandMultivariateTemplate(rhs, input_size); + } + for (size_t i = 0; i < rhs.size(); i++) { size_t pos = i; if (rhs[pos] == '$' && rhs[pos + 1] == '{') { diff --git a/paddle/fluid/framework/ir/fusion_group/code_generator_helper.h b/paddle/fluid/framework/ir/fusion_group/code_generator_helper.h index 5749755d3ab..18652b53051 100644 --- a/paddle/fluid/framework/ir/fusion_group/code_generator_helper.h +++ b/paddle/fluid/framework/ir/fusion_group/code_generator_helper.h @@ -52,7 +52,8 @@ class OperationExpression { private: // TODO(wangchao): make offset more flexible we add stride and basic offset - std::string GetRHS(std::unordered_set* used, size_t i = 0) const; + std::string GetRHS(std::unordered_set* used, + size_t exprs_index = 0) const; std::string GetLHS(size_t i = 0) const; private: diff --git a/paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.cc b/paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.cc index 970d97e8e3b..66b4130b0b5 100644 --- a/paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.cc +++ b/paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.cc @@ -24,23 +24,13 @@ namespace framework { namespace ir { namespace fusion_group { -static std::unordered_set binary_op_types; -static std::unordered_set unary_op_types; +static std::unordered_set elementwise_op_types; -static std::unordered_set& GetBinaryOpTypes() { - if (binary_op_types.empty()) { - binary_op_types = - OperationMap::Instance().Find(/* type= */ 0, /* num_operands= */ 2); +static std::unordered_set& GetElementwiseOpTypes() { + if (elementwise_op_types.empty()) { + elementwise_op_types = OperationMap::Instance().Find(/* type= */ 0); } - return binary_op_types; -} - -static std::unordered_set& GetUnaryOpTypes() { - if (unary_op_types.empty()) { - unary_op_types = - OperationMap::Instance().Find(/* type= */ 0, /* num_operands= */ 1); - } - return unary_op_types; + return elementwise_op_types; } static bool IsSpecifiedOp(const std::unordered_set& op_types, @@ -70,13 +60,8 @@ static bool IsEqualAndNotEmpty(const std::vector& l, return l.size() != 0U && r.size() != 0U && l == r; } -static bool IsBinaryOp(const Node* n) { - if (IsSpecifiedOp(GetBinaryOpTypes(), n)) { - if ((!IsGradOp(n) && n->inputs.size() != 2U) || n->inputs.size() == 0U) { - return false; - } - - // The shape of all inputs should be the same. +bool ElementwiseGroupDetector::IsElementwiseOp(const Node* n) { + if (IsSpecifiedOp(GetElementwiseOpTypes(), n)) { std::vector shape_0; for (size_t i = 0; i < n->inputs.size(); ++i) { auto* in_i = n->inputs[i]; @@ -98,14 +83,6 @@ static bool IsBinaryOp(const Node* n) { return false; } -static bool IsUnaryOp(const Node* n) { - return IsSpecifiedOp(GetUnaryOpTypes(), n); -} - -bool ElementwiseGroupDetector::IsElementwiseOp(const Node* n) { - return IsBinaryOp(n) || IsUnaryOp(n); -} - std::vector> ElementwiseGroupDetector::operator()( Graph* graph) { auto teller = [&](const Node* n) -> bool { return IsElementwiseOp(n); }; diff --git a/paddle/fluid/framework/ir/fusion_group/operation.cc b/paddle/fluid/framework/ir/fusion_group/operation.cc index 966cc2752c0..57572af2c2e 100644 --- a/paddle/fluid/framework/ir/fusion_group/operation.cc +++ b/paddle/fluid/framework/ir/fusion_group/operation.cc @@ -25,13 +25,13 @@ OperationMap* OperationMap::map = nullptr; OperationMap::OperationMap() { InsertUnaryElementwiseOperations(); InsertBinaryElementwiseOperations(); + InsertMultivariateElementwiseOperations(); } -std::unordered_set OperationMap::Find(int type, int num_operands) { +std::unordered_set OperationMap::Find(int type) { std::unordered_set res; for (auto& t : operations_) { - if ((t.second.type == type) && - (num_operands < 0 || t.second.num_operands == num_operands)) { + if (t.second.type == type) { res.insert(t.first); } } @@ -153,6 +153,18 @@ void OperationMap::InsertBinaryElementwiseOperations() { {"${3} * (${0} > ${1})", "${3} * (${0} <= ${1})"}); } +void OperationMap::InsertMultivariateElementwiseOperations() { + auto insert_handler = [&](std::string op_type, std::string expr, + std::vector grad_exprs) { + int type = 0; + int num_oprands = -1; + // here ... represent the number of input is changed + Insert(type, num_oprands, op_type, expr, grad_exprs, {"X"}, {"Out"}); + }; + + insert_handler("sum", "${0}[ + ${?}]", {}); +} + } // namespace fusion_group } // namespace ir } // namespace framework diff --git a/paddle/fluid/framework/ir/fusion_group/operation.h b/paddle/fluid/framework/ir/fusion_group/operation.h index d23bea8a437..74abbdaad0b 100644 --- a/paddle/fluid/framework/ir/fusion_group/operation.h +++ b/paddle/fluid/framework/ir/fusion_group/operation.h @@ -84,7 +84,7 @@ class OperationMap { return *map; } - std::unordered_set Find(int type, int num_operands = -1); + std::unordered_set Find(int type); bool Has(std::string op_type) { return operations_.find(op_type) != operations_.end(); @@ -106,6 +106,7 @@ class OperationMap { void InsertUnaryElementwiseOperations(); void InsertBinaryElementwiseOperations(); + void InsertMultivariateElementwiseOperations(); private: static OperationMap* map; diff --git a/python/paddle/fluid/tests/unittests/ir/test_ir_fusion_group_pass.py b/python/paddle/fluid/tests/unittests/ir/test_ir_fusion_group_pass.py index e0121e08eff..c2069a6a855 100644 --- a/python/paddle/fluid/tests/unittests/ir/test_ir_fusion_group_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/test_ir_fusion_group_pass.py @@ -138,5 +138,18 @@ class FusionGroupPassTestFP16(FusionGroupPassTest): self.num_fused_ops = 1 +class FusionGroupPassSumTest(FusionGroupPassTest): + def build_program(self, dtype): + with fluid.program_guard(self.main_program, self.startup_program): + self.feed_vars = self._prepare_feed_vars([32, 128], dtype, 5) + + tmp_0 = layers.elementwise_add(self.feed_vars[0], self.feed_vars[1]) + tmp_1 = layers.sum([tmp_0, self.feed_vars[2], self.feed_vars[3]]) + tmp_2 = layers.sum([tmp_1, self.feed_vars[4]]) + + self.fetch_list = [tmp_0, tmp_1] + self.num_fused_ops = 1 + + if __name__ == "__main__": unittest.main() -- GitLab