未验证 提交 7e076e7b 编写于 作者: Z Zhong Hui 提交者: GitHub

[PHI] Move dist op to phi (#40178)

* move dist op to phi

* fix

* fix

* fix as reviews
上级 a3f28a31
...@@ -12,10 +12,13 @@ ...@@ -12,10 +12,13 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/dist_op.h"
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -121,13 +124,11 @@ class DistGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -121,13 +124,11 @@ class DistGradOpMaker : public framework::SingleGradOpMaker<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
DELCARE_INFER_SHAPE_FUNCTOR(dist, DistInferShapeFunctor,
PT_INFER_META(phi::DistInferMeta));
REGISTER_OPERATOR(dist, ops::DistOp, ops::DistOpMaker, REGISTER_OPERATOR(dist, ops::DistOp, ops::DistOpMaker,
ops::DistGradOpMaker<paddle::framework::OpDesc>, ops::DistGradOpMaker<paddle::framework::OpDesc>,
ops::DistGradOpMaker<paddle::imperative::OpBase>); ops::DistGradOpMaker<paddle::imperative::OpBase>,
DistInferShapeFunctor);
REGISTER_OPERATOR(dist_grad, ops::DistOpGrad); REGISTER_OPERATOR(dist_grad, ops::DistOpGrad);
REGISTER_OP_CPU_KERNEL(
dist, ops::DistKernel<paddle::platform::CPUDeviceContext, float>,
ops::DistKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
dist_grad, ops::DistGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::DistGradKernel<paddle::platform::CPUDeviceContext, double>)
...@@ -456,6 +456,29 @@ void BCELossInferMeta(const MetaTensor& input, ...@@ -456,6 +456,29 @@ void BCELossInferMeta(const MetaTensor& input,
out->share_lod(input); out->share_lod(input);
} }
void DistInferMeta(const MetaTensor& x,
const MetaTensor& y,
float p,
MetaTensor* out) {
auto x_dims = x.dims();
auto y_dims = y.dims();
PADDLE_ENFORCE_NE(phi::product(x_dims),
0,
phi::errors::InvalidArgument(
"The Input(X) has not been initialized properly. The "
"shape of Input(X) = [%s].",
x_dims));
PADDLE_ENFORCE_NE(phi::product(y_dims),
0,
phi::errors::InvalidArgument(
"The Input(Y) has not been initialized properly. The "
"shape of Input(Y) = [%s].",
y_dims));
out->set_dims({1});
out->set_dtype(x.dtype());
}
void GatherNdInferMeta(const MetaTensor& x, void GatherNdInferMeta(const MetaTensor& x,
const MetaTensor& index, const MetaTensor& index,
MetaTensor* out) { MetaTensor* out) {
......
...@@ -85,6 +85,11 @@ void BCELossInferMeta(const MetaTensor& input, ...@@ -85,6 +85,11 @@ void BCELossInferMeta(const MetaTensor& input,
MetaTensor* out, MetaTensor* out,
MetaConfig config = MetaConfig()); MetaConfig config = MetaConfig());
void DistInferMeta(const MetaTensor& x,
const MetaTensor& y,
float p,
MetaTensor* out);
void GatherNdInferMeta(const MetaTensor& x, void GatherNdInferMeta(const MetaTensor& x,
const MetaTensor& index, const MetaTensor& index,
MetaTensor* out); MetaTensor* out);
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/dist_grad_kernel.h"
#include "paddle/phi/kernels/impl/dist_grad_kernel_impl.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(
dist_grad, CPU, ALL_LAYOUT, phi::DistGradKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/dist_kernel.h"
#include "paddle/phi/kernels/impl/dist_kernel_impl.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(dist, CPU, ALL_LAYOUT, phi::DistKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void DistGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out,
const DenseTensor& out_grad,
float p,
DenseTensor* x_grad,
DenseTensor* y_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void DistKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
float p,
DenseTensor* out);
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/dist_grad_kernel.h"
#include "paddle/phi/kernels/impl/dist_grad_kernel_impl.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#ifdef PADDLE_WITH_HIP
PD_REGISTER_KERNEL(dist_grad, GPU, ALL_LAYOUT, phi::DistGradKernel, float) {}
#else
PD_REGISTER_KERNEL(
dist_grad, GPU, ALL_LAYOUT, phi::DistGradKernel, float, double) {}
#endif
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -12,21 +12,16 @@ ...@@ -12,21 +12,16 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/dist_op.h" #include "paddle/phi/kernels/dist_kernel.h"
#include "paddle/phi/kernels/impl/dist_kernel_impl.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace ops = paddle::operators;
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
// Eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorReductionGpu.h:922 // Eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorReductionGpu.h:922
// do not support double in HIPCC platform (Eigen3 to be fixed) // do not support double in HIPCC platform (Eigen3 to be fixed)
REGISTER_OP_CUDA_KERNEL( PD_REGISTER_KERNEL(dist, GPU, ALL_LAYOUT, phi::DistKernel, float) {}
dist, ops::DistKernel<paddle::platform::CUDADeviceContext, float>);
REGISTER_OP_CUDA_KERNEL(
dist_grad, ops::DistGradKernel<paddle::platform::CUDADeviceContext, float>);
#else #else
REGISTER_OP_CUDA_KERNEL( PD_REGISTER_KERNEL(dist, GPU, ALL_LAYOUT, phi::DistKernel, float, double) {}
dist, ops::DistKernel<paddle::platform::CUDADeviceContext, float>,
ops::DistKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
dist_grad, ops::DistGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::DistGradKernel<paddle::platform::CUDADeviceContext, double>);
#endif #endif
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
// You may obtain a copy of the License at You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and See the License for the specific language governing permissions and
// limitations under the License. limitations under the License. */
#pragma once #pragma once
#include <math.h> #include "paddle/phi/core/dense_tensor.h"
#include <algorithm> #include "paddle/phi/kernels/funcs/eigen/common.h"
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle { namespace phi {
namespace operators {
template <typename T, size_t D, int MajorType = Eigen::RowMajor, template <typename T,
size_t D,
int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>; using ETensor = phi::EigenTensor<T, D, MajorType, IndexType>;
using framework::Tensor;
template <int Rank> template <int Rank>
static void GetBraodcastDims(const framework::DDim& x_dims, static void GetBraodcastDims(const phi::DDim& x_dims,
const framework::DDim& y_dims, const phi::DDim& y_dims,
Eigen::DSizes<int, Rank>* x_bcast_dims, Eigen::DSizes<int, Rank>* x_bcast_dims,
Eigen::DSizes<int, Rank>* y_bcast_dims) { Eigen::DSizes<int, Rank>* y_bcast_dims) {
int bcast_dims_remainder = 0; int bcast_dims_remainder = 0;
...@@ -46,14 +43,16 @@ static void GetBraodcastDims(const framework::DDim& x_dims, ...@@ -46,14 +43,16 @@ static void GetBraodcastDims(const framework::DDim& x_dims,
bcast_dims_remainder += y_dims[i] % x_dims[i]; bcast_dims_remainder += y_dims[i] % x_dims[i];
} }
} }
PADDLE_ENFORCE_EQ(bcast_dims_remainder, 0, PADDLE_ENFORCE_EQ(bcast_dims_remainder,
platform::errors::PreconditionNotMet( 0,
phi::errors::PreconditionNotMet(
"The input tensor of Op(dist) could not be broadcast, " "The input tensor of Op(dist) could not be broadcast, "
"X's shape is [%s], Y's shape is [%s].", "X's shape is [%s], Y's shape is [%s].",
x_dims, y_dims)); x_dims,
y_dims));
} }
static framework::DDim GetNewDims(const framework::DDim& in_dims, int rank) { static phi::DDim GetNewDims(const phi::DDim& in_dims, int rank) {
std::vector<int64_t> new_dims_vec(rank); std::vector<int64_t> new_dims_vec(rank);
if (in_dims.size() < rank) { if (in_dims.size() < rank) {
for (int i = 0; i < rank - in_dims.size(); ++i) { for (int i = 0; i < rank - in_dims.size(); ++i) {
...@@ -68,80 +67,25 @@ static framework::DDim GetNewDims(const framework::DDim& in_dims, int rank) { ...@@ -68,80 +67,25 @@ static framework::DDim GetNewDims(const framework::DDim& in_dims, int rank) {
return phi::make_ddim(new_dims_vec); return phi::make_ddim(new_dims_vec);
} }
template <typename DeviceContext, typename T, int Rank> template <typename Context, typename T, int Rank>
static void DistFunction(const framework::ExecutionContext& context) { static void DistGradFunction(const Context& dev_ctx,
auto* x = context.Input<Tensor>("X"); const DenseTensor& x,
auto* y = context.Input<Tensor>("Y"); const DenseTensor& y,
auto* out = context.Output<Tensor>("Out"); const DenseTensor& out,
auto p = context.Attr<float>("p"); const DenseTensor& out_grad,
out->mutable_data<T>(context.GetPlace()); float p,
DenseTensor* x_grad,
auto x_dims = context.Input<Tensor>("X")->dims(); DenseTensor* y_grad) {
auto y_dims = context.Input<Tensor>("Y")->dims(); auto x_dims = x.dims();
auto y_dims = y.dims();
// new dims with same size as rank, e.g. (rank=3, (4, 3) => (1, 4, 3)) auto out_dims = out.dims();
framework::DDim x_new_dims = GetNewDims(x_dims, Rank);
framework::DDim y_new_dims = GetNewDims(y_dims, Rank); phi::DDim x_new_dims = GetNewDims(x_dims, Rank);
phi::DDim y_new_dims = GetNewDims(y_dims, Rank);
auto x_t = EigenTensor<T, Rank>::From(*x, x_new_dims); phi::DDim out_new_dims = GetNewDims(out_dims, Rank);
auto y_t = EigenTensor<T, Rank>::From(*y, y_new_dims); auto x_t = ETensor<T, Rank>::From(x, x_new_dims);
auto out_t = EigenTensor<T, 1>::From(*out); auto y_t = ETensor<T, Rank>::From(y, y_new_dims);
auto& place = auto out_t = ETensor<T, Rank>::From(out, out_new_dims);
*context.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, Rank> x_bcast_dims;
Eigen::DSizes<int, Rank> y_bcast_dims;
GetBraodcastDims<Rank>(x_new_dims, y_new_dims, &x_bcast_dims, &y_bcast_dims);
// p=0 means number of non-zero elements of (x-y)
// p=inf means the maximum of |x-y|
// p=-inf means the minimum of |x-y|
// otherwise, Lp-norm = pow(sum(pow(|x-y|, p)), 1/p)
if (p == 0) {
out_t.device(place) =
(x_t.broadcast(x_bcast_dims) != y_t.broadcast(y_bcast_dims))
.template cast<T>()
.sum();
} else if (p == INFINITY) {
out_t.device(place) =
(x_t.broadcast(x_bcast_dims) - y_t.broadcast(y_bcast_dims))
.abs()
.maximum();
} else if (p == -INFINITY) {
out_t.device(place) =
(x_t.broadcast(x_bcast_dims) - y_t.broadcast(y_bcast_dims))
.abs()
.minimum();
} else {
out_t.device(place) =
(x_t.broadcast(x_bcast_dims) - y_t.broadcast(y_bcast_dims))
.abs()
.pow(p)
.sum()
.pow(1.0 / p);
}
}
template <typename DeviceContext, typename T, int Rank>
static void DistGradFunction(const framework::ExecutionContext& context) {
auto* x = context.Input<Tensor>("X");
auto* y = context.Input<Tensor>("Y");
auto* out = context.Input<Tensor>("Out");
auto p = context.Attr<float>("p");
auto x_grad = context.Output<Tensor>(framework::GradVarName("X"));
auto y_grad = context.Output<Tensor>(framework::GradVarName("Y"));
auto out_grad = context.Input<Tensor>(framework::GradVarName("Out"));
auto x_dims = context.Input<Tensor>("X")->dims();
auto y_dims = context.Input<Tensor>("Y")->dims();
auto out_dims = context.Input<Tensor>("Out")->dims();
framework::DDim x_new_dims = GetNewDims(x_dims, Rank);
framework::DDim y_new_dims = GetNewDims(y_dims, Rank);
framework::DDim out_new_dims = GetNewDims(out_dims, Rank);
auto x_t = EigenTensor<T, Rank>::From(*x, x_new_dims);
auto y_t = EigenTensor<T, Rank>::From(*y, y_new_dims);
auto out_t = EigenTensor<T, Rank>::From(*out, out_new_dims);
Eigen::DSizes<int, Rank> x_bcast_dims; Eigen::DSizes<int, Rank> x_bcast_dims;
Eigen::DSizes<int, Rank> y_bcast_dims; Eigen::DSizes<int, Rank> y_bcast_dims;
...@@ -153,14 +97,14 @@ static void DistGradFunction(const framework::ExecutionContext& context) { ...@@ -153,14 +97,14 @@ static void DistGradFunction(const framework::ExecutionContext& context) {
new_dims_vec[i] = std::max(x_new_dims[i], y_new_dims[i]); new_dims_vec[i] = std::max(x_new_dims[i], y_new_dims[i]);
out_bcast_dims[i] = new_dims_vec[i]; out_bcast_dims[i] = new_dims_vec[i];
} }
framework::DDim new_dims = phi::make_ddim(new_dims_vec); phi::DDim new_dims = phi::make_ddim(new_dims_vec);
auto& place = auto& place = *dev_ctx.eigen_device();
*context.template device_context<DeviceContext>().eigen_device(); auto out_grad_t = ETensor<T, Rank>::From(out_grad, out_new_dims);
auto out_grad_t = EigenTensor<T, Rank>::From(*out_grad, out_new_dims); DenseTensor grad;
framework::Tensor grad; grad.Resize(new_dims);
grad.mutable_data<T>(new_dims, context.GetPlace()); dev_ctx.template Alloc<T>(&grad);
auto grad_t = EigenTensor<T, Rank>::From(grad); auto grad_t = ETensor<T, Rank>::From(grad);
auto x_minux_y = x_t.broadcast(x_bcast_dims) - y_t.broadcast(y_bcast_dims); auto x_minux_y = x_t.broadcast(x_bcast_dims) - y_t.broadcast(y_bcast_dims);
auto x_minux_y_abs = x_minux_y.abs(); auto x_minux_y_abs = x_minux_y.abs();
...@@ -171,13 +115,12 @@ static void DistGradFunction(const framework::ExecutionContext& context) { ...@@ -171,13 +115,12 @@ static void DistGradFunction(const framework::ExecutionContext& context) {
// 1: Lp-norm(z), z = x-y, compute dz // 1: Lp-norm(z), z = x-y, compute dz
if (p == 0) { if (p == 0) {
phi::funcs::SetConstant<DeviceContext, T> set_zero; phi::funcs::SetConstant<Context, T> set_zero;
auto& dev_ctx = context.template device_context<DeviceContext>();
set_zero(dev_ctx, &grad, static_cast<T>(0)); set_zero(dev_ctx, &grad, static_cast<T>(0));
} else if (p == INFINITY || p == -INFINITY) { } else if (p == INFINITY || p == -INFINITY) {
// p=inf or -inf, Lp-norm = |z_i|, the j-th element of dz tends to 0 if // p=inf or -inf, Lp-norm = |z_i|, the j-th element of dz tends to 0 if
// j!=i, or equals to sign(z_i) * dout if j=i. // j!=i, or equals to sign(z_i) * dout if j=i.
if (platform::is_cpu_place(context.GetPlace())) { if (paddle::platform::is_cpu_place(dev_ctx.GetPlace())) {
grad_t.device(place) = (x_minux_y_abs == out_t.broadcast(out_bcast_dims)) grad_t.device(place) = (x_minux_y_abs == out_t.broadcast(out_bcast_dims))
.template cast<T>() * .template cast<T>() *
sign.eval() * out_grad_t.broadcast(out_bcast_dims); sign.eval() * out_grad_t.broadcast(out_bcast_dims);
...@@ -188,7 +131,7 @@ static void DistGradFunction(const framework::ExecutionContext& context) { ...@@ -188,7 +131,7 @@ static void DistGradFunction(const framework::ExecutionContext& context) {
} }
} else { } else {
// dz = pow(abs(x-y)/out, p-1) * sign(x-y) * dout // dz = pow(abs(x-y)/out, p-1) * sign(x-y) * dout
if (platform::is_cpu_place(context.GetPlace())) { if (paddle::platform::is_cpu_place(dev_ctx.GetPlace())) {
grad_t.device(place) = grad_t.device(place) =
(x_minux_y_abs / (out_t + epsilon).broadcast(out_bcast_dims)) (x_minux_y_abs / (out_t + epsilon).broadcast(out_bcast_dims))
.pow(p - 1) * .pow(p - 1) *
...@@ -215,90 +158,66 @@ static void DistGradFunction(const framework::ExecutionContext& context) { ...@@ -215,90 +158,66 @@ static void DistGradFunction(const framework::ExecutionContext& context) {
// 2: if x or y is broadcasted in forward function, // 2: if x or y is broadcasted in forward function,
// the grad need to be sum along the broadcasted dimensions // the grad need to be sum along the broadcasted dimensions
if (x_grad) { if (x_grad) {
x_grad->mutable_data<T>(context.GetPlace()); dev_ctx.template Alloc<T>(x_grad);
auto x_grad_t = EigenTensor<T, Rank>::From(*x_grad, x_new_dims); auto x_grad_t = ETensor<T, Rank>::From(*x_grad, x_new_dims);
x_grad_t.device(place) = grad_t.reshape(x_reshape_dims) x_grad_t.device(place) = grad_t.reshape(x_reshape_dims)
.sum(reduce_dims) .sum(reduce_dims)
.reshape(x_grad_t.dimensions()); .reshape(x_grad_t.dimensions());
} }
if (y_grad) { if (y_grad) {
y_grad->mutable_data<T>(context.GetPlace()); dev_ctx.template Alloc<T>(y_grad);
auto y_grad_t = EigenTensor<T, Rank>::From(*y_grad, y_new_dims); auto y_grad_t = ETensor<T, Rank>::From(*y_grad, y_new_dims);
y_grad_t.device(place) = -grad_t.reshape(y_reshape_dims) y_grad_t.device(place) = -grad_t.reshape(y_reshape_dims)
.sum(reduce_dims) .sum(reduce_dims)
.reshape(y_grad_t.dimensions()); .reshape(y_grad_t.dimensions());
} }
} }
template <typename DeviceContext, typename T> template <typename T, typename Context>
class DistKernel : public framework::OpKernel<T> { void DistGradKernel(const Context& dev_ctx,
public: const DenseTensor& x,
void Compute(const framework::ExecutionContext& context) const override { const DenseTensor& y,
auto x_rank = context.Input<Tensor>("X")->dims().size(); const DenseTensor& out,
auto y_rank = context.Input<Tensor>("Y")->dims().size(); const DenseTensor& out_grad,
auto rank = std::max(x_rank, y_rank); float p,
PADDLE_ENFORCE_LE(rank, 6, DenseTensor* x_grad,
platform::errors::Unimplemented( DenseTensor* y_grad) {
"Op(dist) only support tensors with no more than 6 " auto x_rank = x.dims().size();
"dimensions, but X's rank is %d, Y's rank is %d.", auto y_rank = y.dims().size();
x_rank, y_rank)); auto rank = std::max(x_rank, y_rank);
switch (rank) { PADDLE_ENFORCE_LE(rank,
case 1: 6,
DistFunction<DeviceContext, T, 1>(context); phi::errors::Unimplemented(
break; "Op(dist) only support tensors with no more than 6 "
case 2: "dimensions, but X's rank is %d, Y's rank is %d.",
DistFunction<DeviceContext, T, 2>(context); x_rank,
break; y_rank));
case 3: switch (rank) {
DistFunction<DeviceContext, T, 3>(context); case 1:
break; DistGradFunction<Context, T, 1>(
case 4: dev_ctx, x, y, out, out_grad, p, x_grad, y_grad);
DistFunction<DeviceContext, T, 4>(context); break;
break; case 2:
case 5: DistGradFunction<Context, T, 2>(
DistFunction<DeviceContext, T, 5>(context); dev_ctx, x, y, out, out_grad, p, x_grad, y_grad);
break; break;
case 6: case 3:
DistFunction<DeviceContext, T, 6>(context); DistGradFunction<Context, T, 3>(
break; dev_ctx, x, y, out, out_grad, p, x_grad, y_grad);
} break;
case 4:
DistGradFunction<Context, T, 4>(
dev_ctx, x, y, out, out_grad, p, x_grad, y_grad);
break;
case 5:
DistGradFunction<Context, T, 5>(
dev_ctx, x, y, out, out_grad, p, x_grad, y_grad);
break;
case 6:
DistGradFunction<Context, T, 6>(
dev_ctx, x, y, out, out_grad, p, x_grad, y_grad);
break;
} }
}; }
template <typename DeviceContext, typename T>
class DistGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto x_rank = context.Input<Tensor>("X")->dims().size();
auto y_rank = context.Input<Tensor>("Y")->dims().size();
auto rank = std::max(x_rank, y_rank);
PADDLE_ENFORCE_LE(rank, 6,
platform::errors::Unimplemented(
"Op(dist) only support tensors with no more than 6 "
"dimensions, but X's rank is %d, Y's rank is %d.",
x_rank, y_rank));
switch (rank) {
case 1:
DistGradFunction<DeviceContext, T, 1>(context);
break;
case 2:
DistGradFunction<DeviceContext, T, 2>(context);
break;
case 3:
DistGradFunction<DeviceContext, T, 3>(context);
break;
case 4:
DistGradFunction<DeviceContext, T, 4>(context);
break;
case 5:
DistGradFunction<DeviceContext, T, 5>(context);
break;
case 6:
DistGradFunction<DeviceContext, T, 6>(context);
break;
}
}
};
} // namespace operators } // namespace phi
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <math.h>
#include <algorithm>
#include <vector>
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
template <typename T,
size_t D,
int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using ETensor = phi::EigenTensor<T, D, MajorType, IndexType>;
template <int Rank>
static void GetBraodcastDims(const phi::DDim& x_dims,
const phi::DDim& y_dims,
Eigen::DSizes<int, Rank>* x_bcast_dims,
Eigen::DSizes<int, Rank>* y_bcast_dims) {
int bcast_dims_remainder = 0;
for (int i = 0; i < x_dims.size(); ++i) {
if (x_dims[i] >= y_dims[i]) {
(*x_bcast_dims)[i] = 1;
(*y_bcast_dims)[i] = x_dims[i] / y_dims[i];
bcast_dims_remainder += x_dims[i] % y_dims[i];
} else {
(*y_bcast_dims)[i] = 1;
(*x_bcast_dims)[i] = y_dims[i] / x_dims[i];
bcast_dims_remainder += y_dims[i] % x_dims[i];
}
}
PADDLE_ENFORCE_EQ(bcast_dims_remainder,
0,
phi::errors::PreconditionNotMet(
"The input tensor of Op(dist) could not be broadcast, "
"X's shape is [%s], Y's shape is [%s].",
x_dims,
y_dims));
}
static phi::DDim GetNewDims(const phi::DDim& in_dims, int rank) {
std::vector<int64_t> new_dims_vec(rank);
if (in_dims.size() < rank) {
for (int i = 0; i < rank - in_dims.size(); ++i) {
new_dims_vec[i] = 1;
}
for (int i = 0; i < in_dims.size(); ++i) {
new_dims_vec[i + rank - in_dims.size()] = in_dims[i];
}
} else {
new_dims_vec = vectorize(in_dims);
}
return phi::make_ddim(new_dims_vec);
}
template <typename Context, typename T, int Rank>
static void DistFunction(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
float p,
DenseTensor* out) {
if (out) {
dev_ctx.template Alloc<T>(out);
}
auto x_dims = x.dims();
auto y_dims = y.dims();
// new dims with same size as rank, e.g. (rank=3, (4, 3) => (1, 4, 3))
phi::DDim x_new_dims = GetNewDims(x_dims, Rank);
phi::DDim y_new_dims = GetNewDims(y_dims, Rank);
auto x_t = ETensor<T, Rank>::From(x, x_new_dims);
auto y_t = ETensor<T, Rank>::From(y, y_new_dims);
auto out_t = ETensor<T, 1>::From(*out);
auto& place = *dev_ctx.eigen_device();
Eigen::DSizes<int, Rank> x_bcast_dims;
Eigen::DSizes<int, Rank> y_bcast_dims;
GetBraodcastDims<Rank>(x_new_dims, y_new_dims, &x_bcast_dims, &y_bcast_dims);
// p=0 means number of non-zero elements of (x-y)
// p=inf means the maximum of |x-y|
// p=-inf means the minimum of |x-y|
// otherwise, Lp-norm = pow(sum(pow(|x-y|, p)), 1/p)
if (p == 0) {
out_t.device(place) =
(x_t.broadcast(x_bcast_dims) != y_t.broadcast(y_bcast_dims))
.template cast<T>()
.sum();
} else if (p == INFINITY) {
out_t.device(place) =
(x_t.broadcast(x_bcast_dims) - y_t.broadcast(y_bcast_dims))
.abs()
.maximum();
} else if (p == -INFINITY) {
out_t.device(place) =
(x_t.broadcast(x_bcast_dims) - y_t.broadcast(y_bcast_dims))
.abs()
.minimum();
} else {
out_t.device(place) =
(x_t.broadcast(x_bcast_dims) - y_t.broadcast(y_bcast_dims))
.abs()
.pow(p)
.sum()
.pow(1.0 / p);
}
}
template <typename T, typename Context>
void DistKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
float p,
DenseTensor* out) {
auto x_rank = x.dims().size();
auto y_rank = y.dims().size();
auto rank = std::max(x_rank, y_rank);
PADDLE_ENFORCE_LE(rank,
6,
phi::errors::Unimplemented(
"Op(dist) only support tensors with no more than 6 "
"dimensions, but X's rank is %d, Y's rank is %d.",
x_rank,
y_rank));
switch (rank) {
case 1:
DistFunction<Context, T, 1>(dev_ctx, x, y, p, out);
break;
case 2:
DistFunction<Context, T, 2>(dev_ctx, x, y, p, out);
break;
case 3:
DistFunction<Context, T, 3>(dev_ctx, x, y, p, out);
break;
case 4:
DistFunction<Context, T, 4>(dev_ctx, x, y, p, out);
break;
case 5:
DistFunction<Context, T, 5>(dev_ctx, x, y, p, out);
break;
case 6:
DistFunction<Context, T, 6>(dev_ctx, x, y, p, out);
break;
}
}
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature DistGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("dist_grad",
{"X", "Y", "Out", GradVarName("Out")},
{"p"},
{GradVarName("X"), GradVarName("Y")});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(dist_grad, phi::DistGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册