From 4b90e9424fcb2413cb1576d5837d29560f86122b Mon Sep 17 00:00:00 2001 From: liuqi Date: Thu, 17 May 2018 11:24:51 +0800 Subject: [PATCH] Fix op test no device bug. --- mace/core/net.cc | 2 +- mace/core/workspace.cc | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mace/core/net.cc b/mace/core/net.cc index f629e206..ea5ea504 100644 --- a/mace/core/net.cc +++ b/mace/core/net.cc @@ -43,7 +43,7 @@ SerialNet::SerialNet(const std::shared_ptr op_registry, // TODO(liuqi): refactor based on PB const int op_device = ArgumentHelper::GetSingleArgument( - operator_def, "device", -1); + operator_def, "device", static_cast(device_type_)); if (op_device == type) { VLOG(3) << "Creating operator " << operator_def.name() << "(" << operator_def.type() << ")"; diff --git a/mace/core/workspace.cc b/mace/core/workspace.cc index 3867cfe3..46ae0b48 100644 --- a/mace/core/workspace.cc +++ b/mace/core/workspace.cc @@ -139,7 +139,7 @@ void Workspace::CreateOutputTensorBuffer(const NetDef &net_def, // TODO(liuqi): refactor based on PB const int op_device = ArgumentHelper::GetSingleArgument( - op, "device", -1); + op, "device", static_cast(device_type)); if (op_device == device_type && !op.mem_id().empty()) { const DataType op_dtype = static_cast( ArgumentHelper::GetSingleArgument( @@ -175,7 +175,7 @@ void Workspace::CreateOutputTensorBuffer(const NetDef &net_def, // TODO(liuqi): refactor based on PB const int op_device = ArgumentHelper::GetSingleArgument( - op, "device", -1); + op, "device", static_cast(device_type)); if (op_device == device_type && !op.mem_id().empty()) { auto mem_ids = op.mem_id(); int count = mem_ids.size(); -- GitLab