diff --git a/paddle/pten/kernels/cpu/trace_grad_kernel.cc b/paddle/pten/kernels/cpu/trace_grad_kernel.cc index 136b941ea8b0f0d8d11ba458c35f118e1a6e685e..0ab70f4018a7bc2eb6f16f007057311830e1416c 100644 --- a/paddle/pten/kernels/cpu/trace_grad_kernel.cc +++ b/paddle/pten/kernels/cpu/trace_grad_kernel.cc @@ -13,10 +13,10 @@ // limitations under the License. #include "paddle/pten/kernels/trace_grad_kernel.h" -#include "paddle/pten/kernels/impl/trace_kernel_impl.h" #include "paddle/pten/backends/cpu/cpu_context.h" #include "paddle/pten/core/kernel_registry.h" +#include "paddle/pten/kernels/impl/trace_grad_kernel_impl.h" PT_REGISTER_KERNEL(trace_grad, CPU, @@ -26,6 +26,6 @@ PT_REGISTER_KERNEL(trace_grad, double, int, int64_t, - paddle::platform::float16, - paddle::platform::complex, - paddle::platform::complex) {} + pten::dtype::float16, + pten::dtype::complex, + pten::dtype::complex) {} diff --git a/paddle/pten/kernels/cpu/trace_kernel.cc b/paddle/pten/kernels/cpu/trace_kernel.cc index 4064b752ef4ca87994f6e365cf89349a74c57adc..c918d49954530eb60a5c6df4d3a11e9051a38d1d 100644 --- a/paddle/pten/kernels/cpu/trace_kernel.cc +++ b/paddle/pten/kernels/cpu/trace_kernel.cc @@ -13,31 +13,31 @@ // limitations under the License. #include "paddle/pten/kernels/trace_kernel.h" + #include "paddle/pten/backends/cpu/cpu_context.h" #include "paddle/pten/core/kernel_registry.h" -#include "paddle/pten/kernels/impl/trace_kernel_impl.h" +#include "paddle/pten/kernels/funcs/diagonal.h" +#include "paddle/pten/kernels/funcs/eigen/common.h" namespace pten { template -void TraceKernel(const Context& ctx, +void TraceKernel(const Context& dev_ctx, const DenseTensor& x, int offset, int axis1, int axis2, DenseTensor* out) { - auto output_dims = out->dims(); - - T* out_data = out->mutable_data(ctx.GetPlace()); + auto* out_data = dev_ctx.template Alloc(out); - const DenseTensor diag = Diagonal(ctx, &x, offset, axis1, axis2); + const DenseTensor diag = + funcs::Diagonal(dev_ctx, &x, offset, axis1, axis2); if (diag.numel() > 0) { - auto x = paddle::framework::EigenMatrix::Reshape(diag, - diag.dims().size() - 1); - auto output = paddle::framework::EigenVector::Flatten(*out); + auto x = pten::EigenMatrix::Reshape(diag, diag.dims().size() - 1); + auto output = pten::EigenVector::Flatten(*out); auto reduce_dim = Eigen::array({1}); - output.device(*ctx.eigen_device()) = x.sum(reduce_dim); - out->Resize(output_dims); + output.device(*dev_ctx.eigen_device()) = x.sum(reduce_dim); + out->Resize(out->dims()); } else { std::fill(out_data, out_data + out->numel(), static_cast(0)); } @@ -53,6 +53,6 @@ PT_REGISTER_KERNEL(trace, double, int, int64_t, - paddle::platform::float16, - paddle::platform::complex, - paddle::platform::complex) {} + pten::dtype::float16, + pten::dtype::complex, + pten::dtype::complex) {} diff --git a/paddle/pten/kernels/funcs/diagonal.h b/paddle/pten/kernels/funcs/diagonal.h new file mode 100644 index 0000000000000000000000000000000000000000..533010527f48fc924033dc7fd9577bb077b59aae --- /dev/null +++ b/paddle/pten/kernels/funcs/diagonal.h @@ -0,0 +1,130 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#if defined(__NVCC__) || defined(__HIPCC__) +#include +#include +#endif + +#include + +#include "paddle/fluid/platform/for_range.h" + +namespace pten { +namespace funcs { + +template +struct DiagonalFunctor { + DiagonalFunctor(const T* input, + const int64_t* diag_stride, + const int64_t* ret_strides, + int64_t pos, + int64_t dim_size, + T* diag) + : input_(input), + diag_stride_(diag_stride), + ret_strides_(ret_strides), + pos_(pos), + dim_size_(dim_size), + diag_(diag) {} + + HOSTDEVICE void operator()(size_t idx) const { + int64_t position = pos_; + int64_t num = idx; + for (int64_t i = 0; i < dim_size_; i++) { + position += num / diag_stride_[i] * ret_strides_[i]; + num = num % diag_stride_[i]; + } + diag_[idx] = input_[position]; + } + + const T* input_; + const int64_t* diag_stride_; + const int64_t* ret_strides_; + int64_t pos_; + int64_t dim_size_; + T* diag_; +}; + +template +DenseTensor Diagonal(const DeviceContext& context, + const DenseTensor* input, + int64_t offset, + int64_t dim1, + int64_t dim2) { + auto* input_data = input->data(); + auto input_dims = input->dims(); + auto input_stride = framework::stride(input_dims); + auto dim1_ = dim1 < 0 ? input_dims.size() + dim1 : dim1; + auto dim2_ = dim2 < 0 ? input_dims.size() + dim2 : dim2; + auto len1 = input_dims[std::min(dim1_, dim2_)]; + auto len2 = input_dims[std::max(dim1_, dim2_)]; + auto stride1 = input_stride[std::min(dim1_, dim2_)]; + auto stride2 = input_stride[std::max(dim1_, dim2_)]; + + int offset_stride = 0; + if (offset >= 0) { + offset_stride = stride2; + len2 -= offset; + } else { + offset_stride = stride1; + len1 += offset; + } + int diag_size = len2 < len1 ? len2 : len1; + + if (diag_size > 0) { + auto ret_strides = vectorize(input_stride); + auto ret_dims = vectorize(input_dims); + ret_strides.erase(ret_strides.begin() + std::max(dim1_, dim2_)); + ret_strides.erase(ret_strides.begin() + std::min(dim1_, dim2_)); + ret_dims.erase(ret_dims.begin() + std::max(dim1_, dim2_)); + ret_dims.erase(ret_dims.begin() + std::min(dim1_, dim2_)); + if (ret_strides.empty()) { + ret_strides.push_back(1); + ret_dims.push_back(1); + } + ret_strides.push_back(stride1 + stride2); + ret_dims.push_back(diag_size); + DenseTensor diag; + framework::DDim diag_dims = framework::make_ddim(ret_dims); + auto dig_stride = framework::stride(diag_dims); + auto diag_data = diag.mutable_data(diag_dims, context.GetPlace()); + + int64_t pos = std::abs(offset) * offset_stride; + int64_t dim_size = ret_strides.size(); +#if defined(__NVCC__) || defined(__HIPCC__) + thrust::device_vector diag_vec(vectorize(dig_stride)); + const int64_t* diag_arr = thrust::raw_pointer_cast(diag_vec.data()); + thrust::device_vector ret_vec(ret_strides); + const int64_t* ret_arr = thrust::raw_pointer_cast(ret_vec.data()); +#else + auto* diag_arr = dig_stride.Get(); + const auto* ret_arr = ret_strides.data(); +#endif + + // auto& dev_ctx = context.template device_context(); + paddle::platform::ForRange for_range(context, diag.numel()); + DiagonalFunctor functor( + input_data, diag_arr, ret_arr, pos, dim_size, diag_data); + for_range(functor); + return diag; + } else { + return {}; + } +} + +} // namespace funcs +} // namespace pten diff --git a/paddle/pten/kernels/gpu/trace_grad_kernel.cu b/paddle/pten/kernels/gpu/trace_grad_kernel.cu index b1b22e5ce549609b1f4d4adb843f726c0eefd7a3..d5b1bcb0a87b940289abe539dcae95db775ad8fd 100644 --- a/paddle/pten/kernels/gpu/trace_grad_kernel.cu +++ b/paddle/pten/kernels/gpu/trace_grad_kernel.cu @@ -12,11 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/pten/kernels/impl/trace_kernel_impl.h" #include "paddle/pten/kernels/trace_grad_kernel.h" -#include "paddle/pten/backends/cpu/cpu_context.h" +#include "paddle/pten/backends/gpu/gpu_context.h" #include "paddle/pten/core/kernel_registry.h" +#include "paddle/pten/kernels/impl/trace_grad_kernel_impl.h" PT_REGISTER_KERNEL(trace_grad, GPU, @@ -26,6 +26,6 @@ PT_REGISTER_KERNEL(trace_grad, double, int, int64_t, - paddle::platform::float16, - paddle::platform::complex, - paddle::platform::complex) {} + pten::dtype::float16, + pten::dtype::complex, + pten::dtype::complex) {} diff --git a/paddle/pten/kernels/gpu/trace_kernel.cu b/paddle/pten/kernels/gpu/trace_kernel.cu index f552386fafdc76f6f92e91ac39d31262a3489e79..47db3dd7e6fa8376314de676892b037c0d468faf 100644 --- a/paddle/pten/kernels/gpu/trace_kernel.cu +++ b/paddle/pten/kernels/gpu/trace_kernel.cu @@ -12,11 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/pten/kernels/trace_kernel.h" + #include "paddle/pten/backends/gpu/gpu_context.h" #include "paddle/pten/core/kernel_registry.h" +#include "paddle/pten/kernels/funcs/diagonal.h" #include "paddle/pten/kernels/gpu/reduce.h" -#include "paddle/pten/kernels/impl/trace_kernel_impl.h" -#include "paddle/pten/kernels/trace_kernel.h" namespace pten { @@ -27,8 +28,8 @@ void TraceKernel(const Context& ctx, int axis1, int axis2, DenseTensor* out) { - T* out_data = out->mutable_data(ctx.GetPlace()); - auto diag = Diagonal(ctx, &x, offset, axis1, axis2); + T* out_data = ctx.template Alloc(out); + auto diag = funcs::Diagonal(ctx, &x, offset, axis1, axis2); if (diag.numel() > 0) { auto stream = ctx.stream(); std::vector reduce_dims; @@ -51,6 +52,6 @@ PT_REGISTER_KERNEL(trace, double, int, int64_t, - paddle::platform::float16, - paddle::platform::complex, - paddle::platform::complex) {} + pten::dtype::float16, + pten::dtype::complex, + pten::dtype::complex) {} diff --git a/paddle/pten/kernels/impl/trace_kernel_impl.h b/paddle/pten/kernels/impl/trace_grad_kernel_impl.h similarity index 56% rename from paddle/pten/kernels/impl/trace_kernel_impl.h rename to paddle/pten/kernels/impl/trace_grad_kernel_impl.h index 1b499681bbbe4da8e5de4e5373057e6351d705e9..9ad038bf1059c206ed9a644575a946f863fc7a15 100644 --- a/paddle/pten/kernels/impl/trace_kernel_impl.h +++ b/paddle/pten/kernels/impl/trace_grad_kernel_impl.h @@ -21,44 +21,10 @@ #include -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/for_range.h" #include "paddle/pten/kernels/funcs/math_function.h" namespace pten { -template -struct DiagonalFunctor { - DiagonalFunctor(const T* input, - const int64_t* diag_stride, - const int64_t* ret_strides, - int64_t pos, - int64_t dim_size, - T* diag) - : input_(input), - diag_stride_(diag_stride), - ret_strides_(ret_strides), - pos_(pos), - dim_size_(dim_size), - diag_(diag) {} - - HOSTDEVICE void operator()(size_t idx) const { - int64_t position = pos_; - int64_t num = idx; - for (int64_t i = 0; i < dim_size_; i++) { - position += num / diag_stride_[i] * ret_strides_[i]; - num = num % diag_stride_[i]; - } - diag_[idx] = input_[position]; - } - - const T* input_; - const int64_t* diag_stride_; - const int64_t* ret_strides_; - int64_t pos_; - int64_t dim_size_; - T* diag_; -}; template struct TraceGradFunctor { @@ -114,73 +80,6 @@ struct TraceGradFunctor { T* d_x_; }; -template -DenseTensor Diagonal(const DeviceContext& context, - const DenseTensor* input, - int64_t offset, - int64_t dim1, - int64_t dim2) { - auto* input_data = input->data(); - auto input_dims = input->dims(); - auto input_stride = framework::stride(input_dims); - auto dim1_ = dim1 < 0 ? input_dims.size() + dim1 : dim1; - auto dim2_ = dim2 < 0 ? input_dims.size() + dim2 : dim2; - auto len1 = input_dims[std::min(dim1_, dim2_)]; - auto len2 = input_dims[std::max(dim1_, dim2_)]; - auto stride1 = input_stride[std::min(dim1_, dim2_)]; - auto stride2 = input_stride[std::max(dim1_, dim2_)]; - - int offset_stride = 0; - if (offset >= 0) { - offset_stride = stride2; - len2 -= offset; - } else { - offset_stride = stride1; - len1 += offset; - } - int diag_size = len2 < len1 ? len2 : len1; - - if (diag_size > 0) { - auto ret_strides = vectorize(input_stride); - auto ret_dims = vectorize(input_dims); - ret_strides.erase(ret_strides.begin() + std::max(dim1_, dim2_)); - ret_strides.erase(ret_strides.begin() + std::min(dim1_, dim2_)); - ret_dims.erase(ret_dims.begin() + std::max(dim1_, dim2_)); - ret_dims.erase(ret_dims.begin() + std::min(dim1_, dim2_)); - if (ret_strides.empty()) { - ret_strides.push_back(1); - ret_dims.push_back(1); - } - ret_strides.push_back(stride1 + stride2); - ret_dims.push_back(diag_size); - DenseTensor diag; - framework::DDim diag_dims = framework::make_ddim(ret_dims); - auto dig_stride = framework::stride(diag_dims); - auto diag_data = diag.mutable_data(diag_dims, context.GetPlace()); - - int64_t pos = std::abs(offset) * offset_stride; - int64_t dim_size = ret_strides.size(); -#if defined(__NVCC__) || defined(__HIPCC__) - thrust::device_vector diag_vec(vectorize(dig_stride)); - const int64_t* diag_arr = thrust::raw_pointer_cast(diag_vec.data()); - thrust::device_vector ret_vec(ret_strides); - const int64_t* ret_arr = thrust::raw_pointer_cast(ret_vec.data()); -#else - auto* diag_arr = dig_stride.Get(); - const auto* ret_arr = ret_strides.data(); -#endif - - // auto& dev_ctx = context.template device_context(); - paddle::platform::ForRange for_range(context, diag.numel()); - DiagonalFunctor functor( - input_data, diag_arr, ret_arr, pos, dim_size, diag_data); - for_range(functor); - return diag; - } else { - return {}; - } -} - template void TraceGradKernel(const Context& ctx, const DenseTensor& out_grad, @@ -195,7 +94,7 @@ void TraceGradKernel(const Context& ctx, auto output_stride = framework::stride(output_dims); auto* out_data = out_grad.data(); - T* x_data = in_grad->mutable_data(ctx.GetPlace()); + T* x_data = ctx.template Alloc(in_grad); pten::funcs::SetConstant set_zero;