未验证 提交 c57e12be 编写于 作者: L Leo Chen 提交者: GitHub

refine dist_grad kernel (#44182)

* refine dist_grad kernel

* fix cpu kernel bug
上级 ee5cb5f2
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/dist_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/dist_grad_kernel_impl.h"
PD_REGISTER_KERNEL(
dist_grad, CPU, ALL_LAYOUT, phi::DistGradKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/dist_grad_kernel.h"
#include <tuple>
#include <vector>
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/elementwise_subtract_kernel.h"
#include "paddle/phi/kernels/p_norm_grad_kernel.h"
#include "paddle/phi/kernels/reduce_sum_kernel.h"
#include "paddle/phi/kernels/scale_kernel.h"
namespace phi {
std::pair<std::vector<int64_t>, std::vector<int64_t>> GetReduceDims(
const DDim& src_dim, const DDim& dst_dim) {
std::vector<int64_t> reduce_dims, new_dims;
auto pre_dims = src_dim.size() - dst_dim.size();
for (auto i = 0; i < pre_dims; ++i) {
reduce_dims.push_back(i);
}
for (auto i = pre_dims; i < src_dim.size(); ++i) {
if (dst_dim[i - pre_dims] == 1 && src_dim[i] != 1) {
reduce_dims.push_back(i);
} else {
new_dims.push_back(dst_dim[i - pre_dims]);
}
}
return {reduce_dims, new_dims};
}
template <typename T, typename Context>
void DistGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out,
const DenseTensor& out_grad,
float p,
DenseTensor* x_grad,
DenseTensor* y_grad) {
auto t = Subtract<T, Context>(dev_ctx, x, y);
DenseTensor x_grad_tmp;
x_grad_tmp.Resize(t.dims());
DenseTensor y_grad_tmp;
y_grad_tmp.Resize(t.dims());
PNormGradKernel<T, Context>(
dev_ctx, t, out, out_grad, p, -1, 1e-12, false, true, &x_grad_tmp);
ScaleKernel<T, Context>(dev_ctx, x_grad_tmp, -1.0, 0.0, false, &y_grad_tmp);
// do reduce, the implemetation of cpu SumKernel has bug, it changes
// the dims of output iternally, so we Resize x/y_grad twice.
auto res_x = GetReduceDims(x_grad_tmp.dims(), x.dims());
if (!std::get<0>(res_x).empty()) {
x_grad->Resize(phi::make_ddim(std::get<1>(res_x)));
SumKernel<T, Context>(
dev_ctx, x_grad_tmp, std::get<0>(res_x), x.dtype(), false, x_grad);
x_grad->Resize(x.dims());
} else {
x_grad->ShareBufferWith(x_grad_tmp);
}
auto res_y = GetReduceDims(y_grad_tmp.dims(), y.dims());
if (!std::get<0>(res_y).empty()) {
y_grad->Resize(phi::make_ddim(std::get<1>(res_y)));
SumKernel<T, Context>(
dev_ctx, y_grad_tmp, std::get<0>(res_y), y.dtype(), false, y_grad);
y_grad->Resize(y.dims());
} else {
y_grad->ShareBufferWith(y_grad_tmp);
}
}
} // namespace phi
PD_REGISTER_KERNEL(
dist_grad, CPU, ALL_LAYOUT, phi::DistGradKernel, float, double) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(
dist_grad, GPU, ALL_LAYOUT, phi::DistGradKernel, float, double) {}
#endif
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/dist_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/dist_grad_kernel_impl.h"
#ifdef PADDLE_WITH_HIP
PD_REGISTER_KERNEL(dist_grad, GPU, ALL_LAYOUT, phi::DistGradKernel, float) {}
#else
PD_REGISTER_KERNEL(
dist_grad, GPU, ALL_LAYOUT, phi::DistGradKernel, float, double) {}
#endif
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
template <typename T,
size_t D,
int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using ETensor = phi::EigenTensor<T, D, MajorType, IndexType>;
template <int Rank>
static void GetBraodcastDims(const phi::DDim& x_dims,
const phi::DDim& y_dims,
Eigen::DSizes<int, Rank>* x_bcast_dims,
Eigen::DSizes<int, Rank>* y_bcast_dims) {
int bcast_dims_remainder = 0;
for (int i = 0; i < x_dims.size(); ++i) {
if (x_dims[i] >= y_dims[i]) {
(*x_bcast_dims)[i] = 1;
(*y_bcast_dims)[i] = x_dims[i] / y_dims[i];
bcast_dims_remainder += x_dims[i] % y_dims[i];
} else {
(*y_bcast_dims)[i] = 1;
(*x_bcast_dims)[i] = y_dims[i] / x_dims[i];
bcast_dims_remainder += y_dims[i] % x_dims[i];
}
}
PADDLE_ENFORCE_EQ(bcast_dims_remainder,
0,
phi::errors::PreconditionNotMet(
"The input tensor of Op(dist) could not be broadcast, "
"X's shape is [%s], Y's shape is [%s].",
x_dims,
y_dims));
}
static phi::DDim GetNewDims(const phi::DDim& in_dims, int rank) {
std::vector<int64_t> new_dims_vec(rank);
if (in_dims.size() < rank) {
for (int i = 0; i < rank - in_dims.size(); ++i) {
new_dims_vec[i] = 1;
}
for (int i = 0; i < in_dims.size(); ++i) {
new_dims_vec[i + rank - in_dims.size()] = in_dims[i];
}
} else {
new_dims_vec = vectorize(in_dims);
}
return phi::make_ddim(new_dims_vec);
}
template <typename Context, typename T, int Rank>
static void DistGradFunction(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out,
const DenseTensor& out_grad,
float p,
DenseTensor* x_grad,
DenseTensor* y_grad) {
auto x_dims = x.dims();
auto y_dims = y.dims();
auto out_dims = out.dims();
phi::DDim x_new_dims = GetNewDims(x_dims, Rank);
phi::DDim y_new_dims = GetNewDims(y_dims, Rank);
phi::DDim out_new_dims = GetNewDims(out_dims, Rank);
auto x_t = ETensor<T, Rank>::From(x, x_new_dims);
auto y_t = ETensor<T, Rank>::From(y, y_new_dims);
auto out_t = ETensor<T, Rank>::From(out, out_new_dims);
Eigen::DSizes<int, Rank> x_bcast_dims;
Eigen::DSizes<int, Rank> y_bcast_dims;
Eigen::DSizes<int, Rank> out_bcast_dims;
GetBraodcastDims<Rank>(x_new_dims, y_new_dims, &x_bcast_dims, &y_bcast_dims);
std::vector<int64_t> new_dims_vec(Rank);
for (int i = 0; i < Rank; ++i) {
new_dims_vec[i] = std::max(x_new_dims[i], y_new_dims[i]);
out_bcast_dims[i] = new_dims_vec[i];
}
phi::DDim new_dims = phi::make_ddim(new_dims_vec);
auto& place = *dev_ctx.eigen_device();
auto out_grad_t = ETensor<T, Rank>::From(out_grad, out_new_dims);
DenseTensor grad;
grad.Resize(new_dims);
dev_ctx.template Alloc<T>(&grad);
auto grad_t = ETensor<T, Rank>::From(grad);
auto x_minux_y = x_t.broadcast(x_bcast_dims) - y_t.broadcast(y_bcast_dims);
auto x_minux_y_abs = x_minux_y.abs();
auto sign =
(x_minux_y > static_cast<T>(0)).template cast<T>() * static_cast<T>(1.0) +
(x_minux_y < static_cast<T>(0)).template cast<T>() * static_cast<T>(-1.0);
T epsilon = static_cast<T>(1.0e-10f);
// 1: Lp-norm(z), z = x-y, compute dz
if (p == 0) {
phi::funcs::SetConstant<Context, T> set_zero;
set_zero(dev_ctx, &grad, static_cast<T>(0));
} else if (p == INFINITY || p == -INFINITY) {
// p=inf or -inf, Lp-norm = |z_i|, the j-th element of dz tends to 0 if
// j!=i, or equals to sign(z_i) * dout if j=i.
if (paddle::platform::is_cpu_place(dev_ctx.GetPlace())) {
grad_t.device(place) = (x_minux_y_abs == out_t.broadcast(out_bcast_dims))
.template cast<T>() *
sign.eval() * out_grad_t.broadcast(out_bcast_dims);
} else {
grad_t.device(place) = (x_minux_y_abs == out_t.broadcast(out_bcast_dims))
.template cast<T>() *
sign * out_grad_t.broadcast(out_bcast_dims);
}
} else {
// dz = pow(abs(x-y)/out, p-1) * sign(x-y) * dout
if (paddle::platform::is_cpu_place(dev_ctx.GetPlace())) {
grad_t.device(place) =
(x_minux_y_abs / (out_t + epsilon).broadcast(out_bcast_dims))
.pow(p - 1) *
sign.eval() * out_grad_t.broadcast(out_bcast_dims);
} else {
grad_t.device(place) =
(x_minux_y_abs / (out_t + epsilon).broadcast(out_bcast_dims))
.pow(p - 1) *
sign * out_grad_t.broadcast(out_bcast_dims);
}
}
Eigen::DSizes<int, Rank * 2> x_reshape_dims;
Eigen::DSizes<int, Rank * 2> y_reshape_dims;
Eigen::DSizes<int, Rank> reduce_dims;
for (int i = 0; i < x_new_dims.size(); ++i) {
x_reshape_dims[2 * i] = x_bcast_dims[i];
x_reshape_dims[2 * i + 1] = x_new_dims[i];
y_reshape_dims[2 * i] = y_bcast_dims[i];
y_reshape_dims[2 * i + 1] = y_new_dims[i];
reduce_dims[i] = 2 * i;
}
// 2: if x or y is broadcasted in forward function,
// the grad need to be sum along the broadcasted dimensions
if (x_grad) {
dev_ctx.template Alloc<T>(x_grad);
auto x_grad_t = ETensor<T, Rank>::From(*x_grad, x_new_dims);
x_grad_t.device(place) = grad_t.reshape(x_reshape_dims)
.sum(reduce_dims)
.reshape(x_grad_t.dimensions());
}
if (y_grad) {
dev_ctx.template Alloc<T>(y_grad);
auto y_grad_t = ETensor<T, Rank>::From(*y_grad, y_new_dims);
y_grad_t.device(place) = -grad_t.reshape(y_reshape_dims)
.sum(reduce_dims)
.reshape(y_grad_t.dimensions());
}
}
template <typename T, typename Context>
void DistGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out,
const DenseTensor& out_grad,
float p,
DenseTensor* x_grad,
DenseTensor* y_grad) {
auto x_rank = x.dims().size();
auto y_rank = y.dims().size();
auto rank = std::max(x_rank, y_rank);
PADDLE_ENFORCE_LE(rank,
6,
phi::errors::Unimplemented(
"Op(dist) only support tensors with no more than 6 "
"dimensions, but X's rank is %d, Y's rank is %d.",
x_rank,
y_rank));
switch (rank) {
case 1:
DistGradFunction<Context, T, 1>(
dev_ctx, x, y, out, out_grad, p, x_grad, y_grad);
break;
case 2:
DistGradFunction<Context, T, 2>(
dev_ctx, x, y, out, out_grad, p, x_grad, y_grad);
break;
case 3:
DistGradFunction<Context, T, 3>(
dev_ctx, x, y, out, out_grad, p, x_grad, y_grad);
break;
case 4:
DistGradFunction<Context, T, 4>(
dev_ctx, x, y, out, out_grad, p, x_grad, y_grad);
break;
case 5:
DistGradFunction<Context, T, 5>(
dev_ctx, x, y, out, out_grad, p, x_grad, y_grad);
break;
case 6:
DistGradFunction<Context, T, 6>(
dev_ctx, x, y, out, out_grad, p, x_grad, y_grad);
break;
}
}
} // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册