/** * \file dnn/src/arm_common/conv_bias/postprocess_helper.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #pragma once #include "megdnn/basic_types.h" #include "src/arm_common/elemwise_helper/kimpl/op_base.h" #include "src/arm_common/elemwise_op.h" #include "src/fallback/conv_bias/opr_impl.h" namespace { #define CONCAT_OP(_name) megdnn::arm_common::_name #define CONCAT_NL(_name) megdnn::NonlineMode::_name #define CB(_caller, _op, _mode) \ case _mode: \ _caller(_op); \ break; #define DEFAULT \ default: \ megdnn_throw("unsupported nolinemode"); \ break; #define HANDLE_IDENTITY() \ case megdnn::NonlineMode::IDENTITY: \ break; #define FOR_NONLINEAR_UNARY(_op) \ megdnn::arm_common::OpCallerUnary<_op, megdnn::arm_common::VEC>:: \ run(static_cast(conv_dst_ptr), \ reinterpret_cast(dst_ptr), bias_type, dst_type, \ N* OC* OH* OW* pack_oc_size); #define FOR_NONLINEAR_BINARY_BROADCAST(_op) \ megdnn::arm_common:: \ OpCallerBinary<_op, megdnn::arm_common::VEC_BCAST101>::run( \ static_cast(conv_dst_ptr), \ reinterpret_cast(bias_ptr), \ reinterpret_cast(dst_ptr), bias_type, bias_type, \ dst_type, N, OC, OH* OW); #define FOR_NONLINEAR_BINARY_BROADCAST_NCHW44(_op) \ megdnn::arm_common::OpCallerBinary<_op, \ megdnn::arm_common::VEC_BCAST101x4>:: \ run(static_cast(conv_dst_ptr), \ reinterpret_cast(bias_ptr), \ reinterpret_cast(dst_ptr), bias_type, bias_type, \ dst_type, N, OC, OH* OW, pack_oc_size); #define FOR_NONLINEAR_BINARY(_op) \ megdnn::arm_common:: \ OpCallerBinary<_op, megdnn::arm_common::VEC_VEC>::run( \ static_cast(conv_dst_ptr), \ reinterpret_cast(bias_ptr), \ reinterpret_cast(dst_ptr), bias_type, bias_type, \ dst_type, N* OC* OH* OW* pack_oc_size); #define FOR_BIAS(_mode) \ switch (_mode) { \ case megdnn::BiasMode::NO_BIAS: \ FOR_NONLINEAR_NOBIAS(FOR_NONLINEAR_UNARY) \ break; \ case megdnn::BiasMode::BROADCAST_CHANNEL_BIAS: \ if (pack_oc_size == 1) { \ FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST); \ } else { \ megdnn_assert(pack_oc_size == 4, \ "Only support nchw44 in ARM"); \ FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST_NCHW44); \ } \ break; \ case megdnn::BiasMode::BIAS: \ FOR_NONLINEAR(FOR_NONLINEAR_BINARY) \ break; \ default: \ megdnn_throw("no quantized unsupported biasmode"); \ break; \ } #define FOR_NONLINEAR(_caller) \ switch (nonlineMode) { \ CB(_caller, CONCAT_OP(AddOp), CONCAT_NL(IDENTITY)) \ CB(_caller, CONCAT_OP(FuseAddReluOp), CONCAT_NL(RELU)) \ CB(_caller, CONCAT_OP(FuseAddSigmoidOp), CONCAT_NL(SIGMOID)) \ CB(_caller, CONCAT_OP(FuseAddHSwishOp), CONCAT_NL(H_SWISH)) \ DEFAULT \ } #define FOR_NONLINEAR_NOBIAS(_caller) \ switch (nonlineMode) { \ HANDLE_IDENTITY() \ CB(_caller, CONCAT_OP(ReluOp), CONCAT_NL(RELU)) \ CB(_caller, CONCAT_OP(SigmoidOp), CONCAT_NL(SIGMOID)) \ CB(_caller, CONCAT_OP(HSwishOp), CONCAT_NL(H_SWISH)) \ DEFAULT \ } template struct PostProcess { static void run(void* conv_dst_ptr, const void* bias_ptr, void* dst_ptr, megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode, megdnn::DType bias_type, megdnn::DType dst_type, size_t N, size_t OC, size_t OH, size_t OW, size_t pack_oc_size = 1) { FOR_BIAS(bias_mode) } }; template struct PostProcess { static void run(void* conv_dst_ptr, void* bias_ptr, void* dst_ptr, megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode, megdnn::DType bias_type, megdnn::DType dst_type, size_t N, size_t OC, size_t OH, size_t OW, size_t pack_oc_size = 1) { MEGDNN_MARK_USED_VAR(conv_dst_ptr); MEGDNN_MARK_USED_VAR(bias_ptr); MEGDNN_MARK_USED_VAR(dst_ptr); MEGDNN_MARK_USED_VAR(bias_mode); MEGDNN_MARK_USED_VAR(nonlineMode); MEGDNN_MARK_USED_VAR(bias_type); MEGDNN_MARK_USED_VAR(dst_type); MEGDNN_MARK_USED_VAR(N); MEGDNN_MARK_USED_VAR(OC); MEGDNN_MARK_USED_VAR(OH); MEGDNN_MARK_USED_VAR(OW); MEGDNN_MARK_USED_VAR(pack_oc_size); megdnn_assert(bias_mode == megdnn::BiasMode::NO_BIAS && nonlineMode == megdnn::NonlineMode::IDENTITY); } }; #undef FOR_NONLINEAR_UNARY #undef FOR_NONLINEAR_BINARY_BROADCAST #undef FOR_NONLINEAR_BINARY_BROADCAST_NCHW44 #undef FOR_NONLINEAR_BINARY #undef FOR_NONLINEAR_NOBIAS #undef FOR_NONLINEAR #undef FOR_BIAS #undef HANDLE_IDENTITY #define FOR_NONLINEAR_UNARY(_op) \ megdnn::arm_common::OpCallerUnary< \ _op, \ megdnn::arm_common::VEC>::run(static_cast(conv_dst_ptr), \ reinterpret_cast(dst_ptr), \ bias_type, dst_type, \ N* OC* OH* OW* pack_oc_size); #define FOR_NONLINEAR_BINARY_BROADCAST(_op) \ megdnn::arm_common::OpCallerBinary<_op, \ megdnn::arm_common::VEC_BCAST101>:: \ run(static_cast(conv_dst_ptr), \ reinterpret_cast(bias_ptr), \ reinterpret_cast(dst_ptr), bias_type, bias_type, \ dst_type, N, OC, OH* OW); #define FOR_NONLINEAR_BINARY_BROADCAST_NCHW44(_op) \ megdnn::arm_common::OpCallerBinary<_op, \ megdnn::arm_common::VEC_BCAST101x4>:: \ run(static_cast(conv_dst_ptr), \ reinterpret_cast(bias_ptr), \ reinterpret_cast(dst_ptr), bias_type, bias_type, \ dst_type, N, OC, OH* OW, pack_oc_size); #define HANDLE_IDENTITY(_caller, _op) \ case megdnn::NonlineMode::IDENTITY: \ _caller(_op) break; #define FOR_NONLINEAR(_caller) \ switch (nonlineMode) { \ HANDLE_IDENTITY(_caller, CONCAT_OP(AddOp)) \ CB(_caller, CONCAT_OP(FuseAddReluOp), CONCAT_NL(RELU)) \ CB(_caller, CONCAT_OP(FuseAddHSwishOp), CONCAT_NL(H_SWISH)) \ DEFAULT \ } #define FOR_NONLINEAR_NOBIAS(_caller) \ switch (nonlineMode) { \ HANDLE_IDENTITY(_caller, CONCAT_OP(TypeCvtOp)) \ CB(_caller, CONCAT_OP(ReluOp), CONCAT_NL(RELU)) \ CB(_caller, CONCAT_OP(HSwishOp), CONCAT_NL(H_SWISH)) \ DEFAULT \ } #define FOR_BIAS(_bias_mode) \ switch (_bias_mode) { \ case megdnn::BiasMode::NO_BIAS: \ FOR_NONLINEAR_NOBIAS(FOR_NONLINEAR_UNARY); \ break; \ case megdnn::BiasMode::BROADCAST_CHANNEL_BIAS: \ if (pack_oc_size == 1) { \ FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST); \ } else { \ megdnn_assert(pack_oc_size == 4, \ "Only support nchw44 in ARM"); \ FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST_NCHW44); \ } \ break; \ default: \ megdnn_throw("quantized unsupported biasmode"); \ break; \ } template struct PostProcess { static void run(void* conv_dst_ptr, const void* bias_ptr, void* dst_ptr, megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode, megdnn::DType bias_type, megdnn::DType dst_type, size_t N, size_t OC, size_t OH, size_t OW, size_t pack_oc_size = 1) { FOR_BIAS(bias_mode); } }; #undef FOR_NONLINEAR_UNARY #undef FOR_NONLINEAR_BINARY_BROADCAST #undef FOR_NONLINEAR_BINARY_BROADCAST_NCHW44 #undef FOR_NONLINEAR_BINARY #undef FOR_NONLINEAR_NOBIAS #undef FOR_NONLINEAR #undef FOR_BIAS #undef CB #undef CONCAT_OP #undef CONCAT_NL #undef DEFAULT #undef HANDLE_IDENTITY #define DISPATCH_CONV_WINOGRAD_NONLINE(_midout_tag, cb, _bias_id, _src_type, \ _dst_type, _bmode, _nonline_mode, ...) \ switch (_nonline_mode) { \ case param::ConvBias::NonlineMode::IDENTITY: { \ MIDOUT_BEGIN(_midout_tag, _bias_id, 0) { \ cb(_bmode, NoneOp<_src_type MEGDNN_COMMA _dst_type>, \ __VA_ARGS__); \ } \ MIDOUT_END(); \ break; \ } \ case param::ConvBias::NonlineMode::RELU: { \ MIDOUT_BEGIN(_midout_tag, _bias_id, 1) { \ cb(_bmode, ReluOp<_src_type MEGDNN_COMMA _dst_type>, \ __VA_ARGS__); \ } \ MIDOUT_END(); \ break; \ } \ case param::ConvBias::NonlineMode::SIGMOID: { \ MIDOUT_BEGIN(_midout_tag, _bias_id, 2) { \ cb(_bmode, SigmoidOp<_src_type MEGDNN_COMMA _dst_type>, \ __VA_ARGS__); \ } \ MIDOUT_END(); \ break; \ } \ case param::ConvBias::NonlineMode::H_SWISH: { \ MIDOUT_BEGIN(_midout_tag, _bias_id, 3) { \ cb(_bmode, HSwishOp<_src_type MEGDNN_COMMA _dst_type>, \ __VA_ARGS__); \ } \ MIDOUT_END(); \ break; \ } \ default: \ megdnn_assert(0); \ break; \ } #define DISPATCH_CONV_WINOGRAD_NONLINE_QUANTIZED(_midout_tag, cb, _bias_id, \ _src_type, _dst_type, _bmode, \ _nonline_mode, ...) \ switch (_nonline_mode) { \ case param::ConvBias::NonlineMode::IDENTITY: { \ MIDOUT_BEGIN(_midout_tag, _bias_id, 0) { \ cb(_bmode, TypeCvtOp<_src_type MEGDNN_COMMA _dst_type>, \ __VA_ARGS__); \ } \ MIDOUT_END(); \ break; \ } \ case param::ConvBias::NonlineMode::RELU: { \ MIDOUT_BEGIN(_midout_tag, _bias_id, 1) { \ cb(_bmode, ReluOp<_src_type MEGDNN_COMMA _dst_type>, \ __VA_ARGS__); \ } \ MIDOUT_END(); \ break; \ } \ default: \ megdnn_assert(0); \ break; \ } #define DISPATCH_CONV_WINOGRAD_BIAS(_midout_tag, cb, _src_type, _dst_type, \ _bmode, _nonline_mode, ...) \ switch (_bmode) { \ case BiasMode::BIAS: { \ DISPATCH_CONV_WINOGRAD_NONLINE(_midout_tag, cb, 0, _src_type, \ _dst_type, BiasMode::BIAS, \ _nonline_mode, __VA_ARGS__) \ break; \ } \ case BiasMode::NO_BIAS: { \ DISPATCH_CONV_WINOGRAD_NONLINE(_midout_tag, cb, 1, _src_type, \ _dst_type, BiasMode::NO_BIAS, \ _nonline_mode, __VA_ARGS__) \ break; \ } \ case BiasMode::BROADCAST_CHANNEL_BIAS: { \ DISPATCH_CONV_WINOGRAD_NONLINE(_midout_tag, cb, 2, _src_type, \ _dst_type, \ BiasMode::BROADCAST_CHANNEL_BIAS, \ _nonline_mode, __VA_ARGS__) \ break; \ } \ default: \ megdnn_assert(0); \ break; \ } #define DISPATCH_CONV_WINOGRAD_BIAS_QUANTIZED( \ _midout_tag, cb, _src_type, _dst_type, _bmode, _nonline_mode, ...) \ switch (_bmode) { \ case BiasMode::BIAS: { \ DISPATCH_CONV_WINOGRAD_NONLINE_QUANTIZED( \ _midout_tag, cb, 0, _src_type, _dst_type, BiasMode::BIAS, \ _nonline_mode, __VA_ARGS__) \ break; \ } \ case BiasMode::NO_BIAS: { \ DISPATCH_CONV_WINOGRAD_NONLINE_QUANTIZED( \ _midout_tag, cb, 1, _src_type, _dst_type, \ BiasMode::NO_BIAS, _nonline_mode, __VA_ARGS__) \ break; \ } \ case BiasMode::BROADCAST_CHANNEL_BIAS: { \ DISPATCH_CONV_WINOGRAD_NONLINE_QUANTIZED( \ _midout_tag, cb, 2, _src_type, _dst_type, \ BiasMode::BROADCAST_CHANNEL_BIAS, _nonline_mode, \ __VA_ARGS__) \ break; \ } \ default: \ megdnn_assert(0); \ break; \ } } // namespace