未验证 提交 769e24ce 编写于 作者: M Meteor Liu 提交者: GitHub

implement floor_grad by primitive logic (#51059)

* implement floor_grad by primitive logic

* implement floor_grad by primitive logic

* Merge branch 'develop' into floor_grad
上级 a8a2b7f4
...@@ -284,6 +284,15 @@ void sqrt_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) { ...@@ -284,6 +284,15 @@ void sqrt_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) {
} }
} }
template <typename T>
void floor_grad(const Tensor& out_grad, Tensor* x_grad) {
if (x_grad) {
auto zero_tensor =
full<T>(phi::vectorize(out_grad.dims()), 0.0, out_grad.dtype());
set_output<T>(zero_tensor, x_grad);
}
}
template <typename T> template <typename T>
void concat_grad(const std::vector<Tensor>& x, void concat_grad(const std::vector<Tensor>& x,
const Tensor& out_grad, const Tensor& out_grad,
......
...@@ -533,6 +533,7 @@ ...@@ -533,6 +533,7 @@
param: [out_grad] param: [out_grad]
kernel : kernel :
func : floor_grad func : floor_grad
composite : floor_grad(out_grad, x_grad)
inplace : (out_grad -> x_grad) inplace : (out_grad -> x_grad)
- backward_op : fold_grad - backward_op : fold_grad
......
...@@ -1434,6 +1434,49 @@ class TestFloor_ZeroDim(TestFloor): ...@@ -1434,6 +1434,49 @@ class TestFloor_ZeroDim(TestFloor):
self.shape = [] self.shape = []
class TestFloorPrim(TestActivation):
def setUp(self):
self.op_type = "floor"
self.prim_op_type = "prim"
# the gradient on floor, ceil, round is undefined.
# we return zero as gradient, but the numpy return nan.
# for prim, we compare result with eager python api,
# so, we use only_prim flag to express we only test prim.
self.only_prim = True
self.check_eager = True
self.python_api = paddle.floor
self.init_dtype()
self.init_shape()
if len(self.shape) == 0:
# for 0-D tensor, skip cinn testing
self.enable_cinn = False
np.random.seed(1024)
x = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
out = np.floor(x)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
def init_shape(self):
self.shape = [10, 12]
def test_check_grad(self):
self.check_grad(['X'], 'Out', check_prim=True)
class TestFloorPrim_ZeroDim(TestFloorPrim):
def init_shape(self):
self.shape = []
class TestFloorPrimFp16(TestFloorPrim):
def init_dtype(self):
self.dtype = np.float16
class TestCos(TestActivation): class TestCos(TestActivation):
def setUp(self): def setUp(self):
self.op_type = "cos" self.op_type = "cos"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册