提交 f978beb1 编写于 作者: S sunsuodong

fix ScheduleNode and fill parser

上级 9730e1f4
...@@ -231,7 +231,7 @@ kernel::LiteKernel *Scheduler::ScheduleNode(const std::vector<tensor::Tensor *> ...@@ -231,7 +231,7 @@ kernel::LiteKernel *Scheduler::ScheduleNode(const std::vector<tensor::Tensor *>
const std::vector<tensor::Tensor *> &out_tensors, const std::vector<tensor::Tensor *> &out_tensors,
const mindspore::lite::PrimitiveC *primitive, const schema::CNode *cnode) { const mindspore::lite::PrimitiveC *primitive, const schema::CNode *cnode) {
MS_ASSERT(nullptr != primitive); MS_ASSERT(nullptr != primitive);
auto data_type = in_tensors.front()->data_type(); TypeId data_type = GetFirstFp32Fp16OrInt8Type(in_tensors);
kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type, static_cast<schema::PrimitiveType>(primitive->Type())}; kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type, static_cast<schema::PrimitiveType>(primitive->Type())};
if (context_->device_ctx_.type == DT_GPU) { if (context_->device_ctx_.type == DT_GPU) {
desc.arch = kernel::KERNEL_ARCH::kGPU; desc.arch = kernel::KERNEL_ARCH::kGPU;
...@@ -271,6 +271,16 @@ kernel::LiteKernel *Scheduler::ScheduleNode(const std::vector<tensor::Tensor *> ...@@ -271,6 +271,16 @@ kernel::LiteKernel *Scheduler::ScheduleNode(const std::vector<tensor::Tensor *>
return nullptr; return nullptr;
} }
TypeId Scheduler::GetFirstFp32Fp16OrInt8Type(const std::vector<tensor::Tensor *> &in_tensors) {
for (const auto &tensor : in_tensors) {
auto dtype = tensor->data_type();
if (dtype == kNumberTypeFloat32 || dtype == kNumberTypeFloat16 || dtype == kNumberTypeInt8) {
return dtype;
}
}
return kNumberTypeFloat32;
}
void Scheduler::SetKernelTensorDataType(kernel::LiteKernel *kernel) { void Scheduler::SetKernelTensorDataType(kernel::LiteKernel *kernel) {
if (kernel->desc().arch != kernel::KERNEL_ARCH::kCPU) { if (kernel->desc().arch != kernel::KERNEL_ARCH::kCPU) {
return; return;
......
...@@ -47,6 +47,7 @@ class Scheduler { ...@@ -47,6 +47,7 @@ class Scheduler {
void ConstructSubgraphs(std::vector<kernel::LiteKernel *> *kernels); void ConstructSubgraphs(std::vector<kernel::LiteKernel *> *kernels);
kernel::LiteKernel *CreateSubKernel(const std::vector<kernel::LiteKernel *> &kernels, kernel::KERNEL_ARCH arch); kernel::LiteKernel *CreateSubKernel(const std::vector<kernel::LiteKernel *> &kernels, kernel::KERNEL_ARCH arch);
TypeId GetFirstFp32Fp16OrInt8Type(const std::vector<tensor::Tensor *> &in_tensors);
void SetKernelTensorDataType(kernel::LiteKernel *kernel); void SetKernelTensorDataType(kernel::LiteKernel *kernel);
protected: protected:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册