diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 9d7fe1f5ba293227e67cf6bfcd566a1247c567ed..79a452b616b0109f8d137669cd111b16d5839287 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -327,6 +327,9 @@ class CompileTimeInferShapeContext : public InferShapeContext { bool HasInput(const std::string& name) const override { const std::vector& input_names = op_.Input(name); auto length = input_names.size(); + if (length == 0) { + return false; + } PADDLE_ENFORCE_EQ(length, 1UL, "Input(%s) should have only one value, " "but it have %d now", @@ -337,6 +340,9 @@ class CompileTimeInferShapeContext : public InferShapeContext { bool HasOutput(const std::string& name) const override { const std::vector& output_names = op_.Output(name); auto length = output_names.size(); + if (length == 0) { + return false; + } PADDLE_ENFORCE_EQ(length, 1UL, "Output(%s) should have only one value, " "but it have %d now", @@ -346,7 +352,9 @@ class CompileTimeInferShapeContext : public InferShapeContext { bool HasInputs(const std::string& name) const override { const std::vector& input_names = op_.Input(name); - PADDLE_ENFORCE(!input_names.empty(), "Inputs(%s) length is 0", name); + if (input_names.empty()) { + return false; + } for (auto& input : input_names) { if (!block_.HasVar(input)) return false; } @@ -355,7 +363,9 @@ class CompileTimeInferShapeContext : public InferShapeContext { bool HasOutputs(const std::string& name) const override { const std::vector& output_names = op_.Output(name); - PADDLE_ENFORCE(!output_names.empty(), "Inputs(%s) length is 0", name); + if (output_names.empty()) { + return false; + } for (auto& output : output_names) { if (!block_.HasVar(output)) return false; } @@ -421,13 +431,27 @@ class RuntimeInferShapeContext : public InferShapeContext { : op_(op), scope_(scope) {} bool HasInput(const std::string& name) const override { - auto ipt = op_.Input(name); + auto& ins = Inputs(name); + size_t length = ins.size(); + if (length == 0) { + return false; + } + PADDLE_ENFORCE_EQ(length, 1UL, "Input %s should have more than one inputs", + name); + auto ipt = ins[0]; auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt); return var != nullptr; } bool HasOutput(const std::string& name) const override { - auto ipt = op_.Output(name); + auto& outs = Outputs(name); + size_t length = outs.size(); + if (length == 0) { + return false; + } + PADDLE_ENFORCE_EQ(length, 1UL, "Output %s should have more than one inputs", + name); + auto ipt = outs[0]; auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt); return var != nullptr; }