未验证 提交 f181d47f 编写于 作者: Z Zhang Zheng 提交者: GitHub

[Phi]Move kron kernel to phi (#40427)

* first commit

* fix

* fix

* fix compile eeror

* fix

* fix complex

* fix

* fix

* fix npu

* fix

* modify accroding to comments

* fix
上级 5d08a447
......@@ -17,7 +17,6 @@ limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/kron_op.h"
#include "paddle/fluid/platform/device/npu/npu_info.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
......
......@@ -17,9 +17,7 @@ limitations under the License. */
#include <unordered_map>
#include <vector>
#include "paddle/fluid/operators/kron_op.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
......@@ -178,27 +176,4 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(kron, ops::KronOp, ops::KronOpMaker,
ops::KronGradOpMaker<paddle::framework::OpDesc>,
ops::KronGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(
kron, ops::KronKernel<paddle::platform::CPUDeviceContext, float>,
ops::KronKernel<paddle::platform::CPUDeviceContext, double>,
ops::KronKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>,
ops::KronKernel<paddle::platform::CPUDeviceContext, int>,
ops::KronKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::KronKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::KronKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OPERATOR(kron_grad, ops::KronGradOp);
REGISTER_OP_CPU_KERNEL(
kron_grad, ops::KronGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::KronGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::KronGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>,
ops::KronGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::KronGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::KronGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::KronGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/kron_op.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
kron, ops::KronKernel<paddle::platform::CUDADeviceContext, float>,
ops::KronKernel<paddle::platform::CUDADeviceContext, double>,
ops::KronKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::KronKernel<paddle::platform::CUDADeviceContext, int>,
ops::KronKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::KronKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::KronKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
kron_grad, ops::KronGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::KronGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::KronGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::KronGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::KronGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::KronGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::KronGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
......@@ -16,7 +16,7 @@ limitations under the License. */
#include <memory>
#include <string>
#include "paddle/fluid/operators/kron_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
namespace paddle {
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/kron_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/kron_grad_kernel_impl.h"
PD_REGISTER_KERNEL(kron_grad,
CPU,
ALL_LAYOUT,
phi::KronGradKernel,
int,
int64_t,
float,
double,
phi::dtype::float16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/kron_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/kron_kernel_impl.h"
PD_REGISTER_KERNEL(kron,
CPU,
ALL_LAYOUT,
phi::KronKernel,
int,
int64_t,
float,
double,
phi::dtype::float16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/kron_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/kron_grad_kernel_impl.h"
PD_REGISTER_KERNEL(kron_grad,
GPU,
ALL_LAYOUT,
phi::KronGradKernel,
int,
int64_t,
float,
double,
phi::dtype::float16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/kron_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/kron_kernel_impl.h"
PD_REGISTER_KERNEL(kron,
GPU,
ALL_LAYOUT,
phi::KronKernel,
int,
int64_t,
float,
double,
phi::dtype::float16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/for_range.h"
#if defined(__NVCC__) || defined(__HIPCC__)
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "thrust/device_vector.h"
#endif
namespace paddle {
namespace operators {
#include "paddle/phi/kernels/impl/kron_kernel_impl.h"
// Process an element in the output, used with a parallel-for
template <typename T>
struct KronElemFunctor {
KronElemFunctor(const T* a, const T* b, T* out, const int64_t* shape_b,
const int64_t* stride_a, const int64_t* stride_b,
const int64_t* stride_out, int ndims)
: a_(a),
b_(b),
out_(out),
shape_b_(shape_b),
stride_a_(stride_a),
stride_b_(stride_b),
stride_out_(stride_out),
ndims_(ndims) {}
HOSTDEVICE void operator()(int64_t idx) const {
// it computes 1 element in the output
int64_t index = idx;
int64_t index_a = 0;
int64_t index_b = 0;
for (int i = 0; i < ndims_; i++) {
auto pos_i = index / stride_out_[i];
index = index % stride_out_[i];
auto pos_ai = pos_i / shape_b_[i];
auto pos_bi = pos_i % shape_b_[i];
index_a += stride_a_[i] * pos_ai;
index_b += stride_b_[i] * pos_bi;
}
out_[idx] = a_[index_a] * b_[index_b];
}
private:
const T* a_;
const T* b_;
T* out_;
const int64_t* shape_b_;
const int64_t* stride_a_;
const int64_t* stride_b_;
const int64_t* stride_out_;
const int ndims_;
};
template <typename DeviceContext, typename T>
struct KronOpFunctor {
void operator()(const DeviceContext& dev_ctx, const framework::Tensor& x,
const framework::Tensor& y, framework::Tensor* out) {
int ndims = out->dims().size();
int64_t numel = out->numel();
const framework::DDim& dim_x = x.dims();
const framework::DDim& dim_y = y.dims();
const framework::DDim& dim_out = out->dims();
const framework::DDim stride_x = phi::stride(dim_x);
const framework::DDim stride_y = phi::stride(dim_y);
const framework::DDim stride_out = phi::stride(dim_out);
const int64_t *p_stride_x = nullptr, *p_stride_y = nullptr,
*p_stride_out = nullptr, *p_shape_y = nullptr;
#if defined(__NVCC__) || defined(__HIPCC__)
thrust::device_vector<int64_t> d_stride_x(ndims);
thrust::device_vector<int64_t> d_stride_y(ndims);
thrust::device_vector<int64_t> d_stride_out(ndims);
thrust::device_vector<int64_t> d_shape_y(ndims);
thrust::copy(stride_x.Get(), stride_x.Get() + ndims, d_stride_x.begin());
thrust::copy(stride_y.Get(), stride_y.Get() + ndims, d_stride_y.begin());
thrust::copy(stride_out.Get(), stride_out.Get() + ndims,
d_stride_out.begin());
thrust::copy(dim_y.Get(), dim_y.Get() + ndims, d_shape_y.begin());
p_stride_x = thrust::raw_pointer_cast(d_stride_x.data());
p_stride_y = thrust::raw_pointer_cast(d_stride_y.data());
p_stride_out = thrust::raw_pointer_cast(d_stride_out.data());
p_shape_y = thrust::raw_pointer_cast(d_shape_y.data());
#else
p_stride_x = stride_x.Get();
p_stride_y = stride_y.Get();
p_stride_out = stride_out.Get();
p_shape_y = dim_y.Get();
#endif
platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
KronElemFunctor<T> functor(x.data<T>(), y.data<T>(), out->data<T>(),
p_shape_y, p_stride_x, p_stride_y, p_stride_out,
ndims);
for_range(functor);
}
};
namespace phi {
template <typename T>
struct KronGradElemFunctor {
KronGradElemFunctor(const T* dout, const T* A, const T* B, T* dout_a,
T* dout_b, const int64_t* stride_dout,
const int64_t* stride_a, const int64_t* stride_b,
const int64_t* shape_b, const int64_t numel_a,
const int64_t numel_b, const int ndims)
KronGradElemFunctor(const T* dout,
const T* A,
const T* B,
T* dout_a,
T* dout_b,
const int64_t* stride_dout,
const int64_t* stride_a,
const int64_t* stride_b,
const int64_t* shape_b,
const int64_t numel_a,
const int64_t numel_b,
const int ndims)
: dout_(dout),
A_(A),
B_(B),
......@@ -173,15 +84,19 @@ struct KronGradElemFunctor {
};
template <typename T>
struct KronGradElemFunctor<platform::complex<T>> {
KronGradElemFunctor(const platform::complex<T>* dout,
const platform::complex<T>* A,
const platform::complex<T>* B,
platform::complex<T>* dout_a,
platform::complex<T>* dout_b, const int64_t* stride_dout,
const int64_t* stride_a, const int64_t* stride_b,
const int64_t* shape_b, const int64_t numel_a,
const int64_t numel_b, const int ndims)
struct KronGradElemFunctor<dtype::complex<T>> {
KronGradElemFunctor(const dtype::complex<T>* dout,
const dtype::complex<T>* A,
const dtype::complex<T>* B,
dtype::complex<T>* dout_a,
dtype::complex<T>* dout_b,
const int64_t* stride_dout,
const int64_t* stride_a,
const int64_t* stride_b,
const int64_t* shape_b,
const int64_t numel_a,
const int64_t numel_b,
const int ndims)
: dout_(dout),
A_(A),
B_(B),
......@@ -211,23 +126,21 @@ struct KronGradElemFunctor<platform::complex<T>> {
if (dout_a_) {
size_t index_out_a = index_a * numel_b_ + index_b;
dout_a_[index_out_a] =
dout_[idx] *
platform::complex<T>(B_[index_b].real, -B_[index_b].imag);
dout_[idx] * dtype::complex<T>(B_[index_b].real, -B_[index_b].imag);
}
if (dout_b_) {
size_t index_out_b = index_b * numel_a_ + index_a;
dout_b_[index_out_b] =
dout_[idx] *
platform::complex<T>(A_[index_a].real, -A_[index_a].imag);
dout_[idx] * dtype::complex<T>(A_[index_a].real, -A_[index_a].imag);
}
}
private:
const platform::complex<T>* dout_;
const platform::complex<T>* A_;
const platform::complex<T>* B_;
platform::complex<T>* dout_a_;
platform::complex<T>* dout_b_;
const dtype::complex<T>* dout_;
const dtype::complex<T>* A_;
const dtype::complex<T>* B_;
dtype::complex<T>* dout_a_;
dtype::complex<T>* dout_b_;
const int64_t* stride_dout_;
const int64_t* stride_a_;
const int64_t* stride_b_;
......@@ -237,23 +150,26 @@ struct KronGradElemFunctor<platform::complex<T>> {
const int ndims_;
};
template <typename DeviceContext, typename T>
template <typename Context, typename T>
struct KronGradOpFunctor {
void operator()(const DeviceContext& dev_ctx, const framework::Tensor& dout,
const framework::Tensor& x, const framework::Tensor& y,
framework::Tensor* dx, framework::Tensor* dy) {
void operator()(const Context& dev_ctx,
const DenseTensor& dout,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* dx,
DenseTensor* dy) {
int ndims = dout.dims().size();
int64_t numel = dout.numel();
int64_t numel_x = x.numel();
int64_t numel_y = y.numel();
const framework::DDim& dim_x = x.dims();
const framework::DDim& dim_y = y.dims();
const framework::DDim& dim_dout = dout.dims();
const phi::DDim& dim_x = x.dims();
const phi::DDim& dim_y = y.dims();
const phi::DDim& dim_dout = dout.dims();
const framework::DDim stride_x = phi::stride(dim_x);
const framework::DDim stride_y = phi::stride(dim_y);
const framework::DDim stride_dout = phi::stride(dim_dout);
const phi::DDim stride_x = phi::stride(dim_x);
const phi::DDim stride_y = phi::stride(dim_y);
const phi::DDim stride_dout = phi::stride(dim_dout);
const int64_t* p_stride_x = nullptr;
const int64_t* p_stride_y = nullptr;
......@@ -266,8 +182,8 @@ struct KronGradOpFunctor {
thrust::device_vector<int64_t> d_shape_y(ndims);
thrust::copy(stride_x.Get(), stride_x.Get() + ndims, d_stride_x.begin());
thrust::copy(stride_y.Get(), stride_y.Get() + ndims, d_stride_y.begin());
thrust::copy(stride_dout.Get(), stride_dout.Get() + ndims,
d_stride_dout.begin());
thrust::copy(
stride_dout.Get(), stride_dout.Get() + ndims, d_stride_dout.begin());
thrust::copy(dim_y.Get(), dim_y.Get() + ndims, d_shape_y.begin());
p_stride_x = thrust::raw_pointer_cast(d_stride_x.data());
......@@ -282,134 +198,98 @@ struct KronGradOpFunctor {
#endif
// dout_x: dout * kron(ones(X), Y) re-aranged in shape (numel_x, numel_y)
// dout_y: dout * kron(X, ones(Y)) re-aranged in shaoe (numel_y, numel_x)
framework::Tensor dout_x;
DenseTensor dout_x;
T* p_dout_x = nullptr;
if (dx) {
dout_x.mutable_data<T>({numel_x, numel_y}, dev_ctx.GetPlace());
dout_x.Resize({numel_x, numel_y});
dev_ctx.template Alloc<T>(&dout_x);
p_dout_x = dout_x.data<T>();
}
framework::Tensor dout_y;
DenseTensor dout_y;
T* p_dout_y = nullptr;
if (dy) {
dout_y.mutable_data<T>({numel_y, numel_x}, dev_ctx.GetPlace());
dout_y.Resize({numel_y, numel_x});
dev_ctx.template Alloc<T>(&dout_y);
p_dout_y = dout_y.data<T>();
}
platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
KronGradElemFunctor<T> func(dout.data<T>(), x.data<T>(), y.data<T>(),
p_dout_x, p_dout_y, p_stride_dout, p_stride_x,
p_stride_y, p_shape_y, numel_x, numel_y, ndims);
funcs::ForRange<Context> for_range(dev_ctx, numel);
KronGradElemFunctor<T> func(dout.data<T>(),
x.data<T>(),
y.data<T>(),
p_dout_x,
p_dout_y,
p_stride_dout,
p_stride_x,
p_stride_y,
p_shape_y,
numel_x,
numel_y,
ndims);
for_range(func);
// reduce_sum along aixs 1
#if defined(__NVCC__) || defined(__HIPCC__)
auto stream = dev_ctx.stream(); // it is a cuda device_context
if (dx) {
TensorReduceImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
dev_ctx, dout_x, dx, kps::IdentityFunctor<T>(), {1}, stream);
funcs::ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
dev_ctx, dout_x, dx, kps::IdentityFunctor<T>(), {1});
}
if (dy) {
TensorReduceImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
dev_ctx, dout_y, dy, kps::IdentityFunctor<T>(), {1}, stream);
funcs::ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
dev_ctx, dout_y, dy, kps::IdentityFunctor<T>(), {1});
}
#else
auto* place = dev_ctx.eigen_device();
Eigen::array<int, 1> reduce_dim = {1};
if (dx) {
auto eigen_dout_x = framework::EigenMatrix<T>::Reshape(dout_x, 1);
auto eigen_vec_dx = framework::EigenVector<T>::Flatten(*dx);
auto eigen_dout_x = EigenMatrix<T>::Reshape(dout_x, 1);
auto eigen_vec_dx = EigenVector<T>::Flatten(*dx);
eigen_vec_dx.device(*place) = eigen_dout_x.sum(reduce_dim);
}
if (dy) {
auto eigen_dout_y = framework::EigenMatrix<T>::Reshape(dout_y, 1);
auto eigen_vec_dy = framework::EigenVector<T>::Flatten(*dy);
auto eigen_dout_y = EigenMatrix<T>::Reshape(dout_y, 1);
auto eigen_vec_dy = EigenVector<T>::Flatten(*dy);
eigen_vec_dy.device(*place) = eigen_dout_y.sum(reduce_dim);
}
#endif
}
};
inline framework::Tensor UnsqueezeTo(const framework::Tensor& src, int ndims) {
const framework::DDim& shape = src.dims();
int rank = shape.size();
framework::Tensor res;
res.ShareDataWith(src);
PADDLE_ENFORCE_LE(
rank, ndims,
platform::errors::InvalidArgument(
"The input Tensor's rank should be less than or equal to ndims"
"Received input Tensor's rank = %d, ndims = %d",
rank, ndims));
if (rank < ndims) {
std::vector<int64_t> new_dim(ndims, 1);
for (int i = ndims - rank; i < ndims; i++) {
new_dim[i] = shape[i - ndims + rank];
}
res.Resize(phi::make_ddim(new_dim));
template <typename T, typename Context>
void KronGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out_grad,
DenseTensor* x_grad,
DenseTensor* y_grad) {
if (x_grad) {
ctx.template Alloc<T>(x_grad);
}
return res;
}
template <typename DeviceContext, typename T>
class KronKernel : public framework::OpKernel<T> {
public:
virtual void Compute(const framework::ExecutionContext& ctx) const {
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto* x = ctx.Input<framework::Tensor>("X");
auto* y = ctx.Input<framework::Tensor>("Y");
auto* out = ctx.Output<framework::Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
int ndims = out->dims().size();
framework::Tensor xx = UnsqueezeTo(*x, ndims);
framework::Tensor yy = UnsqueezeTo(*y, ndims);
KronOpFunctor<DeviceContext, T> func;
func(dev_ctx, xx, yy, out);
if (y_grad) {
ctx.template Alloc<T>(y_grad);
}
};
template <typename DeviceContext, typename T>
class KronGradKernel : public framework::OpKernel<T> {
public:
virtual void Compute(const framework::ExecutionContext& ctx) const {
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto* x = ctx.Input<framework::Tensor>("X");
auto* y = ctx.Input<framework::Tensor>("Y");
auto* dout = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<framework::Tensor>(framework::GradVarName("Y"));
if (dx) {
dx->mutable_data<T>(ctx.GetPlace());
}
if (dy) {
dy->mutable_data<T>(ctx.GetPlace());
}
int ndims = dout->dims().size();
framework::Tensor xx = UnsqueezeTo(*x, ndims);
framework::Tensor yy = UnsqueezeTo(*y, ndims);
framework::Tensor* pdxx = nullptr;
framework::Tensor* pdyy = nullptr;
framework::Tensor dxx;
framework::Tensor dyy;
if (dx) {
dxx = UnsqueezeTo(*dx, ndims);
pdxx = &dxx;
}
if (dy) {
dyy = UnsqueezeTo(*dy, ndims);
pdyy = &dyy;
}
int ndims = out_grad.dims().size();
DenseTensor xx = UnsqueezeTo(x, ndims);
DenseTensor yy = UnsqueezeTo(y, ndims);
DenseTensor* pdxx = nullptr;
DenseTensor* pdyy = nullptr;
DenseTensor dxx;
DenseTensor dyy;
if (x_grad) {
dxx = UnsqueezeTo(*x_grad, ndims);
pdxx = &dxx;
}
KronGradOpFunctor<DeviceContext, T> func;
func(dev_ctx, *dout, xx, yy, pdxx, pdyy);
if (y_grad) {
dyy = UnsqueezeTo(*y_grad, ndims);
pdyy = &dyy;
}
};
} // namespace operators
} // namespace paddle
KronGradOpFunctor<Context, T> func;
func(ctx, out_grad, xx, yy, pdxx, pdyy);
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <vector>
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#if defined(__NVCC__) || defined(__HIPCC__)
#include "paddle/phi/kernels/funcs/reduce_function.h"
#include "thrust/device_vector.h"
#endif
namespace phi {
inline DenseTensor UnsqueezeTo(const DenseTensor& src, int ndims) {
const phi::DDim& shape = src.dims();
int rank = shape.size();
DenseTensor res;
res.ShareDataWith(src);
PADDLE_ENFORCE_LE(
rank,
ndims,
errors::InvalidArgument(
"The input Tensor's rank should be less than or equal to ndims"
"Received input Tensor's rank = %d, ndims = %d",
rank,
ndims));
if (rank < ndims) {
std::vector<int64_t> new_dim(ndims, 1);
for (int i = ndims - rank; i < ndims; i++) {
new_dim[i] = shape[i - ndims + rank];
}
res.Resize(phi::make_ddim(new_dim));
}
return res;
}
template <typename T>
struct KronElemFunctor {
KronElemFunctor(const T* a,
const T* b,
T* out,
const int64_t* shape_b,
const int64_t* stride_a,
const int64_t* stride_b,
const int64_t* stride_out,
int ndims)
: a_(a),
b_(b),
out_(out),
shape_b_(shape_b),
stride_a_(stride_a),
stride_b_(stride_b),
stride_out_(stride_out),
ndims_(ndims) {}
HOSTDEVICE void operator()(int64_t idx) const {
// it computes 1 element in the output
int64_t index = idx;
int64_t index_a = 0;
int64_t index_b = 0;
for (int i = 0; i < ndims_; i++) {
auto pos_i = index / stride_out_[i];
index = index % stride_out_[i];
auto pos_ai = pos_i / shape_b_[i];
auto pos_bi = pos_i % shape_b_[i];
index_a += stride_a_[i] * pos_ai;
index_b += stride_b_[i] * pos_bi;
}
out_[idx] = a_[index_a] * b_[index_b];
}
private:
const T* a_;
const T* b_;
T* out_;
const int64_t* shape_b_;
const int64_t* stride_a_;
const int64_t* stride_b_;
const int64_t* stride_out_;
const int ndims_;
};
template <typename Context, typename T>
struct KronOpFunctor {
void operator()(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
int ndims = out->dims().size();
int64_t numel = out->numel();
const phi::DDim& dim_x = x.dims();
const phi::DDim& dim_y = y.dims();
const phi::DDim& dim_out = out->dims();
const phi::DDim stride_x = phi::stride(dim_x);
const phi::DDim stride_y = phi::stride(dim_y);
const phi::DDim stride_out = phi::stride(dim_out);
const int64_t *p_stride_x = nullptr, *p_stride_y = nullptr,
*p_stride_out = nullptr, *p_shape_y = nullptr;
#if defined(__NVCC__) || defined(__HIPCC__)
thrust::device_vector<int64_t> d_stride_x(ndims);
thrust::device_vector<int64_t> d_stride_y(ndims);
thrust::device_vector<int64_t> d_stride_out(ndims);
thrust::device_vector<int64_t> d_shape_y(ndims);
thrust::copy(stride_x.Get(), stride_x.Get() + ndims, d_stride_x.begin());
thrust::copy(stride_y.Get(), stride_y.Get() + ndims, d_stride_y.begin());
thrust::copy(
stride_out.Get(), stride_out.Get() + ndims, d_stride_out.begin());
thrust::copy(dim_y.Get(), dim_y.Get() + ndims, d_shape_y.begin());
p_stride_x = thrust::raw_pointer_cast(d_stride_x.data());
p_stride_y = thrust::raw_pointer_cast(d_stride_y.data());
p_stride_out = thrust::raw_pointer_cast(d_stride_out.data());
p_shape_y = thrust::raw_pointer_cast(d_shape_y.data());
#else
p_stride_x = stride_x.Get();
p_stride_y = stride_y.Get();
p_stride_out = stride_out.Get();
p_shape_y = dim_y.Get();
#endif
funcs::ForRange<Context> for_range(dev_ctx, numel);
KronElemFunctor<T> functor(x.data<T>(),
y.data<T>(),
out->data<T>(),
p_shape_y,
p_stride_x,
p_stride_y,
p_stride_out,
ndims);
for_range(functor);
}
};
template <typename T, typename Context>
void KronKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
ctx.template Alloc<T>(out);
int ndims = out->dims().size();
DenseTensor xx = UnsqueezeTo(x, ndims);
DenseTensor yy = UnsqueezeTo(y, ndims);
KronOpFunctor<Context, T> func;
func(ctx, xx, yy, out);
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void KronGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out_grad,
DenseTensor* x_grad,
DenseTensor* y_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void KronKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature KronGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("kron_grad",
{"X", "Y", GradVarName("Out")},
{},
{GradVarName("X"), GradVarName("Y")});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(kron_grad, phi::KronGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册