From b23914c26819cadf76cb47a4a7ec173ce88c3212 Mon Sep 17 00:00:00 2001 From: Aganlengzi Date: Wed, 1 Jun 2022 19:55:30 +0800 Subject: [PATCH] [fix] split nanmedian fluid deps (#43135) --- .../phi/kernels/cpu/nanmedian_grad_kernel.cc | 2 + paddle/phi/kernels/cpu/nanmedian_kernel.cc | 2 + .../phi/kernels/gpu/nanmedian_grad_kernel.cu | 4 +- paddle/phi/kernels/gpu/nanmedian_kernel.cu | 4 +- .../kernels/impl/nanmedian_grad_kernel_impl.h | 66 ++++++++++++++++++ .../phi/kernels/impl/nanmedian_kernel_impl.h | 69 +++++++++++++++++++ paddle/phi/kernels/nanmedian_grad_kernel.h | 45 +----------- paddle/phi/kernels/nanmedian_kernel.h | 48 +------------ 8 files changed, 147 insertions(+), 93 deletions(-) create mode 100644 paddle/phi/kernels/impl/nanmedian_grad_kernel_impl.h create mode 100644 paddle/phi/kernels/impl/nanmedian_kernel_impl.h diff --git a/paddle/phi/kernels/cpu/nanmedian_grad_kernel.cc b/paddle/phi/kernels/cpu/nanmedian_grad_kernel.cc index 156124c2148..f8639a0d10f 100644 --- a/paddle/phi/kernels/cpu/nanmedian_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/nanmedian_grad_kernel.cc @@ -13,9 +13,11 @@ // limitations under the License. #include "paddle/phi/kernels/nanmedian_grad_kernel.h" + #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/impl/nanmedian_grad_kernel_impl.h" namespace phi { diff --git a/paddle/phi/kernels/cpu/nanmedian_kernel.cc b/paddle/phi/kernels/cpu/nanmedian_kernel.cc index ed38405c917..03d7fe304be 100644 --- a/paddle/phi/kernels/cpu/nanmedian_kernel.cc +++ b/paddle/phi/kernels/cpu/nanmedian_kernel.cc @@ -13,8 +13,10 @@ // limitations under the License. #include "paddle/phi/kernels/nanmedian_kernel.h" + #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/nanmedian_kernel_impl.h" #include "paddle/phi/kernels/top_k_kernel.h" namespace phi { diff --git a/paddle/phi/kernels/gpu/nanmedian_grad_kernel.cu b/paddle/phi/kernels/gpu/nanmedian_grad_kernel.cu index a7cd49c0e53..1661d396641 100644 --- a/paddle/phi/kernels/gpu/nanmedian_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/nanmedian_grad_kernel.cu @@ -12,13 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/phi/kernels/nanmedian_grad_kernel.h" + #include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_meta.h" #include "paddle/phi/kernels/funcs/math_function.h" -#include "paddle/phi/kernels/nanmedian_grad_kernel.h" +#include "paddle/phi/kernels/impl/nanmedian_grad_kernel_impl.h" namespace phi { diff --git a/paddle/phi/kernels/gpu/nanmedian_kernel.cu b/paddle/phi/kernels/gpu/nanmedian_kernel.cu index 5975e274899..a67d64c2577 100644 --- a/paddle/phi/kernels/gpu/nanmedian_kernel.cu +++ b/paddle/phi/kernels/gpu/nanmedian_kernel.cu @@ -12,13 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/phi/kernels/nanmedian_kernel.h" + #include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/full_kernel.h" -#include "paddle/phi/kernels/nanmedian_kernel.h" +#include "paddle/phi/kernels/impl/nanmedian_kernel_impl.h" #include "paddle/phi/kernels/top_k_kernel.h" namespace phi { diff --git a/paddle/phi/kernels/impl/nanmedian_grad_kernel_impl.h b/paddle/phi/kernels/impl/nanmedian_grad_kernel_impl.h new file mode 100644 index 00000000000..f5743412762 --- /dev/null +++ b/paddle/phi/kernels/impl/nanmedian_grad_kernel_impl.h @@ -0,0 +1,66 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/nanmedian_grad_kernel.h" + +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +template +void PostprocessMedianGradKernel(const Context& dev_ctx, + DenseTensor* input, + const IntArray& raw_axes, + DenseTensor* x) { + auto input_dim = input->dims(); + auto rank = input_dim.size(); + + std::vector axes = raw_axes.GetData(); + int64_t axes_size = static_cast(axes.size()); + for (int64_t i = 0; i < axes_size; i++) { + if (axes[i] < 0) { + axes[i] += rank; + } + } + + std::vector trans_back; + std::vector reshape_back; + trans_back.reserve(rank); + trans_back.resize(rank); + + int offset = 0; + for (int64_t i = 0; i < rank; i++) { + if (std::find(axes.begin(), axes.end(), i) == axes.end()) { + reshape_back.push_back(input_dim[i]); + trans_back[i] = offset; + offset += 1; + } + } + + for (int64_t i = 0; i < rank; i++) { + if (std::find(axes.begin(), axes.end(), i) != axes.end()) { + trans_back[i] = offset; + reshape_back.push_back(input_dim[i]); + offset += 1; + } + } + + input->Resize(make_ddim(reshape_back)); + funcs::TransCompute( + static_cast(trans_back.size()), dev_ctx, *input, x, trans_back); +} + +} // namespace phi diff --git a/paddle/phi/kernels/impl/nanmedian_kernel_impl.h b/paddle/phi/kernels/impl/nanmedian_kernel_impl.h new file mode 100644 index 00000000000..57e9e5646e5 --- /dev/null +++ b/paddle/phi/kernels/impl/nanmedian_kernel_impl.h @@ -0,0 +1,69 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/nanmedian_kernel.h" + +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +template +void PreprocessMedianKernel(const Context& dev_ctx, + const DenseTensor& input, + const IntArray& raw_axes, + DenseTensor* x) { + auto input_dim = input.dims(); + auto rank = input_dim.size(); + std::vector perm; + std::vector reshape; + + std::vector axes = raw_axes.GetData(); + int64_t axes_size = static_cast(axes.size()); + for (int64_t i = 0; i < axes_size; i++) { + if (axes[i] < 0) { + axes[i] += rank; + } + } + + for (int64_t i = 0; i < rank; i++) { + if (std::find(axes.begin(), axes.end(), i) == axes.end()) { + perm.push_back(i); + reshape.push_back(input_dim[i]); + } + } + + int64_t post_numel = 1; + for (int64_t i = 0; i < rank; i++) { + if (std::find(axes.begin(), axes.end(), i) != axes.end()) { + perm.push_back(i); + post_numel *= input_dim[i]; + } + } + reshape.push_back(post_numel); + + DDim trans_dim(input_dim); + int ndims = perm.size(); + for (int i = 0; i < ndims; i++) { + trans_dim[i] = input_dim[perm[i]]; + } + x->Resize(trans_dim); + dev_ctx.template Alloc(x); + funcs::TransCompute(ndims, dev_ctx, input, x, perm); + + x->Resize(make_ddim(reshape)); +} + +} // namespace phi diff --git a/paddle/phi/kernels/nanmedian_grad_kernel.h b/paddle/phi/kernels/nanmedian_grad_kernel.h index dc7321c1aa7..e8fb01b7060 100644 --- a/paddle/phi/kernels/nanmedian_grad_kernel.h +++ b/paddle/phi/kernels/nanmedian_grad_kernel.h @@ -13,55 +13,12 @@ // limitations under the License. #pragma once + #include "paddle/phi/common/int_array.h" #include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/kernels/funcs/math_function.h" namespace phi { -template -void PostprocessMedianGradKernel(const Context& dev_ctx, - DenseTensor* input, - const IntArray& raw_axes, - DenseTensor* x) { - auto input_dim = input->dims(); - auto rank = input_dim.size(); - - std::vector axes = raw_axes.GetData(); - int64_t axes_size = static_cast(axes.size()); - for (int64_t i = 0; i < axes_size; i++) { - if (axes[i] < 0) { - axes[i] += rank; - } - } - - std::vector trans_back; - std::vector reshape_back; - trans_back.reserve(rank); - trans_back.resize(rank); - - int offset = 0; - for (int64_t i = 0; i < rank; i++) { - if (std::find(axes.begin(), axes.end(), i) == axes.end()) { - reshape_back.push_back(input_dim[i]); - trans_back[i] = offset; - offset += 1; - } - } - - for (int64_t i = 0; i < rank; i++) { - if (std::find(axes.begin(), axes.end(), i) != axes.end()) { - trans_back[i] = offset; - reshape_back.push_back(input_dim[i]); - offset += 1; - } - } - - input->Resize(make_ddim(reshape_back)); - funcs::TransCompute( - static_cast(trans_back.size()), dev_ctx, *input, x, trans_back); -} - template void NanmedianGradKernel(const Context& dev_ctx, const DenseTensor& x, diff --git a/paddle/phi/kernels/nanmedian_kernel.h b/paddle/phi/kernels/nanmedian_kernel.h index 374f420381b..4bb382a4431 100644 --- a/paddle/phi/kernels/nanmedian_kernel.h +++ b/paddle/phi/kernels/nanmedian_kernel.h @@ -13,58 +13,12 @@ // limitations under the License. #pragma once + #include "paddle/phi/common/int_array.h" #include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/kernels/funcs/math_function.h" namespace phi { -template -void PreprocessMedianKernel(const Context& dev_ctx, - const DenseTensor& input, - const IntArray& raw_axes, - DenseTensor* x) { - auto input_dim = input.dims(); - auto rank = input_dim.size(); - std::vector perm; - std::vector reshape; - - std::vector axes = raw_axes.GetData(); - int64_t axes_size = static_cast(axes.size()); - for (int64_t i = 0; i < axes_size; i++) { - if (axes[i] < 0) { - axes[i] += rank; - } - } - - for (int64_t i = 0; i < rank; i++) { - if (std::find(axes.begin(), axes.end(), i) == axes.end()) { - perm.push_back(i); - reshape.push_back(input_dim[i]); - } - } - - int64_t post_numel = 1; - for (int64_t i = 0; i < rank; i++) { - if (std::find(axes.begin(), axes.end(), i) != axes.end()) { - perm.push_back(i); - post_numel *= input_dim[i]; - } - } - reshape.push_back(post_numel); - - DDim trans_dim(input_dim); - int ndims = perm.size(); - for (int i = 0; i < ndims; i++) { - trans_dim[i] = input_dim[perm[i]]; - } - x->Resize(trans_dim); - dev_ctx.template Alloc(x); - funcs::TransCompute(ndims, dev_ctx, input, x, perm); - - x->Resize(make_ddim(reshape)); -} - template void NanmedianKernel(const Context& dev_ctx, const DenseTensor& x, -- GitLab