未验证 提交 befa78ea 编写于 作者: L Liu-xiandong 提交者: GitHub

[phi] move matrix_power op (#40231)

* [phi] move matrix_power op

* MatrixInverse fluid -> phi

* modify the CMake to fix compile bug

* delete useless comment

* mutable memory -> phi Alloc

* modify the include file

* modify the include file

* fix bug in CI compiler
上级 857069f3
......@@ -12,7 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/matrix_power_op.h"
#include <memory>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h"
namespace paddle {
namespace operators {
......@@ -119,13 +122,3 @@ REGISTER_OPERATOR(matrix_power, ops::MatrixPowerOp, ops::MatrixPowerOpMaker,
ops::MatrixPowerGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(matrix_power_grad, ops::MatrixPowerGradOp);
REGISTER_OP_CPU_KERNEL(
matrix_power,
ops::MatrixPowerKernel<paddle::platform::CPUDeviceContext, float>,
ops::MatrixPowerKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
matrix_power_grad,
ops::MatrixPowerGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::MatrixPowerGradKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/matrix_inverse.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
struct IdentityMatrixFunctor {
IdentityMatrixFunctor(const int m, T* output) : m_(m), output_(output) {}
HOSTDEVICE void operator()(size_t index) const {
const int row = index / m_ % m_;
const int col = index % m_;
output_[index] = col == row ? static_cast<T>(1) : static_cast<T>(0);
}
const int m_;
T* output_;
};
template <typename DeviceContext, typename T>
void MatrixPowerFunction(const Tensor* X, const int n, Tensor* Out,
const paddle::framework::ExecutionContext& ctx) {
const auto& x_dims = X->dims();
const int x_ndim = x_dims.size();
T* out_data = Out->mutable_data<T>(ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<DeviceContext>();
platform::ForRange<DeviceContext> for_range(dev_ctx, X->numel());
if (n == 0) {
// Out = Identity Matrix
IdentityMatrixFunctor<T> functor(x_dims[x_ndim - 1], out_data);
for_range(functor);
return;
}
auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
Tensor new_x = ctx.AllocateTmpTensor<T, DeviceContext>(X->dims(), dev_ctx);
int new_n = n;
if (n > 0) {
// newX = X
framework::TensorCopy(*X, ctx.GetPlace(), dev_ctx, &new_x);
} else {
// newX = X^{-1}, n = -n
phi::funcs::MatrixInverseFunctor<DeviceContext, T> mat_inv;
mat_inv(dev_ctx, *X, &new_x);
new_n = -n;
}
if (new_n == 1) {
framework::TensorCopy(new_x, ctx.GetPlace(), dev_ctx, Out);
return;
}
auto no_trans_desc = phi::funcs::CreateMatrixDescriptor(x_dims, 0, false);
if (new_n == 2) {
// Out = newX * newX
Out->mutable_data<T>(ctx.GetPlace());
blas.MatMul(new_x, no_trans_desc, new_x, no_trans_desc, static_cast<T>(1),
Out, static_cast<T>(0));
return;
} else if (new_n == 3) {
// Out = (newX * newX) * newX
// Note: C[i] matrices in MatMul must not overlap, i.e. the individual
// gemm operations must be computable independently; otherwise,
// undefined behavior is expected.
Tensor temp = ctx.AllocateTmpTensor<T, DeviceContext>(X->dims(), dev_ctx);
blas.MatMul(new_x, no_trans_desc, new_x, no_trans_desc, static_cast<T>(1),
&temp, static_cast<T>(0));
blas.MatMul(temp, no_trans_desc, new_x, no_trans_desc, static_cast<T>(1),
Out, static_cast<T>(0));
return;
} else if (new_n == 4) {
// Out = (newX * newX) * (newX * newX)
Tensor temp = ctx.AllocateTmpTensor<T, DeviceContext>(X->dims(), dev_ctx);
blas.MatMul(new_x, no_trans_desc, new_x, no_trans_desc, static_cast<T>(1),
&temp, static_cast<T>(0));
blas.MatMul(temp, no_trans_desc, temp, no_trans_desc, static_cast<T>(1),
Out, static_cast<T>(0));
return;
}
// Calculate Out = newX^{n} for abs(n) > 4 with time complexity as O(logN)
int bit = 0;
Tensor z = Tensor(X->dtype());
bool out_inited = false;
Tensor temp_out = ctx.AllocateTmpTensor<T, DeviceContext>(X->dims(), dev_ctx);
Tensor temp_z = ctx.AllocateTmpTensor<T, DeviceContext>(X->dims(), dev_ctx);
while (new_n > 0) {
bit = new_n & 0x1;
new_n >>= 1;
if (z.IsInitialized()) {
blas.MatMul(z, no_trans_desc, z, no_trans_desc, static_cast<T>(1),
&temp_z, static_cast<T>(0));
framework::TensorCopy(temp_z, ctx.GetPlace(), dev_ctx, &z);
} else {
z = ctx.AllocateTmpTensor<T, DeviceContext>(X->dims(), dev_ctx);
framework::TensorCopy(new_x, ctx.GetPlace(), dev_ctx, &z);
}
if (bit == 1) {
if (out_inited == true) {
blas.MatMul(*Out, no_trans_desc, z, no_trans_desc, static_cast<T>(1),
&temp_out, static_cast<T>(0));
framework::TensorCopy(temp_out, ctx.GetPlace(), dev_ctx, Out);
} else {
framework::TensorCopy(z, ctx.GetPlace(), dev_ctx, Out);
out_inited = true;
}
}
}
return;
}
template <typename DeviceContext, typename T>
class MatrixPowerKernel : public framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
const Tensor* X = ctx.Input<Tensor>("X");
Tensor* Out = ctx.Output<Tensor>("Out");
int n = ctx.Attr<int>("n");
const auto& x_dims = X->dims();
const int x_ndim = x_dims.size();
PADDLE_ENFORCE_EQ(
x_dims[x_ndim - 2], x_dims[x_ndim - 1],
platform::errors::InvalidArgument(
"The inner-most 2 dimensions of Input(X) should be equal."
"X's shape[-2] = %d and shape[-1] = %d.",
x_dims[x_ndim - 2], x_dims[x_ndim - 1]));
MatrixPowerFunction<DeviceContext, T>(X, n, Out, ctx);
}
};
template <typename DeviceContext, typename T>
void MatrixPowerGradFunction(const Tensor* X, const Tensor* Out,
const Tensor* dOut, const int n, Tensor* dX,
const paddle::framework::ExecutionContext& ctx) {
dX->mutable_data<T>(ctx.GetPlace());
const auto& x_dims = X->dims();
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
if (n == 0) {
// \nabla X = O
phi::funcs::SetConstant<DeviceContext, T> zero;
zero(dev_ctx, dX, static_cast<T>(0));
return;
} else if (n == 1) {
// \nabla X = \nabla Out
framework::TensorCopy(*dOut, ctx.GetPlace(), dev_ctx, dX);
return;
}
auto trans_desc = phi::funcs::CreateMatrixDescriptor(x_dims, 0, true);
auto no_trans_desc = phi::funcs::CreateMatrixDescriptor(x_dims, 0, false);
if (n == -1) {
// \nabla X = Out^{T} * \nabla Out * Out^{T}
Tensor temp_dx =
ctx.AllocateTmpTensor<T, DeviceContext>(X->dims(), dev_ctx);
blas.MatMul(*Out, trans_desc, *dOut, no_trans_desc, static_cast<T>(-1),
&temp_dx, static_cast<T>(0));
blas.MatMul(temp_dx, no_trans_desc, *Out, trans_desc, static_cast<T>(1), dX,
static_cast<T>(0));
return;
}
Tensor new_x = ctx.AllocateTmpTensor<T, DeviceContext>(X->dims(), dev_ctx);
int new_n = n;
if (n > 0) {
// newX = X
framework::TensorCopy(*X, ctx.GetPlace(), dev_ctx, &new_x);
} else {
// newX = X^{-1}, n = -n
phi::funcs::MatrixInverseFunctor<DeviceContext, T> mat_inv;
mat_inv(dev_ctx, *X, &new_x);
new_n = -n;
}
// Use chain rule blow to compute \nabla newX^{n}
// First, Get newX^{0}, newX^{1}, ..., newX^{n - 1},
// Note that newX^{0} can be omitted
std::vector<std::shared_ptr<Tensor>> tensor_list(new_n - 1);
tensor_list[0] = std::make_shared<Tensor>(new_x);
int index = 1;
while (index < new_n - 1) {
tensor_list[index] = std::make_shared<Tensor>(
ctx.AllocateTmpTensor<T, DeviceContext>(X->dims(), dev_ctx));
blas.MatMul(*tensor_list[index - 1], no_trans_desc, new_x, no_trans_desc,
static_cast<T>(1), tensor_list[index].get(), static_cast<T>(0));
index++;
}
// Second, \nabla newX = \sum_{i = 0}^{n - 1} (newX^{T}^{i}
// * \nabla Out
// * (newX^{T}^{n - i - 1})
Tensor dx_new = ctx.AllocateTmpTensor<T, DeviceContext>(X->dims(), dev_ctx);
blas.MatMul(*tensor_list[new_n - 2], trans_desc, *dOut, no_trans_desc,
static_cast<T>(1), &dx_new, static_cast<T>(0));
Tensor da_an_minus1 =
ctx.AllocateTmpTensor<T, DeviceContext>(X->dims(), dev_ctx);
blas.MatMul(*dOut, no_trans_desc, *tensor_list[new_n - 2], trans_desc,
static_cast<T>(1), &da_an_minus1, static_cast<T>(0));
blas.AXPY(X->numel(), static_cast<T>(1), da_an_minus1.data<T>(),
dx_new.data<T>());
int start = 0;
while (start < new_n - 2) {
Tensor a_da = ctx.AllocateTmpTensor<T, DeviceContext>(X->dims(), dev_ctx);
Tensor a_da_a = ctx.AllocateTmpTensor<T, DeviceContext>(X->dims(), dev_ctx);
blas.MatMul(*tensor_list[start], trans_desc, *dOut, no_trans_desc,
static_cast<T>(1), &a_da, static_cast<T>(0));
blas.MatMul(a_da, no_trans_desc, *tensor_list[new_n - 3 - start],
trans_desc, static_cast<T>(1), &a_da_a, static_cast<T>(0));
blas.AXPY(X->numel(), static_cast<T>(1), a_da_a.data<T>(),
dx_new.data<T>());
start++;
}
if (n > 0) {
// \nabla X = \nabla newX
framework::TensorCopy(dx_new, ctx.GetPlace(), dev_ctx, dX);
} else {
// \nabla X = newX^{T} * \nabla newX * newX^{T}
Tensor temp_dx =
ctx.AllocateTmpTensor<T, DeviceContext>(X->dims(), dev_ctx);
blas.MatMul(new_x, trans_desc, dx_new, no_trans_desc, static_cast<T>(-1),
&temp_dx, static_cast<T>(0));
blas.MatMul(temp_dx, no_trans_desc, new_x, trans_desc, static_cast<T>(1),
dX, static_cast<T>(0));
}
return;
}
template <typename DeviceContext, typename T>
class MatrixPowerGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const Tensor* X = ctx.Input<Tensor>("X");
const Tensor* Out = ctx.Input<Tensor>("Out");
const Tensor* dOut = ctx.Input<Tensor>(framework::GradVarName("Out"));
const int n = ctx.Attr<int>("n");
Tensor* dX = ctx.Output<Tensor>(framework::GradVarName("X"));
MatrixPowerGradFunction<DeviceContext, T>(X, Out, dOut, n, dX, ctx);
}
};
} // namespace operators
} // namespace paddle
......@@ -27,7 +27,7 @@ kernel_library(full_kernel DEPS ${COMMON_KERNEL_DEPS} empty_kernel)
# Some kernels depend on some targets that are not commonly used.
# These targets are not suitable for common dependencies.
# In this case, you need to manually generate them here.
set(MANUAL_BUILD_KERNELS math_kernel softmax_kernel softmax_grad_kernel triangular_solve_grad_kernel maxout_kernel maxout_grad_kernel put_along_axis_kernel put_along_axis_grad_kernel take_along_axis_kernel take_along_axis_grad_kernel eigh_kernel segment_pool_kernel segment_pool_grad_kernel)
set(MANUAL_BUILD_KERNELS math_kernel softmax_kernel softmax_grad_kernel triangular_solve_grad_kernel maxout_kernel maxout_grad_kernel put_along_axis_kernel put_along_axis_grad_kernel take_along_axis_kernel take_along_axis_grad_kernel eigh_kernel segment_pool_kernel segment_pool_grad_kernel matrix_power_kernel matrix_power_grad_kernel)
kernel_library(math_kernel DEPS ${COMMON_KERNEL_DEPS} cast_kernel copy_kernel)
kernel_library(softmax_kernel DEPS ${COMMON_KERNEL_DEPS} softmax)
kernel_library(softmax_grad_kernel DEPS ${COMMON_KERNEL_DEPS} softmax)
......@@ -38,6 +38,8 @@ kernel_library(put_along_axis_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_k
kernel_library(put_along_axis_grad_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_kernel)
kernel_library(take_along_axis_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_kernel)
kernel_library(take_along_axis_grad_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_kernel)
kernel_library(matrix_power_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_inverse)
kernel_library(matrix_power_grad_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_inverse)
kernel_library(eigh_kernel DEPS ${COMMON_KERNEL_DEPS} lapack_function)
kernel_library(segment_pool_kernel DEPS ${COMMON_KERNEL_DEPS} segment_pooling)
kernel_library(segment_pool_grad_kernel DEPS ${COMMON_KERNEL_DEPS} segment_pooling)
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -12,16 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/matrix_power_op.h"
#include "paddle/phi/kernels/matrix_power_grad_kernel.h"
#include "paddle/phi/kernels/impl/matrix_power_grad_kernel_impl.h"
namespace ops = paddle::operators;
namespace plf = paddle::platform;
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
REGISTER_OP_CUDA_KERNEL(matrix_power,
ops::MatrixPowerKernel<plf::CUDADeviceContext, float>,
ops::MatrixPowerKernel<plf::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
matrix_power_grad,
ops::MatrixPowerGradKernel<plf::CUDADeviceContext, float>,
ops::MatrixPowerGradKernel<plf::CUDADeviceContext, double>);
PD_REGISTER_KERNEL(matrix_power_grad,
CPU,
ALL_LAYOUT,
phi::MatrixPowerGradKernel,
float,
double) {}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/matrix_power_kernel.h"
#include "paddle/phi/kernels/impl/matrix_power_kernel_impl.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(
matrix_power, CPU, ALL_LAYOUT, phi::MatrixPowerKernel, float, double) {}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/matrix_power_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/matrix_power_grad_kernel_impl.h"
PD_REGISTER_KERNEL(matrix_power_grad,
GPU,
ALL_LAYOUT,
phi::MatrixPowerGradKernel,
float,
double) {}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/matrix_power_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/matrix_power_kernel_impl.h"
PD_REGISTER_KERNEL(
matrix_power, GPU, ALL_LAYOUT, phi::MatrixPowerKernel, float, double) {}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/matrix_inverse.h"
namespace phi {
template <typename Context, typename T>
void MatrixPowerGradFunction(const DenseTensor* X,
const DenseTensor* Out,
const DenseTensor* dOut,
const int n,
DenseTensor* dX,
const Context& ctx) {
ctx.template Alloc<T>(dX);
const auto& x_dims = X->dims();
auto blas = phi::funcs::GetBlas<Context, T>(ctx);
if (n == 0) {
// \nabla X = O
phi::funcs::SetConstant<Context, T> zero;
zero(ctx, dX, static_cast<T>(0));
return;
} else if (n == 1) {
// \nabla X = \nabla Out
paddle::framework::TensorCopy(*dOut, ctx.GetPlace(), ctx, dX);
return;
}
auto trans_desc = phi::funcs::CreateMatrixDescriptor(x_dims, 0, true);
auto no_trans_desc = phi::funcs::CreateMatrixDescriptor(x_dims, 0, false);
if (n == -1) {
// \nabla X = Out^{T} * \nabla Out * Out^{T}
DenseTensor temp_dx;
temp_dx.Resize(X->dims());
ctx.template Alloc<T>(&temp_dx);
blas.MatMul(*Out,
trans_desc,
*dOut,
no_trans_desc,
static_cast<T>(-1),
&temp_dx,
static_cast<T>(0));
blas.MatMul(temp_dx,
no_trans_desc,
*Out,
trans_desc,
static_cast<T>(1),
dX,
static_cast<T>(0));
return;
}
DenseTensor new_x;
new_x.Resize(X->dims());
ctx.template Alloc<T>(&new_x);
int new_n = n;
if (n > 0) {
// newX = X
paddle::framework::TensorCopy(*X, ctx.GetPlace(), ctx, &new_x);
} else {
// newX = X^{-1}, n = -n
phi::funcs::MatrixInverseFunctor<Context, T> mat_inv;
mat_inv(ctx, *X, &new_x);
new_n = -n;
}
// Use chain rule blow to compute \nabla newX^{n}
// First, Get newX^{0}, newX^{1}, ..., newX^{n - 1},
// Note that newX^{0} can be omitted
std::vector<std::shared_ptr<DenseTensor>> tensor_list(new_n - 1);
tensor_list[0] = std::make_shared<DenseTensor>(new_x);
int index = 1;
while (index < new_n - 1) {
DenseTensor tensor_list_index;
tensor_list_index.Resize(X->dims());
ctx.template Alloc<T>(&tensor_list_index);
tensor_list[index] = std::make_shared<DenseTensor>(tensor_list_index);
blas.MatMul(*tensor_list[index - 1],
no_trans_desc,
new_x,
no_trans_desc,
static_cast<T>(1),
tensor_list[index].get(),
static_cast<T>(0));
index++;
}
// Second, \nabla newX = \sum_{i = 0}^{n - 1} (newX^{T}^{i}
// * \nabla Out
// * (newX^{T}^{n - i - 1})
DenseTensor dx_new;
dx_new.Resize(X->dims());
ctx.template Alloc<T>(&dx_new);
blas.MatMul(*tensor_list[new_n - 2],
trans_desc,
*dOut,
no_trans_desc,
static_cast<T>(1),
&dx_new,
static_cast<T>(0));
DenseTensor da_an_minus1;
da_an_minus1.Resize(X->dims());
ctx.template Alloc<T>(&da_an_minus1);
blas.MatMul(*dOut,
no_trans_desc,
*tensor_list[new_n - 2],
trans_desc,
static_cast<T>(1),
&da_an_minus1,
static_cast<T>(0));
blas.AXPY(
X->numel(), static_cast<T>(1), da_an_minus1.data<T>(), dx_new.data<T>());
int start = 0;
while (start < new_n - 2) {
DenseTensor a_da;
a_da.Resize(X->dims());
ctx.template Alloc<T>(&a_da);
DenseTensor a_da_a;
a_da_a.Resize(X->dims());
ctx.template Alloc<T>(&a_da_a);
blas.MatMul(*tensor_list[start],
trans_desc,
*dOut,
no_trans_desc,
static_cast<T>(1),
&a_da,
static_cast<T>(0));
blas.MatMul(a_da,
no_trans_desc,
*tensor_list[new_n - 3 - start],
trans_desc,
static_cast<T>(1),
&a_da_a,
static_cast<T>(0));
blas.AXPY(
X->numel(), static_cast<T>(1), a_da_a.data<T>(), dx_new.data<T>());
start++;
}
if (n > 0) {
// \nabla X = \nabla newX
paddle::framework::TensorCopy(dx_new, ctx.GetPlace(), ctx, dX);
} else {
// \nabla X = newX^{T} * \nabla newX * newX^{T}
DenseTensor temp_dx;
temp_dx.Resize(X->dims());
ctx.template Alloc<T>(&temp_dx);
blas.MatMul(new_x,
trans_desc,
dx_new,
no_trans_desc,
static_cast<T>(-1),
&temp_dx,
static_cast<T>(0));
blas.MatMul(temp_dx,
no_trans_desc,
new_x,
trans_desc,
static_cast<T>(1),
dX,
static_cast<T>(0));
}
return;
}
template <typename T, typename Context>
void MatrixPowerGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& out,
const DenseTensor& out_grad,
int n,
DenseTensor* x_grad) {
auto X = &x;
auto Out = &out;
auto dOut = &out_grad;
auto dX = x_grad;
MatrixPowerGradFunction<Context, T>(X, Out, dOut, n, dX, ctx);
}
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/funcs/matrix_inverse.h"
namespace phi {
template <typename T>
struct IdentityMatrixFunctor {
IdentityMatrixFunctor(const int m, T* output) : m_(m), output_(output) {}
HOSTDEVICE void operator()(size_t index) const {
const int row = index / m_ % m_;
const int col = index % m_;
output_[index] = col == row ? static_cast<T>(1) : static_cast<T>(0);
}
const int m_;
T* output_;
};
template <typename Context, typename T>
void MatrixPowerFunction(const DenseTensor* X,
const int n,
DenseTensor* Out,
const Context& ctx) {
const auto& x_dims = X->dims();
const int x_ndim = x_dims.size();
T* out_data = ctx.template Alloc<T>(Out);
phi::funcs::ForRange<Context> for_range(ctx, X->numel());
if (n == 0) {
// Out = Identity Matrix
IdentityMatrixFunctor<T> functor(x_dims[x_ndim - 1], out_data);
for_range(functor);
return;
}
auto blas = phi::funcs::GetBlas<Context, T>(ctx);
DenseTensor new_x;
new_x.Resize(X->dims());
ctx.template Alloc<T>(&new_x);
int new_n = n;
if (n > 0) {
// newX = X
paddle::framework::TensorCopy(*X, ctx.GetPlace(), ctx, &new_x);
} else {
// newX = X^{-1}, n = -n
phi::funcs::MatrixInverseFunctor<Context, T> mat_inv;
mat_inv(ctx, *X, &new_x);
new_n = -n;
}
if (new_n == 1) {
paddle::framework::TensorCopy(new_x, ctx.GetPlace(), ctx, Out);
return;
}
auto no_trans_desc = phi::funcs::CreateMatrixDescriptor(x_dims, 0, false);
if (new_n == 2) {
// Out = newX * newX
ctx.template Alloc<T>(Out);
blas.MatMul(new_x,
no_trans_desc,
new_x,
no_trans_desc,
static_cast<T>(1),
Out,
static_cast<T>(0));
return;
} else if (new_n == 3) {
// Out = (newX * newX) * newX
// Note: C[i] matrices in MatMul must not overlap, i.e. the individual
// gemm operations must be computable independently; otherwise,
// undefined behavior is expected.
DenseTensor temp;
temp.Resize(X->dims());
ctx.template Alloc<T>(&temp);
blas.MatMul(new_x,
no_trans_desc,
new_x,
no_trans_desc,
static_cast<T>(1),
&temp,
static_cast<T>(0));
blas.MatMul(temp,
no_trans_desc,
new_x,
no_trans_desc,
static_cast<T>(1),
Out,
static_cast<T>(0));
return;
} else if (new_n == 4) {
// Out = (newX * newX) * (newX * newX)
DenseTensor temp;
temp.Resize(X->dims());
ctx.template Alloc<T>(&temp);
blas.MatMul(new_x,
no_trans_desc,
new_x,
no_trans_desc,
static_cast<T>(1),
&temp,
static_cast<T>(0));
blas.MatMul(temp,
no_trans_desc,
temp,
no_trans_desc,
static_cast<T>(1),
Out,
static_cast<T>(0));
return;
}
// Calculate Out = newX^{n} for abs(n) > 4 with time complexity as O(logN)
int bit = 0;
DenseTensor z = DenseTensor(X->dtype());
bool out_inited = false;
DenseTensor temp_out;
temp_out.Resize(X->dims());
ctx.template Alloc<T>(&temp_out);
DenseTensor temp_z;
temp_z.Resize(X->dims());
ctx.template Alloc<T>(&temp_z);
while (new_n > 0) {
bit = new_n & 0x1;
new_n >>= 1;
if (z.IsInitialized()) {
blas.MatMul(z,
no_trans_desc,
z,
no_trans_desc,
static_cast<T>(1),
&temp_z,
static_cast<T>(0));
paddle::framework::TensorCopy(temp_z, ctx.GetPlace(), ctx, &z);
} else {
z.Resize(X->dims());
ctx.template Alloc<T>(&z);
paddle::framework::TensorCopy(new_x, ctx.GetPlace(), ctx, &z);
}
if (bit == 1) {
if (out_inited == true) {
blas.MatMul(*Out,
no_trans_desc,
z,
no_trans_desc,
static_cast<T>(1),
&temp_out,
static_cast<T>(0));
paddle::framework::TensorCopy(temp_out, ctx.GetPlace(), ctx, Out);
} else {
paddle::framework::TensorCopy(z, ctx.GetPlace(), ctx, Out);
out_inited = true;
}
}
}
return;
}
template <typename T, typename Context>
void MatrixPowerKernel(const Context& ctx,
const DenseTensor& x,
int n,
DenseTensor* out) {
const DenseTensor* X = &x;
auto Out = out;
const auto& x_dims = X->dims();
const int x_ndim = x_dims.size();
PADDLE_ENFORCE_EQ(
x_dims[x_ndim - 2],
x_dims[x_ndim - 1],
errors::InvalidArgument(
"The inner-most 2 dimensions of Input(X) should be equal."
"X's shape[-2] = %d and shape[-1] = %d.",
x_dims[x_ndim - 2],
x_dims[x_ndim - 1]));
MatrixPowerFunction<Context, T>(X, n, Out, ctx);
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void MatrixPowerGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& out,
const DenseTensor& out_grad,
int n,
DenseTensor* x_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void MatrixPowerKernel(const Context& ctx,
const DenseTensor& x,
int n,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature MatrixPowerGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("matrix_power_grad",
{"X", "Out", GradVarName("Out")},
{"n"},
{GradVarName("X")});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(matrix_power_grad,
phi::MatrixPowerGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册