提交 67e5c972 编写于 作者: 叶剑武

Merge branch 'op_selection_bug' into 'master'

fix bug on op selection

See merge request !1124
......@@ -114,8 +114,7 @@ OpInitContext::OpInitContext(Workspace *ws, Device *device)
: ws_(ws), device_(device) {}
Operation::Operation(OpConstructContext *context)
: operator_def_(context->operator_def())
{}
: operator_def_(context->operator_def()) {}
MaceStatus Operation::Init(OpInitContext *context) {
Workspace *ws = context->workspace();
......@@ -316,9 +315,11 @@ std::unique_ptr<Operation> OpRegistryBase::CreateOperation(
MACE_CHECK(registry_.count(op_type) != 0,
op_type, " operation is not registered.");
auto key_dtype =
(device_type == DeviceType::GPU && dtype == DT_HALF) ? DT_FLOAT : dtype;
std::string key = OpKeyBuilder(op_type)
.Device(device_type)
.TypeConstraint("T", dtype == DT_HALF ? DT_FLOAT : dtype)
.TypeConstraint("T", key_dtype)
.Build();
if (registry_.at(op_type)->creators.count(key) == 0) {
LOG(FATAL) << "Key not registered: " << key;
......@@ -339,13 +340,13 @@ OpConditionBuilder &OpConditionBuilder::SetDevicePlacerFunc(
return *this;
}
OpConditionBuilder& OpConditionBuilder::SetInputMemoryTypeSetter(
OpConditionBuilder &OpConditionBuilder::SetInputMemoryTypeSetter(
OpRegistrationInfo::MemoryTypeSetter setter) {
memory_type_setter_ = setter;
return *this;
}
OpConditionBuilder& OpConditionBuilder::SetInputsDataFormatSelector(
OpConditionBuilder &OpConditionBuilder::SetInputsDataFormatSelector(
OpRegistrationInfo::DataFormatSelector selector) {
data_format_selector_ = selector;
return *this;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册