未验证 提交 30e0409c 编写于 作者: M Meteor Liu 提交者: GitHub

DLTP-66486:implement log_grad by primitive logic (#51296)

上级 c4a02ed0
...@@ -388,6 +388,14 @@ void expand_grad(const Tensor& x, ...@@ -388,6 +388,14 @@ void expand_grad(const Tensor& x,
} }
} }
template <typename T>
void log_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) {
if (x_grad) {
// dx = dout / x
set_output<T>(out_grad / x, x_grad);
}
}
template <typename T> template <typename T>
void exp_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) { void exp_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) {
if (x_grad) { if (x_grad) {
......
...@@ -804,6 +804,7 @@ ...@@ -804,6 +804,7 @@
kernel : kernel :
func : log_grad func : log_grad
backward : log_double_grad backward : log_double_grad
composite : log_grad(x, out_grad, x_grad)
inplace : (out_grad -> x_grad) inplace : (out_grad -> x_grad)
- backward_op : log_loss_grad - backward_op : log_loss_grad
......
...@@ -2623,10 +2623,15 @@ class TestLog(TestActivation): ...@@ -2623,10 +2623,15 @@ class TestLog(TestActivation):
def setUp(self): def setUp(self):
self.op_type = "log" self.op_type = "log"
self.check_eager = True self.check_eager = True
self.prim_op_type = "prim"
self.python_api = paddle.log self.python_api = paddle.log
self.init_dtype() self.init_dtype()
self.init_shape() self.init_shape()
if len(self.shape) == 0:
# for 0-D tensor, skip cinn testing
self.enable_cinn = False
np.random.seed(1024) np.random.seed(1024)
x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype) x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype)
out = np.log(x) out = np.log(x)
...@@ -2637,7 +2642,7 @@ class TestLog(TestActivation): ...@@ -2637,7 +2642,7 @@ class TestLog(TestActivation):
def test_check_grad(self): def test_check_grad(self):
if self.dtype == np.float16: if self.dtype == np.float16:
return return
self.check_grad(['X'], 'Out', check_eager=True) self.check_grad(['X'], 'Out', check_eager=True, check_prim=True)
def test_error(self): def test_error(self):
in1 = paddle.static.data(name="in1", shape=[11, 17], dtype="int32") in1 = paddle.static.data(name="in1", shape=[11, 17], dtype="int32")
...@@ -3846,7 +3851,7 @@ create_test_act_fp16_class(TestSoftRelu, grad_atol=0.85) ...@@ -3846,7 +3851,7 @@ create_test_act_fp16_class(TestSoftRelu, grad_atol=0.85)
create_test_act_fp16_class(TestELU) create_test_act_fp16_class(TestELU)
create_test_act_fp16_class(TestCELU) create_test_act_fp16_class(TestCELU)
create_test_act_fp16_class(TestReciprocal) create_test_act_fp16_class(TestReciprocal)
create_test_act_fp16_class(TestLog) create_test_act_fp16_class(TestLog, check_prim=True)
if core.is_compiled_with_rocm(): if core.is_compiled_with_rocm():
create_test_act_fp16_class(TestLog2, atol=5e-2, grad_atol=0.85) create_test_act_fp16_class(TestLog2, atol=5e-2, grad_atol=0.85)
else: else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册