diff --git a/mace/core/operator.cc b/mace/core/operator.cc index 883bc1eb828faaeeda015402d1f9f40059f28d5c..28b3bc4636135aedd8e009396f699847a3f24a9c 100644 --- a/mace/core/operator.cc +++ b/mace/core/operator.cc @@ -114,8 +114,7 @@ OpInitContext::OpInitContext(Workspace *ws, Device *device) : ws_(ws), device_(device) {} Operation::Operation(OpConstructContext *context) - : operator_def_(context->operator_def()) -{} + : operator_def_(context->operator_def()) {} MaceStatus Operation::Init(OpInitContext *context) { Workspace *ws = context->workspace(); @@ -142,7 +141,7 @@ MaceStatus Operation::Init(OpInitContext *context) { } else { output_type = static_cast( ProtoArgHelper::GetOptionalArg( - *operator_def_, "T", static_cast(DT_FLOAT))); + *operator_def_, "T", static_cast(DT_FLOAT))); } outputs_.push_back(MACE_CHECK_NOTNULL(ws->CreateTensor( output_str, context->device()->allocator(), output_type))); @@ -316,9 +315,11 @@ std::unique_ptr OpRegistryBase::CreateOperation( MACE_CHECK(registry_.count(op_type) != 0, op_type, " operation is not registered."); + auto key_dtype = + (device_type == DeviceType::GPU && dtype == DT_HALF) ? DT_FLOAT : dtype; std::string key = OpKeyBuilder(op_type) .Device(device_type) - .TypeConstraint("T", dtype == DT_HALF ? DT_FLOAT : dtype) + .TypeConstraint("T", key_dtype) .Build(); if (registry_.at(op_type)->creators.count(key) == 0) { LOG(FATAL) << "Key not registered: " << key; @@ -327,7 +328,7 @@ std::unique_ptr OpRegistryBase::CreateOperation( } OpConditionBuilder::OpConditionBuilder(const std::string &type) - : type_(type) {} + : type_(type) {} const std::string OpConditionBuilder::type() const { return type_; @@ -339,13 +340,13 @@ OpConditionBuilder &OpConditionBuilder::SetDevicePlacerFunc( return *this; } -OpConditionBuilder& OpConditionBuilder::SetInputMemoryTypeSetter( +OpConditionBuilder &OpConditionBuilder::SetInputMemoryTypeSetter( OpRegistrationInfo::MemoryTypeSetter setter) { memory_type_setter_ = setter; return *this; } -OpConditionBuilder& OpConditionBuilder::SetInputsDataFormatSelector( +OpConditionBuilder &OpConditionBuilder::SetInputsDataFormatSelector( OpRegistrationInfo::DataFormatSelector selector) { data_format_selector_ = selector; return *this;