// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once template static void basic_trans_mat_to_c4(const type* input, type* output, const int ldin, const int M, const int K, bool pack_k) { const int m_round = (M + 3) / 4 * 4; int k_round = (K + 3) / 4 * 4; if (!pack_k) { k_round = K; } const int m_loop = m_round / 4; type zero_buf[K]; memset(zero_buf, 0, K * sizeof(type)); for (int i = 0; i < m_loop; ++i) { const type* in0 = input + i * 4 * ldin; const type* in1 = in0 + ldin; const type* in2 = in1 + ldin; const type* in3 = in2 + ldin; if (4 * (i + 1) - M > 0) { switch (4 * (i + 1) - M) { case 3: in1 = zero_buf; case 2: in2 = zero_buf; case 1: in3 = zero_buf; default: break; } } for (int j = 0; j < K; ++j) { *output++ = *in0++; *output++ = *in1++; *output++ = *in2++; *output++ = *in3++; } for (int j = K; j < k_round; ++j) { *output++ = static_cast(0); *output++ = static_cast(0); *output++ = static_cast(0); *output++ = static_cast(0); } } } template static void basic_gemm_c4(bool trans_a, bool trans_b, int m, int n, int k, type2 alpha, const type* a, int lda, const type* b, int ldb, type2 beta, type2* c, int ldc, const type2* bias, bool flag_bias = false, bool flag_relu = false) { type2* tmp_c = reinterpret_cast(malloc(m * ldc * sizeof(type2))); memset(tmp_c, 0, m * ldc * sizeof(type2)); #pragma omp parallel for for (int i = 0; i < m; ++i) { auto bias_data = static_cast(0); if (flag_bias) { bias_data = bias[i]; } for (int j = 0; j < n; ++j) { auto sum = static_cast(0); for (int l = 0; l < k; ++l) { type av; type bv; if (trans_a) { av = a[l * lda + i]; } else { av = a[i * lda + l]; } if (trans_b) { bv = b[j * ldb + l]; } else { bv = b[l * ldb + j]; } sum += av * bv; } type2 tmp = alpha * sum + beta * tmp_c[i * ldc + j] + bias_data; if (flag_relu) { tmp_c[i * ldc + j] = tmp > (type2)0 ? tmp : (type2)0; } else { tmp_c[i * ldc + j] = tmp; } } } //! trans c to c4 basic_trans_mat_to_c4(tmp_c, c, ldc, m, n, false); free(tmp_c); } template static void basic_gemm(bool trans_a, bool trans_b, int m, int n, int k, type2 alpha, const type* a, int lda, const type* b, int ldb, type2 beta, type2* c, int ldc, const type2* bias, bool flag_bias = false, bool flag_relu = false) { #pragma omp parallel for for (int i = 0; i < m; ++i) { auto bias_data = static_cast(0); if (flag_bias) { bias_data = bias[i]; } for (int j = 0; j < n; ++j) { auto sum = static_cast(0); for (int l = 0; l < k; ++l) { type av; type bv; if (trans_a) { av = a[l * lda + i]; } else { av = a[i * lda + l]; } if (trans_b) { bv = b[j * ldb + l]; } else { bv = b[l * ldb + j]; } sum += av * bv; } type2 tmp = alpha * sum + beta * c[i * ldc + j] + bias_data; if (flag_relu) { c[i * ldc + j] = tmp > (type2)0 ? tmp : (type2)0; } else { c[i * ldc + j] = tmp; } } } } template static void basic_gemv(int m, int k, const type* a, const type* b, const type2* bias, type2* c, type2 alpha, type2 beta, bool trans_a = false, bool flag_bias = false, int flag_act = false, float six = 6.f, float leakey_relu_alpha = 1.f) { #pragma omp parallel for for (int i = 0; i < m; ++i) { auto bias_data = static_cast(0); if (flag_bias) { bias_data = bias[i]; } auto sum = static_cast(0); for (int j = 0; j < k; ++j) { type av; if (trans_a) { av = a[j * m + i]; } else { av = a[i * k + j]; } sum += av * b[j]; } type2 tmp = alpha * sum + beta * c[i] + bias_data; if (flag_act > 0) { if (flag_act == 1) { // relu c[i] = tmp > (type2)0 ? tmp : (type2)0; } else if (flag_act == 2) { // relu 6 c[i] = tmp > (type2)0 ? tmp : (type2)0; // c[i] = c[i] < six ? c[i] : six; // ut compute } else if (flag_act == 4) { // leakey relu c[i] = tmp < (type2)0 ? (type2)(tmp * leakey_relu_alpha) : tmp; } } else { c[i] = tmp; } } } /** * \brief basic direct convolution function */ //! for float, dtype1 and type2 is float //! for int8, dytpe1 is char, dtype2 is int template static void conv_basic(const Dtype1* din, Dtype2* dout, int num, int chout, int hout, int wout, int chin, int hin, int win, const Dtype1* weights, const Dtype2* bias, int group, int kernel_w, int kernel_h, int stride_w, int stride_h, int dila_w, int dila_h, int pad_w, int pad_h, bool flag_bias, int act_type, float six = 6.f, float scale = 1.f) { Dtype2 beta = 0; auto src_data = din; auto dst_data_ref = dout; auto weights_data = weights; auto with_bias = flag_bias; auto bias_data = bias; int in_num = num; int out_channels = chout; int out_h = hout; int out_w = wout; int in_channel = chin; int in_h = hin; int in_w = win; int out_c_group = out_channels / group; int in_c_group = in_channel / group; for (int n = 0; n < in_num; ++n) { #pragma omp parallel for collapse(4) for (int g = 0; g < group; ++g) { for (int oc = 0; oc < out_c_group; ++oc) { for (int oh = 0; oh < out_h; ++oh) { for (int ow = 0; ow < out_w; ++ow) { int out_idx = n * group * out_c_group * out_h * out_w + g * out_c_group * out_h * out_w + oc * out_h * out_w + oh * out_w + ow; Dtype2 bias_d = with_bias ? (bias_data[g * out_c_group + oc]) : 0; dst_data_ref[out_idx] = bias_d; // + dst_data_ref[out_idx] * beta; for (int ic = 0; ic < in_c_group; ++ic) { for (int kh = 0; kh < kernel_h; ++kh) { for (int kw = 0; kw < kernel_w; ++kw) { int iw = ow * stride_w - pad_w + kw * (dila_w); int ih = oh * stride_h - pad_h + kh * (dila_h); if (iw < 0 || iw >= in_w) continue; if (ih < 0 || ih >= in_h) continue; int iidx = n * in_channel * in_h * in_w + g * in_c_group * in_h * in_w + ic * in_h * in_w + ih * in_w + iw; int widx = g * out_c_group * in_c_group * kernel_h * kernel_w + oc * in_c_group * kernel_h * kernel_w + ic * kernel_h * kernel_w + kh * kernel_w + kw; dst_data_ref[out_idx] += src_data[iidx] * weights_data[widx]; } } } if (act_type > 0) { // 1-relu 2-relu6 4-leakyrelu if (act_type == 1) { dst_data_ref[out_idx] = dst_data_ref[out_idx] > (Dtype2)0 ? dst_data_ref[out_idx] : (Dtype2)0; } else if (act_type == 2) { dst_data_ref[out_idx] = dst_data_ref[out_idx] > (Dtype2)0 ? dst_data_ref[out_idx] : (Dtype2)0; dst_data_ref[out_idx] = dst_data_ref[out_idx] < (Dtype2)six ? dst_data_ref[out_idx] : (Dtype2)six; } else if (act_type == 4) { dst_data_ref[out_idx] = dst_data_ref[out_idx] > (Dtype2)0 ? dst_data_ref[out_idx] : (Dtype2)(dst_data_ref[out_idx] * scale); } else { printf("this act type: %d does not support \n", act_type); } } } } } } } } template static void fill_bias_relu(Dtype* tensor, const Dtype* bias, int channel, int channel_size, bool flag_bias, bool flag_relu) { Dtype* data = tensor; for (int j = 0; j < channel; ++j) { Dtype bias_c = flag_bias ? bias[j] : 0; for (int i = 0; i < channel_size; i++) { data[i] += bias_c; if (flag_relu) { data[i] = data[i] > 0 ? data[i] : 0.f; } } data += channel_size; } } template static void do_relu(Dtype* tensor, int size) { for (int j = 0; j < size; ++j) { tensor[j] = tensor[j] > 0 ? tensor[j] : (Dtype)0; } } inline bool is_a_ge_zero_and_a_lt_b(int a, int b) { return static_cast(a) < static_cast(b); } template static void col2im(const Dtype* data_col, const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h0, const int pad_h1, const int pad_w0, const int pad_w1, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, Dtype* data_im) { memset(data_im, 0, height * width * channels * sizeof(Dtype)); const int output_h = (height + pad_h0 + pad_h1 - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; const int output_w = (width + pad_w0 + pad_w1 - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; const int channel_size = height * width; for (int channel = channels; channel--; data_im += channel_size) { for (int kernel_row = 0; kernel_row < kernel_h; kernel_row++) { for (int kernel_col = 0; kernel_col < kernel_w; kernel_col++) { int input_row = -pad_h0 + kernel_row * dilation_h; for (int output_rows = output_h; output_rows; output_rows--) { if (!is_a_ge_zero_and_a_lt_b(input_row, height)) { data_col += output_w; } else { int input_col = -pad_w0 + kernel_col * dilation_w; for (int output_col = output_w; output_col; output_col--) { if (is_a_ge_zero_and_a_lt_b(input_col, width)) { data_im[input_row * width + input_col] += *data_col; } data_col++; input_col += stride_w; } } input_row += stride_h; } } } } } //! for float, dtype1 and type2 is float //! for int8, dytpe1 is char, dtype2 is int template void deconv_basic(const Dtype1* din, Dtype2* dout, int num, int chout, int hout, int wout, int chin, int hin, int win, const Dtype1* weights, const Dtype2* bias, int group, int kernel_w, int kernel_h, int stride_w, int stride_h, int dila_w, int dila_h, int pad_w0, int pad_w1, int pad_h0, int pad_h1, bool flag_bias, bool flag_relu) { int m = chout * kernel_w * kernel_h / group; int n = hin * win; int k = chin / group; int group_size_in = win * hin * chin / group; int group_size_coldata = m * n; int group_size_weights = chin * chout * kernel_w * kernel_h / (group * group); bool flag_1x1s1p1 = (kernel_w == 1) && (kernel_h == 1) && (stride_h == 1) && (stride_w == 1) && (pad_w0 == 0) && (pad_h0 == 0) && (pad_w1 == 0) && (pad_h1 == 0) && (dila_w == 1) && (dila_h == 1); Dtype2* workspace_ptr = static_cast(malloc(sizeof(float) * m * n * group)); for (int i = 0; i < num; ++i) { const Dtype1* din_batch = din + i * chin * hin * win; Dtype2* dout_batch = dout + i * chout * hout * wout; Dtype2* col_data = workspace_ptr; if (flag_1x1s1p1) { col_data = dout_batch; } memset(col_data, 0, sizeof(Dtype2) * group_size_coldata * group); for (int g = 0; g < group; ++g) { const Dtype1* din_group = din_batch + g * group_size_in; const Dtype1* weights_group = weights + g * group_size_weights; Dtype2* coldata_group = col_data + g * group_size_coldata; basic_gemm(true, false, m, n, k, 1, weights_group, m, din_group, n, 0, coldata_group, n, nullptr, false, (!flag_bias && flag_relu)); } if (!flag_1x1s1p1) { col2im(col_data, chout, hout, wout, kernel_h, kernel_w, pad_h0, pad_h1, pad_w0, pad_w1, stride_h, stride_w, dila_h, dila_w, dout_batch); } //! add bias if (flag_bias) { fill_bias_relu( dout_batch, bias, chout, wout * hout, flag_bias, flag_relu); } } free(workspace_ptr); }