diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index 472a671e470c5411750d56f91721d41c4461e3a8..c8fda8e260b09cc45135e8721624d9a0e8855be7 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -31,88 +31,74 @@ static bool AllInSet(const std::vector& names, return true; } -static std::vector InSetIdx( - const std::vector& names, const std::string& suffix, - const std::unordered_set& set) { - std::vector ret_val; - ret_val.reserve(names.size()); - for (size_t i = 0; i < names.size(); ++i) { - if (set.find(names[i] + suffix) != set.end()) { - ret_val.push_back(i); - } - } - return ret_val; -} - -static std::shared_ptr EmptyOp() { +static std::shared_ptr NOP() { auto net_op = std::make_shared(); - net_op->type_ = "@EMPTY_OP@"; + net_op->type_ = "@NOP@"; net_op->CompleteAddOp(); return net_op; } -/** - * @brief Backward an operator, implementation - * @param forwardOp the forward operator - * @param no_grad_names variable names not calculate for gradient. Like X@GRAD - * is not needed. - * @param uniq_id a unique index used inside BackwardImpl, it will be shared - * through recursive invoke. - * @return The backward operator. For simple situation, it is a simple operator. - * For complex situation, it is a NetOp. - * - * See Backward.h for details - */ -static std::shared_ptr BackwardImpl( +// Get backward operator from a forward operator, recursively implementation. +// +// no_grad_names the gradient variable names without gradient calculating. +// +// uniq_id is a unique index used inside recursively calling BackwardRecursive. +// use `uid = uniq_id++;` to get the unique index, and pass `uniq_id` through +// recursive calling. +// +// returns The backward operator. For simple situation, it is a simple +// operator. For complex situation, it is a NetOp. +// +// See Backward.h for details +static std::shared_ptr BackwardRecursive( + const OperatorBase& forwardOp, + std::unordered_set& no_grad_names, size_t& uniq_id); +std::shared_ptr BackwardRecursive( const OperatorBase& forwardOp, std::unordered_set& no_grad_names, size_t& uniq_id) { - /** - * If all input gradients of forwarding operator do not need to calculate, - * just return an EmptyOp. Not return null ptr because EmptyOp does not take - * too much time for calculation, but it is useful for simplifying logic. - */ + // If all input gradients of forwarding operator do not need to calculate, + // just return an NOP. Not return null ptr because NOP does not take + // too much time for calculation, but it is useful for simplifying logic. if (AllInSet(forwardOp.inputs_, OperatorBase::GRAD_VAR_SUFFIX(), no_grad_names)) { - return EmptyOp(); + return NOP(); } - /** - * All output gradients of forwarding operator do not need to calculate. Then - * all input gradients cannot be computed at all, and we put them into - * `no_grad_names` set. Return an EmptyOp. - */ + // All output gradients of forwarding operator do not need to calculate. Then + // all input gradients cannot be computed at all, and we put them into + // `no_grad_names` set. Return an NOP. if (AllInSet(forwardOp.outputs_, OperatorBase::GRAD_VAR_SUFFIX(), no_grad_names)) { for (auto& name : forwardOp.inputs_) { - /// Mark all input is not need + // Mark all input is not need no_grad_names.insert(name + OperatorBase::GRAD_VAR_SUFFIX()); } - return EmptyOp(); + return NOP(); } - //! Returned gradient network + // Returned gradient network auto net = std::make_shared(); if (forwardOp.IsNetOp()) { - /// Because forwardOp is a net op, it can static_cast. + // Because forwardOp is a net op, it can static_cast. auto& forwardNet = static_cast(forwardOp); - //! Map from output gradient variable name to operator's indices in backward - //! net. That operator generates that variable. + // Map from output gradient variable name to operator's indices in backward + // net. That operator generates that variable. std::unordered_map> dup_output_ops; size_t local_op_id = 0; - /// reversely travel forwardNet + // reversely travel forwardNet for (auto it = forwardNet.ops_.rbegin(); it != forwardNet.ops_.rend(); ++it, ++local_op_id) { auto fwd = *it; - auto bwd = BackwardImpl(*fwd, no_grad_names, uniq_id); + auto bwd = BackwardRecursive(*fwd, no_grad_names, uniq_id); net->AddOp(bwd); for (auto& out : bwd->outputs_) { dup_output_ops[out].emplace_back(local_op_id); } } - /// Get unique ID for this method. + // Get unique ID for this method. auto uid = uniq_id++; // TODO(dzh): more comment using Pos = std::pair>; @@ -145,13 +131,15 @@ static std::shared_ptr BackwardImpl( } } else { - //! TODO(fjy) std::shared_ptr grad_op = OpRegistry::CreateGradOp(forwardOp); for (std::string& grad_input : grad_op->inputs_) { if (no_grad_names.count(grad_input)) { std::string prefix = grad_input.substr( 0, grad_input.size() - OperatorBase::GRAD_VAR_SUFFIX().size()); grad_input = prefix + OperatorBase::ZERO_VAR_SUFFIX(); + + // If part of input gradient of that operator is not calculated, fill + // zero variables to that input gradient. net->AddOp(OpRegistry::CreateOp("fill_zeros_like", {prefix}, {grad_input}, {})); } @@ -173,8 +161,8 @@ static std::shared_ptr BackwardImpl( return net; } -//! See header for comments -extern std::shared_ptr Backward( +// See header for comments +std::shared_ptr Backward( const OperatorBase& forwardOp, const std::unordered_set& no_grad_vars) { std::unordered_set no_grad_names; @@ -184,7 +172,7 @@ extern std::shared_ptr Backward( no_grad_names.insert(name + OperatorBase::GRAD_VAR_SUFFIX()); } size_t uid = 0; - return BackwardImpl(forwardOp, no_grad_names, uid); + return BackwardRecursive(forwardOp, no_grad_names, uid); } } // namespace framework } // namespace paddle diff --git a/paddle/framework/backward.h b/paddle/framework/backward.h index d711c7bbb642781d50ccad0c249a38c939e5e31a..c181919dc165cf0b49362f85e22ceb4131bbd387 100644 --- a/paddle/framework/backward.h +++ b/paddle/framework/backward.h @@ -18,12 +18,8 @@ namespace paddle { namespace framework { -/** - * @brief - * @param forwardOp - * @param no_grad_vars ignored input name of forward - * @return - */ +// Create the backward operator from a forward operator. +// TODO(yuyang18): Add more API reference comment. extern std::shared_ptr Backward( const OperatorBase& forwardOp, const std::unordered_set& no_grad_vars); diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index ec55661e7990bc8432d0f359bfab4d631b67370f..cb14ef95737430397b9a035c31e9d29e6aff7eb6 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -169,7 +169,6 @@ TEST(Backward, simple_op_grad) { ASSERT_EQ("X" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->Output("X" + f::OperatorBase::GRAD_VAR_SUFFIX())); - // LOG(INFO) << gop->Output("X" + "@GRAD"); } TEST(Backward, simple_op_not_need_grad) { diff --git a/paddle/operators/fill_zeros_like_op.cc b/paddle/operators/fill_zeros_like_op.cc index d641bc4adaf8c7a84f5dab37632108d929a64730..79a0e3d7e911b728a7a96ceff573976ba2b2e37f 100644 --- a/paddle/operators/fill_zeros_like_op.cc +++ b/paddle/operators/fill_zeros_like_op.cc @@ -21,15 +21,17 @@ namespace operators { class FillZerosLikeOp : public framework::OperatorWithKernel { protected: - void InferShape( - const std::vector &inputs, - const std::vector &outputs) const override { - PADDLE_ENFORCE(inputs.size() == 1, + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE(ctx.InputSize() == 1UL, "Input size of FillZerosLikeOp must be one."); - PADDLE_ENFORCE(outputs.size() == 1, "Output size of AddOp must be one."); - PADDLE_ENFORCE(inputs[0] != nullptr && outputs[0] != nullptr, - "Outputs of FillZerosLikeOp must all be set."); - outputs[0]->Resize(inputs[0]->dims()); + PADDLE_ENFORCE(ctx.OutputSize() == 1UL, + "Output size of AddOp must be one."); + PADDLE_ENFORCE(ctx.InputVar(0) != nullptr, + "Input of FillZerosLikeOp must be set."); + PADDLE_ENFORCE(ctx.OutputVar(0) != nullptr, + "Output of FillZerosLikeOp must be set."); + ctx.Output(0)->Resize( + ctx.Input(0)->dims()); } }; diff --git a/paddle/operators/fill_zeros_like_op.h b/paddle/operators/fill_zeros_like_op.h index ca44a201f7d54bb05679752b8e39bce3dfb1023f..05272964abd43bdc2bd5c3cae8b128099e1c888c 100644 --- a/paddle/operators/fill_zeros_like_op.h +++ b/paddle/operators/fill_zeros_like_op.h @@ -23,8 +23,8 @@ namespace operators { template class FillZerosLikeKernel : public framework::OpKernel { public: - void Compute(const framework::KernelContext& context) const override { - auto* output = context.Output(0)->GetMutable(); + void Compute(const framework::ExecutionContext& context) const override { + auto* output = context.Output(0); output->mutable_data(context.GetPlace()); framework::EigenVector::Flatten(*output).setZero(); } diff --git a/paddle/operators/recurrent_network_op.cc b/paddle/operators/recurrent_network_op.cc index 1a101d6ddf149d608dbdbe048ef43d86bacbcc16..4ad3133184d46227bf3a6fdd0e403631a3c43a18 100644 --- a/paddle/operators/recurrent_network_op.cc +++ b/paddle/operators/recurrent_network_op.cc @@ -312,13 +312,14 @@ public: : OpProtoAndCheckerMaker(proto, op_checker) { const auto& name = RecurrentOp::kArgName; // inputs and outputs stored in proto - AddInputs(name.inlinks, - "the input that need to be segmented for each step."); - AddInputs(name.boot_memories, "variables to initialize memories."); + AddInput(name.inlinks, "the input that need to be segmented for each step.") + .SetMultiple(); + AddInput(name.boot_memories, "variables to initialize memories.") + .SetMultiple(); AddInput(name.step_net, "network shared by all steps."); - AddOutputs(name.outlinks, - "the output that need to concated for all steps."); + AddOutput(name.outlinks, "the output that need to concated for all steps.") + .SetMultiple(); AddOutput(name.step_scopes, "step scopes"); // Attributes stored in AttributeMap