提交 ad33154b 编写于 作者: L liuqi

BUG: fix adpter bug when output of net is used by other ops.

1. Use uint32_t for average pooling.
2. Fix destruct command queue bug.
上级 0a9f331c
...@@ -221,6 +221,13 @@ MaceStatus NetDefAdapter::AdaptNetDef( ...@@ -221,6 +221,13 @@ MaceStatus NetDefAdapter::AdaptNetDef(
&op_output_data_format, &op_output_data_format,
target_net_def)); target_net_def));
} }
input_size = op_def.input_size();
for (int i = 0; i < input_size; ++i) {
if (output_map.count(op_def.input(i)) == 1) {
output_map.at(op_def.input(i)).consumer_op_indices.push_back(
target_net_def->op_size());
}
}
int output_size = op_def.output_size(); int output_size = op_def.output_size();
for (int out_idx = 0; out_idx < output_size; ++out_idx) { for (int out_idx = 0; out_idx < output_size; ++out_idx) {
...@@ -277,6 +284,15 @@ MaceStatus NetDefAdapter::AdaptNetDef( ...@@ -277,6 +284,15 @@ MaceStatus NetDefAdapter::AdaptNetDef(
output_op_def->set_output(i, t_output_name); output_op_def->set_output(i, t_output_name);
} }
} }
for (int idx : internal_output_info.consumer_op_indices) {
auto consumer_op_def = target_net_def->mutable_op(idx);
int input_size = consumer_op_def->input_size();
for (int i = 0; i < input_size; ++i) {
if (consumer_op_def->input(i) == output_info.name()) {
consumer_op_def->set_input(i, t_output_name);
}
}
}
auto transformed_op_def = target_net_def->add_op(); auto transformed_op_def = target_net_def->add_op();
OpenCLUtil::BuildTransformOpDef( OpenCLUtil::BuildTransformOpDef(
t_output_name, t_output_name,
...@@ -415,12 +431,10 @@ MaceStatus NetDefAdapter::AdaptDataFormat( ...@@ -415,12 +431,10 @@ MaceStatus NetDefAdapter::AdaptDataFormat(
} }
src_df = output_map->at(op_def->input(i)).data_format; src_df = output_map->at(op_def->input(i)).data_format;
dst_df = inputs_data_format[i]; dst_df = inputs_data_format[i];
if (src_df == DataFormat::NONE if (src_df != DataFormat::NONE
|| dst_df == DataFormat::NONE && dst_df != DataFormat::NONE
|| output_map->at(op_def->input(i)).shape.size() != 4) { && output_map->at(op_def->input(i)).shape.size() == 4
continue; && src_df != dst_df) {
}
if (src_df != dst_df) {
std::string transformed_name = TransformedName(op_def->input(i), std::string transformed_name = TransformedName(op_def->input(i),
"data_format", static_cast<int>(dst_df)); "data_format", static_cast<int>(dst_df));
if (transformed_set->count(transformed_name) == 0) { if (transformed_set->count(transformed_name) == 0) {
...@@ -462,6 +476,9 @@ MaceStatus NetDefAdapter::AdaptDataFormat( ...@@ -462,6 +476,9 @@ MaceStatus NetDefAdapter::AdaptDataFormat(
SetProtoArg<int>(transpose_op_def, SetProtoArg<int>(transpose_op_def,
OutputMemoryTypeTagName(), OutputMemoryTypeTagName(),
target_mem_type); target_mem_type);
// update tensor consumer information
output_map->at(op_def->input(i)).consumer_op_indices.push_back(
target_net_def->op_size() - 1);
// update output information map // update output information map
output_map->emplace( output_map->emplace(
...@@ -546,6 +563,10 @@ MaceStatus NetDefAdapter::AdaptMemoryType( ...@@ -546,6 +563,10 @@ MaceStatus NetDefAdapter::AdaptMemoryType(
OutputMemoryTypeTagName(), OutputMemoryTypeTagName(),
dst_mem_type); dst_mem_type);
// update tensor consumer information
output_map->at(op_def->input(i)).consumer_op_indices.push_back(
target_net_def->op_size() - 1);
// update output information map // update output information map
output_map->emplace( output_map->emplace(
transformed_name, transformed_name,
......
...@@ -78,13 +78,14 @@ class NetDefAdapter { ...@@ -78,13 +78,14 @@ class NetDefAdapter {
const std::vector<index_t> &shape, const std::vector<index_t> &shape,
int op_idx) int op_idx)
: mem_type(mem_type), dtype(dtype), data_format(data_format), : mem_type(mem_type), dtype(dtype), data_format(data_format),
shape(shape), op_idx(op_idx) {} shape(shape), op_idx(op_idx), consumer_op_indices() {}
MemoryType mem_type; MemoryType mem_type;
DataType dtype; DataType dtype;
DataFormat data_format; DataFormat data_format;
std::vector<index_t> shape; // tensor shape std::vector<index_t> shape; // tensor shape
int op_idx; // operation which generate the tensor int op_idx; // operation which generate the tensor
std::vector<int> consumer_op_indices;
}; };
typedef std::unordered_map<std::string, InternalOutputInfo> TensorInfoMap; typedef std::unordered_map<std::string, InternalOutputInfo> TensorInfoMap;
......
...@@ -434,7 +434,9 @@ OpenCLRuntime::OpenCLRuntime( ...@@ -434,7 +434,9 @@ OpenCLRuntime::OpenCLRuntime(
} }
OpenCLRuntime::~OpenCLRuntime() { OpenCLRuntime::~OpenCLRuntime() {
command_queue_->finish(); if (command_queue_ != nullptr) {
command_queue_->finish();
}
built_program_map_.clear(); built_program_map_.clear();
// We need to control the destruction order, which has dependencies // We need to control the destruction order, which has dependencies
command_queue_.reset(); command_queue_.reset();
......
...@@ -426,8 +426,8 @@ class PoolingOp<DeviceType::CPU, uint8_t> : public PoolingOpBase { ...@@ -426,8 +426,8 @@ class PoolingOp<DeviceType::CPU, uint8_t> : public PoolingOpBase {
(in_h_end - in_h_begin) * (in_w_end - in_w_begin); (in_h_end - in_h_begin) * (in_w_end - in_w_begin);
MACE_CHECK(block_size > 0); MACE_CHECK(block_size > 0);
std::vector<uint16_t> average_buffer(channels); std::vector<uint32_t> average_buffer(channels);
uint16_t *avg_buffer = average_buffer.data(); uint32_t *avg_buffer = average_buffer.data();
std::fill_n(avg_buffer, channels, 0); std::fill_n(avg_buffer, channels, 0);
for (index_t ih = in_h_begin; ih < in_h_end; ++ih) { for (index_t ih = in_h_begin; ih < in_h_end; ++ih) {
for (index_t iw = in_w_begin; iw < in_w_end; ++iw) { for (index_t iw = in_w_begin; iw < in_w_end; ++iw) {
...@@ -436,20 +436,34 @@ class PoolingOp<DeviceType::CPU, uint8_t> : public PoolingOpBase { ...@@ -436,20 +436,34 @@ class PoolingOp<DeviceType::CPU, uint8_t> : public PoolingOpBase {
index_t c = 0; index_t c = 0;
#if defined(MACE_ENABLE_NEON) #if defined(MACE_ENABLE_NEON)
for (; c <= channels - 16; c += 16) { for (; c <= channels - 16; c += 16) {
uint16x8_t avg_vec[2]; uint16x8_t tmp_avg[2];
avg_vec[0] = vld1q_u16(avg_buffer + c);
avg_vec[1] = vld1q_u16(avg_buffer + c + 8);
uint8x16_t in_vec = vld1q_u8(in_ptr + c); uint8x16_t in_vec = vld1q_u8(in_ptr + c);
avg_vec[0] = vaddw_u8(avg_vec[0], vget_low_u8(in_vec)); tmp_avg[0] = vmovl_u8(vget_low_u8(in_vec));
avg_vec[1] = vaddw_u8(avg_vec[1], vget_high_u8(in_vec)); tmp_avg[1] = vmovl_u8(vget_high_u8(in_vec));
vst1q_u16(avg_buffer + c, avg_vec[0]); uint32x4_t avg_vec[4];
vst1q_u16(avg_buffer + c + 8, avg_vec[1]); avg_vec[0] = vld1q_u32(avg_buffer + c);
avg_vec[1] = vld1q_u32(avg_buffer + c + 4);
avg_vec[2] = vld1q_u32(avg_buffer + c + 8);
avg_vec[3] = vld1q_u32(avg_buffer + c + 12);
avg_vec[0] = vaddw_u16(avg_vec[0], vget_low_u16(tmp_avg[0]));
avg_vec[1] = vaddw_u16(avg_vec[1], vget_high_u16(tmp_avg[0]));
avg_vec[2] = vaddw_u16(avg_vec[2], vget_low_u16(tmp_avg[1]));
avg_vec[3] = vaddw_u16(avg_vec[3], vget_high_u16(tmp_avg[1]));
vst1q_u32(avg_buffer + c, avg_vec[0]);
vst1q_u32(avg_buffer + c + 4, avg_vec[1]);
vst1q_u32(avg_buffer + c + 8, avg_vec[2]);
vst1q_u32(avg_buffer + c + 12, avg_vec[3]);
} }
for (; c <= channels - 8; c += 8) { for (; c <= channels - 8; c += 8) {
uint16x8_t avg_vec = vld1q_u16(avg_buffer + c);
uint8x8_t in_vec = vld1_u8(in_ptr + c); uint8x8_t in_vec = vld1_u8(in_ptr + c);
avg_vec = vaddw_u8(avg_vec, in_vec); uint16x8_t tmp_avg = vmovl_u8(in_vec);
vst1q_u16(avg_buffer + c, avg_vec); uint32x4_t avg_vec[2];
avg_vec[0] = vld1q_u32(avg_buffer + c);
avg_vec[1] = vld1q_u32(avg_buffer + c + 4);
avg_vec[0] = vaddw_u16(avg_vec[0], vget_low_u16(tmp_avg));
avg_vec[1] = vaddw_u16(avg_vec[1], vget_high_u16(tmp_avg));
vst1q_u32(avg_buffer + c, avg_vec[0]);
vst1q_u32(avg_buffer + c + 4, avg_vec[1]);
} }
#endif #endif
for (; c < channels; ++c) { for (; c < channels; ++c) {
......
...@@ -1489,6 +1489,11 @@ class Transformer(base_converter.ConverterInterface): ...@@ -1489,6 +1489,11 @@ class Transformer(base_converter.ConverterInterface):
new_output_name = self.output_name_map[op.output[i]] new_output_name = self.output_name_map[op.output[i]]
self._quantize_activation_info[new_output_name] = \ self._quantize_activation_info[new_output_name] = \
self._quantize_activation_info[op.output[i]] self._quantize_activation_info[op.output[i]]
if op.output[i] in self._consumers:
for consumer_op in self._consumers[op.output[i]]:
self.replace(consumer_op.input,
op.output[i],
new_output_name)
op.output[i] = new_output_name op.output[i] = new_output_name
data_type_arg = ConverterUtil.get_arg( data_type_arg = ConverterUtil.get_arg(
......
...@@ -60,7 +60,7 @@ cc_test( ...@@ -60,7 +60,7 @@ cc_test(
deps = [ deps = [
":mace_api_test_header", ":mace_api_test_header",
"//mace/libmace", "//mace/libmace",
"//mace/ops:test", "//test/ccutils",
"@gtest//:gtest_main", "@gtest//:gtest_main",
], ],
) )
...@@ -97,7 +97,7 @@ cc_test( ...@@ -97,7 +97,7 @@ cc_test(
deps = [ deps = [
":mace_api_test_header", ":mace_api_test_header",
"//mace/libmace", "//mace/libmace",
"//mace/ops:test", "//test/ccutils",
"@gtest//:gtest_main", "@gtest//:gtest_main",
], ],
) )
...@@ -133,7 +133,7 @@ cc_test( ...@@ -133,7 +133,7 @@ cc_test(
linkstatic = 1, linkstatic = 1,
deps = [ deps = [
"//mace/libmace", "//mace/libmace",
"//mace/ops:test", "//test/ccutils",
"@gtest//:gtest_main", "@gtest//:gtest_main",
], ],
) )
...@@ -169,7 +169,7 @@ cc_test( ...@@ -169,7 +169,7 @@ cc_test(
linkstatic = 1, linkstatic = 1,
deps = [ deps = [
"//mace/libmace", "//mace/libmace",
"//mace/ops:test", "//test/ccutils",
"@gtest//:gtest_main", "@gtest//:gtest_main",
], ],
) )
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册