From 6084af47ef4358ab2b54636aa41d976f2ea34056 Mon Sep 17 00:00:00 2001 From: Yiqun Liu Date: Thu, 3 May 2018 10:28:16 +0800 Subject: [PATCH] Fix the bug when a input variable of op is dispensable. (#10268) * Fix the bug when a input variable of op is dispensable. * Add HasInputs/Outputs interfaces to OperatorBase. * Remove the unreferenced header file. --- paddle/capi/Matrix.cpp | 2 +- paddle/fluid/framework/operator.cc | 41 +++++++++++++++++++++++++++--- paddle/fluid/framework/operator.h | 2 ++ paddle/fluid/platform/profiler.h | 1 - 4 files changed, 40 insertions(+), 6 deletions(-) diff --git a/paddle/capi/Matrix.cpp b/paddle/capi/Matrix.cpp index 24b002063..733d49cac 100644 --- a/paddle/capi/Matrix.cpp +++ b/paddle/capi/Matrix.cpp @@ -108,7 +108,7 @@ paddle_error paddle_matrix_get_row(paddle_matrix mat, paddle_error paddle_matrix_get_shape(paddle_matrix mat, uint64_t* height, uint64_t* width) { - if (mat == nullptr) return kPD_NULLPTR; + if (mat == nullptr || cast(mat)->mat == nullptr) return kPD_NULLPTR; if (height != nullptr) { *height = cast(mat)->mat->getHeight(); } diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 32576423a..d70f26026 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -93,6 +93,14 @@ void OperatorBase::Run(const Scope& scope, const platform::Place& place) { RunImpl(scope, place); } +bool OperatorBase::HasInputs(const std::string& name) const { + if (inputs_.find(name) != inputs_.end()) { + return true; + } else { + return false; + } +} + std::string OperatorBase::Input(const std::string& name) const { auto& ins = Inputs(name); PADDLE_ENFORCE_LE(ins.size(), 1UL, @@ -109,6 +117,14 @@ const std::vector& OperatorBase::Inputs( return it->second; } +bool OperatorBase::HasOutputs(const std::string& name) const { + if (outputs_.find(name) != outputs_.end()) { + return true; + } else { + return false; + } +} + std::string OperatorBase::Output(const std::string& name) const { auto& outs = Outputs(name); PADDLE_ENFORCE_LE(outs.size(), 1UL, @@ -220,13 +236,18 @@ void OperatorBase::CheckAllInputOutputSet() const { if (op_info == nullptr || op_info->proto_ == nullptr) return; for (auto& in : op_info->Proto().inputs()) { - PADDLE_ENFORCE(inputs_.find(in.name()) != inputs_.end(), - "Type %s's input %s is not set", Type(), in.name()); + if (!in.dispensable()) { + PADDLE_ENFORCE(inputs_.find(in.name()) != inputs_.end(), + "Operator %s's input, %s, is not set", Type(), in.name()); + } } for (auto& out : op_info->Proto().outputs()) { - PADDLE_ENFORCE(outputs_.find(out.name()) != outputs_.end(), - "Type %s's output %s is not set", Type(), out.name()); + if (!out.dispensable()) { + PADDLE_ENFORCE(outputs_.find(out.name()) != outputs_.end(), + "Operator %s's output, %s, is not set", Type(), + out.name()); + } } } @@ -332,6 +353,9 @@ class RuntimeInferShapeContext : public InferShapeContext { : op_(op), scope_(scope) {} bool HasInput(const std::string& name) const override { + if (!op_.HasInputs(name)) { + return false; + } auto& ins = Inputs(name); size_t length = ins.size(); if (length == 0) { @@ -345,6 +369,9 @@ class RuntimeInferShapeContext : public InferShapeContext { } bool HasOutput(const std::string& name) const override { + if (!op_.HasOutputs(name)) { + return false; + } auto& outs = Outputs(name); size_t length = outs.size(); if (length == 0) { @@ -358,6 +385,9 @@ class RuntimeInferShapeContext : public InferShapeContext { } bool HasInputs(const std::string& name) const override { + if (!op_.HasInputs(name)) { + return false; + } auto inputs = op_.Inputs(name); if (inputs.empty()) { return false; @@ -371,6 +401,9 @@ class RuntimeInferShapeContext : public InferShapeContext { } bool HasOutputs(const std::string& name) const override { + if (!op_.HasOutputs(name)) { + return false; + } auto outputs = op_.Outputs(name); if (outputs.empty()) { return false; diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 826cc57b7..d373c48b1 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -105,6 +105,7 @@ class OperatorBase { const VariableNameMap& Inputs() const { return inputs_; } const VariableNameMap& Outputs() const { return outputs_; } + bool HasInputs(const std::string& name) const; //! Get a input with argument's name described in `op_proto` std::string Input(const std::string& name) const; //! Get a input which has multiple variables. @@ -112,6 +113,7 @@ class OperatorBase { //! Get all inputs variable names std::vector InputVars() const; + bool HasOutputs(const std::string& name) const; //! Get a output with argument's name described in `op_proto` std::string Output(const std::string& name) const; //! Get an output which has multiple variables. diff --git a/paddle/fluid/platform/profiler.h b/paddle/fluid/platform/profiler.h index b07427c8f..428d9ebce 100644 --- a/paddle/fluid/platform/profiler.h +++ b/paddle/fluid/platform/profiler.h @@ -18,7 +18,6 @@ limitations under the License. */ #include #include #include "paddle/fluid/platform/device_context.h" -#include "paddle/fluid/platform/profiler.pb.h" namespace paddle { namespace platform { -- GitLab