未验证 提交 99ae88f1 编写于 作者: C cyber-pioneer 提交者: GitHub

fix addn infermeta (#56934)

* fix addn infermeta

* fix rule bug
上级 c71f5f9c
......@@ -30,6 +30,13 @@ vjp_interface_declare_gen_op_list = [
"add",
"concat",
"split",
"gelu",
"matmul",
"erf",
"multiply",
"subtract",
"pow",
"rsqrt",
]
vjp_interface_implementation_gen_op_list = [
"tanh",
......@@ -38,4 +45,11 @@ vjp_interface_implementation_gen_op_list = [
"add",
"concat",
"split",
"gelu",
"matmul",
"erf",
"multiply",
"subtract",
"pow",
"rsqrt",
]
......@@ -39,8 +39,8 @@ OpInfoTuple AddNOp::GetOpInfo() {
std::vector<paddle::dialect::OpAttributeInfo> attributes = {};
std::vector<paddle::dialect::OpOutputInfo> outputs = {
OpOutputInfo("out", "paddle::dialect::DenseTensorType", false, false)};
paddle::dialect::OpRunTimeInfo run_time_info =
OpRunTimeInfo("", {""}, {""}, {""}, {""}, {}, {}, {});
paddle::dialect::OpRunTimeInfo run_time_info = OpRunTimeInfo(
"AddNInferMeta", {"inputs"}, {"add_n"}, {"inputs"}, {}, {}, {}, {});
return std::make_tuple(inputs, attributes, outputs, run_time_info, "add_n");
}
......
......@@ -34,7 +34,7 @@ paddle::dialect::AddNOp, paddle::dialect::SplitGradOp
namespace paddle {
namespace dialect {
class AddNOp : public ir::Op<AddNOp, OpYamlInfoInterface> {
class AddNOp : public ir::Op<AddNOp, OpYamlInfoInterface, InferMetaInterface> {
public:
using Op::Op;
static const char *name() { return "pd.add_n"; }
......
......@@ -169,8 +169,11 @@ def decompose(
dst_vars,
op_filter,
)
for item in dst_vars:
for idx, item in enumerate(dst_vars):
if not isinstance(item, ir.OpResult):
if item is None:
dst_vars[idx] = src_vars[idx]
else:
raise TypeError(
f"Each var in dst_vars should map corresponding var in src_vars, but got type {type(item)} in {dst_vars}."
)
......
......@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle import _ir_ops
from .primitives import * # noqa: F403
from .register import register_decomp
......@@ -34,3 +36,40 @@ def mean(x, axis, keepdim):
)
res = divide(sum_x, norm)
return res
@register_decomp('pd.gelu')
def gelu_composite(x, approximate):
"""define composite rule of op gelu"""
M_SQRT1_2 = (
0.70710678118654752440 # /* 1/sqrt(2) */ copy from gelu-kernel.cc
)
M_2_SQRTPI = 1.12837916709551257390 # /* 2/sqrt(pi) */
full_shape = x.shape if len(x.shape) == 0 else [1]
one = ones(full_shape, x.dtype)
half = full(full_shape, 0.5, x.dtype)
# Todo(cz): after symbol overload, add and multiply will be replaced by "+" and "*"
if approximate:
# gelu(x) = 0.5 * x * (1 + tanh(sqrt(2 / \pi) * (x + 0.044715 * x^{3})))
kAlpha = full(full_shape, M_2_SQRTPI * M_SQRT1_2, x.dtype)
GELU_CONSTANT = full(full_shape, 0.044715, x.dtype)
tanh_out = tanh(kAlpha * (x + GELU_CONSTANT * x * x * x))
out = x * half * (one + tanh_out)
return out
else:
# gelu(x) = 0.5 * x * (1 + erf(x / sqrt(2)))
cdf = _ir_ops.multiply(
half,
(
_ir_ops.add(
one,
_ir_ops.erf(
_ir_ops.multiply(x, full(x.shape, M_SQRT1_2, x.dtype))
),
)
),
)
out = _ir_ops.multiply(x, cdf)
return out
set(TEST_PRIM_PURE_NEW_IR_CASES test_prim_program)
set(TEST_PRIM_PURE_NEW_IR_CASES test_prim_program test_prim_simpnet)
foreach(target ${TEST_PRIM_PURE_NEW_IR_CASES})
py_test_modules(${target} MODULES ${target} ENVS GLOG_v=1
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from paddle import _ir_ops, nn
from paddle.autograd.ir_backward import grad
from paddle.decomposition import decompose
from paddle.framework import core
paddle.enable_static()
class SimpNet(nn.Layer):
def __init__(self):
super().__init__()
def forward(self, x, linear1_weight, linear2_weight):
x2 = _ir_ops.matmul(x, linear1_weight, False, False)
x3 = _ir_ops.gelu(x2, False)
res = _ir_ops.matmul(x3, linear2_weight, False, False)
return res
class TestPrimMode(unittest.TestCase):
def setUp(self):
np.random.seed(2023)
self.shape_x = [2, 1024, 1024]
self.shape_y = [2, 1024, 1024]
self.shape_l1_w = [2, 1024, 4096]
self.shape_l2_w = [2, 4096, 1024]
self.x = np.random.random(self.shape_x).astype("float32")
self.y = np.random.random(self.shape_y).astype("float32")
self.l1_w = np.random.random(self.shape_l1_w).astype("float32")
self.l2_w = np.random.random(self.shape_l2_w).astype("float32")
def base_net(self, flag=None):
if flag == "all":
core._set_prim_all_enabled(True)
main_program = paddle.static.Program()
with paddle.static.program_guard(main_program):
net = SimpNet()
x = paddle.static.data('x', self.shape_x, dtype='float32')
y = paddle.static.data('y', self.shape_y, dtype='float32')
x.stop_gradient = False
y.stop_gradient = False
l1_w = paddle.static.data('l1_w', self.shape_l1_w, dtype='float32')
l2_w = paddle.static.data('l2_w', self.shape_l2_w, dtype='float32')
divide_out = paddle.divide(x, y)
res = net(divide_out, l1_w, l2_w)
[res2] = decompose(main_program, [res])
gradients = grad(res2, (x, y))
exe = paddle.static.Executor()
outs = exe.run(
feed={
'x': self.x,
'y': self.y,
'l1_w': self.l1_w,
'l2_w': self.l2_w,
},
fetch_list=[res2, gradients[0], gradients[1]],
)
whole_ops = [op.name() for op in main_program.block().ops]
if flag == "all":
core._set_prim_all_enabled(False)
assert (
'pd.gelu' not in whole_ops and 'pd.divide_grad' not in whole_ops
)
return outs
def test_prim_all(self):
res_ref = self.base_net()
res = self.base_net("all")
for ref, actual in zip(res_ref, res):
np.testing.assert_allclose(ref, actual, rtol=1e-6)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册