未验证 提交 2ffa6436 编写于 作者: C chentianyu03 提交者: GitHub

fix output var may be nullptr and cause segment fault bug (#40079)

上级 31d3d857
......@@ -2106,15 +2106,19 @@ void OperatorWithKernel::BuildPhiKernelContext(
for (size_t offset = 0; offset < outs_vector.size(); ++offset) {
phi::TensorBase* tensor_out = nullptr;
auto* var = outs_vector[offset];
if (var->template IsType<framework::LoDTensor>()) {
tensor_out = var->template GetMutable<framework::LoDTensor>();
} else if (var->template IsType<phi::SelectedRows>()) {
tensor_out = var->template GetMutable<phi::SelectedRows>();
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported output `%s` type when call pt kernel.",
framework::ToTypeName(var->Type())));
if (var) {
if (var->template IsType<framework::LoDTensor>()) {
tensor_out = var->template GetMutable<framework::LoDTensor>();
} else if (var->template IsType<phi::SelectedRows>()) {
tensor_out = var->template GetMutable<phi::SelectedRows>();
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported output `%s` type when call pt kernel.",
framework::ToTypeName(var->Type())));
}
}
pt_kernel_context->EmplaceBackOutputWithoutSetRange(tensor_out);
}
......
......@@ -314,15 +314,18 @@ void BuildDygraphPhiKernelContext(
phi::TensorBase* tensor_out = nullptr;
auto* var = outs_vector[offset]->MutableVar();
if (var->template IsType<phi::DenseTensor>()) {
tensor_out = var->template GetMutable<phi::DenseTensor>();
} else if (var->template IsType<phi::SelectedRows>()) {
tensor_out = var->template GetMutable<phi::SelectedRows>();
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported output `%s` type when call pt kernel.",
framework::ToTypeName(var->Type())));
if (var) {
if (var->template IsType<phi::DenseTensor>()) {
tensor_out = var->template GetMutable<phi::DenseTensor>();
} else if (var->template IsType<phi::SelectedRows>()) {
tensor_out = var->template GetMutable<phi::SelectedRows>();
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported output `%s` type when call pt kernel.",
framework::ToTypeName(var->Type())));
}
}
kernel_ctx->EmplaceBackOutputWithoutSetRange(tensor_out);
}
kernel_ctx->AssignOutputRange(std::make_pair(start_idx, end_idx), i);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册