diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index 11690342185ba6046df2b0f9121bb9bd8cd9a073..d8653b5dd681603b7261e58de02c6787bcdcebfe 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -29,10 +29,10 @@ static bool AllInSet(const std::vector& names, return true; } -static std::vector InSetIdx(const std::vector& names, - const std::string& suffix, - const std::unordered_set& set) { - std::vector ret_val; +static std::vector InSetIdx( + const std::vector& names, const std::string& suffix, + const std::unordered_set& set) { + std::vector ret_val; ret_val.reserve(names.size()); for (size_t i = 0; i < names.size(); ++i) { if (set.find(names[i] + suffix) != set.end()) { @@ -78,7 +78,7 @@ static std::shared_ptr BackwardImpl( } extern std::shared_ptr Backward( - const std::shared_ptr& forwardOp, + const OperatorBase& forwardOp, const std::unordered_set& no_grad_vars) { std::unordered_set no_grad_names; no_grad_names.reserve(no_grad_vars.size()); @@ -87,7 +87,7 @@ extern std::shared_ptr Backward( no_grad_names.insert(name + OperatorBase::GRAD_VAR_SUFFIX()); } int uid = 0; - return BackwardImpl(*forwardOp, no_grad_names, uid); + return BackwardImpl(forwardOp, no_grad_names, uid); } } // namespace framework } // namespace paddle diff --git a/paddle/framework/backward.h b/paddle/framework/backward.h index e835ef6351102686d68d657fdc5a6a2913ace3e6..d711c7bbb642781d50ccad0c249a38c939e5e31a 100644 --- a/paddle/framework/backward.h +++ b/paddle/framework/backward.h @@ -25,7 +25,7 @@ namespace framework { * @return */ extern std::shared_ptr Backward( - const std::shared_ptr& forwardOp, + const OperatorBase& forwardOp, const std::unordered_set& no_grad_vars); } // namespace framework } // namespace paddle