diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index dde642764fa5dfce11edcef51ad1be11be331fbc..0a3bb586fcc61e80f0358adc21730c279509ef87 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -110,6 +110,30 @@ class CompileTimeInferShapeContext : public InferShapeContext { } } + std::vector GetInputVarPtrs( + const std::string &name) override { + const std::vector arg_names = Inputs(name); + std::vector res; + res.reserve(arg_names.size()); + std::transform(arg_names.begin(), arg_names.end(), std::back_inserter(res), + [this](const std::string &name) { + return block_.FindVarRecursive(name); + }); + return res; + } + + std::vector GetOutputVarPtrs( + const std::string &name) override { + const std::vector arg_names = Outputs(name); + std::vector res; + res.reserve(arg_names.size()); + std::transform(arg_names.begin(), arg_names.end(), std::back_inserter(res), + [this](const std::string &name) { + return block_.FindVarRecursive(name); + }); + return res; + } + bool IsRuntime() const override; protected: @@ -124,8 +148,6 @@ class CompileTimeInferShapeContext : public InferShapeContext { void SetRepeatedDims(const std::string &name, const std::vector &dims) override; - InferShapeVarPtr GetVarPtr(const std::string &name) override; - const OpDesc &op_; const BlockDesc &block_; }; @@ -696,10 +718,5 @@ proto::VarType::Type CompileTimeInferShapeContext::GetVarType( return block_.FindVarRecursive(name)->GetType(); } -InferShapeVarPtr CompileTimeInferShapeContext::GetVarPtr( - const std::string &name) { - return block_.FindVarRecursive(name); -} - } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index e023d165b03c0dc44a97ee3cfcba9dea493a47ad..4ccef3105c6c1724f0b767d657fa39ac5a6ce266 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -691,6 +691,25 @@ class RuntimeInferShapeContext : public InferShapeContext { bool IsRuntime() const override { return true; } + // TODO(paddle-dev): Can this be template? + std::vector GetInputVarPtrs( + const std::string& name) override { + const std::vector& vars = InputVars(name); + std::vector res; + res.reserve(vars.size()); + res.insert(res.begin(), vars.begin(), vars.end()); + return res; + } + + std::vector GetOutputVarPtrs( + const std::string& name) override { + const std::vector& vars = OutputVars(name); + std::vector res; + res.reserve(vars.size()); + res.insert(res.begin(), vars.begin(), vars.end()); + return res; + } + protected: DDim GetDim(const std::string& name) const override { Variable* var = scope_.FindVar(name); @@ -733,11 +752,22 @@ class RuntimeInferShapeContext : public InferShapeContext { return ToVarType(var->Type()); } - InferShapeVarPtr GetVarPtr(const std::string& name) override { - return scope_.FindVar(name); + private: + const std::vector& InputVars(const std::string& name) const { + auto it = ctx_.inputs.find(name); + PADDLE_ENFORCE(it != ctx_.inputs.end(), + "Operator %s does not have the input %s.", op_.Type(), name); + return it->second; + } + + const std::vector& OutputVars(const std::string& name) const { + auto it = ctx_.outputs.find(name); + PADDLE_ENFORCE(it != ctx_.outputs.end(), + "Operator %s does not have the outputs %s.", op_.Type(), + name); + return it->second; } - private: const OperatorBase& op_; const Scope& scope_; const RuntimeContext& ctx_; diff --git a/paddle/fluid/framework/shape_inference.cc b/paddle/fluid/framework/shape_inference.cc index ddff2c7c261746ac9986e79cff3da7e0a9654adc..0a7cebcc5a2e92334a4cfcbf87d7c7b475fd255e 100644 --- a/paddle/fluid/framework/shape_inference.cc +++ b/paddle/fluid/framework/shape_inference.cc @@ -76,28 +76,6 @@ void InferShapeContext::SetReaderDims(const std::string &name, return this->SetRepeatedDims(arg_names[0], dims); } -std::vector InferShapeContext::GetInputVarPtrs( - const std::string &name) { - const std::vector arg_names = Inputs(name); - std::vector res; - res.reserve(arg_names.size()); - std::transform( - arg_names.begin(), arg_names.end(), std::back_inserter(res), - [this](const std::string &name) { return this->GetVarPtr(name); }); - return res; -} - -std::vector InferShapeContext::GetOutputVarPtrs( - const std::string &name) { - const std::vector arg_names = Outputs(name); - std::vector res; - res.reserve(arg_names.size()); - std::transform( - arg_names.begin(), arg_names.end(), std::back_inserter(res), - [this](const std::string &name) { return this->GetVarPtr(name); }); - return res; -} - std::vector InferShapeContext::GetDims( const std::vector &names) const { std::vector ret; diff --git a/paddle/fluid/framework/shape_inference.h b/paddle/fluid/framework/shape_inference.h index d73cca121e41e68f9fb6548117ed91c5cc1415ca..543696d43b179ed62561f743eb52c50a6bc90a01 100644 --- a/paddle/fluid/framework/shape_inference.h +++ b/paddle/fluid/framework/shape_inference.h @@ -33,22 +33,24 @@ class InferShapeContext { virtual bool HasInput(const std::string &name) const = 0; virtual bool HasOutput(const std::string &name) const = 0; - std::vector GetInputsVarType( + virtual std::vector GetInputsVarType( const std::string &name) const; - std::vector GetOutputsVarType( + virtual std::vector GetOutputsVarType( const std::string &name) const; virtual bool HasInputs(const std::string &name) const = 0; virtual bool HasOutputs(const std::string &name) const = 0; - DDim GetInputDim(const std::string &name) const; - std::vector GetInputsDim(const std::string &name) const; - std::vector GetReaderDims(const std::string &name) const; - DDim GetInputsElementDim(const std::string &name, int idx) const; + virtual DDim GetInputDim(const std::string &name) const; + virtual std::vector GetInputsDim(const std::string &name) const; + virtual std::vector GetReaderDims(const std::string &name) const; + virtual DDim GetInputsElementDim(const std::string &name, int idx) const; - void SetOutputDim(const std::string &name, const DDim &dim); - void SetOutputsDim(const std::string &name, const std::vector &dims); - void SetReaderDims(const std::string &name, const std::vector &dims); + virtual void SetOutputDim(const std::string &name, const DDim &dim); + virtual void SetOutputsDim(const std::string &name, + const std::vector &dims); + virtual void SetReaderDims(const std::string &name, + const std::vector &dims); virtual AttrReader Attrs() const = 0; virtual const std::vector &Inputs( @@ -67,13 +69,14 @@ class InferShapeContext { virtual bool IsRuntime() const = 0; - std::vector GetInputVarPtrs(const std::string &name); - std::vector GetOutputVarPtrs(const std::string &name); - virtual InferShapeVarPtr GetVarPtr(const std::string &name) = 0; + virtual std::vector GetInputVarPtrs( + const std::string &name) = 0; + virtual std::vector GetOutputVarPtrs( + const std::string &name) = 0; // Note: In while op, we need this to be public - void SetDims(const std::vector &names, - const std::vector &dims); + virtual void SetDims(const std::vector &names, + const std::vector &dims); protected: virtual DDim GetDim(const std::string &name) const = 0;