diff --git a/paddle/fluid/operators/spectral_norm_op.cu b/paddle/fluid/operators/spectral_norm_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..634d5b310baa383e5cd764305402efca7ca76017 --- /dev/null +++ b/paddle/fluid/operators/spectral_norm_op.cu @@ -0,0 +1,22 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/fluid/operators/spectral_norm_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + spectral_norm, + ops::SpectralNormKernel, + ops::SpectralNormKernel); +REGISTER_OP_CUDA_KERNEL( + spectral_norm_grad, + ops::SpectralNormGradKernel, + ops::SpectralNormGradKernel); diff --git a/paddle/fluid/operators/spectral_norm_op.h b/paddle/fluid/operators/spectral_norm_op.h index 876dacf3bb29117f7ca5adcde679422f45e624b3..897945d18883a7fff175c791e1bc79117ebb0f42 100644 --- a/paddle/fluid/operators/spectral_norm_op.h +++ b/paddle/fluid/operators/spectral_norm_op.h @@ -46,47 +46,51 @@ static inline void CalcMatrixSigmaAndNormWeight( Tensor* sigma, Tensor* u, Tensor* v, Tensor* weight, const int power_iters, const float eps, const framework::ExecutionContext& ctx) { auto& place = *ctx.template device_context().eigen_device(); + auto blas = math::GetBlas(ctx); auto sigma_t = EigenTensor::From(*sigma); auto weight_t = EigenTensor::From(*weight); - auto u_t = EigenTensor::From(*u); - auto v_t = EigenTensor::From(*v); + auto u_t = EigenTensor::From(*u); + auto v_t = EigenTensor::From(*v); const int h = weight->dims()[0]; const int w = weight->dims()[1]; - Eigen::array perm = {1, 0}; - Eigen::array product_dims = {IndexPair(1, 0)}; - auto weight_trans_t = weight_t.shuffle(perm); - LOG(ERROR) << "weight: " << weight_t; - LOG(ERROR) << "weight_trans: " << weight_trans_t; + // LOG(ERROR) << "weight: " << weight_t; + // LOG(ERROR) << "weight_trans: " << weight_trans_t; for (int i = 0; i < power_iters; i++) { - v_t.device(place) = weight_trans_t.contract(u_t, product_dims); - LOG(ERROR) << "iter v: " << v_t; + // v_t.device(place) = weight_trans_t.contract(u_t, product_dims); + blas.MatMul(*weight, true, *u, false, T(1), v, T(0)); + // LOG(ERROR) << "iter v: " << v_t; auto v_t_norm = v_t.square().sum().sqrt().eval().reshape(Array1(1)).broadcast( Array1(w)); - LOG(ERROR) << "iter v_norm: " << v_t_norm; + // LOG(ERROR) << "iter v_norm: " << v_t_norm; v_t.device(place) = v_t / (v_t_norm + v_t_norm.constant(eps)); - LOG(ERROR) << "iter norm v: " << v_t; - u_t.device(place) = weight_t.contract(v_t, product_dims); - LOG(ERROR) << "iter u: " << u_t; + // LOG(ERROR) << "iter norm v: " << v_t; + // u_t.device(place) = weight_t.contract(v_t, product_dims); + blas.MatMul(*weight, false, *v, false, T(1), u, T(0)); + // LOG(ERROR) << "iter u: " << u_t; auto u_t_norm = u_t.square().sum().sqrt().eval().reshape(Array1(1)).broadcast( Array1(h)); u_t.device(place) = u_t / (u_t_norm + u_t_norm.constant(eps)); - LOG(ERROR) << "iter norm u: " << u_t; + // LOG(ERROR) << "iter norm u: " << u_t; } - LOG(ERROR) << "h" << h << "w" << w; - LOG(ERROR) << "u: " << u_t; - LOG(ERROR) << "v: " << v_t; - LOG(ERROR) << "weight_v: " << weight_t.contract(v_t, product_dims); - sigma_t.device(place) = (u_t * weight_t.contract(v_t, product_dims)) + // LOG(ERROR) << "h" << h << "w" << w; + // LOG(ERROR) << "u: " << u_t; + // LOG(ERROR) << "v: " << v_t; + Tensor weight_v; + weight_v.mutable_data({h, 1}, ctx.GetPlace()); + blas.MatMul(*weight, false, *v, false, T(1), &weight_v, T(0)); + auto weight_v_t = EigenTensor::From(weight_v); + // LOG(ERROR) << "weight_v: " << weight_v_t; + sigma_t.device(place) = (u_t * weight_v_t) .sum() .eval() .reshape(Array2(1, 1)) .broadcast(Array2(h, w)); - LOG(ERROR) << "weight: " << weight_t; - LOG(ERROR) << "sigma: " << sigma_t; + // LOG(ERROR) << "weight: " << weight_t; + // LOG(ERROR) << "sigma: " << sigma_t; weight_t.device(place) = weight_t / sigma_t; } @@ -103,6 +107,9 @@ class SpectralNormKernel : public framework::OpKernel { int power_iters = ctx.Attr("power_iters"); float eps = ctx.Attr("eps"); + const int h = weight->dims()[0]; + const int w = weight->dims()[1]; + Tensor weight_mat; TensorCopySync(*weight, ctx.GetPlace(), &weight_mat); ResizeWeight(&weight_mat, dim); @@ -113,7 +120,8 @@ class SpectralNormKernel : public framework::OpKernel { TensorCopySync(*u, ctx.GetPlace(), &uu); TensorCopySync(*v, ctx.GetPlace(), &vv); CalcMatrixSigmaAndNormWeight( - &sigma, &uu, &vv, &weight_mat, power_iters, eps, ctx); + &sigma, &(uu.Resize({h, 1})), &(vv.Resize({w, 1})), &weight_mat, + power_iters, eps, ctx); TensorCopySync(weight_mat, ctx.GetPlace(), out); } }; diff --git a/python/paddle/fluid/tests/unittests/test_spectral_norm_op.py b/python/paddle/fluid/tests/unittests/test_spectral_norm_op.py index 2d7ff16aa664fb2d13574886950e6ee1dccabcd0..57a1d3ed11785c6dc7281ab49786efded201e6d2 100644 --- a/python/paddle/fluid/tests/unittests/test_spectral_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_spectral_norm_op.py @@ -21,17 +21,36 @@ from op_test import OpTest from paddle.fluid import core +def spectral_norm(weight, u, v, dim, power_iters, eps): + h = w = 1 + for i, d in enumerate(weight.shape): + if i <= dim: + h *= d + else: + w *= d + weight_mat = weight.reshape((h, w)) + + u = u.reshape((h, 1)) + v = v.reshape((w, 1)) + for i in range(power_iters): + v = np.matmul(weight_mat.T, u) + v_norm = np.sqrt((v * v).sum()) + v = v / (v_norm + eps) + u = np.matmul(weight_mat, v) + u_norm = np.sqrt((u * u).sum()) + u = u / (u_norm + eps) + + sigma = (u * np.matmul(weight_mat, v)).sum() + return (weight_mat / sigma).reshape(weight.shape) + + class TestSpectralNormOp(OpTest): def setUp(self): self.initTestCase() self.op_type = 'spectral_norm' - # weight = np.random.random(self.weight_shape).astype('float32') - # u = np.random.random(self.u_shape).astype('float32') - # v = np.random.random(self.u_shape).astype('float32') - weight = np.ones(self.weight_shape).astype('float32') - weight[1, :] = 2. - u = np.ones(self.u_shape).astype('float32') - v = np.ones(self.v_shape).astype('float32') + weight = np.random.random(self.weight_shape).astype('float32') + u = np.random.random(self.u_shape).astype('float32') + v = np.random.random(self.v_shape).astype('float32') self.attrs = { "dim": self.dim, @@ -45,8 +64,9 @@ class TestSpectralNormOp(OpTest): "V": v, } - output = weight - self.outputs = {"Out": weight, } + output = spectral_norm(weight, u, v, self.dim, self.power_iters, + self.eps) + self.outputs = {"Out": output} def test_check_output(self): self.check_output() @@ -56,7 +76,7 @@ class TestSpectralNormOp(OpTest): self.u_shape = (2, ) self.v_shape = (3, ) self.dim = 0 - self.power_iters = 1 + self.power_iters = 2 self.eps = 1e-12