未验证 提交 768e50c9 编写于 作者: L Lin Manhui 提交者: GitHub

[PHI] Move spectral_norm to phi (#44577)

* Add kernel declarations

* Copy kernel implementation code

* Transfer implementation code

* Fix: Move out_grad to first

* Register new kernels

* Remove old kernels

* Move out_grad to last

* Fix bugs

* Transfer infermeta

* Add yaml files

* Add blank line

* Fix code style

* Optimize directory structure
Co-authored-by: NBobholamovic <linmanhui@baidu.com>
上级 a90b8dc1
......@@ -9,12 +9,14 @@
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/spectral_norm_op.h"
#include <memory>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/ternary.h"
namespace paddle {
namespace operators {
......@@ -24,82 +26,6 @@ class SpectralNormOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "SpectralNorm");
OP_INOUT_CHECK(ctx->HasInput("U"), "Input", "U", "SpectralNorm");
OP_INOUT_CHECK(ctx->HasInput("V"), "Input", "V", "SpectralNorm");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "SpectralNorm");
auto dim_weight = ctx->GetInputDim("Weight");
auto rank_weight = dim_weight.size();
PADDLE_ENFORCE_GE(rank_weight,
2,
platform::errors::InvalidArgument(
"The rank of Input(Weights) should be greater equal "
"than 2, but received Weight rank(%d)",
rank_weight));
PADDLE_ENFORCE_LE(rank_weight,
5,
platform::errors::InvalidArgument(
"The rank of Input(Weights) should be less equal "
"than 5, but received Weight rank(%d)",
rank_weight));
int dim = ctx->Attrs().Get<int>("dim");
int power_iters = ctx->Attrs().Get<int>("power_iters");
auto dim_valid = dim == 0 || dim == 1;
PADDLE_ENFORCE_EQ(
dim_valid,
true,
platform::errors::InvalidArgument(
"Attr(dim) can only be 0 or 1, but received %d", dim));
PADDLE_ENFORCE_GE(
power_iters,
0,
platform::errors::InvalidArgument(
"Attr(power_iters) should be greater equal then 0, but received %d",
power_iters));
int h = dim_weight[dim];
int w = 1;
for (int i = 0; i < rank_weight; i++) {
if (i != dim) {
w *= dim_weight[i];
}
}
auto dim_u = ctx->GetInputDim("U");
auto dim_v = ctx->GetInputDim("V");
if (ctx->IsRuntime() || (dim_u[0] > 0 && h > 0)) {
PADDLE_ENFORCE_EQ(dim_u[0],
h,
platform::errors::InvalidArgument(
"Input(U) dimension[0] should be equal to "
"Input(Weight) dimension[Attr(dim)], but received "
"U dimension[0](%d) != Weight dimension[%d](%d)",
dim_u[0],
dim,
h));
}
if (ctx->IsRuntime() || (dim_v[0] > 0 && w > 0)) {
PADDLE_ENFORCE_EQ(
dim_v[0],
w,
platform::errors::InvalidArgument(
"Input(V) dimension[0] should be equal to the product of "
"Input(Weight) dimension except dimension[Attr(dim)], but "
"received V dimension[0](%d) != product of Input(Weight) "
"dimension(%d)",
dim_v[0],
w));
}
ctx->SetOutputDim("Out", dim_weight);
ctx->ShareLoD("Weight", /*->*/ "Out");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
......@@ -219,26 +145,6 @@ class SpectralNormOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(
ctx->HasInput("Weight"), "Input", "Weight", "SpectralNormGrad");
OP_INOUT_CHECK(ctx->HasInput("U"), "Input", "U", "SpectralNormGrad");
OP_INOUT_CHECK(ctx->HasInput("V"), "Input", "V", "SpectralNormGrad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")),
"Input",
"Out@GRAD",
"SpectralNormGrad");
PADDLE_ENFORCE_EQ(
ctx->HasInput(framework::GradVarName("Out")),
true,
platform::errors::NotFound("Input(Out@GRAD) should not be null"));
auto dim_x = ctx->GetInputDim("Weight");
if (ctx->HasOutput(framework::GradVarName("Weight"))) {
ctx->SetOutputDim(framework::GradVarName("Weight"), dim_x);
}
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
......@@ -250,15 +156,20 @@ class SpectralNormOpGrad : public framework::OperatorWithKernel {
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(spectral_norm,
SpectralNormInferMetaFunctor,
PD_INFER_META(phi::SpectralNormInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(spectral_norm_grad,
SpectralNormGradInferMetaFunctor,
PD_INFER_META(phi::SpectralNormGradInferMeta));
REGISTER_OPERATOR(spectral_norm,
ops::SpectralNormOp,
ops::SpectralNormOpMaker,
ops::SpectralNormGradOpMaker<paddle::framework::OpDesc>,
ops::SpectralNormGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(spectral_norm_grad, ops::SpectralNormOpGrad);
REGISTER_OP_CPU_KERNEL(spectral_norm,
ops::SpectralNormKernel<phi::CPUContext, float>,
ops::SpectralNormKernel<phi::CPUContext, double>);
REGISTER_OP_CPU_KERNEL(spectral_norm_grad,
ops::SpectralNormGradKernel<phi::CPUContext, float>,
ops::SpectralNormGradKernel<phi::CPUContext, double>);
ops::SpectralNormGradOpMaker<paddle::imperative::OpBase>,
SpectralNormInferMetaFunctor);
REGISTER_OPERATOR(spectral_norm_grad,
ops::SpectralNormOpGrad,
SpectralNormGradInferMetaFunctor);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/spectral_norm_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
spectral_norm,
ops::SpectralNormKernel<paddle::platform::CUDADeviceContext, float>,
ops::SpectralNormKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
spectral_norm_grad,
ops::SpectralNormGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::SpectralNormGradKernel<paddle::platform::CUDADeviceContext, double>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
namespace operators {
template <typename T,
size_t D,
int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
using Tensor = framework::Tensor;
using Array1 = Eigen::DSizes<int64_t, 1>;
using Array2 = Eigen::DSizes<int64_t, 2>;
using IndexPair = Eigen::IndexPair<int>;
template <typename DeviceContext, typename T>
static inline void TransCompute(const int rank,
const Tensor& in,
Tensor* out,
const std::vector<int>& perm,
const DeviceContext& dev_ctx) {
if (rank <= 1 || rank > 5) {
PADDLE_THROW(paddle::platform::errors::Fatal(
"Weight rank of SpectralNorm should be in range [2, 5], but got %d.",
rank));
}
switch (rank) {
case 2:
phi::funcs::Transpose<DeviceContext, T, 2> trans2;
trans2(dev_ctx, in, out, perm);
break;
case 3:
phi::funcs::Transpose<DeviceContext, T, 3> trans3;
trans3(dev_ctx, in, out, perm);
break;
case 4:
phi::funcs::Transpose<DeviceContext, T, 4> trans4;
trans4(dev_ctx, in, out, perm);
break;
case 5:
phi::funcs::Transpose<DeviceContext, T, 5> trans5;
trans5(dev_ctx, in, out, perm);
break;
default:
break;
}
}
template <typename DeviceContext, typename T>
static inline void CalcMatrixSigmaAndNormWeight(
Tensor* sigma,
Tensor* u,
Tensor* v,
Tensor* weight,
const int power_iters,
const float eps,
const framework::ExecutionContext& ctx) {
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
auto blas = phi::funcs::GetBlas<DeviceContext, T>(ctx);
auto sigma_t = EigenTensor<T, 2>::From(*sigma);
auto weight_t = EigenTensor<T, 2>::From(*weight);
auto u_t = EigenTensor<T, 2>::From(*u);
auto v_t = EigenTensor<T, 2>::From(*v);
const int h = weight->dims()[0];
const int w = weight->dims()[1];
for (int i = 0; i < power_iters; i++) {
// V = W^T * U / ||W^T * U||_2
blas.MatMul(*weight, true, *u, false, T(1), v, T(0));
auto v_t_norm =
v_t.square().sum().sqrt().eval().reshape(Array1(1)).broadcast(
Array1(w));
v_t.device(place) = v_t / (v_t_norm + v_t_norm.constant(eps));
// U = W^T * V / ||W^T * V||_2
blas.MatMul(*weight, false, *v, false, T(1), u, T(0));
auto u_t_norm =
u_t.square().sum().sqrt().eval().reshape(Array1(1)).broadcast(
Array1(h));
u_t.device(place) = u_t / (u_t_norm + u_t_norm.constant(eps));
}
Tensor weight_v;
weight_v.mutable_data<T>({h, 1}, ctx.GetPlace());
blas.MatMul(*weight, false, *v, false, T(1), &weight_v, T(0));
auto weight_v_t = EigenTensor<T, 2>::From(weight_v);
sigma_t.device(place) = (u_t * weight_v_t)
.sum()
.eval()
.reshape(Array2(1, 1))
.broadcast(Array2(h, w));
weight_t.device(place) = weight_t / sigma_t;
}
template <typename DeviceContext, typename T>
class SpectralNormKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto weight = ctx.Input<Tensor>("Weight");
auto u = ctx.Input<Tensor>("U");
auto v = ctx.Input<Tensor>("V");
auto out = ctx.Output<Tensor>("Out");
int dim = ctx.Attr<int>("dim");
int power_iters = ctx.Attr<int>("power_iters");
float eps = ctx.Attr<float>("eps");
const int h = u->dims()[0];
const int w = v->dims()[0];
Tensor weight_mat;
auto dims = weight->dims();
const int rank = dims.size();
std::vector<int> real_dims;
if (dim != 0) {
std::vector<int> perm;
perm.push_back(dim);
real_dims.push_back(dims[dim]);
for (int i = 0; i < rank; i++) {
if (i != dim) {
perm.push_back(i);
real_dims.push_back(dims[i]);
}
}
weight_mat.mutable_data<T>(phi::make_ddim(real_dims), ctx.GetPlace());
TransCompute<DeviceContext, T>(rank, *weight, &weight_mat, perm, dev_ctx);
} else {
for (int i = 0; i < rank; i++) {
real_dims.push_back(i);
}
paddle::framework::TensorCopySync(*weight, ctx.GetPlace(), &weight_mat);
}
weight_mat = weight_mat.Resize({h, w});
Tensor sigma;
sigma.mutable_data<T>(weight_mat.dims(), ctx.GetPlace());
Tensor uu, vv;
paddle::framework::TensorCopySync(*u, ctx.GetPlace(), &uu);
paddle::framework::TensorCopySync(*v, ctx.GetPlace(), &vv);
CalcMatrixSigmaAndNormWeight<DeviceContext, T>(&sigma,
&(uu.Resize({h, 1})),
&(vv.Resize({w, 1})),
&weight_mat,
power_iters,
eps,
ctx);
if (dim != 0) {
std::vector<int> perm;
for (int i = 0; i < rank; i++) {
if (i < dim) {
perm.push_back(i + 1);
} else if (i == dim) {
perm.push_back(0);
} else {
perm.push_back(i);
}
}
out->mutable_data<T>(dims, ctx.GetPlace());
TransCompute<DeviceContext, T>(
rank,
weight_mat.Resize(phi::make_ddim(real_dims)),
out,
perm,
dev_ctx);
} else {
paddle::framework::TensorCopySync(
weight_mat.Resize(dims), ctx.GetPlace(), out);
}
}
};
template <typename DeviceContext, typename T>
class SpectralNormGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = phi::funcs::GetBlas<DeviceContext, T>(ctx);
auto weight = ctx.Input<Tensor>("Weight");
auto u = ctx.Input<Tensor>("U");
auto v = ctx.Input<Tensor>("V");
auto out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto weight_grad = ctx.Output<Tensor>(framework::GradVarName("Weight"));
int dim = ctx.Attr<int>("dim");
int power_iters = ctx.Attr<int>("power_iters");
float eps = ctx.Attr<float>("eps");
const int h = u->dims()[0];
const int w = v->dims()[0];
Tensor weight_mat, out_grad_mat;
auto dims = weight->dims();
const int rank = dims.size();
std::vector<int> real_dims;
if (dim != 0) {
std::vector<int> perm;
perm.push_back(dim);
real_dims.push_back(dims[dim]);
for (int i = 0; i < rank; i++) {
if (i != dim) {
perm.push_back(i);
real_dims.push_back(dims[i]);
}
}
weight_mat.mutable_data<T>(phi::make_ddim(real_dims), ctx.GetPlace());
out_grad_mat.mutable_data<T>(phi::make_ddim(real_dims), ctx.GetPlace());
TransCompute<DeviceContext, T>(rank, *weight, &weight_mat, perm, dev_ctx);
TransCompute<DeviceContext, T>(
rank, *out_grad, &out_grad_mat, perm, dev_ctx);
} else {
for (int i = 0; i < rank; i++) {
real_dims.push_back(i);
}
paddle::framework::TensorCopySync(*weight, ctx.GetPlace(), &weight_mat);
paddle::framework::TensorCopySync(
*out_grad, ctx.GetPlace(), &out_grad_mat);
}
weight_mat = weight_mat.Resize({h, w});
out_grad_mat = out_grad_mat.Resize({h, w});
Tensor sigma;
sigma.mutable_data<T>(weight_mat.dims(), ctx.GetPlace());
Tensor uu, vv;
paddle::framework::TensorCopySync(*u, ctx.GetPlace(), &uu);
paddle::framework::TensorCopySync(*v, ctx.GetPlace(), &vv);
CalcMatrixSigmaAndNormWeight<DeviceContext, T>(&sigma,
&(uu.Resize({h, 1})),
&(vv.Resize({w, 1})),
&weight_mat,
power_iters,
eps,
ctx);
Tensor uv;
uv.mutable_data<T>({h, w}, ctx.GetPlace());
blas.MatMul(
uu.Resize({h, 1}), false, vv.Resize({w, 1}), false, T(1), &uv, T(0));
Tensor weight_grad_mat;
weight_grad_mat.mutable_data<T>({h, w}, ctx.GetPlace());
auto weight_grad_mat_t = EigenTensor<T, 2>::From(weight_grad_mat);
auto weight_mat_t = EigenTensor<T, 2>::From(weight_mat);
auto out_grad_mat_t = EigenTensor<T, 2>::From(out_grad_mat);
auto sigma_t = EigenTensor<T, 2>::From(sigma);
auto uv_t = EigenTensor<T, 2>::From(uv);
weight_mat_t.device(place) =
weight_mat_t.sum().eval().reshape(Array2(1, 1)).broadcast(Array2(h, w));
weight_grad_mat_t.device(place) =
out_grad_mat_t * (out_grad_mat_t.constant(1.0) - uv_t * weight_mat_t) /
sigma_t;
if (dim != 0) {
std::vector<int> perm;
for (int i = 0; i < rank; i++) {
if (i < dim) {
perm.push_back(i + 1);
} else if (i == dim) {
perm.push_back(0);
} else {
perm.push_back(i);
}
}
weight_grad->mutable_data<T>(dims, ctx.GetPlace());
TransCompute<DeviceContext, T>(
rank,
weight_grad_mat.Resize(phi::make_ddim(real_dims)),
weight_grad,
perm,
dev_ctx);
} else {
paddle::framework::TensorCopySync(
weight_grad_mat.Resize(dims), ctx.GetPlace(), weight_grad);
}
}
};
} // namespace operators
} // namespace paddle
......@@ -2152,6 +2152,16 @@
use_gpudnn : true
backward : softmax_grad
- api : spectral_norm
args : (Tensor weight, Tensor u, Tensor v, int dim, int power_iters, float eps)
output : Tensor
infer_meta :
func : SpectralNormInferMeta
kernel :
func : spectralnorm
data_type : weight
backward : spectral_norm_grad
- api : split
args : (Tensor x, IntArray num_or_sections, Scalar(int) axis)
output : Tensor[]
......
......@@ -2037,6 +2037,16 @@
func : softmax_grad
use_gpudnn : true
- backward_api : spectral_norm_grad
forward : spectral_norm (Tensor weight, Tensor u, Tensor v, int dim, int power_iters, float eps) -> Tensor(out)
args : (Tensor weight, Tensor u, Tensor v, Tensor out_grad, int dim, int power_iters, float eps)
output : Tensor(weight_grad)
infer_meta :
func : SpectralNormGradInferMeta
kernel :
func : spectral_norm_grad
data_type : out_grad
- backward_api : split_grad
forward : split (Tensor x, IntArray num_or_sections, Scalar axis) -> Tensor[](out)
args : (Tensor[] out_grad, Scalar axis = -1)
......
......@@ -685,6 +685,21 @@ void ScatterNdAddGradInferMeta(const MetaTensor& index,
}
}
void SpectralNormGradInferMeta(const MetaTensor& weight,
const MetaTensor& u,
const MetaTensor& v,
const MetaTensor& out_grad,
int dim,
int power_iters,
float eps,
MetaTensor* weight_grad) {
auto dim_x = weight.dims();
if (weight_grad) {
weight_grad->set_dims(dim_x);
weight_grad->set_dtype(out_grad.dtype());
}
}
void StackGradInferMeta(const MetaTensor& out_grad,
int axis,
std::vector<MetaTensor*> x_grad) {
......
......@@ -288,6 +288,15 @@ void ScatterNdAddGradInferMeta(const MetaTensor& index,
MetaTensor* x_grad,
MetaTensor* updates_grad);
void SpectralNormGradInferMeta(const MetaTensor& weight,
const MetaTensor& u,
const MetaTensor& v,
const MetaTensor& out_grad,
int dim,
int power_iters,
float eps,
MetaTensor* weight_grad);
void StackGradInferMeta(const MetaTensor& out_grad,
int axis,
std::vector<MetaTensor*> x_grad);
......
......@@ -1088,6 +1088,83 @@ void ScatterNdAddInferMeta(const MetaTensor& x,
out->set_dtype(x.dtype());
}
void SpectralNormInferMeta(const MetaTensor& weight,
const MetaTensor& u,
const MetaTensor& v,
int dim,
int power_iters,
float eps,
MetaTensor* out,
MetaConfig config) {
auto dim_weight = weight.dims();
auto rank_weight = dim_weight.size();
PADDLE_ENFORCE_GE(rank_weight,
2,
errors::InvalidArgument(
"The rank of Input(Weights) should be greater equal "
"than 2, but received Weight rank(%d)",
rank_weight));
PADDLE_ENFORCE_LE(
rank_weight,
5,
errors::InvalidArgument("The rank of Input(Weights) should be less equal "
"than 5, but received Weight rank(%d)",
rank_weight));
auto dim_valid = dim == 0 || dim == 1;
PADDLE_ENFORCE_EQ(dim_valid,
true,
errors::InvalidArgument(
"Attr(dim) can only be 0 or 1, but received %d", dim));
PADDLE_ENFORCE_GE(
power_iters,
0,
errors::InvalidArgument(
"Attr(power_iters) should be greater equal then 0, but received %d",
power_iters));
int h = dim_weight[dim];
int w = 1;
for (int i = 0; i < rank_weight; i++) {
if (i != dim) {
w *= dim_weight[i];
}
}
auto dim_u = u.dims();
auto dim_v = v.dims();
if (config.is_runtime || (dim_u[0] > 0 && h > 0)) {
PADDLE_ENFORCE_EQ(dim_u[0],
h,
errors::InvalidArgument(
"Input(U) dimension[0] should be equal to "
"Input(Weight) dimension[Attr(dim)], but received "
"U dimension[0](%d) != Weight dimension[%d](%d)",
dim_u[0],
dim,
h));
}
if (config.is_runtime || (dim_v[0] > 0 && w > 0)) {
PADDLE_ENFORCE_EQ(
dim_v[0],
w,
errors::InvalidArgument(
"Input(V) dimension[0] should be equal to the product of "
"Input(Weight) dimension except dimension[Attr(dim)], but "
"received V dimension[0](%d) != product of Input(Weight) "
"dimension(%d)",
dim_v[0],
w));
}
if (out) {
out->set_dims(dim_weight);
out->set_dtype(weight.dtype());
out->share_lod(weight);
}
}
void ViterbiDecodeInferMeta(const MetaTensor& input,
const MetaTensor& transition,
const MetaTensor& length,
......
......@@ -170,6 +170,15 @@ void ScatterNdAddInferMeta(const MetaTensor& x,
const MetaTensor& updates,
MetaTensor* out);
void SpectralNormInferMeta(const MetaTensor& weight,
const MetaTensor& u,
const MetaTensor& v,
int dim,
int power_iters,
float eps,
MetaTensor* out,
MetaConfig config = MetaConfig());
void ViterbiDecodeInferMeta(const MetaTensor& input,
const MetaTensor& transition,
const MetaTensor& length,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/spectral_norm_grad_kernel_impl.h"
#include "paddle/phi/kernels/spectral_norm_grad_kernel.h"
PD_REGISTER_KERNEL(spectral_norm_grad,
CPU,
ALL_LAYOUT,
phi::SpectralNormGradKernel,
float,
double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/spectral_norm_kernel_impl.h"
#include "paddle/phi/kernels/spectral_norm_kernel.h"
PD_REGISTER_KERNEL(
spectral_norm, CPU, ALL_LAYOUT, phi::SpectralNormKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/spectral_norm_grad_kernel_impl.h"
#include "paddle/phi/kernels/spectral_norm_grad_kernel.h"
PD_REGISTER_KERNEL(spectral_norm_grad,
GPU,
ALL_LAYOUT,
phi::SpectralNormGradKernel,
float,
double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/spectral_norm_kernel_impl.h"
#include "paddle/phi/kernels/spectral_norm_kernel.h"
PD_REGISTER_KERNEL(
spectral_norm, GPU, ALL_LAYOUT, phi::SpectralNormKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/impl/spectral_norm_kernel_impl.h"
namespace phi {
template <typename T, typename Context>
void SpectralNormGradKernel(const Context& dev_ctx,
const DenseTensor& weight,
const DenseTensor& u,
const DenseTensor& v,
const DenseTensor& out_grad,
int dim,
int power_iters,
float eps,
DenseTensor* weight_grad) {
auto& place = *dev_ctx.eigen_device();
auto blas = phi::funcs::GetBlas<Context, T>(dev_ctx);
const int h = u.dims()[0];
const int w = v.dims()[0];
DenseTensor weight_mat, out_grad_mat;
auto dims = weight.dims();
const int rank = dims.size();
std::vector<int> real_dims;
if (dim != 0) {
std::vector<int> perm;
perm.push_back(dim);
real_dims.push_back(dims[dim]);
for (int i = 0; i < rank; i++) {
if (i != dim) {
perm.push_back(i);
real_dims.push_back(dims[i]);
}
}
weight_mat.Resize(phi::make_ddim(real_dims));
dev_ctx.template Alloc<T>(&weight_mat);
out_grad_mat.Resize(phi::make_ddim(real_dims));
dev_ctx.template Alloc<T>(&out_grad_mat);
TransCompute2DTo5D<Context, T>(dev_ctx, weight, rank, perm, &weight_mat);
TransCompute2DTo5D<Context, T>(
dev_ctx, out_grad, rank, perm, &out_grad_mat);
} else {
for (int i = 0; i < rank; i++) {
real_dims.push_back(i);
}
phi::Copy(dev_ctx, weight, dev_ctx.GetPlace(), true, &weight_mat);
phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), true, &out_grad_mat);
}
weight_mat = weight_mat.Resize({h, w});
out_grad_mat = out_grad_mat.Resize({h, w});
DenseTensor sigma;
sigma.Resize(weight_mat.dims());
dev_ctx.template Alloc<T>(&sigma);
DenseTensor uu, vv;
phi::Copy(dev_ctx, u, dev_ctx.GetPlace(), true, &uu);
phi::Copy(dev_ctx, v, dev_ctx.GetPlace(), true, &vv);
CalcMatrixSigmaAndNormWeight<Context, T>(dev_ctx,
&weight_mat,
&(uu.Resize({h, 1})),
&(vv.Resize({w, 1})),
&sigma,
power_iters,
eps);
DenseTensor uv;
uv.Resize({h, w});
dev_ctx.template Alloc<T>(&uv);
blas.MatMul(
uu.Resize({h, 1}), false, vv.Resize({w, 1}), false, T(1), &uv, T(0));
DenseTensor weight_grad_mat;
weight_grad_mat.Resize({h, w});
dev_ctx.template Alloc<T>(&weight_grad_mat);
auto weight_grad_mat_t = EigenTensor<T, 2>::From(weight_grad_mat);
auto weight_mat_t = EigenTensor<T, 2>::From(weight_mat);
auto out_grad_mat_t = EigenTensor<T, 2>::From(out_grad_mat);
auto sigma_t = EigenTensor<T, 2>::From(sigma);
auto uv_t = EigenTensor<T, 2>::From(uv);
weight_mat_t.device(place) =
weight_mat_t.sum().eval().reshape(Array2(1, 1)).broadcast(Array2(h, w));
weight_grad_mat_t.device(place) =
out_grad_mat_t * (out_grad_mat_t.constant(1.0) - uv_t * weight_mat_t) /
sigma_t;
if (dim != 0) {
std::vector<int> perm;
for (int i = 0; i < rank; i++) {
if (i < dim) {
perm.push_back(i + 1);
} else if (i == dim) {
perm.push_back(0);
} else {
perm.push_back(i);
}
}
weight_grad->Resize(dims);
dev_ctx.template Alloc<T>(weight_grad);
TransCompute2DTo5D<Context, T>(
dev_ctx,
weight_grad_mat.Resize(phi::make_ddim(real_dims)),
rank,
perm,
weight_grad);
} else {
phi::Copy(dev_ctx,
weight_grad_mat.Resize(dims),
dev_ctx.GetPlace(),
true,
weight_grad);
}
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
using Array1 = Eigen::DSizes<int64_t, 1>;
using Array2 = Eigen::DSizes<int64_t, 2>;
using IndexPair = Eigen::IndexPair<int>;
template <typename Context, typename T>
static inline void TransCompute2DTo5D(const Context& dev_ctx,
const DenseTensor& in,
const int rank,
const std::vector<int>& perm,
DenseTensor* out) {
if (rank <= 1 || rank > 5) {
PADDLE_THROW(phi::errors::Fatal(
"Weight rank of SpectralNorm should be in range [2, 5], but got %d.",
rank));
}
switch (rank) {
case 2:
phi::funcs::Transpose<Context, T, 2> trans2;
trans2(dev_ctx, in, out, perm);
break;
case 3:
phi::funcs::Transpose<Context, T, 3> trans3;
trans3(dev_ctx, in, out, perm);
break;
case 4:
phi::funcs::Transpose<Context, T, 4> trans4;
trans4(dev_ctx, in, out, perm);
break;
case 5:
phi::funcs::Transpose<Context, T, 5> trans5;
trans5(dev_ctx, in, out, perm);
break;
default:
break;
}
}
template <typename Context, typename T>
static inline void CalcMatrixSigmaAndNormWeight(const Context& dev_ctx,
DenseTensor* weight,
DenseTensor* u,
DenseTensor* v,
DenseTensor* sigma,
const int power_iters,
const float eps) {
auto& place = *dev_ctx.eigen_device();
auto blas = funcs::GetBlas<Context, T>(dev_ctx);
auto sigma_t = EigenTensor<T, 2>::From(*sigma);
auto weight_t = EigenTensor<T, 2>::From(*weight);
auto u_t = EigenTensor<T, 2>::From(*u);
auto v_t = EigenTensor<T, 2>::From(*v);
const int h = weight->dims()[0];
const int w = weight->dims()[1];
for (int i = 0; i < power_iters; i++) {
// V = W^T * U / ||W^T * U||_2
blas.MatMul(*weight, true, *u, false, T(1), v, T(0));
auto v_t_norm =
v_t.square().sum().sqrt().eval().reshape(Array1(1)).broadcast(
Array1(w));
v_t.device(place) = v_t / (v_t_norm + v_t_norm.constant(eps));
// U = W^T * V / ||W^T * V||_2
blas.MatMul(*weight, false, *v, false, T(1), u, T(0));
auto u_t_norm =
u_t.square().sum().sqrt().eval().reshape(Array1(1)).broadcast(
Array1(h));
u_t.device(place) = u_t / (u_t_norm + u_t_norm.constant(eps));
}
DenseTensor weight_v;
weight_v.Resize({h, 1});
dev_ctx.template Alloc<T>(&weight_v);
blas.MatMul(*weight, false, *v, false, T(1), &weight_v, T(0));
auto weight_v_t = EigenTensor<T, 2>::From(weight_v);
sigma_t.device(place) = (u_t * weight_v_t)
.sum()
.eval()
.reshape(Array2(1, 1))
.broadcast(Array2(h, w));
weight_t.device(place) = weight_t / sigma_t;
}
template <typename T, typename Context>
void SpectralNormKernel(const Context& dev_ctx,
const DenseTensor& weight,
const DenseTensor& u,
const DenseTensor& v,
int dim,
int power_iters,
float eps,
DenseTensor* out) {
const int h = u.dims()[0];
const int w = v.dims()[0];
DenseTensor weight_mat;
auto dims = weight.dims();
const int rank = dims.size();
std::vector<int> real_dims;
if (dim != 0) {
std::vector<int> perm;
perm.push_back(dim);
real_dims.push_back(dims[dim]);
for (int i = 0; i < rank; i++) {
if (i != dim) {
perm.push_back(i);
real_dims.push_back(dims[i]);
}
}
weight_mat.Resize(phi::make_ddim(real_dims));
dev_ctx.template Alloc<T>(&weight_mat);
TransCompute2DTo5D<Context, T>(dev_ctx, weight, rank, perm, &weight_mat);
} else {
for (int i = 0; i < rank; i++) {
real_dims.push_back(i);
}
phi::Copy(dev_ctx, weight, dev_ctx.GetPlace(), true, &weight_mat);
}
weight_mat = weight_mat.Resize({h, w});
DenseTensor sigma;
sigma.Resize(weight_mat.dims());
dev_ctx.template Alloc<T>(&sigma);
DenseTensor uu, vv;
phi::Copy(dev_ctx, u, dev_ctx.GetPlace(), true, &uu);
phi::Copy(dev_ctx, v, dev_ctx.GetPlace(), true, &vv);
CalcMatrixSigmaAndNormWeight<Context, T>(dev_ctx,
&weight_mat,
&(uu.Resize({h, 1})),
&(vv.Resize({w, 1})),
&sigma,
power_iters,
eps);
if (dim != 0) {
std::vector<int> perm;
for (int i = 0; i < rank; i++) {
if (i < dim) {
perm.push_back(i + 1);
} else if (i == dim) {
perm.push_back(0);
} else {
perm.push_back(i);
}
}
out->Resize(dims);
dev_ctx.template Alloc<T>(out);
TransCompute2DTo5D<Context, T>(
dev_ctx, weight_mat.Resize(phi::make_ddim(real_dims)), rank, perm, out);
} else {
phi::Copy(dev_ctx, weight_mat.Resize(dims), dev_ctx.GetPlace(), true, out);
}
}
} // namespace phi
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void SpectrumNormGradKernel(const Context& dev_ctx,
const DenseTensor& weight,
const DenseTensor& u,
const DenseTensor& v,
const DenseTensor& out_grad,
int dim,
int power_iters,
float eps,
DenseTensor* weight_grad);
} // namespace phi
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void SpectrumNormKernel(const Context& dev_ctx,
const DenseTensor& weight,
const DenseTensor& u,
const DenseTensor& v,
int dim,
int power_iters,
float eps,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature SpectralNormOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("spectral_norm",
{"Weight", "U", "V"},
{"dim", "power_iters", "eps"},
{"Out"});
}
KernelSignature SpectralNormGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("spectral_norm_grad",
{"Weight", "U", "V", "Out@GRAD"},
{"dim", "power_iters", "eps"},
{"Weight@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(spectral_norm, phi::SpectralNormOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(spectral_norm_grad,
phi::SpectralNormGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册