diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 4d327958843e00fe43657abe074bbbbd1364afbc..68fb469b561618ec838045a76d903bcfb805c06f 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -35,16 +35,10 @@ ExecutionContext::GetEigenDevice() const { std::string OperatorBase::Input(const std::string& name) const { auto& ins = Inputs(name); - switch (ins.size()) { - case 0: - return kEmptyVarName; - case 1: - return ins[0]; - default: - PADDLE_THROW("Op %s input %s should contain only one variable", type_, - name); - return ""; - } + PADDLE_ENFORCE_LE(ins.size(), 1UL, + "Op %s input %s should contain only one variable", type_, + name); + return ins.empty() ? kEmptyVarName : ins[0]; } const std::vector& OperatorBase::Inputs( @@ -57,16 +51,10 @@ const std::vector& OperatorBase::Inputs( std::string OperatorBase::Output(const std::string& name) const { auto& outs = Outputs(name); - switch (outs.size()) { - case 0: - return kEmptyVarName; - case 1: - return outs[0]; - default: - PADDLE_THROW("Op %s output %s should contain only one variable", type_, - name); - return ""; - } + PADDLE_ENFORCE_LE(outs.size(), 1UL, + "Op %s output %s should contain only one variable", type_, + name); + return outs.empty() ? kEmptyVarName : outs[0]; } const std::vector& OperatorBase::Outputs( diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 7759473ef303dbc73abc79dd5b053828742f4951..60ca8b279e75f7ade7183ea8b818708ced6dd05f 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -239,20 +239,12 @@ class InferShapeContext { const Variable* InputVar(const std::string& name) const { auto ipt = op_.Input(name); - if (ipt == kEmptyVarName) { - return nullptr; - } else { - return scope_.FindVar(ipt); - } + return ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt); } Variable* OutputVar(const std::string& name) const { auto opt = op_.Output(name); - if (opt == kEmptyVarName) { - return nullptr; - } else { - return scope_.FindVar(opt); - } + return opt == kEmptyVarName ? nullptr : scope_.FindVar(opt); } const std::vector MultiInputVar( @@ -262,8 +254,8 @@ class InferShapeContext { res.reserve(names.size()); std::transform(names.begin(), names.end(), std::back_inserter(res), [this](const std::string& name) { - return name != kEmptyVarName ? scope_.FindVar(name) - : nullptr; + return name == kEmptyVarName ? nullptr + : scope_.FindVar(name); }); return res; } @@ -274,8 +266,8 @@ class InferShapeContext { res.reserve(names.size()); std::transform(names.begin(), names.end(), std::back_inserter(res), [this](const std::string& name) { - return name != kEmptyVarName ? scope_.FindVar(name) - : nullptr; + return name == kEmptyVarName ? nullptr + : scope_.FindVar(name); }); return res; }