diff --git a/mace/core/net.cc b/mace/core/net.cc index f629e2068da0015a3744aa6a4b2ecc6c309e3739..ea5ea504b6019fc35e930d69addcd5af502c76d0 100644 --- a/mace/core/net.cc +++ b/mace/core/net.cc @@ -43,7 +43,7 @@ SerialNet::SerialNet(const std::shared_ptr op_registry, // TODO(liuqi): refactor based on PB const int op_device = ArgumentHelper::GetSingleArgument( - operator_def, "device", -1); + operator_def, "device", static_cast(device_type_)); if (op_device == type) { VLOG(3) << "Creating operator " << operator_def.name() << "(" << operator_def.type() << ")"; diff --git a/mace/core/workspace.cc b/mace/core/workspace.cc index 3867cfe36f3b5606f793de0043cda6fad79429f6..46ae0b48a49aa461fb0efee21826a8bd5fc443aa 100644 --- a/mace/core/workspace.cc +++ b/mace/core/workspace.cc @@ -139,7 +139,7 @@ void Workspace::CreateOutputTensorBuffer(const NetDef &net_def, // TODO(liuqi): refactor based on PB const int op_device = ArgumentHelper::GetSingleArgument( - op, "device", -1); + op, "device", static_cast(device_type)); if (op_device == device_type && !op.mem_id().empty()) { const DataType op_dtype = static_cast( ArgumentHelper::GetSingleArgument( @@ -175,7 +175,7 @@ void Workspace::CreateOutputTensorBuffer(const NetDef &net_def, // TODO(liuqi): refactor based on PB const int op_device = ArgumentHelper::GetSingleArgument( - op, "device", -1); + op, "device", static_cast(device_type)); if (op_device == device_type && !op.mem_id().empty()) { auto mem_ids = op.mem_id(); int count = mem_ids.size();