未验证 提交 688743bf 编写于 作者: Y Yiqun Liu 提交者: GitHub

Rename phi::func::TensorReduceImpl to phi::func::ReduceKernel. (#40183)

上级 c1d81ec1
...@@ -36,9 +36,9 @@ void TensorReduceImpl(const platform::CUDADeviceContext& dev_ctx, ...@@ -36,9 +36,9 @@ void TensorReduceImpl(const platform::CUDADeviceContext& dev_ctx,
gpuStream_t stream) { gpuStream_t stream) {
y->mutable_data<Ty>(x.place()); y->mutable_data<Ty>(x.place());
phi::funcs::TensorReduceImpl<Tx, Ty, ReduceOp, TransformOp>( phi::funcs::ReduceKernel<Tx, Ty, ReduceOp, TransformOp>(
static_cast<const phi::GPUContext&>(dev_ctx), x, y, transform, static_cast<const phi::GPUContext&>(dev_ctx), x, y, transform,
origin_reduce_dims, stream); origin_reduce_dims);
} }
} // namespace operators } // namespace operators
......
...@@ -45,13 +45,8 @@ class MatrixReduceSumFunctor<T, GPUContext> { ...@@ -45,13 +45,8 @@ class MatrixReduceSumFunctor<T, GPUContext> {
out_reduce_dims.push_back(idx); out_reduce_dims.push_back(idx);
} }
} }
TensorReduceImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>( ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
dev_ctx, dev_ctx, in, out, kps::IdentityFunctor<T>(), out_reduce_dims);
in,
out,
kps::IdentityFunctor<T>(),
out_reduce_dims,
dev_ctx.stream());
} }
}; };
......
...@@ -1087,12 +1087,12 @@ template <typename Tx, ...@@ -1087,12 +1087,12 @@ template <typename Tx,
typename Ty, typename Ty,
template <typename> class ReduceOp, template <typename> class ReduceOp,
typename TransformOp> typename TransformOp>
void TensorReduceImpl(const phi::GPUContext& dev_ctx, void ReduceKernel(const phi::GPUContext& dev_ctx,
const phi::DenseTensor& x, const phi::DenseTensor& x,
phi::DenseTensor* y, phi::DenseTensor* y,
const TransformOp& transform, const TransformOp& transform,
const std::vector<int>& origin_reduce_dims, const std::vector<int>& origin_reduce_dims) {
KPStream stream) { auto stream = dev_ctx.stream();
dev_ctx.Alloc<Ty>(y); dev_ctx.Alloc<Ty>(y);
auto x_dim = phi::vectorize<int>(x.dims()); auto x_dim = phi::vectorize<int>(x.dims());
......
...@@ -87,13 +87,12 @@ void BroadcastTensorsGradKernel(const Context& ctx, ...@@ -87,13 +87,12 @@ void BroadcastTensorsGradKernel(const Context& ctx,
*input_tensor, ctx.GetPlace(), ctx, output_tensor); *input_tensor, ctx.GetPlace(), ctx, output_tensor);
} else { } else {
// reduce_sum implementation on CUDA // reduce_sum implementation on CUDA
funcs::TensorReduceImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>( funcs::ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
ctx, ctx,
*input_tensor, *input_tensor,
output_tensor, output_tensor,
kps::IdentityFunctor<T>(), kps::IdentityFunctor<T>(),
reduce_dims_vec, reduce_dims_vec);
ctx.stream());
} }
} }
} }
......
...@@ -80,8 +80,8 @@ inline void CompareAllKernelImpl(const Context& ctx, ...@@ -80,8 +80,8 @@ inline void CompareAllKernelImpl(const Context& ctx,
for (int i = 0; i < reduce_dims.size(); ++i) { for (int i = 0; i < reduce_dims.size(); ++i) {
reduce_dims[i] = i; reduce_dims[i] = i;
} }
funcs::TensorReduceImpl<bool, bool, BitwiseAdd, kps::IdentityFunctor<bool>>( funcs::ReduceKernel<bool, bool, BitwiseAdd, kps::IdentityFunctor<bool>>(
ctx, tmp, out, kps::IdentityFunctor<bool>(), reduce_dims, ctx.stream()); ctx, tmp, out, kps::IdentityFunctor<bool>(), reduce_dims);
} }
} // namespace phi } // namespace phi
......
...@@ -29,13 +29,8 @@ void ReduceWrapper(const GPUContext &dev_ctx, ...@@ -29,13 +29,8 @@ void ReduceWrapper(const GPUContext &dev_ctx,
DenseTensor *dst) { DenseTensor *dst) {
std::vector<int> reduce_dims = std::vector<int> reduce_dims =
funcs::GetReduceDim(dst->dims(), src->dims(), axis); funcs::GetReduceDim(dst->dims(), src->dims(), axis);
funcs::TensorReduceImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>( funcs::ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
dev_ctx, dev_ctx, *src, dst, kps::IdentityFunctor<T>(), reduce_dims);
*src,
dst,
kps::IdentityFunctor<T>(),
reduce_dims,
dev_ctx.stream());
} }
template <ElementwiseType ET, typename T, typename Functor> template <ElementwiseType ET, typename T, typename Functor>
...@@ -172,9 +167,8 @@ void DefaultElementwiseAddGrad(const GPUContext &ctx, ...@@ -172,9 +167,8 @@ void DefaultElementwiseAddGrad(const GPUContext &ctx,
} }
std::vector<int> reduce_dims = std::vector<int> reduce_dims =
funcs::GetReduceDim(x.dims(), out.dims(), axis); funcs::GetReduceDim(x.dims(), out.dims(), axis);
gpuStream_t stream = ctx.stream(); funcs::ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
funcs::TensorReduceImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>( ctx, dout, dx, kps::IdentityFunctor<T>(), reduce_dims);
ctx, dout, dx, kps::IdentityFunctor<T>(), reduce_dims, stream);
} }
} }
// dy // dy
...@@ -187,9 +181,8 @@ void DefaultElementwiseAddGrad(const GPUContext &ctx, ...@@ -187,9 +181,8 @@ void DefaultElementwiseAddGrad(const GPUContext &ctx,
} else { } else {
std::vector<int> reduce_dims = std::vector<int> reduce_dims =
funcs::GetReduceDim(y.dims(), out.dims(), axis); funcs::GetReduceDim(y.dims(), out.dims(), axis);
gpuStream_t stream = ctx.stream(); funcs::ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
funcs::TensorReduceImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>( ctx, dout, dy, kps::IdentityFunctor<T>(), reduce_dims);
ctx, dout, dy, kps::IdentityFunctor<T>(), reduce_dims, stream);
} }
} }
} }
...@@ -285,9 +278,8 @@ void default_elementwise_sub_grad(const GPUContext &ctx, ...@@ -285,9 +278,8 @@ void default_elementwise_sub_grad(const GPUContext &ctx,
} }
std::vector<int> reduce_dims = std::vector<int> reduce_dims =
funcs::GetReduceDim(x.dims(), out.dims(), axis); funcs::GetReduceDim(x.dims(), out.dims(), axis);
gpuStream_t stream = ctx.stream(); funcs::ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
funcs::TensorReduceImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>( ctx, dout, dx, kps::IdentityFunctor<T>(), reduce_dims);
ctx, dout, dx, kps::IdentityFunctor<T>(), reduce_dims, stream);
} }
} }
// dy // dy
...@@ -306,9 +298,8 @@ void default_elementwise_sub_grad(const GPUContext &ctx, ...@@ -306,9 +298,8 @@ void default_elementwise_sub_grad(const GPUContext &ctx,
} else { } else {
std::vector<int> reduce_dims = std::vector<int> reduce_dims =
funcs::GetReduceDim(y.dims(), out.dims(), axis); funcs::GetReduceDim(y.dims(), out.dims(), axis);
gpuStream_t stream = ctx.stream(); funcs::ReduceKernel<T, T, kps::AddFunctor, kps::InverseFunctor<T>>(
funcs::TensorReduceImpl<T, T, kps::AddFunctor, kps::InverseFunctor<T>>( ctx, dout, dy, kps::InverseFunctor<T>(), reduce_dims);
ctx, dout, dy, kps::InverseFunctor<T>(), reduce_dims, stream);
} }
} }
} }
......
...@@ -39,8 +39,6 @@ void Reduce(const KPDevice& dev_ctx, ...@@ -39,8 +39,6 @@ void Reduce(const KPDevice& dev_ctx,
reduce_num *= (x.dims())[i]; reduce_num *= (x.dims())[i];
} }
KPStream stream = dev_ctx.stream();
if (out_dtype != phi::DataType::UNDEFINED && out_dtype != x.dtype()) { if (out_dtype != phi::DataType::UNDEFINED && out_dtype != x.dtype()) {
auto tmp_tensor = phi::Cast<T>(dev_ctx, x, out_dtype); auto tmp_tensor = phi::Cast<T>(dev_ctx, x, out_dtype);
PD_VISIT_BOOL_AND_FLOATING_AND_COMPLEX_AND_3_TYPES( PD_VISIT_BOOL_AND_FLOATING_AND_COMPLEX_AND_3_TYPES(
...@@ -48,10 +46,10 @@ void Reduce(const KPDevice& dev_ctx, ...@@ -48,10 +46,10 @@ void Reduce(const KPDevice& dev_ctx,
phi::DataType::INT64, phi::DataType::INT64,
phi::DataType::FLOAT16, phi::DataType::FLOAT16,
out_dtype, out_dtype,
"TensorReduceImpl", "ReduceKernel",
([&] { ([&] {
using MPType = typename kps::details::MPTypeTrait<data_t>::Type; using MPType = typename kps::details::MPTypeTrait<data_t>::Type;
phi::funcs::TensorReduceImpl<data_t, phi::funcs::ReduceKernel<data_t,
data_t, data_t,
ReduceOp, ReduceOp,
TransformOp<data_t, MPType>>( TransformOp<data_t, MPType>>(
...@@ -59,18 +57,12 @@ void Reduce(const KPDevice& dev_ctx, ...@@ -59,18 +57,12 @@ void Reduce(const KPDevice& dev_ctx,
tmp_tensor, tmp_tensor,
out, out,
TransformOp<data_t, MPType>(reduce_num), TransformOp<data_t, MPType>(reduce_num),
reduce_dims, reduce_dims);
stream);
})); }));
} else { } else {
using MPType = typename kps::details::MPTypeTrait<T>::Type; using MPType = typename kps::details::MPTypeTrait<T>::Type;
phi::funcs::TensorReduceImpl<T, T, ReduceOp, TransformOp<T, MPType>>( phi::funcs::ReduceKernel<T, T, ReduceOp, TransformOp<T, MPType>>(
dev_ctx, dev_ctx, x, out, TransformOp<T, MPType>(reduce_num), reduce_dims);
x,
out,
TransformOp<T, MPType>(reduce_num),
reduce_dims,
stream);
} }
} }
} // namespace phi } // namespace phi
......
...@@ -69,17 +69,12 @@ void SigmoidCrossEntropyWithLogitsGradKernel(const Context &dev_ctx, ...@@ -69,17 +69,12 @@ void SigmoidCrossEntropyWithLogitsGradKernel(const Context &dev_ctx,
dev_ctx.template Alloc<T>(counts_tensor); dev_ctx.template Alloc<T>(counts_tensor);
counts_tensor->Resize(in_grad->dims()); counts_tensor->Resize(in_grad->dims());
int limit = in_grad->numel();
int blocks = NumBlocks(limit);
int threads = kNumCUDAThreads;
std::vector<const DenseTensor *> ins = {&x, &label, &out_grad}; std::vector<const DenseTensor *> ins = {&x, &label, &out_grad};
std::vector<DenseTensor *> outs = {in_grad, counts_tensor}; std::vector<DenseTensor *> outs = {in_grad, counts_tensor};
auto functor = SigmoidBwdFunctor<T>(ignore_index); auto functor = SigmoidBwdFunctor<T>(ignore_index);
constexpr int Size = 2; phi::funcs::ElementwiseKernel<T, decltype(functor), 2>(
phi::funcs::ElementwiseKernel<T, decltype(functor), Size>(
dev_ctx, ins, &outs, functor); dev_ctx, ins, &outs, functor);
if (normalize) { if (normalize) {
T *counts = dev_ctx.template Alloc<T>(counts_tensor);
DenseTensor *norm_tensor = new DenseTensor(); DenseTensor *norm_tensor = new DenseTensor();
norm_tensor->Resize({sizeof(T)}); norm_tensor->Resize({sizeof(T)});
dev_ctx.template Alloc<T>(norm_tensor); dev_ctx.template Alloc<T>(norm_tensor);
...@@ -89,13 +84,8 @@ void SigmoidCrossEntropyWithLogitsGradKernel(const Context &dev_ctx, ...@@ -89,13 +84,8 @@ void SigmoidCrossEntropyWithLogitsGradKernel(const Context &dev_ctx,
reduce_dim.push_back(i); reduce_dim.push_back(i);
} }
funcs::TensorReduceImpl<T, T, kps::AddFunctor, NonzeroFunctor<T>>( funcs::ReduceKernel<T, T, kps::AddFunctor, NonzeroFunctor<T>>(
dev_ctx, dev_ctx, *counts_tensor, norm_tensor, NonzeroFunctor<T>(), reduce_dim);
*counts_tensor,
norm_tensor,
NonzeroFunctor<T>(),
reduce_dim,
dev_ctx.stream());
T *norm = dev_ctx.template Alloc<T>(norm_tensor); T *norm = dev_ctx.template Alloc<T>(norm_tensor);
auto norm_cpu_mem = paddle::memory::Alloc(phi::CPUPlace(), sizeof(T)); auto norm_cpu_mem = paddle::memory::Alloc(phi::CPUPlace(), sizeof(T));
T *norm_cpu_ptr = reinterpret_cast<T *>(norm_cpu_mem->ptr()); T *norm_cpu_ptr = reinterpret_cast<T *>(norm_cpu_mem->ptr());
...@@ -114,6 +104,7 @@ void SigmoidCrossEntropyWithLogitsGradKernel(const Context &dev_ctx, ...@@ -114,6 +104,7 @@ void SigmoidCrossEntropyWithLogitsGradKernel(const Context &dev_ctx,
phi::funcs::ElementwiseKernel<T>(dev_ctx, div_ins, &div_outs, div_functor); phi::funcs::ElementwiseKernel<T>(dev_ctx, div_ins, &div_outs, div_functor);
delete norm_tensor; delete norm_tensor;
} }
delete counts_tensor;
} }
} // namespace phi } // namespace phi
......
...@@ -69,17 +69,12 @@ void SigmoidCrossEntropyWithLogitsKernel(const Context &dev_ctx, ...@@ -69,17 +69,12 @@ void SigmoidCrossEntropyWithLogitsKernel(const Context &dev_ctx,
dev_ctx.template Alloc<T>(counts_tensor); dev_ctx.template Alloc<T>(counts_tensor);
counts_tensor->Resize(out->dims()); counts_tensor->Resize(out->dims());
int limit = out->numel();
int blocks = NumBlocks(limit);
int threads = kNumCUDAThreads;
std::vector<const DenseTensor *> ins = {&x, &label}; std::vector<const DenseTensor *> ins = {&x, &label};
std::vector<DenseTensor *> outs = {out, counts_tensor}; std::vector<DenseTensor *> outs = {out, counts_tensor};
auto functor = SigmoidFwdFunctor<T>(ignore_index); auto functor = SigmoidFwdFunctor<T>(ignore_index);
constexpr int Size = 2; phi::funcs::ElementwiseKernel<T, decltype(functor), 2>(
phi::funcs::ElementwiseKernel<T, decltype(functor), Size>(
dev_ctx, ins, &outs, functor); dev_ctx, ins, &outs, functor);
if (normalize) { if (normalize) {
T *counts = dev_ctx.template Alloc<T>(counts_tensor);
DenseTensor *norm_tensor = new DenseTensor(); DenseTensor *norm_tensor = new DenseTensor();
norm_tensor->Resize({sizeof(T)}); norm_tensor->Resize({sizeof(T)});
dev_ctx.template Alloc<T>(norm_tensor); dev_ctx.template Alloc<T>(norm_tensor);
...@@ -89,13 +84,8 @@ void SigmoidCrossEntropyWithLogitsKernel(const Context &dev_ctx, ...@@ -89,13 +84,8 @@ void SigmoidCrossEntropyWithLogitsKernel(const Context &dev_ctx,
reduce_dim.push_back(i); reduce_dim.push_back(i);
} }
funcs::TensorReduceImpl<T, T, kps::AddFunctor, NonzeroFunctor<T>>( funcs::ReduceKernel<T, T, kps::AddFunctor, NonzeroFunctor<T>>(
dev_ctx, dev_ctx, *counts_tensor, norm_tensor, NonzeroFunctor<T>(), reduce_dim);
*counts_tensor,
norm_tensor,
NonzeroFunctor<T>(),
reduce_dim,
dev_ctx.stream());
T *norm = dev_ctx.template Alloc<T>(norm_tensor); T *norm = dev_ctx.template Alloc<T>(norm_tensor);
auto norm_cpu_mem = paddle::memory::Alloc(phi::CPUPlace(), sizeof(T)); auto norm_cpu_mem = paddle::memory::Alloc(phi::CPUPlace(), sizeof(T));
T *norm_cpu_ptr = reinterpret_cast<T *>(norm_cpu_mem->ptr()); T *norm_cpu_ptr = reinterpret_cast<T *>(norm_cpu_mem->ptr());
...@@ -114,8 +104,8 @@ void SigmoidCrossEntropyWithLogitsKernel(const Context &dev_ctx, ...@@ -114,8 +104,8 @@ void SigmoidCrossEntropyWithLogitsKernel(const Context &dev_ctx,
phi::funcs::ElementwiseKernel<T>(dev_ctx, div_ins, &div_outs, div_functor); phi::funcs::ElementwiseKernel<T>(dev_ctx, div_ins, &div_outs, div_functor);
delete norm_tensor; delete norm_tensor;
delete counts_tensor;
} }
delete counts_tensor;
} }
} // namespace phi } // namespace phi
......
...@@ -31,11 +31,10 @@ void TraceKernel(const Context& ctx, ...@@ -31,11 +31,10 @@ void TraceKernel(const Context& ctx,
T* out_data = ctx.template Alloc<T>(out); T* out_data = ctx.template Alloc<T>(out);
auto diag = funcs::Diagonal<T, Context>(ctx, &x, offset, axis1, axis2); auto diag = funcs::Diagonal<T, Context>(ctx, &x, offset, axis1, axis2);
if (diag.numel() > 0) { if (diag.numel() > 0) {
auto stream = ctx.stream();
std::vector<int> reduce_dims; std::vector<int> reduce_dims;
reduce_dims.push_back(out->dims().size()); reduce_dims.push_back(out->dims().size());
funcs::TensorReduceImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>( funcs::ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
ctx, diag, out, kps::IdentityFunctor<T>(), reduce_dims, stream); ctx, diag, out, kps::IdentityFunctor<T>(), reduce_dims);
} else { } else {
phi::funcs::SetConstant<Context, T> functor; phi::funcs::SetConstant<Context, T> functor;
functor(ctx, out, static_cast<T>(0)); functor(ctx, out, static_cast<T>(0));
......
...@@ -59,9 +59,8 @@ struct ReduceSumForMatmulGrad<GPUContext, T> { ...@@ -59,9 +59,8 @@ struct ReduceSumForMatmulGrad<GPUContext, T> {
const DenseTensor& input, const DenseTensor& input,
DenseTensor* output, DenseTensor* output,
const std::vector<int>& reduce_dims) { const std::vector<int>& reduce_dims) {
auto stream = dev_ctx.stream(); funcs::ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
funcs::TensorReduceImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>( dev_ctx, input, output, kps::IdentityFunctor<T>(), reduce_dims);
dev_ctx, input, output, kps::IdentityFunctor<T>(), reduce_dims, stream);
} }
}; };
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册