From 3909108cae29ee035785e7e2fa44f1c7c8bbd9ea Mon Sep 17 00:00:00 2001 From: zhhsplendid Date: Tue, 26 Mar 2019 01:59:25 +0000 Subject: [PATCH] Add SpectralNormGradOpDescMaker Use SpectralNormGradOpDescMaker instead of DefaultGradOpDescMaker to avoid registering useless variables to improve GPU usage. test=develop --- paddle/fluid/operators/spectral_norm_op.cc | 27 +++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/spectral_norm_op.cc b/paddle/fluid/operators/spectral_norm_op.cc index 357d055756..04f659a465 100644 --- a/paddle/fluid/operators/spectral_norm_op.cc +++ b/paddle/fluid/operators/spectral_norm_op.cc @@ -10,6 +10,9 @@ limitations under the License. */ #include "paddle/fluid/operators/spectral_norm_op.h" + +#include + #include "paddle/fluid/framework/op_registry.h" namespace paddle { @@ -156,6 +159,28 @@ class SpectralNormOpMaker : public framework::OpProtoAndCheckerMaker { } }; +class SpectralNormGradOpDescMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + std::unique_ptr op(new framework::OpDesc()); + op->SetType("spectral_norm_grad"); + + op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); + op->SetInput("Weight", Input("Weight")); + op->SetInput("U", Input("U")); + op->SetInput("V", Input("V")); + + op->SetOutput(framework::GradVarName("Weight"), InputGrad("Weight")); + + op->SetAttrMap(Attrs()); + + return op; + } +}; + class SpectralNormOpGrad : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -185,7 +210,7 @@ class SpectralNormOpGrad : public framework::OperatorWithKernel { namespace ops = paddle::operators; REGISTER_OPERATOR(spectral_norm, ops::SpectralNormOp, ops::SpectralNormOpMaker, - paddle::framework::DefaultGradOpDescMaker); + ops::SpectralNormGradOpDescMaker); REGISTER_OPERATOR(spectral_norm_grad, ops::SpectralNormOpGrad); REGISTER_OP_CPU_KERNEL( spectral_norm, -- GitLab