提交 4dd61e72 编写于 作者: X Xin Pan

convert GetInputVarPtrs and GetOutputVarPtrs

test=develop
上级 52d3903a
......@@ -110,6 +110,30 @@ class CompileTimeInferShapeContext : public InferShapeContext {
}
}
std::vector<InferShapeVarPtr> GetInputVarPtrs(
const std::string &name) override {
const std::vector<std::string> arg_names = Inputs(name);
std::vector<InferShapeVarPtr> res;
res.reserve(arg_names.size());
std::transform(arg_names.begin(), arg_names.end(), std::back_inserter(res),
[this](const std::string &name) {
return block_.FindVarRecursive(name);
});
return res;
}
std::vector<InferShapeVarPtr> GetOutputVarPtrs(
const std::string &name) override {
const std::vector<std::string> arg_names = Outputs(name);
std::vector<InferShapeVarPtr> res;
res.reserve(arg_names.size());
std::transform(arg_names.begin(), arg_names.end(), std::back_inserter(res),
[this](const std::string &name) {
return block_.FindVarRecursive(name);
});
return res;
}
bool IsRuntime() const override;
protected:
......@@ -124,8 +148,6 @@ class CompileTimeInferShapeContext : public InferShapeContext {
void SetRepeatedDims(const std::string &name,
const std::vector<DDim> &dims) override;
InferShapeVarPtr GetVarPtr(const std::string &name) override;
const OpDesc &op_;
const BlockDesc &block_;
};
......@@ -696,10 +718,5 @@ proto::VarType::Type CompileTimeInferShapeContext::GetVarType(
return block_.FindVarRecursive(name)->GetType();
}
InferShapeVarPtr CompileTimeInferShapeContext::GetVarPtr(
const std::string &name) {
return block_.FindVarRecursive(name);
}
} // namespace framework
} // namespace paddle
......@@ -691,6 +691,25 @@ class RuntimeInferShapeContext : public InferShapeContext {
bool IsRuntime() const override { return true; }
// TODO(paddle-dev): Can this be template?
std::vector<InferShapeVarPtr> GetInputVarPtrs(
const std::string& name) override {
const std::vector<Variable*>& vars = InputVars(name);
std::vector<InferShapeVarPtr> res;
res.reserve(vars.size());
res.insert(res.begin(), vars.begin(), vars.end());
return res;
}
std::vector<InferShapeVarPtr> GetOutputVarPtrs(
const std::string& name) override {
const std::vector<Variable*>& vars = OutputVars(name);
std::vector<InferShapeVarPtr> res;
res.reserve(vars.size());
res.insert(res.begin(), vars.begin(), vars.end());
return res;
}
protected:
DDim GetDim(const std::string& name) const override {
Variable* var = scope_.FindVar(name);
......@@ -733,11 +752,22 @@ class RuntimeInferShapeContext : public InferShapeContext {
return ToVarType(var->Type());
}
InferShapeVarPtr GetVarPtr(const std::string& name) override {
return scope_.FindVar(name);
private:
const std::vector<Variable*>& InputVars(const std::string& name) const {
auto it = ctx_.inputs.find(name);
PADDLE_ENFORCE(it != ctx_.inputs.end(),
"Operator %s does not have the input %s.", op_.Type(), name);
return it->second;
}
const std::vector<Variable*>& OutputVars(const std::string& name) const {
auto it = ctx_.outputs.find(name);
PADDLE_ENFORCE(it != ctx_.outputs.end(),
"Operator %s does not have the outputs %s.", op_.Type(),
name);
return it->second;
}
private:
const OperatorBase& op_;
const Scope& scope_;
const RuntimeContext& ctx_;
......
......@@ -76,28 +76,6 @@ void InferShapeContext::SetReaderDims(const std::string &name,
return this->SetRepeatedDims(arg_names[0], dims);
}
std::vector<InferShapeVarPtr> InferShapeContext::GetInputVarPtrs(
const std::string &name) {
const std::vector<std::string> arg_names = Inputs(name);
std::vector<InferShapeVarPtr> res;
res.reserve(arg_names.size());
std::transform(
arg_names.begin(), arg_names.end(), std::back_inserter(res),
[this](const std::string &name) { return this->GetVarPtr(name); });
return res;
}
std::vector<InferShapeVarPtr> InferShapeContext::GetOutputVarPtrs(
const std::string &name) {
const std::vector<std::string> arg_names = Outputs(name);
std::vector<InferShapeVarPtr> res;
res.reserve(arg_names.size());
std::transform(
arg_names.begin(), arg_names.end(), std::back_inserter(res),
[this](const std::string &name) { return this->GetVarPtr(name); });
return res;
}
std::vector<DDim> InferShapeContext::GetDims(
const std::vector<std::string> &names) const {
std::vector<DDim> ret;
......
......@@ -33,22 +33,24 @@ class InferShapeContext {
virtual bool HasInput(const std::string &name) const = 0;
virtual bool HasOutput(const std::string &name) const = 0;
std::vector<proto::VarType::Type> GetInputsVarType(
virtual std::vector<proto::VarType::Type> GetInputsVarType(
const std::string &name) const;
std::vector<proto::VarType::Type> GetOutputsVarType(
virtual std::vector<proto::VarType::Type> GetOutputsVarType(
const std::string &name) const;
virtual bool HasInputs(const std::string &name) const = 0;
virtual bool HasOutputs(const std::string &name) const = 0;
DDim GetInputDim(const std::string &name) const;
std::vector<DDim> GetInputsDim(const std::string &name) const;
std::vector<DDim> GetReaderDims(const std::string &name) const;
DDim GetInputsElementDim(const std::string &name, int idx) const;
virtual DDim GetInputDim(const std::string &name) const;
virtual std::vector<DDim> GetInputsDim(const std::string &name) const;
virtual std::vector<DDim> GetReaderDims(const std::string &name) const;
virtual DDim GetInputsElementDim(const std::string &name, int idx) const;
void SetOutputDim(const std::string &name, const DDim &dim);
void SetOutputsDim(const std::string &name, const std::vector<DDim> &dims);
void SetReaderDims(const std::string &name, const std::vector<DDim> &dims);
virtual void SetOutputDim(const std::string &name, const DDim &dim);
virtual void SetOutputsDim(const std::string &name,
const std::vector<DDim> &dims);
virtual void SetReaderDims(const std::string &name,
const std::vector<DDim> &dims);
virtual AttrReader Attrs() const = 0;
virtual const std::vector<std::string> &Inputs(
......@@ -67,13 +69,14 @@ class InferShapeContext {
virtual bool IsRuntime() const = 0;
std::vector<InferShapeVarPtr> GetInputVarPtrs(const std::string &name);
std::vector<InferShapeVarPtr> GetOutputVarPtrs(const std::string &name);
virtual InferShapeVarPtr GetVarPtr(const std::string &name) = 0;
virtual std::vector<InferShapeVarPtr> GetInputVarPtrs(
const std::string &name) = 0;
virtual std::vector<InferShapeVarPtr> GetOutputVarPtrs(
const std::string &name) = 0;
// Note: In while op, we need this to be public
void SetDims(const std::vector<std::string> &names,
const std::vector<DDim> &dims);
virtual void SetDims(const std::vector<std::string> &names,
const std::vector<DDim> &dims);
protected:
virtual DDim GetDim(const std::string &name) const = 0;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册