未验证 提交 9844aafb 编写于 作者: C Chen Weihang 提交者: GitHub

[Phi] Add swish yaml and final state api (#41479)

* add swish yaml and final state api

* skip mkldnn test

* fix grad mkldnn test
上级 b3bcebbe
......@@ -113,6 +113,7 @@ class TestMKLDNNSwishDim2(TestSwish):
super(TestMKLDNNSwishDim2, self).setUp()
self.attrs["use_mkldnn"] = True
self.check_eager = False
def init_dtype(self):
self.dtype = np.float32
......@@ -284,6 +285,7 @@ class TestMKLDNNSwishDim4(TestSwish):
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
self.attrs = {"use_mkldnn": True, "beta": beta}
self.check_eager = False
def init_dtype(self):
self.dtype = np.float32
......
......@@ -2940,7 +2940,9 @@ def ref_swish(x):
class TestSwish(TestActivation):
def setUp(self):
self.op_type = "swish"
self.python_api = paddle.nn.functional.swish
self.init_dtype()
self.check_eager = True
np.random.seed(1024)
x = np.random.uniform(-1, 1, [10, 12]).astype(self.dtype)
......@@ -2952,7 +2954,10 @@ class TestSwish(TestActivation):
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out')
check_eager = False
if hasattr(self, 'check_eager'):
check_eager = self.check_eager
self.check_grad(['X'], 'Out', check_eager=check_eager)
class TestSwishAPI(unittest.TestCase):
......@@ -2987,6 +2992,10 @@ class TestSwishAPI(unittest.TestCase):
self.assertEqual(np.allclose(out_ref, r.numpy()), True)
paddle.enable_static()
def test_dygraph_final_state_api(self):
with _test_eager_guard():
self.test_dygraph_api()
def test_fluid_api(self):
paddle.enable_static()
with fluid.program_guard(fluid.Program()):
......
......@@ -1181,8 +1181,9 @@ def swish(x, name=None):
x = paddle.to_tensor(np.array([-2., 0., 1.]))
out = F.swish(x) # [-0.238406, 0., 0.731059]
"""
if in_dynamic_mode():
if in_dygraph_mode():
return _C_ops.final_state_swish(x, 1.0)
if _in_legacy_dygraph():
return _C_ops.swish(x, 'beta', 1.0)
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'swish')
......
......@@ -1876,6 +1876,17 @@
data_type : x
backward : sum_grad
# The python API paddle.nn.functional.swish has no `bete` argument, it may be removed later
- api : swish
args : (Tensor x, float beta=1.0)
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : swish
backward : swish_grad
# take_along_axis
- api : take_along_axis
args : (Tensor x, Tensor index, int axis)
......
......@@ -1410,6 +1410,16 @@
kernel :
func : sum_grad
- backward_api : swish_grad
forward : swish (Tensor x, float beta=1.0) -> Tensor(out)
args : (Tensor x, Tensor out_grad, float bete=1.0)
output : Tensor(x_grad)
infer_meta :
func : GeneralUnaryGradInferMeta
param : [x]
kernel :
func : swish_grad
- backward_api : take_along_axis_grad
forward : take_along_axis (Tensor x, Tensor index, int axis) -> Tensor(out)
args : (Tensor x, Tensor index, Tensor out_grad, int axis)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册