未验证 提交 b7e4d974 编写于 作者: G GGBond8488 提交者: GitHub

add prim erf grad (#50436)

* add prim erf grad

* add yaml config for prim erf grad

* add math.h

* add cmath

* add math  defines

* use define math

* use define math

* define M_2_SQRTPI

* M_2_SQRTPI math

* try math.h

* fix typro

* remove pow in erf grad

* use new optest

* add fp16 fp32 test

* remove fp16 test
上级 f3cac966
...@@ -14,6 +14,12 @@ ...@@ -14,6 +14,12 @@
#pragma once #pragma once
#ifndef _USE_MATH_DEFINES
#define _USE_MATH_DEFINES
#endif
#include <math.h>
#include "paddle/fluid/prim/api/all.h" #include "paddle/fluid/prim/api/all.h"
#include "paddle/fluid/prim/api/generated_prim/prim_generated_api.h" #include "paddle/fluid/prim/api/generated_prim/prim_generated_api.h"
#include "paddle/phi/common/int_array.h" #include "paddle/phi/common/int_array.h"
...@@ -881,5 +887,16 @@ void gather_nd_grad(const Tensor& x, ...@@ -881,5 +887,16 @@ void gather_nd_grad(const Tensor& x,
} }
} }
template <typename T>
void erf_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) {
if (x_grad) {
auto m_2_sqrt_pi = full<T>(phi::vectorize(x.dims()), M_2_SQRTPI, x.dtype());
auto neg_one = full<T>(phi::vectorize(x.dims()), -1.0, x.dtype());
auto neg_tmp = neg_one * x * x;
auto mul_tmp = m_2_sqrt_pi * exp<T>(neg_tmp);
set_output<T>(out_grad * mul_tmp, x_grad);
}
}
} // namespace prim } // namespace prim
} // namespace paddle } // namespace paddle
...@@ -423,6 +423,7 @@ ...@@ -423,6 +423,7 @@
kernel : kernel :
func : erf_grad func : erf_grad
data_type : out_grad data_type : out_grad
composite : erf_grad(x, out_grad, x_grad)
- backward_op : erfinv_grad - backward_op : erfinv_grad
forward : erfinv (Tensor x) -> Tensor(out) forward : erfinv (Tensor x) -> Tensor(out)
......
...@@ -26,6 +26,8 @@ import paddle.fluid.dygraph as dg ...@@ -26,6 +26,8 @@ import paddle.fluid.dygraph as dg
class TestErfOp(OpTest): class TestErfOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "erf" self.op_type = "erf"
self.prim_op_type = "prim"
self.enable_cinn = True
self.python_api = paddle.erf self.python_api = paddle.erf
self.dtype = self._init_dtype() self.dtype = self._init_dtype()
self.x_shape = [11, 17] self.x_shape = [11, 17]
...@@ -41,7 +43,7 @@ class TestErfOp(OpTest): ...@@ -41,7 +43,7 @@ class TestErfOp(OpTest):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out', check_prim=True)
class TestErfLayer(unittest.TestCase): class TestErfLayer(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册