提交 cc3c8e2f 编写于 作者: 卢旭辉

Merge branch 'bug-fix' into 'm0.11'

Fix some bugs.

See merge request !1110
...@@ -164,6 +164,7 @@ MaceStatus NetDefAdapter::AdaptNetDef( ...@@ -164,6 +164,7 @@ MaceStatus NetDefAdapter::AdaptNetDef(
input_info->set_dims(j, input_shape[j]); input_info->set_dims(j, input_shape[j]);
} }
} }
tensor_shape_map.emplace(input_info->name(), input_shape);
output_map.emplace(input_info->name(), InternalOutputInfo( output_map.emplace(input_info->name(), InternalOutputInfo(
mem_type, input_info->data_type(), mem_type, input_info->data_type(),
input_data_format, input_shape, -1)); input_data_format, input_shape, -1));
...@@ -220,6 +221,13 @@ MaceStatus NetDefAdapter::AdaptNetDef( ...@@ -220,6 +221,13 @@ MaceStatus NetDefAdapter::AdaptNetDef(
&op_output_data_format, &op_output_data_format,
target_net_def)); target_net_def));
} }
input_size = op_def.input_size();
for (int i = 0; i < input_size; ++i) {
if (output_map.count(op_def.input(i)) == 1) {
output_map.at(op_def.input(i)).consumer_op_indices.push_back(
target_net_def->op_size());
}
}
int output_size = op_def.output_size(); int output_size = op_def.output_size();
for (int out_idx = 0; out_idx < output_size; ++out_idx) { for (int out_idx = 0; out_idx < output_size; ++out_idx) {
...@@ -276,6 +284,15 @@ MaceStatus NetDefAdapter::AdaptNetDef( ...@@ -276,6 +284,15 @@ MaceStatus NetDefAdapter::AdaptNetDef(
output_op_def->set_output(i, t_output_name); output_op_def->set_output(i, t_output_name);
} }
} }
for (int idx : internal_output_info.consumer_op_indices) {
auto consumer_op_def = target_net_def->mutable_op(idx);
int input_size = consumer_op_def->input_size();
for (int i = 0; i < input_size; ++i) {
if (consumer_op_def->input(i) == output_info.name()) {
consumer_op_def->set_input(i, t_output_name);
}
}
}
auto transformed_op_def = target_net_def->add_op(); auto transformed_op_def = target_net_def->add_op();
OpenCLUtil::BuildTransformOpDef( OpenCLUtil::BuildTransformOpDef(
t_output_name, t_output_name,
...@@ -414,12 +431,10 @@ MaceStatus NetDefAdapter::AdaptDataFormat( ...@@ -414,12 +431,10 @@ MaceStatus NetDefAdapter::AdaptDataFormat(
} }
src_df = output_map->at(op_def->input(i)).data_format; src_df = output_map->at(op_def->input(i)).data_format;
dst_df = inputs_data_format[i]; dst_df = inputs_data_format[i];
if (src_df == DataFormat::NONE if (src_df != DataFormat::NONE
|| dst_df == DataFormat::NONE && dst_df != DataFormat::NONE
|| output_map->at(op_def->input(i)).shape.size() != 4) { && output_map->at(op_def->input(i)).shape.size() == 4
continue; && src_df != dst_df) {
}
if (src_df != dst_df) {
std::string transformed_name = TransformedName(op_def->input(i), std::string transformed_name = TransformedName(op_def->input(i),
"data_format", static_cast<int>(dst_df)); "data_format", static_cast<int>(dst_df));
if (transformed_set->count(transformed_name) == 0) { if (transformed_set->count(transformed_name) == 0) {
...@@ -461,6 +476,9 @@ MaceStatus NetDefAdapter::AdaptDataFormat( ...@@ -461,6 +476,9 @@ MaceStatus NetDefAdapter::AdaptDataFormat(
SetProtoArg<int>(transpose_op_def, SetProtoArg<int>(transpose_op_def,
OutputMemoryTypeTagName(), OutputMemoryTypeTagName(),
target_mem_type); target_mem_type);
// update tensor consumer information
output_map->at(op_def->input(i)).consumer_op_indices.push_back(
target_net_def->op_size() - 1);
// update output information map // update output information map
output_map->emplace( output_map->emplace(
...@@ -545,6 +563,10 @@ MaceStatus NetDefAdapter::AdaptMemoryType( ...@@ -545,6 +563,10 @@ MaceStatus NetDefAdapter::AdaptMemoryType(
OutputMemoryTypeTagName(), OutputMemoryTypeTagName(),
dst_mem_type); dst_mem_type);
// update tensor consumer information
output_map->at(op_def->input(i)).consumer_op_indices.push_back(
target_net_def->op_size() - 1);
// update output information map // update output information map
output_map->emplace( output_map->emplace(
transformed_name, transformed_name,
......
...@@ -78,13 +78,14 @@ class NetDefAdapter { ...@@ -78,13 +78,14 @@ class NetDefAdapter {
const std::vector<index_t> &shape, const std::vector<index_t> &shape,
int op_idx) int op_idx)
: mem_type(mem_type), dtype(dtype), data_format(data_format), : mem_type(mem_type), dtype(dtype), data_format(data_format),
shape(shape), op_idx(op_idx) {} shape(shape), op_idx(op_idx), consumer_op_indices() {}
MemoryType mem_type; MemoryType mem_type;
DataType dtype; DataType dtype;
DataFormat data_format; DataFormat data_format;
std::vector<index_t> shape; // tensor shape std::vector<index_t> shape; // tensor shape
int op_idx; // operation which generate the tensor int op_idx; // operation which generate the tensor
std::vector<int> consumer_op_indices;
}; };
typedef std::unordered_map<std::string, InternalOutputInfo> TensorInfoMap; typedef std::unordered_map<std::string, InternalOutputInfo> TensorInfoMap;
......
...@@ -436,7 +436,9 @@ OpenCLRuntime::OpenCLRuntime( ...@@ -436,7 +436,9 @@ OpenCLRuntime::OpenCLRuntime(
} }
OpenCLRuntime::~OpenCLRuntime() { OpenCLRuntime::~OpenCLRuntime() {
command_queue_->finish(); if (command_queue_ != nullptr) {
command_queue_->finish();
}
built_program_map_.clear(); built_program_map_.clear();
// We need to control the destruction order, which has dependencies // We need to control the destruction order, which has dependencies
command_queue_.reset(); command_queue_.reset();
......
...@@ -426,8 +426,8 @@ class PoolingOp<DeviceType::CPU, uint8_t> : public PoolingOpBase { ...@@ -426,8 +426,8 @@ class PoolingOp<DeviceType::CPU, uint8_t> : public PoolingOpBase {
(in_h_end - in_h_begin) * (in_w_end - in_w_begin); (in_h_end - in_h_begin) * (in_w_end - in_w_begin);
MACE_CHECK(block_size > 0); MACE_CHECK(block_size > 0);
std::vector<uint16_t> average_buffer(channels); std::vector<uint32_t> average_buffer(channels);
uint16_t *avg_buffer = average_buffer.data(); uint32_t *avg_buffer = average_buffer.data();
std::fill_n(avg_buffer, channels, 0); std::fill_n(avg_buffer, channels, 0);
for (index_t ih = in_h_begin; ih < in_h_end; ++ih) { for (index_t ih = in_h_begin; ih < in_h_end; ++ih) {
for (index_t iw = in_w_begin; iw < in_w_end; ++iw) { for (index_t iw = in_w_begin; iw < in_w_end; ++iw) {
...@@ -436,20 +436,34 @@ class PoolingOp<DeviceType::CPU, uint8_t> : public PoolingOpBase { ...@@ -436,20 +436,34 @@ class PoolingOp<DeviceType::CPU, uint8_t> : public PoolingOpBase {
index_t c = 0; index_t c = 0;
#if defined(MACE_ENABLE_NEON) #if defined(MACE_ENABLE_NEON)
for (; c <= channels - 16; c += 16) { for (; c <= channels - 16; c += 16) {
uint16x8_t avg_vec[2]; uint16x8_t tmp_avg[2];
avg_vec[0] = vld1q_u16(avg_buffer + c);
avg_vec[1] = vld1q_u16(avg_buffer + c + 8);
uint8x16_t in_vec = vld1q_u8(in_ptr + c); uint8x16_t in_vec = vld1q_u8(in_ptr + c);
avg_vec[0] = vaddw_u8(avg_vec[0], vget_low_u8(in_vec)); tmp_avg[0] = vmovl_u8(vget_low_u8(in_vec));
avg_vec[1] = vaddw_u8(avg_vec[1], vget_high_u8(in_vec)); tmp_avg[1] = vmovl_u8(vget_high_u8(in_vec));
vst1q_u16(avg_buffer + c, avg_vec[0]); uint32x4_t avg_vec[4];
vst1q_u16(avg_buffer + c + 8, avg_vec[1]); avg_vec[0] = vld1q_u32(avg_buffer + c);
avg_vec[1] = vld1q_u32(avg_buffer + c + 4);
avg_vec[2] = vld1q_u32(avg_buffer + c + 8);
avg_vec[3] = vld1q_u32(avg_buffer + c + 12);
avg_vec[0] = vaddw_u16(avg_vec[0], vget_low_u16(tmp_avg[0]));
avg_vec[1] = vaddw_u16(avg_vec[1], vget_high_u16(tmp_avg[0]));
avg_vec[2] = vaddw_u16(avg_vec[2], vget_low_u16(tmp_avg[1]));
avg_vec[3] = vaddw_u16(avg_vec[3], vget_high_u16(tmp_avg[1]));
vst1q_u32(avg_buffer + c, avg_vec[0]);
vst1q_u32(avg_buffer + c + 4, avg_vec[1]);
vst1q_u32(avg_buffer + c + 8, avg_vec[2]);
vst1q_u32(avg_buffer + c + 12, avg_vec[3]);
} }
for (; c <= channels - 8; c += 8) { for (; c <= channels - 8; c += 8) {
uint16x8_t avg_vec = vld1q_u16(avg_buffer + c);
uint8x8_t in_vec = vld1_u8(in_ptr + c); uint8x8_t in_vec = vld1_u8(in_ptr + c);
avg_vec = vaddw_u8(avg_vec, in_vec); uint16x8_t tmp_avg = vmovl_u8(in_vec);
vst1q_u16(avg_buffer + c, avg_vec); uint32x4_t avg_vec[2];
avg_vec[0] = vld1q_u32(avg_buffer + c);
avg_vec[1] = vld1q_u32(avg_buffer + c + 4);
avg_vec[0] = vaddw_u16(avg_vec[0], vget_low_u16(tmp_avg));
avg_vec[1] = vaddw_u16(avg_vec[1], vget_high_u16(tmp_avg));
vst1q_u32(avg_buffer + c, avg_vec[0]);
vst1q_u32(avg_buffer + c + 4, avg_vec[1]);
} }
#endif #endif
for (; c < channels; ++c) { for (; c < channels; ++c) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册