提交 b5830970 编写于 作者: L luxuhui

fix bug on op selection

N/A
Signed-off-by: NLuxuhui <luxuhui@xiaomi.com>
上级 7fa03dd2
......@@ -114,8 +114,7 @@ OpInitContext::OpInitContext(Workspace *ws, Device *device)
: ws_(ws), device_(device) {}
Operation::Operation(OpConstructContext *context)
: operator_def_(context->operator_def())
{}
: operator_def_(context->operator_def()) {}
MaceStatus Operation::Init(OpInitContext *context) {
Workspace *ws = context->workspace();
......@@ -142,7 +141,7 @@ MaceStatus Operation::Init(OpInitContext *context) {
} else {
output_type = static_cast<DataType>(
ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
*operator_def_, "T", static_cast<int>(DT_FLOAT)));
*operator_def_, "T", static_cast<int>(DT_FLOAT)));
}
outputs_.push_back(MACE_CHECK_NOTNULL(ws->CreateTensor(
output_str, context->device()->allocator(), output_type)));
......@@ -316,9 +315,11 @@ std::unique_ptr<Operation> OpRegistryBase::CreateOperation(
MACE_CHECK(registry_.count(op_type) != 0,
op_type, " operation is not registered.");
auto key_dtype =
(device_type == DeviceType::GPU && dtype == DT_HALF) ? DT_FLOAT : dtype;
std::string key = OpKeyBuilder(op_type)
.Device(device_type)
.TypeConstraint("T", dtype == DT_HALF ? DT_FLOAT : dtype)
.TypeConstraint("T", key_dtype)
.Build();
if (registry_.at(op_type)->creators.count(key) == 0) {
LOG(FATAL) << "Key not registered: " << key;
......@@ -327,7 +328,7 @@ std::unique_ptr<Operation> OpRegistryBase::CreateOperation(
}
OpConditionBuilder::OpConditionBuilder(const std::string &type)
: type_(type) {}
: type_(type) {}
const std::string OpConditionBuilder::type() const {
return type_;
......@@ -339,13 +340,13 @@ OpConditionBuilder &OpConditionBuilder::SetDevicePlacerFunc(
return *this;
}
OpConditionBuilder& OpConditionBuilder::SetInputMemoryTypeSetter(
OpConditionBuilder &OpConditionBuilder::SetInputMemoryTypeSetter(
OpRegistrationInfo::MemoryTypeSetter setter) {
memory_type_setter_ = setter;
return *this;
}
OpConditionBuilder& OpConditionBuilder::SetInputsDataFormatSelector(
OpConditionBuilder &OpConditionBuilder::SetInputsDataFormatSelector(
OpRegistrationInfo::DataFormatSelector selector) {
data_format_selector_ = selector;
return *this;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册