From fd8df0806dd07af4fed9ae40ffe5ec571639d2c4 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Sat, 2 Sep 2017 20:37:48 -0700 Subject: [PATCH] Make operator Input/Output can return nullptr --- paddle/framework/operator.cc | 32 +++++++++++++++-------- paddle/framework/operator.h | 50 ++++++++++++++++++++---------------- 2 files changed, 50 insertions(+), 32 deletions(-) diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 7abbde610f1..4d327958843 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -33,12 +33,18 @@ ExecutionContext::GetEigenDevice() const { } #endif -const std::string& OperatorBase::Input(const std::string& name) const { +std::string OperatorBase::Input(const std::string& name) const { auto& ins = Inputs(name); - PADDLE_ENFORCE_EQ(ins.size(), 1UL, - "Op %s input %s should contain only one variable", type_, - name); - return ins[0]; + switch (ins.size()) { + case 0: + return kEmptyVarName; + case 1: + return ins[0]; + default: + PADDLE_THROW("Op %s input %s should contain only one variable", type_, + name); + return ""; + } } const std::vector& OperatorBase::Inputs( @@ -49,12 +55,18 @@ const std::vector& OperatorBase::Inputs( return it->second; } -const std::string& OperatorBase::Output(const std::string& name) const { +std::string OperatorBase::Output(const std::string& name) const { auto& outs = Outputs(name); - PADDLE_ENFORCE_EQ(outs.size(), 1UL, - "Op %s output %s should contain only one variable", type_, - name); - return outs[0]; + switch (outs.size()) { + case 0: + return kEmptyVarName; + case 1: + return outs[0]; + default: + PADDLE_THROW("Op %s output %s should contain only one variable", type_, + name); + return ""; + } } const std::vector& OperatorBase::Outputs( diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 8397570d26f..7759473ef30 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -95,12 +95,12 @@ class OperatorBase { const VariableNameMap& Inputs() const { return inputs_; } const VariableNameMap& Outputs() const { return outputs_; } //! Get a input with argument's name described in `op_proto` - const std::string& Input(const std::string& name) const; + std::string Input(const std::string& name) const; //! Get a input which has multiple variables. const std::vector& Inputs(const std::string& name) const; //! Get a output with argument's name described in `op_proto` - const std::string& Output(const std::string& name) const; + std::string Output(const std::string& name) const; //! Get an output which has multiple variables. //! TODO add a vector_view to prevent memory copy. const std::vector& Outputs(const std::string& name) const; @@ -238,11 +238,21 @@ class InferShapeContext { } const Variable* InputVar(const std::string& name) const { - return scope_.FindVar(op_.Input(name)); + auto ipt = op_.Input(name); + if (ipt == kEmptyVarName) { + return nullptr; + } else { + return scope_.FindVar(ipt); + } } Variable* OutputVar(const std::string& name) const { - return scope_.FindVar(op_.Output(name)); + auto opt = op_.Output(name); + if (opt == kEmptyVarName) { + return nullptr; + } else { + return scope_.FindVar(opt); + } } const std::vector MultiInputVar( @@ -250,9 +260,11 @@ class InferShapeContext { auto names = op_.Inputs(name); std::vector res; res.reserve(names.size()); - std::transform( - names.begin(), names.end(), std::back_inserter(res), - [this](const std::string& name) { return scope_.FindVar(name); }); + std::transform(names.begin(), names.end(), std::back_inserter(res), + [this](const std::string& name) { + return name != kEmptyVarName ? scope_.FindVar(name) + : nullptr; + }); return res; } @@ -260,24 +272,24 @@ class InferShapeContext { auto names = op_.Outputs(name); std::vector res; res.reserve(names.size()); - std::transform( - names.begin(), names.end(), std::back_inserter(res), - [this](const std::string& name) { return scope_.FindVar(name); }); + std::transform(names.begin(), names.end(), std::back_inserter(res), + [this](const std::string& name) { + return name != kEmptyVarName ? scope_.FindVar(name) + : nullptr; + }); return res; } template const T* Input(const std::string& name) const { auto* var = InputVar(name); - PADDLE_ENFORCE_NOT_NULL(var, "Input(%s) should not be nullptr", name); - return &var->Get(); + return var == nullptr ? nullptr : &var->Get(); } template T* Output(const std::string& name) const { auto var = OutputVar(name); - PADDLE_ENFORCE_NOT_NULL(var, "Output(%s) should not be nullptr", name); - return var->GetMutable(); + return var == nullptr ? nullptr : var->GetMutable(); } template @@ -288,10 +300,7 @@ class InferShapeContext { std::transform(names.begin(), names.end(), std::back_inserter(res), [&](const std::string& sub_name) { auto var = scope_.FindVar(sub_name); - PADDLE_ENFORCE_NOT_NULL( - var, "MultiInput(%s:%s) should not be nullptr", name, - sub_name); - return &var->Get(); + return var == nullptr ? nullptr : &var->Get(); }); return res; } @@ -304,10 +313,7 @@ class InferShapeContext { std::transform(names.begin(), names.end(), std::back_inserter(res), [&](const std::string& sub_name) { auto var = scope_.FindVar(sub_name); - PADDLE_ENFORCE_NOT_NULL( - var, "MultiOutput(%s:%s) should not be nullptr.", name, - sub_name); - return var->GetMutable(); + return var == nullptr ? nullptr : var->GetMutable(); }); return res; } -- GitLab