未验证 提交 ca9e77a8 编写于 作者: W wangchaochaohu 提交者: GitHub

add sum op support for fusion group (#22771)

* Add the codegen and auto fusion for sum Op  in fusion group
上级 b681215a
......@@ -60,18 +60,21 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
// - X, Y in forward operations
// - X, Y, Out, out@GRAD in backward operations
std::vector<int> input_ids;
std::vector<std::string> input_names =
OperationMap::Instance().Get(op->Type()).input_names;
auto operation = OperationMap::Instance().Get(op->Type());
std::vector<std::string> input_names = operation.input_names;
for (auto& name : input_names) {
// Some input vars are not used in grad ops, such as
// "elementwise_add_grad", where "X", "Y" and "Out" are not used.
if (HasInput(node, name) && op->Input(name).size() >= 1U) {
// TODO(liuyiqun): support duplicated input.
if ((HasInput(node, name) && op->Input(name).size() >= 1U)) {
for (size_t i = 0; i < op->Input(name).size(); i++) {
PADDLE_ENFORCE_NE(
var_ids.find(op->Input(name)[0]), var_ids.end(),
var_ids.find(op->Input(name)[i]), var_ids.end(),
platform::errors::InvalidArgument(
"Input(%s) of operation %s is not set.", name, op->Type()));
input_ids.push_back(var_ids[op->Input(name)[0]]);
input_ids.push_back(var_ids[op->Input(name)[i]]);
}
} else {
input_ids.push_back(-1);
}
......
......@@ -33,9 +33,32 @@ static T StringTo(const std::string& str) {
return value;
}
static std::string ExpandMultivariateTemplate(const std::string rhs,
const size_t input_size) {
int start_pos = rhs.find("[", 0);
int end_pos = rhs.find("]", 0);
std::string sum_rhs = rhs.substr(0, start_pos);
std::string sum_rhs_component =
rhs.substr(start_pos + 1, (end_pos - start_pos - 1));
int replace_pos = sum_rhs_component.find("?", 0);
for (size_t i = 1; i < input_size; i++) {
std::string append_str =
sum_rhs_component.replace(replace_pos, 1, std::to_string(i));
sum_rhs = sum_rhs + append_str;
}
return sum_rhs;
}
std::string OperationExpression::GetRHS(std::unordered_set<int>* used,
size_t i) const {
auto rhs = OperationMap::Instance().Get(op_type_).exprs[i];
size_t exprs_index) const {
auto rhs = OperationMap::Instance().Get(op_type_).exprs[exprs_index];
auto num_operands = OperationMap::Instance().Get(op_type_).num_operands;
if (num_operands == -1) {
size_t input_size = input_ids_.size();
rhs = ExpandMultivariateTemplate(rhs, input_size);
}
for (size_t i = 0; i < rhs.size(); i++) {
size_t pos = i;
if (rhs[pos] == '$' && rhs[pos + 1] == '{') {
......
......@@ -52,7 +52,8 @@ class OperationExpression {
private:
// TODO(wangchao): make offset more flexible we add stride and basic offset
std::string GetRHS(std::unordered_set<int>* used, size_t i = 0) const;
std::string GetRHS(std::unordered_set<int>* used,
size_t exprs_index = 0) const;
std::string GetLHS(size_t i = 0) const;
private:
......
......@@ -24,23 +24,13 @@ namespace framework {
namespace ir {
namespace fusion_group {
static std::unordered_set<std::string> binary_op_types;
static std::unordered_set<std::string> unary_op_types;
static std::unordered_set<std::string> elementwise_op_types;
static std::unordered_set<std::string>& GetBinaryOpTypes() {
if (binary_op_types.empty()) {
binary_op_types =
OperationMap::Instance().Find(/* type= */ 0, /* num_operands= */ 2);
static std::unordered_set<std::string>& GetElementwiseOpTypes() {
if (elementwise_op_types.empty()) {
elementwise_op_types = OperationMap::Instance().Find(/* type= */ 0);
}
return binary_op_types;
}
static std::unordered_set<std::string>& GetUnaryOpTypes() {
if (unary_op_types.empty()) {
unary_op_types =
OperationMap::Instance().Find(/* type= */ 0, /* num_operands= */ 1);
}
return unary_op_types;
return elementwise_op_types;
}
static bool IsSpecifiedOp(const std::unordered_set<std::string>& op_types,
......@@ -70,13 +60,8 @@ static bool IsEqualAndNotEmpty(const std::vector<int64_t>& l,
return l.size() != 0U && r.size() != 0U && l == r;
}
static bool IsBinaryOp(const Node* n) {
if (IsSpecifiedOp(GetBinaryOpTypes(), n)) {
if ((!IsGradOp(n) && n->inputs.size() != 2U) || n->inputs.size() == 0U) {
return false;
}
// The shape of all inputs should be the same.
bool ElementwiseGroupDetector::IsElementwiseOp(const Node* n) {
if (IsSpecifiedOp(GetElementwiseOpTypes(), n)) {
std::vector<int64_t> shape_0;
for (size_t i = 0; i < n->inputs.size(); ++i) {
auto* in_i = n->inputs[i];
......@@ -98,14 +83,6 @@ static bool IsBinaryOp(const Node* n) {
return false;
}
static bool IsUnaryOp(const Node* n) {
return IsSpecifiedOp(GetUnaryOpTypes(), n);
}
bool ElementwiseGroupDetector::IsElementwiseOp(const Node* n) {
return IsBinaryOp(n) || IsUnaryOp(n);
}
std::vector<std::vector<Node*>> ElementwiseGroupDetector::operator()(
Graph* graph) {
auto teller = [&](const Node* n) -> bool { return IsElementwiseOp(n); };
......
......@@ -25,13 +25,13 @@ OperationMap* OperationMap::map = nullptr;
OperationMap::OperationMap() {
InsertUnaryElementwiseOperations();
InsertBinaryElementwiseOperations();
InsertMultivariateElementwiseOperations();
}
std::unordered_set<std::string> OperationMap::Find(int type, int num_operands) {
std::unordered_set<std::string> OperationMap::Find(int type) {
std::unordered_set<std::string> res;
for (auto& t : operations_) {
if ((t.second.type == type) &&
(num_operands < 0 || t.second.num_operands == num_operands)) {
if (t.second.type == type) {
res.insert(t.first);
}
}
......@@ -153,6 +153,18 @@ void OperationMap::InsertBinaryElementwiseOperations() {
{"${3} * (${0} > ${1})", "${3} * (${0} <= ${1})"});
}
void OperationMap::InsertMultivariateElementwiseOperations() {
auto insert_handler = [&](std::string op_type, std::string expr,
std::vector<std::string> grad_exprs) {
int type = 0;
int num_oprands = -1;
// here ... represent the number of input is changed
Insert(type, num_oprands, op_type, expr, grad_exprs, {"X"}, {"Out"});
};
insert_handler("sum", "${0}[ + ${?}]", {});
}
} // namespace fusion_group
} // namespace ir
} // namespace framework
......
......@@ -84,7 +84,7 @@ class OperationMap {
return *map;
}
std::unordered_set<std::string> Find(int type, int num_operands = -1);
std::unordered_set<std::string> Find(int type);
bool Has(std::string op_type) {
return operations_.find(op_type) != operations_.end();
......@@ -106,6 +106,7 @@ class OperationMap {
void InsertUnaryElementwiseOperations();
void InsertBinaryElementwiseOperations();
void InsertMultivariateElementwiseOperations();
private:
static OperationMap* map;
......
......@@ -138,5 +138,18 @@ class FusionGroupPassTestFP16(FusionGroupPassTest):
self.num_fused_ops = 1
class FusionGroupPassSumTest(FusionGroupPassTest):
def build_program(self, dtype):
with fluid.program_guard(self.main_program, self.startup_program):
self.feed_vars = self._prepare_feed_vars([32, 128], dtype, 5)
tmp_0 = layers.elementwise_add(self.feed_vars[0], self.feed_vars[1])
tmp_1 = layers.sum([tmp_0, self.feed_vars[2], self.feed_vars[3]])
tmp_2 = layers.sum([tmp_1, self.feed_vars[4]])
self.fetch_list = [tmp_0, tmp_1]
self.num_fused_ops = 1
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册