提交 d2783979 编写于 作者: S sunsuodong

fix_matmul_create

上级 e4d2f2fd
...@@ -35,30 +35,15 @@ kernel::LiteKernel *CpuMatmulKernelCreator(const std::vector<lite::tensor::Tenso ...@@ -35,30 +35,15 @@ kernel::LiteKernel *CpuMatmulKernelCreator(const std::vector<lite::tensor::Tenso
auto input_tensor = inputs.at(kInputIndex); auto input_tensor = inputs.at(kInputIndex);
auto data_type = input_tensor->data_type(); auto data_type = input_tensor->data_type();
kernel::LiteKernel *kernel = nullptr; kernel::LiteKernel *kernel = nullptr;
switch (data_type) { if (data_type == kNumberTypeInt8 || data_type == kNumberTypeUInt8) {
case kNumberTypeInt8: kernel = new (std::nothrow) MatmulInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive);
case kNumberTypeUInt8: { } else {
kernel = new (std::nothrow) MatmulInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); kernel = new (std::nothrow) MatmulCPUKernel(opParameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) { }
MS_LOG(ERROR) << "kernel is nullptr."; if (kernel == nullptr) {
return nullptr; MS_LOG(ERROR) << "kernel is nullptr.";
} return nullptr;
break;
}
case kNumberTypeFloat32: {
kernel = new (std::nothrow) MatmulCPUKernel(opParameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
return nullptr;
}
break;
}
default:
break;
} }
auto ret = kernel->Init(); auto ret = kernel->Init();
if (ret != RET_OK) { if (ret != RET_OK) {
delete kernel; delete kernel;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册