diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index 74391958189186983360df0a14dcd273d211f2b1..d790d226b2d379e410b18c67ab56d88ed442bd61 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -598,9 +598,7 @@ void LerpInferMeta(const MetaTensor& x, auto w_dims = weight.dims(); DDim out_dims; out_dims = funcs::GetOutputDims(x_dims, y_dims); - if (w_dims.size() > 1 || w_dims[0] != 1) { - out_dims = funcs::GetOutputDims(out_dims, w_dims); - } + out_dims = funcs::GetOutputDims(out_dims, w_dims); out->set_dims(out_dims); out->set_dtype(x.dtype()); out->share_lod(x); diff --git a/paddle/phi/kernels/gpu/lerp_grad_kernel.cu b/paddle/phi/kernels/gpu/lerp_grad_kernel.cu index b097e4ce4d07a76d210b8fa25ba8824a906013cd..f42f316aae98038179282cbf9ade4844f3065868 100644 --- a/paddle/phi/kernels/gpu/lerp_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/lerp_grad_kernel.cu @@ -77,9 +77,15 @@ __global__ void LerpGradScalarKernelImpl(const T* weight, bool XYNeedReduce(const DenseTensor& x, const DenseTensor& y, const DenseTensor& out) { - auto x_dims = x.dims(); - auto y_dims = y.dims(); + auto x_dims = + x.dims().size() ? x.dims() : make_ddim(std::vector(1, 1)); + auto y_dims = + y.dims().size() ? y.dims() : make_ddim(std::vector(1, 1)); + auto out_dims = out.dims(); + if (out_dims.size() == 0) { + return false; + } int x_rank = x_dims.size(); int y_rank = y_dims.size(); int out_rank = out_dims.size(); @@ -166,10 +172,10 @@ void LerpGradKernel(const Context& ctx, const int rank = out.dims().size(); PADDLE_ENFORCE_GE( rank, - 1, + 0, phi::errors::InvalidArgument( "The number of dimensions for LerpGradOp must be " - "greater than or equal to 1, but the value received is %d.", + "greater than or equal to 0, but the value received is %d.", rank)); PADDLE_ENFORCE_LE( rank, @@ -231,9 +237,12 @@ void LerpGradKernel(const Context& ctx, x_grad_data, y_grad_data); + auto zero_dim = make_ddim(std::vector(1, 1)); if (x_grad) { std::vector reduce_axis_x = - funcs::GetReduceDim(x_grad->dims(), b_xgrad.dims(), -1); + funcs::GetReduceDim(x_grad->dims().size() ? x_grad->dims() : zero_dim, + b_xgrad.dims(), + -1); if (!reduce_axis_x.empty()) { phi::funcs:: ReduceKernel>( @@ -245,7 +254,9 @@ void LerpGradKernel(const Context& ctx, if (y_grad) { std::vector reduce_axis_y = - funcs::GetReduceDim(y_grad->dims(), b_ygrad.dims(), -1); + funcs::GetReduceDim(y_grad->dims().size() ? y_grad->dims() : zero_dim, + b_ygrad.dims(), + -1); if (!reduce_axis_y.empty()) { phi::funcs:: ReduceKernel>( diff --git a/paddle/phi/kernels/impl/lerp_grad_kernel_impl.h b/paddle/phi/kernels/impl/lerp_grad_kernel_impl.h index b47acbda0da2d213930476467a45f00fa1e544ba..541de0cc162ccd7bc48e296db97819176644f14f 100644 --- a/paddle/phi/kernels/impl/lerp_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/lerp_grad_kernel_impl.h @@ -33,33 +33,36 @@ static void LerpGradFunction(const Context& ctx, auto* dx = x_grad; auto* dy = y_grad; - auto dout_dims = dout.dims(); + auto& out_dims = out.dims(); DDim dx_dims; DDim dy_dims; auto w_dims = phi::funcs::ExtendDims2Rank(w.dims(), D); + auto g_dims = phi::funcs::ExtendDims2Rank(out_grad.dims(), D); Eigen::DSizes dx_bcast_dims; Eigen::DSizes dy_bcast_dims; Eigen::DSizes w_bcast_dims; + Eigen::DSizes g_bcast_dims; if (dx) { dx_dims = phi::funcs::ExtendDims2Rank(dx->dims(), D); - phi::funcs::GetBroadcastDims(dx_dims, dout_dims, &dx_bcast_dims); + phi::funcs::GetBroadcastDims(dx_dims, out_dims, &dx_bcast_dims); } if (dy) { dy_dims = phi::funcs::ExtendDims2Rank(dy->dims(), D); - phi::funcs::GetBroadcastDims(dy_dims, dout_dims, &dy_bcast_dims); + phi::funcs::GetBroadcastDims(dy_dims, out_dims, &dy_bcast_dims); } - phi::funcs::GetBroadcastDims(w_dims, dout_dims, &w_bcast_dims); + phi::funcs::GetBroadcastDims(w_dims, out_dims, &w_bcast_dims); + phi::funcs::GetBroadcastDims(g_dims, out_dims, &g_bcast_dims); auto eigen_w = phi::EigenTensor::From(w, w_dims); - auto eigen_dout = phi::EigenTensor::From(dout); + auto eigen_dout = phi::EigenTensor::From(dout, g_dims); Eigen::DSizes dx_reshape_dims; Eigen::DSizes dy_reshape_dims; Eigen::DSizes reduce_dims; - for (int i = 0; i < dout_dims.size(); ++i) { + for (int i = 0; i < out_dims.size(); ++i) { if (dx) { dx_reshape_dims[2 * i] = dx_bcast_dims[i]; dx_reshape_dims[2 * i + 1] = dx_dims[i]; @@ -76,7 +79,8 @@ static void LerpGradFunction(const Context& ctx, if (dx) { ctx.template Alloc(dx); auto eigen_dx = phi::EigenTensor::From(*dx, dx_dims); - auto eigen_expr = (1 - eigen_w.broadcast(w_bcast_dims)) * eigen_dout; + auto eigen_expr = (1 - eigen_w.broadcast(w_bcast_dims)) * + eigen_dout.broadcast(g_bcast_dims); eigen_dx.device(place) = eigen_expr.reshape(dx_reshape_dims) .sum(reduce_dims) .reshape(eigen_dx.dimensions()); @@ -84,13 +88,40 @@ static void LerpGradFunction(const Context& ctx, if (dy) { ctx.template Alloc(dy); auto eigen_dy = phi::EigenTensor::From(*dy, dy_dims); - auto eigen_expr = eigen_w.broadcast(w_bcast_dims) * eigen_dout; + auto eigen_expr = + eigen_w.broadcast(w_bcast_dims) * eigen_dout.broadcast(g_bcast_dims); eigen_dy.device(place) = eigen_expr.reshape(dy_reshape_dims) .sum(reduce_dims) .reshape(eigen_dy.dimensions()); } } +template +static void LerpGradFunctionZero(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& weight, + const DenseTensor& out, + const DenseTensor& out_grad, + DenseTensor* x_grad, + DenseTensor* y_grad) { + auto dim = make_ddim(std::vector(1, 1)); + auto eigen_w = phi::EigenTensor::From(weight, dim); + auto eigen_dout = phi::EigenTensor::From(out_grad, dim); + + auto& place = *ctx.eigen_device(); + if (x_grad) { + ctx.template Alloc(x_grad); + auto eigen_dx = phi::EigenTensor::From(*x_grad, dim); + eigen_dx.device(place) = (1 - eigen_w) * eigen_dout; + } + if (y_grad) { + ctx.template Alloc(y_grad); + auto eigen_dy = phi::EigenTensor::From(*y_grad, dim); + eigen_dy.device(place) = eigen_w * eigen_dout; + } +} + template void LerpGradKernel(const Context& ctx, const DenseTensor& x, @@ -103,10 +134,10 @@ void LerpGradKernel(const Context& ctx, int rank = out.dims().size(); PADDLE_ENFORCE_GE( rank, - 1, + 0, phi::errors::InvalidArgument( "The number of dimensions for LerpGradOp must be " - "greater than or equal to 1, but the value received is %d.", + "greater than or equal to 0, but the value received is %d.", rank)); PADDLE_ENFORCE_LE( rank, @@ -116,6 +147,10 @@ void LerpGradKernel(const Context& ctx, "less than or equal to 6, but the value received is %d.", rank)); switch (rank) { + case 0: + LerpGradFunctionZero( + ctx, x, y, weight, out, out_grad, x_grad, y_grad); + break; case 1: LerpGradFunction( ctx, x, y, weight, out, out_grad, x_grad, y_grad); diff --git a/paddle/phi/kernels/impl/lerp_kernel_impl.h b/paddle/phi/kernels/impl/lerp_kernel_impl.h index 72fa0672a5f48b3749e4ccb82f2f1709b3b3e08a..668349e09b951e46b83ed9dd12343a537878a22a 100644 --- a/paddle/phi/kernels/impl/lerp_kernel_impl.h +++ b/paddle/phi/kernels/impl/lerp_kernel_impl.h @@ -27,7 +27,6 @@ static void LerpFunction(const Context& ctx, const DenseTensor& weight, DenseTensor* out) { ctx.template Alloc(out); - const auto& out_dims = out->dims(); auto x_dims = phi::funcs::ExtendDims2Rank(x.dims(), D); auto y_dims = phi::funcs::ExtendDims2Rank(y.dims(), D); @@ -51,6 +50,24 @@ static void LerpFunction(const Context& ctx, (eigen_y.broadcast(y_bcast_dims) - eigen_x.broadcast(x_bcast_dims)); } +template +static void LerpFunctionZero(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& weight, + DenseTensor* out) { + ctx.template Alloc(out); + + auto dim = make_ddim(std::vector(1, 1)); + auto eigen_x = phi::EigenTensor::From(x, dim); + auto eigen_y = phi::EigenTensor::From(y, dim); + auto eigen_w = phi::EigenTensor::From(weight, dim); + auto eigen_out = phi::EigenTensor::From(*out, dim); + + auto& place = *ctx.eigen_device(); + eigen_out.device(place) = eigen_x + eigen_w * (eigen_y - eigen_x); +} + template void LerpKernel(const Context& ctx, const DenseTensor& x, @@ -60,10 +77,10 @@ void LerpKernel(const Context& ctx, int rank = out->dims().size(); PADDLE_ENFORCE_GE( rank, - 1, + 0, phi::errors::InvalidArgument( "The number of dimensions for LerpOp must be " - "greater than or equal to 1, but the value received is %d.", + "greater than or equal to 0, but the value received is %d.", rank)); PADDLE_ENFORCE_LE( rank, @@ -73,6 +90,9 @@ void LerpKernel(const Context& ctx, "less than or equal to 6, but the value received is %d.", rank)); switch (rank) { + case 0: + LerpFunctionZero(ctx, x, y, weight, out); + break; case 1: LerpFunction(ctx, x, y, weight, out); break; diff --git a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py index 4023023ab5283ea4a14b1ffc823c9d48b9a1aecd..370a626630f2b1511e358c0d23bb67a159150565 100644 --- a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py @@ -971,6 +971,49 @@ class TestSundryAPI(unittest.TestCase): self.assertEqual(x1.grad.numpy(), 0) self.assertEqual(x2.grad.numpy(), 0) + def test_lerp(self): + # 0D + 0D + x0 = paddle.rand([]) + y0 = paddle.rand([]) + w0 = paddle.rand([]) + x0.stop_gradient = False + y0.stop_gradient = False + + out0 = paddle.lerp(x0, y0, w0) + out0.backward() + + self.assertEqual(out0.shape, []) + self.assertEqual(x0.grad.shape, []) + self.assertEqual(y0.grad.shape, []) + + # 0D + ND + x1 = paddle.rand([]) + y1 = paddle.rand([64, 64]) + w1 = paddle.rand([]) + x1.stop_gradient = False + y1.stop_gradient = False + + out1 = paddle.lerp(x1, y1, w1) + out1.backward() + + self.assertEqual(out1.shape, [64, 64]) + self.assertEqual(x1.grad.shape, []) + self.assertEqual(y1.grad.shape, [64, 64]) + + # ND + 0D + x2 = paddle.rand([64, 64]) + y2 = paddle.rand([]) + w2 = paddle.rand([]) + x2.stop_gradient = False + y2.stop_gradient = False + + out2 = paddle.lerp(x2, y2, w2) + out2.backward() + + self.assertEqual(out2.shape, [64, 64]) + self.assertEqual(x2.grad.shape, [64, 64]) + self.assertEqual(y2.grad.shape, []) + def test_repeat_interleave(self): places = ['cpu'] if paddle.is_compiled_with_cuda(): @@ -1442,6 +1485,35 @@ class TestSundryAPIStatic(unittest.TestCase): self.assertEqual(res[0].shape, ()) self.assertEqual(res[1].shape, ()) + @prog_scope() + def test_lerp(self): + shapes = [ + [(), (), (), ()], + [(), (64, 64), (), (64, 64)], + [(64, 64), (), (), (64, 64)], + ] + for shape in shapes: + x = paddle.rand(shape[0]) + y = paddle.rand(shape[1]) + w = paddle.rand(shape[2]) + + x.stop_gradient = False + y.stop_gradient = False + out = paddle.lerp(x, y, w) + paddle.static.append_backward(out.sum()) + + prog = paddle.static.default_main_program() + block = prog.global_block() + x_grad = block.var(fluid.framework.grad_var_name(x.name)) + y_grad = block.var(fluid.framework.grad_var_name(y.name)) + out_grad = block.var(fluid.framework.grad_var_name(out.name)) + + res = self.exe.run(prog, fetch_list=[out, out_grad, y_grad, x_grad]) + self.assertEqual(res[0].shape, shape[3]) + self.assertEqual(res[1].shape, shape[3]) + self.assertEqual(res[2].shape, shape[1]) + self.assertEqual(res[3].shape, shape[0]) + @prog_scope() def test_repeat_interleave(self): x = paddle.full([], 1.0, 'float32')