未验证 提交 2ffa6436 编写于 作者: C chentianyu03 提交者: GitHub

fix output var may be nullptr and cause segment fault bug (#40079)

上级 31d3d857
...@@ -2106,15 +2106,19 @@ void OperatorWithKernel::BuildPhiKernelContext( ...@@ -2106,15 +2106,19 @@ void OperatorWithKernel::BuildPhiKernelContext(
for (size_t offset = 0; offset < outs_vector.size(); ++offset) { for (size_t offset = 0; offset < outs_vector.size(); ++offset) {
phi::TensorBase* tensor_out = nullptr; phi::TensorBase* tensor_out = nullptr;
auto* var = outs_vector[offset]; auto* var = outs_vector[offset];
if (var->template IsType<framework::LoDTensor>()) {
tensor_out = var->template GetMutable<framework::LoDTensor>(); if (var) {
} else if (var->template IsType<phi::SelectedRows>()) { if (var->template IsType<framework::LoDTensor>()) {
tensor_out = var->template GetMutable<phi::SelectedRows>(); tensor_out = var->template GetMutable<framework::LoDTensor>();
} else { } else if (var->template IsType<phi::SelectedRows>()) {
PADDLE_THROW(platform::errors::Unimplemented( tensor_out = var->template GetMutable<phi::SelectedRows>();
"Unsupported output `%s` type when call pt kernel.", } else {
framework::ToTypeName(var->Type()))); PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported output `%s` type when call pt kernel.",
framework::ToTypeName(var->Type())));
}
} }
pt_kernel_context->EmplaceBackOutputWithoutSetRange(tensor_out); pt_kernel_context->EmplaceBackOutputWithoutSetRange(tensor_out);
} }
......
...@@ -314,15 +314,18 @@ void BuildDygraphPhiKernelContext( ...@@ -314,15 +314,18 @@ void BuildDygraphPhiKernelContext(
phi::TensorBase* tensor_out = nullptr; phi::TensorBase* tensor_out = nullptr;
auto* var = outs_vector[offset]->MutableVar(); auto* var = outs_vector[offset]->MutableVar();
if (var->template IsType<phi::DenseTensor>()) { if (var) {
tensor_out = var->template GetMutable<phi::DenseTensor>(); if (var->template IsType<phi::DenseTensor>()) {
} else if (var->template IsType<phi::SelectedRows>()) { tensor_out = var->template GetMutable<phi::DenseTensor>();
tensor_out = var->template GetMutable<phi::SelectedRows>(); } else if (var->template IsType<phi::SelectedRows>()) {
} else { tensor_out = var->template GetMutable<phi::SelectedRows>();
PADDLE_THROW(platform::errors::Unimplemented( } else {
"Unsupported output `%s` type when call pt kernel.", PADDLE_THROW(platform::errors::Unimplemented(
framework::ToTypeName(var->Type()))); "Unsupported output `%s` type when call pt kernel.",
framework::ToTypeName(var->Type())));
}
} }
kernel_ctx->EmplaceBackOutputWithoutSetRange(tensor_out); kernel_ctx->EmplaceBackOutputWithoutSetRange(tensor_out);
} }
kernel_ctx->AssignOutputRange(std::make_pair(start_idx, end_idx), i); kernel_ctx->AssignOutputRange(std::make_pair(start_idx, end_idx), i);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册