提交 81fc7774 编写于 作者: Q qiaolongfei

optimize infershape context

上级 455436e5
...@@ -323,74 +323,76 @@ class CompileTimeInferShapeContext : public InferShapeContextBase { ...@@ -323,74 +323,76 @@ class CompileTimeInferShapeContext : public InferShapeContextBase {
CompileTimeInferShapeContext(const OpDescBind& op, const BlockDescBind& block) CompileTimeInferShapeContext(const OpDescBind& op, const BlockDescBind& block)
: op_(op), block_(block) {} : op_(op), block_(block) {}
bool HasInput(const std::string& name) const { bool HasInput(const std::string& name) const override {
const std::vector<std::string>& input_names = op_.Input(name); const std::vector<std::string>& input_names = op_.Input(name);
PADDLE_ENFORCE_EQ(input_names.size(), 1UL, "Inputs(%s) length is not 1", PADDLE_ENFORCE_EQ(input_names.size(), 1UL, "Inputs(%s) length is not 1",
name); name);
return block_.HasVar(input_names[0]); return block_.HasVar(input_names[0]);
} }
bool HasOutput(const std::string& name) const { bool HasOutput(const std::string& name) const override {
const std::vector<std::string>& output_names = op_.Output(name); const std::vector<std::string>& output_names = op_.Output(name);
PADDLE_ENFORCE_EQ(output_names.size(), 1UL, "Outputs(%s) length is not 1", PADDLE_ENFORCE_EQ(output_names.size(), 1UL, "Outputs(%s) length is not 1",
name); name);
return block_.HasVar(output_names[0]); return block_.HasVar(output_names[0]);
} }
bool HasInputs(const std::string& name) const { bool HasInputs(const std::string& name) const override {
const std::vector<std::string>& input_names = op_.Input(name); const std::vector<std::string>& input_names = op_.Input(name);
PADDLE_ENFORCE_GT(input_names.size(), 0UL, "Inputs(%s) length is 0", name); PADDLE_ENFORCE(!input_names.empty(), "Inputs(%s) length is 0", name);
for (auto& input : input_names) { for (auto& input : input_names) {
if (!block_.HasVar(input)) return false; if (!block_.HasVar(input)) return false;
} }
return true; return true;
} }
bool HasOutputs(const std::string& name) const { bool HasOutputs(const std::string& name) const override {
const std::vector<std::string>& output_names = op_.Output(name); const std::vector<std::string>& output_names = op_.Output(name);
PADDLE_ENFORCE_GT(output_names.size(), 0UL, "Inputs(%s) length is 0", name); PADDLE_ENFORCE(!output_names.empty(), "Inputs(%s) length is 0", name);
for (auto& output : output_names) { for (auto& output : output_names) {
if (!block_.HasVar(output)) return false; if (!block_.HasVar(output)) return false;
} }
return true; return true;
} }
DDim GetInputDim(const std::string& name) const { DDim GetInputDim(const std::string& name) const override {
std::vector<DDim> ddims = GetInputsDim(name); std::vector<DDim> ddims = GetInputsDim(name);
PADDLE_ENFORCE_EQ(ddims.size(), 1UL, "Inputs(%s) length is not 1", name); PADDLE_ENFORCE_EQ(ddims.size(), 1UL, "Inputs(%s) length is not 1", name);
return ddims[0]; return ddims[0];
} }
void SetInputDim(const std::string& name, const DDim& dim) { void SetInputDim(const std::string& name, const DDim& dim) override {
SetInputsDim(name, {dim}); SetInputsDim(name, {dim});
} }
DDim GetOutputDim(const std::string& name) const { DDim GetOutputDim(const std::string& name) const override {
std::vector<DDim> ddims = GetOutputsDim(name); std::vector<DDim> ddims = GetOutputsDim(name);
PADDLE_ENFORCE_EQ(ddims.size(), 1UL, "Outputs(%s) length is not 1", name); PADDLE_ENFORCE_EQ(ddims.size(), 1UL, "Outputs(%s) length is not 1", name);
return ddims[0]; return ddims[0];
} }
void SetOutputDim(const std::string& name, const DDim& dim) { void SetOutputDim(const std::string& name, const DDim& dim) override {
SetOutputsDim(name, {dim}); SetOutputsDim(name, {dim});
} }
AttrReader Attrs() const { return AttrReader(op_.GetAttrMap()); } AttrReader Attrs() const override { return AttrReader(op_.GetAttrMap()); }
const std::vector<std::string>& Inputs(const std::string& name) const { const std::vector<std::string>& Inputs(
const std::string& name) const override {
return op_.Input(name); return op_.Input(name);
} }
const std::vector<std::string>& Outputs(const std::string& name) const { const std::vector<std::string>& Outputs(
const std::string& name) const override {
return op_.Output(name); return op_.Output(name);
} }
private: private:
DDim GetDim(const std::string& name) const { DDim GetDim(const std::string& name) const override {
return framework::make_ddim(block_.Var(name)->Shape()); return framework::make_ddim(block_.Var(name)->Shape());
} }
void SetDim(const std::string& name, const DDim& dim) { void SetDim(const std::string& name, const DDim& dim) override {
block_.Var(name)->SetShape(framework::vectorize(dim)); block_.Var(name)->SetShape(framework::vectorize(dim));
} }
...@@ -403,21 +405,21 @@ class RuntimeInferShapeContext : public InferShapeContextBase { ...@@ -403,21 +405,21 @@ class RuntimeInferShapeContext : public InferShapeContextBase {
RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope) RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope)
: op_(op), scope_(scope) {} : op_(op), scope_(scope) {}
bool HasInput(const std::string& name) const { bool HasInput(const std::string& name) const override {
auto ipt = op_.Input(name); auto ipt = op_.Input(name);
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt); auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
return var != nullptr; return var != nullptr;
} }
bool HasOutput(const std::string& name) const { bool HasOutput(const std::string& name) const override {
auto ipt = op_.Output(name); auto ipt = op_.Output(name);
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt); auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
return var != nullptr; return var != nullptr;
} }
bool HasInputs(const std::string& name) const { bool HasInputs(const std::string& name) const override {
auto inputs = op_.Inputs(name); auto inputs = op_.Inputs(name);
if (inputs.size() == 0UL) { if (inputs.empty()) {
return false; return false;
} }
for (auto& input : inputs) { for (auto& input : inputs) {
...@@ -428,9 +430,9 @@ class RuntimeInferShapeContext : public InferShapeContextBase { ...@@ -428,9 +430,9 @@ class RuntimeInferShapeContext : public InferShapeContextBase {
return true; return true;
} }
bool HasOutputs(const std::string& name) const { bool HasOutputs(const std::string& name) const override {
auto outputs = op_.Outputs(name); auto outputs = op_.Outputs(name);
if (outputs.size() == 0UL) { if (outputs.empty()) {
return false; return false;
} }
for (auto& output : outputs) { for (auto& output : outputs) {
...@@ -441,29 +443,31 @@ class RuntimeInferShapeContext : public InferShapeContextBase { ...@@ -441,29 +443,31 @@ class RuntimeInferShapeContext : public InferShapeContextBase {
return true; return true;
} }
DDim GetInputDim(const std::string& name) const { DDim GetInputDim(const std::string& name) const override {
return GetDim(op_.Input(name)); return GetDim(op_.Input(name));
} }
void SetInputDim(const std::string& name, const DDim& dim) { void SetInputDim(const std::string& name, const DDim& dim) override {
SetDim(op_.Input(name), dim); SetDim(op_.Input(name), dim);
} }
DDim GetOutputDim(const std::string& name) const { DDim GetOutputDim(const std::string& name) const override {
return GetDim(op_.Output(name)); return GetDim(op_.Output(name));
} }
void SetOutputDim(const std::string& name, const DDim& dim) { void SetOutputDim(const std::string& name, const DDim& dim) override {
SetDim(op_.Output(name), dim); SetDim(op_.Output(name), dim);
} }
AttrReader Attrs() const { return AttrReader(op_.Attrs()); } AttrReader Attrs() const override { return AttrReader(op_.Attrs()); }
const std::vector<std::string>& Inputs(const std::string& name) const { const std::vector<std::string>& Inputs(
const std::string& name) const override {
return op_.Inputs(name); return op_.Inputs(name);
} }
const std::vector<std::string>& Outputs(const std::string& name) const { const std::vector<std::string>& Outputs(
const std::string& name) const override {
return op_.Outputs(name); return op_.Outputs(name);
} }
...@@ -484,11 +488,11 @@ class RuntimeInferShapeContext : public InferShapeContextBase { ...@@ -484,11 +488,11 @@ class RuntimeInferShapeContext : public InferShapeContextBase {
return t; return t;
} }
DDim GetDim(const std::string& name) const { DDim GetDim(const std::string& name) const override {
return GetTensor<false>(name)->dims(); return GetTensor<false>(name)->dims();
} }
void SetDim(const std::string& name, const DDim& dim) { void SetDim(const std::string& name, const DDim& dim) override {
GetTensor<true>(name)->Resize(dim); GetTensor<true>(name)->Resize(dim);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册