提交 b5078c21 编写于 作者: M minqiyang

Make infer var type virtual

test=develop
上级 9041b238
......@@ -31,86 +31,89 @@ class InferVarTypeContext {
InferVarTypeContext(const OpDesc* op, BlockDesc* block)
: op_(op), block_(block) {}
Attribute GetAttr(const std::string& name) const {
virtual ~InferVarTypeContext() {}
virtual Attribute GetAttr(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(op_);
return op_->GetAttr(name);
}
inline bool HasVar(const std::string& name) const {
virtual bool HasVar(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_);
return block_->FindVarRecursive(name) != nullptr;
}
inline bool HasInput(const std::string& name) const {
virtual bool HasInput(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(op_);
return op_->Inputs().count(name) > 0;
}
inline bool HasOutput(const std::string& name) const {
virtual bool HasOutput(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(op_);
return op_->Outputs().count(name) > 0;
}
inline const std::vector<std::string>& Input(const std::string& name) const {
virtual const std::vector<std::string>& Input(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(op_);
return op_->Input(name);
}
inline const std::vector<std::string>& Output(const std::string& name) const {
virtual const std::vector<std::string>& Output(
const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(op_);
return op_->Output(name);
}
inline proto::VarType::Type GetType(const std::string& name) const {
virtual proto::VarType::Type GetType(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_);
return block_->FindRecursiveOrCreateVar(name).GetType();
}
inline void SetType(const std::string& name, proto::VarType::Type type) {
virtual void SetType(const std::string& name, proto::VarType::Type type) {
PADDLE_ENFORCE_NOT_NULL(block_);
block_->FindRecursiveOrCreateVar(name).SetType(type);
}
inline proto::VarType::Type GetDataType(const std::string& name) const {
virtual proto::VarType::Type GetDataType(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_);
return block_->FindRecursiveOrCreateVar(name).GetDataType();
}
inline void SetDataType(const std::string& name, proto::VarType::Type type) {
virtual void SetDataType(const std::string& name, proto::VarType::Type type) {
PADDLE_ENFORCE_NOT_NULL(block_);
block_->FindRecursiveOrCreateVar(name).SetDataType(type);
}
inline std::vector<proto::VarType::Type> GetDataTypes(
virtual std::vector<proto::VarType::Type> GetDataTypes(
const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_);
return block_->FindRecursiveOrCreateVar(name).GetDataTypes();
}
inline void SetDataTypes(
virtual void SetDataTypes(
const std::string& name,
const std::vector<proto::VarType::Type>& multiple_data_type) {
PADDLE_ENFORCE_NOT_NULL(block_);
block_->FindRecursiveOrCreateVar(name).SetDataTypes(multiple_data_type);
}
inline std::vector<int64_t> GetShape(const std::string& name) const {
virtual std::vector<int64_t> GetShape(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_);
return block_->FindRecursiveOrCreateVar(name).GetShape();
}
inline void SetShape(const std::string& name,
const std::vector<int64_t>& dims) {
virtual void SetShape(const std::string& name,
const std::vector<int64_t>& dims) {
PADDLE_ENFORCE_NOT_NULL(block_);
block_->FindRecursiveOrCreateVar(name).SetShape(dims);
}
inline int32_t GetLoDLevel(const std::string& name) const {
virtual int32_t GetLoDLevel(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_);
return block_->FindRecursiveOrCreateVar(name).GetLoDLevel();
}
inline void SetLoDLevel(const std::string& name, int32_t lod_level) {
virtual void SetLoDLevel(const std::string& name, int32_t lod_level) {
PADDLE_ENFORCE_NOT_NULL(block_);
block_->FindRecursiveOrCreateVar(name).SetLoDLevel(lod_level);
}
......
......@@ -377,12 +377,10 @@ class PyLayer {
class PYBIND11_HIDDEN RuntimeInferVarTypeContext
: public framework::InferVarTypeContext {
public:
RuntimeInferVarTypeContext(imperative::OpBase* op,
const imperative::VarBasePtrMap* inputs,
RuntimeInferVarTypeContext(const imperative::VarBasePtrMap* inputs,
imperative::VarBasePtrMap* outputs,
const framework::AttributeMap* attrs_map)
: InferVarTypeContext(nullptr, nullptr),
op_(op),
inputs_(inputs),
outputs_(outputs),
attrs_(attrs_map),
......@@ -406,83 +404,86 @@ class PYBIND11_HIDDEN RuntimeInferVarTypeContext
}
}
framework::Attribute GetAttr(const std::string& name) const {
virtual ~RuntimeInferVarTypeContext() {}
framework::Attribute GetAttr(const std::string& name) const override {
PADDLE_ENFORCE_NOT_NULL(attrs_);
return attrs_->at(name);
}
inline bool HasVar(const std::string& name) const {
bool HasVar(const std::string& name) const override {
return var_set_.count(name) > 0;
}
inline bool HasInput(const std::string& name) const {
bool HasInput(const std::string& name) const override {
PADDLE_ENFORCE_NOT_NULL(inputs_);
return inputs_->count(name) > 0;
}
inline bool HasOutput(const std::string& name) const {
bool HasOutput(const std::string& name) const override {
PADDLE_ENFORCE_NOT_NULL(outputs_);
return outputs_->count(name) > 0;
}
inline const std::vector<std::string>& Input(const std::string& name) const {
const std::vector<std::string>& Input(
const std::string& name) const override {
return input_names_.at(name);
}
inline const std::vector<std::string>& Output(const std::string& name) const {
const std::vector<std::string>& Output(
const std::string& name) const override {
return output_names_.at(name);
}
inline framework::proto::VarType::Type GetType(
const std::string& name) const {
framework::proto::VarType::Type GetType(
const std::string& name) const override {
return var_set_.at(name)->DType();
}
inline void SetType(const std::string& name,
framework::proto::VarType::Type type) {
void SetType(const std::string& name,
framework::proto::VarType::Type type) override {
var_set_[name]->SetDType(type);
}
inline framework::proto::VarType::Type GetDataType(
const std::string& name) const {
framework::proto::VarType::Type GetDataType(
const std::string& name) const override {
return var_set_.at(name)->DType();
}
inline void SetDataType(const std::string& name,
framework::proto::VarType::Type type) {
void SetDataType(const std::string& name,
framework::proto::VarType::Type type) override {
var_set_[name]->SetDType(type);
}
inline std::vector<framework::proto::VarType::Type> GetDataTypes(
const std::string& name) const {
std::vector<framework::proto::VarType::Type> GetDataTypes(
const std::string& name) const override {
PADDLE_THROW("GetDataTypes is not supported in runtime InferVarType");
}
inline void SetDataTypes(
const std::string& name,
const std::vector<framework::proto::VarType::Type>& multiple_data_type) {
void SetDataTypes(const std::string& name,
const std::vector<framework::proto::VarType::Type>&
multiple_data_type) override {
PADDLE_THROW("SetDataTypes is not supported in runtime InferVarType");
}
inline std::vector<int64_t> GetShape(const std::string& name) const {
std::vector<int64_t> GetShape(const std::string& name) const override {
PADDLE_THROW("Do not handle Shape in runtime InferVarType");
}
inline void SetShape(const std::string& name,
const std::vector<int64_t>& dims) {
void SetShape(const std::string& name,
const std::vector<int64_t>& dims) override {
PADDLE_THROW("Do not handle Shape in runtime InferVarType");
}
inline int32_t GetLoDLevel(const std::string& name) const {
int32_t GetLoDLevel(const std::string& name) const override {
PADDLE_THROW("Do not handle LoDLevel in runtime InferVarType");
}
inline void SetLoDLevel(const std::string& name, int32_t lod_level) {
void SetLoDLevel(const std::string& name, int32_t lod_level) override {
PADDLE_THROW("Do not handle LoDLevel in runtime InferVarType");
}
private:
imperative::OpBase* op_;
const imperative::VarBasePtrMap* inputs_;
imperative::VarBasePtrMap* outputs_;
const framework::AttributeMap* attrs_;
......
......@@ -230,7 +230,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
outvars_name_map, attrs_map);
if (info.infer_var_type_) {
RuntimeInferVarTypeContext infer_var_type_ctx(op, &inputs, &outputs,
RuntimeInferVarTypeContext infer_var_type_ctx(&inputs, &outputs,
&attrs_map);
info.infer_var_type_(infer_var_type_ctx);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册