提交 e7f62703 编写于 作者: Q Qiao Longfei 提交者: GitHub

fix InferShapeContext Has interface (#4994)

上级 d0cfbba4
...@@ -327,6 +327,9 @@ class CompileTimeInferShapeContext : public InferShapeContext { ...@@ -327,6 +327,9 @@ class CompileTimeInferShapeContext : public InferShapeContext {
bool HasInput(const std::string& name) const override { bool HasInput(const std::string& name) const override {
const std::vector<std::string>& input_names = op_.Input(name); const std::vector<std::string>& input_names = op_.Input(name);
auto length = input_names.size(); auto length = input_names.size();
if (length == 0) {
return false;
}
PADDLE_ENFORCE_EQ(length, 1UL, PADDLE_ENFORCE_EQ(length, 1UL,
"Input(%s) should have only one value, " "Input(%s) should have only one value, "
"but it have %d now", "but it have %d now",
...@@ -337,6 +340,9 @@ class CompileTimeInferShapeContext : public InferShapeContext { ...@@ -337,6 +340,9 @@ class CompileTimeInferShapeContext : public InferShapeContext {
bool HasOutput(const std::string& name) const override { bool HasOutput(const std::string& name) const override {
const std::vector<std::string>& output_names = op_.Output(name); const std::vector<std::string>& output_names = op_.Output(name);
auto length = output_names.size(); auto length = output_names.size();
if (length == 0) {
return false;
}
PADDLE_ENFORCE_EQ(length, 1UL, PADDLE_ENFORCE_EQ(length, 1UL,
"Output(%s) should have only one value, " "Output(%s) should have only one value, "
"but it have %d now", "but it have %d now",
...@@ -346,7 +352,9 @@ class CompileTimeInferShapeContext : public InferShapeContext { ...@@ -346,7 +352,9 @@ class CompileTimeInferShapeContext : public InferShapeContext {
bool HasInputs(const std::string& name) const override { bool HasInputs(const std::string& name) const override {
const std::vector<std::string>& input_names = op_.Input(name); const std::vector<std::string>& input_names = op_.Input(name);
PADDLE_ENFORCE(!input_names.empty(), "Inputs(%s) length is 0", name); if (input_names.empty()) {
return false;
}
for (auto& input : input_names) { for (auto& input : input_names) {
if (!block_.HasVar(input)) return false; if (!block_.HasVar(input)) return false;
} }
...@@ -355,7 +363,9 @@ class CompileTimeInferShapeContext : public InferShapeContext { ...@@ -355,7 +363,9 @@ class CompileTimeInferShapeContext : public InferShapeContext {
bool HasOutputs(const std::string& name) const override { bool HasOutputs(const std::string& name) const override {
const std::vector<std::string>& output_names = op_.Output(name); const std::vector<std::string>& output_names = op_.Output(name);
PADDLE_ENFORCE(!output_names.empty(), "Inputs(%s) length is 0", name); if (output_names.empty()) {
return false;
}
for (auto& output : output_names) { for (auto& output : output_names) {
if (!block_.HasVar(output)) return false; if (!block_.HasVar(output)) return false;
} }
...@@ -421,13 +431,27 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -421,13 +431,27 @@ class RuntimeInferShapeContext : public InferShapeContext {
: op_(op), scope_(scope) {} : op_(op), scope_(scope) {}
bool HasInput(const std::string& name) const override { bool HasInput(const std::string& name) const override {
auto ipt = op_.Input(name); auto& ins = Inputs(name);
size_t length = ins.size();
if (length == 0) {
return false;
}
PADDLE_ENFORCE_EQ(length, 1UL, "Input %s should have more than one inputs",
name);
auto ipt = ins[0];
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt); auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
return var != nullptr; return var != nullptr;
} }
bool HasOutput(const std::string& name) const override { bool HasOutput(const std::string& name) const override {
auto ipt = op_.Output(name); auto& outs = Outputs(name);
size_t length = outs.size();
if (length == 0) {
return false;
}
PADDLE_ENFORCE_EQ(length, 1UL, "Output %s should have more than one inputs",
name);
auto ipt = outs[0];
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt); auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
return var != nullptr; return var != nullptr;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册