From 1fe3ac352a3471e87111cd3f1021fe879fbdf6fe Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Wed, 19 Dec 2018 19:13:26 +0800 Subject: [PATCH] move more and fix while test=develop --- paddle/fluid/framework/op_desc.cc | 28 +++++++++++- paddle/fluid/framework/operator.cc | 33 ++++++++++++-- paddle/fluid/framework/shape_inference.cc | 26 ----------- paddle/fluid/framework/shape_inference.h | 9 +--- .../fluid/operators/controlflow/while_op.cc | 43 +++++++++++++------ 5 files changed, 87 insertions(+), 52 deletions(-) diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index 4d204aefde..2fe1c94ec0 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -159,6 +159,20 @@ class CompileTimeInferShapeContext : public InferShapeContext { return GetVarTypes(Outputs(name)); } + void SetOutputDim(const std::string &name, const DDim &dim) override { + auto &arg_names = Outputs(name); + PADDLE_ENFORCE_EQ(arg_names.size(), 1UL, + "Output(%s) should hold one element, but now it holds %d", + name, arg_names.size()); + SetDim(arg_names[0], dim); + } + + void SetOutputsDim(const std::string &name, + const std::vector &dims) override { + auto &names = Outputs(name); + SetDims(names, dims); + } + protected: std::vector GetVarTypes( const std::vector &names) const { @@ -196,7 +210,19 @@ class CompileTimeInferShapeContext : public InferShapeContext { return ret; } - void SetDim(const std::string &name, const DDim &dim) override; + void SetDim(const std::string &name, const DDim &dim); + + void SetDims(const std::vector &names, + const std::vector &dims) { + size_t length = names.size(); + PADDLE_ENFORCE_EQ(length, dims.size()); + for (size_t i = 0; i < length; ++i) { + if (names[i] == framework::kEmptyVarName) { + continue; + } + SetDim(names[i], dims[i]); + } + } std::vector GetRepeatedDims(const std::string &name) const override; diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index eb172ca88f..4b520a393f 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -732,6 +732,20 @@ class RuntimeInferShapeContext : public InferShapeContext { return GetVarTypes(OutputVars(name)); } + void SetOutputDim(const std::string& name, const DDim& dim) override { + auto& vars = OutputVars(name); + PADDLE_ENFORCE_EQ(vars.size(), 1UL, + "Output(%s) should hold one element, but now it holds %d", + name, vars.size()); + SetDim(vars[0], dim); + } + + void SetOutputsDim(const std::string& name, + const std::vector& dims) override { + auto& vars = OutputVars(name); + SetDims(vars, dims); + } + protected: DDim GetDim(Variable* var) const { PADDLE_ENFORCE_NOT_NULL(var); @@ -759,15 +773,26 @@ class RuntimeInferShapeContext : public InferShapeContext { PADDLE_THROW("Only compile time support this method"); } - void SetDim(const std::string& name, const DDim& dim) override { - Variable* var = scope_.FindVar(name); + void SetDim(Variable* var, const DDim& dim) { if (var->IsType()) { var->GetMutable()->Resize(dim); } else if (var->IsType()) { var->GetMutable()->set_height(dim[0]); } else { - PADDLE_THROW("Variable %s type_id %s, expect LoDTensor/SelectedRows.", - name, var->Type().name()); + PADDLE_THROW("Variable type_id %s, expect LoDTensor/SelectedRows.", + var->Type().name()); + } + } + + void SetDims(const std::vector& vars, + const std::vector& dims) { + size_t length = vars.size(); + PADDLE_ENFORCE_EQ(length, dims.size()); + for (size_t i = 0; i < length; ++i) { + if (vars[i] == nullptr) { + continue; + } + SetDim(vars[i], dims[i]); } } diff --git a/paddle/fluid/framework/shape_inference.cc b/paddle/fluid/framework/shape_inference.cc index 4e67855b5c..4ac872ac3d 100644 --- a/paddle/fluid/framework/shape_inference.cc +++ b/paddle/fluid/framework/shape_inference.cc @@ -32,20 +32,6 @@ std::vector InferShapeContext::GetReaderDims( return this->GetRepeatedDims(arg_names[0]); } -void InferShapeContext::SetOutputDim(const std::string &name, const DDim &dim) { - auto &arg_names = Outputs(name); - PADDLE_ENFORCE_EQ(arg_names.size(), 1UL, - "Output(%s) should hold one element, but now it holds %d", - name, arg_names.size()); - SetDim(arg_names[0], dim); -} - -void InferShapeContext::SetOutputsDim(const std::string &name, - const std::vector &dims) { - auto &names = Outputs(name); - SetDims(names, dims); -} - void InferShapeContext::SetReaderDims(const std::string &name, const std::vector &dims) { const std::vector &arg_names = Outputs(name); @@ -56,17 +42,5 @@ void InferShapeContext::SetReaderDims(const std::string &name, return this->SetRepeatedDims(arg_names[0], dims); } -void InferShapeContext::SetDims(const std::vector &names, - const std::vector &dims) { - size_t length = names.size(); - PADDLE_ENFORCE_EQ(length, dims.size()); - for (size_t i = 0; i < length; ++i) { - if (names[i] == framework::kEmptyVarName) { - continue; - } - SetDim(names[i], dims[i]); - } -} - } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/shape_inference.h b/paddle/fluid/framework/shape_inference.h index 415339a01d..824f75b3d3 100644 --- a/paddle/fluid/framework/shape_inference.h +++ b/paddle/fluid/framework/shape_inference.h @@ -45,9 +45,9 @@ class InferShapeContext { virtual std::vector GetInputsDim(const std::string &name) const = 0; virtual std::vector GetReaderDims(const std::string &name) const; - virtual void SetOutputDim(const std::string &name, const DDim &dim); + virtual void SetOutputDim(const std::string &name, const DDim &dim) = 0; virtual void SetOutputsDim(const std::string &name, - const std::vector &dims); + const std::vector &dims) = 0; virtual void SetReaderDims(const std::string &name, const std::vector &dims); @@ -73,12 +73,7 @@ class InferShapeContext { virtual std::vector GetOutputVarPtrs( const std::string &name) = 0; - // Note: In while op, we need this to be public - virtual void SetDims(const std::vector &names, - const std::vector &dims); - protected: - virtual void SetDim(const std::string &name, const DDim &dim) = 0; virtual std::vector GetRepeatedDims(const std::string &name) const = 0; virtual void SetRepeatedDims(const std::string &name, const std::vector &dims) = 0; diff --git a/paddle/fluid/operators/controlflow/while_op.cc b/paddle/fluid/operators/controlflow/while_op.cc index 3f75ee956a..48800947fd 100644 --- a/paddle/fluid/operators/controlflow/while_op.cc +++ b/paddle/fluid/operators/controlflow/while_op.cc @@ -399,26 +399,41 @@ class WhileGradOpShapeInference : public framework::InferShapeBase { ctx->HasInputs(kOutputs); ctx->HasInputs(framework::GradVarName(kOutputs)); - auto p_names = ctx->Inputs(kX); auto pg_ig_names = ctx->Outputs(kXGRAD); - auto var_types = ctx->GetInputsVarType(kX); - std::vector names_to_set; - std::vector dims_to_set; - for (size_t i = 0; i < p_names.size(); ++i) { + std::vector in_var_ptrs = + ctx->GetInputVarPtrs(kX); + std::vector out_var_ptrs = + ctx->GetOutputVarPtrs(kXGRAD); + PADDLE_ENFORCE(in_var_ptrs.size() == out_var_ptrs.size()); + + for (size_t i = 0; i < in_var_ptrs.size(); ++i) { if (pg_ig_names[i] == framework::kEmptyVarName) { continue; } - auto dims = ctx->GetInputsDim(kX)[i]; - if (var_types[i] == framework::proto::VarType::LOD_TENSOR) { - names_to_set.push_back(pg_ig_names[i]); - dims_to_set.push_back(dims); - } else if (var_types[i] == framework::proto::VarType::LOD_TENSOR_ARRAY) { - // not sure how to set the dim of LOD_TENSOR_ARRAY - names_to_set.push_back(pg_ig_names[i]); - dims_to_set.push_back(dims); + if (ctx->IsRuntime()) { + framework::Variable *in_var = + boost::get(in_var_ptrs[i]); + framework::Variable *out_var = + boost::get(out_var_ptrs[i]); + + auto type = framework::ToVarType(in_var->Type()); + if (type == framework::proto::VarType::LOD_TENSOR) { + out_var->GetMutable()->Resize( + in_var->Get().dims()); + } else if (type == framework::proto::VarType::SELECTED_ROWS) { + out_var->GetMutable()->set_height( + in_var->Get().GetCompleteDims()[0]); + } else if (type == framework::proto::VarType::LOD_TENSOR_ARRAY) { + PADDLE_THROW("WhileGradOp doesn't support type %d", + static_cast(type)); + } + } else { + framework::VarDesc *in_var = + boost::get(in_var_ptrs[i]); + boost::get(out_var_ptrs[i]) + ->SetShape(in_var->GetShape()); } } - ctx->SetDims(names_to_set, dims_to_set); } }; -- GitLab