未验证 提交 2ffa6436 编写于 作者: C chentianyu03 提交者: GitHub

fix output var may be nullptr and cause segment fault bug (#40079)

上级 31d3d857
...@@ -2106,6 +2106,8 @@ void OperatorWithKernel::BuildPhiKernelContext( ...@@ -2106,6 +2106,8 @@ void OperatorWithKernel::BuildPhiKernelContext(
for (size_t offset = 0; offset < outs_vector.size(); ++offset) { for (size_t offset = 0; offset < outs_vector.size(); ++offset) {
phi::TensorBase* tensor_out = nullptr; phi::TensorBase* tensor_out = nullptr;
auto* var = outs_vector[offset]; auto* var = outs_vector[offset];
if (var) {
if (var->template IsType<framework::LoDTensor>()) { if (var->template IsType<framework::LoDTensor>()) {
tensor_out = var->template GetMutable<framework::LoDTensor>(); tensor_out = var->template GetMutable<framework::LoDTensor>();
} else if (var->template IsType<phi::SelectedRows>()) { } else if (var->template IsType<phi::SelectedRows>()) {
...@@ -2115,6 +2117,8 @@ void OperatorWithKernel::BuildPhiKernelContext( ...@@ -2115,6 +2117,8 @@ void OperatorWithKernel::BuildPhiKernelContext(
"Unsupported output `%s` type when call pt kernel.", "Unsupported output `%s` type when call pt kernel.",
framework::ToTypeName(var->Type()))); framework::ToTypeName(var->Type())));
} }
}
pt_kernel_context->EmplaceBackOutputWithoutSetRange(tensor_out); pt_kernel_context->EmplaceBackOutputWithoutSetRange(tensor_out);
} }
......
...@@ -314,6 +314,7 @@ void BuildDygraphPhiKernelContext( ...@@ -314,6 +314,7 @@ void BuildDygraphPhiKernelContext(
phi::TensorBase* tensor_out = nullptr; phi::TensorBase* tensor_out = nullptr;
auto* var = outs_vector[offset]->MutableVar(); auto* var = outs_vector[offset]->MutableVar();
if (var) {
if (var->template IsType<phi::DenseTensor>()) { if (var->template IsType<phi::DenseTensor>()) {
tensor_out = var->template GetMutable<phi::DenseTensor>(); tensor_out = var->template GetMutable<phi::DenseTensor>();
} else if (var->template IsType<phi::SelectedRows>()) { } else if (var->template IsType<phi::SelectedRows>()) {
...@@ -323,6 +324,8 @@ void BuildDygraphPhiKernelContext( ...@@ -323,6 +324,8 @@ void BuildDygraphPhiKernelContext(
"Unsupported output `%s` type when call pt kernel.", "Unsupported output `%s` type when call pt kernel.",
framework::ToTypeName(var->Type()))); framework::ToTypeName(var->Type())));
} }
}
kernel_ctx->EmplaceBackOutputWithoutSetRange(tensor_out); kernel_ctx->EmplaceBackOutputWithoutSetRange(tensor_out);
} }
kernel_ctx->AssignOutputRange(std::make_pair(start_idx, end_idx), i); kernel_ctx->AssignOutputRange(std::make_pair(start_idx, end_idx), i);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册