提交 9ace65da 编写于 作者: 李寅 提交者: 赵奇可

Merge branch 'fix-converter-bug' into 'master'

Fix the bug: convert image_to_buffer dtype from float to half.

See merge request !783
......@@ -48,8 +48,6 @@ SerialNet::SerialNet(
ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
operator_def, "device", static_cast<int>(device_type_));
if (op_device == type) {
VLOG(3) << "Creating operator " << operator_def.name() << "("
<< operator_def.type() << ")";
OperatorDef temp_def(operator_def);
std::unique_ptr<OperatorBase> op(
op_registry->CreateOperator(temp_def, ws, type, mode));
......
......@@ -62,6 +62,8 @@ std::unique_ptr<OperatorBase> OperatorRegistryBase::CreateOperator(
const int op_mode_i = ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
operator_def, "mode", static_cast<int>(NetMode::NORMAL));
const NetMode op_mode = static_cast<NetMode>(op_mode_i);
VLOG(3) << "Creating operator " << operator_def.name() << "("
<< operator_def.type() << "<" << dtype << ">" << ")";
if (op_mode == mode) {
return registry_.Create(
OpKeyBuilder(operator_def.type().data())
......
......@@ -63,7 +63,7 @@ __constant sampler_t SAMPLER =
inline float4 do_sigmoid(float4 in) {
// native_func not support half
return native_recip(1.0 + native_exp(-in));
return native_recip(1.0f + native_exp(-in));
}
inline DATA_TYPE4 do_activation(DATA_TYPE4 in,
......
......@@ -104,6 +104,7 @@ class Transformer(base_converter.ConverterInterface):
self._target_data_format = DataFormat.NHWC
self._input_output_added = False
self._opencl_max_image_size = [0, 0]
self._output_op_names = set()
self._quantize_activation_info = {}
self._quantized_tensor = set()
......@@ -1388,6 +1389,7 @@ class Transformer(base_converter.ConverterInterface):
ConverterUtil.add_data_type_arg(op_def, mace_pb2.DT_FLOAT)
ConverterUtil.add_data_format_arg(op_def, DataFormat.NHWC)
self._output_op_names.add(op_def.name)
self._input_output_added = True
......@@ -1525,7 +1527,8 @@ class Transformer(base_converter.ConverterInterface):
data_type_arg.name = MaceKeyword.mace_op_data_type_str
data_type_arg.i = self._option.data_type
elif data_type_arg.i != self._option.data_type \
and data_type_arg.i == mace_pb2.DT_FLOAT:
and data_type_arg.i == mace_pb2.DT_FLOAT \
and op.name not in self._output_op_names:
data_type_arg.i = self._option.data_type
return False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册