未验证 提交 ab866777 编写于 作者: C Chen Weihang 提交者: GitHub

[PTen] Polish trace moving (#39510)

* polish trace moving

* remove useless header
上级 4745234f
...@@ -13,10 +13,10 @@ ...@@ -13,10 +13,10 @@
// limitations under the License. // limitations under the License.
#include "paddle/pten/kernels/trace_grad_kernel.h" #include "paddle/pten/kernels/trace_grad_kernel.h"
#include "paddle/pten/kernels/impl/trace_kernel_impl.h"
#include "paddle/pten/backends/cpu/cpu_context.h" #include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/core/kernel_registry.h" #include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/impl/trace_grad_kernel_impl.h"
PT_REGISTER_KERNEL(trace_grad, PT_REGISTER_KERNEL(trace_grad,
CPU, CPU,
...@@ -26,6 +26,6 @@ PT_REGISTER_KERNEL(trace_grad, ...@@ -26,6 +26,6 @@ PT_REGISTER_KERNEL(trace_grad,
double, double,
int, int,
int64_t, int64_t,
paddle::platform::float16, pten::dtype::float16,
paddle::platform::complex<float>, pten::dtype::complex<float>,
paddle::platform::complex<double>) {} pten::dtype::complex<double>) {}
...@@ -13,31 +13,31 @@ ...@@ -13,31 +13,31 @@
// limitations under the License. // limitations under the License.
#include "paddle/pten/kernels/trace_kernel.h" #include "paddle/pten/kernels/trace_kernel.h"
#include "paddle/pten/backends/cpu/cpu_context.h" #include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/core/kernel_registry.h" #include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/impl/trace_kernel_impl.h" #include "paddle/pten/kernels/funcs/diagonal.h"
#include "paddle/pten/kernels/funcs/eigen/common.h"
namespace pten { namespace pten {
template <typename T, typename Context> template <typename T, typename Context>
void TraceKernel(const Context& ctx, void TraceKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
int offset, int offset,
int axis1, int axis1,
int axis2, int axis2,
DenseTensor* out) { DenseTensor* out) {
auto output_dims = out->dims(); auto* out_data = dev_ctx.template Alloc<T>(out);
T* out_data = out->mutable_data<T>(ctx.GetPlace());
const DenseTensor diag = Diagonal<T, Context>(ctx, &x, offset, axis1, axis2); const DenseTensor diag =
funcs::Diagonal<T, Context>(dev_ctx, &x, offset, axis1, axis2);
if (diag.numel() > 0) { if (diag.numel() > 0) {
auto x = paddle::framework::EigenMatrix<T>::Reshape(diag, auto x = pten::EigenMatrix<T>::Reshape(diag, diag.dims().size() - 1);
diag.dims().size() - 1); auto output = pten::EigenVector<T>::Flatten(*out);
auto output = paddle::framework::EigenVector<T>::Flatten(*out);
auto reduce_dim = Eigen::array<int, 1>({1}); auto reduce_dim = Eigen::array<int, 1>({1});
output.device(*ctx.eigen_device()) = x.sum(reduce_dim); output.device(*dev_ctx.eigen_device()) = x.sum(reduce_dim);
out->Resize(output_dims); out->Resize(out->dims());
} else { } else {
std::fill(out_data, out_data + out->numel(), static_cast<T>(0)); std::fill(out_data, out_data + out->numel(), static_cast<T>(0));
} }
...@@ -53,6 +53,6 @@ PT_REGISTER_KERNEL(trace, ...@@ -53,6 +53,6 @@ PT_REGISTER_KERNEL(trace,
double, double,
int, int,
int64_t, int64_t,
paddle::platform::float16, pten::dtype::float16,
paddle::platform::complex<float>, pten::dtype::complex<float>,
paddle::platform::complex<double>) {} pten::dtype::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#if defined(__NVCC__) || defined(__HIPCC__)
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#endif
#include <algorithm>
#include "paddle/fluid/platform/for_range.h"
namespace pten {
namespace funcs {
template <typename T>
struct DiagonalFunctor {
DiagonalFunctor(const T* input,
const int64_t* diag_stride,
const int64_t* ret_strides,
int64_t pos,
int64_t dim_size,
T* diag)
: input_(input),
diag_stride_(diag_stride),
ret_strides_(ret_strides),
pos_(pos),
dim_size_(dim_size),
diag_(diag) {}
HOSTDEVICE void operator()(size_t idx) const {
int64_t position = pos_;
int64_t num = idx;
for (int64_t i = 0; i < dim_size_; i++) {
position += num / diag_stride_[i] * ret_strides_[i];
num = num % diag_stride_[i];
}
diag_[idx] = input_[position];
}
const T* input_;
const int64_t* diag_stride_;
const int64_t* ret_strides_;
int64_t pos_;
int64_t dim_size_;
T* diag_;
};
template <typename T, typename DeviceContext>
DenseTensor Diagonal(const DeviceContext& context,
const DenseTensor* input,
int64_t offset,
int64_t dim1,
int64_t dim2) {
auto* input_data = input->data<T>();
auto input_dims = input->dims();
auto input_stride = framework::stride(input_dims);
auto dim1_ = dim1 < 0 ? input_dims.size() + dim1 : dim1;
auto dim2_ = dim2 < 0 ? input_dims.size() + dim2 : dim2;
auto len1 = input_dims[std::min(dim1_, dim2_)];
auto len2 = input_dims[std::max(dim1_, dim2_)];
auto stride1 = input_stride[std::min(dim1_, dim2_)];
auto stride2 = input_stride[std::max(dim1_, dim2_)];
int offset_stride = 0;
if (offset >= 0) {
offset_stride = stride2;
len2 -= offset;
} else {
offset_stride = stride1;
len1 += offset;
}
int diag_size = len2 < len1 ? len2 : len1;
if (diag_size > 0) {
auto ret_strides = vectorize(input_stride);
auto ret_dims = vectorize(input_dims);
ret_strides.erase(ret_strides.begin() + std::max(dim1_, dim2_));
ret_strides.erase(ret_strides.begin() + std::min(dim1_, dim2_));
ret_dims.erase(ret_dims.begin() + std::max(dim1_, dim2_));
ret_dims.erase(ret_dims.begin() + std::min(dim1_, dim2_));
if (ret_strides.empty()) {
ret_strides.push_back(1);
ret_dims.push_back(1);
}
ret_strides.push_back(stride1 + stride2);
ret_dims.push_back(diag_size);
DenseTensor diag;
framework::DDim diag_dims = framework::make_ddim(ret_dims);
auto dig_stride = framework::stride(diag_dims);
auto diag_data = diag.mutable_data<T>(diag_dims, context.GetPlace());
int64_t pos = std::abs(offset) * offset_stride;
int64_t dim_size = ret_strides.size();
#if defined(__NVCC__) || defined(__HIPCC__)
thrust::device_vector<int64_t> diag_vec(vectorize(dig_stride));
const int64_t* diag_arr = thrust::raw_pointer_cast(diag_vec.data());
thrust::device_vector<int64_t> ret_vec(ret_strides);
const int64_t* ret_arr = thrust::raw_pointer_cast(ret_vec.data());
#else
auto* diag_arr = dig_stride.Get();
const auto* ret_arr = ret_strides.data();
#endif
// auto& dev_ctx = context.template device_context<DeviceContext>();
paddle::platform::ForRange<DeviceContext> for_range(context, diag.numel());
DiagonalFunctor<T> functor(
input_data, diag_arr, ret_arr, pos, dim_size, diag_data);
for_range(functor);
return diag;
} else {
return {};
}
}
} // namespace funcs
} // namespace pten
...@@ -12,11 +12,11 @@ ...@@ -12,11 +12,11 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/pten/kernels/impl/trace_kernel_impl.h"
#include "paddle/pten/kernels/trace_grad_kernel.h" #include "paddle/pten/kernels/trace_grad_kernel.h"
#include "paddle/pten/backends/cpu/cpu_context.h" #include "paddle/pten/backends/gpu/gpu_context.h"
#include "paddle/pten/core/kernel_registry.h" #include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/impl/trace_grad_kernel_impl.h"
PT_REGISTER_KERNEL(trace_grad, PT_REGISTER_KERNEL(trace_grad,
GPU, GPU,
...@@ -26,6 +26,6 @@ PT_REGISTER_KERNEL(trace_grad, ...@@ -26,6 +26,6 @@ PT_REGISTER_KERNEL(trace_grad,
double, double,
int, int,
int64_t, int64_t,
paddle::platform::float16, pten::dtype::float16,
paddle::platform::complex<float>, pten::dtype::complex<float>,
paddle::platform::complex<double>) {} pten::dtype::complex<double>) {}
...@@ -12,11 +12,12 @@ ...@@ -12,11 +12,12 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/pten/kernels/trace_kernel.h"
#include "paddle/pten/backends/gpu/gpu_context.h" #include "paddle/pten/backends/gpu/gpu_context.h"
#include "paddle/pten/core/kernel_registry.h" #include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/funcs/diagonal.h"
#include "paddle/pten/kernels/gpu/reduce.h" #include "paddle/pten/kernels/gpu/reduce.h"
#include "paddle/pten/kernels/impl/trace_kernel_impl.h"
#include "paddle/pten/kernels/trace_kernel.h"
namespace pten { namespace pten {
...@@ -27,8 +28,8 @@ void TraceKernel(const Context& ctx, ...@@ -27,8 +28,8 @@ void TraceKernel(const Context& ctx,
int axis1, int axis1,
int axis2, int axis2,
DenseTensor* out) { DenseTensor* out) {
T* out_data = out->mutable_data<T>(ctx.GetPlace()); T* out_data = ctx.template Alloc<T>(out);
auto diag = Diagonal<T, Context>(ctx, &x, offset, axis1, axis2); auto diag = funcs::Diagonal<T, Context>(ctx, &x, offset, axis1, axis2);
if (diag.numel() > 0) { if (diag.numel() > 0) {
auto stream = ctx.stream(); auto stream = ctx.stream();
std::vector<int> reduce_dims; std::vector<int> reduce_dims;
...@@ -51,6 +52,6 @@ PT_REGISTER_KERNEL(trace, ...@@ -51,6 +52,6 @@ PT_REGISTER_KERNEL(trace,
double, double,
int, int,
int64_t, int64_t,
paddle::platform::float16, pten::dtype::float16,
paddle::platform::complex<float>, pten::dtype::complex<float>,
paddle::platform::complex<double>) {} pten::dtype::complex<double>) {}
...@@ -21,44 +21,10 @@ ...@@ -21,44 +21,10 @@
#include <algorithm> #include <algorithm>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/for_range.h" #include "paddle/fluid/platform/for_range.h"
#include "paddle/pten/kernels/funcs/math_function.h" #include "paddle/pten/kernels/funcs/math_function.h"
namespace pten { namespace pten {
template <typename T>
struct DiagonalFunctor {
DiagonalFunctor(const T* input,
const int64_t* diag_stride,
const int64_t* ret_strides,
int64_t pos,
int64_t dim_size,
T* diag)
: input_(input),
diag_stride_(diag_stride),
ret_strides_(ret_strides),
pos_(pos),
dim_size_(dim_size),
diag_(diag) {}
HOSTDEVICE void operator()(size_t idx) const {
int64_t position = pos_;
int64_t num = idx;
for (int64_t i = 0; i < dim_size_; i++) {
position += num / diag_stride_[i] * ret_strides_[i];
num = num % diag_stride_[i];
}
diag_[idx] = input_[position];
}
const T* input_;
const int64_t* diag_stride_;
const int64_t* ret_strides_;
int64_t pos_;
int64_t dim_size_;
T* diag_;
};
template <typename T> template <typename T>
struct TraceGradFunctor { struct TraceGradFunctor {
...@@ -114,73 +80,6 @@ struct TraceGradFunctor { ...@@ -114,73 +80,6 @@ struct TraceGradFunctor {
T* d_x_; T* d_x_;
}; };
template <typename T, typename DeviceContext>
DenseTensor Diagonal(const DeviceContext& context,
const DenseTensor* input,
int64_t offset,
int64_t dim1,
int64_t dim2) {
auto* input_data = input->data<T>();
auto input_dims = input->dims();
auto input_stride = framework::stride(input_dims);
auto dim1_ = dim1 < 0 ? input_dims.size() + dim1 : dim1;
auto dim2_ = dim2 < 0 ? input_dims.size() + dim2 : dim2;
auto len1 = input_dims[std::min(dim1_, dim2_)];
auto len2 = input_dims[std::max(dim1_, dim2_)];
auto stride1 = input_stride[std::min(dim1_, dim2_)];
auto stride2 = input_stride[std::max(dim1_, dim2_)];
int offset_stride = 0;
if (offset >= 0) {
offset_stride = stride2;
len2 -= offset;
} else {
offset_stride = stride1;
len1 += offset;
}
int diag_size = len2 < len1 ? len2 : len1;
if (diag_size > 0) {
auto ret_strides = vectorize(input_stride);
auto ret_dims = vectorize(input_dims);
ret_strides.erase(ret_strides.begin() + std::max(dim1_, dim2_));
ret_strides.erase(ret_strides.begin() + std::min(dim1_, dim2_));
ret_dims.erase(ret_dims.begin() + std::max(dim1_, dim2_));
ret_dims.erase(ret_dims.begin() + std::min(dim1_, dim2_));
if (ret_strides.empty()) {
ret_strides.push_back(1);
ret_dims.push_back(1);
}
ret_strides.push_back(stride1 + stride2);
ret_dims.push_back(diag_size);
DenseTensor diag;
framework::DDim diag_dims = framework::make_ddim(ret_dims);
auto dig_stride = framework::stride(diag_dims);
auto diag_data = diag.mutable_data<T>(diag_dims, context.GetPlace());
int64_t pos = std::abs(offset) * offset_stride;
int64_t dim_size = ret_strides.size();
#if defined(__NVCC__) || defined(__HIPCC__)
thrust::device_vector<int64_t> diag_vec(vectorize(dig_stride));
const int64_t* diag_arr = thrust::raw_pointer_cast(diag_vec.data());
thrust::device_vector<int64_t> ret_vec(ret_strides);
const int64_t* ret_arr = thrust::raw_pointer_cast(ret_vec.data());
#else
auto* diag_arr = dig_stride.Get();
const auto* ret_arr = ret_strides.data();
#endif
// auto& dev_ctx = context.template device_context<DeviceContext>();
paddle::platform::ForRange<DeviceContext> for_range(context, diag.numel());
DiagonalFunctor<T> functor(
input_data, diag_arr, ret_arr, pos, dim_size, diag_data);
for_range(functor);
return diag;
} else {
return {};
}
}
template <typename T, typename Context> template <typename T, typename Context>
void TraceGradKernel(const Context& ctx, void TraceGradKernel(const Context& ctx,
const DenseTensor& out_grad, const DenseTensor& out_grad,
...@@ -195,7 +94,7 @@ void TraceGradKernel(const Context& ctx, ...@@ -195,7 +94,7 @@ void TraceGradKernel(const Context& ctx,
auto output_stride = framework::stride(output_dims); auto output_stride = framework::stride(output_dims);
auto* out_data = out_grad.data<T>(); auto* out_data = out_grad.data<T>();
T* x_data = in_grad->mutable_data<T>(ctx.GetPlace()); T* x_data = ctx.template Alloc<T>(in_grad);
pten::funcs::SetConstant<Context, T> set_zero; pten::funcs::SetConstant<Context, T> set_zero;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册