未验证 提交 ca9e77a8 编写于 作者: W wangchaochaohu 提交者: GitHub

add sum op support for fusion group (#22771)

* Add the codegen and auto fusion for sum Op  in fusion group
上级 b681215a
...@@ -60,18 +60,21 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions( ...@@ -60,18 +60,21 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
// - X, Y in forward operations // - X, Y in forward operations
// - X, Y, Out, out@GRAD in backward operations // - X, Y, Out, out@GRAD in backward operations
std::vector<int> input_ids; std::vector<int> input_ids;
std::vector<std::string> input_names = auto operation = OperationMap::Instance().Get(op->Type());
OperationMap::Instance().Get(op->Type()).input_names; std::vector<std::string> input_names = operation.input_names;
for (auto& name : input_names) { for (auto& name : input_names) {
// Some input vars are not used in grad ops, such as // Some input vars are not used in grad ops, such as
// "elementwise_add_grad", where "X", "Y" and "Out" are not used. // "elementwise_add_grad", where "X", "Y" and "Out" are not used.
if (HasInput(node, name) && op->Input(name).size() >= 1U) {
// TODO(liuyiqun): support duplicated input. if ((HasInput(node, name) && op->Input(name).size() >= 1U)) {
for (size_t i = 0; i < op->Input(name).size(); i++) {
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
var_ids.find(op->Input(name)[0]), var_ids.end(), var_ids.find(op->Input(name)[i]), var_ids.end(),
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Input(%s) of operation %s is not set.", name, op->Type())); "Input(%s) of operation %s is not set.", name, op->Type()));
input_ids.push_back(var_ids[op->Input(name)[0]]); input_ids.push_back(var_ids[op->Input(name)[i]]);
}
} else { } else {
input_ids.push_back(-1); input_ids.push_back(-1);
} }
......
...@@ -33,9 +33,32 @@ static T StringTo(const std::string& str) { ...@@ -33,9 +33,32 @@ static T StringTo(const std::string& str) {
return value; return value;
} }
static std::string ExpandMultivariateTemplate(const std::string rhs,
const size_t input_size) {
int start_pos = rhs.find("[", 0);
int end_pos = rhs.find("]", 0);
std::string sum_rhs = rhs.substr(0, start_pos);
std::string sum_rhs_component =
rhs.substr(start_pos + 1, (end_pos - start_pos - 1));
int replace_pos = sum_rhs_component.find("?", 0);
for (size_t i = 1; i < input_size; i++) {
std::string append_str =
sum_rhs_component.replace(replace_pos, 1, std::to_string(i));
sum_rhs = sum_rhs + append_str;
}
return sum_rhs;
}
std::string OperationExpression::GetRHS(std::unordered_set<int>* used, std::string OperationExpression::GetRHS(std::unordered_set<int>* used,
size_t i) const { size_t exprs_index) const {
auto rhs = OperationMap::Instance().Get(op_type_).exprs[i]; auto rhs = OperationMap::Instance().Get(op_type_).exprs[exprs_index];
auto num_operands = OperationMap::Instance().Get(op_type_).num_operands;
if (num_operands == -1) {
size_t input_size = input_ids_.size();
rhs = ExpandMultivariateTemplate(rhs, input_size);
}
for (size_t i = 0; i < rhs.size(); i++) { for (size_t i = 0; i < rhs.size(); i++) {
size_t pos = i; size_t pos = i;
if (rhs[pos] == '$' && rhs[pos + 1] == '{') { if (rhs[pos] == '$' && rhs[pos + 1] == '{') {
......
...@@ -52,7 +52,8 @@ class OperationExpression { ...@@ -52,7 +52,8 @@ class OperationExpression {
private: private:
// TODO(wangchao): make offset more flexible we add stride and basic offset // TODO(wangchao): make offset more flexible we add stride and basic offset
std::string GetRHS(std::unordered_set<int>* used, size_t i = 0) const; std::string GetRHS(std::unordered_set<int>* used,
size_t exprs_index = 0) const;
std::string GetLHS(size_t i = 0) const; std::string GetLHS(size_t i = 0) const;
private: private:
......
...@@ -24,23 +24,13 @@ namespace framework { ...@@ -24,23 +24,13 @@ namespace framework {
namespace ir { namespace ir {
namespace fusion_group { namespace fusion_group {
static std::unordered_set<std::string> binary_op_types; static std::unordered_set<std::string> elementwise_op_types;
static std::unordered_set<std::string> unary_op_types;
static std::unordered_set<std::string>& GetBinaryOpTypes() { static std::unordered_set<std::string>& GetElementwiseOpTypes() {
if (binary_op_types.empty()) { if (elementwise_op_types.empty()) {
binary_op_types = elementwise_op_types = OperationMap::Instance().Find(/* type= */ 0);
OperationMap::Instance().Find(/* type= */ 0, /* num_operands= */ 2);
} }
return binary_op_types; return elementwise_op_types;
}
static std::unordered_set<std::string>& GetUnaryOpTypes() {
if (unary_op_types.empty()) {
unary_op_types =
OperationMap::Instance().Find(/* type= */ 0, /* num_operands= */ 1);
}
return unary_op_types;
} }
static bool IsSpecifiedOp(const std::unordered_set<std::string>& op_types, static bool IsSpecifiedOp(const std::unordered_set<std::string>& op_types,
...@@ -70,13 +60,8 @@ static bool IsEqualAndNotEmpty(const std::vector<int64_t>& l, ...@@ -70,13 +60,8 @@ static bool IsEqualAndNotEmpty(const std::vector<int64_t>& l,
return l.size() != 0U && r.size() != 0U && l == r; return l.size() != 0U && r.size() != 0U && l == r;
} }
static bool IsBinaryOp(const Node* n) { bool ElementwiseGroupDetector::IsElementwiseOp(const Node* n) {
if (IsSpecifiedOp(GetBinaryOpTypes(), n)) { if (IsSpecifiedOp(GetElementwiseOpTypes(), n)) {
if ((!IsGradOp(n) && n->inputs.size() != 2U) || n->inputs.size() == 0U) {
return false;
}
// The shape of all inputs should be the same.
std::vector<int64_t> shape_0; std::vector<int64_t> shape_0;
for (size_t i = 0; i < n->inputs.size(); ++i) { for (size_t i = 0; i < n->inputs.size(); ++i) {
auto* in_i = n->inputs[i]; auto* in_i = n->inputs[i];
...@@ -98,14 +83,6 @@ static bool IsBinaryOp(const Node* n) { ...@@ -98,14 +83,6 @@ static bool IsBinaryOp(const Node* n) {
return false; return false;
} }
static bool IsUnaryOp(const Node* n) {
return IsSpecifiedOp(GetUnaryOpTypes(), n);
}
bool ElementwiseGroupDetector::IsElementwiseOp(const Node* n) {
return IsBinaryOp(n) || IsUnaryOp(n);
}
std::vector<std::vector<Node*>> ElementwiseGroupDetector::operator()( std::vector<std::vector<Node*>> ElementwiseGroupDetector::operator()(
Graph* graph) { Graph* graph) {
auto teller = [&](const Node* n) -> bool { return IsElementwiseOp(n); }; auto teller = [&](const Node* n) -> bool { return IsElementwiseOp(n); };
......
...@@ -25,13 +25,13 @@ OperationMap* OperationMap::map = nullptr; ...@@ -25,13 +25,13 @@ OperationMap* OperationMap::map = nullptr;
OperationMap::OperationMap() { OperationMap::OperationMap() {
InsertUnaryElementwiseOperations(); InsertUnaryElementwiseOperations();
InsertBinaryElementwiseOperations(); InsertBinaryElementwiseOperations();
InsertMultivariateElementwiseOperations();
} }
std::unordered_set<std::string> OperationMap::Find(int type, int num_operands) { std::unordered_set<std::string> OperationMap::Find(int type) {
std::unordered_set<std::string> res; std::unordered_set<std::string> res;
for (auto& t : operations_) { for (auto& t : operations_) {
if ((t.second.type == type) && if (t.second.type == type) {
(num_operands < 0 || t.second.num_operands == num_operands)) {
res.insert(t.first); res.insert(t.first);
} }
} }
...@@ -153,6 +153,18 @@ void OperationMap::InsertBinaryElementwiseOperations() { ...@@ -153,6 +153,18 @@ void OperationMap::InsertBinaryElementwiseOperations() {
{"${3} * (${0} > ${1})", "${3} * (${0} <= ${1})"}); {"${3} * (${0} > ${1})", "${3} * (${0} <= ${1})"});
} }
void OperationMap::InsertMultivariateElementwiseOperations() {
auto insert_handler = [&](std::string op_type, std::string expr,
std::vector<std::string> grad_exprs) {
int type = 0;
int num_oprands = -1;
// here ... represent the number of input is changed
Insert(type, num_oprands, op_type, expr, grad_exprs, {"X"}, {"Out"});
};
insert_handler("sum", "${0}[ + ${?}]", {});
}
} // namespace fusion_group } // namespace fusion_group
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
......
...@@ -84,7 +84,7 @@ class OperationMap { ...@@ -84,7 +84,7 @@ class OperationMap {
return *map; return *map;
} }
std::unordered_set<std::string> Find(int type, int num_operands = -1); std::unordered_set<std::string> Find(int type);
bool Has(std::string op_type) { bool Has(std::string op_type) {
return operations_.find(op_type) != operations_.end(); return operations_.find(op_type) != operations_.end();
...@@ -106,6 +106,7 @@ class OperationMap { ...@@ -106,6 +106,7 @@ class OperationMap {
void InsertUnaryElementwiseOperations(); void InsertUnaryElementwiseOperations();
void InsertBinaryElementwiseOperations(); void InsertBinaryElementwiseOperations();
void InsertMultivariateElementwiseOperations();
private: private:
static OperationMap* map; static OperationMap* map;
......
...@@ -138,5 +138,18 @@ class FusionGroupPassTestFP16(FusionGroupPassTest): ...@@ -138,5 +138,18 @@ class FusionGroupPassTestFP16(FusionGroupPassTest):
self.num_fused_ops = 1 self.num_fused_ops = 1
class FusionGroupPassSumTest(FusionGroupPassTest):
def build_program(self, dtype):
with fluid.program_guard(self.main_program, self.startup_program):
self.feed_vars = self._prepare_feed_vars([32, 128], dtype, 5)
tmp_0 = layers.elementwise_add(self.feed_vars[0], self.feed_vars[1])
tmp_1 = layers.sum([tmp_0, self.feed_vars[2], self.feed_vars[3]])
tmp_2 = layers.sum([tmp_1, self.feed_vars[4]])
self.fetch_list = [tmp_0, tmp_1]
self.num_fused_ops = 1
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册