diff --git a/paddle/fluid/operators/spectral_norm_op.cc b/paddle/fluid/operators/spectral_norm_op.cc index e7fbf4e6ecdc6d584a76fb544161b59fae359a3f..56856c45b47a8450e28cc8cc65e15bb4e8a7b21c 100644 --- a/paddle/fluid/operators/spectral_norm_op.cc +++ b/paddle/fluid/operators/spectral_norm_op.cc @@ -33,19 +33,34 @@ class SpectralNormOp : public framework::OperatorWithKernel { "Output(Out) of SpectralNormOp should not be null."); auto dim_weight = ctx->GetInputDim("Weight"); - auto weight_dimsize = dim_weight.size(); - PADDLE_ENFORCE(weight_dimsize >= 2 && weight_dimsize <= 5, - "The size of dims of Input(Weights) can only be 2, 3," + auto rank_weight = dim_weight.size(); + PADDLE_ENFORCE(rank_weight >= 2 && rank_weight <= 5, + "The rank of Input(Weights) can only be 2, 3," "4, 5 for fc, conv1d, conv2d, conv3d layers."); int dim = ctx->Attrs().Get("dim"); int power_iters = ctx->Attrs().Get("power_iters"); - PADDLE_ENFORCE(dim >= 0 && dim < weight_dimsize - 1, - "Attr(dim) should be larger equal 0 and less then the" - "size of dims of Input(Weights) - 1,"); + PADDLE_ENFORCE(dim == 0 || dim == 1, "Attr(dim) can only be 0 or 1"); PADDLE_ENFORCE(power_iters >= 0, "Attr(power_iters) should be larger equal then 0"); + int h = dim_weight[dim]; + int w = 1; + for (int i = 0; i < rank_weight; i++) { + if (i != dim) { + w *= dim_weight[i]; + } + } + auto dim_u = ctx->GetInputDim("U"); + auto dim_v = ctx->GetInputDim("V"); + PADDLE_ENFORCE_EQ(dim_u[0], h, + "Input(U) dims[0] should be equal to " + "Input(Weight) dims[Attr(dim)]"); + PADDLE_ENFORCE_EQ( + dim_v[0], w, + "Input(V) dims[0] should be equal to " + "the product of Input(Weight) dims except dims[Attr(dim)]"); + ctx->SetOutputDim("Out", dim_weight); ctx->ShareLoD("Weight", /*->*/ "Out"); } diff --git a/paddle/fluid/operators/spectral_norm_op.h b/paddle/fluid/operators/spectral_norm_op.h index 18bf14c64f08dcb70d279a922e47fa280a78e08f..45a3ad8d532a176a39e226cd3ea371699eaa87e8 100644 --- a/paddle/fluid/operators/spectral_norm_op.h +++ b/paddle/fluid/operators/spectral_norm_op.h @@ -10,6 +10,7 @@ limitations under the License. */ #pragma once +#include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/blas.h" @@ -27,17 +28,33 @@ using Array1 = Eigen::DSizes; using Array2 = Eigen::DSizes; using IndexPair = Eigen::IndexPair; -static inline void CalcMatrixShape(const Tensor& weight, const int dim, int* h, - int* w) { - auto weight_dims = weight.dims(); - *h = 1; - *w = 1; - for (int i = 0; i < weight_dims.size(); i++) { - if (i <= dim) { - *h *= weight_dims[i]; - } else { - *w *= weight_dims[i]; - } +template +static inline void TransCompute(const int rank, const Tensor& in, Tensor* out, + const std::vector& perm, + const DeviceContext& dev_ctx) { + if (rank <= 1 || rank > 5) { + PADDLE_THROW("Invalid weight rank."); + } + + switch (rank) { + case 2: + math::Transpose trans2; + trans2(dev_ctx, in, out, perm); + break; + case 3: + math::Transpose trans3; + trans3(dev_ctx, in, out, perm); + break; + case 4: + math::Transpose trans4; + trans4(dev_ctx, in, out, perm); + break; + case 5: + math::Transpose trans5; + trans5(dev_ctx, in, out, perm); + break; + default: + break; } } @@ -83,6 +100,7 @@ template class SpectralNormKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + auto& dev_ctx = ctx.template device_context(); auto weight = ctx.Input("Weight"); auto u = ctx.Input("U"); auto v = ctx.Input("V"); @@ -92,10 +110,32 @@ class SpectralNormKernel : public framework::OpKernel { int power_iters = ctx.Attr("power_iters"); float eps = ctx.Attr("eps"); + const int h = u->dims()[0]; + const int w = v->dims()[0]; + Tensor weight_mat; - int h, w; - CalcMatrixShape(*weight, dim, &h, &w); - TensorCopySync(*weight, ctx.GetPlace(), &weight_mat); + auto dims = weight->dims(); + const int rank = dims.size(); + std::vector real_dims; + if (dim != 0) { + std::vector perm; + perm.push_back(dim); + real_dims.push_back(dims[dim]); + for (int i = 0; i < rank; i++) { + if (i != dim) { + perm.push_back(i); + real_dims.push_back(dims[i]); + } + } + weight_mat.mutable_data(framework::make_ddim(real_dims), + ctx.GetPlace()); + TransCompute(rank, *weight, &weight_mat, perm, dev_ctx); + } else { + for (int i = 0; i < rank; i++) { + real_dims.push_back(i); + } + TensorCopySync(*weight, ctx.GetPlace(), &weight_mat); + } weight_mat = weight_mat.Resize({h, w}); Tensor sigma; @@ -106,7 +146,25 @@ class SpectralNormKernel : public framework::OpKernel { CalcMatrixSigmaAndNormWeight( &sigma, &(uu.Resize({h, 1})), &(vv.Resize({w, 1})), &weight_mat, power_iters, eps, ctx); - TensorCopySync(weight_mat.Resize(out->dims()), ctx.GetPlace(), out); + + if (dim != 0) { + std::vector perm; + for (int i = 0; i < rank; i++) { + if (i < dim) { + perm.push_back(i + 1); + } else if (i == dim) { + perm.push_back(0); + } else { + perm.push_back(i); + } + } + out->mutable_data(dims, ctx.GetPlace()); + TransCompute( + rank, weight_mat.Resize(framework::make_ddim(real_dims)), out, perm, + dev_ctx); + } else { + TensorCopySync(weight_mat.Resize(dims), ctx.GetPlace(), out); + } } }; @@ -115,6 +173,7 @@ class SpectralNormGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto& place = *ctx.template device_context().eigen_device(); + auto& dev_ctx = ctx.template device_context(); auto blas = math::GetBlas(ctx); auto weight = ctx.Input("Weight"); auto u = ctx.Input("U"); @@ -126,11 +185,37 @@ class SpectralNormGradKernel : public framework::OpKernel { int power_iters = ctx.Attr("power_iters"); float eps = ctx.Attr("eps"); + const int h = u->dims()[0]; + const int w = v->dims()[0]; + Tensor weight_mat, out_grad_mat; - int h, w; - CalcMatrixShape(*weight, dim, &h, &w); - TensorCopySync(*weight, ctx.GetPlace(), &weight_mat); - TensorCopySync(*out_grad, ctx.GetPlace(), &out_grad_mat); + auto dims = weight->dims(); + const int rank = dims.size(); + std::vector real_dims; + if (dim != 0) { + std::vector perm; + perm.push_back(dim); + real_dims.push_back(dims[dim]); + for (int i = 0; i < rank; i++) { + if (i != dim) { + perm.push_back(i); + real_dims.push_back(dims[i]); + } + } + weight_mat.mutable_data(framework::make_ddim(real_dims), + ctx.GetPlace()); + out_grad_mat.mutable_data(framework::make_ddim(real_dims), + ctx.GetPlace()); + TransCompute(rank, *weight, &weight_mat, perm, dev_ctx); + TransCompute(rank, *out_grad, &out_grad_mat, perm, + dev_ctx); + } else { + for (int i = 0; i < rank; i++) { + real_dims.push_back(i); + } + TensorCopySync(*weight, ctx.GetPlace(), &weight_mat); + TensorCopySync(*out_grad, ctx.GetPlace(), &out_grad_mat); + } weight_mat = weight_mat.Resize({h, w}); out_grad_mat = out_grad_mat.Resize({h, w}); @@ -148,21 +233,37 @@ class SpectralNormGradKernel : public framework::OpKernel { blas.MatMul(uu.Resize({h, 1}), false, vv.Resize({w, 1}), false, T(1), &uv, T(0)); - Tensor weight_grad_mat, ones; + Tensor weight_grad_mat; weight_grad_mat.mutable_data({h, w}, ctx.GetPlace()); - ones.mutable_data({h, w}, ctx.GetPlace()); auto weight_grad_mat_t = EigenTensor::From(weight_grad_mat); auto weight_mat_t = EigenTensor::From(weight_mat); auto out_grad_mat_t = EigenTensor::From(out_grad_mat); auto sigma_t = EigenTensor::From(sigma); auto uv_t = EigenTensor::From(uv); - auto ones_t = EigenTensor::From(ones).setConstant((T)1); weight_mat_t.device(place) = weight_mat_t.sum().eval().reshape(Array2(1, 1)).broadcast(Array2(h, w)); weight_grad_mat_t.device(place) = - out_grad_mat_t * (ones_t - uv_t * weight_mat_t) / sigma_t; - TensorCopySync(weight_grad_mat.Resize(weight_grad->dims()), ctx.GetPlace(), - weight_grad); + out_grad_mat_t * (out_grad_mat_t.constant(1.0) - uv_t * weight_mat_t) / + sigma_t; + + if (dim != 0) { + std::vector perm; + for (int i = 0; i < rank; i++) { + if (i < dim) { + perm.push_back(i + 1); + } else if (i == dim) { + perm.push_back(0); + } else { + perm.push_back(i); + } + } + weight_grad->mutable_data(dims, ctx.GetPlace()); + TransCompute( + rank, weight_grad_mat.Resize(framework::make_ddim(real_dims)), + weight_grad, perm, dev_ctx); + } else { + TensorCopySync(weight_grad_mat.Resize(dims), ctx.GetPlace(), weight_grad); + } } }; diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 6be0df46993351eb0093ca3607b59dab6d10dd77..2eb18e447f789c538ae7e0d1334c343b89f36b7c 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -94,6 +94,7 @@ __all__ = [ 'multiplex', 'layer_norm', 'group_norm', + 'spectral_norm', 'softmax_with_cross_entropy', 'smooth_l1', 'one_hot', @@ -3347,6 +3348,80 @@ def group_norm(input, return helper.append_activation(group_norm_out) +@templatedoc() +def spectral_norm(weight, + dim=0, + power_iters=1, + eps=1e-12, + u_attr=None, + v_attr=None, + name=None): + """ + **Spectral Normalization Layer** + + Refer to `Spectral Normalization `_ . + + Args: + weight(${weight_type}): ${weight_comment} + dim(${dim_type}): ${dim_comment} + eps(${eps_type}): ${eps_comment} + u_attr(ParamAttr|None): The parameter attribute for vector u in + spectral calculatings, set None to use default attribute, which + generates random values in normal distribution N(0, 1). Default: None. + v_attr(ParamAttr|None): The parameter attribute for vector v in + spectral calculatings, set None to use default attribute, which + generates random values in normal distribution N(0, 1). Default: None. + name (str): The name of this layer. It is optional. + + Returns: + Variable: A tensor variable of weight after spetral normalization. + + Examples: + + >>> weight = fluid.layers.data(name='weight', shape=[8, 32, 32], + >>> dtype='float32') + >>> x = fluid.layers.spectral_norm(weight=data, dim=1, power_iters=2) + """ + helper = LayerHelper('spectral_norm', **locals()) + dtype = helper.input_dtype() + + # create intput and parameters + inputs = {'Weight': weight} + input_shape = input.shape + if data_layout != 'NCHW': + raise ValueError("unsupported data layout:" + data_layout) + param_shape = [input_shape[1]] + if param_attr: + scale = helper.create_parameter( + attr=helper.param_attr, + shape=param_shape, + dtype=dtype, + default_initializer=Constant(1.0)) + inputs['Scale'] = scale + if bias_attr: + bias = helper.create_parameter( + attr=helper.bias_attr, shape=param_shape, dtype=dtype, is_bias=True) + inputs['Bias'] = bias + + # create output + mean_out = helper.create_variable(dtype=dtype, stop_gradient=True) + variance_out = helper.create_variable(dtype=dtype, stop_gradient=True) + group_norm_out = helper.create_variable(dtype=dtype) + + helper.append_op( + type="group_norm", + inputs=inputs, + outputs={ + "Y": group_norm_out, + "Mean": mean_out, + "Variance": variance_out, + }, + attrs={"epsilon": epsilon, + "groups": groups}) + + return helper.append_activation(group_norm_out) + + def conv2d_transpose(input, num_filters, output_size=None, diff --git a/python/paddle/fluid/tests/unittests/test_spectral_norm_op.py b/python/paddle/fluid/tests/unittests/test_spectral_norm_op.py index 79594b3842e16f35c8ca36503453446a186a3058..549ed486d710616a8edf4028c94d7ba9642e6a91 100644 --- a/python/paddle/fluid/tests/unittests/test_spectral_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_spectral_norm_op.py @@ -22,13 +22,17 @@ from paddle.fluid import core def spectral_norm(weight, u, v, dim, power_iters, eps): - h = w = 1 - for i, d in enumerate(weight.shape): - if i <= dim: - h *= d - else: - w *= d - weight_mat = weight.reshape((h, w)) + shape = weight.shape + weight_mat = weight.copy() + h = shape[dim] + w = np.prod(shape) // h + if dim != 0: + perm = [dim] + [d for d in range(len(shape)) if d != dim] + weight_mat = weight_mat.transpose(perm) + real_shape = weight_mat.shape + else: + real_shape = shape + weight_mat = weight_mat.reshape((h, w)) u = u.reshape((h, 1)) v = v.reshape((w, 1)) @@ -41,7 +45,7 @@ def spectral_norm(weight, u, v, dim, power_iters, eps): u = u / (u_norm + eps) sigma = (u * np.matmul(weight_mat, v)).sum() - return (weight_mat / sigma).reshape(weight.shape) + return weight / sigma class TestSpectralNormOpNoGrad(OpTest): @@ -83,8 +87,8 @@ class TestSpectralNormOpNoGrad(OpTest): class TestSpectralNormOpNoGrad2(TestSpectralNormOpNoGrad): def initTestCase(self): self.weight_shape = (2, 3, 3, 3) - self.u_shape = (6, ) - self.v_shape = (9, ) + self.u_shape = (3, ) + self.v_shape = (18, ) self.dim = 1 self.power_iters = 10 self.eps = 1e-12 @@ -110,8 +114,8 @@ class TestSpectralNormOp(TestSpectralNormOpNoGrad): class TestSpectralNormOp2(TestSpectralNormOp): def initTestCase(self): self.weight_shape = (2, 3, 3, 3) - self.u_shape = (6, ) - self.v_shape = (9, ) + self.u_shape = (3, ) + self.v_shape = (18, ) self.dim = 1 self.power_iters = 0 self.eps = 1e-12