diff --git a/paddle/framework/shape_inference.cc b/paddle/framework/shape_inference.cc index 0e17219e4ebafba0e79e6cfbcc3579469917ff8e..a0fa467291bb42c59b65f5efeabe9c2235e15b2a 100644 --- a/paddle/framework/shape_inference.cc +++ b/paddle/framework/shape_inference.cc @@ -18,7 +18,7 @@ limitations under the License. */ namespace paddle { namespace framework { -framework::DDim InferShapeContext::GetInputDim(const std::string &name) const { +DDim InferShapeContext::GetInputDim(const std::string &name) const { const std::vector &arg_names = Inputs(name); PADDLE_ENFORCE_EQ(arg_names.size(), 1UL, "Input(%s) should hold one element, but now it holds %d", @@ -26,7 +26,7 @@ framework::DDim InferShapeContext::GetInputDim(const std::string &name) const { return this->GetDim(arg_names[0]); } -std::vector InferShapeContext::GetInputsDim( +std::vector InferShapeContext::GetInputsDim( const std::string &name) const { const std::vector &arg_names = Inputs(name); return GetDims(arg_names); @@ -46,15 +46,15 @@ void InferShapeContext::SetOutputDim(const std::string &name, const DDim &dim) { SetDim(arg_names[0], dim); } -void InferShapeContext::SetOutputsDim( - const std::string &name, const std::vector &dims) { +void InferShapeContext::SetOutputsDim(const std::string &name, + const std::vector &dims) { auto &names = Outputs(name); SetDims(names, dims); } std::vector InferShapeContext::GetDims( const std::vector &names) const { - std::vector ret; + std::vector ret; ret.reserve(names.size()); std::transform( names.begin(), names.end(), std::back_inserter(ret), @@ -62,7 +62,7 @@ std::vector InferShapeContext::GetDims( return ret; } void InferShapeContext::SetDims(const std::vector &names, - const std::vector &dims) { + const std::vector &dims) { size_t length = names.size(); PADDLE_ENFORCE_EQ(length, dims.size()); for (size_t i = 0; i < length; ++i) { diff --git a/paddle/framework/shape_inference.h b/paddle/framework/shape_inference.h index 77fc9359be5eee7b62bff6e65deee7766e1461a3..830f199ed1451538f12fc8dd34fb7b2bfc356a71 100644 --- a/paddle/framework/shape_inference.h +++ b/paddle/framework/shape_inference.h @@ -35,14 +35,13 @@ class InferShapeContext { virtual bool HasInputs(const std::string &name) const = 0; virtual bool HasOutputs(const std::string &name) const = 0; - framework::DDim GetInputDim(const std::string &name) const; + DDim GetInputDim(const std::string &name) const; - std::vector GetInputsDim(const std::string &name) const; + std::vector GetInputsDim(const std::string &name) const; DDim GetInputsElementDim(const std::string &name, int idx) const; void SetOutputDim(const std::string &name, const DDim &dim); - void SetOutputsDim(const std::string &name, - const std::vector &dims); + void SetOutputsDim(const std::string &name, const std::vector &dims); virtual AttrReader Attrs() const = 0; virtual const std::vector &Inputs( @@ -57,11 +56,11 @@ class InferShapeContext { // Note: In while op, we need this to be public void SetDims(const std::vector &names, - const std::vector &dims); + const std::vector &dims); protected: - virtual framework::DDim GetDim(const std::string &name) const = 0; - virtual void SetDim(const std::string &name, const framework::DDim &dim) = 0; + virtual DDim GetDim(const std::string &name) const = 0; + virtual void SetDim(const std::string &name, const DDim &dim) = 0; std::vector GetDims(const std::vector &names) const; std::vector GetVarTypes(