提交 b1b13f8f 编写于 作者: Y Yu Yang

Update Interface

上级 ecf23ce5
...@@ -29,10 +29,10 @@ static bool AllInSet(const std::vector<std::string>& names, ...@@ -29,10 +29,10 @@ static bool AllInSet(const std::vector<std::string>& names,
return true; return true;
} }
static std::vector<int> InSetIdx(const std::vector<std::string>& names, static std::vector<size_t> InSetIdx(
const std::string& suffix, const std::vector<std::string>& names, const std::string& suffix,
const std::unordered_set<std::string>& set) { const std::unordered_set<std::string>& set) {
std::vector<int> ret_val; std::vector<size_t> ret_val;
ret_val.reserve(names.size()); ret_val.reserve(names.size());
for (size_t i = 0; i < names.size(); ++i) { for (size_t i = 0; i < names.size(); ++i) {
if (set.find(names[i] + suffix) != set.end()) { if (set.find(names[i] + suffix) != set.end()) {
...@@ -78,7 +78,7 @@ static std::shared_ptr<OperatorBase> BackwardImpl( ...@@ -78,7 +78,7 @@ static std::shared_ptr<OperatorBase> BackwardImpl(
} }
extern std::shared_ptr<OperatorBase> Backward( extern std::shared_ptr<OperatorBase> Backward(
const std::shared_ptr<OperatorBase>& forwardOp, const OperatorBase& forwardOp,
const std::unordered_set<std::string>& no_grad_vars) { const std::unordered_set<std::string>& no_grad_vars) {
std::unordered_set<std::string> no_grad_names; std::unordered_set<std::string> no_grad_names;
no_grad_names.reserve(no_grad_vars.size()); no_grad_names.reserve(no_grad_vars.size());
...@@ -87,7 +87,7 @@ extern std::shared_ptr<OperatorBase> Backward( ...@@ -87,7 +87,7 @@ extern std::shared_ptr<OperatorBase> Backward(
no_grad_names.insert(name + OperatorBase::GRAD_VAR_SUFFIX()); no_grad_names.insert(name + OperatorBase::GRAD_VAR_SUFFIX());
} }
int uid = 0; int uid = 0;
return BackwardImpl(*forwardOp, no_grad_names, uid); return BackwardImpl(forwardOp, no_grad_names, uid);
} }
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -25,7 +25,7 @@ namespace framework { ...@@ -25,7 +25,7 @@ namespace framework {
* @return * @return
*/ */
extern std::shared_ptr<OperatorBase> Backward( extern std::shared_ptr<OperatorBase> Backward(
const std::shared_ptr<OperatorBase>& forwardOp, const OperatorBase& forwardOp,
const std::unordered_set<std::string>& no_grad_vars); const std::unordered_set<std::string>& no_grad_vars);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册