未验证 提交 556d5097 编写于 作者: Y YuanRisheng 提交者: GitHub

refactor impl of elementwise op part2 (#38898)

上级 7f8d5bc8
......@@ -549,4 +549,148 @@ static void ElemwiseGradBroadcast2CPU(const T* x,
}
}
template <typename T, typename DX_OP, typename DY_OP, typename Tout = T>
void CommonElementwiseBroadcastBackward(const CPUContext& ctx,
const DDim& x_dims,
const DDim& y_dims,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out,
const DenseTensor& dout,
int axis,
DenseTensor* dx,
DenseTensor* dy,
DX_OP dx_op,
DY_OP dy_op) {
int max_dim = std::max(x_dims.size(), y_dims.size());
axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis);
std::vector<int> x_dims_array(max_dim);
std::vector<int> y_dims_array(max_dim);
std::vector<int> out_dims_array(max_dim);
funcs::GetBroadcastDimsArrays(x_dims,
y_dims,
x_dims_array.data(),
y_dims_array.data(),
out_dims_array.data(),
max_dim,
axis);
// for inplace strategy. memset will make dx and dout clear and get wrong
// result.
if (dx && dx->IsSharedBufferWith(dout)) {
dx->clear();
dx->mutable_data<T>(x_dims, ctx.GetPlace());
}
VLOG(3) << "CommonElementwiseBroadcastBackward xdims:"
<< paddle::framework::make_ddim(x_dims_array)
<< " ydim:" << paddle::framework::make_ddim(y_dims_array);
CommonGradBroadcastCPU<T, DX_OP, DY_OP, Tout>(x,
y,
out,
dout,
dx,
dy,
x_dims_array.data(),
y_dims_array.data(),
out_dims_array.data(),
max_dim,
ctx,
dx_op,
dy_op);
}
template <typename T, typename DX_OP, typename DY_OP, typename Tout = T>
void ElemwiseGradComputeWithBroadcast(const CPUContext& ctx,
const DDim& x_dims,
const DDim& y_dims,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out,
const DenseTensor& dout,
int axis,
DenseTensor* dx,
DenseTensor* dy,
DX_OP dx_op,
DY_OP dy_op) {
bool is_xsize_larger = true;
int max_dim = x_dims.size();
if (x_dims.size() < y_dims.size()) {
is_xsize_larger = false;
max_dim = y_dims.size();
}
axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis);
PADDLE_ENFORCE_GE(
axis,
0,
paddle::platform::errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.",
axis));
PADDLE_ENFORCE_LT(axis,
max_dim,
paddle::platform::errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.",
max_dim,
axis));
int pre, n, post, is_run_common_broadcast, axis_trim = 0;
if (is_xsize_larger) {
auto y_dims_trimed = funcs::trim_trailing_singular_dims(y_dims);
axis_trim = (y_dims_trimed.size() == 0) ? x_dims.size() : axis;
funcs::get_mid_dims(x_dims,
y_dims_trimed,
axis_trim,
&pre,
&n,
&post,
&is_run_common_broadcast);
} else {
auto x_dims_trimed = funcs::trim_trailing_singular_dims(x_dims);
axis_trim = (x_dims_trimed.size() == 0) ? y_dims.size() : axis;
funcs::get_mid_dims(y_dims,
x_dims_trimed,
axis_trim,
&pre,
&n,
&post,
&is_run_common_broadcast);
}
// special case for common backward implementation.
if (is_run_common_broadcast) {
CommonElementwiseBroadcastBackward<T, DX_OP, DY_OP, Tout>(
ctx, x_dims, y_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
return;
}
if (post == 1) {
ElemwiseGradBroadcast1CPU(
x.data<T>(),
y.data<T>(),
out.data<Tout>(),
dout.data<Tout>(),
pre,
n,
is_xsize_larger,
dx_op,
dy_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
} else {
ElemwiseGradBroadcast2CPU(
x.data<T>(),
y.data<T>(),
out.data<Tout>(),
dout.data<Tout>(),
pre,
n,
post,
is_xsize_larger,
dx_op,
dy_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
}
}
} // namespace pten
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册