未验证 提交 3239a7b3 编写于 作者: F FlyingQianMM 提交者: GitHub

fix wrong backward grad of amin/amax (#51301)

上级 64a21f71
...@@ -92,7 +92,14 @@ void ReduceCudaAMaxAMinGrad(const Context& dev_ctx, ...@@ -92,7 +92,14 @@ void ReduceCudaAMaxAMinGrad(const Context& dev_ctx,
reduce_dims, reduce_dims,
false); false);
// 3. dx = Div(dout, equal_out) // 3. dx = dout * 1
std::vector<const phi::DenseTensor*> mul_inputs = {&new_dout,
&equal_out_tensor};
std::vector<phi::DenseTensor*> mul_outputs = {&equal_out_tensor};
funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
dev_ctx, mul_inputs, &mul_outputs, 0, funcs::MultiplyFunctor<T>());
// 4. dx = Div(dx, equal_out)
std::vector<const phi::DenseTensor*> grad_inputs = {&equal_out_tensor, std::vector<const phi::DenseTensor*> grad_inputs = {&equal_out_tensor,
equal_count}; equal_count};
std::vector<phi::DenseTensor*> grad_outputs = {new_dx_tensor}; std::vector<phi::DenseTensor*> grad_outputs = {new_dx_tensor};
......
...@@ -198,5 +198,43 @@ class TestMaxMinAmaxAminAPI6(TestMaxMinAmaxAminAPI): ...@@ -198,5 +198,43 @@ class TestMaxMinAmaxAminAPI6(TestMaxMinAmaxAminAPI):
self.keepdim = False self.keepdim = False
# test input grad when out is operated like mutiply
class TestMaxMinAmaxAminAPI7(TestMaxMinAmaxAminAPI):
def init_case(self):
self.x_np = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]).astype(
np.int32
)
self.shape = [2, 2, 2]
self.dtype = 'int32'
self.axis = (0, 1)
self.keepdim = False
# As dygraph is easy to compute gradient, we check the gradient between
# paddle API and numpy in dygraph.
def test_dygraph(self):
def _test_dygraph(func):
paddle.disable_static()
x = paddle.to_tensor(
self.x_np, dtype=self.dtype, stop_gradient=False
)
out = self._choose_paddle_func(func, x)
loss = out * 2
grad_tensor = paddle.ones_like(x)
paddle.autograd.backward([loss], [grad_tensor], True)
np.testing.assert_allclose(
self.np_out[func], out.numpy(), rtol=1e-05
)
np.testing.assert_allclose(
self.np_grad[func] * 2, x.grad, rtol=1e-05
)
paddle.enable_static()
_test_dygraph('amax')
_test_dygraph('amin')
_test_dygraph('max')
_test_dygraph('min')
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册