未验证 提交 88ad79d2 编写于 作者: X xiaoguoguo626807 提交者: GitHub

【prim】sin and cos composite grad (#51748)

* init

* close cinn

* close cinn

* add public_python_api
上级 d2122c6f
......@@ -40,4 +40,6 @@
- put_along_axis
- greater_than
- less_equal
- sin
- cos
- where
......@@ -1137,6 +1137,18 @@ void dropout_grad(const Tensor& mask,
}
}
template <typename T>
void sin_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) {
auto x_grad_tmp = cos<T>(x) * out_grad;
set_output<T>(x_grad_tmp, x_grad);
}
template <typename T>
void cos_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) {
auto x_grad_tmp = -sin<T>(x) * out_grad;
set_output<T>(x_grad_tmp, x_grad);
}
template <typename T>
void batch_norm_grad(const Tensor& x,
const Tensor& scale,
......
......@@ -290,6 +290,7 @@
kernel :
func : cos_grad
backward : cos_double_grad
composite : cos_grad(x, out_grad, x_grad)
inplace : (out_grad -> x_grad)
- backward_op : cos_triple_grad
......@@ -1355,6 +1356,7 @@
kernel :
func : sin_grad
backward : sin_double_grad
composite : sin_grad(x, out_grad, x_grad)
inplace : (out_grad -> x_grad)
- backward_op : sin_triple_grad
......
......@@ -1479,8 +1479,12 @@ class TestCos(TestActivation):
def setUp(self):
self.op_type = "cos"
self.python_api = paddle.cos
self.public_python_api = paddle.cos
self.prim_op_type = "prim"
self.init_dtype()
self.init_shape()
# prim not support now
self.enable_cinn = False
np.random.seed(1024)
x = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
......@@ -1495,7 +1499,7 @@ class TestCos(TestActivation):
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_prim=True)
class TestCos_ZeroDim(TestCos):
......@@ -1611,6 +1615,8 @@ class TestSin(TestActivation, TestParameter):
def setUp(self):
self.op_type = "sin"
self.python_api = paddle.sin
self.public_python_api = paddle.sin
self.prim_op_type = "prim"
self.init_dtype()
self.init_shape()
# prim not support now
......@@ -1629,7 +1635,7 @@ class TestSin(TestActivation, TestParameter):
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_prim=True)
class TestSin_ZeroDim(TestSin):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册