From ff7fdb7d705a34e224561cb53933b5477fef644b Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Tue, 3 Oct 2017 10:28:52 -0700 Subject: [PATCH] Add `CreateBackwardOp` function --- paddle/framework/backward.cc | 56 +++++++++++++++++++++++++++++++++ paddle/framework/backward.h | 5 +++ paddle/framework/op_desc.cc | 20 ++++++++++++ paddle/framework/op_desc.h | 7 +++++ paddle/framework/op_registry.cc | 5 +++ paddle/framework/op_registry.h | 2 ++ 6 files changed, 95 insertions(+) diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index 0ec18de5b8..1b4c5c025e 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -222,5 +222,61 @@ std::unique_ptr Backward( return BackwardRecursive(forwardOp, no_grad_names, uid); } +// ==================================== // + +static bool AllGradInSet(const std::vector& names, + const std::unordered_set& set) { + for (const std::string& name : names) { + if (!set.count(GradVarName(name))) { + return false; + } + } + return true; +} + +std::vector CreatBackwardOps( + const OpDescBind& op_desc, unordered_map& no_grad_vars) { + std::vector grad_op_descs; + // All input gradients of forwarding operator do not need to calculat. + if (AllGradInSet(op_desc_.InputNames(), kGradVarSuffix, no_grad_vars)) { + return grad_op_descs; // empty vector + } + // All output gradients of forwarding operator do not need to calculate. + const std::vector& outputs = op_desc_.OutputNames(); + if (AllGradInSet(outputs, kGradVarSuffix, no_grad_vars)) { + for (const std::string& name : outputs) { + no_grad_vars.insert(GradVarName(name)); + } + return grad_op_descs; // empty vector + } + + grad_op_descs = OpRegistry::CreateGradOpDescs(op_desc); + + std::vector fill_zeros_ops; + for (OpDescBind& desc : grad_op_descs) { + for (const std::string& in_name : desc.InputNames()) { + if (no_grad_vars.count(in_name)) { + std::string prefix = in_name.substr( + 0, in_name.size() - sizeof(kGradVarSuffix) / sizeof(char) + 1); + std::string new_name = prefix + kZeroVarSuffix; + desc.Rename(in_name, new_name); + OpDescBind op_desc_bind( + {"fill_zeros_like", {{"X", {prefix}}}, {{"Y", {new_name}}}, {}}); + fill_zeros_ops.push_back(op_desc_bind); + } + } + for (const std::string& out_name : desc.OutputName()) { + if (no_grad_vars.count(out_name)) { + desc.Rename(out_name, kEmptyVarName); + } + } + } + grad_op_descs.insert(grad_op_descs.begin(), fill_zeros_ops.begin(), + fill_zeros_ops.end()); + + // TODO (fengjiayi): RNN op + return grad_op_descs; +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/backward.h b/paddle/framework/backward.h index 1ecf69881b..6aeddafb41 100644 --- a/paddle/framework/backward.h +++ b/paddle/framework/backward.h @@ -23,5 +23,10 @@ namespace framework { extern std::unique_ptr Backward( const OperatorBase& forwardOp, const std::unordered_set& no_grad_vars); + +extern void AppendBackwardOps( + BlockDescBind& block_desc, + const std::unordered_set& no_grad_vars); + } // namespace framework } // namespace paddle diff --git a/paddle/framework/op_desc.cc b/paddle/framework/op_desc.cc index 0c12c55dc0..e98f8f1154 100644 --- a/paddle/framework/op_desc.cc +++ b/paddle/framework/op_desc.cc @@ -18,6 +18,15 @@ limitations under the License. */ namespace paddle { namespace framework { +OpDescBind::OpDescBind(const std::string &type, const VariableNameMap &inputs, + const VariableNameMap &outputs, + const AttributeMap &attrs) { + op_desc_.set_type(type); + inputs_ = inputs; + outputs_ = outputs; + attrs_ = attrs; +} + OpDesc *OpDescBind::Proto() { Sync(); return &op_desc_; @@ -112,6 +121,17 @@ const std::unordered_map &OpDescBind::GetAttrMap() return attrs_; } +void Rename(const std::string &old_name, const std::string &new_name) { + for (std : string &input : inputs_) { + std::replace(input.second.begin(), input.second.end(), old_name, new_name); + } + for (std::string &output : outputs_) { + std::repalce(output.second.begin(), output.second.end(), old_name, + new_name); + } + need_update_ = true; +} + void OpDescBind::Sync() { if (need_update_) { this->op_desc_.mutable_inputs()->Clear(); diff --git a/paddle/framework/op_desc.h b/paddle/framework/op_desc.h index 0cf7d13971..a32e6d03f7 100644 --- a/paddle/framework/op_desc.h +++ b/paddle/framework/op_desc.h @@ -26,6 +26,11 @@ class BlockDescBind; class OpDescBind { public: + OpDescBind() {} + + OpDescBind(const std::string &type, const VariableNameMap &inputs, + const VariableNameMap &outputs, const AttributeMap &attrs); + OpDesc *Proto(); std::string Type() const { return op_desc_.type(); } @@ -67,6 +72,8 @@ class OpDescBind { int GetBlockAttr(const std::string &name) const; + void Rename(const std::string &old_name, const std::string &new_name); + // Only be used in C++ const std::unordered_map &GetAttrMap() const; diff --git a/paddle/framework/op_registry.cc b/paddle/framework/op_registry.cc index b0e85dd49f..fe3228ce5b 100644 --- a/paddle/framework/op_registry.cc +++ b/paddle/framework/op_registry.cc @@ -57,5 +57,10 @@ std::unique_ptr OpRegistry::CreateGradOp(const OperatorBase& op) { return std::unique_ptr(BuildGradOp(&op)); } +static std::vector CreateGradOpDescs(const OpDescBind& op_desc) { + auto& info = OpInfoMap::Instance().Get(op_desc.Type()); + return info.grad_op_maker_(op_desc); +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 4ee2c7d275..c80b6e9630 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -68,6 +68,8 @@ class OpRegistry { static std::unique_ptr CreateOp(const OpDesc& op_desc); static std::unique_ptr CreateGradOp(const OperatorBase& op); + + static std::vector CreateGradOpDescs(const OpDescBind& op_desc); }; class Registrar { -- GitLab