未验证 提交 5898e9ab 编写于 作者: Y YuanRisheng 提交者: GitHub

[Phi]Move elementwise function to funcs directory (#39986)

* move elementwise function to funcs directory

* fix compile bugs

* modify according to comment
上级 66196573
...@@ -39,7 +39,7 @@ limitations under the License. */ ...@@ -39,7 +39,7 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h" #include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
#else #else
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h" #include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
#include "paddle/phi/kernels/gpu/elementwise.h" #include "paddle/phi/kernels/gpu/elementwise_grad.h"
#endif #endif
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -16,9 +16,6 @@ ...@@ -16,9 +16,6 @@
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
// only can include the headers in paddle/top/api dirs
#include "paddle/phi/kernels/gpu/elementwise.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -31,6 +31,7 @@ limitations under the License. */ ...@@ -31,6 +31,7 @@ limitations under the License. */
#include "paddle/phi/api/lib/utils/tensor_utils.h" #include "paddle/phi/api/lib/utils/tensor_utils.h"
#include "paddle/phi/kernels/cpu/elementwise.h" #include "paddle/phi/kernels/cpu/elementwise.h"
#include "paddle/phi/kernels/cpu/elementwise_grad.h"
#if defined(__NVCC__) || defined(__HIPCC__) #if defined(__NVCC__) || defined(__HIPCC__)
#ifdef __NVCC__ #ifdef __NVCC__
...@@ -133,7 +134,7 @@ inline void GetBroadcastDimsArrays(const framework::DDim &x_dims, ...@@ -133,7 +134,7 @@ inline void GetBroadcastDimsArrays(const framework::DDim &x_dims,
inline framework::DDim trim_trailing_singular_dims( inline framework::DDim trim_trailing_singular_dims(
const framework::DDim &dims) { const framework::DDim &dims) {
return phi::funcs::trim_trailing_singular_dims(dims); return phi::funcs::TrimTrailingSingularDims(dims);
} }
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP, template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP,
...@@ -152,7 +153,7 @@ void ElemwiseGradCompute(const framework::ExecutionContext &ctx, ...@@ -152,7 +153,7 @@ void ElemwiseGradCompute(const framework::ExecutionContext &ctx,
Tout>( Tout>(
dev_ctx, x_dim, y_dim, x, y, out, dout, axis, dx, dy, dx_op, dy_op); dev_ctx, x_dim, y_dim, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
} else { } else {
phi::ElemwiseGradComputeWithBroadcast<T, DX_OP, DY_OP, Tout>( phi::funcs::ElemwiseGradComputeWithBroadcast<T, DX_OP, DY_OP, Tout>(
dev_ctx, x_dim, y_dim, x, y, out, dout, axis, dx, dy, dx_op, dy_op); dev_ctx, x_dim, y_dim, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
} }
} }
...@@ -173,19 +174,9 @@ void ElementwiseComputeEx(const framework::ExecutionContext &ctx, ...@@ -173,19 +174,9 @@ void ElementwiseComputeEx(const framework::ExecutionContext &ctx,
const framework::Tensor *y, int axis, Functor func, const framework::Tensor *y, int axis, Functor func,
framework::Tensor *z) { framework::Tensor *z) {
z->mutable_data<OutType>(ctx.GetPlace()); z->mutable_data<OutType>(ctx.GetPlace());
if (platform::is_gpu_place(ctx.GetPlace())) { const auto &dev_ctx = ctx.template device_context<DeviceContext>();
#if defined(__NVCC__) || defined(__HIPCC__) phi::funcs::ElementwiseCompute<Functor, T, OutType>(dev_ctx, *x, *y, axis,
const auto &dev_ctx = func, z);
ctx.template device_context<platform::CUDADeviceContext>();
phi::ElementwiseCompute<Functor, T, OutType>(dev_ctx, *x, *y, axis, func,
z);
#endif
return;
}
const auto &dev_ctx =
ctx.template device_context<platform::CPUDeviceContext>();
phi::ElementwiseCompute<Functor, T, OutType>(dev_ctx, *x, *y, axis, func, z);
} }
// FusedElemwiseAndAct // FusedElemwiseAndAct
...@@ -443,8 +434,8 @@ void FusedElemwiseAndActComputeWithBroadcast( ...@@ -443,8 +434,8 @@ void FusedElemwiseAndActComputeWithBroadcast(
axis = (y_dim.size() == 0) ? x_dim.size() : axis; axis = (y_dim.size() == 0) ? x_dim.size() : axis;
int pre, n, post, is_run_common_broadcast; int pre, n, post, is_run_common_broadcast;
phi::funcs::get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post, phi::funcs::GetMidDims(x_dim, y_dim, axis, &pre, &n, &post,
&is_run_common_broadcast); &is_run_common_broadcast);
if (post == 1) { if (post == 1) {
int h = pre; int h = pre;
int w = n; int w = n;
...@@ -991,8 +982,8 @@ void FusedElemwiseAndActGradComputeWithBroadcast( ...@@ -991,8 +982,8 @@ void FusedElemwiseAndActGradComputeWithBroadcast(
axis = (y_dim.size() == 0) ? x_dim.size() : axis; axis = (y_dim.size() == 0) ? x_dim.size() : axis;
int pre, n, post, is_run_common_broadcast; int pre, n, post, is_run_common_broadcast;
phi::funcs::get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post, phi::funcs::GetMidDims(x_dim, y_dim, axis, &pre, &n, &post,
&is_run_common_broadcast); &is_run_common_broadcast);
const T *x_data = nullptr; const T *x_data = nullptr;
const T *y_data = nullptr; const T *y_data = nullptr;
if (x->IsInitialized()) x_data = x->data<T>(); if (x->IsInitialized()) x_data = x->data<T>();
......
...@@ -19,7 +19,7 @@ limitations under the License. */ ...@@ -19,7 +19,7 @@ limitations under the License. */
// only can include the headers in paddle/top/api dirs // only can include the headers in paddle/top/api dirs
#include "paddle/phi/api/lib/utils/tensor_utils.h" #include "paddle/phi/api/lib/utils/tensor_utils.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h" #include "paddle/phi/kernels/funcs/broadcast_function.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -151,12 +151,12 @@ struct GetInputIndex<false> { ...@@ -151,12 +151,12 @@ struct GetInputIndex<false> {
const std::vector<int>& output_strides, int output_idx, const std::vector<int>& output_strides, int output_idx,
int* index_array, int* lhs_idx, int* rhs_idx) { int* index_array, int* lhs_idx, int* rhs_idx) {
int out_dims_size = output_strides.size(); int out_dims_size = output_strides.size();
*lhs_idx = *lhs_idx = phi::funcs::GetElementwiseIndex(lhs_dims.data(), out_dims_size,
phi::GetElementwiseIndex(lhs_dims.data(), out_dims_size, index_array); index_array);
*rhs_idx = *rhs_idx = phi::funcs::GetElementwiseIndex(rhs_dims.data(), out_dims_size,
phi::GetElementwiseIndex(rhs_dims.data(), out_dims_size, index_array); index_array);
phi::UpdateElementwiseIndexArray(output_dims.data(), out_dims_size, phi::funcs::UpdateElementwiseIndexArray(output_dims.data(), out_dims_size,
index_array); index_array);
} }
}; };
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/elementwise_grad_base.h"
namespace phi {
// NOTE(dzhwinter): Only used in elementwise_add, elementwise_sub.
// explicit gradient can cut off X, Y, Out from gradient op
// In elementwise_add, elementwise_sub, we use dout as fake X, Y, Out to reuse
// elementwise code.
template <typename T, typename DX_OP, typename DY_OP>
void ElemwiseExplicitGradCompute(const CPUContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out,
const DenseTensor& dout,
int axis,
DenseTensor* dx,
DenseTensor* dy,
DX_OP dx_op,
DY_OP dy_op) {
const DDim& x_dim = x.dims();
const DDim& y_dim = y.dims();
if (x.dims() == y.dims()) {
funcs::ElemwiseGradComputeNoBroadcast<CPUContext, T, DX_OP, DY_OP>(dev_ctx,
x_dim,
y_dim,
dout,
dout,
out,
dout,
axis,
dx,
dy,
dx_op,
dy_op);
} else {
funcs::ElemwiseGradComputeWithBroadcast<T, DX_OP, DY_OP>(dev_ctx,
x_dim,
y_dim,
dout,
dout,
out,
dout,
axis,
dx,
dy,
dx_op,
dy_op);
}
}
/*
******************************
Add Grad
******************************
*/
template <typename T>
struct IdentityGrad {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout; }
};
template <typename T>
typename std::enable_if<std::is_floating_point<T>::value>::type
ElementwiseAddGrad(const CPUContext& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out,
const DenseTensor& dout,
DenseTensor* dx,
DenseTensor* dy,
int axis = -1) {
auto blas = phi::funcs::GetBlas<CPUContext, T>(ctx);
if (dx) {
blas.VCOPY(
dout.numel(), dout.data<T>(), dx->mutable_data<T>(ctx.GetPlace()));
}
if (dy) {
blas.VCOPY(
dout.numel(), dout.data<T>(), dy->mutable_data<T>(ctx.GetPlace()));
}
}
template <typename T>
typename std::enable_if<!std::is_floating_point<T>::value>::type
ElementwiseAddGrad(const CPUContext& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out,
const DenseTensor& dout,
DenseTensor* dx,
DenseTensor* dy,
int axis = -1) {
ElemwiseExplicitGradCompute<T, IdentityGrad<T>, IdentityGrad<T>>(
ctx, x, y, out, dout, axis, dx, dy, IdentityGrad<T>(), IdentityGrad<T>());
}
/*
******************************
Sub Grad
******************************
*/
template <typename T>
struct SubGradDX {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout; }
};
template <typename T>
struct SubGradDY {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return -dout; }
};
template <typename T>
void ElementwiseSubGrad(const CPUContext& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out,
const DenseTensor& dout,
DenseTensor* dx,
DenseTensor* dy,
int axis = -1) {
ElemwiseExplicitGradCompute<T, SubGradDX<T>, SubGradDY<T>>(
ctx, x, y, out, dout, axis, dx, dy, SubGradDX<T>(), SubGradDY<T>());
}
} // namespace phi
...@@ -17,7 +17,8 @@ ...@@ -17,7 +17,8 @@
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/copy_kernel.h" #include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/cpu/elementwise.h" #include "paddle/phi/kernels/cpu/elementwise_grad.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h" #include "paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h"
...@@ -33,7 +34,7 @@ void AddGradFunc(const CPUContext& dev_ctx, ...@@ -33,7 +34,7 @@ void AddGradFunc(const CPUContext& dev_ctx,
DenseTensor* dy, DenseTensor* dy,
int axis = -1) { int axis = -1) {
if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) { if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) {
elementwise_add_grad<T>(dev_ctx, x, y, out, dout, dx, dy); ElementwiseAddGrad<T>(dev_ctx, x, y, out, dout, dx, dy);
} else { } else {
ElemwiseExplicitGradCompute<T, IdentityGrad<T>, IdentityGrad<T>>( ElemwiseExplicitGradCompute<T, IdentityGrad<T>, IdentityGrad<T>>(
dev_ctx, dev_ctx,
...@@ -68,15 +69,7 @@ void AddDoubleGradKernel(const Context& dev_ctx, ...@@ -68,15 +69,7 @@ void AddDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& dout, const DenseTensor& dout,
int axis, int axis,
DenseTensor* ddout) { DenseTensor* ddout) {
phi::AddDoubleGradImpl<T>(dev_ctx, phi::AddDoubleGradImpl<T>(dev_ctx, y, ddx, ddy, dout, axis, ddout);
y,
ddx,
ddy,
dout,
axis,
ddout,
ElementwiseCompute<funcs::AddFunctor<T>, T>,
ElementwiseCompute<funcs::InverseAddFunctor<T>, T>);
} }
template <typename T, typename Context> template <typename T, typename Context>
...@@ -101,7 +94,7 @@ void SubtractGradKernel(const Context& dev_ctx, ...@@ -101,7 +94,7 @@ void SubtractGradKernel(const Context& dev_ctx,
DenseTensor* dy) { DenseTensor* dy) {
// skip out // skip out
auto* out = &dout; auto* out = &dout;
elementwise_sub_grad<T>(dev_ctx, x, y, *out, dout, dx, dy, axis); ElementwiseSubGrad<T>(dev_ctx, x, y, *out, dout, dx, dy, axis);
} }
template <typename T, typename Context> template <typename T, typename Context>
...@@ -112,15 +105,7 @@ void SubtractDoubleGradKernel(const Context& dev_ctx, ...@@ -112,15 +105,7 @@ void SubtractDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& dout, const DenseTensor& dout,
int axis, int axis,
DenseTensor* ddout) { DenseTensor* ddout) {
phi::SubtractDoubleGradImpl<T>( phi::SubtractDoubleGradImpl<T>(dev_ctx, y, ddx, ddy, dout, axis, ddout);
dev_ctx,
y,
ddx,
ddy,
dout,
axis,
ddout,
ElementwiseCompute<funcs::SubtractFunctor<T>, T>);
} }
} // namespace phi } // namespace phi
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cpu/elementwise.h" #include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/funcs/logical_functor.h" #include "paddle/phi/kernels/funcs/logical_functor.h"
// See Note [ Why still include the fluid headers? ] // See Note [ Why still include the fluid headers? ]
...@@ -24,15 +24,15 @@ ...@@ -24,15 +24,15 @@
namespace phi { namespace phi {
#define DEFINE_LOGICAL_BINARY_KERNEL(type) \ #define DEFINE_LOGICAL_BINARY_KERNEL(type) \
template <typename T, typename Context> \ template <typename T, typename Context> \
void Logical##type##Kernel(const Context& dev_ctx, \ void Logical##type##Kernel(const Context& dev_ctx, \
const DenseTensor& x, \ const DenseTensor& x, \
const DenseTensor& y, \ const DenseTensor& y, \
DenseTensor* out) { \ DenseTensor* out) { \
funcs::Logical##type##Functor<T> binary_func; \ funcs::Logical##type##Functor<T> binary_func; \
ElementwiseCompute<funcs::Logical##type##Functor<T>, T, bool>( \ funcs::ElementwiseCompute<funcs::Logical##type##Functor<T>, T, bool>( \
dev_ctx, x, y, -1, binary_func, out); \ dev_ctx, x, y, -1, binary_func, out); \
} }
DEFINE_LOGICAL_BINARY_KERNEL(And) DEFINE_LOGICAL_BINARY_KERNEL(And)
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cpu/elementwise.h" #include "paddle/phi/kernels/cpu/elementwise.h"
#include "paddle/phi/kernels/cpu/reduce.h" #include "paddle/phi/kernels/cpu/reduce.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/reduce_functor.h" #include "paddle/phi/kernels/funcs/reduce_functor.h"
...@@ -45,10 +46,10 @@ namespace phi { ...@@ -45,10 +46,10 @@ namespace phi {
auto x_dims = x.dims(); \ auto x_dims = x.dims(); \
auto y_dims = y.dims(); \ auto y_dims = y.dims(); \
if (x_dims.size() >= y_dims.size()) { \ if (x_dims.size() >= y_dims.size()) { \
ElementwiseCompute<funcs::name##Functor<T>, T>( \ funcs::ElementwiseCompute<funcs::name##Functor<T>, T>( \
dev_ctx, x, y, axis, funcs::name##Functor<T>(), out); \ dev_ctx, x, y, axis, funcs::name##Functor<T>(), out); \
} else { \ } else { \
ElementwiseCompute<funcs::Inverse##name##Functor<T>, T>( \ funcs::ElementwiseCompute<funcs::Inverse##name##Functor<T>, T>( \
dev_ctx, x, y, axis, funcs::Inverse##name##Functor<T>(), out); \ dev_ctx, x, y, axis, funcs::Inverse##name##Functor<T>(), out); \
} \ } \
} \ } \
...@@ -93,10 +94,10 @@ void DivideRawKernel(const Context& dev_ctx, ...@@ -93,10 +94,10 @@ void DivideRawKernel(const Context& dev_ctx,
auto x_dims = x.dims(); auto x_dims = x.dims();
auto y_dims = y.dims(); auto y_dims = y.dims();
if (x_dims.size() >= y_dims.size()) { if (x_dims.size() >= y_dims.size()) {
ElementwiseCompute<funcs::DivideFunctor<T>, T>( funcs::ElementwiseCompute<funcs::DivideFunctor<T>, T>(
dev_ctx, x, y, axis, funcs::DivideFunctor<T>(), out); dev_ctx, x, y, axis, funcs::DivideFunctor<T>(), out);
} else { } else {
ElementwiseCompute<funcs::InverseDivideFunctor<T>, T>( funcs::ElementwiseCompute<funcs::InverseDivideFunctor<T>, T>(
dev_ctx, x, y, axis, funcs::InverseDivideFunctor<T>(), out); dev_ctx, x, y, axis, funcs::InverseDivideFunctor<T>(), out);
} }
} }
......
...@@ -25,6 +25,8 @@ namespace kps = phi::kps; ...@@ -25,6 +25,8 @@ namespace kps = phi::kps;
namespace phi { namespace phi {
namespace funcs { namespace funcs {
#if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__)
struct DimensionsTransform { struct DimensionsTransform {
using DimVector = std::vector<int64_t>; using DimVector = std::vector<int64_t>;
typedef void (*MergeFunctor)( typedef void (*MergeFunctor)(
...@@ -183,8 +185,6 @@ struct DimensionsTransform { ...@@ -183,8 +185,6 @@ struct DimensionsTransform {
} }
}; };
#if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__)
template <typename T, int VecSize, int Rank, bool IsBoundary = false> template <typename T, int VecSize, int Rank, bool IsBoundary = false>
__device__ __forceinline__ void LoadData( __device__ __forceinline__ void LoadData(
T *dst, T *dst,
...@@ -578,6 +578,20 @@ void BroadcastKernel(const KPDevice &ctx, ...@@ -578,6 +578,20 @@ void BroadcastKernel(const KPDevice &ctx,
} }
} }
template <typename Functor, typename T, typename OutType = T>
void ElementwiseCompute(const GPUContext &dev_ctx,
const DenseTensor &x,
const DenseTensor &y,
int axis,
Functor func,
DenseTensor *z) {
std::vector<const DenseTensor *> ins = {&x, &y};
std::vector<DenseTensor *> outs = {z};
z->mutable_data<OutType>(dev_ctx.GetPlace());
BroadcastKernel<ElementwiseType::kBinary, T, OutType, Functor, 1>(
dev_ctx, ins, &outs, axis, func);
}
#endif #endif
} // namespace funcs } // namespace funcs
......
...@@ -18,7 +18,8 @@ limitations under the License. */ ...@@ -18,7 +18,8 @@ limitations under the License. */
#include "paddle/phi/backends/all_context.h" #include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/for_range.h" #include "paddle/phi/kernels/funcs/common_shape.h"
#include "paddle/phi/kernels/funcs/elementwise_utils.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__) #if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__)
...@@ -44,28 +45,6 @@ using ConditionalT = ...@@ -44,28 +45,6 @@ using ConditionalT =
namespace funcs { namespace funcs {
using DDim = phi::DDim; using DDim = phi::DDim;
template <typename T, typename DX_OP, typename DY_OP, typename Tout = T>
struct ElemwiseGradNoBroadcast {
const T *x_;
const T *y_;
const Tout *out_;
const Tout *dout_;
HOSTDEVICE void operator()(size_t i) {
if (dx_ != nullptr) {
dx_[i] = dx_op_(x_[i], y_[i], out_[i], dout_[i]);
}
if (dy_ != nullptr) {
dy_[i] = dy_op_(x_[i], y_[i], out_[i], dout_[i]);
}
}
DX_OP dx_op_;
DY_OP dy_op_;
T *dx_;
T *dy_;
};
template <typename T, typename DeviceContext> template <typename T, typename DeviceContext>
class RowwiseTransformIterator; class RowwiseTransformIterator;
...@@ -293,73 +272,172 @@ class TransformFunctor { ...@@ -293,73 +272,172 @@ class TransformFunctor {
bool is_xsize_larger_; bool is_xsize_larger_;
}; };
inline DDim trim_trailing_singular_dims(const DDim &dims) { template <typename Functor, typename T, typename OutType = T>
// Remove trailing dimensions of size 1 for y void CommonForwardBroadcastCPU(const DenseTensor &x,
auto actual_dims_size = dims.size(); const DenseTensor &y,
for (; actual_dims_size != 0; --actual_dims_size) { DenseTensor *z,
if (dims[actual_dims_size - 1] != 1) break; int *x_dims_array,
} int *y_dims_array,
if (actual_dims_size == dims.size()) return dims; int *out_dims_array,
std::vector<int> trim_dims; int max_dim,
trim_dims.resize(actual_dims_size); const CPUContext &ctx,
for (int i = 0; i < actual_dims_size; ++i) { Functor func,
trim_dims[i] = dims[i]; const bool is_xsize_larger = true) {
} std::vector<int> index_array(max_dim, 0);
if (trim_dims.size() == 0) { const T *x_data = x.data<T>();
return DDim(phi::make_dim()); const T *y_data = y.data<T>();
PADDLE_ENFORCE_NOT_NULL(
x_data, errors::InvalidArgument("The input X should not be empty."));
PADDLE_ENFORCE_NOT_NULL(
y_data, errors::InvalidArgument("The input Y should not be empty."));
OutType *out_data = ctx.Alloc<OutType>(z);
const int out_size = std::accumulate(
out_dims_array, out_dims_array + max_dim, 1, std::multiplies<int>());
int x_index, y_index;
for (int out_index = 0; out_index < out_size; ++out_index) {
x_index = GetElementwiseIndex(x_dims_array, max_dim, index_array.data());
y_index = GetElementwiseIndex(y_dims_array, max_dim, index_array.data());
if (is_xsize_larger) {
out_data[out_index] = func(x_data[x_index], y_data[y_index]);
} else {
out_data[out_index] = func(y_data[y_index], x_data[x_index]);
}
UpdateElementwiseIndexArray(out_dims_array, max_dim, index_array.data());
} }
DDim actual_dims = phi::make_ddim(trim_dims);
return actual_dims;
} }
/* template <typename Functor, typename T, typename OutType = T>
* Out = X ⊙ Y void CommonElementwiseBroadcastForward(const CPUContext &dev_ctx,
* If Y's shape does not match X' shape, they will be reshaped. const DenseTensor &x,
* For example: const DenseTensor &y,
* 1. shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1 DenseTensor *z,
* pre=2, n=3*4, post=5 const DDim &x_dims,
* x.shape(2, 12, 5) * y.shape(1, 12, 1).broadcast(2, 12, 5) const DDim &y_dims,
* 2. shape(X) = (2, 3, 4, 5), shape(Y) = (4,5) Functor func,
* pre=2*3, n=4*5, post=1 int axis,
* x.shape(6, 20, 1) * y.shape(1, 20, 1).broadcast(6, 20, 1) const bool is_xsize_larger = true) {
* int max_dim = (std::max)(x_dims.size(), y_dims.size());
* New parameter: *is_run_common_broadcast* is a flag to record whether to run axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis);
* common broadcast code. PADDLE_ENFORCE_GE(
*/ axis,
inline void get_mid_dims(const DDim &x_dims, 0,
const DDim &y_dims, phi::errors::InvalidArgument(
const int axis, "Axis should be great than or equal to 0, but received axis is %d.",
int *pre, axis));
int *n, PADDLE_ENFORCE_LT(axis,
int *post, max_dim,
int *is_run_common_broadcast) { phi::errors::InvalidArgument(
*pre = 1; "Axis should be less than %d, but received axis is %d.",
*n = 1; max_dim,
*post = 1; axis));
*is_run_common_broadcast = 0; std::vector<int> x_dims_array(max_dim);
for (int i = 0; i < axis; ++i) { std::vector<int> y_dims_array(max_dim);
(*pre) *= x_dims[i]; std::vector<int> out_dims_array(max_dim);
} GetBroadcastDimsArrays(x_dims,
for (int i = 0; i < y_dims.size(); ++i) { y_dims,
if (x_dims[i + axis] != y_dims[i]) { x_dims_array.data(),
PADDLE_ENFORCE_EQ(y_dims[i] == 1 || x_dims[i + axis] == 1, y_dims_array.data(),
true, out_dims_array.data(),
phi::errors::InvalidArgument( max_dim,
"Broadcast dimension mismatch. Operands " axis);
"could not be broadcast together with the shape of "
"X = [%s] and the shape of Y = [%s]. Received [%d] " CommonForwardBroadcastCPU<Functor, T, OutType>(x,
"in X is not equal to [%d] in Y.", y,
x_dims, z,
y_dims, x_dims_array.data(),
x_dims[i + axis], y_dims_array.data(),
y_dims[i])); out_dims_array.data(),
*is_run_common_broadcast = 1; max_dim,
return; dev_ctx,
} func,
(*n) *= y_dims[i]; is_xsize_larger);
} }
for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) {
(*post) *= x_dims[i]; // It is a common CPU implementation to compute binary calculation with the
// support of broadcast. Note:
// 1. CPU implementation cannot support the case when x needs broadcast, thus
// this function need to be called with XxxFunctor and XxxInverseFunctor,
// like AddFunctor and InverseAddFunctor.
// 2. The corresponding GPU implementation supports all the broadcast cases,
// thus there is no need to define and call with XxxInverseFunctor.
// TODO(liuyiqun): optimize the CPU implementation to support all broadcast
// cases and avoid the need of XxxInverseFunctor.
template <typename Functor, typename T, typename OutType = T>
void ElementwiseCompute(const CPUContext &dev_ctx,
const DenseTensor &x,
const DenseTensor &y,
int axis,
Functor func,
DenseTensor *z) {
dev_ctx.Alloc<OutType>(z);
auto x_dims = x.dims();
auto y_dims = y.dims();
bool is_xsize_larger = true;
int max_dim = x_dims.size();
if (x_dims.size() < y_dims.size()) {
is_xsize_larger = false;
max_dim = y_dims.size();
}
TransformFunctor<Functor, T, CPUContext, OutType> functor(
x, y, z, dev_ctx, func, is_xsize_larger);
if (x_dims == y_dims) {
functor.Run();
return;
}
axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis);
PADDLE_ENFORCE_GE(
axis,
0,
errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.",
axis));
PADDLE_ENFORCE_LT(axis,
max_dim,
errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.",
max_dim,
axis));
int pre, n, post, is_run_common_broadcast, axis_trim = 0;
if (is_xsize_larger) {
auto y_dims_trimed = TrimTrailingSingularDims(y_dims);
axis_trim = (y_dims_trimed.size() == 0) ? x_dims.size() : axis;
GetMidDims(x_dims,
y_dims_trimed,
axis_trim,
&pre,
&n,
&post,
&is_run_common_broadcast);
} else {
auto x_dims_trimed = TrimTrailingSingularDims(x_dims);
axis_trim = (x_dims_trimed.size() == 0) ? y_dims.size() : axis;
GetMidDims(y_dims,
x_dims_trimed,
axis_trim,
&pre,
&n,
&post,
&is_run_common_broadcast);
}
// special case for common implementation.
// case 1: x=[2,3,1,5], y=[2,1,4,1]
// case 2: x=[2,3,4], y=[1,1,4]
if (is_run_common_broadcast == 1) {
CommonElementwiseBroadcastForward<Functor, T, OutType>(
dev_ctx, x, y, z, x_dims, y_dims, func, axis, is_xsize_larger);
return;
}
if (post == 1) {
functor.RunRowWise(n, pre);
return;
} else {
functor.RunMidWise(n, pre, post);
return;
} }
} }
...@@ -395,41 +473,11 @@ static inline void GetDoubleGradSafeTensor(const DeviceContext &dev_ctx, ...@@ -395,41 +473,11 @@ static inline void GetDoubleGradSafeTensor(const DeviceContext &dev_ctx,
auto meta = phi::DenseTensorMeta(x.dtype(), x.dims(), x.layout()); auto meta = phi::DenseTensorMeta(x.dtype(), x.dims(), x.layout());
*ddx_safe = phi::Empty(dev_ctx, std::move(meta)); *ddx_safe = phi::Empty(dev_ctx, std::move(meta));
ddx_safe->mutable_data(dev_ctx.GetPlace()); ddx_safe->mutable_data(dev_ctx.GetPlace());
phi::funcs::SetConstant<DeviceContext, T> set_zero; SetConstant<DeviceContext, T> set_zero;
set_zero(dev_ctx, ddx_safe, static_cast<T>(0)); set_zero(dev_ctx, ddx_safe, static_cast<T>(0));
} }
} }
template <typename DeviceContext,
typename T,
typename DX_OP,
typename DY_OP,
typename Tout = T>
void ElemwiseGradComputeNoBroadcast(const DeviceContext &dev_ctx,
const DDim &x_dim,
const DDim &y_dim,
const DenseTensor &x,
const DenseTensor &y,
const DenseTensor &out,
const DenseTensor &dout,
int axis,
DenseTensor *dx,
DenseTensor *dy,
DX_OP dx_op,
DY_OP dy_op) {
size_t N = static_cast<size_t>(phi::product(x_dim));
phi::funcs::ForRange<DeviceContext> for_range(dev_ctx, N);
for_range(ElemwiseGradNoBroadcast<T, DX_OP, DY_OP, Tout>{
x.data<T>(),
y.data<T>(),
out.data<Tout>(),
dout.data<Tout>(),
dx_op,
dy_op,
dx == nullptr ? nullptr : dev_ctx.template Alloc<T>(dx),
dy == nullptr ? nullptr : dev_ctx.template Alloc<T>(dy)});
}
inline void ElementwiseGradPreProcess(const DenseTensor &dout, inline void ElementwiseGradPreProcess(const DenseTensor &dout,
DenseTensor *dx) { DenseTensor *dx) {
if (dx != nullptr) { if (dx != nullptr) {
...@@ -806,6 +854,7 @@ void ElementwiseKernel(const KPDevice &ctx, ...@@ -806,6 +854,7 @@ void ElementwiseKernel(const KPDevice &ctx,
} }
} }
} }
#endif #endif
} // namespace funcs } // namespace funcs
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
namespace phi {
namespace funcs {
using DDim = phi::DDim;
/*
* Out = X ⊙ Y
* If Y's shape does not match X' shape, they will be reshaped.
* For example:
* 1. shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1
* pre=2, n=3*4, post=5
* x.shape(2, 12, 5) * y.shape(1, 12, 1).broadcast(2, 12, 5)
* 2. shape(X) = (2, 3, 4, 5), shape(Y) = (4,5)
* pre=2*3, n=4*5, post=1
* x.shape(6, 20, 1) * y.shape(1, 20, 1).broadcast(6, 20, 1)
*
* New parameter: *is_run_common_broadcast* is a flag to record whether to run
* common broadcast code.
*/
inline void GetMidDims(const DDim &x_dims,
const DDim &y_dims,
const int axis,
int *pre,
int *n,
int *post,
int *is_run_common_broadcast) {
*pre = 1;
*n = 1;
*post = 1;
*is_run_common_broadcast = 0;
for (int i = 0; i < axis; ++i) {
(*pre) *= x_dims[i];
}
for (int i = 0; i < y_dims.size(); ++i) {
if (x_dims[i + axis] != y_dims[i]) {
PADDLE_ENFORCE_EQ(y_dims[i] == 1 || x_dims[i + axis] == 1,
true,
phi::errors::InvalidArgument(
"Broadcast dimension mismatch. Operands "
"could not be broadcast together with the shape of "
"X = [%s] and the shape of Y = [%s]. Received [%d] "
"in X is not equal to [%d] in Y.",
x_dims,
y_dims,
x_dims[i + axis],
y_dims[i]));
*is_run_common_broadcast = 1;
return;
}
(*n) *= y_dims[i];
}
for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) {
(*post) *= x_dims[i];
}
}
inline DDim TrimTrailingSingularDims(const DDim &dims) {
// Remove trailing dimensions of size 1 for y
auto actual_dims_size = dims.size();
for (; actual_dims_size != 0; --actual_dims_size) {
if (dims[actual_dims_size - 1] != 1) break;
}
if (actual_dims_size == dims.size()) return dims;
std::vector<int> trim_dims;
trim_dims.resize(actual_dims_size);
for (int i = 0; i < actual_dims_size; ++i) {
trim_dims[i] = dims[i];
}
if (trim_dims.size() == 0) {
return DDim(phi::make_dim());
}
DDim actual_dims = phi::make_ddim(trim_dims);
return actual_dims;
}
inline int GetElementwiseIndex(const int *x_dims_array,
const int max_dim,
const int *index_array) {
int index_ = 0;
for (int i = 0; i < max_dim; i++) {
if (x_dims_array[i] > 1) {
index_ = index_ * x_dims_array[i] + index_array[i];
}
}
return index_;
}
inline void UpdateElementwiseIndexArray(const int *out_dims_array,
const int max_dim,
int *index_array) {
for (int i = max_dim - 1; i >= 0; --i) {
++index_array[i];
if (index_array[i] >= out_dims_array[i]) {
index_array[i] -= out_dims_array[i];
} else {
break;
}
}
}
} // namespace funcs
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/funcs/elementwise_grad_base.h"
#include "paddle/phi/kernels/gpu/reduce.h"
namespace phi {
/*
******************************
Add Grad
******************************
*/
template <typename T>
static __global__ void SimpleElemwiseAddGradCUDAKernel(
const T *__restrict__ dout, int size, int vec_size, T *dx, T *dy) {
int tid = BLOCK_ID_X * BLOCK_NUM_X + THREAD_ID_X;
int stride = GRID_NUM_X * BLOCK_NUM_X;
int loop = size / vec_size;
int remainder = size % vec_size;
const float4 *dout_vec = reinterpret_cast<const float4 *>(dout);
float4 *dx_vec = reinterpret_cast<float4 *>(dx);
float4 *dy_vec = reinterpret_cast<float4 *>(dy);
float4 tmp_loop;
for (int i = tid; i < loop; i += stride) {
tmp_loop = dout_vec[i];
dx_vec[i] = tmp_loop;
dy_vec[i] = tmp_loop;
}
if (tid == loop && remainder != 0) {
T tmp_rem;
while (remainder) {
int idx = size - remainder;
remainder--;
tmp_rem = dout[idx];
dx[idx] = tmp_rem;
dy[idx] = tmp_rem;
}
}
}
template <typename T>
void DefaultElementwiseAddGrad(const GPUContext &ctx,
const DenseTensor &x,
const DenseTensor &y,
const DenseTensor &out,
const DenseTensor &dout,
DenseTensor *dx,
DenseTensor *dy,
int axis = -1) {
auto *dout_data = dout.data<T>();
// dx
if (dx != nullptr) {
auto *dx_data = dx->mutable_data<T>(ctx.GetPlace());
if (dx->dims() == dout.dims()) {
if (dx_data != dout_data) {
phi::Copy(ctx, dout, ctx.GetPlace(), false, dx);
}
} else {
// For inplace strategy, dx will be stored in addr of dout, which makes
// the result of dy wrong.
if (dx->IsSharedBufferWith(dout)) {
dx->clear();
dx->mutable_data<T>(x.dims(), ctx.GetPlace());
}
std::vector<int> reduce_dims =
funcs::GetReduceDim(x.dims(), out.dims(), axis);
gpuStream_t stream = ctx.stream();
kernels::TensorReduceImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
ctx, dout, dx, kps::IdentityFunctor<T>(), reduce_dims, stream);
}
}
// dy
if (dy != nullptr) {
auto *dy_data = dy->mutable_data<T>(ctx.GetPlace());
if (dy->dims() == dout.dims()) {
if (dy_data != dout_data) {
phi::Copy(ctx, dout, ctx.GetPlace(), false, dy);
}
} else {
std::vector<int> reduce_dims =
funcs::GetReduceDim(y.dims(), out.dims(), axis);
gpuStream_t stream = ctx.stream();
kernels::TensorReduceImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
ctx, dout, dy, kps::IdentityFunctor<T>(), reduce_dims, stream);
}
}
}
template <typename T>
void ElementwiseAddGrad(const GPUContext &ctx,
const DenseTensor &x,
const DenseTensor &y,
const DenseTensor &out,
const DenseTensor &dout,
DenseTensor *dx,
DenseTensor *dy) {
ctx.template Alloc<T>(dx);
ctx.template Alloc<T>(dy);
auto *dx_data = dx->data<T>();
auto *dy_data = dy->data<T>();
auto *dout_data = dout.data<T>();
if (dx_data == dout_data && dy_data != dout_data) {
VLOG(4) << "Special case when dx_data is the same as dout_data, "
"only need copy dout to dy";
phi::Copy(ctx, dout, ctx.GetPlace(), false, dy);
} else if (dx_data != dout_data && dy_data == dout_data) {
VLOG(4) << "Special case when dy_data is the same as dout_data, "
"only need copy dout to dx";
phi::Copy(ctx, dout, ctx.GetPlace(), false, dx);
} else if (dx_data != dout_data && dy_data != dout_data) {
auto size = x.numel();
int vec_size = max(static_cast<int>(sizeof(float4) / sizeof(T)), 1);
dim3 block_size = dim3(PREDEFINED_BLOCK_SIZE, 1);
dim3 grid_size =
dim3(((size + vec_size - 1) / vec_size + PREDEFINED_BLOCK_SIZE - 1) /
PREDEFINED_BLOCK_SIZE,
1);
SimpleElemwiseAddGradCUDAKernel<
T><<<grid_size, block_size, 0, ctx.stream()>>>(
dout.data<T>(),
size,
vec_size,
dx->mutable_data<T>(ctx.GetPlace()),
dy->mutable_data<T>(ctx.GetPlace()));
} else {
VLOG(4) << "Special case when dy_data is the same as dout_data, "
"and dx_data is the same as dout_data, do not need "
"any operator";
}
}
/*
******************************
Sub Grad
******************************
*/
template <typename T>
static __global__ void SimpleElemwiseSubGradCUDAKernel(const T *dout,
int64_t size,
T *dx,
T *dy) {
int col = BLOCK_ID_X * BLOCK_NUM_X + THREAD_ID_X;
while (col < size) {
if (dx != nullptr) {
dx[col] = dout[col];
}
dy[col] = -dout[col];
col += BLOCK_NUM_X * GRID_NUM_X;
}
}
template <typename T>
void default_elementwise_sub_grad(const GPUContext &ctx,
const DenseTensor &x,
const DenseTensor &y,
const DenseTensor &out,
const DenseTensor &dout,
DenseTensor *dx,
DenseTensor *dy,
int axis = -1) {
auto *dout_data = dout.data<T>();
// dx
if (dx != nullptr) {
auto *dx_data = dx->mutable_data<T>(ctx.GetPlace());
if (dx->dims() == dout.dims()) {
if (dx_data != dout_data) {
phi::Copy(ctx, dout, ctx.GetPlace(), false, dx);
}
} else {
// For inplace strategy, dx will be stored in addr of dout, which makes
// the result of dy wrong.
if (dx->IsSharedBufferWith(dout)) {
dx->clear();
dx->mutable_data<T>(x.dims(), ctx.GetPlace());
}
std::vector<int> reduce_dims =
funcs::GetReduceDim(x.dims(), out.dims(), axis);
gpuStream_t stream = ctx.stream();
kernels::TensorReduceImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
ctx, dout, dx, kps::IdentityFunctor<T>(), reduce_dims, stream);
}
}
// dy
if (dy != nullptr) {
auto *dy_data = dy->mutable_data<T>(ctx.GetPlace());
if (dy->dims() == dout.dims()) {
if (dy_data != dout_data) {
dim3 block_size = dim3(PREDEFINED_BLOCK_SIZE, 1);
auto size = dy->numel();
dim3 grid_size =
dim3((size + PREDEFINED_BLOCK_SIZE - 1) / PREDEFINED_BLOCK_SIZE, 1);
SimpleElemwiseSubGradCUDAKernel<
T><<<grid_size, block_size, 0, ctx.stream()>>>(
dout.data<T>(), size, nullptr, dy->mutable_data<T>(ctx.GetPlace()));
}
} else {
std::vector<int> reduce_dims =
funcs::GetReduceDim(y.dims(), out.dims(), axis);
gpuStream_t stream = ctx.stream();
kernels::TensorReduceImpl<T, T, kps::AddFunctor, kps::InverseFunctor<T>>(
ctx, dout, dy, kps::InverseFunctor<T>(), reduce_dims, stream);
}
}
}
template <typename T>
void elementwise_sub_grad(const GPUContext &ctx,
const DenseTensor &x,
const DenseTensor &y,
const DenseTensor &out,
const DenseTensor &dout,
DenseTensor *dx,
DenseTensor *dy) {
dim3 block_size = dim3(PREDEFINED_BLOCK_SIZE, 1);
auto size = x.numel();
dim3 grid_size =
dim3((size + PREDEFINED_BLOCK_SIZE - 1) / PREDEFINED_BLOCK_SIZE, 1);
SimpleElemwiseSubGradCUDAKernel<
T><<<grid_size, block_size, 0, ctx.stream()>>>(
dout.data<T>(),
size,
dx->mutable_data<T>(ctx.GetPlace()),
dy->mutable_data<T>(ctx.GetPlace()));
}
} // namespace phi
...@@ -17,8 +17,9 @@ ...@@ -17,8 +17,9 @@
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/copy_kernel.h" #include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/gpu/elementwise.h" #include "paddle/phi/kernels/gpu/elementwise_grad.h"
#include "paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h" #include "paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h"
namespace phi { namespace phi {
...@@ -33,9 +34,9 @@ void AddGradFunc(const GPUContext& dev_ctx, ...@@ -33,9 +34,9 @@ void AddGradFunc(const GPUContext& dev_ctx,
DenseTensor* dy, DenseTensor* dy,
int axis = -1) { int axis = -1) {
if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) { if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) {
elementwise_add_grad<T>(dev_ctx, x, y, out, dout, dx, dy); ElementwiseAddGrad<T>(dev_ctx, x, y, out, dout, dx, dy);
} else { } else {
default_elementwise_add_grad<T>(dev_ctx, x, y, out, dout, dx, dy, axis); DefaultElementwiseAddGrad<T>(dev_ctx, x, y, out, dout, dx, dy, axis);
} }
} }
...@@ -58,15 +59,7 @@ void AddDoubleGradKernel(const Context& dev_ctx, ...@@ -58,15 +59,7 @@ void AddDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& dout, const DenseTensor& dout,
int axis, int axis,
DenseTensor* ddout) { DenseTensor* ddout) {
phi::AddDoubleGradImpl<T>(dev_ctx, phi::AddDoubleGradImpl<T>(dev_ctx, y, ddx, ddy, dout, axis, ddout);
y,
ddx,
ddy,
dout,
axis,
ddout,
ElementwiseCompute<funcs::AddFunctor<T>, T>,
ElementwiseCompute<funcs::InverseAddFunctor<T>, T>);
} }
template <typename T, typename Context> template <typename T, typename Context>
...@@ -106,15 +99,7 @@ void SubtractDoubleGradKernel(const Context& dev_ctx, ...@@ -106,15 +99,7 @@ void SubtractDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& dout, const DenseTensor& dout,
int axis, int axis,
DenseTensor* ddout) { DenseTensor* ddout) {
phi::SubtractDoubleGradImpl<T>( phi::SubtractDoubleGradImpl<T>(dev_ctx, y, ddx, ddy, dout, axis, ddout);
dev_ctx,
y,
ddx,
ddy,
dout,
axis,
ddout,
ElementwiseCompute<funcs::SubtractFunctor<T>, T>);
} }
} // namespace phi } // namespace phi
......
...@@ -16,9 +16,8 @@ ...@@ -16,9 +16,8 @@
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h" #include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/logical_functor.h" #include "paddle/phi/kernels/funcs/logical_functor.h"
#include "paddle/phi/kernels/gpu/elementwise.h"
namespace phi { namespace phi {
......
...@@ -15,8 +15,8 @@ limitations under the License. */ ...@@ -15,8 +15,8 @@ limitations under the License. */
#include "paddle/phi/kernels/math_kernel.h" #include "paddle/phi/kernels/math_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/gpu/elementwise.h"
#include "paddle/phi/kernels/gpu/reduce.h" #include "paddle/phi/kernels/gpu/reduce.h"
#ifdef __NVCC__ #ifdef __NVCC__
......
...@@ -15,7 +15,7 @@ limitations under the License. */ ...@@ -15,7 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h" #include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h"
namespace phi { namespace phi {
...@@ -47,19 +47,14 @@ void AddGradImpl(const Context& dev_ctx, ...@@ -47,19 +47,14 @@ void AddGradImpl(const Context& dev_ctx,
} }
} }
template <typename T, template <typename T, typename Context>
typename Context,
typename GradFunc,
typename GradInverseFunc>
void AddDoubleGradImpl(const Context& dev_ctx, void AddDoubleGradImpl(const Context& dev_ctx,
const DenseTensor& y, const DenseTensor& y,
const paddle::optional<const DenseTensor&>& ddx, const paddle::optional<const DenseTensor&>& ddx,
const paddle::optional<const DenseTensor&>& ddy, const paddle::optional<const DenseTensor&>& ddy,
const DenseTensor& dout, const DenseTensor& dout,
int axis, int axis,
DenseTensor* ddout, DenseTensor* ddout) {
GradFunc grad_func,
GradInverseFunc grad_inverse_func) {
// ddOut = ddx + ddy // ddOut = ddx + ddy
if (ddout) { if (ddout) {
DenseTensor ddx_safe, ddy_safe; DenseTensor ddx_safe, ddy_safe;
...@@ -72,28 +67,28 @@ void AddDoubleGradImpl(const Context& dev_ctx, ...@@ -72,28 +67,28 @@ void AddDoubleGradImpl(const Context& dev_ctx,
auto ddx_dims = ddx_safe.dims(); auto ddx_dims = ddx_safe.dims();
auto ddy_dims = ddy_safe.dims(); auto ddy_dims = ddy_safe.dims();
if (ddx_dims.size() >= ddy_dims.size()) { if (ddx_dims.size() >= ddy_dims.size()) {
grad_func( funcs::ElementwiseCompute<funcs::AddFunctor<T>, T>(
dev_ctx, ddx_safe, ddy_safe, axis, funcs::AddFunctor<T>(), ddout); dev_ctx, ddx_safe, ddy_safe, axis, funcs::AddFunctor<T>(), ddout);
} else { } else {
grad_inverse_func(dev_ctx, funcs::ElementwiseCompute<funcs::InverseAddFunctor<T>, T>(
ddx_safe, dev_ctx,
ddy_safe, ddx_safe,
axis, ddy_safe,
funcs::InverseAddFunctor<T>(), axis,
ddout); funcs::InverseAddFunctor<T>(),
ddout);
} }
} }
} }
template <typename T, typename Context, typename GradFunc> template <typename T, typename Context>
void SubtractDoubleGradImpl(const Context& dev_ctx, void SubtractDoubleGradImpl(const Context& dev_ctx,
const DenseTensor& y, const DenseTensor& y,
const paddle::optional<const DenseTensor&>& ddx, const paddle::optional<const DenseTensor&>& ddx,
const paddle::optional<const DenseTensor&>& ddy, const paddle::optional<const DenseTensor&>& ddy,
const DenseTensor& dout, const DenseTensor& dout,
int axis, int axis,
DenseTensor* ddout, DenseTensor* ddout) {
GradFunc grad_func) {
// DDOut = ddx - ddy // DDOut = ddx - ddy
if (ddout) { if (ddout) {
DenseTensor ddx_safe, ddy_safe; DenseTensor ddx_safe, ddy_safe;
...@@ -103,7 +98,7 @@ void SubtractDoubleGradImpl(const Context& dev_ctx, ...@@ -103,7 +98,7 @@ void SubtractDoubleGradImpl(const Context& dev_ctx,
dev_ctx, y, ddy.get_ptr(), &ddy_safe); dev_ctx, y, ddy.get_ptr(), &ddy_safe);
ddout->mutable_data<T>(dev_ctx.GetPlace()); ddout->mutable_data<T>(dev_ctx.GetPlace());
grad_func( funcs::ElementwiseCompute<funcs::SubtractFunctor<T>, T>(
dev_ctx, ddx_safe, ddy_safe, axis, funcs::SubtractFunctor<T>(), ddout); dev_ctx, ddx_safe, ddy_safe, axis, funcs::SubtractFunctor<T>(), ddout);
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册