未验证 提交 d2801060 编写于 作者: W wangchaochaohu 提交者: GitHub

Add support for attr type Op and add fill_constant Op and scale Op (#23163)

* add attr support for fusion group and add support for fill_constant and scale Op
上级 3a45767d
...@@ -75,8 +75,8 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions( ...@@ -75,8 +75,8 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
for (auto* node : subgraph->SortedNodes()) { for (auto* node : subgraph->SortedNodes()) {
if (node && node->IsOp() && node->Op()) { if (node && node->IsOp() && node->Op()) {
auto* op = node->Op(); auto* op = node->Op();
AttributeMap attr = *(op->MutableAttrMap());
// Input ids should be set in fixed order, like:
// - X, Y in forward operations // - X, Y in forward operations
// - X, Y, Out, out@GRAD in backward operations // - X, Y, Out, out@GRAD in backward operations
std::vector<int> input_ids; std::vector<int> input_ids;
...@@ -118,8 +118,10 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions( ...@@ -118,8 +118,10 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
std::string lhs_type = ExtractDataType(node->outputs); std::string lhs_type = ExtractDataType(node->outputs);
std::string rhs_type = ExtractDataType(node->inputs); std::string rhs_type = ExtractDataType(node->inputs);
expressions.emplace_back(OperationExpression( auto expression = OperationExpression(node->Name(), input_ids, output_ids,
node->Name(), input_ids, output_ids, rhs_type, lhs_type)); rhs_type, lhs_type);
expression.SetAttr(attr);
expressions.push_back(expression);
} }
} }
return expressions; return expressions;
......
...@@ -18,7 +18,14 @@ limitations under the License. */ ...@@ -18,7 +18,14 @@ limitations under the License. */
#include <string> #include <string>
#include <vector> #include <vector>
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/ir/fusion_group/operation.h" #include "paddle/fluid/framework/ir/fusion_group/operation.h"
#include "paddle/fluid/framework/op_call_stack.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/shape_inference.h"
#include "paddle/fluid/framework/var_type_inference.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -33,7 +40,7 @@ static T StringTo(const std::string& str) { ...@@ -33,7 +40,7 @@ static T StringTo(const std::string& str) {
return value; return value;
} }
static std::string ExpandMultivariateTemplate(const std::string rhs, static std::string ExpandMultivariateTemplate(const std::string& rhs,
const size_t input_size) { const size_t input_size) {
int start_pos = rhs.find("[", 0); int start_pos = rhs.find("[", 0);
int end_pos = rhs.find("]", 0); int end_pos = rhs.find("]", 0);
...@@ -50,6 +57,66 @@ static std::string ExpandMultivariateTemplate(const std::string rhs, ...@@ -50,6 +57,66 @@ static std::string ExpandMultivariateTemplate(const std::string rhs,
return sum_rhs; return sum_rhs;
} }
static std::string RefineTemplateWithAttr(const std::string& op_type,
const std::string& exp_definition,
const AttributeMap& attrs) {
std::string ret;
// here str_cvt convert string to number in some attr
// for example in fill_constant str_value
std::stringstream str_cvt;
auto IsNumber = [exp_definition]() -> bool {
return exp_definition.find_first_not_of("0123456789") == std::string::npos;
};
if (!IsNumber()) {
// Get attr with different type, Now we only support the simple attr
// condition
std::string attr_name, default_value;
if (exp_definition.find("=") != std::string::npos) {
attr_name = exp_definition.substr(0, exp_definition.find("="));
default_value = exp_definition.substr(exp_definition.rfind("=") + 1,
exp_definition.length() - 1);
ret = default_value;
} else {
attr_name = exp_definition;
}
auto it = attrs.find(attr_name);
if (it == attrs.end()) {
return ret;
}
Attribute attr = it->second;
proto::AttrType attr_type =
static_cast<proto::AttrType>(it->second.which() - 1);
if (attr_type == proto::AttrType::BOOLEAN) {
bool result = boost::get<bool>(attr);
if (result) {
ret = "true";
} else {
ret = "false";
}
} else if (attr_type == proto::AttrType::INT) {
int result = boost::get<int>(attr);
str_cvt << result;
ret = str_cvt.str();
} else if (attr_type == proto::AttrType::LONG) {
int64_t result = boost::get<int64_t>(attr);
str_cvt << result;
ret = str_cvt.str();
} else if (attr_type == proto::AttrType::FLOAT) {
float result = boost::get<float>(attr);
str_cvt << result;
ret = str_cvt.str();
} else if (attr_type == proto::AttrType::STRING) {
std::string result = boost::get<std::string>(attr);
ret = result;
}
} else {
ret = exp_definition;
}
return ret;
}
// In order to avoid multiple __half2float function calls, we do this // In order to avoid multiple __half2float function calls, we do this
// optimization // optimization
static std::string OptimzeFP16RHS(std::unordered_set<int>* used, static std::string OptimzeFP16RHS(std::unordered_set<int>* used,
...@@ -74,7 +141,6 @@ std::string OperationExpression::GetRHS(std::unordered_set<int>* used, ...@@ -74,7 +141,6 @@ std::string OperationExpression::GetRHS(std::unordered_set<int>* used,
size_t input_size = input_ids_.size(); size_t input_size = input_ids_.size();
rhs = ExpandMultivariateTemplate(rhs, input_size); rhs = ExpandMultivariateTemplate(rhs, input_size);
} }
for (size_t i = 0; i < rhs.size(); i++) { for (size_t i = 0; i < rhs.size(); i++) {
size_t pos = i; size_t pos = i;
if (rhs[pos] == '$' && rhs[pos + 1] == '{') { if (rhs[pos] == '$' && rhs[pos + 1] == '{') {
...@@ -83,28 +149,36 @@ std::string OperationExpression::GetRHS(std::unordered_set<int>* used, ...@@ -83,28 +149,36 @@ std::string OperationExpression::GetRHS(std::unordered_set<int>* used,
length++; length++;
} }
std::string index_str = rhs.substr(pos + 2, length); std::string index_str = rhs.substr(pos + 2, length);
int index = StringTo<int>(index_str); std::string refine_str =
PADDLE_ENFORCE_LT( RefineTemplateWithAttr(op_type_, index_str, attr_);
index, input_ids_.size(), if (index_str == refine_str) {
platform::errors::InvalidArgument( int index = StringTo<int>(index_str);
"Only %d inputs are provided, but need %d for operation < %s >.", PADDLE_ENFORCE_LT(index, input_ids_.size(),
input_ids_.size(), index + 1, op_type_)); platform::errors::InvalidArgument(
PADDLE_ENFORCE_GE( "Only %d inputs are provided, but need %d for "
input_ids_[index], 0, "operation < %s >.",
platform::errors::InvalidArgument( input_ids_.size(), index + 1, op_type_));
"Expected %d-th input id > 0 for operation < %s >. Received %d.", PADDLE_ENFORCE_GE(input_ids_[index], 0,
index, op_type_, input_ids_[index])); platform::errors::InvalidArgument(
// TODO(wangchaochaohu): Here fp16 convert to float to do comupte, we need "Expected %d-th input id > 0 for operation < %s "
// to add general fp16 compute later. ">. Received %d.",
std::string var_name; index, op_type_, input_ids_[index]));
if (rhs_type_ == "float16") { // TODO(wangchaochaohu): Here fp16 convert to float to do comupte, we
half2fp32_statement->append(OptimzeFP16RHS(used, index, input_ids_)); // need
var_name = "half2fp32_" + TmpName(input_ids_[index]); // to add general fp16 compute later.
std::string var_name;
if (rhs_type_ == "float16") {
half2fp32_statement->append(OptimzeFP16RHS(used, index, input_ids_));
var_name = "half2fp32_" + TmpName(input_ids_[index]);
} else {
var_name = TmpName(input_ids_[index]);
}
rhs.replace(pos, length + 3, var_name);
used->insert(input_ids_[index]);
} else { } else {
var_name = TmpName(input_ids_[index]); std::string var_name = refine_str;
rhs.replace(pos, length + 3, var_name);
} }
rhs.replace(pos, length + 3, var_name);
used->insert(input_ids_[index]);
} }
} }
return rhs; return rhs;
......
...@@ -20,6 +20,9 @@ limitations under the License. */ ...@@ -20,6 +20,9 @@ limitations under the License. */
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
...@@ -55,7 +58,8 @@ class OperationExpression { ...@@ -55,7 +58,8 @@ class OperationExpression {
std::vector<int> GetOutputIds() const { return output_ids_; } std::vector<int> GetOutputIds() const { return output_ids_; }
std::string GetRHSType() const { return rhs_type_; } std::string GetRHSType() const { return rhs_type_; }
std::string GetLHSType() const { return lhs_type_; } std::string GetLHSType() const { return lhs_type_; }
void SetAttr(AttributeMap attr) { attr_ = attr; }
AttributeMap GetAttr() { return attr_; }
// Check whether this operation type is supported in OperationMap. // Check whether this operation type is supported in OperationMap.
bool IsSupport() const; bool IsSupport() const;
...@@ -72,6 +76,7 @@ class OperationExpression { ...@@ -72,6 +76,7 @@ class OperationExpression {
std::string op_type_; std::string op_type_;
std::vector<int> input_ids_; std::vector<int> input_ids_;
std::vector<int> output_ids_; std::vector<int> output_ids_;
AttributeMap attr_;
std::string rhs_type_; std::string rhs_type_;
std::string lhs_type_; std::string lhs_type_;
}; };
......
...@@ -118,6 +118,16 @@ void OperationMap::InsertUnaryElementwiseOperations() { ...@@ -118,6 +118,16 @@ void OperationMap::InsertUnaryElementwiseOperations() {
// out = x^2 // out = x^2
// dx = dout * 2.0 * x // dx = dout * 2.0 * x
insert_handler("square", "${0} * ${0}", {"${2} * 2.0 * ${0}"}); insert_handler("square", "${0} * ${0}", {"${2} * 2.0 * ${0}"});
// scale
// out = (bias_after_scale) ? scale * X + bias : scale(X + bias)
// here we use '=' operator to seperate th default value
// TODO(wangchaochaohu): Later we need to support Tensor input for scale and
// bias.
insert_handler("scale",
"${bias_after_scale=true} ? (${scale=1.0} * ${0} + "
"${bias=0.0}) : (${scale=1.0} * (${0} + ${bias=0.0}))",
{});
} }
void OperationMap::InsertBinaryElementwiseOperations() { void OperationMap::InsertBinaryElementwiseOperations() {
...@@ -185,6 +195,15 @@ void OperationMap::InsertMultivariateElementwiseOperations() { ...@@ -185,6 +195,15 @@ void OperationMap::InsertMultivariateElementwiseOperations() {
// For example, sum with 4 inputs, the expanded expression is: // For example, sum with 4 inputs, the expanded expression is:
// ${0} + ${1} + ${2} + ${3} // ${0} + ${1} + ${2} + ${3}
insert_handler("sum", "${0}[ + ${?}]", {}); insert_handler("sum", "${0}[ + ${?}]", {});
auto insert_handler_without_input = [&](std::string op_type, std::string expr,
std::vector<std::string> grad_exprs) {
int type = 0;
int num_oprands = -1;
Insert(type, num_oprands, op_type, expr, grad_exprs, {}, {"Out"});
};
// fill_constant:
insert_handler_without_input("fill_constant", "${str_value}", {});
} }
} // namespace fusion_group } // namespace fusion_group
......
...@@ -165,7 +165,7 @@ class FusionGroupPassSumTest(FusionGroupPassTest): ...@@ -165,7 +165,7 @@ class FusionGroupPassSumTest(FusionGroupPassTest):
self.append_gradients(tmp_3) self.append_gradients(tmp_3)
self.num_fused_ops = 3 self.num_fused_ops = 4
self.fetch_list = [tmp_3, self.grad(tmp_0)] self.fetch_list = [tmp_3, self.grad(tmp_0)]
...@@ -190,5 +190,28 @@ class FusionGroupPassCastTest(FusionGroupPassTest): ...@@ -190,5 +190,28 @@ class FusionGroupPassCastTest(FusionGroupPassTest):
self.fused_op_type = "fusion_group" self.fused_op_type = "fusion_group"
class FusionGroupPassFillConstantTest(FusionGroupPassTest):
def build_program(self, dtype):
with fluid.program_guard(self.main_program, self.startup_program):
self.feed_vars = self._prepare_feed_vars([2, 2], dtype, 2)
tmp_0 = layers.elementwise_add(self.feed_vars[0], self.feed_vars[1])
tmp_1 = layers.fill_constant(shape=[2, 2], dtype=dtype, value=2.0)
tmp_2 = layers.scale(
tmp_1, scale=3.0, bias=1.0, bias_after_scale=True)
tmp_3 = layers.elementwise_mul(tmp_2, tmp_0)
self.append_gradients(tmp_3)
self.num_fused_ops = 1
self.fetch_list = [tmp_2, self.grad(tmp_0)]
def setUp(self):
self.build_program("float32")
self.feeds = self._feed_random_data(self.feed_vars)
self.pass_names = "fusion_group_pass"
self.fused_op_type = "fusion_group"
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册