提交 c75689bd 编写于 作者: 李寅

Merge branch 'fix-op-test' into 'master'

Fix op test no device bug.

See merge request !487
...@@ -43,7 +43,7 @@ SerialNet::SerialNet(const std::shared_ptr<const OperatorRegistry> op_registry, ...@@ -43,7 +43,7 @@ SerialNet::SerialNet(const std::shared_ptr<const OperatorRegistry> op_registry,
// TODO(liuqi): refactor based on PB // TODO(liuqi): refactor based on PB
const int op_device = const int op_device =
ArgumentHelper::GetSingleArgument<OperatorDef, int>( ArgumentHelper::GetSingleArgument<OperatorDef, int>(
operator_def, "device", -1); operator_def, "device", static_cast<int>(device_type_));
if (op_device == type) { if (op_device == type) {
VLOG(3) << "Creating operator " << operator_def.name() << "(" VLOG(3) << "Creating operator " << operator_def.name() << "("
<< operator_def.type() << ")"; << operator_def.type() << ")";
......
...@@ -139,7 +139,7 @@ void Workspace::CreateOutputTensorBuffer(const NetDef &net_def, ...@@ -139,7 +139,7 @@ void Workspace::CreateOutputTensorBuffer(const NetDef &net_def,
// TODO(liuqi): refactor based on PB // TODO(liuqi): refactor based on PB
const int op_device = const int op_device =
ArgumentHelper::GetSingleArgument<OperatorDef, int>( ArgumentHelper::GetSingleArgument<OperatorDef, int>(
op, "device", -1); op, "device", static_cast<int>(device_type));
if (op_device == device_type && !op.mem_id().empty()) { if (op_device == device_type && !op.mem_id().empty()) {
const DataType op_dtype = static_cast<DataType>( const DataType op_dtype = static_cast<DataType>(
ArgumentHelper::GetSingleArgument<OperatorDef, int>( ArgumentHelper::GetSingleArgument<OperatorDef, int>(
...@@ -175,7 +175,7 @@ void Workspace::CreateOutputTensorBuffer(const NetDef &net_def, ...@@ -175,7 +175,7 @@ void Workspace::CreateOutputTensorBuffer(const NetDef &net_def,
// TODO(liuqi): refactor based on PB // TODO(liuqi): refactor based on PB
const int op_device = const int op_device =
ArgumentHelper::GetSingleArgument<OperatorDef, int>( ArgumentHelper::GetSingleArgument<OperatorDef, int>(
op, "device", -1); op, "device", static_cast<int>(device_type));
if (op_device == device_type && !op.mem_id().empty()) { if (op_device == device_type && !op.mem_id().empty()) {
auto mem_ids = op.mem_id(); auto mem_ids = op.mem_id();
int count = mem_ids.size(); int count = mem_ids.size();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册