From 2ffa643644241b1cecb1a0255dddbfbf1688c16c Mon Sep 17 00:00:00 2001 From: chentianyu03 Date: Thu, 3 Mar 2022 11:23:43 +0800 Subject: [PATCH] fix output var may be nullptr and cause segment fault bug (#40079) --- paddle/fluid/framework/operator.cc | 20 ++++++++++++-------- paddle/fluid/imperative/prepared_operator.h | 19 +++++++++++-------- 2 files changed, 23 insertions(+), 16 deletions(-) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 8ebc64e5f2c..b68748a687c 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -2106,15 +2106,19 @@ void OperatorWithKernel::BuildPhiKernelContext( for (size_t offset = 0; offset < outs_vector.size(); ++offset) { phi::TensorBase* tensor_out = nullptr; auto* var = outs_vector[offset]; - if (var->template IsType()) { - tensor_out = var->template GetMutable(); - } else if (var->template IsType()) { - tensor_out = var->template GetMutable(); - } else { - PADDLE_THROW(platform::errors::Unimplemented( - "Unsupported output `%s` type when call pt kernel.", - framework::ToTypeName(var->Type()))); + + if (var) { + if (var->template IsType()) { + tensor_out = var->template GetMutable(); + } else if (var->template IsType()) { + tensor_out = var->template GetMutable(); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported output `%s` type when call pt kernel.", + framework::ToTypeName(var->Type()))); + } } + pt_kernel_context->EmplaceBackOutputWithoutSetRange(tensor_out); } diff --git a/paddle/fluid/imperative/prepared_operator.h b/paddle/fluid/imperative/prepared_operator.h index 3b5762720e7..30dbe07d7af 100644 --- a/paddle/fluid/imperative/prepared_operator.h +++ b/paddle/fluid/imperative/prepared_operator.h @@ -314,15 +314,18 @@ void BuildDygraphPhiKernelContext( phi::TensorBase* tensor_out = nullptr; auto* var = outs_vector[offset]->MutableVar(); - if (var->template IsType()) { - tensor_out = var->template GetMutable(); - } else if (var->template IsType()) { - tensor_out = var->template GetMutable(); - } else { - PADDLE_THROW(platform::errors::Unimplemented( - "Unsupported output `%s` type when call pt kernel.", - framework::ToTypeName(var->Type()))); + if (var) { + if (var->template IsType()) { + tensor_out = var->template GetMutable(); + } else if (var->template IsType()) { + tensor_out = var->template GetMutable(); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported output `%s` type when call pt kernel.", + framework::ToTypeName(var->Type()))); + } } + kernel_ctx->EmplaceBackOutputWithoutSetRange(tensor_out); } kernel_ctx->AssignOutputRange(std::make_pair(start_idx, end_idx), i); -- GitLab