From e32e306821fc8ffd79ccbe6f9c090d1ad217fd56 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Wed, 26 Jul 2017 19:37:10 +0800 Subject: [PATCH] Develop backward building precess of single op --- paddle/framework/backward.cc | 23 +++++++++++++++++++++-- paddle/framework/operator.h | 3 +++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index d8653b5dd68..1531cb53f9d 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -12,8 +12,9 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include -#include +#include "paddle/framework/backward.h" +#include "paddle/framework/net.h" +#include "paddle/framework/op_registry.h" namespace paddle { namespace framework { @@ -71,6 +72,24 @@ static std::shared_ptr BackwardImpl( //! TODO(dzh) } else { //! TODO(fjy) + std::shared_ptr grad_op = OpRegistry::CreateGradOp(forwardOp); + for (std::string& grad_input : grad_op->inputs_) { + if (no_grad_names.count(grad_input)) { + std::string prefix = grad_input.substr( + 0, grad_input.size() - OperatorBase::GRAD_VAR_SUFFIX().size()); + grad_input = prefix + OperatorBase::ZERO_VAR_SUFFIX(); + std::vector fill_zeros_in = {prefix}; + std::vector fill_zeros_out = {grad_input}; + net.AddOp(OpRegistry::CreateOp("fill_zeros_like", fill_zeros_in, + fill_zeros_out, AttributeMap())); + } + } + for (std::string& grad_output : grad_op->output_) { + if (no_grad_names.count(grad_output)) { + grad_output = OperatorBase::EMPTY_VAR_NAME(); + } + } + net.AddOp(grad_op); } net->CompleteAddOp(); diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 65fddb68112..c2cd21a0806 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -67,6 +67,9 @@ class OperatorBase { /// e.g. Variable "x@GRAD" is the gradient of varibale "x". static std::string GRAD_VAR_SUFFIX() { return "@GRAD"; } + /// Variables with this suffix are supposed to be filled up with zeros. + static std::string ZERO_VAR_SUFFIX() { return "@ZERO"; } + virtual ~OperatorBase() {} template -- GitLab