diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 310d68d7c1baac231a2f1709af28bfb58ae1a436..0af527c88c753f8545184d4175114c17abfba0a6 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -334,6 +334,32 @@ class RuntimeInferShapeContext : public InferShapeContextBase { return var != nullptr; } + bool HasInputs(const std::string& name) const { + auto inputs = op_.Inputs(name); + if (inputs.size() == 0UL) { + return false; + } + for (auto& input : inputs) { + if (scope_.FindVar(input) == nullptr) { + return false; + } + } + return true; + } + + bool HasOutputs(const std::string& name) const { + auto outputs = op_.Outputs(name); + if (outputs.size() == 0UL) { + return false; + } + for (auto& output : outputs) { + if (scope_.FindVar(output) == nullptr) { + return false; + } + } + return true; + } + DDim GetInputDim(const std::string& name) const { return GetDim(op_.Input(name)); } diff --git a/paddle/framework/shape_inference.h b/paddle/framework/shape_inference.h index b07fc788124413f728c713027609d9d2d1c39538..bc8af0eb3ec7e8685eb7d2734e9b8f75372d1309 100644 --- a/paddle/framework/shape_inference.h +++ b/paddle/framework/shape_inference.h @@ -24,6 +24,10 @@ class InferShapeContextBase { virtual ~InferShapeContextBase() {} virtual bool HasInput(const std::string &name) const = 0; virtual bool HasOutput(const std::string &name) const = 0; + + virtual bool HasInputs(const std::string &name) const = 0; + virtual bool HasOutputs(const std::string &name) const = 0; + virtual framework::DDim GetInputDim(const std::string &name) const = 0; std::vector GetInputsDim(const std::string &name) const { const std::vector &names = Inputs(name);