未验证 提交 ef61df30 编写于 作者: R Rayman 提交者: GitHub

【Hackathon No.36】优化 lerp_grad op 在 GPU 上的计算性能 (#45946)

上级 5e0614a1
......@@ -15,8 +15,249 @@
#include "paddle/phi/kernels/lerp_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/lerp_grad_kernel_impl.h"
#include "paddle/phi/kernels/broadcast_tensors_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/reduce_function.h"
#include "paddle/phi/kernels/gpu/reduce.h"
namespace phi {
template <typename T>
__global__ void LerpGradKernelImpl(const T* weight,
const T* dout,
T* dx,
T* dy,
const int out_size,
const int x_size,
const int y_size) {
CUDA_KERNEL_LOOP_TYPE(idx, out_size, int64_t) {
T temp_dx = weight[idx] * dout[idx];
if (dx) {
if (idx < x_size) {
dx[idx] = dout[idx] - temp_dx;
}
}
if (dy) {
if (idx < y_size) {
dy[idx] = temp_dx;
}
}
}
}
template <typename T>
__global__ void LerpGradScalarKernelImpl(const T* weight,
const T* dout,
T* dx,
T* dy,
const int out_size,
const int x_size,
const int y_size) {
T weight_scalar = weight[0];
CUDA_KERNEL_LOOP_TYPE(idx, out_size, int64_t) {
T temp_dx = weight_scalar * dout[idx];
if (dx) {
if (idx < x_size) {
dx[idx] = dout[idx] - temp_dx;
}
}
if (dy) {
if (idx < y_size) {
dy[idx] = temp_dx;
}
}
}
}
bool XYNeedReduce(const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out) {
auto x_dims = x.dims();
auto y_dims = y.dims();
auto out_dims = out.dims();
int x_rank = x_dims.size();
int y_rank = y_dims.size();
int out_rank = out_dims.size();
int smaller_rank = std::min(x_rank, y_rank);
if (std::max(x_rank, y_rank) < out_rank) {
return true;
}
for (int i = 1; i <= smaller_rank; ++i) {
int x_idx = x_rank - i;
int y_idx = y_rank - i;
int out_idx = out_rank - i;
if (x_dims[x_idx] != y_dims[y_idx]) {
return true;
}
if (x_dims[x_idx] == 1 && y_dims[y_idx] == 1 && out_dims[out_idx] != 1) {
return true;
}
}
return false;
}
template <typename T, typename Context>
void SwitchKernel(const Context& ctx,
const DenseTensor& weight,
const DenseTensor& out_grad,
const int x_grad_size,
const int y_grad_size,
T* x_grad_data,
T* y_grad_data) {
if (weight.numel() == 1) {
// condition when weight is a scalar
const T* weight_data = weight.data<T>();
const T* out_grad_data = out_grad.data<T>();
const int64_t out_size = out_grad.numel();
const int64_t weight_size = weight.numel();
auto gpu_config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, out_size);
LerpGradScalarKernelImpl<T><<<gpu_config.GetGridSize(),
gpu_config.GetBlockSize(),
0,
ctx.stream()>>>(weight_data,
out_grad_data,
x_grad_data,
y_grad_data,
out_size,
x_grad_size,
y_grad_size);
} else {
// broadcast weight with out_grad's dimensions
const std::vector<const DenseTensor*> in_tensors = {&weight, &out_grad};
DenseTensor b_weight = phi::EmptyLike<T>(ctx, out_grad);
DenseTensor b_out = phi::EmptyLike<T>(ctx, out_grad);
std::vector<DenseTensor*> out_tensors = {&b_weight, &b_out};
phi::BroadcastTensorsKernel<T, Context>(ctx, in_tensors, out_tensors);
const T* weight_data = b_weight.data<T>();
const T* out_grad_data = b_out.data<T>();
const int out_size = out_grad.numel();
const int weight_size = weight.numel();
auto gpu_config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, out_size);
LerpGradKernelImpl<T><<<gpu_config.GetGridSize(),
gpu_config.GetBlockSize(),
0,
ctx.stream()>>>(weight_data,
out_grad_data,
x_grad_data,
y_grad_data,
out_size,
x_grad_size,
y_grad_size);
}
}
template <typename T, typename Context>
void LerpGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& weight,
const DenseTensor& out,
const DenseTensor& out_grad,
DenseTensor* x_grad,
DenseTensor* y_grad) {
const int rank = out.dims().size();
PADDLE_ENFORCE_GE(
rank,
1,
phi::errors::InvalidArgument(
"The number of dimensions for LerpGradOp must be "
"greater than or equal to 1, but the value received is %d.",
rank));
PADDLE_ENFORCE_LE(
rank,
6,
phi::errors::InvalidArgument(
"The number of dimensions for LerpGradOp must be "
"less than or equal to 6, but the value received is %d.",
rank));
// check if x_grad and y_grad need to be reduced
// if x has a different dimension with y or weight in the middle axis, then
// they need to be broadcast and then reduced.
bool reduce_flag = XYNeedReduce(x, y, out);
if (!reduce_flag) {
int x_grad_size = 0, y_grad_size = 0;
T* x_grad_data = NULL;
T* y_grad_data = NULL;
if (x_grad) {
x_grad_data = ctx.template Alloc<T>(x_grad);
x_grad_size = x.numel();
}
if (y_grad) {
y_grad_data = ctx.template Alloc<T>(y_grad);
y_grad_size = y.numel();
}
SwitchKernel<T, Context>(ctx,
weight,
out_grad,
x_grad_size,
y_grad_size,
x_grad_data,
y_grad_data);
} else {
int x_grad_size = 0, y_grad_size = 0;
DenseTensor b_xgrad = phi::EmptyLike<T, Context>(ctx, out_grad);
DenseTensor b_ygrad = phi::EmptyLike<T, Context>(ctx, out_grad);
T* x_grad_data = NULL;
T* y_grad_data = NULL;
if (x_grad) {
x_grad_data = ctx.template Alloc<T>(&b_xgrad);
x_grad_size = out.numel();
}
if (y_grad) {
y_grad_data = ctx.template Alloc<T>(&b_ygrad);
y_grad_size = out.numel();
}
SwitchKernel<T, Context>(ctx,
weight,
out_grad,
x_grad_size,
y_grad_size,
x_grad_data,
y_grad_data);
if (x_grad) {
std::vector<int> reduce_axis_x =
funcs::GetReduceDim(x_grad->dims(), b_xgrad.dims(), -1);
if (!reduce_axis_x.empty()) {
phi::funcs::
ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
ctx, b_xgrad, x_grad, kps::IdentityFunctor<T>(), reduce_axis_x);
} else {
x_grad->ShareDataWith(b_xgrad);
}
}
if (y_grad) {
std::vector<int> reduce_axis_y =
funcs::GetReduceDim(y_grad->dims(), b_ygrad.dims(), -1);
if (!reduce_axis_y.empty()) {
phi::funcs::
ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
ctx, b_ygrad, y_grad, kps::IdentityFunctor<T>(), reduce_axis_y);
} else {
y_grad->ShareDataWith(b_ygrad);
}
}
}
}
} // namespace phi
PD_REGISTER_KERNEL(
lerp_grad, GPU, ALL_LAYOUT, phi::LerpGradKernel, float, double) {}
......@@ -106,10 +106,11 @@ void BroadcastTensorsKernel(const Context& ctx,
SWITCH_OUT_RANK_CASE(3)
SWITCH_OUT_RANK_CASE(4)
SWITCH_OUT_RANK_CASE(5)
SWITCH_OUT_RANK_CASE(6)
default: {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"Target tensor rank out of range"
"Maximum supported rank for broadcast is: 5"));
"Maximum supported rank for broadcast is: 6"));
}
}
}
......
......@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
......@@ -78,6 +80,34 @@ class TestLerpWithDim6(TestLerp):
self.shape = [2, 1, 2, 5, 1, 5]
class TestLerpBroadXY(TestLerp):
def setUp(self):
self.op_type = "lerp"
self.python_api = paddle.lerp
self.init_dtype()
self.init_shape()
x = np.arange(1., 201.).astype(self.dtype).reshape([2, 1, 2, 50])
y = np.full(200, 10.).astype(self.dtype).reshape([2, 2, 1, 50])
w = np.asarray([0.5]).astype(self.dtype)
self.inputs = {'X': x, 'Y': y, 'Weight': w}
self.outputs = {'Out': x + w * (y - x)}
class TestLerpBroadWToXY(TestLerp):
def setUp(self):
self.op_type = "lerp"
self.python_api = paddle.lerp
self.init_dtype()
self.init_shape()
x = np.full(600, 2.5).astype(self.dtype).reshape([50, 2, 2, 3])
y = np.full(600, 1.).astype(self.dtype).reshape([50, 2, 2, 3])
w = np.random.random([3]).astype(self.dtype)
self.inputs = {'X': x, 'Y': y, 'Weight': w}
self.outputs = {'Out': x + w * (y - x)}
class TestLerpAPI(unittest.TestCase):
def init_dtype(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册