diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index 0a3bb586fcc61e80f0358adc21730c279509ef87..ef98558820f5696077c83424d5946caeec00b6fb 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -134,12 +134,46 @@ class CompileTimeInferShapeContext : public InferShapeContext { return res; } + DDim GetInputDim(const std::string &name) const override { + const std::vector &arg_names = Inputs(name); + PADDLE_ENFORCE_EQ(arg_names.size(), 1UL, + "Input(%s) should hold one element, but now it holds %d", + name, arg_names.size()); + return this->GetDim(arg_names[0]); + } + + std::vector GetInputsDim(const std::string &name) const override { + const std::vector &arg_names = Inputs(name); + return GetDims(arg_names); + } + bool IsRuntime() const override; protected: proto::VarType::Type GetVarType(const std::string &name) const override; - DDim GetDim(const std::string &name) const override; + DDim GetDim(const std::string &name) const { + auto var = block_.FindVarRecursive(name); + PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name); + DDim res; + try { + auto shape = var->GetShape(); + res = shape.empty() ? make_ddim({0UL}) : make_ddim(shape); + } catch (...) { + VLOG(5) << "GetDim of variable " << name << " error"; + std::rethrow_exception(std::current_exception()); + } + return res; + } + + std::vector GetDims(const std::vector &names) const { + std::vector ret; + ret.reserve(names.size()); + std::transform( + names.begin(), names.end(), std::back_inserter(ret), + [this](const std::string &name) { return this->GetDim(name); }); + return ret; + } void SetDim(const std::string &name, const DDim &dim) override; @@ -666,20 +700,6 @@ const std::vector &CompileTimeInferShapeContext::Outputs( return op_.Output(name); } -DDim CompileTimeInferShapeContext::GetDim(const std::string &name) const { - auto var = block_.FindVarRecursive(name); - PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name); - DDim res; - try { - auto shape = var->GetShape(); - res = shape.empty() ? make_ddim({0UL}) : make_ddim(shape); - } catch (...) { - VLOG(5) << "GetDim of variable " << name << " error"; - std::rethrow_exception(std::current_exception()); - } - return res; -} - std::vector CompileTimeInferShapeContext::GetRepeatedDims( const std::string &name) const { auto var = block_.FindVarRecursive(name); diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 2f418f728fb6a3fe2e393f5b62d71537e97f1a35..2bfe055b4ca0addeb872dcdcd7ea92172d7e5040 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -709,9 +709,21 @@ class RuntimeInferShapeContext : public InferShapeContext { return res; } + DDim GetInputDim(const std::string& name) const override { + const std::vector& vars = InputVars(name); + PADDLE_ENFORCE_EQ(vars.size(), 1UL, + "Input(%s) should hold one element, but now it holds %d", + name, vars.size()); + return this->GetDim(vars[0]); + } + + std::vector GetInputsDim(const std::string& name) const override { + const std::vector& vars = InputVars(name); + return GetDims(vars); + } + protected: - DDim GetDim(const std::string& name) const override { - Variable* var = scope_.FindVar(name); + DDim GetDim(Variable* var) const { PADDLE_ENFORCE_NOT_NULL(var); if (var->IsType()) { return var->Get().dims(); @@ -719,12 +731,20 @@ class RuntimeInferShapeContext : public InferShapeContext { return var->Get().GetCompleteDims(); } else { PADDLE_THROW( - "Only LoDTensor/SelectedRows support 'GetDim', but Variable %s's " + "Only LoDTensor/SelectedRows support 'GetDim', but Variables " "type_id is %s.", - name, var->Type().name()); + var->Type().name()); } } + std::vector GetDims(const std::vector& vars) const { + std::vector ret; + ret.reserve(vars.size()); + std::transform(vars.begin(), vars.end(), std::back_inserter(ret), + [this](Variable* var) { return this->GetDim(var); }); + return ret; + } + std::vector GetRepeatedDims(const std::string& name) const override { PADDLE_THROW("Only compile time support this method"); } diff --git a/paddle/fluid/framework/shape_inference.cc b/paddle/fluid/framework/shape_inference.cc index 0a7cebcc5a2e92334a4cfcbf87d7c7b475fd255e..f274a1b73f47bcd99ffc7a70fc7026fc97427200 100644 --- a/paddle/fluid/framework/shape_inference.cc +++ b/paddle/fluid/framework/shape_inference.cc @@ -22,20 +22,6 @@ limitations under the License. */ namespace paddle { namespace framework { -DDim InferShapeContext::GetInputDim(const std::string &name) const { - const std::vector &arg_names = Inputs(name); - PADDLE_ENFORCE_EQ(arg_names.size(), 1UL, - "Input(%s) should hold one element, but now it holds %d", - name, arg_names.size()); - return this->GetDim(arg_names[0]); -} - -std::vector InferShapeContext::GetInputsDim( - const std::string &name) const { - const std::vector &arg_names = Inputs(name); - return GetDims(arg_names); -} - std::vector InferShapeContext::GetReaderDims( const std::string &name) const { const std::vector &arg_names = Inputs(name); @@ -46,12 +32,6 @@ std::vector InferShapeContext::GetReaderDims( return this->GetRepeatedDims(arg_names[0]); } -DDim InferShapeContext::GetInputsElementDim(const std::string &name, - int idx) const { - const std::vector &names = Inputs(name); - return this->GetDim(names[idx]); -} - void InferShapeContext::SetOutputDim(const std::string &name, const DDim &dim) { auto &arg_names = Outputs(name); PADDLE_ENFORCE_EQ(arg_names.size(), 1UL, @@ -76,16 +56,6 @@ void InferShapeContext::SetReaderDims(const std::string &name, return this->SetRepeatedDims(arg_names[0], dims); } -std::vector InferShapeContext::GetDims( - const std::vector &names) const { - std::vector ret; - ret.reserve(names.size()); - std::transform( - names.begin(), names.end(), std::back_inserter(ret), - [this](const std::string &name) { return this->GetDim(name); }); - return ret; -} - void InferShapeContext::SetDims(const std::vector &names, const std::vector &dims) { size_t length = names.size(); diff --git a/paddle/fluid/framework/shape_inference.h b/paddle/fluid/framework/shape_inference.h index 543696d43b179ed62561f743eb52c50a6bc90a01..6cf9cf3f38608cc22425b1ea8ba71c6fe90580d4 100644 --- a/paddle/fluid/framework/shape_inference.h +++ b/paddle/fluid/framework/shape_inference.h @@ -41,10 +41,9 @@ class InferShapeContext { virtual bool HasInputs(const std::string &name) const = 0; virtual bool HasOutputs(const std::string &name) const = 0; - virtual DDim GetInputDim(const std::string &name) const; - virtual std::vector GetInputsDim(const std::string &name) const; + virtual DDim GetInputDim(const std::string &name) const = 0; + virtual std::vector GetInputsDim(const std::string &name) const = 0; virtual std::vector GetReaderDims(const std::string &name) const; - virtual DDim GetInputsElementDim(const std::string &name, int idx) const; virtual void SetOutputDim(const std::string &name, const DDim &dim); virtual void SetOutputsDim(const std::string &name, @@ -79,14 +78,11 @@ class InferShapeContext { const std::vector &dims); protected: - virtual DDim GetDim(const std::string &name) const = 0; virtual void SetDim(const std::string &name, const DDim &dim) = 0; virtual std::vector GetRepeatedDims(const std::string &name) const = 0; virtual void SetRepeatedDims(const std::string &name, const std::vector &dims) = 0; - std::vector GetDims(const std::vector &names) const; - std::vector GetVarTypes( const std::vector &names) const; diff --git a/paddle/fluid/operators/controlflow/while_op.cc b/paddle/fluid/operators/controlflow/while_op.cc index e91d9ef7765568a842b31ba682dc1b7e0d8ffa08..3f75ee956ab7aa90e9d226c2d5a1e3f603ff9afd 100644 --- a/paddle/fluid/operators/controlflow/while_op.cc +++ b/paddle/fluid/operators/controlflow/while_op.cc @@ -408,7 +408,7 @@ class WhileGradOpShapeInference : public framework::InferShapeBase { if (pg_ig_names[i] == framework::kEmptyVarName) { continue; } - auto dims = ctx->GetInputsElementDim(kX, i); + auto dims = ctx->GetInputsDim(kX)[i]; if (var_types[i] == framework::proto::VarType::LOD_TENSOR) { names_to_set.push_back(pg_ig_names[i]); dims_to_set.push_back(dims);