提交 67e5c972 编写于 作者: 叶剑武

Merge branch 'op_selection_bug' into 'master'

fix bug on op selection

See merge request !1124
...@@ -114,8 +114,7 @@ OpInitContext::OpInitContext(Workspace *ws, Device *device) ...@@ -114,8 +114,7 @@ OpInitContext::OpInitContext(Workspace *ws, Device *device)
: ws_(ws), device_(device) {} : ws_(ws), device_(device) {}
Operation::Operation(OpConstructContext *context) Operation::Operation(OpConstructContext *context)
: operator_def_(context->operator_def()) : operator_def_(context->operator_def()) {}
{}
MaceStatus Operation::Init(OpInitContext *context) { MaceStatus Operation::Init(OpInitContext *context) {
Workspace *ws = context->workspace(); Workspace *ws = context->workspace();
...@@ -142,7 +141,7 @@ MaceStatus Operation::Init(OpInitContext *context) { ...@@ -142,7 +141,7 @@ MaceStatus Operation::Init(OpInitContext *context) {
} else { } else {
output_type = static_cast<DataType>( output_type = static_cast<DataType>(
ProtoArgHelper::GetOptionalArg<OperatorDef, int>( ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
*operator_def_, "T", static_cast<int>(DT_FLOAT))); *operator_def_, "T", static_cast<int>(DT_FLOAT)));
} }
outputs_.push_back(MACE_CHECK_NOTNULL(ws->CreateTensor( outputs_.push_back(MACE_CHECK_NOTNULL(ws->CreateTensor(
output_str, context->device()->allocator(), output_type))); output_str, context->device()->allocator(), output_type)));
...@@ -316,9 +315,11 @@ std::unique_ptr<Operation> OpRegistryBase::CreateOperation( ...@@ -316,9 +315,11 @@ std::unique_ptr<Operation> OpRegistryBase::CreateOperation(
MACE_CHECK(registry_.count(op_type) != 0, MACE_CHECK(registry_.count(op_type) != 0,
op_type, " operation is not registered."); op_type, " operation is not registered.");
auto key_dtype =
(device_type == DeviceType::GPU && dtype == DT_HALF) ? DT_FLOAT : dtype;
std::string key = OpKeyBuilder(op_type) std::string key = OpKeyBuilder(op_type)
.Device(device_type) .Device(device_type)
.TypeConstraint("T", dtype == DT_HALF ? DT_FLOAT : dtype) .TypeConstraint("T", key_dtype)
.Build(); .Build();
if (registry_.at(op_type)->creators.count(key) == 0) { if (registry_.at(op_type)->creators.count(key) == 0) {
LOG(FATAL) << "Key not registered: " << key; LOG(FATAL) << "Key not registered: " << key;
...@@ -327,7 +328,7 @@ std::unique_ptr<Operation> OpRegistryBase::CreateOperation( ...@@ -327,7 +328,7 @@ std::unique_ptr<Operation> OpRegistryBase::CreateOperation(
} }
OpConditionBuilder::OpConditionBuilder(const std::string &type) OpConditionBuilder::OpConditionBuilder(const std::string &type)
: type_(type) {} : type_(type) {}
const std::string OpConditionBuilder::type() const { const std::string OpConditionBuilder::type() const {
return type_; return type_;
...@@ -339,13 +340,13 @@ OpConditionBuilder &OpConditionBuilder::SetDevicePlacerFunc( ...@@ -339,13 +340,13 @@ OpConditionBuilder &OpConditionBuilder::SetDevicePlacerFunc(
return *this; return *this;
} }
OpConditionBuilder& OpConditionBuilder::SetInputMemoryTypeSetter( OpConditionBuilder &OpConditionBuilder::SetInputMemoryTypeSetter(
OpRegistrationInfo::MemoryTypeSetter setter) { OpRegistrationInfo::MemoryTypeSetter setter) {
memory_type_setter_ = setter; memory_type_setter_ = setter;
return *this; return *this;
} }
OpConditionBuilder& OpConditionBuilder::SetInputsDataFormatSelector( OpConditionBuilder &OpConditionBuilder::SetInputsDataFormatSelector(
OpRegistrationInfo::DataFormatSelector selector) { OpRegistrationInfo::DataFormatSelector selector) {
data_format_selector_ = selector; data_format_selector_ = selector;
return *this; return *this;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册