提交 4b90e942 编写于 作者: L liuqi

Fix op test no device bug.

上级 911dbb26
......@@ -43,7 +43,7 @@ SerialNet::SerialNet(const std::shared_ptr<const OperatorRegistry> op_registry,
// TODO(liuqi): refactor based on PB
const int op_device =
ArgumentHelper::GetSingleArgument<OperatorDef, int>(
operator_def, "device", -1);
operator_def, "device", static_cast<int>(device_type_));
if (op_device == type) {
VLOG(3) << "Creating operator " << operator_def.name() << "("
<< operator_def.type() << ")";
......
......@@ -139,7 +139,7 @@ void Workspace::CreateOutputTensorBuffer(const NetDef &net_def,
// TODO(liuqi): refactor based on PB
const int op_device =
ArgumentHelper::GetSingleArgument<OperatorDef, int>(
op, "device", -1);
op, "device", static_cast<int>(device_type));
if (op_device == device_type && !op.mem_id().empty()) {
const DataType op_dtype = static_cast<DataType>(
ArgumentHelper::GetSingleArgument<OperatorDef, int>(
......@@ -175,7 +175,7 @@ void Workspace::CreateOutputTensorBuffer(const NetDef &net_def,
// TODO(liuqi): refactor based on PB
const int op_device =
ArgumentHelper::GetSingleArgument<OperatorDef, int>(
op, "device", -1);
op, "device", static_cast<int>(device_type));
if (op_device == device_type && !op.mem_id().empty()) {
auto mem_ids = op.mem_id();
int count = mem_ids.size();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册