未验证 提交 7a4a512d 编写于 作者: C chentianyu03 提交者: GitHub

[pten]Move reduce code new (#38648)

* change 'math' to 'math_kernel'

* fix compile bugs

* merge develop

* fix compile bugs

* fix compile bugs

* move reduce files by new rule

* add set header

* format code style

* merge develop and fix conflict

* merge develop and fix conflict
Co-authored-by: NYuanRisheng <yuanrisheng@baidu.com>
上级 c90a652d
......@@ -13,7 +13,7 @@
// limitations under the License.
#include "gtest/gtest.h"
#include "paddle/pten/kernels/hybird/cuda/reduce/reduce_cuda_impl.h"
#include "paddle/pten/kernels/gpu/reduce.h"
namespace paddle {
namespace operators {
......
......@@ -32,7 +32,7 @@ namespace cub = hipcub;
#include "paddle/fluid/framework/tensor.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/kernels/hybird/cuda/reduce/reduce_cuda_impl.h"
#include "paddle/pten/kernels/gpu/reduce.h"
namespace paddle {
namespace operators {
......
......@@ -28,10 +28,10 @@ limitations under the License. */
#include "paddle/pten/api/lib/utils/tensor_utils.h"
#include "paddle/pten/include/core.h"
#include "paddle/pten/include/math.h"
#include "paddle/pten/kernels/hybird/general/reduce_impl.h"
#include "paddle/pten/kernels/cpu/reduce.h"
#if defined(__HIPCC__) || defined(__NVCC__)
#include "paddle/pten/kernels/hybird/cuda/reduce/reduce.h"
#include "paddle/pten/kernels/gpu/reduce.h"
#endif
namespace paddle {
......@@ -259,7 +259,7 @@ class ReduceKernel : public framework::OpKernel<T> {
std::vector<int64_t> tmp_dims(dims.begin(), dims.end());
// call new kernel
pten::general::Reduce<DeviceContext, T, Functor>(
pten::Reduce<DeviceContext, T, Functor>(
dev_ctx, *pt_x.get(), reduce_all, tmp_dims, keep_dim,
pten::TransToPtenDataType(cast_out_dtype), pt_out.get());
}
......
......@@ -45,7 +45,7 @@ DenseTensor Mean(const ContextT& dev_ctx,
dev_ctx.GetPlace()),
std::move(out_meta));
bool reduce_all = false;
Mean<T, ContextT>(dev_ctx, x, axis, keep_dim, reduce_all, &dense_out);
MeanKernel<T, ContextT>(dev_ctx, x, axis, keep_dim, reduce_all, &dense_out);
return dense_out;
}
......@@ -65,7 +65,7 @@ DenseTensor Sum(const ContextT& dev_ctx,
// so use default value(false) is OK.
bool reduce_all = false;
Sum<T, ContextT>(
SumKernel<T, ContextT>(
dev_ctx, x, axis, keep_dim, reduce_all, out_meta.dtype, &dense_out);
return dense_out;
}
......
......@@ -18,13 +18,10 @@
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/common/scalar.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/cpu/elementwise_impl.h"
#include "paddle/pten/kernels/cpu/reduce.h"
#include "paddle/pten/kernels/funcs/elementwise_functor.h"
#include "paddle/pten/kernels/hybird/eigen/reduce.h"
#include "paddle/pten/kernels/hybird/general/reduce_impl.h"
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/platform/bfloat16.h"
......@@ -57,14 +54,14 @@ namespace pten {
}
template <typename T, typename Context>
void Mean(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DenseTensor* out) {
void MeanKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DenseTensor* out) {
auto out_dtype = x.dtype();
pten::general::Reduce<CPUContext, T, pten::eigen::MeanFunctor>(
pten::Reduce<CPUContext, T, pten::eigen::MeanFunctor>(
dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out);
}
......@@ -93,14 +90,14 @@ void DivideKernel(const Context& dev_ctx,
}
template <typename T, typename Context>
void Sum(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DataType out_dtype,
DenseTensor* out) {
pten::general::Reduce<CPUContext, T, pten::eigen::SumFunctor>(
void SumKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DataType out_dtype,
DenseTensor* out) {
pten::Reduce<CPUContext, T, pten::eigen::SumFunctor>(
dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out);
}
......@@ -120,8 +117,8 @@ using complex128 = ::paddle::platform::complex<double>;
// NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16
// using bfloat16 = ::paddle::platform::bfloat16;
PT_REGISTER_CTX_KERNEL(mean, CPU, ALL_LAYOUT, pten::Mean, float, double, bool) {
}
PT_REGISTER_CTX_KERNEL(
mean, CPU, ALL_LAYOUT, pten::MeanKernel, float, double, bool) {}
PT_REGISTER_CTX_KERNEL(add,
CPU,
ALL_LAYOUT,
......@@ -166,7 +163,7 @@ PT_REGISTER_CTX_KERNEL(multiply,
PT_REGISTER_CTX_KERNEL(sum,
CPU,
ALL_LAYOUT,
pten::Sum,
pten::SumKernel,
bool,
float,
double,
......
......@@ -13,14 +13,15 @@
// limitations under the License.
#pragma once
#include "paddle/fluid/platform/transform.h"
#include <set>
#include "paddle/pten/api/ext/dispatch.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/kernels/cast_kernel.h"
#include "paddle/pten/kernels/hybird/eigen/reduce.h"
namespace pten {
namespace general {
template <typename DeviceContext, typename T, typename Functor>
void Reduce(const DeviceContext& dev_ctx,
......@@ -71,6 +72,4 @@ void Reduce(const DeviceContext& dev_ctx,
}
}
} // namespace general
} // namespace pten
......@@ -16,9 +16,8 @@ limitations under the License. */
#include "paddle/pten/backends/gpu/gpu_context.h"
#include "paddle/pten/kernels/funcs/elementwise_functor.h"
#include "paddle/pten/kernels/gpu/reduce.h"
#include "paddle/pten/kernels/hybird/cuda/elementwise/elementwise.h"
#include "paddle/pten/kernels/hybird/cuda/reduce/reduce.h"
#include "paddle/pten/kernels/hybird/general/reduce_impl.h"
#ifdef __NVCC__
#include "cub/cub.cuh"
......@@ -76,12 +75,12 @@ struct DivideFunctor {
*/
template <typename T, typename Context>
void Mean(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DenseTensor* out) {
void MeanKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DenseTensor* out) {
auto out_dtype = x.dtype();
pten::Reduce<T, kps::AddFunctor, kps::DivideFunctor>(
dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out);
......@@ -97,13 +96,13 @@ DEFINE_CUDA_ELEMENTWISE_OP(Multiply)
DEFINE_CUDA_ELEMENTWISE_OP(Divide)
template <typename T, typename Context>
void Sum(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DataType out_dtype,
DenseTensor* out) {
void SumKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DataType out_dtype,
DenseTensor* out) {
pten::Reduce<T, kps::AddFunctor, kps::IdentityFunctor>(
dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out);
}
......@@ -115,7 +114,7 @@ using complex64 = ::paddle::platform::complex<float>;
using complex128 = ::paddle::platform::complex<double>;
PT_REGISTER_CTX_KERNEL(
mean, GPU, ALL_LAYOUT, pten::Mean, float, double, bool, float16) {}
mean, GPU, ALL_LAYOUT, pten::MeanKernel, float, double, bool, float16) {}
PT_REGISTER_CTX_KERNEL(add,
GPU,
ALL_LAYOUT,
......@@ -164,7 +163,7 @@ PT_REGISTER_CTX_KERNEL(multiply,
PT_REGISTER_CTX_KERNEL(sum,
GPU,
ALL_LAYOUT,
pten::Sum,
pten::SumKernel,
bool,
float,
double,
......
......@@ -14,6 +14,9 @@
#pragma once
// CUDA and HIP use same api
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include <algorithm>
#include <cmath>
#include <numeric>
......@@ -40,6 +43,7 @@ namespace cub = hipcub;
#include "paddle/fluid/string/string_helper.h"
#include "paddle/pten/api/ext/dispatch.h"
#include "paddle/pten/backends/gpu/gpu_context.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/kernels/cast_kernel.h"
#include "paddle/pten/kernels/copy_kernel.h"
......@@ -1230,4 +1234,48 @@ void TensorReduceFunctorImpl(const pten::DenseTensor& x,
}
} // namespace kernels
template <typename T,
template <typename> class ReduceOp,
template <typename, typename> class TransformOp>
void Reduce(const GPUContext& dev_ctx,
const DenseTensor& x,
bool reduce_all,
const std::vector<int64_t>& dims,
bool keep_dim,
DataType out_dtype,
DenseTensor* out) {
std::vector<int> reduce_dims =
pten::kernels::details::GetReduceDim(dims, x.dims().size(), reduce_all);
int reduce_num = 1;
for (auto i : reduce_dims) {
reduce_num *= (x.dims())[i];
}
gpuStream_t stream = dev_ctx.stream();
if (out_dtype != pten::DataType::UNDEFINED && out_dtype != x.dtype()) {
PD_DISPATCH_FLOATING_AND_COMPLEX_AND_2_TYPES(
pten::DataType::INT32,
pten::DataType::INT64,
out_dtype,
"TensorReduceFunctorImpl",
([&] {
using MPType = typename kps::details::MPTypeTrait<data_t>::Type;
pten::kernels::TensorReduceFunctorImpl<T,
data_t,
ReduceOp,
TransformOp<T, MPType>>(
x, out, TransformOp<T, MPType>(reduce_num), reduce_dims, stream);
}));
} else {
using MPType = typename kps::details::MPTypeTrait<T>::Type;
pten::kernels::
TensorReduceFunctorImpl<T, T, ReduceOp, TransformOp<T, MPType>>(
x, out, TransformOp<T, MPType>(reduce_num), reduce_dims, stream);
}
}
} // namespace pten
#endif
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
// CUDA and HIP use same api
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/pten/api/ext/dispatch.h"
#include "paddle/pten/backends/gpu/gpu_context.h"
#include "paddle/pten/common/scalar.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/kernels/hybird/cuda/reduce/reduce_cuda_impl.h"
namespace pten {
template <typename T,
template <typename> class ReduceOp,
template <typename, typename> class TransformOp>
void Reduce(const GPUContext& dev_ctx,
const DenseTensor& x,
bool reduce_all,
const std::vector<int64_t>& dims,
bool keep_dim,
DataType out_dtype,
DenseTensor* out) {
std::vector<int> reduce_dims =
pten::kernels::details::GetReduceDim(dims, x.dims().size(), reduce_all);
int reduce_num = 1;
for (auto i : reduce_dims) {
reduce_num *= (x.dims())[i];
}
gpuStream_t stream = dev_ctx.stream();
if (out_dtype != pten::DataType::UNDEFINED && out_dtype != x.dtype()) {
PD_DISPATCH_FLOATING_AND_COMPLEX_AND_2_TYPES(
pten::DataType::INT32,
pten::DataType::INT64,
out_dtype,
"TensorReduceFunctorImpl",
([&] {
using MPType = typename kps::details::MPTypeTrait<data_t>::Type;
pten::kernels::TensorReduceFunctorImpl<T,
data_t,
ReduceOp,
TransformOp<T, MPType>>(
x, out, TransformOp<T, MPType>(reduce_num), reduce_dims, stream);
}));
} else {
using MPType = typename kps::details::MPTypeTrait<T>::Type;
pten::kernels::
TensorReduceFunctorImpl<T, T, ReduceOp, TransformOp<T, MPType>>(
x, out, TransformOp<T, MPType>(reduce_num), reduce_dims, stream);
}
}
} // namespace pten
#endif
......@@ -21,12 +21,12 @@ limitations under the License. */
namespace pten {
template <typename T, typename Context>
void Mean(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DenseTensor* out);
void MeanKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DenseTensor* out);
template <typename T, typename Context>
void AddKernel(const Context& dev_ctx,
......@@ -57,13 +57,13 @@ void MultiplyKernel(const Context& dev_ctx,
DenseTensor* out);
template <typename T, typename Context>
void Sum(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DataType out_dtype,
DenseTensor* out);
void SumKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DataType out_dtype,
DenseTensor* out);
template <typename T, typename ContextT>
DenseTensor Add(const ContextT& dev_ctx,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册