From cb95587feb6e32c8595d02e76e58aa69a96b5035 Mon Sep 17 00:00:00 2001 From: dongzhihong Date: Wed, 19 Jul 2017 14:28:29 +0800 Subject: [PATCH] "ignore some gradient of specific op" --- paddle/framework/op_proto.proto | 6 ++++++ paddle/framework/op_registry.h | 16 ++++++++++------ 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/paddle/framework/op_proto.proto b/paddle/framework/op_proto.proto index 596b8588e..366c84e53 100644 --- a/paddle/framework/op_proto.proto +++ b/paddle/framework/op_proto.proto @@ -84,6 +84,11 @@ message VarProto { // "temporary_index": [1] // } optional bool temporary = 4 [default=false]; + + // The gradient of operator can be ignored immediately + // e.g. operator AddOp, y = x1 + x2, the gradient of dy/dx1, dy/dx2 + // can be ignored for the future optimized on graph. + optional bool ignore_gradient = 6; } // Op protocol message for 3rd-party language binding. @@ -105,4 +110,5 @@ message OpProto { // The type of that Op. required string type = 5; + } diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 6ba0784f1..dded0ad33 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -74,25 +74,29 @@ class OpProtoAndCheckerMaker { protected: void AddInput(const std::string& name, const std::string& comment, - bool multiple = false) { + bool multiple = false, bool ignore_gradient = false) { auto input = proto_->mutable_inputs()->Add(); *input->mutable_name() = name; *input->mutable_comment() = comment; + *input->set_ignore_gradient(ignore_gradient); input->set_multiple(multiple); if (multiple) { SetHasMultipleInput(); } } - void AddInputs(const std::string& name, const std::string& comment) { - AddInput(name, comment, true); + void AddInputs(const std::string& name, const std::string& comment, + bool ignore_gradient = false) { + AddInput(name, comment, true, ignore_gradient); } void AddOutput(const std::string& name, const std::string& comment, - bool temporary = false, bool multiple = false) { + bool temporary = false, bool multiple = false, + bool ignore_gradient = false) { auto output = proto_->mutable_outputs()->Add(); *output->mutable_name() = name; *output->mutable_comment() = comment; + *output->set_ignore_gradient(ignore_gradient); output->set_multiple(multiple); if (multiple) { SetHasMultipleOutput(); @@ -104,8 +108,8 @@ class OpProtoAndCheckerMaker { } void AddOutputs(const std::string& name, const std::string& comment, - bool temporary = false) { - AddOutput(name, comment, temporary, true); + bool temporary = false, bool ignore_gradient = false) { + AddOutput(name, comment, temporary, true, ignore_gradient); } template -- GitLab