conv_impl.h 15.7 KB
Newer Older
Y
Yan Chunwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "lite/core/context.h"
#include "lite/core/target_wrapper.h"
#include "lite/operators/op_params.h"

namespace paddle {
namespace lite {
namespace arm {
namespace math {

26 27 28
/// conv 3x3s1
size_t conv3x3s1_direct_workspace_size(const operators::ConvParam& param,
                                       ARMContext* ctx);
Y
Yan Chunwei 已提交
29 30 31 32 33 34 35 36 37 38 39 40
void conv_3x3s1_direct_fp32(const float* din,
                            float* dout,
                            int num,
                            int chout,
                            int hout,
                            int wout,
                            int chin,
                            int hin,
                            int win,
                            const float* weights,
                            const float* bias,
                            const operators::ConvParam& param,
41
                            ARMContext* ctx);
Y
Yan Chunwei 已提交
42

43
template <typename Dtype>
Y
Yan Chunwei 已提交
44
void conv_3x3s1_direct_int8(const int8_t* din,
45
                            Dtype* dout,
Y
Yan Chunwei 已提交
46 47 48 49 50 51 52 53
                            int num,
                            int chout,
                            int hout,
                            int wout,
                            int chin,
                            int hin,
                            int win,
                            const int8_t* weights,
54
                            const float* bias,
Y
Yan Chunwei 已提交
55
                            const operators::ConvParam& param,
56
                            ARMContext* ctx,
Y
Yan Chunwei 已提交
57 58
                            const float* scale);

59 60 61
/// conv3x3s2
size_t conv3x3s2_direct_workspace_size(const operators::ConvParam& param,
                                       ARMContext* ctx);
Y
Yan Chunwei 已提交
62 63 64 65 66 67 68 69 70 71 72 73
void conv_3x3s2_direct_fp32(const float* din,
                            float* dout,
                            int num,
                            int chout,
                            int hout,
                            int wout,
                            int chin,
                            int hin,
                            int win,
                            const float* weights,
                            const float* bias,
                            const operators::ConvParam& param,
74
                            ARMContext* ctx);
Y
Yan Chunwei 已提交
75 76 77

int conv_3x3s2_direct_int8_c_num();

78
template <typename Dtype>
Y
Yan Chunwei 已提交
79
void conv_3x3s2_direct_int8(const int8_t* din,
80
                            Dtype* dout,
Y
Yan Chunwei 已提交
81 82 83 84 85 86 87 88
                            int num,
                            int chout,
                            int hout,
                            int wout,
                            int chin,
                            int hin,
                            int win,
                            const int8_t* weights,
89
                            const float* bias,
Y
Yan Chunwei 已提交
90
                            const operators::ConvParam& param,
91
                            ARMContext* ctx,
Y
Yan Chunwei 已提交
92 93
                            const float* scale);

94 95
void conv_1x5s1_direct(const float* din,
                       float* dout,
Y
Yan Chunwei 已提交
96 97 98 99 100 101 102
                       int num,
                       int chout,
                       int hout,
                       int wout,
                       int chin,
                       int hin,
                       int win,
103 104
                       const float* weights,
                       const float* bias,
Y
Yan Chunwei 已提交
105 106 107 108 109 110 111 112 113 114 115
                       int group,
                       int kernel_w,
                       int kernel_h,
                       int stride_w,
                       int stride_h,
                       int dila_w,
                       int dila_h,
                       int pad_w,
                       int pad_h,
                       bool flag_bias,
                       bool flag_relu,
116
                       ARMContext& ctx);  // NOLINT
Y
Yan Chunwei 已提交
117

118 119
void conv_5x1s1_direct(const float* din,
                       float* dout,
Y
Yan Chunwei 已提交
120 121 122 123 124 125 126
                       int num,
                       int chout,
                       int hout,
                       int wout,
                       int chin,
                       int hin,
                       int win,
127 128
                       const float* weights,
                       const float* bias,
Y
Yan Chunwei 已提交
129 130 131 132 133 134 135 136 137 138 139
                       int group,
                       int kernel_w,
                       int kernel_h,
                       int stride_w,
                       int stride_h,
                       int dila_w,
                       int dila_h,
                       int pad_w,
                       int pad_h,
                       bool flag_bias,
                       bool flag_relu,
140
                       ARMContext& ctx);  // NOLINT
Y
Yan Chunwei 已提交
141 142 143 144 145 146 147 148 149 150 151 152 153

void conv1x1s1_gemm(const float* din,
                    float* dout,
                    int num,
                    int chout,
                    int hout,
                    int wout,
                    int chin,
                    int hin,
                    int win,
                    const float* weights,
                    const float* bias,
                    const operators::ConvParam& param,
154
                    ARMContext* ctx);
Y
Yan Chunwei 已提交
155

156
template <typename Dtype>
Y
Yan Chunwei 已提交
157
void conv1x1s1_gemm_int8(const int8_t* din,
158
                         Dtype* dout,
Y
Yan Chunwei 已提交
159 160 161 162 163 164 165 166
                         int num,
                         int chout,
                         int hout,
                         int wout,
                         int chin,
                         int hin,
                         int win,
                         const int8_t* weights,
167
                         const float* bias,
Y
Yan Chunwei 已提交
168
                         const operators::ConvParam& param,
169 170
                         ARMContext* ctx,
                         const float* scale);
Y
Yan Chunwei 已提交
171 172 173 174 175 176 177 178 179 180 181 182 183

void conv_im2col_gemm(const float* din,
                      float* dout,
                      int num,
                      int chout,
                      int hout,
                      int wout,
                      int chin,
                      int hin,
                      int win,
                      const float* weights,
                      const float* bias,
                      const operators::ConvParam& param,
184
                      ARMContext* ctx);
Y
Yan Chunwei 已提交
185

186
template <typename Dtype>
Y
Yan Chunwei 已提交
187
void conv_im2col_gemm_int8(const int8_t* din,
188
                           Dtype* dout,
Y
Yan Chunwei 已提交
189 190 191 192 193 194 195 196
                           int num,
                           int chout,
                           int hout,
                           int wout,
                           int chin,
                           int hin,
                           int win,
                           const int8_t* weights,
197
                           const float* bias,
Y
Yan Chunwei 已提交
198
                           const operators::ConvParam& param,
199 200
                           ARMContext* ctx,
                           const float* scale);
Y
Yan Chunwei 已提交
201

202 203 204
/// depthwise conv
void conv_depthwise_3x3_fp32(const void* din,
                             void* dout,
Y
Yan Chunwei 已提交
205
                             int num,
206 207 208 209 210 211 212 213
                             int ch_out,
                             int h_out,
                             int w_out,
                             int ch_in,
                             int h_in,
                             int w_in,
                             const void* weights,
                             const float* bias,
Y
Yan Chunwei 已提交
214
                             const operators::ConvParam& param,
215
                             ARMContext* ctx,
Y
Yan Chunwei 已提交
216 217
                             const float* scale);

218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249
void conv_depthwise_3x3_int8_fp32(const void* din,
                                  void* dout,
                                  int num,
                                  int ch_out,
                                  int h_out,
                                  int w_out,
                                  int ch_in,
                                  int h_in,
                                  int w_in,
                                  const void* weights,
                                  const float* bias,
                                  const operators::ConvParam& param,
                                  ARMContext* ctx,
                                  const float* scale);

void conv_depthwise_3x3_int8_int8(const void* din,
                                  void* dout,
                                  int num,
                                  int ch_out,
                                  int h_out,
                                  int w_out,
                                  int ch_in,
                                  int h_in,
                                  int w_in,
                                  const void* weights,
                                  const float* bias,
                                  const operators::ConvParam& param,
                                  ARMContext* ctx,
                                  const float* scale);

void conv_depthwise_5x5_fp32(const void* din,
                             void* dout,
Y
Yan Chunwei 已提交
250
                             int num,
251 252 253 254 255 256 257 258
                             int ch_out,
                             int h_out,
                             int w_out,
                             int ch_in,
                             int h_in,
                             int w_in,
                             const void* weights,
                             const float* bias,
Y
Yan Chunwei 已提交
259
                             const operators::ConvParam& param,
260
                             ARMContext* ctx,
Y
Yan Chunwei 已提交
261 262
                             const float* scale);

263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293
void conv_depthwise_5x5_int8_fp32(const void* din,
                                  void* dout,
                                  int num,
                                  int ch_out,
                                  int h_out,
                                  int w_out,
                                  int ch_in,
                                  int h_in,
                                  int w_in,
                                  const void* weights,
                                  const float* bias,
                                  const operators::ConvParam& param,
                                  ARMContext* ctx,
                                  const float* scale);

void conv_depthwise_5x5_int8_int8(const void* din,
                                  void* dout,
                                  int num,
                                  int ch_out,
                                  int h_out,
                                  int w_out,
                                  int ch_in,
                                  int h_in,
                                  int w_in,
                                  const void* weights,
                                  const float* bias,
                                  const operators::ConvParam& param,
                                  ARMContext* ctx,
                                  const float* scale);

/// winograd conv, only support 3x3s1
Y
Yan Chunwei 已提交
294 295 296 297 298 299 300 301 302 303 304 305
void conv_winograd3x3(const float* din,
                      float* dout,
                      int num,
                      int chout,
                      int hout,
                      int wout,
                      int chin,
                      int hin,
                      int win,
                      const float* weights,
                      const float* bias,
                      const operators::ConvParam& param,
306
                      ARMContext* ctx);
Y
Yan Chunwei 已提交
307 308 309 310 311 312 313 314 315 316

void winograd_transform_weights(
    void* dout, const void* din, int ch_out, int ch_in, void* work_space);

void fill_bias(float* tensor, const float* bias, int channel, int channel_size);

void fill_bias_int8(int* tensor,
                    const int* bias,
                    int channel,
                    int channel_size);
T
TianXiaogang 已提交
317
// new winograd
Y
Yan Chunwei 已提交
318

T
TianXiaogang 已提交
319 320 321
void weight_trans_c4_8x8(
    float* dest, const float* src, int ic, int oc, void* workspace);
void weight_trans_c4_4x4(
T
TianXiaogang 已提交
322 323 324 325 326 327 328 329 330 331 332 333 334 335
    float* dest, const float* src, int ic, int oc, void* workspace);
void conv_compute_6x6_3x3(const float* input,
                          float* output,
                          int num,
                          int chout,
                          int hout,
                          int wout,
                          int chin,
                          int hin,
                          int win,
                          const float* weight,
                          const float* bias,
                          const operators::ConvParam& param,
                          ARMContext* ctx);
T
TianXiaogang 已提交
336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361
void conv_compute_2x2_3x3(const float* input,
                          float* output,
                          int num,
                          int chout,
                          int hout,
                          int wout,
                          int chin,
                          int hin,
                          int win,
                          const float* weight,
                          const float* bias,
                          const operators::ConvParam& param,
                          ARMContext* ctx);
void conv_compute_2x2_3x3_small(const float* input,
                                float* output,
                                int num,
                                int chout,
                                int hout,
                                int wout,
                                int chin,
                                int hin,
                                int win,
                                const float* weight,
                                const float* bias,
                                const operators::ConvParam& param,
                                ARMContext* ctx);
362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390
void input_trans_c8_4x4_int8(const int8_t* src,
                             int src_stride,
                             int src_h_stride,
                             int16_t* dest,
                             int dest_stride,
                             int dest_h_stride);
void output_trans_c8_post_2x4_int8(const int32_t* src,
                                   int src_stride,
                                   int src_h_stride,
                                   int32_t* dest,
                                   int dest_stride,
                                   int dest_h_stride);
void weight_trans_c8_4x4_int8(
    int16_t* dest, const int8_t* src, int ic, int oc, void* workspace);
template <typename Dtype>
void conv_compute_2x2_3x3_int8(const int8_t* input,
                               Dtype* output,
                               int num,
                               int chout,
                               int hout,
                               int wout,
                               int chin,
                               int hin,
                               int win,
                               const int16_t* weight,
                               const float* bias,
                               const float* scale,
                               const operators::ConvParam& param,
                               ARMContext* ctx);
391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408

template <typename Dtype>
void im2col(const Dtype* data_im,
            int channels,
            int height,
            int width,
            int kernel_h,
            int kernel_w,
            int pad_top,
            int pad_bottom,
            int pad_left,
            int pad_right,
            int stride_h,
            int stride_w,
            int dilation_h,
            int dilation_w,
            Dtype* data_col);

Y
Yan Chunwei 已提交
409 410 411 412
}  // namespace math
}  // namespace arm
}  // namespace lite
}  // namespace paddle