提交 982a2b5a 编写于 作者: M miraiwk

add UOut and VOut for SpectralNorm

上级 9373cf5a
...@@ -30,6 +30,8 @@ class SpectralNormOp : public framework::OperatorWithKernel { ...@@ -30,6 +30,8 @@ class SpectralNormOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK(ctx->HasInput("U"), "Input", "U", "SpectralNorm"); OP_INOUT_CHECK(ctx->HasInput("U"), "Input", "U", "SpectralNorm");
OP_INOUT_CHECK(ctx->HasInput("V"), "Input", "V", "SpectralNorm"); OP_INOUT_CHECK(ctx->HasInput("V"), "Input", "V", "SpectralNorm");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "SpectralNorm"); OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "SpectralNorm");
OP_INOUT_CHECK(ctx->HasOutput("UOut"), "Output", "UOut", "SpectralNorm");
OP_INOUT_CHECK(ctx->HasOutput("VOut"), "Output", "VOut", "SpectralNorm");
auto dim_weight = ctx->GetInputDim("Weight"); auto dim_weight = ctx->GetInputDim("Weight");
auto rank_weight = dim_weight.size(); auto rank_weight = dim_weight.size();
...@@ -88,6 +90,8 @@ class SpectralNormOp : public framework::OperatorWithKernel { ...@@ -88,6 +90,8 @@ class SpectralNormOp : public framework::OperatorWithKernel {
} }
ctx->SetOutputDim("Out", dim_weight); ctx->SetOutputDim("Out", dim_weight);
ctx->SetOutputDim("UOut", dim_u);
ctx->SetOutputDim("VOut", dim_v);
ctx->ShareLoD("Weight", /*->*/ "Out"); ctx->ShareLoD("Weight", /*->*/ "Out");
} }
...@@ -126,6 +130,10 @@ class SpectralNormOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -126,6 +130,10 @@ class SpectralNormOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("Out", AddOutput("Out",
"The output weight tensor of spectral_norm operator, " "The output weight tensor of spectral_norm operator, "
"This tensor is in same shape with Input(Weight)."); "This tensor is in same shape with Input(Weight).");
AddOutput("UOut",
"The updated value of `U`");
AddOutput("VOut",
"The updated value of `V`");
AddAttr<int>("dim", AddAttr<int>("dim",
"The index of dimension which should be permuted " "The index of dimension which should be permuted "
...@@ -145,7 +153,6 @@ class SpectralNormOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -145,7 +153,6 @@ class SpectralNormOpMaker : public framework::OpProtoAndCheckerMaker {
"the denominator to aviod divide zero. " "the denominator to aviod divide zero. "
"Default 1e-12.") "Default 1e-12.")
.SetDefault(1e-12); .SetDefault(1e-12);
AddComment(R"DOC( AddComment(R"DOC(
This layer calculates the spectral normalization value of weight of This layer calculates the spectral normalization value of weight of
fc, conv1d, conv2d, conv3d layers which should be 2-D, 3-D, 4-D, 5-D fc, conv1d, conv2d, conv3d layers which should be 2-D, 3-D, 4-D, 5-D
......
...@@ -61,13 +61,11 @@ static inline void TransCompute(const int rank, const Tensor& in, Tensor* out, ...@@ -61,13 +61,11 @@ static inline void TransCompute(const int rank, const Tensor& in, Tensor* out,
} }
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
static inline void CalcMatrixSigmaAndNormWeight( static inline void UpdateUandV(
Tensor* sigma, Tensor* u, Tensor* v, Tensor* weight, const int power_iters, Tensor* u, Tensor* v, Tensor* weight, const int power_iters,
const float eps, const framework::ExecutionContext& ctx) { const float eps, const framework::ExecutionContext& ctx) {
auto& place = *ctx.template device_context<DeviceContext>().eigen_device(); auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
auto blas = math::GetBlas<DeviceContext, T>(ctx); auto blas = math::GetBlas<DeviceContext, T>(ctx);
auto sigma_t = EigenTensor<T, 2>::From(*sigma);
auto weight_t = EigenTensor<T, 2>::From(*weight);
auto u_t = EigenTensor<T, 2>::From(*u); auto u_t = EigenTensor<T, 2>::From(*u);
auto v_t = EigenTensor<T, 2>::From(*v); auto v_t = EigenTensor<T, 2>::From(*v);
...@@ -88,6 +86,23 @@ static inline void CalcMatrixSigmaAndNormWeight( ...@@ -88,6 +86,23 @@ static inline void CalcMatrixSigmaAndNormWeight(
Array1(h)); Array1(h));
u_t.device(place) = u_t / (u_t_norm + u_t_norm.constant(eps)); u_t.device(place) = u_t / (u_t_norm + u_t_norm.constant(eps));
} }
}
// CalcMatrixSigmaAndNormWeight will not update u and v
template <typename DeviceContext, typename T>
static inline void CalcMatrixSigmaAndNormWeight(
Tensor* sigma, const Tensor* u, const Tensor* v,
Tensor* weight, const int power_iters,
const float eps, const framework::ExecutionContext& ctx) {
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
auto blas = math::GetBlas<DeviceContext, T>(ctx);
auto sigma_t = EigenTensor<T, 2>::From(*sigma);
auto weight_t = EigenTensor<T, 2>::From(*weight);
auto u_t = EigenTensor<T, 2>::From(*u);
const int h = weight->dims()[0];
const int w = weight->dims()[1];
Tensor weight_v; Tensor weight_v;
weight_v.mutable_data<T>({h, 1}, ctx.GetPlace()); weight_v.mutable_data<T>({h, 1}, ctx.GetPlace());
blas.MatMul(*weight, false, *v, false, T(1), &weight_v, T(0)); blas.MatMul(*weight, false, *v, false, T(1), &weight_v, T(0));
...@@ -109,6 +124,8 @@ class SpectralNormKernel : public framework::OpKernel<T> { ...@@ -109,6 +124,8 @@ class SpectralNormKernel : public framework::OpKernel<T> {
auto u = ctx.Input<Tensor>("U"); auto u = ctx.Input<Tensor>("U");
auto v = ctx.Input<Tensor>("V"); auto v = ctx.Input<Tensor>("V");
auto out = ctx.Output<Tensor>("Out"); auto out = ctx.Output<Tensor>("Out");
auto u_out = ctx.Output<Tensor>("UOut");
auto v_out = ctx.Output<Tensor>("VOut");
int dim = ctx.Attr<int>("dim"); int dim = ctx.Attr<int>("dim");
int power_iters = ctx.Attr<int>("power_iters"); int power_iters = ctx.Attr<int>("power_iters");
...@@ -144,11 +161,13 @@ class SpectralNormKernel : public framework::OpKernel<T> { ...@@ -144,11 +161,13 @@ class SpectralNormKernel : public framework::OpKernel<T> {
Tensor sigma; Tensor sigma;
sigma.mutable_data<T>(weight_mat.dims(), ctx.GetPlace()); sigma.mutable_data<T>(weight_mat.dims(), ctx.GetPlace());
Tensor uu, vv; TensorCopySync(*u, ctx.GetPlace(), u_out);
TensorCopySync(*u, ctx.GetPlace(), &uu); TensorCopySync(*v, ctx.GetPlace(), v_out);
TensorCopySync(*v, ctx.GetPlace(), &vv); UpdateUandV<DeviceContext, T>(
&(u_out->Resize({h, 1})), &(v_out->Resize({w, 1})), &weight_mat,
power_iters, eps, ctx);
CalcMatrixSigmaAndNormWeight<DeviceContext, T>( CalcMatrixSigmaAndNormWeight<DeviceContext, T>(
&sigma, &(uu.Resize({h, 1})), &(vv.Resize({w, 1})), &weight_mat, &sigma, &(u_out->Resize({h, 1})), &(v_out->Resize({w, 1})), &weight_mat,
power_iters, eps, ctx); power_iters, eps, ctx);
if (dim != 0) { if (dim != 0) {
...@@ -180,8 +199,8 @@ class SpectralNormGradKernel : public framework::OpKernel<T> { ...@@ -180,8 +199,8 @@ class SpectralNormGradKernel : public framework::OpKernel<T> {
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(ctx); auto blas = math::GetBlas<DeviceContext, T>(ctx);
auto weight = ctx.Input<Tensor>("Weight"); auto weight = ctx.Input<Tensor>("Weight");
auto u = ctx.Input<Tensor>("U"); auto u_out = ctx.Input<Tensor>("UOut");
auto v = ctx.Input<Tensor>("V"); auto v_out = ctx.Input<Tensor>("VOut");
auto out_grad = ctx.Input<Tensor>(framework::GradVarName("Out")); auto out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto weight_grad = ctx.Output<Tensor>(framework::GradVarName("Weight")); auto weight_grad = ctx.Output<Tensor>(framework::GradVarName("Weight"));
...@@ -189,8 +208,12 @@ class SpectralNormGradKernel : public framework::OpKernel<T> { ...@@ -189,8 +208,12 @@ class SpectralNormGradKernel : public framework::OpKernel<T> {
int power_iters = ctx.Attr<int>("power_iters"); int power_iters = ctx.Attr<int>("power_iters");
float eps = ctx.Attr<float>("eps"); float eps = ctx.Attr<float>("eps");
const int h = u->dims()[0]; const int h = u_out->dims()[0];
const int w = v->dims()[0]; const int w = v_out->dims()[0];
Tensor u_mat, v_mat;
TensorCopySync(*u_out, ctx.GetPlace(), &u_mat);
TensorCopySync(*v_out, ctx.GetPlace(), &v_mat);
Tensor weight_mat, out_grad_mat; Tensor weight_mat, out_grad_mat;
auto dims = weight->dims(); auto dims = weight->dims();
...@@ -225,16 +248,14 @@ class SpectralNormGradKernel : public framework::OpKernel<T> { ...@@ -225,16 +248,14 @@ class SpectralNormGradKernel : public framework::OpKernel<T> {
Tensor sigma; Tensor sigma;
sigma.mutable_data<T>(weight_mat.dims(), ctx.GetPlace()); sigma.mutable_data<T>(weight_mat.dims(), ctx.GetPlace());
Tensor uu, vv;
TensorCopySync(*u, ctx.GetPlace(), &uu);
TensorCopySync(*v, ctx.GetPlace(), &vv);
CalcMatrixSigmaAndNormWeight<DeviceContext, T>( CalcMatrixSigmaAndNormWeight<DeviceContext, T>(
&sigma, &(uu.Resize({h, 1})), &(vv.Resize({w, 1})), &weight_mat, &sigma, &(u_mat.Resize({h, 1})), &(v_mat.Resize({w, 1})), &weight_mat,
power_iters, eps, ctx); power_iters, eps, ctx);
Tensor uv; Tensor uv;
uv.mutable_data<T>({h, w}, ctx.GetPlace()); uv.mutable_data<T>({h, w}, ctx.GetPlace());
blas.MatMul(uu.Resize({h, 1}), false, vv.Resize({w, 1}), false, T(1), &uv, blas.MatMul(u_mat.Resize({h, 1}), false, v_mat.Resize({w, 1}), false, T(1), &uv,
T(0)); T(0));
Tensor weight_grad_mat; Tensor weight_grad_mat;
......
...@@ -1375,7 +1375,7 @@ class BatchNorm(layers.Layer): ...@@ -1375,7 +1375,7 @@ class BatchNorm(layers.Layer):
outputs = { outputs = {
"Y": [batch_norm_out], "Y": [batch_norm_out],
"MeanOut": [mean_out], "MeanOut": [],
"VarianceOut": [variance_out], "VarianceOut": [variance_out],
"SavedMean": [saved_mean], "SavedMean": [saved_mean],
"SavedVariance": [saved_variance] "SavedVariance": [saved_variance]
...@@ -3031,9 +3031,11 @@ class SpectralNorm(layers.Layer): ...@@ -3031,9 +3031,11 @@ class SpectralNorm(layers.Layer):
dim(int, optional): The index of dimension which should be permuted to the first before reshaping Input(Weight) to matrix, it should be set as 0 if Input(Weight) is the weight of fc layer, and should be set as 1 if Input(Weight) is the weight of conv layer. Default: 0. dim(int, optional): The index of dimension which should be permuted to the first before reshaping Input(Weight) to matrix, it should be set as 0 if Input(Weight) is the weight of fc layer, and should be set as 1 if Input(Weight) is the weight of conv layer. Default: 0.
power_iters(int, optional): The number of power iterations to calculate spectral norm. Default: 1. power_iters(int, optional): The number of power iterations to calculate spectral norm. Default: 1.
eps(float, optional): The epsilon for numerical stability in calculating norms. Default: 1e-12. eps(float, optional): The epsilon for numerical stability in calculating norms. Default: 1e-12.
fix_state(bool, optional): whether to update the two vectors `u` and `v`. Default: True.
name (str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` . name (str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` .
dtype (str, optional): Data type, it can be "float32" or "float64". Default: "float32". dtype (str, optional): Data type, it can be "float32" or "float64". Default: "float32".
Returns: Returns:
None None
...@@ -3055,10 +3057,12 @@ class SpectralNorm(layers.Layer): ...@@ -3055,10 +3057,12 @@ class SpectralNorm(layers.Layer):
dim=0, dim=0,
power_iters=1, power_iters=1,
eps=1e-12, eps=1e-12,
fix_state=True,
dtype='float32'): dtype='float32'):
super(SpectralNorm, self).__init__() super(SpectralNorm, self).__init__()
self._power_iters = power_iters self._power_iters = power_iters
self._eps = eps self._eps = eps
self._fix_state = fix_state
self._dim = dim self._dim = dim
self._dtype = dtype self._dtype = dtype
...@@ -3080,10 +3084,31 @@ class SpectralNorm(layers.Layer): ...@@ -3080,10 +3084,31 @@ class SpectralNorm(layers.Layer):
default_initializer=Normal(0., 1.)) default_initializer=Normal(0., 1.))
self.weight_v.stop_gradient = True self.weight_v.stop_gradient = True
if fix_state:
self.out_weight_u = self.create_parameter(
attr=ParamAttr(),
shape=[h],
dtype=self._dtype,
default_initializer=Normal(0., 1.))
self.out_weight_u.stop_gradient = True
self.out_weight_v = self.create_parameter(
attr=ParamAttr(),
shape=[w],
dtype=self._dtype,
default_initializer=Normal(0., 1.))
self.out_weight_v.stop_gradient = True
else:
self.out_weight_u = self.weight_u
self.out_weight_v = self.weight_v
def forward(self, weight): def forward(self, weight):
check_variable_and_dtype(weight, "weight", ['float32', 'float64'], check_variable_and_dtype(weight, "weight", ['float32', 'float64'],
'SpectralNorm') 'SpectralNorm')
inputs = {'Weight': weight, 'U': self.weight_u, 'V': self.weight_v} inputs = {
'Weight': weight, 'U': self.weight_u, 'V': self.weight_v,
'UOut': self.out_weight_u, 'VOut': self.out_weight_v,
}
out = self._helper.create_variable_for_type_inference(self._dtype) out = self._helper.create_variable_for_type_inference(self._dtype)
self._helper.append_op( self._helper.append_op(
type="spectral_norm", type="spectral_norm",
......
...@@ -44,7 +44,7 @@ def spectral_norm(weight, u, v, dim, power_iters, eps): ...@@ -44,7 +44,7 @@ def spectral_norm(weight, u, v, dim, power_iters, eps):
u = u / (u_norm + eps) u = u / (u_norm + eps)
sigma = (u * np.matmul(weight_mat, v)).sum() sigma = (u * np.matmul(weight_mat, v)).sum()
return weight / sigma return weight / sigma, u, v
@skip_check_grad_ci( @skip_check_grad_ci(
...@@ -63,6 +63,7 @@ class TestSpectralNormOpNoGrad(OpTest): ...@@ -63,6 +63,7 @@ class TestSpectralNormOpNoGrad(OpTest):
"dim": self.dim, "dim": self.dim,
"power_iters": self.power_iters, "power_iters": self.power_iters,
"eps": self.eps, "eps": self.eps,
"fix_state": self.fix_state,
} }
self.inputs = { self.inputs = {
...@@ -71,9 +72,9 @@ class TestSpectralNormOpNoGrad(OpTest): ...@@ -71,9 +72,9 @@ class TestSpectralNormOpNoGrad(OpTest):
"V": v, "V": v,
} }
output = spectral_norm(weight, u, v, self.dim, self.power_iters, output, new_u, new_v = spectral_norm(weight, u, v, self.dim, self.power_iters,
self.eps) self.eps)
self.outputs = {"Out": output} self.outputs = {"Out": output, "UOut": new_u, "VOut": new_v}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -85,6 +86,7 @@ class TestSpectralNormOpNoGrad(OpTest): ...@@ -85,6 +86,7 @@ class TestSpectralNormOpNoGrad(OpTest):
self.dim = 0 self.dim = 0
self.power_iters = 5 self.power_iters = 5
self.eps = 1e-12 self.eps = 1e-12
self.fix_state = True
@skip_check_grad_ci( @skip_check_grad_ci(
...@@ -99,6 +101,7 @@ class TestSpectralNormOpNoGrad2(TestSpectralNormOpNoGrad): ...@@ -99,6 +101,7 @@ class TestSpectralNormOpNoGrad2(TestSpectralNormOpNoGrad):
self.dim = 1 self.dim = 1
self.power_iters = 10 self.power_iters = 10
self.eps = 1e-12 self.eps = 1e-12
self.fix_state = True
class TestSpectralNormOp(TestSpectralNormOpNoGrad): class TestSpectralNormOp(TestSpectralNormOpNoGrad):
...@@ -115,6 +118,7 @@ class TestSpectralNormOp(TestSpectralNormOpNoGrad): ...@@ -115,6 +118,7 @@ class TestSpectralNormOp(TestSpectralNormOpNoGrad):
self.dim = 0 self.dim = 0
self.power_iters = 0 self.power_iters = 0
self.eps = 1e-12 self.eps = 1e-12
self.fix_state = True
class TestSpectralNormOp2(TestSpectralNormOp): class TestSpectralNormOp2(TestSpectralNormOp):
...@@ -125,6 +129,7 @@ class TestSpectralNormOp2(TestSpectralNormOp): ...@@ -125,6 +129,7 @@ class TestSpectralNormOp2(TestSpectralNormOp):
self.dim = 1 self.dim = 1
self.power_iters = 0 self.power_iters = 0
self.eps = 1e-12 self.eps = 1e-12
self.fix_state = True
class TestSpectralNormOpError(unittest.TestCase): class TestSpectralNormOpError(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册