提交 c16e51ef 编写于 作者: xiebaiyuan's avatar xiebaiyuan

commit bug fix in conv add deepwise p0

上级 489e06d1
...@@ -231,6 +231,13 @@ void Executor<Dtype, P>::InitMemory() { ...@@ -231,6 +231,13 @@ void Executor<Dtype, P>::InitMemory() {
Get_binary_data(program_.model_path + "/" + var_desc->Name()); Get_binary_data(program_.model_path + "/" + var_desc->Name());
char *data = origin_data; char *data = origin_data;
LoadMemory(*var_desc, tensor, &data); LoadMemory(*var_desc, tensor, &data);
// DLOG << "----- " << var_desc->Name();
// DLOG << "----- " << tensor->dims();
// float *pDouble = tensor->template data<float>();
// for (int i = 0; i < tensor->numel() && i < 30; ++i) {
// std::cout << pDouble[i] << std::endl;
// }
delete origin_data; delete origin_data;
} else { } else {
if (var_desc->Type() == framework::VARTYPE_TYPE_LOD_TENSOR) { if (var_desc->Type() == framework::VARTYPE_TYPE_LOD_TENSOR) {
......
...@@ -129,10 +129,13 @@ void ConvAddCompute(const FusionConvAddParam<CPU> &param) { ...@@ -129,10 +129,13 @@ void ConvAddCompute(const FusionConvAddParam<CPU> &param) {
// param.Paddings(), // param.Paddings(),
// param.Filter(), param.Bias(), // param.Filter(), param.Bias(),
// param.Output(), false); // param.Output(), false);
if (param.Paddings()[0] == 0) {
math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(), param.Output(), math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(),
*param.Bias(), true); *param.Bias(), true);
} else {
math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(),
param.Output(), *param.Bias(), true);
}
} else { } else {
ConvAddBasic(param); ConvAddBasic(param);
} }
......
...@@ -1881,6 +1881,103 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter, ...@@ -1881,6 +1881,103 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter,
#endif #endif
} }
void DepthwiseConv3x3s2p0(const Tensor *input, const Tensor *filter,
Tensor *output, Tensor bias, bool if_bias) {
#if __ARM_NEON
const int batch_size = static_cast<int>(input->dims()[0]);
const int input_channel = static_cast<int>(input->dims()[1]);
const int input_height = static_cast<int>(input->dims()[2]);
const int input_width = static_cast<int>(input->dims()[3]);
const int output_height = static_cast<int>(output->dims()[2]);
const int output_width = static_cast<int>(output->dims()[3]);
const int inhxw = input_height * input_width;
const int outhxw = output_height * output_width;
float32x4_t zero = vdupq_n_f32(0.0);
for (int b = 0; b < batch_size; b++) {
#pragma omp parallel for
for (int c = 0; c < input_channel; c++) {
const float *filter_data = filter->data<float>() + c * 9;
const float *input_data = input->data<float>() + c * inhxw;
const float *bias_data = bias.data<float>() + c;
float *output_data = output->data<float>() + c * outhxw;
float w00 = filter_data[0];
float w01 = filter_data[1];
float w02 = filter_data[2];
float w10 = filter_data[3];
float w11 = filter_data[4];
float w12 = filter_data[5];
float w20 = filter_data[6];
float w21 = filter_data[7];
float w22 = filter_data[8];
float32x4_t biasv = vld1q_dup_f32(bias_data);
for (int i = 0; i < output_height; i += 1) {
for (int m = 0; m < output_width - 2; m += 3) {
float *output_ptr = output_data + i * output_width + m;
float32x4x2_t input_buff_top{}, input_buff_mid{}, input_buff_bottom{};
float32x4_t in0, in1, in2, in3, in4, in5, tmp0, tmp1, tmp2, tmp3,
tmp4, tmp5, out0;
input_buff_top =
vld2q_f32(input_data + (2 * i) * input_width + (2 * m));
input_buff_mid =
vld2q_f32(input_data + (2 * i + 1) * input_width + (2 * m));
input_buff_bottom =
vld2q_f32(input_data + (2 * i + 2) * input_width + (2 * m));
in0 = input_buff_top.val[0];
tmp0 = input_buff_top.val[1];
tmp1 = vextq_f32(in0, zero, 1);
in2 = input_buff_mid.val[0];
tmp2 = input_buff_mid.val[1];
tmp3 = vextq_f32(in2, zero, 1);
in4 = input_buff_bottom.val[0];
tmp4 = input_buff_bottom.val[1];
tmp5 = vextq_f32(in4, zero, 1);
out0 = vmulq_n_f32(in0, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in2, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vmlaq_n_f32(out0, in4, w20);
out0 = vmlaq_n_f32(out0, tmp4, w21);
out0 = vmlaq_n_f32(out0, tmp5, w22);
out0 = vaddq_f32(out0, biasv);
vst1q_lane_f32(output_ptr, out0, 0);
vst1q_lane_f32(output_ptr + 1, out0, 1);
vst1q_lane_f32(output_ptr + 2, out0, 2);
}
int m;
for (m = 0; m < output_width - 2; m += 3) {
}
for (int j = m; j < output_width; j++) {
output_data[i * output_width + j] =
input_data[(2 * i - 1) * input_width + 2 * j - 1] * w00 +
input_data[(2 * i - 1) * input_width + 2 * j] * w01 +
input_data[(2 * i - 1) * input_width + 2 * j + 1] * w02 +
input_data[(2 * i) * input_width + 2 * j - 1] * w10 +
input_data[(2 * i) * input_width + 2 * j] * w11 +
input_data[(2 * i) * input_width + 2 * j + 1] * w12 +
input_data[(2 * i + 1) * input_width + 2 * j - 1] * w20 +
input_data[(2 * i + 1) * input_width + 2 * j] * w21 +
input_data[(2 * i + 1) * input_width + 2 * j + 1] * w22;
output_data[i * output_width + j] += *bias_data;
}
}
}
}
#endif
}
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -43,6 +43,9 @@ void DepthwiseConv3x3s2p1v2(const Tensor *input, const Tensor *filter, ...@@ -43,6 +43,9 @@ void DepthwiseConv3x3s2p1v2(const Tensor *input, const Tensor *filter,
void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter, void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter,
Tensor *output, const Tensor *new_scale, Tensor *output, const Tensor *new_scale,
const Tensor *new_bias, bool if_relu); const Tensor *new_bias, bool if_relu);
void DepthwiseConv3x3s2p0(const Tensor *input, const Tensor *filter,
Tensor *output, Tensor bias, bool if_bias);
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册