未验证 提交 3508bd28 编写于 作者: D dyning 提交者: GitHub

Add the op def for elementwise_mul and enhance layer_norm_fuse_pass (#33560)

上级 11f5a400
......@@ -99,6 +99,122 @@ void addIntermediateOut(Node* op_node, const std::string& out_name,
} // namespace
LayerNormFusePass::LayerNormFusePass() {
AddOpCompat(OpCompat("layer_norm"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddOutput("Mean")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Variance")
.IsTensor()
.IsOptional()
.End()
.AddAttr("epsilon")
.IsNumGE(0.0f)
.IsNumLE(0.001f)
.End()
.AddAttr("begin_norm_axis")
.IsNumGT(0)
.End();
AddOpCompat(OpCompat("reduce_mean"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("dim")
.IsType<std::vector<int>>()
.End()
.AddAttr("keep_dim")
.IsBoolEQ(true)
.End();
AddOpCompat(OpCompat("sqrt"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();
AddOpCompat(OpCompat("elementwise_sub"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsNumEQ(1)
.End();
AddOpCompat(OpCompat("elementwise_pow"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsNumEQ(1)
.End();
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsNumEQ(1)
.End();
AddOpCompat(OpCompat("elementwise_div"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsNumEQ(1)
.End();
AddOpCompat(OpCompat("elementwise_mul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsNumEQ(1)
.End();
}
void LayerNormFusePass::ApplyImpl(Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(graph,
platform::errors::InvalidArgument(
......@@ -117,6 +233,10 @@ void LayerNormFusePass::ApplyImpl(Graph* graph) const {
int found_layer_norm_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
VLOG(4) << "Fuse LayerNorm from subgraph.";
GET_IR_NODE_FROM_SUBGRAPH(x, x, layer_norm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(x_mean, x_mean, layer_norm_pattern);
......@@ -205,6 +325,12 @@ void LayerNormFusePass::ApplyImpl(Graph* graph) const {
ln_op_desc.SetAttr("begin_norm_axis", static_cast<int>(x_shape.size() - 1));
ln_op_desc.SetAttr("epsilon", *(eps_tensor->data<float>()));
ln_op_desc.SetAttr("is_test", true);
if (!IsCompat(ln_op_desc)) {
LOG(WARNING) << "layer norm pass in out layer_norm op compat failed.";
return;
}
Node* ln_op = g->CreateOpNode(&ln_op_desc);
addIntermediateOut(ln_op, "Mean", scope_name_, g);
......
......@@ -70,6 +70,7 @@ namespace ir {
*/
class LayerNormFusePass : public FusePassBase {
public:
LayerNormFusePass();
virtual ~LayerNormFusePass() {}
protected:
......
......@@ -66,12 +66,16 @@ class LayerNormFuseTest {
x_mean->SetAttr("keep_dim", true);
x_mean->SetAttr("reduce_all", false);
test::CreateOp(&m_prog, "elementwise_sub",
auto* x_sub = test::CreateOp(&m_prog, "elementwise_sub",
{{"X", "x"}, {"Y", "x_mean_out"}},
{{"Out", "x_sub_mean_out"}}, false);
test::CreateOp(&m_prog, "elementwise_pow",
x_sub->SetAttr("axis", 1);
auto* x_pow = test::CreateOp(&m_prog, "elementwise_pow",
{{"X", "x_sub_mean_out"}, {"Y", "sqr_pow"}},
{{"Out", "x_sub_mean_sqr_out"}}, false);
x_pow->SetAttr("axis", 1);
auto* std_dev =
test::CreateOp(&m_prog, "reduce_mean", {{"X", "x_sub_mean_sqr_out"}},
{{"Out", "std_dev_out"}}, false);
......@@ -79,20 +83,29 @@ class LayerNormFuseTest {
std_dev->SetAttr("keep_dim", true);
std_dev->SetAttr("reduce_all", false);
test::CreateOp(&m_prog, "elementwise_add",
auto* x_add = test::CreateOp(&m_prog, "elementwise_add",
{{"X", "std_dev_out"}, {"Y", "eps"}},
{{"Out", "std_dev_eps_out"}}, false);
x_add->SetAttr("axis", 1);
test::CreateOp(&m_prog, "sqrt", {{"X", "std_dev_eps_out"}},
{{"Out", "std_dev_eps_sqrt_out"}}, false);
auto* x_div =
test::CreateOp(&m_prog, "elementwise_div",
{{"X", "x_sub_mean_out"}, {"Y", "std_dev_eps_sqrt_out"}},
{{"Out", "division_out"}}, false);
test::CreateOp(&m_prog, "elementwise_mul",
x_div->SetAttr("axis", 1);
auto* x_mul = test::CreateOp(&m_prog, "elementwise_mul",
{{"X", "division_out"}, {"Y", "gamma"}},
{{"Out", "scale_out"}}, false);
test::CreateOp(&m_prog, "elementwise_add",
{{"X", "scale_out"}, {"Y", "beta"}}, {{"Out", "shift_out"}},
false);
x_mul->SetAttr("axis", 1);
auto* x_add_v1 = test::CreateOp(&m_prog, "elementwise_add",
{{"X", "scale_out"}, {"Y", "beta"}},
{{"Out", "shift_out"}}, false);
x_add_v1->SetAttr("axis", 1);
}
template <typename Func>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册