提交 6a82c5cf 编写于 作者: M miraiwk

update spectral norm

上级 ec6499bc
......@@ -205,6 +205,8 @@ class SpectralNormGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetInput("Weight", this->Input("Weight"));
op->SetInput("U", this->Input("U"));
op->SetInput("V", this->Input("V"));
op->SetInput("UOut", this->Output("UOut"));
op->SetInput("VOut", this->Output("VOut"));
op->SetOutput(framework::GradVarName("Weight"), this->InputGrad("Weight"));
......
......@@ -64,6 +64,7 @@ template <typename DeviceContext, typename T>
static inline void UpdateUandV(
Tensor* u, Tensor* v, Tensor* weight, const int power_iters,
const float eps, const framework::ExecutionContext& ctx) {
if (power_iters <= 0) return;
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
auto blas = math::GetBlas<DeviceContext, T>(ctx);
auto u_t = EigenTensor<T, 2>::From(*u);
......@@ -92,8 +93,7 @@ static inline void UpdateUandV(
template <typename DeviceContext, typename T>
static inline void CalcMatrixSigmaAndNormWeight(
Tensor* sigma, const Tensor* u, const Tensor* v,
Tensor* weight, const int power_iters,
const float eps, const framework::ExecutionContext& ctx) {
Tensor* weight, const framework::ExecutionContext& ctx) {
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
auto blas = math::GetBlas<DeviceContext, T>(ctx);
auto sigma_t = EigenTensor<T, 2>::From(*sigma);
......@@ -168,7 +168,7 @@ class SpectralNormKernel : public framework::OpKernel<T> {
power_iters, eps, ctx);
CalcMatrixSigmaAndNormWeight<DeviceContext, T>(
&sigma, &(u_out->Resize({h, 1})), &(v_out->Resize({w, 1})), &weight_mat,
power_iters, eps, ctx);
ctx);
if (dim != 0) {
std::vector<int> perm;
......@@ -205,8 +205,6 @@ class SpectralNormGradKernel : public framework::OpKernel<T> {
auto weight_grad = ctx.Output<Tensor>(framework::GradVarName("Weight"));
int dim = ctx.Attr<int>("dim");
int power_iters = ctx.Attr<int>("power_iters");
float eps = ctx.Attr<float>("eps");
const int h = u_out->dims()[0];
const int w = v_out->dims()[0];
......@@ -251,7 +249,7 @@ class SpectralNormGradKernel : public framework::OpKernel<T> {
CalcMatrixSigmaAndNormWeight<DeviceContext, T>(
&sigma, &(u_mat.Resize({h, 1})), &(v_mat.Resize({w, 1})), &weight_mat,
power_iters, eps, ctx);
ctx);
Tensor uv;
uv.mutable_data<T>({h, w}, ctx.GetPlace());
......
......@@ -3031,6 +3031,7 @@ class SpectralNorm(layers.Layer):
dim(int, optional): The index of dimension which should be permuted to the first before reshaping Input(Weight) to matrix, it should be set as 0 if Input(Weight) is the weight of fc layer, and should be set as 1 if Input(Weight) is the weight of conv layer. Default: 0.
power_iters(int, optional): The number of power iterations to calculate spectral norm. Default: 1.
eps(float, optional): The epsilon for numerical stability in calculating norms. Default: 1e-12.
fix_state(bool, optional): whether to update the two vectors `u` and `v`. Default: True.
name (str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` .
dtype (str, optional): Data type, it can be "float32" or "float64". Default: "float32".
......@@ -3055,10 +3056,12 @@ class SpectralNorm(layers.Layer):
dim=0,
power_iters=1,
eps=1e-12,
fix_state=True,
dtype='float32'):
super(SpectralNorm, self).__init__()
self._power_iters = power_iters
self._eps = eps
self._fix_state = fix_state
self._dim = dim
self._dtype = dtype
......@@ -3080,10 +3083,31 @@ class SpectralNorm(layers.Layer):
default_initializer=Normal(0., 1.))
self.weight_v.stop_gradient = True
if fix_state:
self.out_weight_u = self.create_parameter(
attr=ParamAttr(),
shape=[h],
dtype=self._dtype,
default_initializer=Normal(0., 1.))
self.out_weight_u.stop_gradient = True
self.out_weight_v = self.create_parameter(
attr=ParamAttr(),
shape=[w],
dtype=self._dtype,
default_initializer=Normal(0., 1.))
self.out_weight_v.stop_gradient = True
else:
self.out_weight_u = self.weight_u
self.out_weight_v = self.weight_v
def forward(self, weight):
check_variable_and_dtype(weight, "weight", ['float32', 'float64'],
'SpectralNorm')
inputs = {'Weight': weight, 'U': self.weight_u, 'V': self.weight_v}
inputs = {
'Weight': weight, 'U': self.weight_u, 'V': self.weight_v,
'UOut': self.out_weight_u, 'VOut': self.out_weight_v,
}
out = self._helper.create_variable_for_type_inference(self._dtype)
self._helper.append_op(
type="spectral_norm",
......
......@@ -3720,11 +3720,13 @@ def spectral_norm(weight, dim=0, power_iters=1, eps=1e-12, name=None):
# create output
out = helper.create_variable(dtype=dtype)
u_out = helper.create_variable(dtype=dtype)
v_out = helper.create_variable(dtype=dtype)
helper.append_op(
type="spectral_norm",
inputs=inputs,
outputs={"Out": out, },
outputs={"Out": out, "UOut": u_out, "VOut": v_out},
attrs={
"dim": dim,
"power_iters": power_iters,
......
......@@ -132,6 +132,40 @@ class TestSpectralNormOp2(TestSpectralNormOp):
self.fix_state = True
class TestSpectralNormOpFixState(TestSpectralNormOpNoGrad):
def test_check_grad_ignore_uv(self):
self.check_grad(
['Weight'],
'Out',
no_grad_set=set(["U", "V"]), )
def initTestCase(self):
self.weight_shape = (10, 12)
self.u_shape = (10, )
self.v_shape = (12, )
self.dim = 0
self.power_iters = 3
self.eps = 1e-12
self.fix_state = False
class TestSpectralNormOpUpdateState(TestSpectralNormOpNoGrad):
def test_check_grad_ignore_uv(self):
self.check_grad(
['Weight'],
'Out',
no_grad_set=set(["U", "V"]), )
def initTestCase(self):
self.weight_shape = (10, 12)
self.u_shape = (10, )
self.v_shape = (12, )
self.dim = 0
self.power_iters = 3
self.eps = 1e-12
self.fix_state = True
class TestSpectralNormOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册