提交 ad33154b 编写于 作者: L liuqi

BUG: fix adpter bug when output of net is used by other ops.

1. Use uint32_t for average pooling.
2. Fix destruct command queue bug.
上级 0a9f331c
......@@ -221,6 +221,13 @@ MaceStatus NetDefAdapter::AdaptNetDef(
&op_output_data_format,
target_net_def));
}
input_size = op_def.input_size();
for (int i = 0; i < input_size; ++i) {
if (output_map.count(op_def.input(i)) == 1) {
output_map.at(op_def.input(i)).consumer_op_indices.push_back(
target_net_def->op_size());
}
}
int output_size = op_def.output_size();
for (int out_idx = 0; out_idx < output_size; ++out_idx) {
......@@ -277,6 +284,15 @@ MaceStatus NetDefAdapter::AdaptNetDef(
output_op_def->set_output(i, t_output_name);
}
}
for (int idx : internal_output_info.consumer_op_indices) {
auto consumer_op_def = target_net_def->mutable_op(idx);
int input_size = consumer_op_def->input_size();
for (int i = 0; i < input_size; ++i) {
if (consumer_op_def->input(i) == output_info.name()) {
consumer_op_def->set_input(i, t_output_name);
}
}
}
auto transformed_op_def = target_net_def->add_op();
OpenCLUtil::BuildTransformOpDef(
t_output_name,
......@@ -415,12 +431,10 @@ MaceStatus NetDefAdapter::AdaptDataFormat(
}
src_df = output_map->at(op_def->input(i)).data_format;
dst_df = inputs_data_format[i];
if (src_df == DataFormat::NONE
|| dst_df == DataFormat::NONE
|| output_map->at(op_def->input(i)).shape.size() != 4) {
continue;
}
if (src_df != dst_df) {
if (src_df != DataFormat::NONE
&& dst_df != DataFormat::NONE
&& output_map->at(op_def->input(i)).shape.size() == 4
&& src_df != dst_df) {
std::string transformed_name = TransformedName(op_def->input(i),
"data_format", static_cast<int>(dst_df));
if (transformed_set->count(transformed_name) == 0) {
......@@ -462,6 +476,9 @@ MaceStatus NetDefAdapter::AdaptDataFormat(
SetProtoArg<int>(transpose_op_def,
OutputMemoryTypeTagName(),
target_mem_type);
// update tensor consumer information
output_map->at(op_def->input(i)).consumer_op_indices.push_back(
target_net_def->op_size() - 1);
// update output information map
output_map->emplace(
......@@ -546,6 +563,10 @@ MaceStatus NetDefAdapter::AdaptMemoryType(
OutputMemoryTypeTagName(),
dst_mem_type);
// update tensor consumer information
output_map->at(op_def->input(i)).consumer_op_indices.push_back(
target_net_def->op_size() - 1);
// update output information map
output_map->emplace(
transformed_name,
......
......@@ -78,13 +78,14 @@ class NetDefAdapter {
const std::vector<index_t> &shape,
int op_idx)
: mem_type(mem_type), dtype(dtype), data_format(data_format),
shape(shape), op_idx(op_idx) {}
shape(shape), op_idx(op_idx), consumer_op_indices() {}
MemoryType mem_type;
DataType dtype;
DataFormat data_format;
std::vector<index_t> shape; // tensor shape
int op_idx; // operation which generate the tensor
std::vector<int> consumer_op_indices;
};
typedef std::unordered_map<std::string, InternalOutputInfo> TensorInfoMap;
......
......@@ -434,7 +434,9 @@ OpenCLRuntime::OpenCLRuntime(
}
OpenCLRuntime::~OpenCLRuntime() {
command_queue_->finish();
if (command_queue_ != nullptr) {
command_queue_->finish();
}
built_program_map_.clear();
// We need to control the destruction order, which has dependencies
command_queue_.reset();
......
......@@ -426,8 +426,8 @@ class PoolingOp<DeviceType::CPU, uint8_t> : public PoolingOpBase {
(in_h_end - in_h_begin) * (in_w_end - in_w_begin);
MACE_CHECK(block_size > 0);
std::vector<uint16_t> average_buffer(channels);
uint16_t *avg_buffer = average_buffer.data();
std::vector<uint32_t> average_buffer(channels);
uint32_t *avg_buffer = average_buffer.data();
std::fill_n(avg_buffer, channels, 0);
for (index_t ih = in_h_begin; ih < in_h_end; ++ih) {
for (index_t iw = in_w_begin; iw < in_w_end; ++iw) {
......@@ -436,20 +436,34 @@ class PoolingOp<DeviceType::CPU, uint8_t> : public PoolingOpBase {
index_t c = 0;
#if defined(MACE_ENABLE_NEON)
for (; c <= channels - 16; c += 16) {
uint16x8_t avg_vec[2];
avg_vec[0] = vld1q_u16(avg_buffer + c);
avg_vec[1] = vld1q_u16(avg_buffer + c + 8);
uint16x8_t tmp_avg[2];
uint8x16_t in_vec = vld1q_u8(in_ptr + c);
avg_vec[0] = vaddw_u8(avg_vec[0], vget_low_u8(in_vec));
avg_vec[1] = vaddw_u8(avg_vec[1], vget_high_u8(in_vec));
vst1q_u16(avg_buffer + c, avg_vec[0]);
vst1q_u16(avg_buffer + c + 8, avg_vec[1]);
tmp_avg[0] = vmovl_u8(vget_low_u8(in_vec));
tmp_avg[1] = vmovl_u8(vget_high_u8(in_vec));
uint32x4_t avg_vec[4];
avg_vec[0] = vld1q_u32(avg_buffer + c);
avg_vec[1] = vld1q_u32(avg_buffer + c + 4);
avg_vec[2] = vld1q_u32(avg_buffer + c + 8);
avg_vec[3] = vld1q_u32(avg_buffer + c + 12);
avg_vec[0] = vaddw_u16(avg_vec[0], vget_low_u16(tmp_avg[0]));
avg_vec[1] = vaddw_u16(avg_vec[1], vget_high_u16(tmp_avg[0]));
avg_vec[2] = vaddw_u16(avg_vec[2], vget_low_u16(tmp_avg[1]));
avg_vec[3] = vaddw_u16(avg_vec[3], vget_high_u16(tmp_avg[1]));
vst1q_u32(avg_buffer + c, avg_vec[0]);
vst1q_u32(avg_buffer + c + 4, avg_vec[1]);
vst1q_u32(avg_buffer + c + 8, avg_vec[2]);
vst1q_u32(avg_buffer + c + 12, avg_vec[3]);
}
for (; c <= channels - 8; c += 8) {
uint16x8_t avg_vec = vld1q_u16(avg_buffer + c);
uint8x8_t in_vec = vld1_u8(in_ptr + c);
avg_vec = vaddw_u8(avg_vec, in_vec);
vst1q_u16(avg_buffer + c, avg_vec);
uint16x8_t tmp_avg = vmovl_u8(in_vec);
uint32x4_t avg_vec[2];
avg_vec[0] = vld1q_u32(avg_buffer + c);
avg_vec[1] = vld1q_u32(avg_buffer + c + 4);
avg_vec[0] = vaddw_u16(avg_vec[0], vget_low_u16(tmp_avg));
avg_vec[1] = vaddw_u16(avg_vec[1], vget_high_u16(tmp_avg));
vst1q_u32(avg_buffer + c, avg_vec[0]);
vst1q_u32(avg_buffer + c + 4, avg_vec[1]);
}
#endif
for (; c < channels; ++c) {
......
......@@ -1489,6 +1489,11 @@ class Transformer(base_converter.ConverterInterface):
new_output_name = self.output_name_map[op.output[i]]
self._quantize_activation_info[new_output_name] = \
self._quantize_activation_info[op.output[i]]
if op.output[i] in self._consumers:
for consumer_op in self._consumers[op.output[i]]:
self.replace(consumer_op.input,
op.output[i],
new_output_name)
op.output[i] = new_output_name
data_type_arg = ConverterUtil.get_arg(
......
......@@ -60,7 +60,7 @@ cc_test(
deps = [
":mace_api_test_header",
"//mace/libmace",
"//mace/ops:test",
"//test/ccutils",
"@gtest//:gtest_main",
],
)
......@@ -97,7 +97,7 @@ cc_test(
deps = [
":mace_api_test_header",
"//mace/libmace",
"//mace/ops:test",
"//test/ccutils",
"@gtest//:gtest_main",
],
)
......@@ -133,7 +133,7 @@ cc_test(
linkstatic = 1,
deps = [
"//mace/libmace",
"//mace/ops:test",
"//test/ccutils",
"@gtest//:gtest_main",
],
)
......@@ -169,7 +169,7 @@ cc_test(
linkstatic = 1,
deps = [
"//mace/libmace",
"//mace/ops:test",
"//test/ccutils",
"@gtest//:gtest_main",
],
)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册