From 2d8c7e310a23f62f3a9aad7823beec8f714bd368 Mon Sep 17 00:00:00 2001 From: Yancey Date: Sat, 30 Sep 2017 12:10:21 +0800 Subject: [PATCH] Add Inputs/Outputs check interface (#4438) Add multiple Inputs/Outputs check interface --- paddle/framework/operator.h | 26 ++++++++++++++++++++++++++ paddle/framework/shape_inference.h | 4 ++++ 2 files changed, 30 insertions(+) diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 310d68d7c1..0af527c88c 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -334,6 +334,32 @@ class RuntimeInferShapeContext : public InferShapeContextBase { return var != nullptr; } + bool HasInputs(const std::string& name) const { + auto inputs = op_.Inputs(name); + if (inputs.size() == 0UL) { + return false; + } + for (auto& input : inputs) { + if (scope_.FindVar(input) == nullptr) { + return false; + } + } + return true; + } + + bool HasOutputs(const std::string& name) const { + auto outputs = op_.Outputs(name); + if (outputs.size() == 0UL) { + return false; + } + for (auto& output : outputs) { + if (scope_.FindVar(output) == nullptr) { + return false; + } + } + return true; + } + DDim GetInputDim(const std::string& name) const { return GetDim(op_.Input(name)); } diff --git a/paddle/framework/shape_inference.h b/paddle/framework/shape_inference.h index b07fc78812..bc8af0eb3e 100644 --- a/paddle/framework/shape_inference.h +++ b/paddle/framework/shape_inference.h @@ -24,6 +24,10 @@ class InferShapeContextBase { virtual ~InferShapeContextBase() {} virtual bool HasInput(const std::string &name) const = 0; virtual bool HasOutput(const std::string &name) const = 0; + + virtual bool HasInputs(const std::string &name) const = 0; + virtual bool HasOutputs(const std::string &name) const = 0; + virtual framework::DDim GetInputDim(const std::string &name) const = 0; std::vector GetInputsDim(const std::string &name) const { const std::vector &names = Inputs(name); -- GitLab