未验证 提交 d2801060 编写于 作者: W wangchaochaohu 提交者: GitHub

Add support for attr type Op and add fill_constant Op and scale Op (#23163)

* add attr support for fusion group and add support for fill_constant and scale Op
上级 3a45767d
......@@ -75,8 +75,8 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
for (auto* node : subgraph->SortedNodes()) {
if (node && node->IsOp() && node->Op()) {
auto* op = node->Op();
AttributeMap attr = *(op->MutableAttrMap());
// Input ids should be set in fixed order, like:
// - X, Y in forward operations
// - X, Y, Out, out@GRAD in backward operations
std::vector<int> input_ids;
......@@ -118,8 +118,10 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
std::string lhs_type = ExtractDataType(node->outputs);
std::string rhs_type = ExtractDataType(node->inputs);
expressions.emplace_back(OperationExpression(
node->Name(), input_ids, output_ids, rhs_type, lhs_type));
auto expression = OperationExpression(node->Name(), input_ids, output_ids,
rhs_type, lhs_type);
expression.SetAttr(attr);
expressions.push_back(expression);
}
}
return expressions;
......
......@@ -18,7 +18,14 @@ limitations under the License. */
#include <string>
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/ir/fusion_group/operation.h"
#include "paddle/fluid/framework/op_call_stack.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/shape_inference.h"
#include "paddle/fluid/framework/var_type_inference.h"
namespace paddle {
namespace framework {
......@@ -33,7 +40,7 @@ static T StringTo(const std::string& str) {
return value;
}
static std::string ExpandMultivariateTemplate(const std::string rhs,
static std::string ExpandMultivariateTemplate(const std::string& rhs,
const size_t input_size) {
int start_pos = rhs.find("[", 0);
int end_pos = rhs.find("]", 0);
......@@ -50,6 +57,66 @@ static std::string ExpandMultivariateTemplate(const std::string rhs,
return sum_rhs;
}
static std::string RefineTemplateWithAttr(const std::string& op_type,
const std::string& exp_definition,
const AttributeMap& attrs) {
std::string ret;
// here str_cvt convert string to number in some attr
// for example in fill_constant str_value
std::stringstream str_cvt;
auto IsNumber = [exp_definition]() -> bool {
return exp_definition.find_first_not_of("0123456789") == std::string::npos;
};
if (!IsNumber()) {
// Get attr with different type, Now we only support the simple attr
// condition
std::string attr_name, default_value;
if (exp_definition.find("=") != std::string::npos) {
attr_name = exp_definition.substr(0, exp_definition.find("="));
default_value = exp_definition.substr(exp_definition.rfind("=") + 1,
exp_definition.length() - 1);
ret = default_value;
} else {
attr_name = exp_definition;
}
auto it = attrs.find(attr_name);
if (it == attrs.end()) {
return ret;
}
Attribute attr = it->second;
proto::AttrType attr_type =
static_cast<proto::AttrType>(it->second.which() - 1);
if (attr_type == proto::AttrType::BOOLEAN) {
bool result = boost::get<bool>(attr);
if (result) {
ret = "true";
} else {
ret = "false";
}
} else if (attr_type == proto::AttrType::INT) {
int result = boost::get<int>(attr);
str_cvt << result;
ret = str_cvt.str();
} else if (attr_type == proto::AttrType::LONG) {
int64_t result = boost::get<int64_t>(attr);
str_cvt << result;
ret = str_cvt.str();
} else if (attr_type == proto::AttrType::FLOAT) {
float result = boost::get<float>(attr);
str_cvt << result;
ret = str_cvt.str();
} else if (attr_type == proto::AttrType::STRING) {
std::string result = boost::get<std::string>(attr);
ret = result;
}
} else {
ret = exp_definition;
}
return ret;
}
// In order to avoid multiple __half2float function calls, we do this
// optimization
static std::string OptimzeFP16RHS(std::unordered_set<int>* used,
......@@ -74,7 +141,6 @@ std::string OperationExpression::GetRHS(std::unordered_set<int>* used,
size_t input_size = input_ids_.size();
rhs = ExpandMultivariateTemplate(rhs, input_size);
}
for (size_t i = 0; i < rhs.size(); i++) {
size_t pos = i;
if (rhs[pos] == '$' && rhs[pos + 1] == '{') {
......@@ -83,28 +149,36 @@ std::string OperationExpression::GetRHS(std::unordered_set<int>* used,
length++;
}
std::string index_str = rhs.substr(pos + 2, length);
int index = StringTo<int>(index_str);
PADDLE_ENFORCE_LT(
index, input_ids_.size(),
platform::errors::InvalidArgument(
"Only %d inputs are provided, but need %d for operation < %s >.",
input_ids_.size(), index + 1, op_type_));
PADDLE_ENFORCE_GE(
input_ids_[index], 0,
platform::errors::InvalidArgument(
"Expected %d-th input id > 0 for operation < %s >. Received %d.",
index, op_type_, input_ids_[index]));
// TODO(wangchaochaohu): Here fp16 convert to float to do comupte, we need
// to add general fp16 compute later.
std::string var_name;
if (rhs_type_ == "float16") {
half2fp32_statement->append(OptimzeFP16RHS(used, index, input_ids_));
var_name = "half2fp32_" + TmpName(input_ids_[index]);
std::string refine_str =
RefineTemplateWithAttr(op_type_, index_str, attr_);
if (index_str == refine_str) {
int index = StringTo<int>(index_str);
PADDLE_ENFORCE_LT(index, input_ids_.size(),
platform::errors::InvalidArgument(
"Only %d inputs are provided, but need %d for "
"operation < %s >.",
input_ids_.size(), index + 1, op_type_));
PADDLE_ENFORCE_GE(input_ids_[index], 0,
platform::errors::InvalidArgument(
"Expected %d-th input id > 0 for operation < %s "
">. Received %d.",
index, op_type_, input_ids_[index]));
// TODO(wangchaochaohu): Here fp16 convert to float to do comupte, we
// need
// to add general fp16 compute later.
std::string var_name;
if (rhs_type_ == "float16") {
half2fp32_statement->append(OptimzeFP16RHS(used, index, input_ids_));
var_name = "half2fp32_" + TmpName(input_ids_[index]);
} else {
var_name = TmpName(input_ids_[index]);
}
rhs.replace(pos, length + 3, var_name);
used->insert(input_ids_[index]);
} else {
var_name = TmpName(input_ids_[index]);
std::string var_name = refine_str;
rhs.replace(pos, length + 3, var_name);
}
rhs.replace(pos, length + 3, var_name);
used->insert(input_ids_[index]);
}
}
return rhs;
......
......@@ -20,6 +20,9 @@ limitations under the License. */
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
......@@ -55,7 +58,8 @@ class OperationExpression {
std::vector<int> GetOutputIds() const { return output_ids_; }
std::string GetRHSType() const { return rhs_type_; }
std::string GetLHSType() const { return lhs_type_; }
void SetAttr(AttributeMap attr) { attr_ = attr; }
AttributeMap GetAttr() { return attr_; }
// Check whether this operation type is supported in OperationMap.
bool IsSupport() const;
......@@ -72,6 +76,7 @@ class OperationExpression {
std::string op_type_;
std::vector<int> input_ids_;
std::vector<int> output_ids_;
AttributeMap attr_;
std::string rhs_type_;
std::string lhs_type_;
};
......
......@@ -118,6 +118,16 @@ void OperationMap::InsertUnaryElementwiseOperations() {
// out = x^2
// dx = dout * 2.0 * x
insert_handler("square", "${0} * ${0}", {"${2} * 2.0 * ${0}"});
// scale
// out = (bias_after_scale) ? scale * X + bias : scale(X + bias)
// here we use '=' operator to seperate th default value
// TODO(wangchaochaohu): Later we need to support Tensor input for scale and
// bias.
insert_handler("scale",
"${bias_after_scale=true} ? (${scale=1.0} * ${0} + "
"${bias=0.0}) : (${scale=1.0} * (${0} + ${bias=0.0}))",
{});
}
void OperationMap::InsertBinaryElementwiseOperations() {
......@@ -185,6 +195,15 @@ void OperationMap::InsertMultivariateElementwiseOperations() {
// For example, sum with 4 inputs, the expanded expression is:
// ${0} + ${1} + ${2} + ${3}
insert_handler("sum", "${0}[ + ${?}]", {});
auto insert_handler_without_input = [&](std::string op_type, std::string expr,
std::vector<std::string> grad_exprs) {
int type = 0;
int num_oprands = -1;
Insert(type, num_oprands, op_type, expr, grad_exprs, {}, {"Out"});
};
// fill_constant:
insert_handler_without_input("fill_constant", "${str_value}", {});
}
} // namespace fusion_group
......
......@@ -165,7 +165,7 @@ class FusionGroupPassSumTest(FusionGroupPassTest):
self.append_gradients(tmp_3)
self.num_fused_ops = 3
self.num_fused_ops = 4
self.fetch_list = [tmp_3, self.grad(tmp_0)]
......@@ -190,5 +190,28 @@ class FusionGroupPassCastTest(FusionGroupPassTest):
self.fused_op_type = "fusion_group"
class FusionGroupPassFillConstantTest(FusionGroupPassTest):
def build_program(self, dtype):
with fluid.program_guard(self.main_program, self.startup_program):
self.feed_vars = self._prepare_feed_vars([2, 2], dtype, 2)
tmp_0 = layers.elementwise_add(self.feed_vars[0], self.feed_vars[1])
tmp_1 = layers.fill_constant(shape=[2, 2], dtype=dtype, value=2.0)
tmp_2 = layers.scale(
tmp_1, scale=3.0, bias=1.0, bias_after_scale=True)
tmp_3 = layers.elementwise_mul(tmp_2, tmp_0)
self.append_gradients(tmp_3)
self.num_fused_ops = 1
self.fetch_list = [tmp_2, self.grad(tmp_0)]
def setUp(self):
self.build_program("float32")
self.feeds = self._feed_random_data(self.feed_vars)
self.pass_names = "fusion_group_pass"
self.fused_op_type = "fusion_group"
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册