提交 87699388 编写于 作者: X Xin Pan

convert more interface to avoid scope

test=develop
上级 8c19f0bf
...@@ -134,12 +134,46 @@ class CompileTimeInferShapeContext : public InferShapeContext { ...@@ -134,12 +134,46 @@ class CompileTimeInferShapeContext : public InferShapeContext {
return res; return res;
} }
DDim GetInputDim(const std::string &name) const override {
const std::vector<std::string> &arg_names = Inputs(name);
PADDLE_ENFORCE_EQ(arg_names.size(), 1UL,
"Input(%s) should hold one element, but now it holds %d",
name, arg_names.size());
return this->GetDim(arg_names[0]);
}
std::vector<DDim> GetInputsDim(const std::string &name) const override {
const std::vector<std::string> &arg_names = Inputs(name);
return GetDims(arg_names);
}
bool IsRuntime() const override; bool IsRuntime() const override;
protected: protected:
proto::VarType::Type GetVarType(const std::string &name) const override; proto::VarType::Type GetVarType(const std::string &name) const override;
DDim GetDim(const std::string &name) const override; DDim GetDim(const std::string &name) const {
auto var = block_.FindVarRecursive(name);
PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name);
DDim res;
try {
auto shape = var->GetShape();
res = shape.empty() ? make_ddim({0UL}) : make_ddim(shape);
} catch (...) {
VLOG(5) << "GetDim of variable " << name << " error";
std::rethrow_exception(std::current_exception());
}
return res;
}
std::vector<DDim> GetDims(const std::vector<std::string> &names) const {
std::vector<DDim> ret;
ret.reserve(names.size());
std::transform(
names.begin(), names.end(), std::back_inserter(ret),
[this](const std::string &name) { return this->GetDim(name); });
return ret;
}
void SetDim(const std::string &name, const DDim &dim) override; void SetDim(const std::string &name, const DDim &dim) override;
...@@ -666,20 +700,6 @@ const std::vector<std::string> &CompileTimeInferShapeContext::Outputs( ...@@ -666,20 +700,6 @@ const std::vector<std::string> &CompileTimeInferShapeContext::Outputs(
return op_.Output(name); return op_.Output(name);
} }
DDim CompileTimeInferShapeContext::GetDim(const std::string &name) const {
auto var = block_.FindVarRecursive(name);
PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name);
DDim res;
try {
auto shape = var->GetShape();
res = shape.empty() ? make_ddim({0UL}) : make_ddim(shape);
} catch (...) {
VLOG(5) << "GetDim of variable " << name << " error";
std::rethrow_exception(std::current_exception());
}
return res;
}
std::vector<DDim> CompileTimeInferShapeContext::GetRepeatedDims( std::vector<DDim> CompileTimeInferShapeContext::GetRepeatedDims(
const std::string &name) const { const std::string &name) const {
auto var = block_.FindVarRecursive(name); auto var = block_.FindVarRecursive(name);
......
...@@ -709,9 +709,21 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -709,9 +709,21 @@ class RuntimeInferShapeContext : public InferShapeContext {
return res; return res;
} }
DDim GetInputDim(const std::string& name) const override {
const std::vector<Variable*>& vars = InputVars(name);
PADDLE_ENFORCE_EQ(vars.size(), 1UL,
"Input(%s) should hold one element, but now it holds %d",
name, vars.size());
return this->GetDim(vars[0]);
}
std::vector<DDim> GetInputsDim(const std::string& name) const override {
const std::vector<Variable*>& vars = InputVars(name);
return GetDims(vars);
}
protected: protected:
DDim GetDim(const std::string& name) const override { DDim GetDim(Variable* var) const {
Variable* var = scope_.FindVar(name);
PADDLE_ENFORCE_NOT_NULL(var); PADDLE_ENFORCE_NOT_NULL(var);
if (var->IsType<LoDTensor>()) { if (var->IsType<LoDTensor>()) {
return var->Get<LoDTensor>().dims(); return var->Get<LoDTensor>().dims();
...@@ -719,12 +731,20 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -719,12 +731,20 @@ class RuntimeInferShapeContext : public InferShapeContext {
return var->Get<SelectedRows>().GetCompleteDims(); return var->Get<SelectedRows>().GetCompleteDims();
} else { } else {
PADDLE_THROW( PADDLE_THROW(
"Only LoDTensor/SelectedRows support 'GetDim', but Variable %s's " "Only LoDTensor/SelectedRows support 'GetDim', but Variables "
"type_id is %s.", "type_id is %s.",
name, var->Type().name()); var->Type().name());
} }
} }
std::vector<DDim> GetDims(const std::vector<Variable*>& vars) const {
std::vector<DDim> ret;
ret.reserve(vars.size());
std::transform(vars.begin(), vars.end(), std::back_inserter(ret),
[this](Variable* var) { return this->GetDim(var); });
return ret;
}
std::vector<DDim> GetRepeatedDims(const std::string& name) const override { std::vector<DDim> GetRepeatedDims(const std::string& name) const override {
PADDLE_THROW("Only compile time support this method"); PADDLE_THROW("Only compile time support this method");
} }
......
...@@ -22,20 +22,6 @@ limitations under the License. */ ...@@ -22,20 +22,6 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
DDim InferShapeContext::GetInputDim(const std::string &name) const {
const std::vector<std::string> &arg_names = Inputs(name);
PADDLE_ENFORCE_EQ(arg_names.size(), 1UL,
"Input(%s) should hold one element, but now it holds %d",
name, arg_names.size());
return this->GetDim(arg_names[0]);
}
std::vector<DDim> InferShapeContext::GetInputsDim(
const std::string &name) const {
const std::vector<std::string> &arg_names = Inputs(name);
return GetDims(arg_names);
}
std::vector<DDim> InferShapeContext::GetReaderDims( std::vector<DDim> InferShapeContext::GetReaderDims(
const std::string &name) const { const std::string &name) const {
const std::vector<std::string> &arg_names = Inputs(name); const std::vector<std::string> &arg_names = Inputs(name);
...@@ -46,12 +32,6 @@ std::vector<DDim> InferShapeContext::GetReaderDims( ...@@ -46,12 +32,6 @@ std::vector<DDim> InferShapeContext::GetReaderDims(
return this->GetRepeatedDims(arg_names[0]); return this->GetRepeatedDims(arg_names[0]);
} }
DDim InferShapeContext::GetInputsElementDim(const std::string &name,
int idx) const {
const std::vector<std::string> &names = Inputs(name);
return this->GetDim(names[idx]);
}
void InferShapeContext::SetOutputDim(const std::string &name, const DDim &dim) { void InferShapeContext::SetOutputDim(const std::string &name, const DDim &dim) {
auto &arg_names = Outputs(name); auto &arg_names = Outputs(name);
PADDLE_ENFORCE_EQ(arg_names.size(), 1UL, PADDLE_ENFORCE_EQ(arg_names.size(), 1UL,
...@@ -76,16 +56,6 @@ void InferShapeContext::SetReaderDims(const std::string &name, ...@@ -76,16 +56,6 @@ void InferShapeContext::SetReaderDims(const std::string &name,
return this->SetRepeatedDims(arg_names[0], dims); return this->SetRepeatedDims(arg_names[0], dims);
} }
std::vector<DDim> InferShapeContext::GetDims(
const std::vector<std::string> &names) const {
std::vector<DDim> ret;
ret.reserve(names.size());
std::transform(
names.begin(), names.end(), std::back_inserter(ret),
[this](const std::string &name) { return this->GetDim(name); });
return ret;
}
void InferShapeContext::SetDims(const std::vector<std::string> &names, void InferShapeContext::SetDims(const std::vector<std::string> &names,
const std::vector<DDim> &dims) { const std::vector<DDim> &dims) {
size_t length = names.size(); size_t length = names.size();
......
...@@ -41,10 +41,9 @@ class InferShapeContext { ...@@ -41,10 +41,9 @@ class InferShapeContext {
virtual bool HasInputs(const std::string &name) const = 0; virtual bool HasInputs(const std::string &name) const = 0;
virtual bool HasOutputs(const std::string &name) const = 0; virtual bool HasOutputs(const std::string &name) const = 0;
virtual DDim GetInputDim(const std::string &name) const; virtual DDim GetInputDim(const std::string &name) const = 0;
virtual std::vector<DDim> GetInputsDim(const std::string &name) const; virtual std::vector<DDim> GetInputsDim(const std::string &name) const = 0;
virtual std::vector<DDim> GetReaderDims(const std::string &name) const; virtual std::vector<DDim> GetReaderDims(const std::string &name) const;
virtual DDim GetInputsElementDim(const std::string &name, int idx) const;
virtual void SetOutputDim(const std::string &name, const DDim &dim); virtual void SetOutputDim(const std::string &name, const DDim &dim);
virtual void SetOutputsDim(const std::string &name, virtual void SetOutputsDim(const std::string &name,
...@@ -79,14 +78,11 @@ class InferShapeContext { ...@@ -79,14 +78,11 @@ class InferShapeContext {
const std::vector<DDim> &dims); const std::vector<DDim> &dims);
protected: protected:
virtual DDim GetDim(const std::string &name) const = 0;
virtual void SetDim(const std::string &name, const DDim &dim) = 0; virtual void SetDim(const std::string &name, const DDim &dim) = 0;
virtual std::vector<DDim> GetRepeatedDims(const std::string &name) const = 0; virtual std::vector<DDim> GetRepeatedDims(const std::string &name) const = 0;
virtual void SetRepeatedDims(const std::string &name, virtual void SetRepeatedDims(const std::string &name,
const std::vector<DDim> &dims) = 0; const std::vector<DDim> &dims) = 0;
std::vector<DDim> GetDims(const std::vector<std::string> &names) const;
std::vector<proto::VarType::Type> GetVarTypes( std::vector<proto::VarType::Type> GetVarTypes(
const std::vector<std::string> &names) const; const std::vector<std::string> &names) const;
......
...@@ -408,7 +408,7 @@ class WhileGradOpShapeInference : public framework::InferShapeBase { ...@@ -408,7 +408,7 @@ class WhileGradOpShapeInference : public framework::InferShapeBase {
if (pg_ig_names[i] == framework::kEmptyVarName) { if (pg_ig_names[i] == framework::kEmptyVarName) {
continue; continue;
} }
auto dims = ctx->GetInputsElementDim(kX, i); auto dims = ctx->GetInputsDim(kX)[i];
if (var_types[i] == framework::proto::VarType::LOD_TENSOR) { if (var_types[i] == framework::proto::VarType::LOD_TENSOR) {
names_to_set.push_back(pg_ig_names[i]); names_to_set.push_back(pg_ig_names[i]);
dims_to_set.push_back(dims); dims_to_set.push_back(dims);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册