From a38c1512437531d429d25254e774c2f9bc29e31e Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Thu, 30 Nov 2017 17:20:39 +0800 Subject: [PATCH] Add GetInputsElementDim (#6091) --- paddle/framework/shape_inference.cc | 6 ++++++ paddle/framework/shape_inference.h | 1 + paddle/operators/while_op.cc | 6 +++--- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/paddle/framework/shape_inference.cc b/paddle/framework/shape_inference.cc index 0af41b164f5..2298507471c 100644 --- a/paddle/framework/shape_inference.cc +++ b/paddle/framework/shape_inference.cc @@ -22,6 +22,12 @@ std::vector InferShapeContext::GetInputsDim( return GetDims(names); } +DDim InferShapeContext::GetInputsElementDim(const std::string &name, + int idx) const { + const std::vector &names = Inputs(name); + return this->GetDim(names[idx]); +} + void InferShapeContext::SetOutputsDim( const std::string &name, const std::vector &dims) { auto &names = Outputs(name); diff --git a/paddle/framework/shape_inference.h b/paddle/framework/shape_inference.h index 05dc47f06ac..46f2ea84b4b 100644 --- a/paddle/framework/shape_inference.h +++ b/paddle/framework/shape_inference.h @@ -37,6 +37,7 @@ class InferShapeContext { virtual framework::DDim GetInputDim(const std::string &name) const = 0; std::vector GetInputsDim(const std::string &name) const; + DDim GetInputsElementDim(const std::string &name, int idx) const; virtual void SetOutputDim(const std::string &name, const DDim &dim) = 0; void SetOutputsDim(const std::string &name, diff --git a/paddle/operators/while_op.cc b/paddle/operators/while_op.cc index 68b4f770599..59460f6c879 100644 --- a/paddle/operators/while_op.cc +++ b/paddle/operators/while_op.cc @@ -287,7 +287,6 @@ class WhileGradOpShapeInference : public framework::InferShapeBase { auto p_names = ctx->Inputs(kParameters); auto pg_names = ctx->Outputs(kParamGrads); - auto dims = ctx->GetInputsDim(kParameters); auto var_types = ctx->GetInputsVarType(kParameters); std::vector names_to_set; std::vector dims_to_set; @@ -295,13 +294,14 @@ class WhileGradOpShapeInference : public framework::InferShapeBase { if (pg_names[i] == framework::kEmptyVarName) { continue; } + auto dims = ctx->GetInputsElementDim(kParameters, i); if (var_types[i] == framework::VarDesc::LOD_TENSOR) { names_to_set.push_back(pg_names[i]); - dims_to_set.push_back(dims[i]); + dims_to_set.push_back(dims); } else if (var_types[i] == framework::VarDesc::LOD_TENSOR_ARRAY) { // not sure how to set the dim of LOD_TENSOR_ARRAY names_to_set.push_back(pg_names[i]); - dims_to_set.push_back(dims[i]); + dims_to_set.push_back(dims); } } ctx->SetDims(names_to_set, dims_to_set); -- GitLab