From d280106007d43cd216d7e67e0769455772a28a7d Mon Sep 17 00:00:00 2001 From: wangchaochaohu Date: Wed, 1 Apr 2020 13:15:03 +0800 Subject: [PATCH] Add support for attr type Op and add fill_constant Op and scale Op (#23163) * add attr support for fusion group and add support for fill_constant and scale Op --- .../ir/fusion_group/code_generator.cc | 8 +- .../ir/fusion_group/code_generator_helper.cc | 118 ++++++++++++++---- .../ir/fusion_group/code_generator_helper.h | 7 +- .../framework/ir/fusion_group/operation.cc | 19 +++ .../unittests/ir/test_ir_fusion_group_pass.py | 25 +++- 5 files changed, 150 insertions(+), 27 deletions(-) diff --git a/paddle/fluid/framework/ir/fusion_group/code_generator.cc b/paddle/fluid/framework/ir/fusion_group/code_generator.cc index 133f1dfce7c..24c3f876b94 100644 --- a/paddle/fluid/framework/ir/fusion_group/code_generator.cc +++ b/paddle/fluid/framework/ir/fusion_group/code_generator.cc @@ -75,8 +75,8 @@ std::vector CodeGenerator::ConvertToExpressions( for (auto* node : subgraph->SortedNodes()) { if (node && node->IsOp() && node->Op()) { auto* op = node->Op(); + AttributeMap attr = *(op->MutableAttrMap()); - // Input ids should be set in fixed order, like: // - X, Y in forward operations // - X, Y, Out, out@GRAD in backward operations std::vector input_ids; @@ -118,8 +118,10 @@ std::vector CodeGenerator::ConvertToExpressions( std::string lhs_type = ExtractDataType(node->outputs); std::string rhs_type = ExtractDataType(node->inputs); - expressions.emplace_back(OperationExpression( - node->Name(), input_ids, output_ids, rhs_type, lhs_type)); + auto expression = OperationExpression(node->Name(), input_ids, output_ids, + rhs_type, lhs_type); + expression.SetAttr(attr); + expressions.push_back(expression); } } return expressions; diff --git a/paddle/fluid/framework/ir/fusion_group/code_generator_helper.cc b/paddle/fluid/framework/ir/fusion_group/code_generator_helper.cc index 573f4c4de3e..1b8c6b2fdfe 100644 --- a/paddle/fluid/framework/ir/fusion_group/code_generator_helper.cc +++ b/paddle/fluid/framework/ir/fusion_group/code_generator_helper.cc @@ -18,7 +18,14 @@ limitations under the License. */ #include #include #include "glog/logging.h" +#include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/ir/fusion_group/operation.h" +#include "paddle/fluid/framework/op_call_stack.h" +#include "paddle/fluid/framework/op_proto_maker.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/shape_inference.h" +#include "paddle/fluid/framework/var_type_inference.h" namespace paddle { namespace framework { @@ -33,7 +40,7 @@ static T StringTo(const std::string& str) { return value; } -static std::string ExpandMultivariateTemplate(const std::string rhs, +static std::string ExpandMultivariateTemplate(const std::string& rhs, const size_t input_size) { int start_pos = rhs.find("[", 0); int end_pos = rhs.find("]", 0); @@ -50,6 +57,66 @@ static std::string ExpandMultivariateTemplate(const std::string rhs, return sum_rhs; } +static std::string RefineTemplateWithAttr(const std::string& op_type, + const std::string& exp_definition, + const AttributeMap& attrs) { + std::string ret; + // here str_cvt convert string to number in some attr + // for example in fill_constant str_value + std::stringstream str_cvt; + auto IsNumber = [exp_definition]() -> bool { + return exp_definition.find_first_not_of("0123456789") == std::string::npos; + }; + + if (!IsNumber()) { + // Get attr with different type, Now we only support the simple attr + // condition + std::string attr_name, default_value; + if (exp_definition.find("=") != std::string::npos) { + attr_name = exp_definition.substr(0, exp_definition.find("=")); + default_value = exp_definition.substr(exp_definition.rfind("=") + 1, + exp_definition.length() - 1); + ret = default_value; + } else { + attr_name = exp_definition; + } + auto it = attrs.find(attr_name); + if (it == attrs.end()) { + return ret; + } + Attribute attr = it->second; + proto::AttrType attr_type = + static_cast(it->second.which() - 1); + if (attr_type == proto::AttrType::BOOLEAN) { + bool result = boost::get(attr); + if (result) { + ret = "true"; + } else { + ret = "false"; + } + } else if (attr_type == proto::AttrType::INT) { + int result = boost::get(attr); + str_cvt << result; + ret = str_cvt.str(); + } else if (attr_type == proto::AttrType::LONG) { + int64_t result = boost::get(attr); + str_cvt << result; + ret = str_cvt.str(); + } else if (attr_type == proto::AttrType::FLOAT) { + float result = boost::get(attr); + str_cvt << result; + ret = str_cvt.str(); + } else if (attr_type == proto::AttrType::STRING) { + std::string result = boost::get(attr); + ret = result; + } + } else { + ret = exp_definition; + } + + return ret; +} + // In order to avoid multiple __half2float function calls, we do this // optimization static std::string OptimzeFP16RHS(std::unordered_set* used, @@ -74,7 +141,6 @@ std::string OperationExpression::GetRHS(std::unordered_set* used, size_t input_size = input_ids_.size(); rhs = ExpandMultivariateTemplate(rhs, input_size); } - for (size_t i = 0; i < rhs.size(); i++) { size_t pos = i; if (rhs[pos] == '$' && rhs[pos + 1] == '{') { @@ -83,28 +149,36 @@ std::string OperationExpression::GetRHS(std::unordered_set* used, length++; } std::string index_str = rhs.substr(pos + 2, length); - int index = StringTo(index_str); - PADDLE_ENFORCE_LT( - index, input_ids_.size(), - platform::errors::InvalidArgument( - "Only %d inputs are provided, but need %d for operation < %s >.", - input_ids_.size(), index + 1, op_type_)); - PADDLE_ENFORCE_GE( - input_ids_[index], 0, - platform::errors::InvalidArgument( - "Expected %d-th input id > 0 for operation < %s >. Received %d.", - index, op_type_, input_ids_[index])); - // TODO(wangchaochaohu): Here fp16 convert to float to do comupte, we need - // to add general fp16 compute later. - std::string var_name; - if (rhs_type_ == "float16") { - half2fp32_statement->append(OptimzeFP16RHS(used, index, input_ids_)); - var_name = "half2fp32_" + TmpName(input_ids_[index]); + std::string refine_str = + RefineTemplateWithAttr(op_type_, index_str, attr_); + if (index_str == refine_str) { + int index = StringTo(index_str); + PADDLE_ENFORCE_LT(index, input_ids_.size(), + platform::errors::InvalidArgument( + "Only %d inputs are provided, but need %d for " + "operation < %s >.", + input_ids_.size(), index + 1, op_type_)); + PADDLE_ENFORCE_GE(input_ids_[index], 0, + platform::errors::InvalidArgument( + "Expected %d-th input id > 0 for operation < %s " + ">. Received %d.", + index, op_type_, input_ids_[index])); + // TODO(wangchaochaohu): Here fp16 convert to float to do comupte, we + // need + // to add general fp16 compute later. + std::string var_name; + if (rhs_type_ == "float16") { + half2fp32_statement->append(OptimzeFP16RHS(used, index, input_ids_)); + var_name = "half2fp32_" + TmpName(input_ids_[index]); + } else { + var_name = TmpName(input_ids_[index]); + } + rhs.replace(pos, length + 3, var_name); + used->insert(input_ids_[index]); } else { - var_name = TmpName(input_ids_[index]); + std::string var_name = refine_str; + rhs.replace(pos, length + 3, var_name); } - rhs.replace(pos, length + 3, var_name); - used->insert(input_ids_[index]); } } return rhs; diff --git a/paddle/fluid/framework/ir/fusion_group/code_generator_helper.h b/paddle/fluid/framework/ir/fusion_group/code_generator_helper.h index 10bd2b20a2f..2a35ac6f632 100644 --- a/paddle/fluid/framework/ir/fusion_group/code_generator_helper.h +++ b/paddle/fluid/framework/ir/fusion_group/code_generator_helper.h @@ -20,6 +20,9 @@ limitations under the License. */ #include #include +#include "paddle/fluid/framework/attribute.h" +#include "paddle/fluid/framework/type_defs.h" +#include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/platform/enforce.h" namespace paddle { @@ -55,7 +58,8 @@ class OperationExpression { std::vector GetOutputIds() const { return output_ids_; } std::string GetRHSType() const { return rhs_type_; } std::string GetLHSType() const { return lhs_type_; } - + void SetAttr(AttributeMap attr) { attr_ = attr; } + AttributeMap GetAttr() { return attr_; } // Check whether this operation type is supported in OperationMap. bool IsSupport() const; @@ -72,6 +76,7 @@ class OperationExpression { std::string op_type_; std::vector input_ids_; std::vector output_ids_; + AttributeMap attr_; std::string rhs_type_; std::string lhs_type_; }; diff --git a/paddle/fluid/framework/ir/fusion_group/operation.cc b/paddle/fluid/framework/ir/fusion_group/operation.cc index 0ff9754bb29..99ae8be37b9 100644 --- a/paddle/fluid/framework/ir/fusion_group/operation.cc +++ b/paddle/fluid/framework/ir/fusion_group/operation.cc @@ -118,6 +118,16 @@ void OperationMap::InsertUnaryElementwiseOperations() { // out = x^2 // dx = dout * 2.0 * x insert_handler("square", "${0} * ${0}", {"${2} * 2.0 * ${0}"}); + + // scale + // out = (bias_after_scale) ? scale * X + bias : scale(X + bias) + // here we use '=' operator to seperate th default value + // TODO(wangchaochaohu): Later we need to support Tensor input for scale and + // bias. + insert_handler("scale", + "${bias_after_scale=true} ? (${scale=1.0} * ${0} + " + "${bias=0.0}) : (${scale=1.0} * (${0} + ${bias=0.0}))", + {}); } void OperationMap::InsertBinaryElementwiseOperations() { @@ -185,6 +195,15 @@ void OperationMap::InsertMultivariateElementwiseOperations() { // For example, sum with 4 inputs, the expanded expression is: // ${0} + ${1} + ${2} + ${3} insert_handler("sum", "${0}[ + ${?}]", {}); + + auto insert_handler_without_input = [&](std::string op_type, std::string expr, + std::vector grad_exprs) { + int type = 0; + int num_oprands = -1; + Insert(type, num_oprands, op_type, expr, grad_exprs, {}, {"Out"}); + }; + // fill_constant: + insert_handler_without_input("fill_constant", "${str_value}", {}); } } // namespace fusion_group diff --git a/python/paddle/fluid/tests/unittests/ir/test_ir_fusion_group_pass.py b/python/paddle/fluid/tests/unittests/ir/test_ir_fusion_group_pass.py index ef57752e87e..d33658c89f1 100644 --- a/python/paddle/fluid/tests/unittests/ir/test_ir_fusion_group_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/test_ir_fusion_group_pass.py @@ -165,7 +165,7 @@ class FusionGroupPassSumTest(FusionGroupPassTest): self.append_gradients(tmp_3) - self.num_fused_ops = 3 + self.num_fused_ops = 4 self.fetch_list = [tmp_3, self.grad(tmp_0)] @@ -190,5 +190,28 @@ class FusionGroupPassCastTest(FusionGroupPassTest): self.fused_op_type = "fusion_group" +class FusionGroupPassFillConstantTest(FusionGroupPassTest): + def build_program(self, dtype): + with fluid.program_guard(self.main_program, self.startup_program): + self.feed_vars = self._prepare_feed_vars([2, 2], dtype, 2) + + tmp_0 = layers.elementwise_add(self.feed_vars[0], self.feed_vars[1]) + tmp_1 = layers.fill_constant(shape=[2, 2], dtype=dtype, value=2.0) + tmp_2 = layers.scale( + tmp_1, scale=3.0, bias=1.0, bias_after_scale=True) + tmp_3 = layers.elementwise_mul(tmp_2, tmp_0) + + self.append_gradients(tmp_3) + + self.num_fused_ops = 1 + self.fetch_list = [tmp_2, self.grad(tmp_0)] + + def setUp(self): + self.build_program("float32") + self.feeds = self._feed_random_data(self.feed_vars) + self.pass_names = "fusion_group_pass" + self.fused_op_type = "fusion_group" + + if __name__ == "__main__": unittest.main() -- GitLab