未验证 提交 8cd0d5b3 编写于 作者: S sunli 提交者: GitHub

lerp support 0 Tensor (#49667)

* lerp support 0 Tensor

* fix lerp grad

* fix lerp zero test

* fix 0D + ND/ND + 0D

* fix check

* update code

* fix lerp infer shape

* static backward test

* updata static graph test
上级 c24e7fe1
......@@ -598,9 +598,7 @@ void LerpInferMeta(const MetaTensor& x,
auto w_dims = weight.dims();
DDim out_dims;
out_dims = funcs::GetOutputDims(x_dims, y_dims);
if (w_dims.size() > 1 || w_dims[0] != 1) {
out_dims = funcs::GetOutputDims(out_dims, w_dims);
}
out_dims = funcs::GetOutputDims(out_dims, w_dims);
out->set_dims(out_dims);
out->set_dtype(x.dtype());
out->share_lod(x);
......
......@@ -77,9 +77,15 @@ __global__ void LerpGradScalarKernelImpl(const T* weight,
bool XYNeedReduce(const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out) {
auto x_dims = x.dims();
auto y_dims = y.dims();
auto x_dims =
x.dims().size() ? x.dims() : make_ddim(std::vector<int64_t>(1, 1));
auto y_dims =
y.dims().size() ? y.dims() : make_ddim(std::vector<int64_t>(1, 1));
auto out_dims = out.dims();
if (out_dims.size() == 0) {
return false;
}
int x_rank = x_dims.size();
int y_rank = y_dims.size();
int out_rank = out_dims.size();
......@@ -166,10 +172,10 @@ void LerpGradKernel(const Context& ctx,
const int rank = out.dims().size();
PADDLE_ENFORCE_GE(
rank,
1,
0,
phi::errors::InvalidArgument(
"The number of dimensions for LerpGradOp must be "
"greater than or equal to 1, but the value received is %d.",
"greater than or equal to 0, but the value received is %d.",
rank));
PADDLE_ENFORCE_LE(
rank,
......@@ -231,9 +237,12 @@ void LerpGradKernel(const Context& ctx,
x_grad_data,
y_grad_data);
auto zero_dim = make_ddim(std::vector<int64_t>(1, 1));
if (x_grad) {
std::vector<int> reduce_axis_x =
funcs::GetReduceDim(x_grad->dims(), b_xgrad.dims(), -1);
funcs::GetReduceDim(x_grad->dims().size() ? x_grad->dims() : zero_dim,
b_xgrad.dims(),
-1);
if (!reduce_axis_x.empty()) {
phi::funcs::
ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
......@@ -245,7 +254,9 @@ void LerpGradKernel(const Context& ctx,
if (y_grad) {
std::vector<int> reduce_axis_y =
funcs::GetReduceDim(y_grad->dims(), b_ygrad.dims(), -1);
funcs::GetReduceDim(y_grad->dims().size() ? y_grad->dims() : zero_dim,
b_ygrad.dims(),
-1);
if (!reduce_axis_y.empty()) {
phi::funcs::
ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
......
......@@ -33,33 +33,36 @@ static void LerpGradFunction(const Context& ctx,
auto* dx = x_grad;
auto* dy = y_grad;
auto dout_dims = dout.dims();
auto& out_dims = out.dims();
DDim dx_dims;
DDim dy_dims;
auto w_dims = phi::funcs::ExtendDims2Rank(w.dims(), D);
auto g_dims = phi::funcs::ExtendDims2Rank(out_grad.dims(), D);
Eigen::DSizes<int, D> dx_bcast_dims;
Eigen::DSizes<int, D> dy_bcast_dims;
Eigen::DSizes<int, D> w_bcast_dims;
Eigen::DSizes<int, D> g_bcast_dims;
if (dx) {
dx_dims = phi::funcs::ExtendDims2Rank(dx->dims(), D);
phi::funcs::GetBroadcastDims<D>(dx_dims, dout_dims, &dx_bcast_dims);
phi::funcs::GetBroadcastDims<D>(dx_dims, out_dims, &dx_bcast_dims);
}
if (dy) {
dy_dims = phi::funcs::ExtendDims2Rank(dy->dims(), D);
phi::funcs::GetBroadcastDims<D>(dy_dims, dout_dims, &dy_bcast_dims);
phi::funcs::GetBroadcastDims<D>(dy_dims, out_dims, &dy_bcast_dims);
}
phi::funcs::GetBroadcastDims<D>(w_dims, dout_dims, &w_bcast_dims);
phi::funcs::GetBroadcastDims<D>(w_dims, out_dims, &w_bcast_dims);
phi::funcs::GetBroadcastDims<D>(g_dims, out_dims, &g_bcast_dims);
auto eigen_w = phi::EigenTensor<T, D>::From(w, w_dims);
auto eigen_dout = phi::EigenTensor<T, D>::From(dout);
auto eigen_dout = phi::EigenTensor<T, D>::From(dout, g_dims);
Eigen::DSizes<int, D * 2> dx_reshape_dims;
Eigen::DSizes<int, D * 2> dy_reshape_dims;
Eigen::DSizes<int, D> reduce_dims;
for (int i = 0; i < dout_dims.size(); ++i) {
for (int i = 0; i < out_dims.size(); ++i) {
if (dx) {
dx_reshape_dims[2 * i] = dx_bcast_dims[i];
dx_reshape_dims[2 * i + 1] = dx_dims[i];
......@@ -76,7 +79,8 @@ static void LerpGradFunction(const Context& ctx,
if (dx) {
ctx.template Alloc<T>(dx);
auto eigen_dx = phi::EigenTensor<T, D>::From(*dx, dx_dims);
auto eigen_expr = (1 - eigen_w.broadcast(w_bcast_dims)) * eigen_dout;
auto eigen_expr = (1 - eigen_w.broadcast(w_bcast_dims)) *
eigen_dout.broadcast(g_bcast_dims);
eigen_dx.device(place) = eigen_expr.reshape(dx_reshape_dims)
.sum(reduce_dims)
.reshape(eigen_dx.dimensions());
......@@ -84,13 +88,40 @@ static void LerpGradFunction(const Context& ctx,
if (dy) {
ctx.template Alloc<T>(dy);
auto eigen_dy = phi::EigenTensor<T, D>::From(*dy, dy_dims);
auto eigen_expr = eigen_w.broadcast(w_bcast_dims) * eigen_dout;
auto eigen_expr =
eigen_w.broadcast(w_bcast_dims) * eigen_dout.broadcast(g_bcast_dims);
eigen_dy.device(place) = eigen_expr.reshape(dy_reshape_dims)
.sum(reduce_dims)
.reshape(eigen_dy.dimensions());
}
}
template <typename Context, typename T>
static void LerpGradFunctionZero(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& weight,
const DenseTensor& out,
const DenseTensor& out_grad,
DenseTensor* x_grad,
DenseTensor* y_grad) {
auto dim = make_ddim(std::vector<int64_t>(1, 1));
auto eigen_w = phi::EigenTensor<T, 1>::From(weight, dim);
auto eigen_dout = phi::EigenTensor<T, 1>::From(out_grad, dim);
auto& place = *ctx.eigen_device();
if (x_grad) {
ctx.template Alloc<T>(x_grad);
auto eigen_dx = phi::EigenTensor<T, 1>::From(*x_grad, dim);
eigen_dx.device(place) = (1 - eigen_w) * eigen_dout;
}
if (y_grad) {
ctx.template Alloc<T>(y_grad);
auto eigen_dy = phi::EigenTensor<T, 1>::From(*y_grad, dim);
eigen_dy.device(place) = eigen_w * eigen_dout;
}
}
template <typename T, typename Context>
void LerpGradKernel(const Context& ctx,
const DenseTensor& x,
......@@ -103,10 +134,10 @@ void LerpGradKernel(const Context& ctx,
int rank = out.dims().size();
PADDLE_ENFORCE_GE(
rank,
1,
0,
phi::errors::InvalidArgument(
"The number of dimensions for LerpGradOp must be "
"greater than or equal to 1, but the value received is %d.",
"greater than or equal to 0, but the value received is %d.",
rank));
PADDLE_ENFORCE_LE(
rank,
......@@ -116,6 +147,10 @@ void LerpGradKernel(const Context& ctx,
"less than or equal to 6, but the value received is %d.",
rank));
switch (rank) {
case 0:
LerpGradFunctionZero<Context, T>(
ctx, x, y, weight, out, out_grad, x_grad, y_grad);
break;
case 1:
LerpGradFunction<Context, T, 1>(
ctx, x, y, weight, out, out_grad, x_grad, y_grad);
......
......@@ -27,7 +27,6 @@ static void LerpFunction(const Context& ctx,
const DenseTensor& weight,
DenseTensor* out) {
ctx.template Alloc<T>(out);
const auto& out_dims = out->dims();
auto x_dims = phi::funcs::ExtendDims2Rank(x.dims(), D);
auto y_dims = phi::funcs::ExtendDims2Rank(y.dims(), D);
......@@ -51,6 +50,24 @@ static void LerpFunction(const Context& ctx,
(eigen_y.broadcast(y_bcast_dims) - eigen_x.broadcast(x_bcast_dims));
}
template <typename Context, typename T>
static void LerpFunctionZero(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& weight,
DenseTensor* out) {
ctx.template Alloc<T>(out);
auto dim = make_ddim(std::vector<int64_t>(1, 1));
auto eigen_x = phi::EigenTensor<T, 1>::From(x, dim);
auto eigen_y = phi::EigenTensor<T, 1>::From(y, dim);
auto eigen_w = phi::EigenTensor<T, 1>::From(weight, dim);
auto eigen_out = phi::EigenTensor<T, 1>::From(*out, dim);
auto& place = *ctx.eigen_device();
eigen_out.device(place) = eigen_x + eigen_w * (eigen_y - eigen_x);
}
template <typename T, typename Context>
void LerpKernel(const Context& ctx,
const DenseTensor& x,
......@@ -60,10 +77,10 @@ void LerpKernel(const Context& ctx,
int rank = out->dims().size();
PADDLE_ENFORCE_GE(
rank,
1,
0,
phi::errors::InvalidArgument(
"The number of dimensions for LerpOp must be "
"greater than or equal to 1, but the value received is %d.",
"greater than or equal to 0, but the value received is %d.",
rank));
PADDLE_ENFORCE_LE(
rank,
......@@ -73,6 +90,9 @@ void LerpKernel(const Context& ctx,
"less than or equal to 6, but the value received is %d.",
rank));
switch (rank) {
case 0:
LerpFunctionZero<Context, T>(ctx, x, y, weight, out);
break;
case 1:
LerpFunction<Context, T, 1>(ctx, x, y, weight, out);
break;
......
......@@ -971,6 +971,49 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(x1.grad.numpy(), 0)
self.assertEqual(x2.grad.numpy(), 0)
def test_lerp(self):
# 0D + 0D
x0 = paddle.rand([])
y0 = paddle.rand([])
w0 = paddle.rand([])
x0.stop_gradient = False
y0.stop_gradient = False
out0 = paddle.lerp(x0, y0, w0)
out0.backward()
self.assertEqual(out0.shape, [])
self.assertEqual(x0.grad.shape, [])
self.assertEqual(y0.grad.shape, [])
# 0D + ND
x1 = paddle.rand([])
y1 = paddle.rand([64, 64])
w1 = paddle.rand([])
x1.stop_gradient = False
y1.stop_gradient = False
out1 = paddle.lerp(x1, y1, w1)
out1.backward()
self.assertEqual(out1.shape, [64, 64])
self.assertEqual(x1.grad.shape, [])
self.assertEqual(y1.grad.shape, [64, 64])
# ND + 0D
x2 = paddle.rand([64, 64])
y2 = paddle.rand([])
w2 = paddle.rand([])
x2.stop_gradient = False
y2.stop_gradient = False
out2 = paddle.lerp(x2, y2, w2)
out2.backward()
self.assertEqual(out2.shape, [64, 64])
self.assertEqual(x2.grad.shape, [64, 64])
self.assertEqual(y2.grad.shape, [])
def test_repeat_interleave(self):
places = ['cpu']
if paddle.is_compiled_with_cuda():
......@@ -1442,6 +1485,35 @@ class TestSundryAPIStatic(unittest.TestCase):
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
@prog_scope()
def test_lerp(self):
shapes = [
[(), (), (), ()],
[(), (64, 64), (), (64, 64)],
[(64, 64), (), (), (64, 64)],
]
for shape in shapes:
x = paddle.rand(shape[0])
y = paddle.rand(shape[1])
w = paddle.rand(shape[2])
x.stop_gradient = False
y.stop_gradient = False
out = paddle.lerp(x, y, w)
paddle.static.append_backward(out.sum())
prog = paddle.static.default_main_program()
block = prog.global_block()
x_grad = block.var(fluid.framework.grad_var_name(x.name))
y_grad = block.var(fluid.framework.grad_var_name(y.name))
out_grad = block.var(fluid.framework.grad_var_name(out.name))
res = self.exe.run(prog, fetch_list=[out, out_grad, y_grad, x_grad])
self.assertEqual(res[0].shape, shape[3])
self.assertEqual(res[1].shape, shape[3])
self.assertEqual(res[2].shape, shape[1])
self.assertEqual(res[3].shape, shape[0])
@prog_scope()
def test_repeat_interleave(self):
x = paddle.full([], 1.0, 'float32')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册