提交 9ef8a768 编写于 作者: X Xin Pan

convert more

test=develop
上级 87699388
...@@ -149,8 +149,29 @@ class CompileTimeInferShapeContext : public InferShapeContext { ...@@ -149,8 +149,29 @@ class CompileTimeInferShapeContext : public InferShapeContext {
bool IsRuntime() const override; bool IsRuntime() const override;
std::vector<proto::VarType::Type> GetInputsVarType(
const std::string &name) const override {
return GetVarTypes(Inputs(name));
}
std::vector<proto::VarType::Type> GetOutputsVarType(
const std::string &name) const override {
return GetVarTypes(Outputs(name));
}
protected: protected:
proto::VarType::Type GetVarType(const std::string &name) const override; std::vector<proto::VarType::Type> GetVarTypes(
const std::vector<std::string> &names) const {
std::vector<proto::VarType::Type> retv;
retv.resize(names.size());
std::transform(
names.begin(), names.end(), retv.begin(),
std::bind(std::mem_fn(&CompileTimeInferShapeContext::GetVarType), this,
std::placeholders::_1));
return retv;
}
proto::VarType::Type GetVarType(const std::string &name) const;
DDim GetDim(const std::string &name) const { DDim GetDim(const std::string &name) const {
auto var = block_.FindVarRecursive(name); auto var = block_.FindVarRecursive(name);
......
...@@ -722,6 +722,16 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -722,6 +722,16 @@ class RuntimeInferShapeContext : public InferShapeContext {
return GetDims(vars); return GetDims(vars);
} }
std::vector<proto::VarType::Type> GetInputsVarType(
const std::string& name) const override {
return GetVarTypes(InputVars(name));
}
std::vector<proto::VarType::Type> GetOutputsVarType(
const std::string& name) const override {
return GetVarTypes(OutputVars(name));
}
protected: protected:
DDim GetDim(Variable* var) const { DDim GetDim(Variable* var) const {
PADDLE_ENFORCE_NOT_NULL(var); PADDLE_ENFORCE_NOT_NULL(var);
...@@ -766,8 +776,17 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -766,8 +776,17 @@ class RuntimeInferShapeContext : public InferShapeContext {
PADDLE_THROW("Only compile time support this method"); PADDLE_THROW("Only compile time support this method");
} }
proto::VarType::Type GetVarType(const std::string& name) const override { std::vector<proto::VarType::Type> GetVarTypes(
auto* var = scope_.FindVar(name); const std::vector<Variable*>& vars) const {
std::vector<proto::VarType::Type> retv;
retv.resize(vars.size());
std::transform(vars.begin(), vars.end(), retv.begin(),
std::bind(std::mem_fn(&RuntimeInferShapeContext::GetVarType),
this, std::placeholders::_1));
return retv;
}
proto::VarType::Type GetVarType(Variable* var) const {
return ToVarType(var->Type()); return ToVarType(var->Type());
} }
......
...@@ -68,25 +68,5 @@ void InferShapeContext::SetDims(const std::vector<std::string> &names, ...@@ -68,25 +68,5 @@ void InferShapeContext::SetDims(const std::vector<std::string> &names,
} }
} }
std::vector<proto::VarType::Type> InferShapeContext::GetInputsVarType(
const std::string &name) const {
return GetVarTypes(Inputs(name));
}
std::vector<proto::VarType::Type> InferShapeContext::GetOutputsVarType(
const std::string &name) const {
return GetVarTypes(Outputs(name));
}
std::vector<proto::VarType::Type> InferShapeContext::GetVarTypes(
const std::vector<std::string> &names) const {
std::vector<proto::VarType::Type> retv;
retv.resize(names.size());
std::transform(names.begin(), names.end(), retv.begin(),
std::bind(std::mem_fn(&InferShapeContext::GetVarType), this,
std::placeholders::_1));
return retv;
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -34,9 +34,9 @@ class InferShapeContext { ...@@ -34,9 +34,9 @@ class InferShapeContext {
virtual bool HasOutput(const std::string &name) const = 0; virtual bool HasOutput(const std::string &name) const = 0;
virtual std::vector<proto::VarType::Type> GetInputsVarType( virtual std::vector<proto::VarType::Type> GetInputsVarType(
const std::string &name) const; const std::string &name) const = 0;
virtual std::vector<proto::VarType::Type> GetOutputsVarType( virtual std::vector<proto::VarType::Type> GetOutputsVarType(
const std::string &name) const; const std::string &name) const = 0;
virtual bool HasInputs(const std::string &name) const = 0; virtual bool HasInputs(const std::string &name) const = 0;
virtual bool HasOutputs(const std::string &name) const = 0; virtual bool HasOutputs(const std::string &name) const = 0;
...@@ -82,11 +82,6 @@ class InferShapeContext { ...@@ -82,11 +82,6 @@ class InferShapeContext {
virtual std::vector<DDim> GetRepeatedDims(const std::string &name) const = 0; virtual std::vector<DDim> GetRepeatedDims(const std::string &name) const = 0;
virtual void SetRepeatedDims(const std::string &name, virtual void SetRepeatedDims(const std::string &name,
const std::vector<DDim> &dims) = 0; const std::vector<DDim> &dims) = 0;
std::vector<proto::VarType::Type> GetVarTypes(
const std::vector<std::string> &names) const;
virtual proto::VarType::Type GetVarType(const std::string &name) const = 0;
}; };
} // namespace framework } // namespace framework
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册