diff --git a/paddle/fluid/operators/spectral_norm_op.h b/paddle/fluid/operators/spectral_norm_op.h index 897945d18883a7fff175c791e1bc79117ebb0f42..18bf14c64f08dcb70d279a922e47fa280a78e08f 100644 --- a/paddle/fluid/operators/spectral_norm_op.h +++ b/paddle/fluid/operators/spectral_norm_op.h @@ -27,18 +27,18 @@ using Array1 = Eigen::DSizes; using Array2 = Eigen::DSizes; using IndexPair = Eigen::IndexPair; -static inline void ResizeWeight(Tensor* weight_mat, const int dim) { - auto weight_dims = weight_mat->dims(); - int h = 1; - int w = 1; +static inline void CalcMatrixShape(const Tensor& weight, const int dim, int* h, + int* w) { + auto weight_dims = weight.dims(); + *h = 1; + *w = 1; for (int i = 0; i < weight_dims.size(); i++) { if (i <= dim) { - h *= weight_dims[i]; + *h *= weight_dims[i]; } else { - w *= weight_dims[i]; + *w *= weight_dims[i]; } } - *weight_mat = weight_mat->Resize({h, w}); } template @@ -55,42 +55,27 @@ static inline void CalcMatrixSigmaAndNormWeight( const int h = weight->dims()[0]; const int w = weight->dims()[1]; - // LOG(ERROR) << "weight: " << weight_t; - // LOG(ERROR) << "weight_trans: " << weight_trans_t; for (int i = 0; i < power_iters; i++) { - // v_t.device(place) = weight_trans_t.contract(u_t, product_dims); blas.MatMul(*weight, true, *u, false, T(1), v, T(0)); - // LOG(ERROR) << "iter v: " << v_t; auto v_t_norm = v_t.square().sum().sqrt().eval().reshape(Array1(1)).broadcast( Array1(w)); - // LOG(ERROR) << "iter v_norm: " << v_t_norm; v_t.device(place) = v_t / (v_t_norm + v_t_norm.constant(eps)); - // LOG(ERROR) << "iter norm v: " << v_t; - // u_t.device(place) = weight_t.contract(v_t, product_dims); blas.MatMul(*weight, false, *v, false, T(1), u, T(0)); - // LOG(ERROR) << "iter u: " << u_t; auto u_t_norm = u_t.square().sum().sqrt().eval().reshape(Array1(1)).broadcast( Array1(h)); u_t.device(place) = u_t / (u_t_norm + u_t_norm.constant(eps)); - // LOG(ERROR) << "iter norm u: " << u_t; } - // LOG(ERROR) << "h" << h << "w" << w; - // LOG(ERROR) << "u: " << u_t; - // LOG(ERROR) << "v: " << v_t; Tensor weight_v; weight_v.mutable_data({h, 1}, ctx.GetPlace()); blas.MatMul(*weight, false, *v, false, T(1), &weight_v, T(0)); auto weight_v_t = EigenTensor::From(weight_v); - // LOG(ERROR) << "weight_v: " << weight_v_t; sigma_t.device(place) = (u_t * weight_v_t) .sum() .eval() .reshape(Array2(1, 1)) .broadcast(Array2(h, w)); - // LOG(ERROR) << "weight: " << weight_t; - // LOG(ERROR) << "sigma: " << sigma_t; weight_t.device(place) = weight_t / sigma_t; } @@ -107,29 +92,78 @@ class SpectralNormKernel : public framework::OpKernel { int power_iters = ctx.Attr("power_iters"); float eps = ctx.Attr("eps"); - const int h = weight->dims()[0]; - const int w = weight->dims()[1]; - Tensor weight_mat; + int h, w; + CalcMatrixShape(*weight, dim, &h, &w); TensorCopySync(*weight, ctx.GetPlace(), &weight_mat); - ResizeWeight(&weight_mat, dim); + weight_mat = weight_mat.Resize({h, w}); Tensor sigma; - sigma.mutable_data(weight->dims(), ctx.GetPlace()); + sigma.mutable_data(weight_mat.dims(), ctx.GetPlace()); Tensor uu, vv; TensorCopySync(*u, ctx.GetPlace(), &uu); TensorCopySync(*v, ctx.GetPlace(), &vv); CalcMatrixSigmaAndNormWeight( &sigma, &(uu.Resize({h, 1})), &(vv.Resize({w, 1})), &weight_mat, power_iters, eps, ctx); - TensorCopySync(weight_mat, ctx.GetPlace(), out); + TensorCopySync(weight_mat.Resize(out->dims()), ctx.GetPlace(), out); } }; template class SpectralNormGradKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& ctx) const override {} + void Compute(const framework::ExecutionContext& ctx) const override { + auto& place = *ctx.template device_context().eigen_device(); + auto blas = math::GetBlas(ctx); + auto weight = ctx.Input("Weight"); + auto u = ctx.Input("U"); + auto v = ctx.Input("V"); + auto out_grad = ctx.Input(framework::GradVarName("Out")); + auto weight_grad = ctx.Output(framework::GradVarName("Weight")); + + int dim = ctx.Attr("dim"); + int power_iters = ctx.Attr("power_iters"); + float eps = ctx.Attr("eps"); + + Tensor weight_mat, out_grad_mat; + int h, w; + CalcMatrixShape(*weight, dim, &h, &w); + TensorCopySync(*weight, ctx.GetPlace(), &weight_mat); + TensorCopySync(*out_grad, ctx.GetPlace(), &out_grad_mat); + weight_mat = weight_mat.Resize({h, w}); + out_grad_mat = out_grad_mat.Resize({h, w}); + + Tensor sigma; + sigma.mutable_data(weight_mat.dims(), ctx.GetPlace()); + Tensor uu, vv; + TensorCopySync(*u, ctx.GetPlace(), &uu); + TensorCopySync(*v, ctx.GetPlace(), &vv); + CalcMatrixSigmaAndNormWeight( + &sigma, &(uu.Resize({h, 1})), &(vv.Resize({w, 1})), &weight_mat, + power_iters, eps, ctx); + + Tensor uv; + uv.mutable_data({h, w}, ctx.GetPlace()); + blas.MatMul(uu.Resize({h, 1}), false, vv.Resize({w, 1}), false, T(1), &uv, + T(0)); + + Tensor weight_grad_mat, ones; + weight_grad_mat.mutable_data({h, w}, ctx.GetPlace()); + ones.mutable_data({h, w}, ctx.GetPlace()); + auto weight_grad_mat_t = EigenTensor::From(weight_grad_mat); + auto weight_mat_t = EigenTensor::From(weight_mat); + auto out_grad_mat_t = EigenTensor::From(out_grad_mat); + auto sigma_t = EigenTensor::From(sigma); + auto uv_t = EigenTensor::From(uv); + auto ones_t = EigenTensor::From(ones).setConstant((T)1); + weight_mat_t.device(place) = + weight_mat_t.sum().eval().reshape(Array2(1, 1)).broadcast(Array2(h, w)); + weight_grad_mat_t.device(place) = + out_grad_mat_t * (ones_t - uv_t * weight_mat_t) / sigma_t; + TensorCopySync(weight_grad_mat.Resize(weight_grad->dims()), ctx.GetPlace(), + weight_grad); + } }; } // namespace operators diff --git a/python/paddle/fluid/tests/unittests/test_spectral_norm_op.py b/python/paddle/fluid/tests/unittests/test_spectral_norm_op.py index 57a1d3ed11785c6dc7281ab49786efded201e6d2..79594b3842e16f35c8ca36503453446a186a3058 100644 --- a/python/paddle/fluid/tests/unittests/test_spectral_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_spectral_norm_op.py @@ -44,13 +44,13 @@ def spectral_norm(weight, u, v, dim, power_iters, eps): return (weight_mat / sigma).reshape(weight.shape) -class TestSpectralNormOp(OpTest): +class TestSpectralNormOpNoGrad(OpTest): def setUp(self): self.initTestCase() self.op_type = 'spectral_norm' weight = np.random.random(self.weight_shape).astype('float32') - u = np.random.random(self.u_shape).astype('float32') - v = np.random.random(self.v_shape).astype('float32') + u = np.random.normal(0., 1., self.u_shape).astype('float32') + v = np.random.normal(0., 1., self.v_shape).astype('float32') self.attrs = { "dim": self.dim, @@ -76,7 +76,44 @@ class TestSpectralNormOp(OpTest): self.u_shape = (2, ) self.v_shape = (3, ) self.dim = 0 - self.power_iters = 2 + self.power_iters = 5 + self.eps = 1e-12 + + +class TestSpectralNormOpNoGrad2(TestSpectralNormOpNoGrad): + def initTestCase(self): + self.weight_shape = (2, 3, 3, 3) + self.u_shape = (6, ) + self.v_shape = (9, ) + self.dim = 1 + self.power_iters = 10 + self.eps = 1e-12 + + +class TestSpectralNormOp(TestSpectralNormOpNoGrad): + def test_check_grad_ignore_uv(self): + self.check_grad( + ['Weight'], + 'Out', + no_grad_set=set(["U", "V"]), + max_relative_error=0.1) + + def initTestCase(self): + self.weight_shape = (2, 3) + self.u_shape = (2, ) + self.v_shape = (3, ) + self.dim = 0 + self.power_iters = 0 + self.eps = 1e-12 + + +class TestSpectralNormOp2(TestSpectralNormOp): + def initTestCase(self): + self.weight_shape = (2, 3, 3, 3) + self.u_shape = (6, ) + self.v_shape = (9, ) + self.dim = 1 + self.power_iters = 0 self.eps = 1e-12