未验证 提交 08941eda 编写于 作者: C chentianyu03 提交者: GitHub

[pten] combine reduce_cuda codes (#38328)

* combine reduce_cuda codes

* support float16 in pten redcue_mean

* replace ReduceCudaKernel impl with pten reduce impl

* mv reduce funcs into reduce_cuda_impl

* rm unsed codes and headers

* mv GetReduceDim into reduce_cuda_impl

* recover GetReduceDim in reduce_op.h

* add new dispatch macro

* fix pool op output not inited and cause transform to pten::denseTensor error

* fix output tensor not initialized error

* rename new dispatch macro and format code style

* rm reduce_functor_op.h file
上级 5ab6ebaf
......@@ -191,9 +191,10 @@ void SetConfigForColumnReduce(const int max_threads, const int reduce_num,
int num_block = (max_threads / left_num);
if (num_block > 1 && reduce_num >= REDUCE_SPLIT_BOUNDARY) {
*blocking_size = details::GetLastPow2(reduce_num / num_block);
*blocking_size =
pten::kernels::details::GetLastPow2(reduce_num / num_block);
if (*blocking_size <= 1) {
*blocking_size = details::GetLastPow2(sqrt(reduce_num));
*blocking_size = pten::kernels::details::GetLastPow2(sqrt(reduce_num));
} else if (*blocking_size * 2 < reduce_num) {
*blocking_size *= 2;
}
......
......@@ -24,6 +24,7 @@ namespace cub = hipcub;
#include "paddle/fluid/operators/margin_cross_entropy_op.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/softmax_impl.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
#include "paddle/fluid/string/string_helper.h"
......
......@@ -13,7 +13,7 @@
// limitations under the License.
#include "gtest/gtest.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/pten/kernels/hybird/cuda/reduce/reduce_cuda_impl.h"
namespace paddle {
namespace operators {
......@@ -39,9 +39,9 @@ TEST(test_reduce_rank_check, all) {
}
if (is_valid) {
CheckReduceRank(reduce_rank, rank);
pten::kernels::details::CheckReduceRank(reduce_rank, rank);
} else {
ASSERT_THROW(CheckReduceRank(reduce_rank, rank),
ASSERT_THROW(pten::kernels::details::CheckReduceRank(reduce_rank, rank),
paddle::platform::EnforceNotMet);
}
}
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cmath>
#include <limits>
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/platform/hostdevice.h"
#ifdef __HIPCC__
#include <hip/hip_runtime.h>
#endif
namespace paddle {
namespace operators {
namespace kps = paddle::operators::kernel_primitives;
template <typename Tx, typename Ty = Tx>
struct CustomMin {
using Transformer = kps::IdentityFunctor<Tx>;
inline Ty initial() {
return static_cast<Ty>(std::numeric_limits<Ty>::max());
}
__device__ __forceinline__ Ty operator()(const Ty &a, const Ty &b) const {
return (b < a) ? b : a;
}
};
template <typename Tx, typename Ty = Tx>
struct CustomMax {
using Transformer = kps::IdentityFunctor<Tx>;
inline Ty initial() {
return static_cast<Ty>(std::numeric_limits<Ty>::lowest());
}
__device__ __forceinline__ Ty operator()(const Ty &a, const Ty &b) const {
return (b > a) ? b : a;
}
};
// for cub::Reduce
template <typename Tx, typename Ty = Tx>
struct CustomSum {
using Transformer = kps::IdentityFunctor<Tx, Ty>;
inline Ty initial() { return static_cast<Ty>(0.0f); }
__device__ __forceinline__ Ty operator()(const Ty &a, const Ty &b) const {
return b + a;
}
};
template <typename Tx, typename Ty = Tx>
struct CustomSub {
using Transformer = kps::InverseFunctor<Tx>;
inline Ty initial() { return static_cast<Ty>(0.0f); }
__device__ __forceinline__ Ty operator()(const Ty &a, const Ty &b) const {
return b + a;
}
};
template <typename Tx, typename Ty = Tx>
struct CustomMean {
using Transformer = kps::DivideFunctor<Tx, Ty>;
inline Ty initial() { return static_cast<Ty>(0.0f); }
__device__ __forceinline__ Ty operator()(const Ty &a, const Ty &b) const {
return b + a;
}
};
template <typename Tx, typename Ty = Tx>
struct CustomMul {
using Transformer = kps::IdentityFunctor<Tx>;
inline Ty initial() { return static_cast<Ty>(1.0f); }
__device__ __forceinline__ Ty operator()(const Ty &a, const Ty &b) const {
return b * a;
}
};
template <typename Tx, typename Ty = Tx>
struct CustomLogicalOr {
using Transformer = kps::IdentityFunctor<Tx>;
inline Ty initial() { return static_cast<Ty>(false); }
__device__ __forceinline__ Ty operator()(const Ty &a, const Ty &b) const {
return b || a;
}
};
template <typename Tx, typename Ty = Tx>
struct CustomLogicalAnd {
using Transformer = kps::IdentityFunctor<Tx>;
inline Ty initial() { return static_cast<Ty>(true); }
__device__ __forceinline__ Ty operator()(const Ty &a, const Ty &b) const {
return b && a;
}
};
} // namespace operators
} // namespace paddle
......@@ -31,7 +31,7 @@ limitations under the License. */
#include "paddle/pten/kernels/hybird/general/reduce_impl.h"
#if defined(__HIPCC__) || defined(__NVCC__)
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/pten/kernels/hybird/cuda/reduce/reduce.h"
#endif
namespace paddle {
......@@ -700,24 +700,28 @@ class ReduceCudaKernel : public framework::OpKernel<T> {
auto out_dtype = context.Attr<int>("out_dtype");
std::vector<int> dims = context.Attr<std::vector<int>>("dim");
std::vector<int> reduce_dims =
GetReduceDim(dims, input->dims().size(), reduce_all);
int reduce_num = 1;
for (auto i : reduce_dims) {
reduce_num *= (input->dims())[i];
}
gpuStream_t stream = context.cuda_device_context().stream();
auto& dev_ctx = context.cuda_device_context();
if (out_dtype >= 0) {
framework::VisitDataTypeSmall(
static_cast<framework::proto::VarType::Type>(out_dtype),
TensorReduceFunc<T, ReduceOp, TransformOp>(
*input, output, reduce_dims, reduce_num, stream));
output->mutable_data(
dev_ctx.GetPlace(),
static_cast<framework::proto::VarType::Type>(out_dtype));
} else {
using MPType = typename details::MPTypeTrait<T>::Type;
TensorReduceFunctorImpl<T, T, ReduceOp, TransformOp<T, MPType>>(
*input, output, TransformOp<T, MPType>(reduce_num), reduce_dims,
stream);
output->mutable_data(
dev_ctx.GetPlace(),
static_cast<framework::proto::VarType::Type>(input->type()));
}
auto pt_x = paddle::experimental::MakePtenDenseTensor(*input);
auto pt_out = paddle::experimental::MakePtenDenseTensor(*output);
std::vector<int64_t> dims_int64{dims.begin(), dims.end()};
auto pt_out_dtype = pten::TransToPtenDataType(
static_cast<framework::proto::VarType::Type>(out_dtype));
pten::Reduce<T, ReduceOp, TransformOp>(dev_ctx, *pt_x.get(), reduce_all,
dims_int64, false, pt_out_dtype,
pt_out.get());
}
};
#endif
......
......@@ -159,6 +159,73 @@ namespace paddle {
} \
}()
///////// Floating and Complex and other type Dispatch Marco ///////////
#define PD_DISPATCH_FLOATING_AND_COMPLEX_AND_1_TYPES( \
SPECIFIED_TYPE, TYPE, NAME, ...) \
[&] { \
const auto& __dtype__ = TYPE; \
switch (__dtype__) { \
PD_PRIVATE_CASE_TYPE( \
NAME, \
SPECIFIED_TYPE, \
::paddle::experimental::DataTypeToCppType<SPECIFIED_TYPE>::type, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE( \
NAME, ::paddle::DataType::FLOAT32, float, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE( \
NAME, ::paddle::DataType::FLOAT64, double, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, \
::paddle::DataType::COMPLEX64, \
::paddle::complex64, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, \
::paddle::DataType::COMPLEX128, \
::paddle::complex128, \
__VA_ARGS__) \
default: \
PD_THROW("function " #NAME " is not implemented for data type `", \
__dtype__, \
"`"); \
} \
}()
///////// Floating and Complex and 2 other type Dispatch Marco ///////////
#define PD_DISPATCH_FLOATING_AND_COMPLEX_AND_2_TYPES( \
SPECIFIED_TYPE1, SPECIFIED_TYPE2, TYPE, NAME, ...) \
[&] { \
const auto& __dtype__ = TYPE; \
switch (__dtype__) { \
PD_PRIVATE_CASE_TYPE( \
NAME, \
SPECIFIED_TYPE1, \
::paddle::experimental::DataTypeToCppType<SPECIFIED_TYPE1>::type, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE( \
NAME, \
SPECIFIED_TYPE2, \
::paddle::experimental::DataTypeToCppType<SPECIFIED_TYPE2>::type, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE( \
NAME, ::paddle::DataType::FLOAT32, float, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE( \
NAME, ::paddle::DataType::FLOAT64, double, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, \
::paddle::DataType::COMPLEX64, \
::paddle::complex64, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, \
::paddle::DataType::COMPLEX128, \
::paddle::complex128, \
__VA_ARGS__) \
default: \
PD_THROW("function " #NAME " is not implemented for data type `", \
__dtype__, \
"`"); \
} \
}()
///////// Floating, Integral and Complex Dispatch Marco ///////////
#define PD_DISPATCH_FLOATING_AND_INTEGRAL_AND_COMPLEX_TYPES(TYPE, NAME, ...) \
......
......@@ -14,7 +14,6 @@ limitations under the License. */
#include "paddle/pten/kernels/gpu/math.h"
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/pten/kernels/hybird/cuda/elementwise/elementwise.h"
#include "paddle/pten/kernels/hybird/cuda/reduce/reduce.h"
#include "paddle/pten/kernels/hybird/general/elementwise_functor.h"
......@@ -35,6 +34,8 @@ namespace cub = hipcub;
#include "paddle/pten/core/convert_utils.h"
#include "paddle/pten/core/kernel_registry.h"
namespace kps = paddle::operators::kernel_primitives;
namespace pten {
/**
......@@ -64,7 +65,7 @@ void Mean(const GPUContext& dev_ctx,
bool reduce_all,
DenseTensor* out) {
auto out_dtype = x.dtype();
pten::Reduce<T, paddle::operators::CustomMean>(
pten::Reduce<T, kps::AddFunctor, kps::DivideFunctor>(
dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out);
}
......@@ -85,7 +86,7 @@ void Sum(const GPUContext& dev_ctx,
bool reduce_all,
DataType out_dtype,
DenseTensor* out) {
pten::Reduce<T, paddle::operators::CustomSum>(
pten::Reduce<T, kps::AddFunctor, kps::IdentityFunctor>(
dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out);
}
......@@ -95,7 +96,8 @@ using float16 = paddle::platform::float16;
using complex64 = ::paddle::platform::complex<float>;
using complex128 = ::paddle::platform::complex<double>;
PT_REGISTER_KERNEL(mean, GPU, ALL_LAYOUT, pten::Mean, float, double, bool) {}
PT_REGISTER_KERNEL(
mean, GPU, ALL_LAYOUT, pten::Mean, float, double, bool, float16) {}
PT_REGISTER_KERNEL(add,
GPU,
ALL_LAYOUT,
......
......@@ -17,38 +17,16 @@
// CUDA and HIP use same api
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/pten/api/ext/dispatch.h"
#include "paddle/pten/backends/gpu/gpu_context.h"
#include "paddle/pten/common/scalar.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/kernels/hybird/cuda/reduce/reduce_cuda_impl.h"
namespace pten {
static inline std::vector<int64_t> GetReduceDim(
const std::vector<int64_t>& dims, int dim_size, bool reduce_all) {
std::vector<int64_t> reduce_dims;
if (reduce_all) {
reduce_dims.resize(dim_size);
int reduce_size = reduce_dims.size();
for (int i = 0; i < reduce_size; ++i) {
reduce_dims[i] = i;
}
} else {
for (auto e : dims) {
PADDLE_ENFORCE_LT(e,
dim_size,
paddle::platform::errors::InvalidArgument(
"ReduceOp: invalid axis, when x_dims is %d, "
"axis[i] should less than x_dims, but got %d.",
dim_size,
e));
reduce_dims.push_back(e >= 0 ? e : e + dim_size);
}
}
return reduce_dims;
}
template <typename T, template <typename, typename> class ReduceFunctor>
template <typename T,
template <typename> class ReduceOp,
template <typename, typename> class TransformOp>
void Reduce(const GPUContext& dev_ctx,
const DenseTensor& x,
bool reduce_all,
......@@ -56,20 +34,35 @@ void Reduce(const GPUContext& dev_ctx,
bool keep_dim,
DataType out_dtype,
DenseTensor* out) {
std::vector<int64_t> reduce_dims =
GetReduceDim(dims, x.dims().size(), reduce_all);
std::vector<int> reduce_dims =
pten::kernels::details::GetReduceDim(dims, x.dims().size(), reduce_all);
int reduce_num = 1;
for (auto i : reduce_dims) {
reduce_num *= (x.dims())[i];
}
gpuStream_t stream = dev_ctx.stream();
if (out_dtype != pten::DataType::UNDEFINED && out_dtype != x.dtype()) {
PD_DISPATCH_FLOATING_AND_INTEGRAL_AND_COMPLEX_TYPES(
out_dtype, "TensorReduceFunctorImpl", ([&] {
pten::detail::TensorReduceFunctorImpl<T, data_t, ReduceFunctor>(
x, out, reduce_dims, stream);
PD_DISPATCH_FLOATING_AND_COMPLEX_AND_2_TYPES(
pten::DataType::INT32,
pten::DataType::INT64,
out_dtype,
"TensorReduceFunctorImpl",
([&] {
using MPType = typename kps::details::MPTypeTrait<data_t>::Type;
pten::kernels::TensorReduceFunctorImpl<T,
data_t,
ReduceOp,
TransformOp<T, MPType>>(
x, out, TransformOp<T, MPType>(reduce_num), reduce_dims, stream);
}));
} else {
pten::detail::TensorReduceFunctorImpl<T, T, ReduceFunctor>(
x, out, reduce_dims, stream);
using MPType = typename kps::details::MPTypeTrait<T>::Type;
pten::kernels::
TensorReduceFunctorImpl<T, T, ReduceOp, TransformOp<T, MPType>>(
x, out, TransformOp<T, MPType>(reduce_num), reduce_dims, stream);
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册