提交 67e5c972 编写于 作者: 叶剑武

Merge branch 'op_selection_bug' into 'master'

fix bug on op selection

See merge request !1124
...@@ -114,8 +114,7 @@ OpInitContext::OpInitContext(Workspace *ws, Device *device) ...@@ -114,8 +114,7 @@ OpInitContext::OpInitContext(Workspace *ws, Device *device)
: ws_(ws), device_(device) {} : ws_(ws), device_(device) {}
Operation::Operation(OpConstructContext *context) Operation::Operation(OpConstructContext *context)
: operator_def_(context->operator_def()) : operator_def_(context->operator_def()) {}
{}
MaceStatus Operation::Init(OpInitContext *context) { MaceStatus Operation::Init(OpInitContext *context) {
Workspace *ws = context->workspace(); Workspace *ws = context->workspace();
...@@ -316,9 +315,11 @@ std::unique_ptr<Operation> OpRegistryBase::CreateOperation( ...@@ -316,9 +315,11 @@ std::unique_ptr<Operation> OpRegistryBase::CreateOperation(
MACE_CHECK(registry_.count(op_type) != 0, MACE_CHECK(registry_.count(op_type) != 0,
op_type, " operation is not registered."); op_type, " operation is not registered.");
auto key_dtype =
(device_type == DeviceType::GPU && dtype == DT_HALF) ? DT_FLOAT : dtype;
std::string key = OpKeyBuilder(op_type) std::string key = OpKeyBuilder(op_type)
.Device(device_type) .Device(device_type)
.TypeConstraint("T", dtype == DT_HALF ? DT_FLOAT : dtype) .TypeConstraint("T", key_dtype)
.Build(); .Build();
if (registry_.at(op_type)->creators.count(key) == 0) { if (registry_.at(op_type)->creators.count(key) == 0) {
LOG(FATAL) << "Key not registered: " << key; LOG(FATAL) << "Key not registered: " << key;
...@@ -339,13 +340,13 @@ OpConditionBuilder &OpConditionBuilder::SetDevicePlacerFunc( ...@@ -339,13 +340,13 @@ OpConditionBuilder &OpConditionBuilder::SetDevicePlacerFunc(
return *this; return *this;
} }
OpConditionBuilder& OpConditionBuilder::SetInputMemoryTypeSetter( OpConditionBuilder &OpConditionBuilder::SetInputMemoryTypeSetter(
OpRegistrationInfo::MemoryTypeSetter setter) { OpRegistrationInfo::MemoryTypeSetter setter) {
memory_type_setter_ = setter; memory_type_setter_ = setter;
return *this; return *this;
} }
OpConditionBuilder& OpConditionBuilder::SetInputsDataFormatSelector( OpConditionBuilder &OpConditionBuilder::SetInputsDataFormatSelector(
OpRegistrationInfo::DataFormatSelector selector) { OpRegistrationInfo::DataFormatSelector selector) {
data_format_selector_ = selector; data_format_selector_ = selector;
return *this; return *this;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册