diff --git a/paddle/fluid/framework/var_type_inference.h b/paddle/fluid/framework/var_type_inference.h index b4b7be619abd9ccf854405b5168041bc947401c6..5dd08442c20322cef9a4b17f9e6cae98fa3e8066 100644 --- a/paddle/fluid/framework/var_type_inference.h +++ b/paddle/fluid/framework/var_type_inference.h @@ -31,86 +31,89 @@ class InferVarTypeContext { InferVarTypeContext(const OpDesc* op, BlockDesc* block) : op_(op), block_(block) {} - Attribute GetAttr(const std::string& name) const { + virtual ~InferVarTypeContext() {} + + virtual Attribute GetAttr(const std::string& name) const { PADDLE_ENFORCE_NOT_NULL(op_); return op_->GetAttr(name); } - inline bool HasVar(const std::string& name) const { + virtual bool HasVar(const std::string& name) const { PADDLE_ENFORCE_NOT_NULL(block_); return block_->FindVarRecursive(name) != nullptr; } - inline bool HasInput(const std::string& name) const { + virtual bool HasInput(const std::string& name) const { PADDLE_ENFORCE_NOT_NULL(op_); return op_->Inputs().count(name) > 0; } - inline bool HasOutput(const std::string& name) const { + virtual bool HasOutput(const std::string& name) const { PADDLE_ENFORCE_NOT_NULL(op_); return op_->Outputs().count(name) > 0; } - inline const std::vector& Input(const std::string& name) const { + virtual const std::vector& Input(const std::string& name) const { PADDLE_ENFORCE_NOT_NULL(op_); return op_->Input(name); } - inline const std::vector& Output(const std::string& name) const { + virtual const std::vector& Output( + const std::string& name) const { PADDLE_ENFORCE_NOT_NULL(op_); return op_->Output(name); } - inline proto::VarType::Type GetType(const std::string& name) const { + virtual proto::VarType::Type GetType(const std::string& name) const { PADDLE_ENFORCE_NOT_NULL(block_); return block_->FindRecursiveOrCreateVar(name).GetType(); } - inline void SetType(const std::string& name, proto::VarType::Type type) { + virtual void SetType(const std::string& name, proto::VarType::Type type) { PADDLE_ENFORCE_NOT_NULL(block_); block_->FindRecursiveOrCreateVar(name).SetType(type); } - inline proto::VarType::Type GetDataType(const std::string& name) const { + virtual proto::VarType::Type GetDataType(const std::string& name) const { PADDLE_ENFORCE_NOT_NULL(block_); return block_->FindRecursiveOrCreateVar(name).GetDataType(); } - inline void SetDataType(const std::string& name, proto::VarType::Type type) { + virtual void SetDataType(const std::string& name, proto::VarType::Type type) { PADDLE_ENFORCE_NOT_NULL(block_); block_->FindRecursiveOrCreateVar(name).SetDataType(type); } - inline std::vector GetDataTypes( + virtual std::vector GetDataTypes( const std::string& name) const { PADDLE_ENFORCE_NOT_NULL(block_); return block_->FindRecursiveOrCreateVar(name).GetDataTypes(); } - inline void SetDataTypes( + virtual void SetDataTypes( const std::string& name, const std::vector& multiple_data_type) { PADDLE_ENFORCE_NOT_NULL(block_); block_->FindRecursiveOrCreateVar(name).SetDataTypes(multiple_data_type); } - inline std::vector GetShape(const std::string& name) const { + virtual std::vector GetShape(const std::string& name) const { PADDLE_ENFORCE_NOT_NULL(block_); return block_->FindRecursiveOrCreateVar(name).GetShape(); } - inline void SetShape(const std::string& name, - const std::vector& dims) { + virtual void SetShape(const std::string& name, + const std::vector& dims) { PADDLE_ENFORCE_NOT_NULL(block_); block_->FindRecursiveOrCreateVar(name).SetShape(dims); } - inline int32_t GetLoDLevel(const std::string& name) const { + virtual int32_t GetLoDLevel(const std::string& name) const { PADDLE_ENFORCE_NOT_NULL(block_); return block_->FindRecursiveOrCreateVar(name).GetLoDLevel(); } - inline void SetLoDLevel(const std::string& name, int32_t lod_level) { + virtual void SetLoDLevel(const std::string& name, int32_t lod_level) { PADDLE_ENFORCE_NOT_NULL(block_); block_->FindRecursiveOrCreateVar(name).SetLoDLevel(lod_level); } diff --git a/paddle/fluid/imperative/layer.h b/paddle/fluid/imperative/layer.h index 494988608e73b4eedd30103067e9d47281fd5af2..4ad7d847c182b3c7f93ff0e01cf9307994aab648 100644 --- a/paddle/fluid/imperative/layer.h +++ b/paddle/fluid/imperative/layer.h @@ -377,12 +377,10 @@ class PyLayer { class PYBIND11_HIDDEN RuntimeInferVarTypeContext : public framework::InferVarTypeContext { public: - RuntimeInferVarTypeContext(imperative::OpBase* op, - const imperative::VarBasePtrMap* inputs, + RuntimeInferVarTypeContext(const imperative::VarBasePtrMap* inputs, imperative::VarBasePtrMap* outputs, const framework::AttributeMap* attrs_map) : InferVarTypeContext(nullptr, nullptr), - op_(op), inputs_(inputs), outputs_(outputs), attrs_(attrs_map), @@ -406,83 +404,86 @@ class PYBIND11_HIDDEN RuntimeInferVarTypeContext } } - framework::Attribute GetAttr(const std::string& name) const { + virtual ~RuntimeInferVarTypeContext() {} + + framework::Attribute GetAttr(const std::string& name) const override { PADDLE_ENFORCE_NOT_NULL(attrs_); return attrs_->at(name); } - inline bool HasVar(const std::string& name) const { + bool HasVar(const std::string& name) const override { return var_set_.count(name) > 0; } - inline bool HasInput(const std::string& name) const { + bool HasInput(const std::string& name) const override { PADDLE_ENFORCE_NOT_NULL(inputs_); return inputs_->count(name) > 0; } - inline bool HasOutput(const std::string& name) const { + bool HasOutput(const std::string& name) const override { PADDLE_ENFORCE_NOT_NULL(outputs_); return outputs_->count(name) > 0; } - inline const std::vector& Input(const std::string& name) const { + const std::vector& Input( + const std::string& name) const override { return input_names_.at(name); } - inline const std::vector& Output(const std::string& name) const { + const std::vector& Output( + const std::string& name) const override { return output_names_.at(name); } - inline framework::proto::VarType::Type GetType( - const std::string& name) const { + framework::proto::VarType::Type GetType( + const std::string& name) const override { return var_set_.at(name)->DType(); } - inline void SetType(const std::string& name, - framework::proto::VarType::Type type) { + void SetType(const std::string& name, + framework::proto::VarType::Type type) override { var_set_[name]->SetDType(type); } - inline framework::proto::VarType::Type GetDataType( - const std::string& name) const { + framework::proto::VarType::Type GetDataType( + const std::string& name) const override { return var_set_.at(name)->DType(); } - inline void SetDataType(const std::string& name, - framework::proto::VarType::Type type) { + void SetDataType(const std::string& name, + framework::proto::VarType::Type type) override { var_set_[name]->SetDType(type); } - inline std::vector GetDataTypes( - const std::string& name) const { + std::vector GetDataTypes( + const std::string& name) const override { PADDLE_THROW("GetDataTypes is not supported in runtime InferVarType"); } - inline void SetDataTypes( - const std::string& name, - const std::vector& multiple_data_type) { + void SetDataTypes(const std::string& name, + const std::vector& + multiple_data_type) override { PADDLE_THROW("SetDataTypes is not supported in runtime InferVarType"); } - inline std::vector GetShape(const std::string& name) const { + std::vector GetShape(const std::string& name) const override { PADDLE_THROW("Do not handle Shape in runtime InferVarType"); } - inline void SetShape(const std::string& name, - const std::vector& dims) { + void SetShape(const std::string& name, + const std::vector& dims) override { PADDLE_THROW("Do not handle Shape in runtime InferVarType"); } - inline int32_t GetLoDLevel(const std::string& name) const { + int32_t GetLoDLevel(const std::string& name) const override { PADDLE_THROW("Do not handle LoDLevel in runtime InferVarType"); } - inline void SetLoDLevel(const std::string& name, int32_t lod_level) { + void SetLoDLevel(const std::string& name, int32_t lod_level) override { PADDLE_THROW("Do not handle LoDLevel in runtime InferVarType"); } private: - imperative::OpBase* op_; const imperative::VarBasePtrMap* inputs_; imperative::VarBasePtrMap* outputs_; const framework::AttributeMap* attrs_; diff --git a/paddle/fluid/imperative/tracer.cc b/paddle/fluid/imperative/tracer.cc index 7a07ec358dceea46917014ae1457a800ce52f3f1..166883bd6f9fe57cac40c7e0e7649ea55969abf1 100644 --- a/paddle/fluid/imperative/tracer.cc +++ b/paddle/fluid/imperative/tracer.cc @@ -230,7 +230,7 @@ std::set Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, outvars_name_map, attrs_map); if (info.infer_var_type_) { - RuntimeInferVarTypeContext infer_var_type_ctx(op, &inputs, &outputs, + RuntimeInferVarTypeContext infer_var_type_ctx(&inputs, &outputs, &attrs_map); info.infer_var_type_(infer_var_type_ctx); }