From d7a1e40e10bbe9778916c61755ae04592ebb19b7 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Sun, 3 Sep 2017 11:02:47 -0700 Subject: [PATCH] Simple Implementation --- paddle/framework/operator.cc | 28 ++++++++-------------------- paddle/framework/operator.h | 20 ++++++-------------- 2 files changed, 14 insertions(+), 34 deletions(-) diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 4d3279588..68fb469b5 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -35,16 +35,10 @@ ExecutionContext::GetEigenDevice() const { std::string OperatorBase::Input(const std::string& name) const { auto& ins = Inputs(name); - switch (ins.size()) { - case 0: - return kEmptyVarName; - case 1: - return ins[0]; - default: - PADDLE_THROW("Op %s input %s should contain only one variable", type_, - name); - return ""; - } + PADDLE_ENFORCE_LE(ins.size(), 1UL, + "Op %s input %s should contain only one variable", type_, + name); + return ins.empty() ? kEmptyVarName : ins[0]; } const std::vector& OperatorBase::Inputs( @@ -57,16 +51,10 @@ const std::vector& OperatorBase::Inputs( std::string OperatorBase::Output(const std::string& name) const { auto& outs = Outputs(name); - switch (outs.size()) { - case 0: - return kEmptyVarName; - case 1: - return outs[0]; - default: - PADDLE_THROW("Op %s output %s should contain only one variable", type_, - name); - return ""; - } + PADDLE_ENFORCE_LE(outs.size(), 1UL, + "Op %s output %s should contain only one variable", type_, + name); + return outs.empty() ? kEmptyVarName : outs[0]; } const std::vector& OperatorBase::Outputs( diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 7759473ef..60ca8b279 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -239,20 +239,12 @@ class InferShapeContext { const Variable* InputVar(const std::string& name) const { auto ipt = op_.Input(name); - if (ipt == kEmptyVarName) { - return nullptr; - } else { - return scope_.FindVar(ipt); - } + return ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt); } Variable* OutputVar(const std::string& name) const { auto opt = op_.Output(name); - if (opt == kEmptyVarName) { - return nullptr; - } else { - return scope_.FindVar(opt); - } + return opt == kEmptyVarName ? nullptr : scope_.FindVar(opt); } const std::vector MultiInputVar( @@ -262,8 +254,8 @@ class InferShapeContext { res.reserve(names.size()); std::transform(names.begin(), names.end(), std::back_inserter(res), [this](const std::string& name) { - return name != kEmptyVarName ? scope_.FindVar(name) - : nullptr; + return name == kEmptyVarName ? nullptr + : scope_.FindVar(name); }); return res; } @@ -274,8 +266,8 @@ class InferShapeContext { res.reserve(names.size()); std::transform(names.begin(), names.end(), std::back_inserter(res), [this](const std::string& name) { - return name != kEmptyVarName ? scope_.FindVar(name) - : nullptr; + return name == kEmptyVarName ? nullptr + : scope_.FindVar(name); }); return res; } -- GitLab