提交 d7a1e40e 编写于 作者: Y Yu Yang

Simple Implementation

上级 fd8df080
...@@ -35,16 +35,10 @@ ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const { ...@@ -35,16 +35,10 @@ ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
std::string OperatorBase::Input(const std::string& name) const { std::string OperatorBase::Input(const std::string& name) const {
auto& ins = Inputs(name); auto& ins = Inputs(name);
switch (ins.size()) { PADDLE_ENFORCE_LE(ins.size(), 1UL,
case 0: "Op %s input %s should contain only one variable", type_,
return kEmptyVarName;
case 1:
return ins[0];
default:
PADDLE_THROW("Op %s input %s should contain only one variable", type_,
name); name);
return ""; return ins.empty() ? kEmptyVarName : ins[0];
}
} }
const std::vector<std::string>& OperatorBase::Inputs( const std::vector<std::string>& OperatorBase::Inputs(
...@@ -57,16 +51,10 @@ const std::vector<std::string>& OperatorBase::Inputs( ...@@ -57,16 +51,10 @@ const std::vector<std::string>& OperatorBase::Inputs(
std::string OperatorBase::Output(const std::string& name) const { std::string OperatorBase::Output(const std::string& name) const {
auto& outs = Outputs(name); auto& outs = Outputs(name);
switch (outs.size()) { PADDLE_ENFORCE_LE(outs.size(), 1UL,
case 0: "Op %s output %s should contain only one variable", type_,
return kEmptyVarName;
case 1:
return outs[0];
default:
PADDLE_THROW("Op %s output %s should contain only one variable", type_,
name); name);
return ""; return outs.empty() ? kEmptyVarName : outs[0];
}
} }
const std::vector<std::string>& OperatorBase::Outputs( const std::vector<std::string>& OperatorBase::Outputs(
......
...@@ -239,20 +239,12 @@ class InferShapeContext { ...@@ -239,20 +239,12 @@ class InferShapeContext {
const Variable* InputVar(const std::string& name) const { const Variable* InputVar(const std::string& name) const {
auto ipt = op_.Input(name); auto ipt = op_.Input(name);
if (ipt == kEmptyVarName) { return ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
return nullptr;
} else {
return scope_.FindVar(ipt);
}
} }
Variable* OutputVar(const std::string& name) const { Variable* OutputVar(const std::string& name) const {
auto opt = op_.Output(name); auto opt = op_.Output(name);
if (opt == kEmptyVarName) { return opt == kEmptyVarName ? nullptr : scope_.FindVar(opt);
return nullptr;
} else {
return scope_.FindVar(opt);
}
} }
const std::vector<const Variable*> MultiInputVar( const std::vector<const Variable*> MultiInputVar(
...@@ -262,8 +254,8 @@ class InferShapeContext { ...@@ -262,8 +254,8 @@ class InferShapeContext {
res.reserve(names.size()); res.reserve(names.size());
std::transform(names.begin(), names.end(), std::back_inserter(res), std::transform(names.begin(), names.end(), std::back_inserter(res),
[this](const std::string& name) { [this](const std::string& name) {
return name != kEmptyVarName ? scope_.FindVar(name) return name == kEmptyVarName ? nullptr
: nullptr; : scope_.FindVar(name);
}); });
return res; return res;
} }
...@@ -274,8 +266,8 @@ class InferShapeContext { ...@@ -274,8 +266,8 @@ class InferShapeContext {
res.reserve(names.size()); res.reserve(names.size());
std::transform(names.begin(), names.end(), std::back_inserter(res), std::transform(names.begin(), names.end(), std::back_inserter(res),
[this](const std::string& name) { [this](const std::string& name) {
return name != kEmptyVarName ? scope_.FindVar(name) return name == kEmptyVarName ? nullptr
: nullptr; : scope_.FindVar(name);
}); });
return res; return res;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册