diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index d8653b5dd681603b7261e58de02c6787bcdcebfe..1531cb53f9dc7ace9911a231e928e869cd7eca28 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -12,8 +12,9 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include -#include +#include "paddle/framework/backward.h" +#include "paddle/framework/net.h" +#include "paddle/framework/op_registry.h" namespace paddle { namespace framework { @@ -71,6 +72,24 @@ static std::shared_ptr BackwardImpl( //! TODO(dzh) } else { //! TODO(fjy) + std::shared_ptr grad_op = OpRegistry::CreateGradOp(forwardOp); + for (std::string& grad_input : grad_op->inputs_) { + if (no_grad_names.count(grad_input)) { + std::string prefix = grad_input.substr( + 0, grad_input.size() - OperatorBase::GRAD_VAR_SUFFIX().size()); + grad_input = prefix + OperatorBase::ZERO_VAR_SUFFIX(); + std::vector fill_zeros_in = {prefix}; + std::vector fill_zeros_out = {grad_input}; + net.AddOp(OpRegistry::CreateOp("fill_zeros_like", fill_zeros_in, + fill_zeros_out, AttributeMap())); + } + } + for (std::string& grad_output : grad_op->output_) { + if (no_grad_names.count(grad_output)) { + grad_output = OperatorBase::EMPTY_VAR_NAME(); + } + } + net.AddOp(grad_op); } net->CompleteAddOp(); diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 65fddb68112d70bda6d1462fd83561cfae657b6d..c2cd21a0806c310ba48917b3871aa7bfced186cf 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -67,6 +67,9 @@ class OperatorBase { /// e.g. Variable "x@GRAD" is the gradient of varibale "x". static std::string GRAD_VAR_SUFFIX() { return "@GRAD"; } + /// Variables with this suffix are supposed to be filled up with zeros. + static std::string ZERO_VAR_SUFFIX() { return "@ZERO"; } + virtual ~OperatorBase() {} template