From c48bd3ffe7e60aee306886a5a6898d1919e6c3ce Mon Sep 17 00:00:00 2001 From: chentianyu03 Date: Thu, 6 Jan 2022 18:46:01 +0800 Subject: [PATCH] [pten]move reduce files and dev_api (#38715) * move eigen/reduce.h imple into cpu/reduce.h * ctx to dev_ctx --- paddle/pten/include/math.h | 37 --- paddle/pten/kernels/cpu/math_kernel.cc | 5 +- paddle/pten/kernels/cpu/reduce.h | 180 ++++++++++++++- paddle/pten/kernels/funcs/reduce_functor.h | 37 +++ paddle/pten/kernels/hybird/eigen/reduce.h | 214 ------------------ paddle/pten/kernels/math_kernel.h | 31 +++ .../pten/tests/kernels/test_mean_dev_api.cc | 2 +- paddle/pten/tests/kernels/test_sum_dev_api.cc | 2 +- 8 files changed, 250 insertions(+), 258 deletions(-) create mode 100644 paddle/pten/kernels/funcs/reduce_functor.h delete mode 100644 paddle/pten/kernels/hybird/eigen/reduce.h diff --git a/paddle/pten/include/math.h b/paddle/pten/include/math.h index e46f460260a..faa4c8db8da 100644 --- a/paddle/pten/include/math.h +++ b/paddle/pten/include/math.h @@ -18,7 +18,6 @@ limitations under the License. */ #include "paddle/pten/api/lib/utils/storage.h" #include "paddle/pten/include/infermeta.h" #include "paddle/pten/kernels/complex_kernel.h" -#include "paddle/pten/kernels/math_kernel.h" #include "paddle/pten/kernels/scale_kernel.h" namespace pten { @@ -34,42 +33,6 @@ DenseTensor Sign(const ContextT& dev_ctx, const DenseTensor& x) { return dense_out; } -template -DenseTensor Mean(const ContextT& dev_ctx, - const DenseTensor& x, - const std::vector& axis, - bool keep_dim) { - auto out_meta = ReduceInferMeta(x.meta(), axis, keep_dim); - pten::DenseTensor dense_out( - pten::make_intrusive( - dev_ctx.GetPlace()), - std::move(out_meta)); - bool reduce_all = false; - MeanKernel(dev_ctx, x, axis, keep_dim, reduce_all, &dense_out); - return dense_out; -} - -template -DenseTensor Sum(const ContextT& dev_ctx, - const DenseTensor& x, - const std::vector& axis, - DataType dtype, - bool keep_dim) { - auto out_meta = ReduceInferMeta(x.meta(), axis, keep_dim, dtype); - pten::DenseTensor dense_out( - pten::make_intrusive( - dev_ctx.GetPlace()), - out_meta); - - // The real value of reduce_all will be get in kernel - // so use default value(false) is OK. - bool reduce_all = false; - - SumKernel( - dev_ctx, x, axis, keep_dim, reduce_all, out_meta.dtype, &dense_out); - return dense_out; -} - template DenseTensor Scale(const ContextT& dev_ctx, const DenseTensor& x, diff --git a/paddle/pten/kernels/cpu/math_kernel.cc b/paddle/pten/kernels/cpu/math_kernel.cc index 2a696584bc7..be0d52355bc 100644 --- a/paddle/pten/kernels/cpu/math_kernel.cc +++ b/paddle/pten/kernels/cpu/math_kernel.cc @@ -21,6 +21,7 @@ #include "paddle/pten/kernels/cpu/elementwise.h" #include "paddle/pten/kernels/cpu/reduce.h" #include "paddle/pten/kernels/funcs/elementwise_functor.h" +#include "paddle/pten/kernels/funcs/reduce_functor.h" // See Note [ Why still include the fluid headers? ] #include "paddle/fluid/framework/eigen.h" @@ -61,7 +62,7 @@ void MeanKernel(const Context& dev_ctx, bool reduce_all, DenseTensor* out) { auto out_dtype = x.dtype(); - pten::Reduce( + pten::Reduce( dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); } @@ -97,7 +98,7 @@ void SumKernel(const Context& dev_ctx, bool reduce_all, DataType out_dtype, DenseTensor* out) { - pten::Reduce( + pten::Reduce( dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); } diff --git a/paddle/pten/kernels/cpu/reduce.h b/paddle/pten/kernels/cpu/reduce.h index fc5dbe9d58d..fa603b21630 100644 --- a/paddle/pten/kernels/cpu/reduce.h +++ b/paddle/pten/kernels/cpu/reduce.h @@ -19,10 +19,184 @@ #include "paddle/pten/api/ext/dispatch.h" #include "paddle/pten/backends/cpu/cpu_context.h" #include "paddle/pten/kernels/cast_kernel.h" -#include "paddle/pten/kernels/hybird/eigen/reduce.h" +#include "paddle/pten/api/lib/utils/storage.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/kernels/hybird/eigen/common.h" +#include "paddle/pten/kernels/hybird/transpose.h" +// See Note [ Why still include the fluid headers? ] +#include "paddle/fluid/operators/eigen/eigen_function.h" namespace pten { +template +void ReduceFunctor(const DeviceContext& context, + const pten::DenseTensor& input, + pten::DenseTensor* output, + const std::vector& dims, + bool keep_dim) { + auto x = EigenTensor::From(input); + auto x_rank = static_cast(x.dimensions().size()); + auto reduce_dim = Eigen::array(); + std::vector dims_ref = dims; + for (size_t i = 0; i < dims_ref.size(); ++i) { + if (dims_ref[i] < 0) dims_ref[i] = x_rank + dims_ref[i]; + reduce_dim[i] = dims_ref[i]; + } + // construct the squeezed output tensor + DDim out_dims = output->dims(); + if (keep_dim && x_rank > 1) { + const int kDelFlag = -2; + auto dims_vector = paddle::framework::vectorize(out_dims); + for (size_t i = 0; i < dims_ref.size(); ++i) { + dims_vector[dims_ref[i]] = kDelFlag; + } + dims_vector.erase(remove(dims_vector.begin(), dims_vector.end(), kDelFlag), + dims_vector.end()); + out_dims = paddle::framework::make_ddim(dims_vector); + } + auto& place = *context.eigen_device(); + Functor functor; + + if (D == 1) { + auto out = EigenScalar::From(*output); + functor(place, &x, &out, reduce_dim); + } else { + auto out = EigenTensor::From(*output, out_dims); + functor(place, &x, &out, reduce_dim); + } +} + +#define HANDLE_REDUCE_DIM(NDIM, RDIM) \ + if (ndim == NDIM && rdim == RDIM) { \ + ReduceFunctor( \ + dev_ctx, input, output, dims, keep_dim); \ + } +//////////////// HandleLargeDim + +inline void GetShuffledDim(const DDim& src_dims, + DDim* dst_dims, + const std::vector& reduced_dims, + std::vector* perm_axis) { + // check if it's a reduced dim + std::vector src_dims_check(src_dims.size(), false); + size_t src_size = src_dims.size(); + size_t reduce_size = reduced_dims.size(); + std::vector regular_reduced_dims = reduced_dims; + for (size_t i = 0; i < regular_reduced_dims.size(); i++) { + if (regular_reduced_dims[i] < 0) { + regular_reduced_dims[i] = src_size + regular_reduced_dims[i]; + } + } + + for (size_t i = 0; i < reduce_size; ++i) { + dst_dims->at(src_size - reduce_size + i) = + src_dims[regular_reduced_dims[i]]; + (*perm_axis)[src_size - reduce_size + i] = regular_reduced_dims[i]; + src_dims_check[regular_reduced_dims[i]] = true; + } + + size_t offset = 0; + for (size_t i = 0; i < src_dims_check.size(); ++i) { + bool is_reduced = src_dims_check[i]; + if (!is_reduced) { + (*perm_axis)[offset] = i; + dst_dims->at(offset++) = src_dims[i]; + } + } +} + +template +void GetShuffledInput(const DeviceContext& dev_ctx, + const pten::DenseTensor& input, + pten::DenseTensor* shuffled_input, + const std::vector& dims) { + DDim shuffled_dims(input.dims()); + std::vector perm_axis(input.dims().size()); + GetShuffledDim(input.dims(), &shuffled_dims, dims, &perm_axis); + + shuffled_input->Resize(shuffled_dims); + shuffled_input->mutable_data(); + + pten::math::TransposeNormal trans; + trans(dev_ctx, input, shuffled_input, perm_axis); +} + +template +void HandleLargeDim(const DeviceContext& dev_ctx, + const pten::DenseTensor& input, + pten::DenseTensor* output, + const std::vector& dims, + bool keep_dim) { + // shuffle the reduced dim to the end + pten::DenseTensor shuffled_input = pten::DenseTensor( + pten::make_intrusive(input.place()), + input.meta()); + + GetShuffledInput(dev_ctx, input, &shuffled_input, dims); + + // transpose to 2D tensor whose shape is {unreduced, reduced}. + const int64_t unreduced = output->numel(); + const int64_t reduced = shuffled_input.numel() / unreduced; + shuffled_input.Resize({unreduced, reduced}); + DDim output_dim = output->dims(); + output->Resize({unreduced}); + ReduceFunctor( + dev_ctx, shuffled_input, output, {1}, keep_dim); + output->Resize(output_dim); +} + +////////////// ReduceKernel + +template +void ReduceKernelImpl(const DeviceContext& dev_ctx, + const pten::DenseTensor& input, + pten::DenseTensor* output, + const std::vector& dims, + bool keep_dim, + bool reduce_all) { + output->mutable_data(); + + if (reduce_all) { + // Flatten and reduce 1-D tensor + auto x = EigenVector::Flatten(input); + auto out = EigenScalar::From(*output); + auto& dev = *dev_ctx.eigen_device(); + auto reduce_dim = Eigen::array({{0}}); + + Functor functor; + functor(dev, &x, &out, reduce_dim); + } else { + int ndim = input.dims().size(); + int rdim = dims.size(); + if (ndim > 6) { + HandleLargeDim( + dev_ctx, input, output, dims, keep_dim); + + } else { + HANDLE_REDUCE_DIM(6, 5); + HANDLE_REDUCE_DIM(6, 4); + HANDLE_REDUCE_DIM(6, 3); + HANDLE_REDUCE_DIM(6, 2); + HANDLE_REDUCE_DIM(6, 1); + HANDLE_REDUCE_DIM(5, 4); + HANDLE_REDUCE_DIM(5, 3); + HANDLE_REDUCE_DIM(5, 2); + HANDLE_REDUCE_DIM(5, 1); + HANDLE_REDUCE_DIM(4, 3); + HANDLE_REDUCE_DIM(4, 2); + HANDLE_REDUCE_DIM(4, 1); + HANDLE_REDUCE_DIM(3, 2); + HANDLE_REDUCE_DIM(3, 1); + HANDLE_REDUCE_DIM(2, 1); + HANDLE_REDUCE_DIM(1, 1); + } + } +} + template void Reduce(const DeviceContext& dev_ctx, const DenseTensor& x, @@ -52,7 +226,7 @@ void Reduce(const DeviceContext& dev_ctx, // do reduce sum PD_VISIT_ALL_TYPES( out_dtype, "ReduceKernelImpl", ([&] { - pten::eigen::ReduceKernelImpl( + pten::ReduceKernelImpl( dev_ctx, x, out, dims, keep_dim, reduce_all); })); } else { @@ -66,7 +240,7 @@ void Reduce(const DeviceContext& dev_ctx, // do reduce sum PD_VISIT_ALL_TYPES( out_dtype, "ReduceKernelImpl", ([&] { - pten::eigen::ReduceKernelImpl( + pten::ReduceKernelImpl( dev_ctx, tmp_tensor, out, dims, keep_dim, reduce_all); })); } diff --git a/paddle/pten/kernels/funcs/reduce_functor.h b/paddle/pten/kernels/funcs/reduce_functor.h new file mode 100644 index 00000000000..64ada023189 --- /dev/null +++ b/paddle/pten/kernels/funcs/reduce_functor.h @@ -0,0 +1,37 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace pten { +namespace funcs { + +//////// Sum Functor /////// +struct SumFunctor { + template + void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) { + y->device(place) = x->sum(dim); + } +}; + +//////// Mean Functor /////// +struct MeanFunctor { + template + void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) { + y->device(place) = x->mean(dim); + } +}; + +} // namespace funcs +} // namespace pten diff --git a/paddle/pten/kernels/hybird/eigen/reduce.h b/paddle/pten/kernels/hybird/eigen/reduce.h deleted file mode 100644 index d60a416dfdb..00000000000 --- a/paddle/pten/kernels/hybird/eigen/reduce.h +++ /dev/null @@ -1,214 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/pten/api/lib/utils/storage.h" -#include "paddle/pten/core/dense_tensor.h" -#include "paddle/pten/kernels/hybird/eigen/common.h" -#include "paddle/pten/kernels/hybird/transpose.h" - -// See Note [ Why still include the fluid headers? ] -#include "paddle/fluid/operators/eigen/eigen_function.h" - -namespace pten { -namespace eigen { - -template -void ReduceFunctor(const DeviceContext& context, - const pten::DenseTensor& input, - pten::DenseTensor* output, - const std::vector& dims, - bool keep_dim) { - auto x = EigenTensor::From(input); - auto x_rank = static_cast(x.dimensions().size()); - auto reduce_dim = Eigen::array(); - std::vector dims_ref = dims; - for (size_t i = 0; i < dims_ref.size(); ++i) { - if (dims_ref[i] < 0) dims_ref[i] = x_rank + dims_ref[i]; - reduce_dim[i] = dims_ref[i]; - } - // construct the squeezed output tensor - DDim out_dims = output->dims(); - if (keep_dim && x_rank > 1) { - const int kDelFlag = -2; - auto dims_vector = paddle::framework::vectorize(out_dims); - for (size_t i = 0; i < dims_ref.size(); ++i) { - dims_vector[dims_ref[i]] = kDelFlag; - } - dims_vector.erase(remove(dims_vector.begin(), dims_vector.end(), kDelFlag), - dims_vector.end()); - out_dims = paddle::framework::make_ddim(dims_vector); - } - auto& place = *context.eigen_device(); - Functor functor; - - if (D == 1) { - auto out = EigenScalar::From(*output); - functor(place, &x, &out, reduce_dim); - } else { - auto out = EigenTensor::From(*output, out_dims); - functor(place, &x, &out, reduce_dim); - } -} - -#define HANDLE_REDUCE_DIM(NDIM, RDIM) \ - if (ndim == NDIM && rdim == RDIM) { \ - ReduceFunctor( \ - dev_ctx, input, output, dims, keep_dim); \ - } -//////////////// HandleLargeDim - -inline void GetShuffledDim(const DDim& src_dims, - DDim* dst_dims, - const std::vector& reduced_dims, - std::vector* perm_axis) { - // check if it's a reduced dim - std::vector src_dims_check(src_dims.size(), false); - size_t src_size = src_dims.size(); - size_t reduce_size = reduced_dims.size(); - std::vector regular_reduced_dims = reduced_dims; - for (size_t i = 0; i < regular_reduced_dims.size(); i++) { - if (regular_reduced_dims[i] < 0) { - regular_reduced_dims[i] = src_size + regular_reduced_dims[i]; - } - } - - for (size_t i = 0; i < reduce_size; ++i) { - dst_dims->at(src_size - reduce_size + i) = - src_dims[regular_reduced_dims[i]]; - (*perm_axis)[src_size - reduce_size + i] = regular_reduced_dims[i]; - src_dims_check[regular_reduced_dims[i]] = true; - } - - size_t offset = 0; - for (size_t i = 0; i < src_dims_check.size(); ++i) { - bool is_reduced = src_dims_check[i]; - if (!is_reduced) { - (*perm_axis)[offset] = i; - dst_dims->at(offset++) = src_dims[i]; - } - } -} - -template -void GetShuffledInput(const DeviceContext& dev_ctx, - const pten::DenseTensor& input, - pten::DenseTensor* shuffled_input, - const std::vector& dims) { - DDim shuffled_dims(input.dims()); - std::vector perm_axis(input.dims().size()); - GetShuffledDim(input.dims(), &shuffled_dims, dims, &perm_axis); - - shuffled_input->Resize(shuffled_dims); - shuffled_input->mutable_data(); - - pten::math::TransposeNormal trans; - trans(dev_ctx, input, shuffled_input, perm_axis); -} - -template -void HandleLargeDim(const DeviceContext& dev_ctx, - const pten::DenseTensor& input, - pten::DenseTensor* output, - const std::vector& dims, - bool keep_dim) { - // shuffle the reduced dim to the end - pten::DenseTensor shuffled_input = pten::DenseTensor( - pten::make_intrusive(input.place()), - input.meta()); - - GetShuffledInput(dev_ctx, input, &shuffled_input, dims); - - // transpose to 2D tensor whose shape is {unreduced, reduced}. - const int64_t unreduced = output->numel(); - const int64_t reduced = shuffled_input.numel() / unreduced; - shuffled_input.Resize({unreduced, reduced}); - DDim output_dim = output->dims(); - output->Resize({unreduced}); - ReduceFunctor( - dev_ctx, shuffled_input, output, {1}, keep_dim); - output->Resize(output_dim); -} - -////////////// ReduceKernel - -template -void ReduceKernelImpl(const DeviceContext& dev_ctx, - const pten::DenseTensor& input, - pten::DenseTensor* output, - const std::vector& dims, - bool keep_dim, - bool reduce_all) { - output->mutable_data(); - - if (reduce_all) { - // Flatten and reduce 1-D tensor - auto x = EigenVector::Flatten(input); - auto out = EigenScalar::From(*output); - auto& dev = *dev_ctx.eigen_device(); - auto reduce_dim = Eigen::array({{0}}); - - Functor functor; - functor(dev, &x, &out, reduce_dim); - } else { - int ndim = input.dims().size(); - int rdim = dims.size(); - if (ndim > 6) { - HandleLargeDim( - dev_ctx, input, output, dims, keep_dim); - - } else { - HANDLE_REDUCE_DIM(6, 5); - HANDLE_REDUCE_DIM(6, 4); - HANDLE_REDUCE_DIM(6, 3); - HANDLE_REDUCE_DIM(6, 2); - HANDLE_REDUCE_DIM(6, 1); - HANDLE_REDUCE_DIM(5, 4); - HANDLE_REDUCE_DIM(5, 3); - HANDLE_REDUCE_DIM(5, 2); - HANDLE_REDUCE_DIM(5, 1); - HANDLE_REDUCE_DIM(4, 3); - HANDLE_REDUCE_DIM(4, 2); - HANDLE_REDUCE_DIM(4, 1); - HANDLE_REDUCE_DIM(3, 2); - HANDLE_REDUCE_DIM(3, 1); - HANDLE_REDUCE_DIM(2, 1); - HANDLE_REDUCE_DIM(1, 1); - } - } -} - -//////// Sum Functor /////// -struct SumFunctor { - template - void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) { - y->device(place) = x->sum(dim); - } -}; - -//////// Mean Functor /////// -struct MeanFunctor { - template - void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) { - y->device(place) = x->mean(dim); - } -}; - -} // namespace eigen -} // namespace pten diff --git a/paddle/pten/kernels/math_kernel.h b/paddle/pten/kernels/math_kernel.h index b1e5188f3aa..f87d0a31b47 100644 --- a/paddle/pten/kernels/math_kernel.h +++ b/paddle/pten/kernels/math_kernel.h @@ -17,6 +17,7 @@ limitations under the License. */ #include "paddle/pten/api/lib/utils/storage.h" #include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/include/infermeta.h" +#include "paddle/pten/kernels/empty_kernel.h" namespace pten { @@ -121,4 +122,34 @@ DenseTensor Multiply(const ContextT& dev_ctx, return dense_out; } +template +DenseTensor Mean(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& axis, + bool keep_dim) { + auto out_meta = ReduceInferMeta(x.meta(), axis, keep_dim); + auto dense_out = pten::Empty(dev_ctx, std::move(out_meta)); + bool reduce_all = false; + MeanKernel(dev_ctx, x, axis, keep_dim, reduce_all, &dense_out); + return dense_out; +} + +template +DenseTensor Sum(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& axis, + DataType dtype, + bool keep_dim) { + auto out_meta = ReduceInferMeta(x.meta(), axis, keep_dim, dtype); + auto dense_out = pten::Empty(dev_ctx, std::move(out_meta)); + + // The real value of reduce_all will be get in kernel + // so use default value(false) is OK. + bool reduce_all = false; + + SumKernel( + dev_ctx, x, axis, keep_dim, reduce_all, out_meta.dtype, &dense_out); + return dense_out; +} + } // namespace pten diff --git a/paddle/pten/tests/kernels/test_mean_dev_api.cc b/paddle/pten/tests/kernels/test_mean_dev_api.cc index 4d062977e23..4b254e7e6c1 100644 --- a/paddle/pten/tests/kernels/test_mean_dev_api.cc +++ b/paddle/pten/tests/kernels/test_mean_dev_api.cc @@ -15,7 +15,7 @@ limitations under the License. */ #include #include -#include "paddle/pten/include/math.h" +#include "paddle/pten/kernels/math_kernel.h" #include "paddle/pten/api/lib/utils/allocator.h" #include "paddle/pten/core/dense_tensor.h" diff --git a/paddle/pten/tests/kernels/test_sum_dev_api.cc b/paddle/pten/tests/kernels/test_sum_dev_api.cc index 381b8fe44f5..afaf9030637 100644 --- a/paddle/pten/tests/kernels/test_sum_dev_api.cc +++ b/paddle/pten/tests/kernels/test_sum_dev_api.cc @@ -15,7 +15,7 @@ limitations under the License. */ #include #include -#include "paddle/pten/include/math.h" +#include "paddle/pten/kernels/math_kernel.h" #include "paddle/pten/api/lib/utils/allocator.h" #include "paddle/pten/core/dense_tensor.h" -- GitLab