未验证 提交 71bd0dfa 编写于 作者: F fengjiayi 提交者: GitHub

Merge pull request #8087 from JiayiFeng/simplify_infershape_code

simplify shape inference code
...@@ -39,10 +39,6 @@ class CompileTimeInferShapeContext : public InferShapeContext { ...@@ -39,10 +39,6 @@ class CompileTimeInferShapeContext : public InferShapeContext {
bool HasOutputs(const std::string &name) const override; bool HasOutputs(const std::string &name) const override;
DDim GetInputDim(const std::string &name) const override;
void SetOutputDim(const std::string &name, const DDim &dim) override;
AttrReader Attrs() const override; AttrReader Attrs() const override;
const std::vector<std::string> &Inputs( const std::vector<std::string> &Inputs(
...@@ -444,21 +440,6 @@ bool CompileTimeInferShapeContext::HasOutputs(const std::string &name) const { ...@@ -444,21 +440,6 @@ bool CompileTimeInferShapeContext::HasOutputs(const std::string &name) const {
return true; return true;
} }
DDim CompileTimeInferShapeContext::GetInputDim(const std::string &name) const {
std::vector<DDim> ddims = GetInputsDim(name);
auto length = ddims.size();
PADDLE_ENFORCE_EQ(length, 1UL,
"Input(%s) should have 1 value, "
"but it has %d now",
name, length);
return ddims[0];
}
void CompileTimeInferShapeContext::SetOutputDim(const std::string &name,
const DDim &dim) {
SetOutputsDim(name, {dim});
}
AttrReader CompileTimeInferShapeContext::Attrs() const { AttrReader CompileTimeInferShapeContext::Attrs() const {
return AttrReader(op_.GetAttrMap()); return AttrReader(op_.GetAttrMap());
} }
......
...@@ -366,14 +366,6 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -366,14 +366,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
return true; return true;
} }
DDim GetInputDim(const std::string& name) const override {
return GetDim(op_.Input(name));
}
void SetOutputDim(const std::string& name, const DDim& dim) override {
SetDim(op_.Output(name), dim);
}
AttrReader Attrs() const override { return AttrReader(op_.Attrs()); } AttrReader Attrs() const override { return AttrReader(op_.Attrs()); }
const std::vector<std::string>& Inputs( const std::vector<std::string>& Inputs(
......
...@@ -18,10 +18,18 @@ limitations under the License. */ ...@@ -18,10 +18,18 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
std::vector<framework::DDim> InferShapeContext::GetInputsDim( DDim InferShapeContext::GetInputDim(const std::string &name) const {
const std::vector<std::string> &arg_names = Inputs(name);
PADDLE_ENFORCE_EQ(arg_names.size(), 1UL,
"Input(%s) should hold one element, but now it holds %d",
name, arg_names.size());
return this->GetDim(arg_names[0]);
}
std::vector<DDim> InferShapeContext::GetInputsDim(
const std::string &name) const { const std::string &name) const {
const std::vector<std::string> &names = Inputs(name); const std::vector<std::string> &arg_names = Inputs(name);
return GetDims(names); return GetDims(arg_names);
} }
DDim InferShapeContext::GetInputsElementDim(const std::string &name, DDim InferShapeContext::GetInputsElementDim(const std::string &name,
...@@ -30,24 +38,31 @@ DDim InferShapeContext::GetInputsElementDim(const std::string &name, ...@@ -30,24 +38,31 @@ DDim InferShapeContext::GetInputsElementDim(const std::string &name,
return this->GetDim(names[idx]); return this->GetDim(names[idx]);
} }
void InferShapeContext::SetOutputsDim( void InferShapeContext::SetOutputDim(const std::string &name, const DDim &dim) {
const std::string &name, const std::vector<framework::DDim> &dims) { auto &arg_names = Outputs(name);
PADDLE_ENFORCE_EQ(arg_names.size(), 1UL,
"Output(%s) should hold one element, but now it holds %d",
name, arg_names.size());
SetDim(arg_names[0], dim);
}
void InferShapeContext::SetOutputsDim(const std::string &name,
const std::vector<DDim> &dims) {
auto &names = Outputs(name); auto &names = Outputs(name);
SetDims(names, dims); SetDims(names, dims);
} }
std::vector<framework::DDim> InferShapeContext::GetDims( std::vector<DDim> InferShapeContext::GetDims(
const std::vector<std::string> &names) const { const std::vector<std::string> &names) const {
std::vector<framework::DDim> ret; std::vector<DDim> ret;
ret.reserve(names.size()); ret.reserve(names.size());
std::transform( std::transform(
names.begin(), names.end(), std::back_inserter(ret), names.begin(), names.end(), std::back_inserter(ret),
[this](const std::string &name) { return this->GetDim(name); }); [this](const std::string &name) { return this->GetDim(name); });
return ret; return ret;
} }
void InferShapeContext::SetDims(const std::vector<std::string> &names, void InferShapeContext::SetDims(const std::vector<std::string> &names,
const std::vector<framework::DDim> &dims) { const std::vector<DDim> &dims) {
size_t length = names.size(); size_t length = names.size();
PADDLE_ENFORCE_EQ(length, dims.size()); PADDLE_ENFORCE_EQ(length, dims.size());
for (size_t i = 0; i < length; ++i) { for (size_t i = 0; i < length; ++i) {
......
...@@ -35,14 +35,13 @@ class InferShapeContext { ...@@ -35,14 +35,13 @@ class InferShapeContext {
virtual bool HasInputs(const std::string &name) const = 0; virtual bool HasInputs(const std::string &name) const = 0;
virtual bool HasOutputs(const std::string &name) const = 0; virtual bool HasOutputs(const std::string &name) const = 0;
virtual framework::DDim GetInputDim(const std::string &name) const = 0; DDim GetInputDim(const std::string &name) const;
std::vector<framework::DDim> GetInputsDim(const std::string &name) const; std::vector<DDim> GetInputsDim(const std::string &name) const;
DDim GetInputsElementDim(const std::string &name, int idx) const; DDim GetInputsElementDim(const std::string &name, int idx) const;
virtual void SetOutputDim(const std::string &name, const DDim &dim) = 0; void SetOutputDim(const std::string &name, const DDim &dim);
void SetOutputsDim(const std::string &name, void SetOutputsDim(const std::string &name, const std::vector<DDim> &dims);
const std::vector<framework::DDim> &dims);
virtual AttrReader Attrs() const = 0; virtual AttrReader Attrs() const = 0;
virtual const std::vector<std::string> &Inputs( virtual const std::vector<std::string> &Inputs(
...@@ -57,15 +56,13 @@ class InferShapeContext { ...@@ -57,15 +56,13 @@ class InferShapeContext {
// Note: In while op, we need this to be public // Note: In while op, we need this to be public
void SetDims(const std::vector<std::string> &names, void SetDims(const std::vector<std::string> &names,
const std::vector<framework::DDim> &dims); const std::vector<DDim> &dims);
protected: protected:
virtual framework::DDim GetDim(const std::string &name) const = 0; virtual DDim GetDim(const std::string &name) const = 0;
virtual void SetDim(const std::string &name, const framework::DDim &dim) = 0; virtual void SetDim(const std::string &name, const DDim &dim) = 0;
std::vector<framework::DDim> GetDims(
const std::vector<std::string> &names) const;
std::vector<DDim> GetDims(const std::vector<std::string> &names) const;
std::vector<proto::VarDesc::VarType> GetVarTypes( std::vector<proto::VarDesc::VarType> GetVarTypes(
const std::vector<std::string> &names) const; const std::vector<std::string> &names) const;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册