提交 7db044ac 编写于 作者: H hjchen2

Refactor batch norm, and fix deconv memory leak

上级 06a6f6a9
......@@ -22,6 +22,8 @@ const char *G_OP_TYPE_BATCHNORM = "batch_norm";
const char *G_OP_TYPE_BOX_CODER = "box_coder";
const char *G_OP_TYPE_CONCAT = "concat";
const char *G_OP_TYPE_ELEMENTWISE_ADD = "elementwise_add";
const char *G_OP_TYPE_ELEMENTWISE_SUB = "elementwise_sub";
const char *G_OP_TYPE_ELEMENTWISE_MUL = "elementwise_mul";
const char *G_OP_TYPE_FILL_CONSTANT = "fill_constant";
const char *G_OP_TYPE_FUSION_CONV_ADD_RELU = "fusion_conv_add_relu";
const char *G_OP_TYPE_FUSION_CONV_ADD_PRELU = "fusion_conv_add_prelu";
......@@ -67,7 +69,6 @@ const char *G_OP_TYPE_CRF = "crf_decoding";
const char *G_OP_TYPE_BILINEAR_INTERP = "bilinear_interp";
const char *G_OP_TYPE_FLATTEN = "flatten";
const char *G_OP_TYPE_SHAPE = "shape";
const char *G_OP_TYPE_ELEMENTWISE_MUL = "elementwise_mul";
const char *G_OP_TYPE_SUM = "sum";
const char *G_OP_TYPE_TOP_K = "top_k";
const char *G_OP_TYPE_CAST = "cast";
......@@ -102,6 +103,8 @@ std::unordered_map<
{G_OP_TYPE_SIGMOID, {{"X"}, {"Out"}}},
{G_OP_TYPE_MUL, {{"X"}, {"Out"}}},
{G_OP_TYPE_ELEMENTWISE_ADD, {{"X", "Y"}, {"Out"}}},
{G_OP_TYPE_ELEMENTWISE_SUB, {{"X", "Y"}, {"Out"}}},
{G_OP_TYPE_ELEMENTWISE_MUL, {{"X", "Y"}, {"Out"}}},
{G_OP_TYPE_POOL2D, {{"X"}, {"Out"}}},
{G_OP_TYPE_BATCHNORM, {{"X"}, {"Y"}}},
{G_OP_TYPE_LRN, {{"X"}, {"Out"}}},
......@@ -146,7 +149,6 @@ std::unordered_map<
{G_OP_TYPE_SUM, {{"X"}, {"Out"}}},
{G_OP_TYPE_TOP_K, {{"X"}, {"Out", "Indices"}}},
{G_OP_TYPE_CAST, {{"X"}, {"Out"}}},
{G_OP_TYPE_ELEMENTWISE_MUL, {{"X", "Y"}, {"Out"}}},
{G_OP_TYPE_QUANTIZE, {{"X"}, {"Out", "OutScale"}}},
{G_OP_TYPE_DEQUANTIZE, {{"X", "Scale"}, {"Out"}}},
{G_OP_TYPE_FUSION_DEQUANT_BN, {{"X", "Scale"}, {"Out"}}},
......
......@@ -112,6 +112,8 @@ extern const char *G_OP_TYPE_BATCHNORM;
extern const char *G_OP_TYPE_BOX_CODER;
extern const char *G_OP_TYPE_CONCAT;
extern const char *G_OP_TYPE_ELEMENTWISE_ADD;
extern const char *G_OP_TYPE_ELEMENTWISE_SUB;
extern const char *G_OP_TYPE_ELEMENTWISE_MUL;
extern const char *G_OP_TYPE_FUSION_CONV_ADD_RELU;
extern const char *G_OP_TYPE_FUSION_CONV_ADD_PRELU;
extern const char *G_OP_TYPE_FUSION_CONV_ADD_ADD_PRELU;
......@@ -149,7 +151,6 @@ extern const char *G_OP_TYPE_FUSION_CONV_BN;
extern const char *G_OP_TYPE_CONV_TRANSPOSE;
extern const char *G_OP_TYPE_PRELU;
extern const char *G_OP_TYPE_SUM;
extern const char *G_OP_TYPE_ELEMENTWISE_MUL;
extern const char *G_OP_TYPE_TOP_K;
extern const char *G_OP_TYPE_CAST;
......
......@@ -18,283 +18,63 @@ limitations under the License. */
#include <cmath>
#include "operators/op_param.h"
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
#include <arm_neon.h>
#endif // __ARM_NEON__
namespace paddle_mobile {
namespace operators {
template <typename P>
void BatchnormCompute(const BatchNormParam<CPU> &param) {
const Tensor *input_x = param.InputX();
auto input_x_ptr = input_x->data<float>();
const auto &x_dims = input_x->dims();
const int N = x_dims[0];
const int C = x_dims[1];
const int H = x_dims[2];
const int W = x_dims[3];
const int stride0 = C * H * W;
const int stride1 = H * W;
const int stride2 = W;
Tensor *out = param.OutputY();
auto out_ptr = out->mutable_data<float>();
const float epsilon = param.Epsilon();
const Tensor *mean = param.InputMean();
const Tensor *variance = param.InputVariance();
const Tensor *scale = param.InputScale();
const Tensor *bias = param.InputBias();
auto mean_ptr = mean->data<float>();
auto variance_ptr = variance->data<float>();
auto scale_ptr = scale->data<float>();
auto bias_ptr = bias->data<float>();
// Tensor inv_std;
// auto inv_std_ptr = inv_std.mutable_data<float>(make_ddim({C}));
PADDLE_MOBILE_ENFORCE(C == variance->numel(),
"C must equal to variance.numel()");
int HXW = H * W;
#if __ARM_NEON
#if __aarch64__
float *inv_std_ptr = new float[C];
for (int i = 0; i < C; i++) {
inv_std_ptr[i] =
1 / static_cast<float>(pow((variance_ptr[i] + epsilon), 0.5));
}
Tensor new_scale;
auto new_scale_ptr = new_scale.mutable_data<float>(framework::make_ddim({C}));
Tensor new_bias;
auto new_bias_ptr = new_bias.mutable_data<float>(framework::make_ddim({C}));
/// ((x - est_mean) * (inv_var) * scale + bias equal to
/// (x * inv_var * scale) + (bias - est_mean * inv_var * scale)
for (int i = 0; i < C; i++) {
new_scale_ptr[i] = inv_std_ptr[i] * scale_ptr[i];
new_bias_ptr[i] = bias_ptr[i] - mean_ptr[i] * inv_std_ptr[i] * scale_ptr[i];
{
for (int n = 0; n < N; n++) {
for (int h = 0; h < H; h++) {
int tmp_index = n * stride0 + i * stride1 + h * stride2;
for (int w = 0; w < W; w++) {
int index = tmp_index + w;
out_ptr[index] =
input_x_ptr[index] * new_scale_ptr[i] + new_bias_ptr[i];
}
}
const float *mean_ptr = param.InputMean()->data<float>();
const float *variance_ptr = param.InputVariance()->data<float>();
const float *scale_ptr = param.InputScale()->data<float>();
const float *bias_ptr = param.InputBias()->data<float>();
const framework::Tensor *input = param.InputX();
const float *input_ptr = input->data<float>();
framework::Tensor *output = param.OutputY();
float *output_ptr = output->mutable_data<float>();
size_t spatial_size = output->dims()[2] * output->dims()[3];
int channels = output->dims()[1];
#pragma omp parallel for collapse(2)
for (int batch = 0; batch < output->dims()[0]; ++batch) {
for (int c = 0; c < channels; ++c) {
float inv_scale = 1.f / (std::sqrt(variance_ptr[c] + epsilon));
float bias = bias_ptr[c] - inv_scale * scale_ptr[c] * mean_ptr[c];
float scale = inv_scale * scale_ptr[c];
size_t offset = (batch * channels + c) * spatial_size;
const float *x = input_ptr + offset;
float *y = output_ptr + offset;
size_t remain = spatial_size;
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
int loop = spatial_size >> 4;
remain = spatial_size & 0xF;
float32x4_t __scale = vdupq_n_f32(scale);
float32x4_t __bias = vdupq_n_f32(bias);
for (int k = 0; k < loop; ++k, x += 16, y += 16) {
float32x4_t r0 = vld1q_f32(x);
float32x4_t r1 = vld1q_f32(x + 4);
float32x4_t r2 = vld1q_f32(x + 8);
float32x4_t r3 = vld1q_f32(x + 12);
r0 = vmlaq_f32(__bias, __scale, r0);
r1 = vmlaq_f32(__bias, __scale, r1);
r2 = vmlaq_f32(__bias, __scale, r2);
r3 = vmlaq_f32(__bias, __scale, r3);
vst1q_f32(y, r0);
vst1q_f32(y + 4, r1);
vst1q_f32(y + 8, r2);
vst1q_f32(y + 12, r3);
}
}
}
delete[] inv_std_ptr;
#else
if (HXW > 32) {
int NXC = N * C;
float *inv_std_ptr = new float[NXC * 4];
float *volatile new_scale_ptr = new float[NXC * 4];
float *volatile new_bias_ptr = new float[NXC * 4];
/// std = (var + epsilon).sqrt();
/// inv_std = 1 / std;
for (int i = 0; i < C * 4; i += 4) {
int index = i / 4;
inv_std_ptr[i] =
1 / static_cast<float>(pow((variance_ptr[index] + epsilon), 0.5));
inv_std_ptr[i + 1] = inv_std_ptr[i];
inv_std_ptr[i + 2] = inv_std_ptr[i];
inv_std_ptr[i + 3] = inv_std_ptr[i];
new_scale_ptr[i] = inv_std_ptr[i] * scale_ptr[index];
new_scale_ptr[i + 1] = new_scale_ptr[i];
new_scale_ptr[i + 2] = new_scale_ptr[i];
new_scale_ptr[i + 3] = new_scale_ptr[i];
new_bias_ptr[i] =
bias_ptr[index] - mean_ptr[index] * inv_std_ptr[i] * scale_ptr[index];
new_bias_ptr[i + 1] = new_bias_ptr[i];
new_bias_ptr[i + 2] = new_bias_ptr[i];
new_bias_ptr[i + 3] = new_bias_ptr[i];
}
for (int j = C * 4; j < NXC * 4; ++j) {
new_scale_ptr[j] = new_scale_ptr[j - C * 4];
new_bias_ptr[j] = new_bias_ptr[j - C * 4];
}
asm volatile(
"subs %[N], %[N], #1 \n\t"
"blt end_n_%= \n\t"
"loop_n_%=: \n\t"
"subs %[C], %[C], #1 \n\t"
"blt end_c_%= \n\t"
"loop_c_%=: \n\t"
"vld1.32 {q9}, [%[new_scale_ptr]]! \n\t"
"vld1.32 {q10}, [%[new_bias_ptr]]! \n\t"
"mov r6, %[HXW] \n\t"
"subs r6, r6, #32 \n\t"
"blt end_hw_%= \n\t"
"loop_hw_%=: \n\t"
"vld1.32 {q1, q2}, [%[input_x_ptr]]! \n\t"
"vld1.32 {q3, q4}, [%[input_x_ptr]]! \n\t"
"vld1.32 {q5, q6}, [%[input_x_ptr]]! \n\t"
"vld1.32 {q7, q8}, [%[input_x_ptr]]! \n\t"
"vmul.f32 q1, q1, q9 \n\t"
"vmul.f32 q2, q2, q9 \n\t"
"vmul.f32 q3, q3, q9 \n\t"
"vmul.f32 q4, q4, q9 \n\t"
"vmul.f32 q5, q5, q9 \n\t"
"vmul.f32 q6, q6, q9 \n\t"
"vmul.f32 q7, q7, q9 \n\t"
"vmul.f32 q8, q8, q9 \n\t"
"vadd.f32 q1, q1, q10 \n\t"
"vadd.f32 q2, q2, q10 \n\t"
"vadd.f32 q3, q3, q10 \n\t"
"vadd.f32 q4, q4, q10 \n\t"
"vadd.f32 q5, q5, q10 \n\t"
"vadd.f32 q6, q6, q10 \n\t"
"vadd.f32 q7, q7, q10 \n\t"
"vadd.f32 q8, q8, q10 \n\t"
"vst1.32 {q1, q2}, [%[out_ptr]]! \n\t"
"vst1.32 {q3, q4}, [%[out_ptr]]! \n\t"
"vst1.32 {q5, q6}, [%[out_ptr]]! \n\t"
"vst1.32 {q7, q8}, [%[out_ptr]]! \n\t"
"subs r6, r6, #32 \n\t"
"bge loop_hw_%= \n\t"
"end_hw_%=: \n\t"
"cmp r6, #0 \n\t"
"bge end_remainder_%= \n\t"
"mov r5, #4 \n\t"
"mul r6, r6, r5 \n\t"
"add %[input_x_ptr], %[input_x_ptr], r6 \n\t"
"vld1.32 {q1, q2}, [%[input_x_ptr]]! \n\t"
"vld1.32 {q3, q4}, [%[input_x_ptr]]! \n\t"
"vld1.32 {q5, q6}, [%[input_x_ptr]]! \n\t"
"vld1.32 {q7, q8}, [%[input_x_ptr]]! \n\t"
"vmul.f32 q1, q1, q9 \n\t"
"vmul.f32 q2, q2, q9 \n\t"
"vmul.f32 q3, q3, q9 \n\t"
"vmul.f32 q4, q4, q9 \n\t"
"vmul.f32 q5, q5, q9 \n\t"
"vmul.f32 q6, q6, q9 \n\t"
"vmul.f32 q7, q7, q9 \n\t"
"vmul.f32 q8, q8, q9 \n\t"
"vadd.f32 q1, q1, q10 \n\t"
"vadd.f32 q2, q2, q10 \n\t"
"vadd.f32 q3, q3, q10 \n\t"
"vadd.f32 q4, q4, q10 \n\t"
"vadd.f32 q5, q5, q10 \n\t"
"vadd.f32 q6, q6, q10 \n\t"
"vadd.f32 q7, q7, q10 \n\t"
"vadd.f32 q8, q8, q10 \n\t"
"add %[out_ptr], %[out_ptr], r6 \n\t"
"vst1.32 {q1, q2}, [%[out_ptr]]! \n\t"
"vst1.32 {q3, q4}, [%[out_ptr]]! \n\t"
"vst1.32 {q5, q6}, [%[out_ptr]]! \n\t"
"vst1.32 {q7, q8}, [%[out_ptr]]! \n\t"
"end_remainder_%=: \n\t"
"subs %[C], %[C], #1 \n\t"
"bge loop_c_%= \n\t"
"end_c_%=: \n\t"
"subs %[N], %[N], #1 \n\t"
"bge loop_n_%= \n\t"
"end_n_%=: \n\t"
:
: [input_x_ptr] "r"(input_x_ptr), [out_ptr] "r"(out_ptr),
[new_scale_ptr] "r"(new_scale_ptr), [new_bias_ptr] "r"(new_bias_ptr),
[N] "r"(N), [C] "r"(C), [HXW] "r"(HXW)
: "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9",
"q10", "r5", "r6");
delete[] inv_std_ptr;
delete[] new_scale_ptr;
delete[] new_bias_ptr;
} else {
float *inv_std_ptr = new float[C];
for (int i = 0; i < C; i++) {
inv_std_ptr[i] =
1 / static_cast<float>(pow((variance_ptr[i] + epsilon), 0.5));
}
Tensor new_scale;
auto new_scale_ptr =
new_scale.mutable_data<float>(framework::make_ddim({C}));
Tensor new_bias;
auto new_bias_ptr = new_bias.mutable_data<float>(framework::make_ddim({C}));
/// ((x - est_mean) * (inv_var) * scale + bias equal to
/// (x * inv_var * scale) + (bias - est_mean * inv_var * scale)
for (int i = 0; i < C; i++) {
new_scale_ptr[i] = inv_std_ptr[i] * scale_ptr[i];
new_bias_ptr[i] =
bias_ptr[i] - mean_ptr[i] * inv_std_ptr[i] * scale_ptr[i];
{
for (int n = 0; n < N; n++) {
for (int h = 0; h < H; h++) {
int tmp_index = n * stride0 + i * stride1 + h * stride2;
for (int w = 0; w < W; w++) {
int index = tmp_index + w;
out_ptr[index] =
input_x_ptr[index] * new_scale_ptr[i] + new_bias_ptr[i];
}
}
}
}
}
delete[] inv_std_ptr;
}
#endif
#else
float *inv_std_ptr = new float[C];
for (int i = 0; i < C; i++) {
inv_std_ptr[i] =
1 / static_cast<float>(pow((variance_ptr[i] + epsilon), 0.5));
}
Tensor new_scale;
auto new_scale_ptr = new_scale.mutable_data<float>(framework::make_ddim({C}));
Tensor new_bias;
auto new_bias_ptr = new_bias.mutable_data<float>(framework::make_ddim({C}));
/// ((x - est_mean) * (inv_var) * scale + bias equal to
/// (x * inv_var * scale) + (bias - est_mean * inv_var * scale)
for (int i = 0; i < C; i++) {
new_scale_ptr[i] = inv_std_ptr[i] * scale_ptr[i];
new_bias_ptr[i] = bias_ptr[i] - mean_ptr[i] * inv_std_ptr[i] * scale_ptr[i];
{
for (int n = 0; n < N; n++) {
for (int h = 0; h < H; h++) {
int tmp_index = n * stride0 + i * stride1 + h * stride2;
for (int w = 0; w < W; w++) {
int index = tmp_index + w;
out_ptr[index] =
input_x_ptr[index] * new_scale_ptr[i] + new_bias_ptr[i];
}
}
#endif // __ARM_NEON__
for (int k = 0; k < remain; ++k) {
y[k] = scale * x[k] + bias;
}
}
}
delete[] inv_std_ptr;
#endif
}
} // namespace operators
......
......@@ -294,11 +294,6 @@ void MultiClassNMSCompute(const MultiClassNMSParam<CPU>& param) {
}
}
}
// framework::LoD lod;
// lod.emplace_back(batch_starts);
//
// outs->set_lod(lod);
}
} // namespace operators
......
......@@ -56,14 +56,13 @@ void MatMul<float, float>(const framework::Tensor &matrix_a, bool trans_a,
int N = dim_out[1];
int K = (!trans_a) ? dim_a[1] : dim_a[0];
Gemm gemm;
if (trans_a) {
framework::Tensor matrix_trans;
int numel = matrix_a.numel();
int m = matrix_a.dims()[0];
int n = matrix_a.dims()[1];
float *tmp = (float *)(matrix_a.data<float>()); // NOLINT
float *a = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * numel));
float *a = matrix_trans.mutable_data<float>(matrix_a.dims());
int index = 0;
for (int j = 0; j < n; j++) {
for (int i = 0; i < m; i++) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册