未验证 提交 978558be 编写于 作者: N niuliling123 提交者: GitHub

Revert "Replace EigenBroadcast with ElementwiseBroadcast in ReduceGrad (#38959)" (#39205)

This reverts commit 9059ef69.
上级 712ccfbf
......@@ -17,9 +17,15 @@
template <typename T>
using CUDAReduceMeanGradKernel =
ops::ReduceCudaGradKernel<T, kps::DivideFunctor>;
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, T,
ops::MeanGradFunctor, true>;
using FP16CUDAReduceMeanGradKernel =
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16, ops::FP16MeanGradFunctor,
true>;
REGISTER_OP_CUDA_KERNEL(reduce_mean_grad, CUDAReduceMeanGradKernel<bool>,
CUDAReduceMeanGradKernel<paddle::platform::float16>,
FP16CUDAReduceMeanGradKernel,
CUDAReduceMeanGradKernel<float>,
CUDAReduceMeanGradKernel<double>);
......@@ -623,10 +623,9 @@ class ReduceGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
int out_dtype = ctx.Attr<int>("out_dtype");
int in_dtype = ctx.Attr<int>("in_dtype");
auto input_data_type =
(out_dtype >= 0)
? static_cast<framework::proto::VarType::Type>(out_dtype)
(in_dtype >= 0) ? static_cast<framework::proto::VarType::Type>(in_dtype)
: OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
#ifdef PADDLE_WITH_MKLDNN
......@@ -737,55 +736,6 @@ class ReduceCudaKernel : public framework::OpKernel<T> {
pt_out.get());
}
};
template <typename T, template <typename, typename> class TransformOp>
class ReduceCudaGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
bool reduce_all = context.Attr<bool>("reduce_all");
std::vector<int> dims = context.Attr<std::vector<int>>("dim");
auto* in_x = context.Input<Tensor>("X");
auto* d_out =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* d_x = context.Output<framework::Tensor>(framework::GradVarName("X"));
auto out_dtype = context.Attr<int>("in_dtype");
// get reduce_dim and reduce_num for reduce_mean_grad
int dim_size = in_x->dims().size();
std::vector<int> reduce_dims = GetReduceDim(dims, dim_size, reduce_all);
auto update_dims = vectorize(d_x->dims());
int reduce_num = 1;
for (auto i : reduce_dims) {
reduce_num *= (in_x->dims())[i];
update_dims[i] = 1;
}
// make new tensor
framework::Tensor new_d_out(d_out->type());
new_d_out.ShareDataWith(*d_out);
new_d_out.Resize(paddle::framework::make_ddim(update_dims));
auto& dev_ctx = context.cuda_device_context();
if (out_dtype > 0) {
d_x->mutable_data(
dev_ctx.GetPlace(),
static_cast<framework::proto::VarType::Type>(out_dtype));
} else {
d_x->mutable_data(
dev_ctx.GetPlace(),
static_cast<framework::proto::VarType::Type>(d_out->type()));
}
auto pt_d_out = paddle::experimental::MakePtenDenseTensor(new_d_out);
auto pt_d_x = paddle::experimental::MakePtenDenseTensor(*d_x);
auto pt_out_dtype = pten::TransToPtenDataType(
static_cast<framework::proto::VarType::Type>(out_dtype));
if (out_dtype <= 0) {
pt_out_dtype = pten::TransToPtenDataType(
static_cast<framework::proto::VarType::Type>(d_out->type()));
}
using MPType = typename kps::details::MPTypeTrait<T>::Type;
pten::ReduceGrad<T, TransformOp<T, MPType>>(
dev_ctx, pt_d_out.get(), pt_d_x.get(), pt_out_dtype,
TransformOp<T, MPType>(reduce_num));
}
};
#endif
} // namespace operators
......
......@@ -50,7 +50,7 @@ class ReduceSumOpGradMaker : public framework::SingleGradOpMaker<T> {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
int in_dtype = ctx.Attr<int>("out_dtype");
int in_dtype = ctx.Attr<int>("in_dtype");
if (in_dtype >= 0) {
return framework::OpKernelType(
static_cast<framework::proto::VarType::Type>(in_dtype),
......
......@@ -74,7 +74,7 @@ class ReduceSumGradKernel : public framework::OpKernel<T> {
auto dims = context.Attr<std::vector<int>>("dim");
if (context.GetPlace().GetType() == platform::CPUPlace().GetType() &&
dims.size() == 1) {
int in_dtype = context.Attr<int>("out_dtype");
int in_dtype = context.Attr<int>("in_dtype");
if (in_dtype >= 0) {
Tensor tmp_tensor;
......
......@@ -17,7 +17,8 @@
template <typename T>
using CUDAReduceSumGradKernel =
ops::ReduceCudaGradKernel<T, kps::IdentityFunctor>;
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, T,
ops::SumGradFunctor, true>;
REGISTER_OP_CUDA_KERNEL(
reduce_sum_grad, CUDAReduceSumGradKernel<bool>,
......
......@@ -134,20 +134,13 @@ struct DimensionsTransform {
explicit DimensionsTransform(const std::vector<const DenseTensor *> &ins,
const pten::framework::DDim &dims,
int axis) {
const int N = max(static_cast<int>(ins.size()), 2);
const int N = ins.size();
dim_size = dims.size();
out_dims = pten::framework::vectorize<int64_t>(dims);
in_dims.resize(N);
if (ins.size() == 1) {
// when ins.size() = 1, broadcast input to output
in_dims[0] = pten::framework::vectorize<int64_t>(ins[0]->dims());
in_dims[1] = out_dims;
// Add out_dims to in_dims to avoid errors in dims merging
} else {
for (int j = 0; j < N; ++j) {
in_dims[j] = pten::framework::vectorize<int64_t>(ins[j]->dims());
}
}
InputDimensionsExtend(N, axis);
auto merge_sequential_dims = [](bool &equal,
......
......@@ -45,7 +45,8 @@ namespace cub = hipcub;
#include "paddle/pten/api/ext/dispatch.h"
#include "paddle/pten/backends/gpu/gpu_context.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/kernels/gpu/elementwise.h"
#include "paddle/pten/kernels/funcs/elementwise_base.h"
// Reduce split or not, Whether to use ReduceHigherDim
#define REDUCE_SPLIT_BOUNDARY 512
#define REDUCE_VEC_SIZE 4
......@@ -1253,24 +1254,6 @@ void Reduce(const GPUContext& dev_ctx,
x, out, TransformOp<T, MPType>(reduce_num), reduce_dims, stream);
}
}
template <typename InT, typename Functor>
void ReduceGrad(const GPUContext& dev_ctx,
DenseTensor* d_out,
DenseTensor* d_x,
DataType out_dtype,
Functor functor) {
std::vector<const DenseTensor*> inputs = {d_out};
std::vector<DenseTensor*> outputs = {d_x};
PD_VISIT_ALL_TYPES(
out_dtype, "LaunchBroadcastElementwiseCudaKernel", ([&] {
LaunchBroadcastElementwiseCudaKernel<pten::ElementwiseType::kUnary,
InT,
data_t>(
dev_ctx, inputs, &outputs, 0, functor);
}));
}
} // namespace pten
#endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册