diff --git a/paddle/phi/kernels/cpu/nanmedian_grad_kernel.cc b/paddle/phi/kernels/cpu/nanmedian_grad_kernel.cc index 156124c21489542d420bcc2a0dc10f1dbfb7a7b5..f8639a0d10feed8e137b340cc61cd65dad9de99a 100644 --- a/paddle/phi/kernels/cpu/nanmedian_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/nanmedian_grad_kernel.cc @@ -13,9 +13,11 @@ // limitations under the License. #include "paddle/phi/kernels/nanmedian_grad_kernel.h" + #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/impl/nanmedian_grad_kernel_impl.h" namespace phi { diff --git a/paddle/phi/kernels/cpu/nanmedian_kernel.cc b/paddle/phi/kernels/cpu/nanmedian_kernel.cc index ed38405c9179fcb8897491d2257326eb4af28210..03d7fe304be3ea663f5201bb49515657eabc36d7 100644 --- a/paddle/phi/kernels/cpu/nanmedian_kernel.cc +++ b/paddle/phi/kernels/cpu/nanmedian_kernel.cc @@ -13,8 +13,10 @@ // limitations under the License. #include "paddle/phi/kernels/nanmedian_kernel.h" + #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/nanmedian_kernel_impl.h" #include "paddle/phi/kernels/top_k_kernel.h" namespace phi { diff --git a/paddle/phi/kernels/gpu/nanmedian_grad_kernel.cu b/paddle/phi/kernels/gpu/nanmedian_grad_kernel.cu index a7cd49c0e53f3c27794115a2bd19093e8abbd04d..1661d396641af1e8657201390e70609d5e6795d7 100644 --- a/paddle/phi/kernels/gpu/nanmedian_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/nanmedian_grad_kernel.cu @@ -12,13 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/phi/kernels/nanmedian_grad_kernel.h" + #include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_meta.h" #include "paddle/phi/kernels/funcs/math_function.h" -#include "paddle/phi/kernels/nanmedian_grad_kernel.h" +#include "paddle/phi/kernels/impl/nanmedian_grad_kernel_impl.h" namespace phi { diff --git a/paddle/phi/kernels/gpu/nanmedian_kernel.cu b/paddle/phi/kernels/gpu/nanmedian_kernel.cu index 5975e2748997e94c184a68102d82417cd6f770f4..a67d64c257761c6bf5d2f1803b18394295c9fa4a 100644 --- a/paddle/phi/kernels/gpu/nanmedian_kernel.cu +++ b/paddle/phi/kernels/gpu/nanmedian_kernel.cu @@ -12,13 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/phi/kernels/nanmedian_kernel.h" + #include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/full_kernel.h" -#include "paddle/phi/kernels/nanmedian_kernel.h" +#include "paddle/phi/kernels/impl/nanmedian_kernel_impl.h" #include "paddle/phi/kernels/top_k_kernel.h" namespace phi { diff --git a/paddle/phi/kernels/impl/nanmedian_grad_kernel_impl.h b/paddle/phi/kernels/impl/nanmedian_grad_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..f57434127620cc345de4b87ed72f29e9cc85de75 --- /dev/null +++ b/paddle/phi/kernels/impl/nanmedian_grad_kernel_impl.h @@ -0,0 +1,66 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/nanmedian_grad_kernel.h" + +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +template +void PostprocessMedianGradKernel(const Context& dev_ctx, + DenseTensor* input, + const IntArray& raw_axes, + DenseTensor* x) { + auto input_dim = input->dims(); + auto rank = input_dim.size(); + + std::vector axes = raw_axes.GetData(); + int64_t axes_size = static_cast(axes.size()); + for (int64_t i = 0; i < axes_size; i++) { + if (axes[i] < 0) { + axes[i] += rank; + } + } + + std::vector trans_back; + std::vector reshape_back; + trans_back.reserve(rank); + trans_back.resize(rank); + + int offset = 0; + for (int64_t i = 0; i < rank; i++) { + if (std::find(axes.begin(), axes.end(), i) == axes.end()) { + reshape_back.push_back(input_dim[i]); + trans_back[i] = offset; + offset += 1; + } + } + + for (int64_t i = 0; i < rank; i++) { + if (std::find(axes.begin(), axes.end(), i) != axes.end()) { + trans_back[i] = offset; + reshape_back.push_back(input_dim[i]); + offset += 1; + } + } + + input->Resize(make_ddim(reshape_back)); + funcs::TransCompute( + static_cast(trans_back.size()), dev_ctx, *input, x, trans_back); +} + +} // namespace phi diff --git a/paddle/phi/kernels/impl/nanmedian_kernel_impl.h b/paddle/phi/kernels/impl/nanmedian_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..57e9e5646e5595deaae74c805b57828cabed7ade --- /dev/null +++ b/paddle/phi/kernels/impl/nanmedian_kernel_impl.h @@ -0,0 +1,69 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/nanmedian_kernel.h" + +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +template +void PreprocessMedianKernel(const Context& dev_ctx, + const DenseTensor& input, + const IntArray& raw_axes, + DenseTensor* x) { + auto input_dim = input.dims(); + auto rank = input_dim.size(); + std::vector perm; + std::vector reshape; + + std::vector axes = raw_axes.GetData(); + int64_t axes_size = static_cast(axes.size()); + for (int64_t i = 0; i < axes_size; i++) { + if (axes[i] < 0) { + axes[i] += rank; + } + } + + for (int64_t i = 0; i < rank; i++) { + if (std::find(axes.begin(), axes.end(), i) == axes.end()) { + perm.push_back(i); + reshape.push_back(input_dim[i]); + } + } + + int64_t post_numel = 1; + for (int64_t i = 0; i < rank; i++) { + if (std::find(axes.begin(), axes.end(), i) != axes.end()) { + perm.push_back(i); + post_numel *= input_dim[i]; + } + } + reshape.push_back(post_numel); + + DDim trans_dim(input_dim); + int ndims = perm.size(); + for (int i = 0; i < ndims; i++) { + trans_dim[i] = input_dim[perm[i]]; + } + x->Resize(trans_dim); + dev_ctx.template Alloc(x); + funcs::TransCompute(ndims, dev_ctx, input, x, perm); + + x->Resize(make_ddim(reshape)); +} + +} // namespace phi diff --git a/paddle/phi/kernels/nanmedian_grad_kernel.h b/paddle/phi/kernels/nanmedian_grad_kernel.h index dc7321c1aa75123a47aa7d470fe16d37dc420c9d..e8fb01b7060a78f2ea4ebfff6ba33aff2bf762cd 100644 --- a/paddle/phi/kernels/nanmedian_grad_kernel.h +++ b/paddle/phi/kernels/nanmedian_grad_kernel.h @@ -13,55 +13,12 @@ // limitations under the License. #pragma once + #include "paddle/phi/common/int_array.h" #include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/kernels/funcs/math_function.h" namespace phi { -template -void PostprocessMedianGradKernel(const Context& dev_ctx, - DenseTensor* input, - const IntArray& raw_axes, - DenseTensor* x) { - auto input_dim = input->dims(); - auto rank = input_dim.size(); - - std::vector axes = raw_axes.GetData(); - int64_t axes_size = static_cast(axes.size()); - for (int64_t i = 0; i < axes_size; i++) { - if (axes[i] < 0) { - axes[i] += rank; - } - } - - std::vector trans_back; - std::vector reshape_back; - trans_back.reserve(rank); - trans_back.resize(rank); - - int offset = 0; - for (int64_t i = 0; i < rank; i++) { - if (std::find(axes.begin(), axes.end(), i) == axes.end()) { - reshape_back.push_back(input_dim[i]); - trans_back[i] = offset; - offset += 1; - } - } - - for (int64_t i = 0; i < rank; i++) { - if (std::find(axes.begin(), axes.end(), i) != axes.end()) { - trans_back[i] = offset; - reshape_back.push_back(input_dim[i]); - offset += 1; - } - } - - input->Resize(make_ddim(reshape_back)); - funcs::TransCompute( - static_cast(trans_back.size()), dev_ctx, *input, x, trans_back); -} - template void NanmedianGradKernel(const Context& dev_ctx, const DenseTensor& x, diff --git a/paddle/phi/kernels/nanmedian_kernel.h b/paddle/phi/kernels/nanmedian_kernel.h index 374f420381bdc8f46f4382ed80ee6180128312c5..4bb382a443144f598da132ebb0213b7317e68b76 100644 --- a/paddle/phi/kernels/nanmedian_kernel.h +++ b/paddle/phi/kernels/nanmedian_kernel.h @@ -13,58 +13,12 @@ // limitations under the License. #pragma once + #include "paddle/phi/common/int_array.h" #include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/kernels/funcs/math_function.h" namespace phi { -template -void PreprocessMedianKernel(const Context& dev_ctx, - const DenseTensor& input, - const IntArray& raw_axes, - DenseTensor* x) { - auto input_dim = input.dims(); - auto rank = input_dim.size(); - std::vector perm; - std::vector reshape; - - std::vector axes = raw_axes.GetData(); - int64_t axes_size = static_cast(axes.size()); - for (int64_t i = 0; i < axes_size; i++) { - if (axes[i] < 0) { - axes[i] += rank; - } - } - - for (int64_t i = 0; i < rank; i++) { - if (std::find(axes.begin(), axes.end(), i) == axes.end()) { - perm.push_back(i); - reshape.push_back(input_dim[i]); - } - } - - int64_t post_numel = 1; - for (int64_t i = 0; i < rank; i++) { - if (std::find(axes.begin(), axes.end(), i) != axes.end()) { - perm.push_back(i); - post_numel *= input_dim[i]; - } - } - reshape.push_back(post_numel); - - DDim trans_dim(input_dim); - int ndims = perm.size(); - for (int i = 0; i < ndims; i++) { - trans_dim[i] = input_dim[perm[i]]; - } - x->Resize(trans_dim); - dev_ctx.template Alloc(x); - funcs::TransCompute(ndims, dev_ctx, input, x, perm); - - x->Resize(make_ddim(reshape)); -} - template void NanmedianKernel(const Context& dev_ctx, const DenseTensor& x,