diff --git a/paddle/phi/kernels/gpu/reduce_amin_amax_common.h b/paddle/phi/kernels/gpu/reduce_amin_amax_common.h index ed6e0ef51558ac75ab95e6640d58604e018948c0..ff15b5698f8626e179475dbb4e7b503fa03500a7 100644 --- a/paddle/phi/kernels/gpu/reduce_amin_amax_common.h +++ b/paddle/phi/kernels/gpu/reduce_amin_amax_common.h @@ -92,7 +92,14 @@ void ReduceCudaAMaxAMinGrad(const Context& dev_ctx, reduce_dims, false); - // 3. dx = Div(dout, equal_out) + // 3. dx = dout * 1 + std::vector mul_inputs = {&new_dout, + &equal_out_tensor}; + std::vector mul_outputs = {&equal_out_tensor}; + funcs::BroadcastKernel( + dev_ctx, mul_inputs, &mul_outputs, 0, funcs::MultiplyFunctor()); + + // 4. dx = Div(dx, equal_out) std::vector grad_inputs = {&equal_out_tensor, equal_count}; std::vector grad_outputs = {new_dx_tensor}; diff --git a/python/paddle/fluid/tests/unittests/test_max_min_amax_amin_op.py b/python/paddle/fluid/tests/unittests/test_max_min_amax_amin_op.py index 33659356ba97b019da89f2e18acdb943e18d0af8..679dc7060f73967bbbe48ace0f639fed5dfd9c54 100644 --- a/python/paddle/fluid/tests/unittests/test_max_min_amax_amin_op.py +++ b/python/paddle/fluid/tests/unittests/test_max_min_amax_amin_op.py @@ -198,5 +198,43 @@ class TestMaxMinAmaxAminAPI6(TestMaxMinAmaxAminAPI): self.keepdim = False +# test input grad when out is operated like mutiply +class TestMaxMinAmaxAminAPI7(TestMaxMinAmaxAminAPI): + def init_case(self): + self.x_np = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]).astype( + np.int32 + ) + self.shape = [2, 2, 2] + self.dtype = 'int32' + self.axis = (0, 1) + self.keepdim = False + + # As dygraph is easy to compute gradient, we check the gradient between + # paddle API and numpy in dygraph. + def test_dygraph(self): + def _test_dygraph(func): + paddle.disable_static() + x = paddle.to_tensor( + self.x_np, dtype=self.dtype, stop_gradient=False + ) + out = self._choose_paddle_func(func, x) + loss = out * 2 + grad_tensor = paddle.ones_like(x) + paddle.autograd.backward([loss], [grad_tensor], True) + + np.testing.assert_allclose( + self.np_out[func], out.numpy(), rtol=1e-05 + ) + np.testing.assert_allclose( + self.np_grad[func] * 2, x.grad, rtol=1e-05 + ) + paddle.enable_static() + + _test_dygraph('amax') + _test_dygraph('amin') + _test_dygraph('max') + _test_dygraph('min') + + if __name__ == '__main__': unittest.main()