提交 7db044ac 编写于 作者: H hjchen2

Refactor batch norm, and fix deconv memory leak

上级 06a6f6a9
...@@ -22,6 +22,8 @@ const char *G_OP_TYPE_BATCHNORM = "batch_norm"; ...@@ -22,6 +22,8 @@ const char *G_OP_TYPE_BATCHNORM = "batch_norm";
const char *G_OP_TYPE_BOX_CODER = "box_coder"; const char *G_OP_TYPE_BOX_CODER = "box_coder";
const char *G_OP_TYPE_CONCAT = "concat"; const char *G_OP_TYPE_CONCAT = "concat";
const char *G_OP_TYPE_ELEMENTWISE_ADD = "elementwise_add"; const char *G_OP_TYPE_ELEMENTWISE_ADD = "elementwise_add";
const char *G_OP_TYPE_ELEMENTWISE_SUB = "elementwise_sub";
const char *G_OP_TYPE_ELEMENTWISE_MUL = "elementwise_mul";
const char *G_OP_TYPE_FILL_CONSTANT = "fill_constant"; const char *G_OP_TYPE_FILL_CONSTANT = "fill_constant";
const char *G_OP_TYPE_FUSION_CONV_ADD_RELU = "fusion_conv_add_relu"; const char *G_OP_TYPE_FUSION_CONV_ADD_RELU = "fusion_conv_add_relu";
const char *G_OP_TYPE_FUSION_CONV_ADD_PRELU = "fusion_conv_add_prelu"; const char *G_OP_TYPE_FUSION_CONV_ADD_PRELU = "fusion_conv_add_prelu";
...@@ -67,7 +69,6 @@ const char *G_OP_TYPE_CRF = "crf_decoding"; ...@@ -67,7 +69,6 @@ const char *G_OP_TYPE_CRF = "crf_decoding";
const char *G_OP_TYPE_BILINEAR_INTERP = "bilinear_interp"; const char *G_OP_TYPE_BILINEAR_INTERP = "bilinear_interp";
const char *G_OP_TYPE_FLATTEN = "flatten"; const char *G_OP_TYPE_FLATTEN = "flatten";
const char *G_OP_TYPE_SHAPE = "shape"; const char *G_OP_TYPE_SHAPE = "shape";
const char *G_OP_TYPE_ELEMENTWISE_MUL = "elementwise_mul";
const char *G_OP_TYPE_SUM = "sum"; const char *G_OP_TYPE_SUM = "sum";
const char *G_OP_TYPE_TOP_K = "top_k"; const char *G_OP_TYPE_TOP_K = "top_k";
const char *G_OP_TYPE_CAST = "cast"; const char *G_OP_TYPE_CAST = "cast";
...@@ -102,6 +103,8 @@ std::unordered_map< ...@@ -102,6 +103,8 @@ std::unordered_map<
{G_OP_TYPE_SIGMOID, {{"X"}, {"Out"}}}, {G_OP_TYPE_SIGMOID, {{"X"}, {"Out"}}},
{G_OP_TYPE_MUL, {{"X"}, {"Out"}}}, {G_OP_TYPE_MUL, {{"X"}, {"Out"}}},
{G_OP_TYPE_ELEMENTWISE_ADD, {{"X", "Y"}, {"Out"}}}, {G_OP_TYPE_ELEMENTWISE_ADD, {{"X", "Y"}, {"Out"}}},
{G_OP_TYPE_ELEMENTWISE_SUB, {{"X", "Y"}, {"Out"}}},
{G_OP_TYPE_ELEMENTWISE_MUL, {{"X", "Y"}, {"Out"}}},
{G_OP_TYPE_POOL2D, {{"X"}, {"Out"}}}, {G_OP_TYPE_POOL2D, {{"X"}, {"Out"}}},
{G_OP_TYPE_BATCHNORM, {{"X"}, {"Y"}}}, {G_OP_TYPE_BATCHNORM, {{"X"}, {"Y"}}},
{G_OP_TYPE_LRN, {{"X"}, {"Out"}}}, {G_OP_TYPE_LRN, {{"X"}, {"Out"}}},
...@@ -146,7 +149,6 @@ std::unordered_map< ...@@ -146,7 +149,6 @@ std::unordered_map<
{G_OP_TYPE_SUM, {{"X"}, {"Out"}}}, {G_OP_TYPE_SUM, {{"X"}, {"Out"}}},
{G_OP_TYPE_TOP_K, {{"X"}, {"Out", "Indices"}}}, {G_OP_TYPE_TOP_K, {{"X"}, {"Out", "Indices"}}},
{G_OP_TYPE_CAST, {{"X"}, {"Out"}}}, {G_OP_TYPE_CAST, {{"X"}, {"Out"}}},
{G_OP_TYPE_ELEMENTWISE_MUL, {{"X", "Y"}, {"Out"}}},
{G_OP_TYPE_QUANTIZE, {{"X"}, {"Out", "OutScale"}}}, {G_OP_TYPE_QUANTIZE, {{"X"}, {"Out", "OutScale"}}},
{G_OP_TYPE_DEQUANTIZE, {{"X", "Scale"}, {"Out"}}}, {G_OP_TYPE_DEQUANTIZE, {{"X", "Scale"}, {"Out"}}},
{G_OP_TYPE_FUSION_DEQUANT_BN, {{"X", "Scale"}, {"Out"}}}, {G_OP_TYPE_FUSION_DEQUANT_BN, {{"X", "Scale"}, {"Out"}}},
......
...@@ -112,6 +112,8 @@ extern const char *G_OP_TYPE_BATCHNORM; ...@@ -112,6 +112,8 @@ extern const char *G_OP_TYPE_BATCHNORM;
extern const char *G_OP_TYPE_BOX_CODER; extern const char *G_OP_TYPE_BOX_CODER;
extern const char *G_OP_TYPE_CONCAT; extern const char *G_OP_TYPE_CONCAT;
extern const char *G_OP_TYPE_ELEMENTWISE_ADD; extern const char *G_OP_TYPE_ELEMENTWISE_ADD;
extern const char *G_OP_TYPE_ELEMENTWISE_SUB;
extern const char *G_OP_TYPE_ELEMENTWISE_MUL;
extern const char *G_OP_TYPE_FUSION_CONV_ADD_RELU; extern const char *G_OP_TYPE_FUSION_CONV_ADD_RELU;
extern const char *G_OP_TYPE_FUSION_CONV_ADD_PRELU; extern const char *G_OP_TYPE_FUSION_CONV_ADD_PRELU;
extern const char *G_OP_TYPE_FUSION_CONV_ADD_ADD_PRELU; extern const char *G_OP_TYPE_FUSION_CONV_ADD_ADD_PRELU;
...@@ -149,7 +151,6 @@ extern const char *G_OP_TYPE_FUSION_CONV_BN; ...@@ -149,7 +151,6 @@ extern const char *G_OP_TYPE_FUSION_CONV_BN;
extern const char *G_OP_TYPE_CONV_TRANSPOSE; extern const char *G_OP_TYPE_CONV_TRANSPOSE;
extern const char *G_OP_TYPE_PRELU; extern const char *G_OP_TYPE_PRELU;
extern const char *G_OP_TYPE_SUM; extern const char *G_OP_TYPE_SUM;
extern const char *G_OP_TYPE_ELEMENTWISE_MUL;
extern const char *G_OP_TYPE_TOP_K; extern const char *G_OP_TYPE_TOP_K;
extern const char *G_OP_TYPE_CAST; extern const char *G_OP_TYPE_CAST;
......
...@@ -18,283 +18,63 @@ limitations under the License. */ ...@@ -18,283 +18,63 @@ limitations under the License. */
#include <cmath> #include <cmath>
#include "operators/op_param.h" #include "operators/op_param.h"
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
#include <arm_neon.h>
#endif // __ARM_NEON__
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
template <typename P> template <typename P>
void BatchnormCompute(const BatchNormParam<CPU> &param) { void BatchnormCompute(const BatchNormParam<CPU> &param) {
const Tensor *input_x = param.InputX();
auto input_x_ptr = input_x->data<float>();
const auto &x_dims = input_x->dims();
const int N = x_dims[0];
const int C = x_dims[1];
const int H = x_dims[2];
const int W = x_dims[3];
const int stride0 = C * H * W;
const int stride1 = H * W;
const int stride2 = W;
Tensor *out = param.OutputY();
auto out_ptr = out->mutable_data<float>();
const float epsilon = param.Epsilon(); const float epsilon = param.Epsilon();
const Tensor *mean = param.InputMean(); const float *mean_ptr = param.InputMean()->data<float>();
const Tensor *variance = param.InputVariance(); const float *variance_ptr = param.InputVariance()->data<float>();
const Tensor *scale = param.InputScale(); const float *scale_ptr = param.InputScale()->data<float>();
const Tensor *bias = param.InputBias(); const float *bias_ptr = param.InputBias()->data<float>();
auto mean_ptr = mean->data<float>();
auto variance_ptr = variance->data<float>(); const framework::Tensor *input = param.InputX();
auto scale_ptr = scale->data<float>(); const float *input_ptr = input->data<float>();
auto bias_ptr = bias->data<float>(); framework::Tensor *output = param.OutputY();
float *output_ptr = output->mutable_data<float>();
// Tensor inv_std; size_t spatial_size = output->dims()[2] * output->dims()[3];
// auto inv_std_ptr = inv_std.mutable_data<float>(make_ddim({C})); int channels = output->dims()[1];
PADDLE_MOBILE_ENFORCE(C == variance->numel(), #pragma omp parallel for collapse(2)
"C must equal to variance.numel()"); for (int batch = 0; batch < output->dims()[0]; ++batch) {
for (int c = 0; c < channels; ++c) {
int HXW = H * W; float inv_scale = 1.f / (std::sqrt(variance_ptr[c] + epsilon));
float bias = bias_ptr[c] - inv_scale * scale_ptr[c] * mean_ptr[c];
#if __ARM_NEON float scale = inv_scale * scale_ptr[c];
#if __aarch64__ size_t offset = (batch * channels + c) * spatial_size;
float *inv_std_ptr = new float[C]; const float *x = input_ptr + offset;
for (int i = 0; i < C; i++) { float *y = output_ptr + offset;
inv_std_ptr[i] = size_t remain = spatial_size;
1 / static_cast<float>(pow((variance_ptr[i] + epsilon), 0.5)); #if defined(__ARM_NEON__) || defined(__ARM_NEON)
} int loop = spatial_size >> 4;
remain = spatial_size & 0xF;
Tensor new_scale; float32x4_t __scale = vdupq_n_f32(scale);
auto new_scale_ptr = new_scale.mutable_data<float>(framework::make_ddim({C})); float32x4_t __bias = vdupq_n_f32(bias);
Tensor new_bias; for (int k = 0; k < loop; ++k, x += 16, y += 16) {
auto new_bias_ptr = new_bias.mutable_data<float>(framework::make_ddim({C})); float32x4_t r0 = vld1q_f32(x);
float32x4_t r1 = vld1q_f32(x + 4);
/// ((x - est_mean) * (inv_var) * scale + bias equal to float32x4_t r2 = vld1q_f32(x + 8);
/// (x * inv_var * scale) + (bias - est_mean * inv_var * scale) float32x4_t r3 = vld1q_f32(x + 12);
for (int i = 0; i < C; i++) { r0 = vmlaq_f32(__bias, __scale, r0);
new_scale_ptr[i] = inv_std_ptr[i] * scale_ptr[i]; r1 = vmlaq_f32(__bias, __scale, r1);
new_bias_ptr[i] = bias_ptr[i] - mean_ptr[i] * inv_std_ptr[i] * scale_ptr[i]; r2 = vmlaq_f32(__bias, __scale, r2);
{ r3 = vmlaq_f32(__bias, __scale, r3);
for (int n = 0; n < N; n++) { vst1q_f32(y, r0);
for (int h = 0; h < H; h++) { vst1q_f32(y + 4, r1);
int tmp_index = n * stride0 + i * stride1 + h * stride2; vst1q_f32(y + 8, r2);
for (int w = 0; w < W; w++) { vst1q_f32(y + 12, r3);
int index = tmp_index + w; }
out_ptr[index] = #endif // __ARM_NEON__
input_x_ptr[index] * new_scale_ptr[i] + new_bias_ptr[i]; for (int k = 0; k < remain; ++k) {
} y[k] = scale * x[k] + bias;
} }
} }
} }
}
delete[] inv_std_ptr;
#else
if (HXW > 32) {
int NXC = N * C;
float *inv_std_ptr = new float[NXC * 4];
float *volatile new_scale_ptr = new float[NXC * 4];
float *volatile new_bias_ptr = new float[NXC * 4];
/// std = (var + epsilon).sqrt();
/// inv_std = 1 / std;
for (int i = 0; i < C * 4; i += 4) {
int index = i / 4;
inv_std_ptr[i] =
1 / static_cast<float>(pow((variance_ptr[index] + epsilon), 0.5));
inv_std_ptr[i + 1] = inv_std_ptr[i];
inv_std_ptr[i + 2] = inv_std_ptr[i];
inv_std_ptr[i + 3] = inv_std_ptr[i];
new_scale_ptr[i] = inv_std_ptr[i] * scale_ptr[index];
new_scale_ptr[i + 1] = new_scale_ptr[i];
new_scale_ptr[i + 2] = new_scale_ptr[i];
new_scale_ptr[i + 3] = new_scale_ptr[i];
new_bias_ptr[i] =
bias_ptr[index] - mean_ptr[index] * inv_std_ptr[i] * scale_ptr[index];
new_bias_ptr[i + 1] = new_bias_ptr[i];
new_bias_ptr[i + 2] = new_bias_ptr[i];
new_bias_ptr[i + 3] = new_bias_ptr[i];
}
for (int j = C * 4; j < NXC * 4; ++j) {
new_scale_ptr[j] = new_scale_ptr[j - C * 4];
new_bias_ptr[j] = new_bias_ptr[j - C * 4];
}
asm volatile(
"subs %[N], %[N], #1 \n\t"
"blt end_n_%= \n\t"
"loop_n_%=: \n\t"
"subs %[C], %[C], #1 \n\t"
"blt end_c_%= \n\t"
"loop_c_%=: \n\t"
"vld1.32 {q9}, [%[new_scale_ptr]]! \n\t"
"vld1.32 {q10}, [%[new_bias_ptr]]! \n\t"
"mov r6, %[HXW] \n\t"
"subs r6, r6, #32 \n\t"
"blt end_hw_%= \n\t"
"loop_hw_%=: \n\t"
"vld1.32 {q1, q2}, [%[input_x_ptr]]! \n\t"
"vld1.32 {q3, q4}, [%[input_x_ptr]]! \n\t"
"vld1.32 {q5, q6}, [%[input_x_ptr]]! \n\t"
"vld1.32 {q7, q8}, [%[input_x_ptr]]! \n\t"
"vmul.f32 q1, q1, q9 \n\t"
"vmul.f32 q2, q2, q9 \n\t"
"vmul.f32 q3, q3, q9 \n\t"
"vmul.f32 q4, q4, q9 \n\t"
"vmul.f32 q5, q5, q9 \n\t"
"vmul.f32 q6, q6, q9 \n\t"
"vmul.f32 q7, q7, q9 \n\t"
"vmul.f32 q8, q8, q9 \n\t"
"vadd.f32 q1, q1, q10 \n\t"
"vadd.f32 q2, q2, q10 \n\t"
"vadd.f32 q3, q3, q10 \n\t"
"vadd.f32 q4, q4, q10 \n\t"
"vadd.f32 q5, q5, q10 \n\t"
"vadd.f32 q6, q6, q10 \n\t"
"vadd.f32 q7, q7, q10 \n\t"
"vadd.f32 q8, q8, q10 \n\t"
"vst1.32 {q1, q2}, [%[out_ptr]]! \n\t"
"vst1.32 {q3, q4}, [%[out_ptr]]! \n\t"
"vst1.32 {q5, q6}, [%[out_ptr]]! \n\t"
"vst1.32 {q7, q8}, [%[out_ptr]]! \n\t"
"subs r6, r6, #32 \n\t"
"bge loop_hw_%= \n\t"
"end_hw_%=: \n\t"
"cmp r6, #0 \n\t"
"bge end_remainder_%= \n\t"
"mov r5, #4 \n\t"
"mul r6, r6, r5 \n\t"
"add %[input_x_ptr], %[input_x_ptr], r6 \n\t"
"vld1.32 {q1, q2}, [%[input_x_ptr]]! \n\t"
"vld1.32 {q3, q4}, [%[input_x_ptr]]! \n\t"
"vld1.32 {q5, q6}, [%[input_x_ptr]]! \n\t"
"vld1.32 {q7, q8}, [%[input_x_ptr]]! \n\t"
"vmul.f32 q1, q1, q9 \n\t"
"vmul.f32 q2, q2, q9 \n\t"
"vmul.f32 q3, q3, q9 \n\t"
"vmul.f32 q4, q4, q9 \n\t"
"vmul.f32 q5, q5, q9 \n\t"
"vmul.f32 q6, q6, q9 \n\t"
"vmul.f32 q7, q7, q9 \n\t"
"vmul.f32 q8, q8, q9 \n\t"
"vadd.f32 q1, q1, q10 \n\t"
"vadd.f32 q2, q2, q10 \n\t"
"vadd.f32 q3, q3, q10 \n\t"
"vadd.f32 q4, q4, q10 \n\t"
"vadd.f32 q5, q5, q10 \n\t"
"vadd.f32 q6, q6, q10 \n\t"
"vadd.f32 q7, q7, q10 \n\t"
"vadd.f32 q8, q8, q10 \n\t"
"add %[out_ptr], %[out_ptr], r6 \n\t"
"vst1.32 {q1, q2}, [%[out_ptr]]! \n\t"
"vst1.32 {q3, q4}, [%[out_ptr]]! \n\t"
"vst1.32 {q5, q6}, [%[out_ptr]]! \n\t"
"vst1.32 {q7, q8}, [%[out_ptr]]! \n\t"
"end_remainder_%=: \n\t"
"subs %[C], %[C], #1 \n\t"
"bge loop_c_%= \n\t"
"end_c_%=: \n\t"
"subs %[N], %[N], #1 \n\t"
"bge loop_n_%= \n\t"
"end_n_%=: \n\t"
:
: [input_x_ptr] "r"(input_x_ptr), [out_ptr] "r"(out_ptr),
[new_scale_ptr] "r"(new_scale_ptr), [new_bias_ptr] "r"(new_bias_ptr),
[N] "r"(N), [C] "r"(C), [HXW] "r"(HXW)
: "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9",
"q10", "r5", "r6");
delete[] inv_std_ptr;
delete[] new_scale_ptr;
delete[] new_bias_ptr;
} else {
float *inv_std_ptr = new float[C];
for (int i = 0; i < C; i++) {
inv_std_ptr[i] =
1 / static_cast<float>(pow((variance_ptr[i] + epsilon), 0.5));
}
Tensor new_scale;
auto new_scale_ptr =
new_scale.mutable_data<float>(framework::make_ddim({C}));
Tensor new_bias;
auto new_bias_ptr = new_bias.mutable_data<float>(framework::make_ddim({C}));
/// ((x - est_mean) * (inv_var) * scale + bias equal to
/// (x * inv_var * scale) + (bias - est_mean * inv_var * scale)
for (int i = 0; i < C; i++) {
new_scale_ptr[i] = inv_std_ptr[i] * scale_ptr[i];
new_bias_ptr[i] =
bias_ptr[i] - mean_ptr[i] * inv_std_ptr[i] * scale_ptr[i];
{
for (int n = 0; n < N; n++) {
for (int h = 0; h < H; h++) {
int tmp_index = n * stride0 + i * stride1 + h * stride2;
for (int w = 0; w < W; w++) {
int index = tmp_index + w;
out_ptr[index] =
input_x_ptr[index] * new_scale_ptr[i] + new_bias_ptr[i];
}
}
}
}
}
delete[] inv_std_ptr;
}
#endif
#else
float *inv_std_ptr = new float[C];
for (int i = 0; i < C; i++) {
inv_std_ptr[i] =
1 / static_cast<float>(pow((variance_ptr[i] + epsilon), 0.5));
}
Tensor new_scale;
auto new_scale_ptr = new_scale.mutable_data<float>(framework::make_ddim({C}));
Tensor new_bias;
auto new_bias_ptr = new_bias.mutable_data<float>(framework::make_ddim({C}));
/// ((x - est_mean) * (inv_var) * scale + bias equal to
/// (x * inv_var * scale) + (bias - est_mean * inv_var * scale)
for (int i = 0; i < C; i++) {
new_scale_ptr[i] = inv_std_ptr[i] * scale_ptr[i];
new_bias_ptr[i] = bias_ptr[i] - mean_ptr[i] * inv_std_ptr[i] * scale_ptr[i];
{
for (int n = 0; n < N; n++) {
for (int h = 0; h < H; h++) {
int tmp_index = n * stride0 + i * stride1 + h * stride2;
for (int w = 0; w < W; w++) {
int index = tmp_index + w;
out_ptr[index] =
input_x_ptr[index] * new_scale_ptr[i] + new_bias_ptr[i];
}
}
}
}
}
delete[] inv_std_ptr;
#endif
} }
} // namespace operators } // namespace operators
......
...@@ -294,11 +294,6 @@ void MultiClassNMSCompute(const MultiClassNMSParam<CPU>& param) { ...@@ -294,11 +294,6 @@ void MultiClassNMSCompute(const MultiClassNMSParam<CPU>& param) {
} }
} }
} }
// framework::LoD lod;
// lod.emplace_back(batch_starts);
//
// outs->set_lod(lod);
} }
} // namespace operators } // namespace operators
......
...@@ -56,14 +56,13 @@ void MatMul<float, float>(const framework::Tensor &matrix_a, bool trans_a, ...@@ -56,14 +56,13 @@ void MatMul<float, float>(const framework::Tensor &matrix_a, bool trans_a,
int N = dim_out[1]; int N = dim_out[1];
int K = (!trans_a) ? dim_a[1] : dim_a[0]; int K = (!trans_a) ? dim_a[1] : dim_a[0];
Gemm gemm; Gemm gemm;
if (trans_a) { if (trans_a) {
framework::Tensor matrix_trans;
int numel = matrix_a.numel(); int numel = matrix_a.numel();
int m = matrix_a.dims()[0]; int m = matrix_a.dims()[0];
int n = matrix_a.dims()[1]; int n = matrix_a.dims()[1];
float *tmp = (float *)(matrix_a.data<float>()); // NOLINT float *tmp = (float *)(matrix_a.data<float>()); // NOLINT
float *a = static_cast<float *>( float *a = matrix_trans.mutable_data<float>(matrix_a.dims());
paddle_mobile::memory::Alloc(sizeof(float) * numel));
int index = 0; int index = 0;
for (int j = 0; j < n; j++) { for (int j = 0; j < n; j++) {
for (int i = 0; i < m; i++) { for (int i = 0; i < m; i++) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册