未验证 提交 676903d5 编写于 作者: Y YuanRisheng 提交者: GitHub

[PTen]Refactor impl of elementwise op grad_kernel (Part1) (#38873)

* refactor the impl of elementwise grad kernel

* refactor impl of elementwise grad kernel(cuda)

* fix compile bugs
上级 572ba24e
......@@ -150,9 +150,12 @@ struct GetInputIndex<false> {
const std::vector<int>& output_strides, int output_idx,
int* index_array, int* lhs_idx, int* rhs_idx) {
int out_dims_size = output_strides.size();
*lhs_idx = GetElementwiseIndex(lhs_dims.data(), out_dims_size, index_array);
*rhs_idx = GetElementwiseIndex(rhs_dims.data(), out_dims_size, index_array);
UpdateElementwiseIndexArray(output_dims.data(), out_dims_size, index_array);
*lhs_idx =
pten::GetElementwiseIndex(lhs_dims.data(), out_dims_size, index_array);
*rhs_idx =
pten::GetElementwiseIndex(rhs_dims.data(), out_dims_size, index_array);
pten::UpdateElementwiseIndexArray(output_dims.data(), out_dims_size,
index_array);
}
};
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/kernels/funcs/elementwise_base.h"
......@@ -22,6 +23,8 @@ limitations under the License. */
namespace pten {
// FORWARD CODE
// Add
template <typename DevCtx, typename T, class Enable = void>
struct SameDimsAddFunctor {
......@@ -206,6 +209,56 @@ inline int GetElementwiseIndex(const int* x_dims_array,
return index_;
}
template <typename T, typename DX_OP, typename DY_OP, typename Tout = T>
void CommonGradBroadcastCPU(const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out,
const DenseTensor& dout,
DenseTensor* dx,
DenseTensor* dy,
int* x_dims_array,
int* y_dims_array,
int* out_dims_array,
int max_dim,
const CPUContext& ctx,
DX_OP dx_op,
DY_OP dy_op) {
std::vector<int> index_array(max_dim, 0);
const T* x_data = x.data<T>();
const T* y_data = y.data<T>();
const Tout* out_data = out.data<Tout>();
const Tout* dout_data = dout.data<Tout>();
T* dx_data = dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace());
T* dy_data = dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace());
if (dx_data != nullptr) {
memset(dx_data, 0, dx->numel() * sizeof(T));
}
if (dy_data != nullptr) {
memset(dy_data, 0, dy->numel() * sizeof(T));
}
const int out_size = std::accumulate(
out_dims_array, out_dims_array + max_dim, 1, std::multiplies<int>());
int x_index, y_index;
for (int out_index = 0; out_index < out_size; ++out_index) {
x_index = GetElementwiseIndex(x_dims_array, max_dim, index_array.data());
y_index = GetElementwiseIndex(y_dims_array, max_dim, index_array.data());
if (dx_data != nullptr) {
dx_data[x_index] += dx_op(x_data[x_index],
y_data[y_index],
out_data[out_index],
dout_data[out_index]);
}
if (dy_data != nullptr) {
dy_data[y_index] += dy_op(x_data[x_index],
y_data[y_index],
out_data[out_index],
dout_data[out_index]);
}
UpdateElementwiseIndexArray(out_dims_array, max_dim, index_array.data());
}
}
template <typename Functor, typename T, typename OutType = T>
void CommonForwardBroadcastCPU(const DenseTensor& x,
const DenseTensor& y,
......@@ -214,7 +267,7 @@ void CommonForwardBroadcastCPU(const DenseTensor& x,
int* y_dims_array,
int* out_dims_array,
int max_dim,
const paddle::platform::CPUDeviceContext& ctx,
const CPUContext& ctx,
Functor func,
const bool is_xsize_larger = true) {
std::vector<int> index_array(max_dim, 0);
......@@ -245,8 +298,7 @@ void CommonForwardBroadcastCPU(const DenseTensor& x,
}
template <typename Functor, typename T, typename OutType = T>
void CommonElementwiseBroadcastForward(
const paddle::platform::CPUDeviceContext& dev_ctx,
void CommonElementwiseBroadcastForward(const CPUContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* z,
......@@ -302,7 +354,7 @@ void CommonElementwiseBroadcastForward(
// TODO(liuyiqun): optimize the CPU implementation to support all broadcast
// cases and avoid the need of XxxInverseFunctor.
template <typename Functor, typename T, typename OutType = T>
void ElementwiseCompute(const paddle::platform::CPUDeviceContext& dev_ctx,
void ElementwiseCompute(const CPUContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
......@@ -317,9 +369,8 @@ void ElementwiseCompute(const paddle::platform::CPUDeviceContext& dev_ctx,
is_xsize_larger = false;
max_dim = y_dims.size();
}
funcs::
TransformFunctor<Functor, T, paddle::platform::CPUDeviceContext, OutType>
functor(x, y, z, dev_ctx, func, is_xsize_larger);
funcs::TransformFunctor<Functor, T, CPUContext, OutType> functor(
x, y, z, dev_ctx, func, is_xsize_larger);
if (x_dims == y_dims) {
functor.Run();
return;
......@@ -381,7 +432,7 @@ void ElementwiseCompute(const paddle::platform::CPUDeviceContext& dev_ctx,
template <typename Functor>
struct SameDimsElementwiseCompute {
void operator()(const paddle::platform::CPUDeviceContext& dev_ctx,
void operator()(const CPUContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* z) {
......@@ -389,4 +440,113 @@ struct SameDimsElementwiseCompute {
}
};
// BACKWARD CODE
template <typename T, typename DX_OP, typename DY_OP, typename Tout = T>
static void ElemwiseGradBroadcast1CPU(const T* x,
const T* y,
const Tout* out,
const Tout* dout,
int h,
int w,
bool is_xsize_larger,
DX_OP dx_op,
DY_OP dy_op,
T* dx,
T* dy) {
if (is_xsize_larger) {
for (int i = 0; i < h; ++i) {
for (int j = 0; j < w; ++j) {
int x_offset = i * w + j;
if (dx != nullptr) {
dx[x_offset] =
dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
}
if (dy != nullptr) {
T tmp = dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
if (i == 0) {
dy[j] = tmp;
} else {
dy[j] += tmp;
}
}
}
}
} else { // x.dims < y.dims, broadcast for x.
for (int i = 0; i < h; ++i) {
for (int j = 0; j < w; ++j) {
int y_offset = i * w + j;
if (dy != nullptr) {
dy[y_offset] =
dy_op(x[j], y[y_offset], out[y_offset], dout[y_offset]);
}
if (dx != nullptr) {
T tmp = dx_op(x[j], y[y_offset], out[y_offset], dout[y_offset]);
if (i == 0) {
dx[j] = tmp;
} else {
dx[j] += tmp;
}
}
}
}
}
}
template <typename T, typename DX_OP, typename DY_OP, typename Tout = T>
static void ElemwiseGradBroadcast2CPU(const T* x,
const T* y,
const Tout* out,
const Tout* dout,
int pre,
int n,
int post,
bool is_xsize_larger,
DX_OP dx_op,
DY_OP dy_op,
T* dx,
T* dy) {
if (is_xsize_larger) {
for (int i = 0; i < pre; ++i) {
for (int j = 0; j < n; ++j) {
for (int k = 0; k < post; ++k) {
int x_offset = i * n * post + j * post + k;
if (dx != nullptr) {
dx[x_offset] =
dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
}
if (dy != nullptr) {
T tmp = dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
if (i == 0 && k == 0) {
dy[j] = tmp;
} else {
dy[j] += tmp;
}
}
}
}
}
} else { // x.dims < y.dims, broadcast for x.
for (int i = 0; i < pre; ++i) {
for (int j = 0; j < n; ++j) {
for (int k = 0; k < post; ++k) {
int y_offset = i * n * post + j * post + k;
if (dy != nullptr) {
dy[y_offset] =
dy_op(x[j], y[y_offset], out[y_offset], dout[y_offset]);
}
if (dx != nullptr) {
T tmp = dx_op(x[j], y[y_offset], out[y_offset], dout[y_offset]);
if (i == 0 && k == 0) {
dx[j] = tmp;
} else {
dx[j] += tmp;
}
}
}
}
}
}
}
} // namespace pten
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include "paddle/fluid/platform/for_range.h"
#include "paddle/fluid/platform/transform.h"
#include "paddle/pten/backends/all_context.h"
#include "paddle/pten/core/dense_tensor.h"
......@@ -23,6 +24,28 @@ namespace funcs {
using DDim = paddle::framework::DDim;
template <typename T, typename DX_OP, typename DY_OP, typename Tout = T>
struct ElemwiseGradNoBroadcast {
const T *x_;
const T *y_;
const Tout *out_;
const Tout *dout_;
HOSTDEVICE void operator()(size_t i) {
if (dx_ != nullptr) {
dx_[i] = dx_op_(x_[i], y_[i], out_[i], dout_[i]);
}
if (dy_ != nullptr) {
dy_[i] = dy_op_(x_[i], y_[i], out_[i], dout_[i]);
}
}
DX_OP dx_op_;
DY_OP dy_op_;
T *dx_;
T *dy_;
};
template <typename T, typename DeviceContext>
class RowwiseTransformIterator;
......@@ -378,5 +401,36 @@ inline void GetBroadcastDimsArrays(const DDim &x_dims,
}
}
}
template <typename DeviceContext,
typename T,
typename DX_OP,
typename DY_OP,
typename Tout = T>
void ElemwiseGradComputeNoBroadcast(const DeviceContext &dev_ctx,
const DDim &x_dim,
const DDim &y_dim,
const DenseTensor &x,
const DenseTensor &y,
const DenseTensor &out,
const DenseTensor &dout,
int axis,
DenseTensor *dx,
DenseTensor *dy,
DX_OP dx_op,
DY_OP dy_op) {
size_t N = static_cast<size_t>(paddle::framework::product(x_dim));
paddle::platform::ForRange<DeviceContext> for_range(dev_ctx, N);
for_range(ElemwiseGradNoBroadcast<T, DX_OP, DY_OP, Tout>{
x.data<T>(),
y.data<T>(),
out.data<Tout>(),
dout.data<Tout>(),
dx_op,
dy_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(dev_ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(dev_ctx.GetPlace())});
}
} // namespace funcs
} // namespace pten
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册