From 3b58574ba9fb5d007a0c82d87ea631a18698f169 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Tue, 1 Aug 2017 16:18:36 +0800 Subject: [PATCH] add check in OPeratorContext Input/Output --- paddle/framework/operator.cc | 6 ++++-- paddle/framework/operator.h | 40 +++++++++++++++++++++++++----------- 2 files changed, 32 insertions(+), 14 deletions(-) diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 9bf60b7b116..c08c6bba592 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -52,7 +52,8 @@ std::vector OperatorBase::Inputs(const std::string& name) const { PADDLE_ENFORCE(in_out_idxs_ != nullptr, "IO Idx could not be nullptr"); auto input_format = GetAttr>("input_format"); auto offset = in_out_idxs_->at(name); - PADDLE_ENFORCE(input_format.at((size_t)offset + 1) <= inputs_.size(), + PADDLE_ENFORCE(input_format.at(static_cast(offset) + 1) <= + static_cast(inputs_.size()), "Input Out Of Range"); return std::vector{ @@ -78,7 +79,8 @@ std::vector OperatorBase::Outputs(const std::string& name) const { PADDLE_ENFORCE(in_out_idxs_ != nullptr, "InOut Indice could not be nullptr"); auto output_format = GetAttr>("output_format"); auto offset = in_out_idxs_->at(name); - PADDLE_ENFORCE(output_format.at((size_t)offset + 1) <= outputs_.size(), + PADDLE_ENFORCE(output_format.at(static_cast(offset) + 1) <= + static_cast(outputs_.size()), "Output Out of Range"); return std::vector{ outputs_.begin() + output_format.at(offset), diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index ef1521b83bb..ff518265a45 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -108,11 +108,11 @@ class OperatorContext { size_t OutputSize() const { return op_.outputs_.size(); } - const Variable* InputVar(const size_t& index) const { + const Variable* InputVar(const size_t index) const { return scope_->GetVariable(op_.inputs_.at(index)); } - Variable* OutputVar(const size_t& index) const { + Variable* OutputVar(const size_t index) const { return scope_->GetVariable(op_.outputs_.at(index)); } @@ -146,23 +146,31 @@ class OperatorContext { } template - const T* Input(const size_t& index) const { - return &(InputVar(index)->Get()); + const T* Input(const size_t index) const { + auto var = InputVar(index); + PADDLE_ENFORCE(var != nullptr, "Input(%d) should not be nullptr", index); + return &var->Get(); } template - T* Output(const size_t& index) const { - return OutputVar(index)->GetMutable(); + T* Output(const size_t index) const { + auto var = OutputVar(index); + PADDLE_ENFORCE(var != nullptr, "Output(%d) should not be nullptr", index); + return var->GetMutable(); } template const T* Input(const std::string& name) const { - return &(InputVar(name)->Get()); + auto var = InputVar(name); + PADDLE_ENFORCE(var != nullptr, "Input(%s) should not be nullptr", name); + return &var->Get(); } template T* Output(const std::string& name) const { - return OutputVar(name)->GetMutable(); + auto var = OutputVar(name); + PADDLE_ENFORCE(var != nullptr, "Output(%s) should not be nullptr", name); + return var->GetMutable(); } template @@ -171,8 +179,12 @@ class OperatorContext { std::vector res; res.reserve(names.size()); std::transform(names.begin(), names.end(), std::back_inserter(res), - [this](const std::string& name) { - return &scope_->GetVariable(name)->Get(); + [&](const std::string& sub_name) { + auto var = scope_->GetVariable(sub_name); + PADDLE_ENFORCE(var != nullptr, + "MultiInput(%s:%s) should not be nullptr", + name, sub_name); + return &var->Get(); }); return res; } @@ -183,8 +195,12 @@ class OperatorContext { std::vector res; res.reserve(names.size()); std::transform(names.begin(), names.end(), std::back_inserter(res), - [this](const std::string& name) { - return scope_->GetVariable(name)->GetMutable(); + [&](const std::string& sub_name) { + auto var = scope_->GetVariable(sub_name); + PADDLE_ENFORCE(var != nullptr, + "MultiOutput(%s:%s) should not be nullptr", + name, sub_name); + return var->GetMutable(); }); return res; } -- GitLab