提交 d7a1e40e 编写于 作者: Y Yu Yang

Simple Implementation

上级 fd8df080
......@@ -35,16 +35,10 @@ ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
std::string OperatorBase::Input(const std::string& name) const {
auto& ins = Inputs(name);
switch (ins.size()) {
case 0:
return kEmptyVarName;
case 1:
return ins[0];
default:
PADDLE_THROW("Op %s input %s should contain only one variable", type_,
name);
return "";
}
PADDLE_ENFORCE_LE(ins.size(), 1UL,
"Op %s input %s should contain only one variable", type_,
name);
return ins.empty() ? kEmptyVarName : ins[0];
}
const std::vector<std::string>& OperatorBase::Inputs(
......@@ -57,16 +51,10 @@ const std::vector<std::string>& OperatorBase::Inputs(
std::string OperatorBase::Output(const std::string& name) const {
auto& outs = Outputs(name);
switch (outs.size()) {
case 0:
return kEmptyVarName;
case 1:
return outs[0];
default:
PADDLE_THROW("Op %s output %s should contain only one variable", type_,
name);
return "";
}
PADDLE_ENFORCE_LE(outs.size(), 1UL,
"Op %s output %s should contain only one variable", type_,
name);
return outs.empty() ? kEmptyVarName : outs[0];
}
const std::vector<std::string>& OperatorBase::Outputs(
......
......@@ -239,20 +239,12 @@ class InferShapeContext {
const Variable* InputVar(const std::string& name) const {
auto ipt = op_.Input(name);
if (ipt == kEmptyVarName) {
return nullptr;
} else {
return scope_.FindVar(ipt);
}
return ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
}
Variable* OutputVar(const std::string& name) const {
auto opt = op_.Output(name);
if (opt == kEmptyVarName) {
return nullptr;
} else {
return scope_.FindVar(opt);
}
return opt == kEmptyVarName ? nullptr : scope_.FindVar(opt);
}
const std::vector<const Variable*> MultiInputVar(
......@@ -262,8 +254,8 @@ class InferShapeContext {
res.reserve(names.size());
std::transform(names.begin(), names.end(), std::back_inserter(res),
[this](const std::string& name) {
return name != kEmptyVarName ? scope_.FindVar(name)
: nullptr;
return name == kEmptyVarName ? nullptr
: scope_.FindVar(name);
});
return res;
}
......@@ -274,8 +266,8 @@ class InferShapeContext {
res.reserve(names.size());
std::transform(names.begin(), names.end(), std::back_inserter(res),
[this](const std::string& name) {
return name != kEmptyVarName ? scope_.FindVar(name)
: nullptr;
return name == kEmptyVarName ? nullptr
: scope_.FindVar(name);
});
return res;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册