From 0575fd4647bf414662d31c02371a68689273b22c Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Fri, 2 Feb 2018 17:31:37 +0800 Subject: [PATCH] simplify shape inference code --- paddle/framework/op_desc.cc | 19 ------------------- paddle/framework/operator.cc | 8 -------- paddle/framework/shape_inference.cc | 23 +++++++++++++++++++---- paddle/framework/shape_inference.h | 8 +++----- 4 files changed, 22 insertions(+), 36 deletions(-) diff --git a/paddle/framework/op_desc.cc b/paddle/framework/op_desc.cc index f8df2cf97a..f554c77845 100644 --- a/paddle/framework/op_desc.cc +++ b/paddle/framework/op_desc.cc @@ -39,10 +39,6 @@ class CompileTimeInferShapeContext : public InferShapeContext { bool HasOutputs(const std::string &name) const override; - DDim GetInputDim(const std::string &name) const override; - - void SetOutputDim(const std::string &name, const DDim &dim) override; - AttrReader Attrs() const override; const std::vector &Inputs( @@ -444,21 +440,6 @@ bool CompileTimeInferShapeContext::HasOutputs(const std::string &name) const { return true; } -DDim CompileTimeInferShapeContext::GetInputDim(const std::string &name) const { - std::vector ddims = GetInputsDim(name); - auto length = ddims.size(); - PADDLE_ENFORCE_EQ(length, 1UL, - "Input(%s) should have 1 value, " - "but it has %d now", - name, length); - return ddims[0]; -} - -void CompileTimeInferShapeContext::SetOutputDim(const std::string &name, - const DDim &dim) { - SetOutputsDim(name, {dim}); -} - AttrReader CompileTimeInferShapeContext::Attrs() const { return AttrReader(op_.GetAttrMap()); } diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 4e854f54dd..81fa8cf477 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -366,14 +366,6 @@ class RuntimeInferShapeContext : public InferShapeContext { return true; } - DDim GetInputDim(const std::string& name) const override { - return GetDim(op_.Input(name)); - } - - void SetOutputDim(const std::string& name, const DDim& dim) override { - SetDim(op_.Output(name), dim); - } - AttrReader Attrs() const override { return AttrReader(op_.Attrs()); } const std::vector& Inputs( diff --git a/paddle/framework/shape_inference.cc b/paddle/framework/shape_inference.cc index e53cc0cdab..14dba75808 100644 --- a/paddle/framework/shape_inference.cc +++ b/paddle/framework/shape_inference.cc @@ -18,10 +18,18 @@ limitations under the License. */ namespace paddle { namespace framework { +framework::DDim InferShapeContext::GetInputDim(const std::string &name) const { + const std::vector &arg_names = Inputs(name); + PADDLE_ENFORCE_EQ(arg_names.size(), 1UL, + "Input(%s) shoudl holds one element, but now it holds %d", + name, arg_names.size()); + return this->GetDim(arg_names[0]); +} + std::vector InferShapeContext::GetInputsDim( const std::string &name) const { - const std::vector &names = Inputs(name); - return GetDims(names); + const std::vector &arg_names = Inputs(name); + return GetDims(arg_names); } DDim InferShapeContext::GetInputsElementDim(const std::string &name, @@ -30,13 +38,21 @@ DDim InferShapeContext::GetInputsElementDim(const std::string &name, return this->GetDim(names[idx]); } +void InferShapeContext::SetOutputDim(const std::string &name, const DDim &dim) { + auto &arg_names = Outputs(name); + PADDLE_ENFORCE_EQ(arg_names.size(), 1UL, + "Output(%s) shoudl holds one element, but now it holds %d", + name, arg_names.size()); + SetDim(arg_names[0], dim); +} + void InferShapeContext::SetOutputsDim( const std::string &name, const std::vector &dims) { auto &names = Outputs(name); SetDims(names, dims); } -std::vector InferShapeContext::GetDims( +std::vector InferShapeContext::GetDims( const std::vector &names) const { std::vector ret; ret.reserve(names.size()); @@ -45,7 +61,6 @@ std::vector InferShapeContext::GetDims( [this](const std::string &name) { return this->GetDim(name); }); return ret; } - void InferShapeContext::SetDims(const std::vector &names, const std::vector &dims) { size_t length = names.size(); diff --git a/paddle/framework/shape_inference.h b/paddle/framework/shape_inference.h index f93319d8f2..77fc9359be 100644 --- a/paddle/framework/shape_inference.h +++ b/paddle/framework/shape_inference.h @@ -35,12 +35,12 @@ class InferShapeContext { virtual bool HasInputs(const std::string &name) const = 0; virtual bool HasOutputs(const std::string &name) const = 0; - virtual framework::DDim GetInputDim(const std::string &name) const = 0; + framework::DDim GetInputDim(const std::string &name) const; std::vector GetInputsDim(const std::string &name) const; DDim GetInputsElementDim(const std::string &name, int idx) const; - virtual void SetOutputDim(const std::string &name, const DDim &dim) = 0; + void SetOutputDim(const std::string &name, const DDim &dim); void SetOutputsDim(const std::string &name, const std::vector &dims); @@ -63,9 +63,7 @@ class InferShapeContext { virtual framework::DDim GetDim(const std::string &name) const = 0; virtual void SetDim(const std::string &name, const framework::DDim &dim) = 0; - std::vector GetDims( - const std::vector &names) const; - + std::vector GetDims(const std::vector &names) const; std::vector GetVarTypes( const std::vector &names) const; -- GitLab