提交 3395bf7a 编写于 作者: Y Yu Yang

Remove duplicated method in OpDesc

上级 d068777a
...@@ -44,12 +44,12 @@ class GradOpDescMakerBase { ...@@ -44,12 +44,12 @@ class GradOpDescMakerBase {
return ToGradNames(fwd_op_.Output(name)); return ToGradNames(fwd_op_.Output(name));
} }
std::vector<std::string> InputParamNames() const { std::vector<std::string> InputNames() const {
return this->fwd_op_.InputParamNames(); return this->fwd_op_.InputNames();
} }
std::vector<std::string> OutputParamNames() const { std::vector<std::string> OutputNames() const {
return this->fwd_op_.OutputParamNames(); return this->fwd_op_.OutputNames();
} }
std::vector<std::string> Input(const std::string& name) const { std::vector<std::string> Input(const std::string& name) const {
...@@ -96,12 +96,12 @@ class DefaultGradOpDescMaker : public SingleGradOpDescMaker { ...@@ -96,12 +96,12 @@ class DefaultGradOpDescMaker : public SingleGradOpDescMaker {
OpDescBind grad; OpDescBind grad;
grad.SetType(this->GradOpType()); grad.SetType(this->GradOpType());
for (auto& input_param : this->InputParamNames()) { for (auto& input_param : this->InputNames()) {
grad.SetInput(input_param, this->Input(input_param)); grad.SetInput(input_param, this->Input(input_param));
grad.SetOutput(GradVarName(input_param), this->InputGrad(input_param)); grad.SetOutput(GradVarName(input_param), this->InputGrad(input_param));
} }
for (auto& output_param : this->OutputParamNames()) { for (auto& output_param : this->OutputNames()) {
grad.SetInput(output_param, this->Output(output_param)); grad.SetInput(output_param, this->Output(output_param));
grad.SetInput(GradVarName(output_param), this->OutputGrad(output_param)); grad.SetInput(GradVarName(output_param), this->OutputGrad(output_param));
} }
......
...@@ -31,15 +31,6 @@ const std::vector<std::string> &OpDescBind::Input( ...@@ -31,15 +31,6 @@ const std::vector<std::string> &OpDescBind::Input(
return it->second; return it->second;
} }
std::vector<std::string> OpDescBind::InputNames() const {
std::vector<std::string> retv;
retv.reserve(this->inputs_.size());
for (auto &ipt : this->inputs_) {
retv.push_back(ipt.first);
}
return retv;
}
void OpDescBind::SetInput(const std::string &param_name, void OpDescBind::SetInput(const std::string &param_name,
const std::vector<std::string> &args) { const std::vector<std::string> &args) {
need_update_ = true; need_update_ = true;
...@@ -54,15 +45,6 @@ const std::vector<std::string> &OpDescBind::Output( ...@@ -54,15 +45,6 @@ const std::vector<std::string> &OpDescBind::Output(
return it->second; return it->second;
} }
std::vector<std::string> OpDescBind::OutputNames() const {
std::vector<std::string> retv;
retv.reserve(this->outputs_.size());
for (auto &ipt : this->outputs_) {
retv.push_back(ipt.first);
}
return retv;
}
void OpDescBind::SetOutput(const std::string &param_name, void OpDescBind::SetOutput(const std::string &param_name,
const std::vector<std::string> &args) { const std::vector<std::string> &args) {
need_update_ = true; need_update_ = true;
......
...@@ -35,15 +35,11 @@ class OpDescBind { ...@@ -35,15 +35,11 @@ class OpDescBind {
const std::vector<std::string> &Input(const std::string &name) const; const std::vector<std::string> &Input(const std::string &name) const;
std::vector<std::string> InputNames() const;
void SetInput(const std::string &param_name, void SetInput(const std::string &param_name,
const std::vector<std::string> &args); const std::vector<std::string> &args);
const std::vector<std::string> &Output(const std::string &name) const; const std::vector<std::string> &Output(const std::string &name) const;
std::vector<std::string> OutputNames() const;
void SetOutput(const std::string &param_name, void SetOutput(const std::string &param_name,
const std::vector<std::string> &args); const std::vector<std::string> &args);
...@@ -71,10 +67,8 @@ class OpDescBind { ...@@ -71,10 +67,8 @@ class OpDescBind {
// Only be used in C++ // Only be used in C++
void SetAttrMap(const AttributeMap &attr_map); void SetAttrMap(const AttributeMap &attr_map);
std::vector<std::string> InputParamNames() const { return MapKeys(inputs_); } std::vector<std::string> InputNames() const { return MapKeys(inputs_); }
std::vector<std::string> OutputParamNames() const { std::vector<std::string> OutputNames() const { return MapKeys(outputs_); }
return MapKeys(outputs_);
}
private: private:
template <typename MapType> template <typename MapType>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册