提交 547945e8 编写于 作者: M Megvii Engine Team

feat(fallback): support general intrinsic in elemwise in fallback

GitOrigin-RevId: 96ff2e88ccbe89be2003bea3fffc5ee4ed86afd0
上级 a017bed3
......@@ -44,30 +44,29 @@ namespace {
break;
#define FOR_NONLINEAR_UNARY(_op) \
megdnn::arm_common::OpCallerUnary<_op<ctype>, megdnn::arm_common::VEC>::run( \
megdnn::arm_common::OpCallerUnary<_op<ctype>, megdnn::VEC>::run( \
static_cast<ctype*>(conv_dst_ptr), reinterpret_cast<ctype*>(dst_ptr), \
bias_type, dst_type, N* OC* OH* OW* pack_oc_size);
#define FOR_NONLINEAR_BINARY_BROADCAST(_op) \
megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::arm_common::VEC_BCAST101>:: \
run(static_cast<ctype*>(conv_dst_ptr), \
reinterpret_cast<const ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, \
OC, OH* OW);
#define FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX(_op) \
megdnn::arm_common:: \
OpCallerBinary<_op<ctype>, megdnn::arm_common::VEC_BCAST101xX>::run( \
static_cast<ctype*>(conv_dst_ptr), \
reinterpret_cast<const ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, \
N, OC, OH* OW, pack_oc_size);
#define FOR_NONLINEAR_BINARY(_op) \
megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::arm_common::VEC_VEC>::run( \
#define FOR_NONLINEAR_BINARY_BROADCAST(_op) \
megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::VEC_BCAST101>::run( \
static_cast<ctype*>(conv_dst_ptr), \
reinterpret_cast<const ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, OC, \
OH* OW);
#define FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX(_op) \
megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::VEC_BCAST101xX>::run( \
static_cast<ctype*>(conv_dst_ptr), \
reinterpret_cast<const ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, OC, \
OH* OW, pack_oc_size);
#define FOR_NONLINEAR_BINARY(_op) \
megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::VEC_VEC>::run( \
static_cast<ctype*>(conv_dst_ptr), \
reinterpret_cast<const ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, \
N* OC* OH* OW* pack_oc_size);
#define FOR_BIAS(_mode) \
......@@ -168,36 +167,33 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> {
#undef FOR_BIAS
#undef HANDLE_IDENTITY
#define FOR_NONLINEAR_UNARY(_op) \
#define FOR_NONLINEAR_UNARY(_op) \
megdnn::arm_common::OpCallerUnary<_op<opctype, opdtype>, megdnn::VEC>::run( \
static_cast<opctype*>(conv_dst_ptr), reinterpret_cast<opdtype*>(dst_ptr), \
bias_type, dst_type, N* OC* OH* OW* pack_oc_size);
#define FOR_NONLINEAR_BINARY_BROADCAST(_op) \
megdnn::arm_common::OpCallerBinary<_op<opctype, opdtype>, megdnn::VEC_BCAST101>:: \
run(static_cast<opctype*>(conv_dst_ptr), \
reinterpret_cast<const opctype*>(bias_ptr), \
reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, dst_type, \
N, OC, OH* OW);
#define FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX(_op) \
megdnn::arm_common:: \
OpCallerUnary<_op<opctype, opdtype>, megdnn::arm_common::VEC>::run( \
OpCallerBinary<_op<opctype, opdtype>, megdnn::VEC_BCAST101xX>::run( \
static_cast<opctype*>(conv_dst_ptr), \
reinterpret_cast<opdtype*>(dst_ptr), bias_type, dst_type, \
N* OC* OH* OW* pack_oc_size);
#define FOR_NONLINEAR_BINARY_BROADCAST(_op) \
megdnn::arm_common::OpCallerBinary< \
_op<opctype, opdtype>, megdnn::arm_common::VEC_BCAST101>:: \
run(static_cast<opctype*>(conv_dst_ptr), \
reinterpret_cast<const opctype*>(bias_ptr), \
reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, dst_type, \
N, OC, OH* OW);
reinterpret_cast<const opctype*>(bias_ptr), \
reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, \
dst_type, N, OC, OH* OW, pack_oc_size);
#define FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX(_op) \
megdnn::arm_common::OpCallerBinary< \
_op<opctype, opdtype>, megdnn::arm_common::VEC_BCAST101xX>:: \
run(static_cast<opctype*>(conv_dst_ptr), \
reinterpret_cast<const opctype*>(bias_ptr), \
reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, dst_type, \
N, OC, OH* OW, pack_oc_size);
#define FOR_NONLINEAR_BINARY_BROADCAST_NCHW88(_op) \
megdnn::arm_common::OpCallerBinary< \
_op<opctype, opdtype>, megdnn::arm_common::VEC_BCAST101xX>:: \
run(static_cast<opctype*>(conv_dst_ptr), \
reinterpret_cast<const opctype*>(bias_ptr), \
reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, dst_type, \
N, OC, OH* OW, pack_oc_size);
#define FOR_NONLINEAR_BINARY_BROADCAST_NCHW88(_op) \
megdnn::arm_common:: \
OpCallerBinary<_op<opctype, opdtype>, megdnn::VEC_BCAST101xX>::run( \
static_cast<opctype*>(conv_dst_ptr), \
reinterpret_cast<const opctype*>(bias_ptr), \
reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, \
dst_type, N, OC, OH* OW, pack_oc_size);
#define HANDLE_IDENTITY(_caller, _op) \
case megdnn::NonlineMode::IDENTITY: \
......@@ -271,26 +267,25 @@ struct PostProcess<opctype, opdtype, megdnn::PostprocessMode::QUANTIZED> {
#undef FOR_NONLINEAR
#undef FOR_BIAS
#define FOR_BINARY_BROADCAST(_op) \
megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::arm_common::VEC_BCAST101>:: \
run(static_cast<ctype*>(conv_dst_ptr), \
reinterpret_cast<const ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, \
OC, OH* OW);
#define FOR_BINARY_BROADCAST_NCHWXX(_op) \
megdnn::arm_common:: \
OpCallerBinary<_op<ctype>, megdnn::arm_common::VEC_BCAST101xX>::run( \
static_cast<ctype*>(conv_dst_ptr), \
reinterpret_cast<const ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, \
N, OC, OH* OW, pack_oc_size);
#define FOR_BINARY(_op) \
megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::arm_common::VEC_VEC>::run( \
#define FOR_BINARY_BROADCAST(_op) \
megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::VEC_BCAST101>::run( \
static_cast<ctype*>(conv_dst_ptr), \
reinterpret_cast<const ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, OC, \
OH* OW);
#define FOR_BINARY_BROADCAST_NCHWXX(_op) \
megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::VEC_BCAST101xX>::run( \
static_cast<ctype*>(conv_dst_ptr), \
reinterpret_cast<const ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, OC, \
OH* OW, pack_oc_size);
#define FOR_BINARY(_op) \
megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::VEC_VEC>::run( \
static_cast<ctype*>(conv_dst_ptr), \
reinterpret_cast<const ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, \
N* OC* OH* OW* pack_oc_size);
#define FOR_BIAS(_bias_mode, OH, OW) \
......
......@@ -89,163 +89,4 @@ void ElemwiseImpl::exec(const TensorNDArray& srcs, _megdnn_tensor_out dst) {
fallback::ElemwiseImpl::exec(srcs, dst);
}
ElemwiseImpl::KernParam ElemwiseImpl::make_kern_param(ElemwiseImpl* opr) {
KernParam kern_param;
kern_param.broad_cast_type = BcastType::UNKNOWN_BCAST_TYPE;
kern_param.mode = opr->param().mode;
kern_param.handle = opr->handle();
auto is_legal_layout_for_nhwc = [](const TensorLayout& l) {
if (is_vector(l))
return true;
if (l.ndim == 2 && l.stride[1] == 1)
return true;
return false;
};
if ((opr->m_src->size() == 3) && (opr->param().mode == Mode::FUSE_MUL_ADD3)) {
kern_param.ternary_elparam = opr->make_elemwise_op_param<3>();
bool c_is_scalar;
opr->prepare_fma3(kern_param.ternary_elparam, c_is_scalar);
auto &src0 = kern_param.ternary_elparam[0],
&src1 = kern_param.ternary_elparam[1],
&src2 = kern_param.ternary_elparam[2];
BroadcastChannelInfo binfo;
if (is_vector(src0.layout) && is_vector(src1.layout) &&
is_vector(src2.layout)) {
kern_param.broad_cast_type = BcastType::VEC_VEC_VEC;
return kern_param;
}
if (is_vector(src0.layout) && is_vector(src1.layout) && c_is_scalar) {
kern_param.broad_cast_type = BcastType::VEC_VEC_SCALAR;
return kern_param;
}
if (is_vector(src1.layout) && is_broadcasted_channel_like(src0.layout, binfo) &&
src0.layout.eq_layout(src2.layout)) {
kern_param.broad_cast_type = BcastType::BCAST101_VEC_BCAST101;
return kern_param;
}
if (is_vector(src1.layout) &&
(is_broadcastedx_channel_like<4>(src0.layout, binfo) ||
is_broadcastedx_channel_like<8>(src0.layout, binfo)) &&
src0.layout.eq_layout(src2.layout)) {
kern_param.broad_cast_type = BcastType::BCAST101xX_VEC_BCAST101xX;
return kern_param;
}
if (is_vector(src0.layout) && src0.layout.eq_layout(src2.layout) &&
is_broadcasted_channel_like(src1.layout, binfo)) {
kern_param.broad_cast_type = BcastType::VEC_BCAST101_VEC;
return kern_param;
}
if (is_legal_layout_for_nhwc(src1.layout) &&
is_NHWC_broadcasted_channel_like(src0.layout, binfo) &&
src0.layout.eq_layout(src2.layout)) {
kern_param.broad_cast_type = BcastType::BCAST111C_VEC_BCAST111C;
return kern_param;
}
if (is_legal_layout_for_nhwc(src0.layout) &&
src2.layout.eq_layout(src0.layout) &&
is_NHWC_broadcasted_channel_like(src1.layout, binfo)) {
kern_param.broad_cast_type = BcastType::VEC_BCAST111C_VEC;
return kern_param;
}
if (is_vector(src0.layout) && src0.layout.eq_layout(src2.layout) &&
(is_broadcastedx_channel_like<4>(src1.layout, binfo) ||
is_broadcastedx_channel_like<8>(src1.layout, binfo))) {
kern_param.broad_cast_type = BcastType::VEC_BCAST101xX_VEC;
return kern_param;
}
if (is_vector(src0.layout) && is_vector(src2.layout) &&
is_broadcasted_scalar(src1.layout)) {
kern_param.broad_cast_type = BcastType::VEC_SCALAR_VEC;
return kern_param;
}
if (is_vector(src0.layout) && is_broadcasted_scalar(src1.layout) &&
is_broadcasted_scalar(src2.layout)) {
kern_param.broad_cast_type = BcastType::VEC_SCALAR_SCALAR;
return kern_param;
}
} else if (opr->m_src->size() == 2) {
kern_param.binary_elparam = opr->make_elemwise_op_param<2>();
auto &src0 = kern_param.binary_elparam[0], &src1 = kern_param.binary_elparam[1];
BroadcastChannelInfo binfo;
if (is_vector(src0.layout) && is_vector(src1.layout)) {
kern_param.broad_cast_type = BcastType::VEC_VEC;
return kern_param;
}
if (is_vector(src0.layout) && is_broadcasted_scalar(src1.layout)) {
kern_param.broad_cast_type = BcastType::VEC_SCALAR;
return kern_param;
}
if (is_vector(src1.layout) && is_broadcasted_scalar(src0.layout)) {
kern_param.broad_cast_type = BcastType::SCALAR_VEC;
return kern_param;
}
if (is_vector(src0.layout) && is_broadcasted_channel_like(src1.layout, binfo)) {
kern_param.broad_cast_type = BcastType::VEC_BCAST101;
return kern_param;
}
if (is_vector(src1.layout) && is_broadcasted_channel_like(src0.layout, binfo)) {
kern_param.broad_cast_type = BcastType::BCAST101_VEC;
return kern_param;
}
if (is_vector(src0.layout) && is_broadcasted_3dim_like(src1.layout, binfo)) {
kern_param.broad_cast_type = BcastType::VEC_BCASTX0X;
return kern_param;
}
if (is_vector(src1.layout) && is_broadcasted_3dim_like(src0.layout, binfo)) {
kern_param.broad_cast_type = BcastType::BCASTX0X_VEC;
return kern_param;
}
if (is_legal_layout_for_nhwc(src1.layout) &&
is_NHWC_broadcasted_channel_like(src0.layout, binfo)) {
kern_param.broad_cast_type = BcastType::BCAST111C_VEC;
return kern_param;
}
if (is_legal_layout_for_nhwc(src0.layout) &&
is_NHWC_broadcasted_channel_like(src1.layout, binfo)) {
kern_param.broad_cast_type = BcastType::VEC_BCAST111C;
return kern_param;
}
if (is_vector(src0.layout) &&
(is_broadcastedx_channel_like<4>(src1.layout, binfo) ||
is_broadcastedx_channel_like<8>(src1.layout, binfo))) {
kern_param.broad_cast_type = BcastType::VEC_BCAST101xX;
return kern_param;
}
if (is_vector(src1.layout) &&
(is_broadcastedx_channel_like<4>(src0.layout, binfo) ||
is_broadcastedx_channel_like<8>(src0.layout, binfo))) {
kern_param.broad_cast_type = BcastType::BCAST101xX_VEC;
return kern_param;
}
} else if (opr->m_src->size() == 1) {
kern_param.broad_cast_type = BcastType::VEC;
kern_param.unary_elparam = opr->make_elemwise_op_param<1>();
return kern_param;
}
return kern_param;
}
// vim: syntax=cpp.doxygen
......@@ -18,22 +18,12 @@ namespace megdnn {
namespace arm_common {
class ElemwiseImpl final : public fallback::ElemwiseImpl {
public:
using fallback::ElemwiseImpl::AlgoBase;
using fallback::ElemwiseImpl::ElemwiseImpl;
void exec(const TensorNDArray& srcs, _megdnn_tensor_out dst) override;
const char* get_algorithm_set_name() const { return "ARM COMMON ELEMWISE"; }
private:
struct KernParam {
BcastType broad_cast_type;
Mode mode;
const TensorND* m_dst;
Handle* handle;
ElemwiseOpParamN<3> ternary_elparam;
ElemwiseOpParamN<2> binary_elparam;
ElemwiseOpParamN<1> unary_elparam;
};
KernParam make_kern_param(ElemwiseImpl* opr);
class AlgoBase;
class AlgoUnary;
class AlgoBinaryVecVec;
class AlgoBinaryVecScalar;
......@@ -54,19 +44,6 @@ private:
class AlgoPack;
};
/*!
*
* \brief base class for Elemwise algo
*
*/
class ElemwiseImpl::AlgoBase : public detail::Algorithm {
public:
virtual bool is_available(const KernParam&) const = 0;
virtual void exec(const KernParam&) const = 0;
virtual ~AlgoBase() = default;
uint32_t type() const override { return INVALID_ALGO_TYPE; };
};
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#define DISPATCH_TYPE(_case) \
if (src0.layout.dtype == dtype::Float32{}) { \
......
......@@ -15,10 +15,13 @@
#include "src/arm_common/elemwise_helper/op_binary.h"
#include "src/arm_common/elemwise_helper/op_ternary.h"
#include "src/arm_common/elemwise_helper/op_unary.h"
#include "src/fallback/elemwise_helper/op_common.h"
namespace megdnn {
namespace arm_common {
using BcastType = megdnn::BcastType;
///////////////////////////////// ParamElemVistor ///////////////////////////
template <typename ctype>
struct ParamElemVisitor;
......@@ -99,36 +102,6 @@ cb(__fp16, __fp16, float16x8_t, f16);
#endif
#undef cb
/*!
* \brief broadcast type
* BCAST_x[0]x[1]...: x[i] == !stride[i]
*/
enum BcastType {
VEC,
VEC_VEC,
VEC_BCAST101,
VEC_BCASTX0X,
VEC_BCAST111C,
VEC_BCAST101xX,
VEC_SCALAR,
SCALAR_VEC,
BCAST101_VEC,
BCASTX0X_VEC,
BCAST111C_VEC,
BCAST101xX_VEC,
VEC_VEC_VEC,
VEC_VEC_SCALAR,
BCAST101_VEC_BCAST101,
BCAST111C_VEC_BCAST111C,
BCAST101xX_VEC_BCAST101xX,
VEC_BCAST101_VEC,
VEC_BCAST111C_VEC,
VEC_BCAST101xX_VEC,
VEC_SCALAR_VEC,
VEC_SCALAR_SCALAR,
UNKNOWN_BCAST_TYPE
};
///////////////////////////////// OpCaller /////////////////////////////
template <typename Op, BcastType bcast_type>
struct OpCallerUnary;
......
/**
* \file dnn/src/fallback/elemwise/opr_binary_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* \file dnn/src/fallback/elemwise/fallback_impl/opr_binary_impl.cpp
*/
#include "./opr_impl.h"
#include "src/fallback/elemwise/opr_impl.h"
#include "src/common/elemwise/kern_defs.cuh"
#include "src/common/utils.h"
......
/**
* \file dnn/src/fallback/elemwise/opr_unary_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* \file dnn/src/fallback/elemwise/fallback_impl/opr_unary_impl.cpp
*/
#include "./opr_impl.h"
#include "src/fallback/elemwise/opr_impl.h"
#include "src/common/elemwise/kern_defs.cuh"
#include "src/common/utils.h"
......
此差异已折叠。
/**
* \file dnn/src/fallback/elemwise/gi_impl/binary/algo.h
*/
#pragma once
#include "src/fallback/elemwise/opr_impl.h"
namespace megdnn {
namespace fallback {
#define DECL_CB(case) \
class ElemwiseImpl::AlgoBinary##case final : public ElemwiseImpl::AlgoBase { \
mutable std::string m_name; \
AlgoAttribute attribute() const override { \
return AlgoAttribute::REPRODUCIBLE; \
} \
const char* name() const override { \
if (m_name.empty()) { \
m_name = ssprintf("Elemwise::AlgoBinaryCase" #case); \
} \
return m_name.c_str(); \
} \
bool is_available(const KernParam&) const override; \
void exec(const KernParam&) const override; \
};
DECL_CB(VecVec);
DECL_CB(VecScalar);
DECL_CB(VecBcast101);
DECL_CB(VecBcastX0X);
DECL_CB(VecBcast111C);
DECL_CB(VecBcast101xX);
#undef DECL_CB
} // namespace fallback
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/fallback/elemwise/gi_mathfun.cpp
*
* This file has been modified by Megvii ("Megvii Modifications").
* All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights
* reserved.
*
*/
/* NEON implementation of sin, cos, exp and log
Inspired by Intel Approximate Math library, and based on the
corresponding algorithms of the cephes math library
*/
/* Copyright (C) 2011 Julien Pommier
This software is provided 'as-is', without any express or implied
warranty. In no event will the authors be held liable for any damages
arising from the use of this software.
Permission is granted to anyone to use this software for any purpose,
including commercial applications, and to alter it and redistribute it
freely, subject to the following restrictions:
1. The origin of this software must not be misrepresented; you must not
claim that you wrote the original software. If you use this software
in a product, an acknowledgment in the product documentation would be
appreciated but is not required.
2. Altered source versions must be plainly marked as such, and must not be
misrepresented as being the original software.
3. This notice may not be removed or altered from any source distribution.
(this is the zlib license)
*/
#include "./gi_mathfun.h"
namespace megdnn {
namespace fallback {
#define c_inv_mant_mask ~0x7f800000u
#define c_cephes_SQRTHF 0.707106781186547524
#define c_cephes_log_p0 7.0376836292E-2
#define c_cephes_log_p1 -1.1514610310E-1
#define c_cephes_log_p2 1.1676998740E-1
#define c_cephes_log_p3 -1.2420140846E-1
#define c_cephes_log_p4 +1.4249322787E-1
#define c_cephes_log_p5 -1.6668057665E-1
#define c_cephes_log_p6 +2.0000714765E-1
#define c_cephes_log_p7 -2.4999993993E-1
#define c_cephes_log_p8 +3.3333331174E-1
#define c_cephes_log_q1 -2.12194440e-4
#define c_cephes_log_q2 0.693359375
/**
* natural logarithm computed for 4 simultaneous float return NaN for x <= 0
*/
v4sf GiLogPsFloat32(v4sf x) {
v4sf one = GiBroadcastFloat32(1);
x = GiMaximumFloat32(
x, GiBroadcastFloat32(0)); /* force flush to zero on denormal values */
v4su invalid_mask = GiLessThanEqFloat32(x, GiBroadcastFloat32(0));
v4si ux = GiReinterpretAsInt32(x);
v4si emm0 = GiShiftRight23Int32(ux);
/* keep only the fractional part */
ux = GiAndInt32(ux, GiBroadcastInt32(c_inv_mant_mask));
ux = GiOrInt32(ux, GiReinterpretAsInt32(GiBroadcastFloat32(0.5f)));
x = GiReintInt32ToFloat32(ux);
emm0 = GiSubtractInt32(emm0, GiBroadcastInt32(0x7f));
v4sf e = GiCastToFloat32(emm0);
e = GiAddFloat32(e, one);
/* part2:
* if( x < SQRTHF ) {
* e -= 1;
* x = x + x - 1.0;
* } else { x = x - 1.0; }
*/
v4su mask = GiLessThanFloat32(x, GiBroadcastFloat32(c_cephes_SQRTHF));
v4sf tmp = GiAndFloat32(x, GiReintUint32ToFloat32(mask));
x = GiSubtractFloat32(x, one);
e = GiSubtractFloat32(e, GiAndFloat32(one, GiReintUint32ToFloat32(mask)));
x = GiAddFloat32(x, tmp);
v4sf z = GiMultiplyFloat32(x, x);
v4sf y = GiBroadcastFloat32(c_cephes_log_p0);
y = GiMultiplyAddFloat32(GiBroadcastFloat32(c_cephes_log_p1), y, x);
y = GiMultiplyAddFloat32(GiBroadcastFloat32(c_cephes_log_p2), y, x);
y = GiMultiplyAddFloat32(GiBroadcastFloat32(c_cephes_log_p3), y, x);
y = GiMultiplyAddFloat32(GiBroadcastFloat32(c_cephes_log_p4), y, x);
y = GiMultiplyAddFloat32(GiBroadcastFloat32(c_cephes_log_p5), y, x);
y = GiMultiplyAddFloat32(GiBroadcastFloat32(c_cephes_log_p6), y, x);
y = GiMultiplyAddFloat32(GiBroadcastFloat32(c_cephes_log_p7), y, x);
y = GiMultiplyAddFloat32(GiBroadcastFloat32(c_cephes_log_p8), y, x);
y = GiMultiplyFloat32(y, x);
y = GiMultiplyFloat32(y, z);
y = GiMultiplyAddFloat32(y, e, GiBroadcastFloat32(c_cephes_log_q1));
y = GiMultiplySubFloat32(y, z, GiBroadcastFloat32(0.5f));
x = GiAddFloat32(x, y);
x = GiMultiplyAddFloat32(x, e, GiBroadcastFloat32(c_cephes_log_q2));
x = GiOrFloat32(
x, GiReintUint32ToFloat32(invalid_mask)); // negative arg will be NAN
return x;
}
#define c_exp_hi 88.3762626647949f
#define c_exp_lo -88.3762626647949f
#define c_cephes_LOG2EF 1.44269504088896341
#define c_cephes_exp_C1 0.693359375
#define c_cephes_exp_C2 -2.12194440e-4
#define c_cephes_exp_p0 1.9875691500E-4
#define c_cephes_exp_p1 1.3981999507E-3
#define c_cephes_exp_p2 8.3334519073E-3
#define c_cephes_exp_p3 4.1665795894E-2
#define c_cephes_exp_p4 1.6666665459E-1
#define c_cephes_exp_p5 5.0000001201E-1
/* exp() computed for 4 float at once */
v4sf GiExpPsFloat32(v4sf x) {
v4sf tmp, fx;
v4sf one = GiBroadcastFloat32(1);
x = GiMinimumFloat32(x, GiBroadcastFloat32(c_exp_hi));
x = GiMaximumFloat32(x, GiBroadcastFloat32(c_exp_lo));
/* express exp(x) as exp(g + n*log(2)) */
fx = GiMultiplyAddFloat32(
GiBroadcastFloat32(0.5f), x, GiBroadcastFloat32(c_cephes_LOG2EF));
/* perform a floorf */
tmp = GiCastToFloat32(GiCastToInt32(fx));
/* if greater, subtract 1 */
v4su mask = GiGreaterThanFloat32(tmp, fx);
v4sf mask_float = GiAndFloat32(GiReintUint32ToFloat32(mask), one);
fx = GiSubtractFloat32(tmp, mask_float);
tmp = GiMultiplyFloat32(fx, GiBroadcastFloat32(c_cephes_exp_C1));
v4sf z = GiMultiplyFloat32(fx, GiBroadcastFloat32(c_cephes_exp_C2));
x = GiSubtractFloat32(x, tmp);
x = GiSubtractFloat32(x, z);
z = GiMultiplyFloat32(x, x);
v4sf y = GiBroadcastFloat32(c_cephes_exp_p0);
y = GiMultiplyAddFloat32(GiBroadcastFloat32(c_cephes_exp_p1), y, x);
y = GiMultiplyAddFloat32(GiBroadcastFloat32(c_cephes_exp_p2), y, x);
y = GiMultiplyAddFloat32(GiBroadcastFloat32(c_cephes_exp_p3), y, x);
y = GiMultiplyAddFloat32(GiBroadcastFloat32(c_cephes_exp_p4), y, x);
y = GiMultiplyAddFloat32(GiBroadcastFloat32(c_cephes_exp_p5), y, x);
y = GiMultiplyAddFloat32(x, y, z);
y = GiAddFloat32(y, one);
/* build 2^n */
v4si mm;
mm = GiCastToInt32(fx);
mm = GiAddInt32(mm, GiBroadcastInt32(0x7f));
mm = GiShiftLeft23Int32(mm);
v4sf pow2n = GiReintInt32ToFloat32(mm);
y = GiMultiplyFloat32(y, pow2n);
return y;
}
#define c_minus_cephes_DP1 -0.78515625
#define c_minus_cephes_DP2 -2.4187564849853515625e-4
#define c_minus_cephes_DP3 -3.77489497744594108e-8
#define c_sincof_p0 -1.9515295891E-4
#define c_sincof_p1 8.3321608736E-3
#define c_sincof_p2 -1.6666654611E-1
#define c_coscof_p0 2.443315711809948E-005
#define c_coscof_p1 -1.388731625493765E-003
#define c_coscof_p2 4.166664568298827E-002
#define c_cephes_FOPI 1.27323954473516 // 4 / M_PI
/* evaluation of 4 sines & cosines at once.
The code is the exact rewriting of the cephes sinf function.
Precision is excellent as long as x < 8192 (I did not bother to
take into account the special handling they have for greater values
-- it does not return garbage for arguments over 8192, though, but
the extra precision is missing).
Note that it is such that sinf((float)M_PI) = 8.74e-8, which is the
surprising but correct result.
Note also that when you compute sin(x), cos(x) is available at
almost no extra price so both sin_ps_f32 and cos_ps_f32 make use of
sincos_ps_f32..
*/
void GiSinCosPsFloat32(v4sf x, v4sf* ysin, v4sf* ycos) {
// any x
v4sf y;
v4su emm2;
v4su sign_mask_sin, sign_mask_cos;
sign_mask_sin = GiLessThanFloat32(x, GiBroadcastFloat32(0));
x = GiAbsFloat32(x);
/* scale by 4/Pi */
y = GiMultiplyFloat32(x, GiBroadcastFloat32(c_cephes_FOPI));
/* store the integer part of y in mm0 */
emm2 = GiReinterpretAsUint32(y);
/* j=(j+1) & (~1) (see the cephes sources) */
emm2 = GiAddUint32(emm2, GiBroadcastUint32(1));
emm2 = GiAddUint32(emm2, GiBroadcastUint32(~1));
y = GiReintUint32ToFloat32(emm2);
/* get the polynom selection mask
* there is one polynom for 0 <= x <= Pi/4
* and another one for Pi/4<x<=Pi/2
*
* Both branches will be computed.
*/
v4su poly_mask = GiTestAndSetUint32(emm2, GiBroadcastUint32(2));
/* The magic pass: "Extended precision modular arithmetic"
* x = ((x - y * DP1) - y * DP2) - y * DP3; */
x = GiMultiplyAddFloat32(x, y, GiBroadcastFloat32(c_minus_cephes_DP1));
x = GiMultiplyAddFloat32(x, y, GiBroadcastFloat32(c_minus_cephes_DP2));
x = GiMultiplyAddFloat32(x, y, GiBroadcastFloat32(c_minus_cephes_DP3));
sign_mask_sin =
GiEOrUint32(sign_mask_sin, GiTestAndSetUint32(emm2, GiBroadcastUint32(4)));
sign_mask_cos = GiTestAndSetUint32(
GiSubtractUint32(emm2, GiBroadcastUint32(2)), GiBroadcastUint32(4));
/* Evaluate the first polynom (0 <= x <= Pi/4) in y1,
* and the second polynom (Pi/4 <= x <= 0) in y2 */
v4sf z = GiMultiplyFloat32(x, x);
v4sf y1, y2;
y1 = GiMultiplyAddFloat32(
GiBroadcastFloat32(c_coscof_p1), z, GiBroadcastFloat32(c_coscof_p0));
y2 = GiMultiplyAddFloat32(
GiBroadcastFloat32(c_sincof_p1), z, GiBroadcastFloat32(c_sincof_p0));
y1 = GiMultiplyAddFloat32(GiBroadcastFloat32(c_coscof_p2), y1, z);
y2 = GiMultiplyAddFloat32(GiBroadcastFloat32(c_sincof_p2), y2, z);
y1 = GiMultiplyFloat32(y1, z);
y2 = GiMultiplyFloat32(y2, z);
y1 = GiMultiplyFloat32(y1, z);
y1 = GiMultiplySubFloat32(y1, z, GiBroadcastFloat32(0.5f));
y2 = GiMultiplyAddFloat32(x, y2, x);
y1 = GiAddFloat32(y1, GiBroadcastFloat32(1));
/* select the correct result from the two polynoms */
v4sf ys = GiBSLFloat32(poly_mask, y1, y2);
v4sf yc = GiBSLFloat32(poly_mask, y2, y1);
*ysin = GiBSLFloat32(sign_mask_sin, GiNegFloat32(ys), ys);
*ycos = GiBSLFloat32(sign_mask_cos, yc, GiNegFloat32(yc));
}
v4sf GiSinPsFloat32(v4sf x) {
v4sf ysin, ycos;
GiSinCosPsFloat32(x, &ysin, &ycos);
return ysin;
}
v4sf GiCosPsFloat32(v4sf x) {
v4sf ysin, ycos;
GiSinCosPsFloat32(x, &ysin, &ycos);
return ycos;
}
v4sf GiTanPsFloat32(v4sf x) {
v4sf ysin, ycos;
GiSinCosPsFloat32(x, &ysin, &ycos);
return ysin / ycos;
}
#undef c_exp_hi
#undef c_exp_lo
#undef c_cephes_LOG2EF
#undef c_cephes_exp_C1
#undef c_cephes_exp_C2
#undef c_cephes_exp_p0
#undef c_cephes_exp_p1
#undef c_cephes_exp_p2
#undef c_cephes_exp_p3
#undef c_cephes_exp_p4
#undef c_cephes_exp_p5
#undef c_minus_cephes_DP1
#undef c_minus_cephes_DP2
#undef c_minus_cephes_DP3
#undef c_sincof_p0
#undef c_sincof_p1
#undef c_sincof_p2
#undef c_coscof_p0
#undef c_coscof_p1
#undef c_coscof_p2
#undef c_cephes_FOPI
#undef c_inv_mant_mask
#undef c_cephes_SQRTHF
#undef c_cephes_log_p0
#undef c_cephes_log_p1
#undef c_cephes_log_p2
#undef c_cephes_log_p3
#undef c_cephes_log_p4
#undef c_cephes_log_p5
#undef c_cephes_log_p6
#undef c_cephes_log_p7
#undef c_cephes_log_p8
#undef c_cephes_log_q1
#undef c_cephes_log_q2
static const struct {
float lower_range;
float upper_range;
float alpha_9;
float alpha_7;
float alpha_5;
float alpha_3;
float alpha_1;
float beta_10;
float beta_8;
float beta_6;
float beta_4;
float beta_2;
float beta_0;
float one_half;
} sigmoid_constants = {
-18.0f,
18.0f,
4.37031012579801e-11f,
1.15627324459942e-07f,
6.08574864600143e-05f,
8.51377133304701e-03f,
2.48287947061529e-01f,
6.10247389755681e-13f,
5.76102136993427e-09f,
6.29106785017040e-06f,
1.70198817374094e-03f,
1.16817656904453e-01f,
9.93151921023180e-01f,
0.5f,
};
v4sf GiSigmoidPsFloat32(v4sf src) {
auto val = GiMaximumFloat32(GiBroadcastFloat32(sigmoid_constants.lower_range), src);
val = GiMinimumFloat32(GiBroadcastFloat32(sigmoid_constants.upper_range), val);
auto squared = GiMultiplyFloat32(val, val);
auto p = GiMultiplyAddFloat32(
GiBroadcastFloat32(sigmoid_constants.alpha_7), squared,
GiBroadcastFloat32(sigmoid_constants.alpha_9));
p = GiMultiplyAddFloat32(GiBroadcastFloat32(sigmoid_constants.alpha_5), p, squared);
p = GiMultiplyAddFloat32(GiBroadcastFloat32(sigmoid_constants.alpha_3), p, squared);
p = GiMultiplyAddFloat32(GiBroadcastFloat32(sigmoid_constants.alpha_1), p, squared);
p = GiMultiplyFloat32(p, val);
auto q = GiMultiplyAddFloat32(
GiBroadcastFloat32(sigmoid_constants.beta_8), squared,
GiBroadcastFloat32(sigmoid_constants.beta_10));
q = GiMultiplyAddFloat32(GiBroadcastFloat32(sigmoid_constants.beta_6), q, squared);
q = GiMultiplyAddFloat32(GiBroadcastFloat32(sigmoid_constants.beta_4), q, squared);
q = GiMultiplyAddFloat32(GiBroadcastFloat32(sigmoid_constants.beta_2), q, squared);
q = GiMultiplyAddFloat32(GiBroadcastFloat32(sigmoid_constants.beta_0), q, squared);
return GiAddFloat32(
GiDivideFloat32(p, q), GiBroadcastFloat32(sigmoid_constants.one_half));
}
} // namespace fallback
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/fallback/elemwise/gi_impl/gi_mathfun.h
*/
#pragma once
#include "src/fallback/general_intrinsic/gi_float.h"
#include "src/fallback/general_intrinsic/gi_int.h"
namespace megdnn {
namespace fallback {
typedef GI_FLOAT32_t v4sf; // vector of 4 float
typedef GI_INT32_t v4si; // vector of 4 int32
typedef GI_UINT32_t v4su; // vector of 4 uint32
/**
* \brief natural logarithm computed for 4 simultaneous float
* return NaN for x <= 0
*/
v4sf GiLogPsFloat32(v4sf x);
//! exp() computed for 4 float at once
v4sf GiExpPsFloat32(v4sf x);
/**
* \brief evaluation of 4 sines & cosines at once.
*
* The code is the exact rewriting of the cephes sinf function.
* Precision is excellent as long as x < 8192 (I did not bother to
* take into account the special handling they have for greater values
* -- it does not return garbage for arguments over 8192, though, but
* the extra precision is missing).
*
* Note that it is such that sinf((float)M_PI) = 8.74e-8, which is the
* surprising but correct result.
*
* Note also that when you compute sin(x), cos(x) is available at
* almost no extra price so both sin_ps_f32 and cos_ps_f32 make use of
* sincos_ps_f32..
*/
void GiSinCosPsFloat32(v4sf x, v4sf* ysin, v4sf* ycos);
v4sf GiSinPsFloat32(v4sf x);
v4sf GiCosPsFloat32(v4sf x);
v4sf GiTanPsFloat32(v4sf x);
v4sf GiSigmoidPsFloat32(v4sf x);
} // namespace fallback
} // namespace megdnn
// vim: syntax=cpp.doxygen
此差异已折叠。
/**
* \file dnn/src/fallback/elemwise/gi_impl/ternary/algo.h
*/
#pragma once
#include "src/fallback/elemwise/opr_impl.h"
namespace megdnn {
namespace fallback {
#define DECL_CB(case) \
class ElemwiseImpl::AlgoTernaryFma3##case final : public ElemwiseImpl::AlgoBase { \
mutable std::string m_name; \
AlgoAttribute attribute() const override { \
return AlgoAttribute::REPRODUCIBLE; \
} \
const char* name() const override { \
if (m_name.empty()) { \
m_name = ssprintf("Elemwise::AlgoTernaryFma3" #case); \
} \
return m_name.c_str(); \
} \
bool is_available(const KernParam&) const override; \
void exec(const KernParam&) const override; \
};
DECL_CB(VecVecVec);
DECL_CB(VecVecScalar);
DECL_CB(Bcast101VecBcast101);
DECL_CB(Bcast111CVecBcast111C);
DECL_CB(Bcast101xXVecBcast101xX);
DECL_CB(VecBcast101Vec);
DECL_CB(VecBcast111CVec);
DECL_CB(VecBcast101xXVec);
DECL_CB(VecScalarVec);
DECL_CB(VecScalarScalar);
#undef DECL_CB
} // namespace fallback
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/fallback/elemwise/gi_impl/unary/algo.cpp
*/
#include "src/fallback/elemwise/gi_impl/unary/algo.h"
#include "src/fallback/elemwise_op.h"
#include "src/common/utils.h"
#include "src/naive/handle.h"
#include "midout.h"
MIDOUT_DECL(megdnn_fallback_elemwise_unary)
using namespace megdnn;
using namespace fallback;
bool ElemwiseImpl::AlgoUnary::is_available(const KernParam& kern_param) const {
if (BcastType::VEC != kern_param.broad_cast_type)
return false;
if (kern_param.m_dst->layout.dtype.category() != DTypeCategory::FLOAT &&
(kern_param.mode == Mode::EXP || kern_param.mode == Mode::SIGMOID ||
kern_param.mode == Mode::TANH || kern_param.mode == Mode::FAST_TANH ||
kern_param.mode == Mode::H_SWISH)) {
return false;
}
//! As `NEGATE` mode is so simple, that the code generate by compiler is
//! vectorized optimized, while other mode such as `ABS` has branch, the
//! compiler may not generate code as good as user intrinsic.
if (kern_param.mode == Mode::NEGATE) {
return false;
}
auto& elparam = kern_param.unary_elparam;
if (!elparam[0].layout.is_contiguous())
return false;
megdnn_assert(elparam[0].layout.ndim == 1);
auto& src0 = elparam[0];
#define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \
auto mode = kern_param.mode; \
if (mode == Mode::RELU || mode == Mode::ABS || mode == Mode::SIGMOID || \
mode == Mode::EXP || mode == Mode::TANH || mode == Mode::FAST_TANH || \
mode == Mode::H_SWISH) \
return true;
#define DISPATCH_MODE_INT(_case, _type, _type_midout_id) \
auto mode = kern_param.mode; \
if (mode == Mode::RELU || mode == Mode::ABS) \
return true;
DISPATCH_TYPE_FALLBACK("AlgoUnary::is_available"_hash);
return false;
#undef DISPATCH_MODE_FLOAT
#undef DISPATCH_MODE_INT
}
void ElemwiseImpl::AlgoUnary::exec(const KernParam& kern_param) const {
#define DISPATCH_UNARY(_mode, _case, _type, _type_midout_id, _op) \
case Mode::_mode: \
MIDOUT_BEGIN( \
megdnn_fallback_elemwise_unary, midout_iv(_case), \
midout_iv(Mode::_mode), midout_iv(_type_midout_id)) { \
thin_function<void(const _type*, _type*, DType, DType, size_t)> run = \
OpCallerUnary<_op<_type, _type>, BcastType::VEC>::run; \
auto kernel = [nr_elems, nr_elems_per_thread, src0, dst_tensor, run]( \
size_t task_id, size_t) { \
size_t offset = task_id * nr_elems_per_thread; \
size_t nr_elems_thread = \
std::min(nr_elems - offset, nr_elems_per_thread); \
run(static_cast<const _type*>(src0.raw_ptr()) + offset, \
static_cast<_type*>(dst_tensor.raw_ptr()) + offset, \
src0.layout.dtype, dst_tensor.layout.dtype, nr_elems_thread); \
}; \
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), nr_threads, \
kernel); \
} \
MIDOUT_END(); \
return
auto& elparam = kern_param.unary_elparam;
megdnn_assert(elparam[0].layout.ndim == 1);
auto& src0 = elparam[0];
auto& dst_tensor = *(kern_param.m_dst);
size_t nr_threads = static_cast<naive::HandleImpl*>(kern_param.handle)
->megcore_dispatcher()
->nr_threads();
size_t nr_elems = src0.layout.total_nr_elems();
size_t nr_elems_per_thread = (nr_elems + nr_threads - 1) / nr_threads;
#define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \
switch (kern_param.mode) { \
DISPATCH_UNARY(RELU, _case, _type, _type_midout_id, ReluOp); \
DISPATCH_UNARY(ABS, _case, _type, _type_midout_id, AbsOp); \
DISPATCH_UNARY(SIGMOID, _case, _type, _type_midout_id, SigmoidOp); \
DISPATCH_UNARY(EXP, _case, _type, _type_midout_id, ExpOp); \
DISPATCH_UNARY(TANH, _case, _type, _type_midout_id, TanhOp); \
DISPATCH_UNARY(FAST_TANH, _case, _type, _type_midout_id, FastTanhOp); \
DISPATCH_UNARY(H_SWISH, _case, _type, _type_midout_id, HSwishOp); \
default: \
megdnn_throw(ssprintf( \
"No avaiable algo find for: %d", \
static_cast<int>(kern_param.mode))); \
}
#define DISPATCH_MODE_INT(_case, _type, _type_midout_id) \
switch (kern_param.mode) { \
DISPATCH_UNARY(RELU, _case, _type, _type_midout_id, ReluOp); \
DISPATCH_UNARY(ABS, _case, _type, _type_midout_id, AbsOp); \
default: \
megdnn_throw(ssprintf( \
"No avaiable algo find for: %d", \
static_cast<int>(kern_param.mode))); \
}
DISPATCH_TYPE_FALLBACK("AlgoUnary::exec"_hash);
#undef DISPATCH_MODE_FLOAT
#undef DISPATCH_MODE_INT
#undef DISPATCH_UNARY
}
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/fallback/elemwise/gi_impl/unary/algo.h
*/
#pragma once
#include "src/fallback/elemwise/opr_impl.h"
namespace megdnn {
namespace fallback {
class ElemwiseImpl::AlgoUnary final : public ElemwiseImpl::AlgoBase {
mutable std::string m_name;
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
const char* name() const override {
if (m_name.empty()) {
m_name = ssprintf("Elemwise::AlgoUnary");
}
return m_name.c_str();
}
bool is_available(const KernParam&) const override;
void exec(const KernParam&) const override;
};
} // namespace fallback
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -12,6 +12,9 @@
#include "src/common/elemwise/kern_defs.cuh"
#include "src/common/utils.h"
#include "src/fallback//elemwise/gi_impl/unary/algo.h"
#include "src/fallback/elemwise/gi_impl/binary/algo.h"
#include "src/fallback/elemwise/gi_impl/ternary/algo.h"
#include "src/naive/handle.h"
#include "midout.h"
......@@ -21,13 +24,22 @@ MIDOUT_DECL(megdnn_fallback_elemwise_exec_UNARY_FLOAT)
MIDOUT_DECL(megdnn_fallback_elemwise_exec_BINARY_INT)
MIDOUT_DECL(megdnn_fallback_elemwise_exec_BINARY_FLOAT)
namespace megdnn {
namespace fallback {
using namespace megdnn;
using namespace fallback;
void ElemwiseImpl::exec(const TensorNDArray& srcs, _megdnn_tensor_out dst) {
if (!dst.layout.is_contiguous()) {
return naive::ElemwiseForwardImpl::exec(srcs, dst);
}
if (!exec_gi_intrinsic(srcs, dst)) {
return exec_fallback(srcs, dst);
}
}
void ElemwiseImpl::exec_fallback(const TensorNDArray& srcs, _megdnn_tensor_out dst) {
if (!dst.layout.is_contiguous()) {
return naive::ElemwiseForwardImpl::exec(srcs, dst);
}
m_src = &srcs;
m_dst = &dst;
......@@ -82,7 +94,229 @@ void ElemwiseImpl::exec(const TensorNDArray& srcs, _megdnn_tensor_out dst) {
naive::ElemwiseForwardImpl::exec(srcs, dst);
}
} // namespace fallback
} // namespace megdnn
class ElemwiseImpl::AlgoPack {
#if !(MEGDNN_AARCH64 || MEGDNN_ARMV7)
AlgoUnary algo_unary;
AlgoBinaryVecVec algo_binary_vec_vec;
AlgoBinaryVecScalar algo_binary_vec_sca;
AlgoBinaryVecBcast101 algo_binary_vec_bcast101;
AlgoBinaryVecBcastX0X algo_binary_vec_bcastX0X;
AlgoBinaryVecBcast111C algo_binary_vec_bcast110;
AlgoBinaryVecBcast101xX algo_binary_VEC_BCAST101xX;
AlgoTernaryFma3VecVecVec algo_ternaryfma3_vec_vec_vec;
AlgoTernaryFma3VecVecScalar algo_ternaryfma3_vec_vecsca;
AlgoTernaryFma3Bcast101VecBcast101 algo_ternaryfma3_bcast101_vec_bcast101;
AlgoTernaryFma3Bcast111CVecBcast111C algo_ternaryfma3_bcast110_vec_bcast110;
AlgoTernaryFma3Bcast101xXVecBcast101xX algo_ternaryfma3_bcast101xX_vec_bcast101xX;
AlgoTernaryFma3VecBcast101Vec algo_ternaryfma3_vec_bcast101_vec;
AlgoTernaryFma3VecBcast111CVec algo_ternaryfma3_vec_bcast110_vec;
AlgoTernaryFma3VecBcast101xXVec algo_ternaryfma3_vec_bcast101xX_vec;
AlgoTernaryFma3VecScalarVec algo_ternaryfma3_vec_sca_vec;
AlgoTernaryFma3VecScalarScalar algo_ternaryfma3_vec_sca_sca;
#endif
public:
AlgoPack() {
#if !(MEGDNN_AARCH64 || MEGDNN_ARMV7)
all_algos.emplace_back(&algo_unary);
all_algos.emplace_back(&algo_binary_vec_vec);
all_algos.emplace_back(&algo_binary_vec_sca);
all_algos.emplace_back(&algo_binary_vec_bcast101);
all_algos.emplace_back(&algo_binary_vec_bcastX0X);
all_algos.emplace_back(&algo_binary_vec_bcast110);
all_algos.emplace_back(&algo_binary_VEC_BCAST101xX);
all_algos.emplace_back(&algo_ternaryfma3_vec_vec_vec);
all_algos.emplace_back(&algo_ternaryfma3_vec_vecsca);
all_algos.emplace_back(&algo_ternaryfma3_bcast101_vec_bcast101);
all_algos.emplace_back(&algo_ternaryfma3_bcast110_vec_bcast110);
all_algos.emplace_back(&algo_ternaryfma3_bcast101xX_vec_bcast101xX);
all_algos.emplace_back(&algo_ternaryfma3_vec_bcast101_vec);
all_algos.emplace_back(&algo_ternaryfma3_vec_bcast110_vec);
all_algos.emplace_back(&algo_ternaryfma3_vec_bcast101xX_vec);
all_algos.emplace_back(&algo_ternaryfma3_vec_sca_vec);
all_algos.emplace_back(&algo_ternaryfma3_vec_sca_sca);
#endif
}
SmallVector<AlgoBase*> all_algos;
};
bool ElemwiseImpl::exec_gi_intrinsic(
const TensorNDArray& srcs, _megdnn_tensor_out dst) {
m_src = &srcs;
m_dst = &dst;
if (m_dst->layout.dtype == dtype::Float32() ||
m_dst->layout.dtype == dtype::Int32() ||
m_dst->layout.dtype == dtype::Int16() || m_dst->layout.dtype == dtype::Int8()) {
auto kern_param = make_kern_param(this);
kern_param.m_dst = &dst;
static AlgoPack m_algo_pack;
for (auto& m_algo : m_algo_pack.all_algos) {
if (m_algo->is_available(kern_param)) {
m_algo->exec(kern_param);
return true;
}
}
}
return false;
}
ElemwiseImpl::KernParam ElemwiseImpl::make_kern_param(ElemwiseImpl* opr) {
KernParam kern_param;
kern_param.broad_cast_type = BcastType::UNKNOWN_BCAST_TYPE;
kern_param.mode = opr->param().mode;
kern_param.handle = opr->handle();
auto is_legal_layout_for_nhwc = [](const TensorLayout& l) {
if (is_vector(l))
return true;
if (l.ndim == 2 && l.stride[1] == 1)
return true;
return false;
};
if ((opr->m_src->size() == 3) && (opr->param().mode == Mode::FUSE_MUL_ADD3)) {
kern_param.ternary_elparam = opr->make_elemwise_op_param<3>();
bool c_is_scalar;
opr->prepare_fma3(kern_param.ternary_elparam, c_is_scalar);
auto &src0 = kern_param.ternary_elparam[0],
&src1 = kern_param.ternary_elparam[1],
&src2 = kern_param.ternary_elparam[2];
BroadcastChannelInfo binfo;
if (is_vector(src0.layout) && is_vector(src1.layout) &&
is_vector(src2.layout)) {
kern_param.broad_cast_type = BcastType::VEC_VEC_VEC;
return kern_param;
}
if (is_vector(src0.layout) && is_vector(src1.layout) && c_is_scalar) {
kern_param.broad_cast_type = BcastType::VEC_VEC_SCALAR;
return kern_param;
}
if (is_vector(src1.layout) && is_broadcasted_channel_like(src0.layout, binfo) &&
src0.layout.eq_layout(src2.layout)) {
kern_param.broad_cast_type = BcastType::BCAST101_VEC_BCAST101;
return kern_param;
}
if (is_vector(src1.layout) &&
(is_broadcastedx_channel_like<4>(src0.layout, binfo) ||
is_broadcastedx_channel_like<8>(src0.layout, binfo)) &&
src0.layout.eq_layout(src2.layout)) {
kern_param.broad_cast_type = BcastType::BCAST101xX_VEC_BCAST101xX;
return kern_param;
}
if (is_vector(src0.layout) && src0.layout.eq_layout(src2.layout) &&
is_broadcasted_channel_like(src1.layout, binfo)) {
kern_param.broad_cast_type = BcastType::VEC_BCAST101_VEC;
return kern_param;
}
if (is_legal_layout_for_nhwc(src1.layout) &&
is_NHWC_broadcasted_channel_like(src0.layout, binfo) &&
src0.layout.eq_layout(src2.layout)) {
kern_param.broad_cast_type = BcastType::BCAST111C_VEC_BCAST111C;
return kern_param;
}
if (is_legal_layout_for_nhwc(src0.layout) &&
src2.layout.eq_layout(src0.layout) &&
is_NHWC_broadcasted_channel_like(src1.layout, binfo)) {
kern_param.broad_cast_type = BcastType::VEC_BCAST111C_VEC;
return kern_param;
}
if (is_vector(src0.layout) && src0.layout.eq_layout(src2.layout) &&
(is_broadcastedx_channel_like<4>(src1.layout, binfo) ||
is_broadcastedx_channel_like<8>(src1.layout, binfo))) {
kern_param.broad_cast_type = BcastType::VEC_BCAST101xX_VEC;
return kern_param;
}
if (is_vector(src0.layout) && is_vector(src2.layout) &&
is_broadcasted_scalar(src1.layout)) {
kern_param.broad_cast_type = BcastType::VEC_SCALAR_VEC;
return kern_param;
}
if (is_vector(src0.layout) && is_broadcasted_scalar(src1.layout) &&
is_broadcasted_scalar(src2.layout)) {
kern_param.broad_cast_type = BcastType::VEC_SCALAR_SCALAR;
return kern_param;
}
} else if (opr->m_src->size() == 2) {
kern_param.binary_elparam = opr->make_elemwise_op_param<2>();
auto &src0 = kern_param.binary_elparam[0], &src1 = kern_param.binary_elparam[1];
BroadcastChannelInfo binfo;
if (is_vector(src0.layout) && is_vector(src1.layout)) {
kern_param.broad_cast_type = BcastType::VEC_VEC;
return kern_param;
}
if (is_vector(src0.layout) && is_broadcasted_scalar(src1.layout)) {
kern_param.broad_cast_type = BcastType::VEC_SCALAR;
return kern_param;
}
if (is_vector(src1.layout) && is_broadcasted_scalar(src0.layout)) {
kern_param.broad_cast_type = BcastType::SCALAR_VEC;
return kern_param;
}
if (is_vector(src0.layout) && is_broadcasted_channel_like(src1.layout, binfo)) {
kern_param.broad_cast_type = BcastType::VEC_BCAST101;
return kern_param;
}
if (is_vector(src1.layout) && is_broadcasted_channel_like(src0.layout, binfo)) {
kern_param.broad_cast_type = BcastType::BCAST101_VEC;
return kern_param;
}
if (is_vector(src0.layout) && is_broadcasted_3dim_like(src1.layout, binfo)) {
kern_param.broad_cast_type = BcastType::VEC_BCASTX0X;
return kern_param;
}
if (is_vector(src1.layout) && is_broadcasted_3dim_like(src0.layout, binfo)) {
kern_param.broad_cast_type = BcastType::BCASTX0X_VEC;
return kern_param;
}
if (is_legal_layout_for_nhwc(src1.layout) &&
is_NHWC_broadcasted_channel_like(src0.layout, binfo)) {
kern_param.broad_cast_type = BcastType::BCAST111C_VEC;
return kern_param;
}
if (is_legal_layout_for_nhwc(src0.layout) &&
is_NHWC_broadcasted_channel_like(src1.layout, binfo)) {
kern_param.broad_cast_type = BcastType::VEC_BCAST111C;
return kern_param;
}
if (is_vector(src0.layout) &&
(is_broadcastedx_channel_like<4>(src1.layout, binfo) ||
is_broadcastedx_channel_like<8>(src1.layout, binfo))) {
kern_param.broad_cast_type = BcastType::VEC_BCAST101xX;
return kern_param;
}
if (is_vector(src1.layout) &&
(is_broadcastedx_channel_like<4>(src0.layout, binfo) ||
is_broadcastedx_channel_like<8>(src0.layout, binfo))) {
kern_param.broad_cast_type = BcastType::BCAST101xX_VEC;
return kern_param;
}
} else if (opr->m_src->size() == 1) {
kern_param.broad_cast_type = BcastType::VEC;
kern_param.unary_elparam = opr->make_elemwise_op_param<1>();
return kern_param;
}
return kern_param;
}
// vim: syntax=cpp.doxygen
......@@ -9,6 +9,7 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "src/fallback/elemwise_op.h"
#include "src/naive/elemwise/opr_impl.h"
namespace megdnn {
......@@ -33,13 +34,69 @@ class ElemwiseImpl : public naive::ElemwiseForwardImpl {
template <uint32_t mode>
void exec_BINARY_FLOAT();
void exec_fallback(const TensorNDArray& srcs, _megdnn_tensor_out dst);
bool exec_gi_intrinsic(const TensorNDArray& srcs, _megdnn_tensor_out dst);
private:
class AlgoUnary;
class AlgoBinaryVecVec;
class AlgoBinaryVecScalar;
class AlgoBinaryVecBcast101;
class AlgoBinaryVecBcastX0X;
class AlgoBinaryVecBcast111C;
class AlgoBinaryVecBcast101xX;
class AlgoTernaryFma3VecVecVec;
class AlgoTernaryFma3VecVecScalar;
class AlgoTernaryFma3Bcast101VecBcast101;
class AlgoTernaryFma3Bcast111CVecBcast111C;
class AlgoTernaryFma3Bcast101xXVecBcast101xX;
class AlgoTernaryFma3VecBcast101Vec;
class AlgoTernaryFma3VecBcast111CVec;
class AlgoTernaryFma3VecBcast101xXVec;
class AlgoTernaryFma3VecScalarVec;
class AlgoTernaryFma3VecScalarScalar;
class AlgoPack;
public:
class AlgoBase;
struct KernParam {
BcastType broad_cast_type;
Mode mode;
const TensorND* m_dst;
Handle* handle;
ElemwiseOpParamN<3> ternary_elparam;
ElemwiseOpParamN<2> binary_elparam;
ElemwiseOpParamN<1> unary_elparam;
};
KernParam make_kern_param(ElemwiseImpl* opr);
using naive::ElemwiseForwardImpl::ElemwiseForwardImpl;
void exec(const TensorNDArray& srcs, _megdnn_tensor_out dst) override;
const char* get_algorithm_set_name() const { return "FALLBACK ELEMWISE"; }
bool is_thread_safe() const override { return true; }
};
/*!
* \brief base class for Elemwise algo
*/
class ElemwiseImpl::AlgoBase : public detail::Algorithm {
public:
virtual bool is_available(const KernParam&) const = 0;
virtual void exec(const KernParam&) const = 0;
virtual ~AlgoBase() = default;
uint32_t type() const override { return INVALID_ALGO_TYPE; };
};
//! fallback only support float, int32, int8
#define DISPATCH_TYPE_FALLBACK(_case) \
if (src0.layout.dtype == dtype::Float32{}) { \
DISPATCH_MODE_FLOAT(_case, float, 0); \
} else if (src0.layout.dtype == dtype::Int32{}) { \
DISPATCH_MODE_INT(_case, int, 2); \
} else if (src0.layout.dtype == dtype::Int8{}) { \
DISPATCH_MODE_INT(_case, dt_int8, 4); \
}
} // namespace fallback
} // namespace megdnn
......
/**
* \file dnn/src/fallback/elemwise_helper/kimpl/abs.h
*/
#pragma once
#include "src/fallback/elemwise_helper/kimpl/op_base.h"
namespace megdnn {
namespace fallback {
template <typename src_ctype, typename dst_ctype = src_ctype>
struct AbsOpBase : UnaryOpBase<src_ctype, dst_ctype> {
using UnaryOpBase<src_ctype, dst_ctype>::UnaryOpBase;
void operator()(const src_ctype& src, dst_ctype* dst) const {
*dst = operator()(src);
}
dst_ctype operator()(const src_ctype& src) const { return src > 0 ? src : (-src); }
};
template <typename src_ctype, typename dst_ctype = src_ctype>
struct AbsOp;
#define OP(_ctype, _gi_type, _func_suffix, _simd_width) \
template <> \
struct AbsOp<_ctype> : AbsOpBase<_ctype> { \
using AbsOpBase::AbsOpBase; \
using AbsOpBase::operator(); \
constexpr static size_t SIMD_WIDTH = _simd_width; \
void operator()(const _gi_type& src, _ctype* dst) const { \
auto vitem = operator()(src); \
GiStore##_func_suffix(dst, vitem.val[0]); \
GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \
} \
_gi_type operator()(const _gi_type& src) const { \
auto vitem0 = GiAbs##_func_suffix(src.val[0]); \
auto vitem1 = GiAbs##_func_suffix(src.val[1]); \
return {{vitem0, vitem1}}; \
} \
};
OP(dt_float32, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(dt_float32))
OP(dt_int32, GI_INT32_V2_t, Int32, GI_SIMD_LEN_BYTE / sizeof(dt_int32))
OP(dt_int8, GI_INT8_V2_t, Int8, GI_SIMD_LEN_BYTE / sizeof(dt_int8))
#undef OP
template <>
struct AbsOpBase<dt_qint8, dt_qint8> : UnaryOpBase<dt_qint8, dt_qint8> {
using UnaryOpBase::UnaryOpBase;
void operator()(const dt_qint8& src, dt_qint8* dst) const {
*dst = operator()(src);
}
dt_qint8 operator()(const dt_qint8& src) const {
float fsrc = src.as_int8() * this->scale;
fsrc = fsrc > 0 ? fsrc : -fsrc;
return QConverter::convert<dt_qint8, float>(fsrc);
}
};
template <>
struct AbsOp<dt_qint8, dt_qint8> : AbsOpBase<dt_qint8, dt_qint8> {
using AbsOpBase::AbsOpBase;
constexpr static size_t SIMD_WIDTH = 16;
using AbsOpBase::operator();
void operator()(const GI_INT8_V2_t& vsrc, dt_qint8* dst) const {
OPERATOR_UNARY_QINT8_FALLBACK;
}
GI_INT8_t operator()(const GI_INT32_V2_t& vsrc) const {
auto vitem0 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[0]), this->vscale);
auto vitem1 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[1]), this->vscale);
vitem0 = GiAbsFloat32(vitem0);
vitem1 = GiAbsFloat32(vitem1);
return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>({{vitem0, vitem1}});
}
};
} // namespace fallback
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/fallback/elemwise_helper/kimpl/add.h
*/
#pragma once
#include "src/fallback/elemwise_helper/kimpl/op_base.h"
namespace megdnn {
namespace fallback {
template <typename src_ctype, typename dst_ctype = src_ctype>
struct AddOpBase : BinaryOpBase<src_ctype, dst_ctype> {
using BinaryOpBase<src_ctype, dst_ctype>::BinaryOpBase;
void operator()(
const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const {
*dst = operator()(src0, src1);
}
dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const {
return src0 + src1;
}
};
template <typename src_ctype, typename dst_ctype = src_ctype>
struct AddOp;
#define OP(_ctype, _gi_type, _gi_type2, _func_suffix, _simd_width) \
template <> \
struct AddOp<_ctype> : AddOpBase<_ctype> { \
using AddOpBase::AddOpBase; \
using AddOpBase::operator(); \
constexpr static size_t SIMD_WIDTH = _simd_width; \
void operator()( \
const _gi_type2& src0, const _gi_type2& src1, dst_ctype* dst) const { \
auto vitem = operator()(src0, src1); \
GiStore##_func_suffix(dst, vitem.val[0]); \
GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \
} \
_gi_type2 operator()(const _gi_type2& src0, const _gi_type2& src1) const { \
auto vitem0 = GiAdd##_func_suffix(src0.val[0], src1.val[0]); \
auto vitem1 = GiAdd##_func_suffix(src0.val[1], src1.val[1]); \
return {{vitem0, vitem1}}; \
} \
void operator()( \
const _gi_type& src0, const _gi_type& src1, dst_ctype* dst) const { \
auto vitem = operator()(src0, src1); \
GiStore##_func_suffix(dst, vitem); \
} \
_gi_type operator()(const _gi_type& src0, const _gi_type& src1) const { \
return GiAdd##_func_suffix(src0, src1); \
} \
};
OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32,
GI_SIMD_LEN_BYTE / sizeof(dt_float32))
OP(dt_int32, GI_INT32_t, GI_INT32_V2_t, Int32, GI_SIMD_LEN_BYTE / sizeof(dt_int32))
OP(dt_int8, GI_INT8_t, GI_INT8_V2_t, Int8, GI_SIMD_LEN_BYTE / sizeof(dt_int8))
#undef OP
template <>
struct AddOpBase<dt_qint8, dt_qint8> : BinaryOpBase<dt_qint8, dt_qint8> {
using BinaryOpBase::BinaryOpBase;
void operator()(const dt_qint8& src0, const dt_qint8& src1, dt_qint8* dst) const {
*dst = operator()(src0, src1);
}
dt_qint8 operator()(const dt_qint8& src0, const dt_qint8& src1) const {
return QConverter::convert<dt_qint8, float>(
src0.as_int8() * this->scale0 + src1.as_int8() * this->scale1);
}
};
template <>
struct AddOp<dt_qint8, dt_qint8> : AddOpBase<dt_qint8, dt_qint8> {
using AddOpBase::AddOpBase;
constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t);
using AddOpBase::operator();
void operator()(
const GI_INT8_V2_t& vsrc0, const GI_INT8_V2_t& vsrc1, dt_qint8* dst) const {
OPERATOR_BINARY_QINT8_FALLBACK;
}
GI_INT8_t operator()(const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1) const {
auto vitem0 = GiAddFloat32(
GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[0]), this->vscale0),
GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[0]), this->vscale1));
auto vitem1 = GiAddFloat32(
GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[1]), this->vscale0),
GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[1]), this->vscale1));
return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>({{vitem0, vitem1}});
}
};
template <>
struct AddOpBase<dt_qint32, dt_qint8> : BinaryOpBase<dt_qint32, dt_qint8> {
using BinaryOpBase::BinaryOpBase;
void operator()(const dt_qint32& src0, const dt_qint32& src1, dt_qint8* dst) const {
*dst = operator()(src0, src1);
}
dt_qint8 operator()(const dt_qint32& src0, const dt_qint32& src1) const {
return QConverter::convert<dt_qint8, float>(
src0.as_int32() * this->scale0 + src1.as_int32() * this->scale1);
}
};
template <>
struct AddOp<dt_qint32, dt_qint8> : AddOpBase<dt_qint32, dt_qint8> {
using AddOpBase::AddOpBase;
using AddOpBase::operator();
constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int32_t);
void operator()(
const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1,
dt_qint8* dst) const {
GiStoreLowInt8(reinterpret_cast<int8_t*>(dst), operator()(vsrc0, vsrc1));
}
GI_INT8_t operator()(const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1) const {
auto vitem0 = GiAddFloat32(
GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[0]), this->vscale0),
GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[0]), this->vscale1));
auto vitem1 = GiAddFloat32(
GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[1]), this->vscale0),
GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[1]), this->vscale1));
return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>({{vitem0, vitem1}});
}
};
} // namespace fallback
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/fallback/elemwise_helper/kimpl/exp.h
*/
#pragma once
#include "src/fallback/elemwise_helper/kimpl/op_base.h"
namespace megdnn {
namespace fallback {
template <typename src_ctype, typename dst_ctype = src_ctype>
struct ExpOpBase : UnaryOpBase<src_ctype, dst_ctype> {
using UnaryOpBase<src_ctype, dst_ctype>::UnaryOpBase;
void operator()(const src_ctype& src, dst_ctype* dst) const {
*dst = operator()(src);
}
dst_ctype operator()(const src_ctype& src) const {
float tmp = src;
return exp(tmp);
}
};
template <typename src_ctype, typename dst_ctype = src_ctype>
struct ExpOp;
#define OP(_ctype, _simd_type, _func_suffix, _simd_width) \
template <> \
struct ExpOp<_ctype> : ExpOpBase<_ctype> { \
using ExpOpBase::ExpOpBase; \
using ExpOpBase::operator(); \
constexpr static size_t SIMD_WIDTH = _simd_width; \
void operator()(const _simd_type& src, _ctype* dst) const { \
auto vitem = operator()(src); \
GiStore##_func_suffix(dst, vitem.val[0]); \
GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \
} \
_simd_type operator()(const _simd_type& src) const { \
auto vitem0 = GiExpPs##_func_suffix(src.val[0]); \
auto vitem1 = GiExpPs##_func_suffix(src.val[1]); \
return {{vitem0, vitem1}}; \
} \
};
OP(dt_float32, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float))
#undef OP
} // namespace fallback
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/fallback/elemwise_helper/kimpl/fast_tanh.h
*/
#pragma once
#include "src/fallback/elemwise_helper/kimpl/op_base.h"
namespace megdnn {
namespace fallback {
//! tanh = x * (27 + x^2) / (27 + 9 * x^2)
template <typename src_ctype, typename dst_ctype = src_ctype>
struct FastTanhOpBase : UnaryOpBase<src_ctype, dst_ctype> {
using UnaryOpBase<src_ctype, dst_ctype>::UnaryOpBase;
void operator()(const src_ctype& src, dst_ctype* dst) const {
*dst = operator()(src);
}
dst_ctype operator()(const src_ctype& src) const {
float x = src;
return x * (27.f + x * x) / (27.f + 9.f * x * x);
}
};
template <typename src_ctype, typename dst_ctype = src_ctype>
struct FastTanhOp;
#define OP(_ctype, _simd_type, _func_suffix, _fix_func_suffix, _simd_width) \
template <> \
struct FastTanhOp<_ctype> : FastTanhOpBase<_ctype> { \
using FastTanhOpBase::FastTanhOpBase; \
using FastTanhOpBase::operator(); \
constexpr static size_t SIMD_WIDTH = _simd_width; \
void operator()(const _simd_type& src, _ctype* dst) const { \
auto vitem = operator()(src); \
GiStore##_func_suffix(dst, vitem.val[0]); \
GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \
} \
_simd_type operator()(const _simd_type& src) const { \
auto val_27 = GiBroadcast##_func_suffix(27.f); \
auto val_9 = GiBroadcast##_func_suffix(9.f); \
auto valx = src.val[0]; \
auto valx1 = src.val[1]; \
auto valxp2 = GiMultiply##_fix_func_suffix(valx, valx); \
auto valx1p2 = GiMultiply##_fix_func_suffix(valx1, valx1); \
auto denominator = GiAdd##_fix_func_suffix(valxp2, val_27); \
auto denominator1 = GiAdd##_fix_func_suffix(valx1p2, val_27); \
valx = GiMultiply##_fix_func_suffix(valx, denominator); \
valx1 = GiMultiply##_fix_func_suffix(valx1, denominator1); \
denominator = GiMultiplyAdd##_fix_func_suffix(val_27, valxp2, val_9); \
denominator1 = GiMultiplyAdd##_fix_func_suffix(val_27, valx1p2, val_9); \
auto r_denominator = GiRecpe##_func_suffix(denominator); \
auto r_denominator1 = GiRecpe##_func_suffix(denominator1); \
r_denominator = GiMultiply##_fix_func_suffix( \
GiRecpeS##_func_suffix(denominator, r_denominator), \
r_denominator); \
r_denominator1 = GiMultiply##_fix_func_suffix( \
GiRecpeS##_func_suffix(denominator1, r_denominator1), \
r_denominator1); \
valx = GiMultiply##_fix_func_suffix(valx, r_denominator); \
valx1 = GiMultiply##_fix_func_suffix(valx1, r_denominator1); \
return {{valx, valx1}}; \
} \
};
OP(dt_float32, GI_FLOAT32_V2_t, Float32, Float32, GI_SIMD_LEN_BYTE / sizeof(float))
#undef OP
} // namespace fallback
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/fallback/elemwise_helper/kimpl/fuse_add_h_swish.h
*/
#pragma once
#include "src/fallback/elemwise_helper/kimpl/kern_macro_prologue.h"
#include "src/fallback/elemwise_helper/kimpl/op_base.h"
namespace megdnn {
namespace fallback {
template <typename src_ctype, typename dst_ctype = src_ctype>
struct FuseAddHSwishOpBase : BinaryOpBase<src_ctype, dst_ctype> {
using BinaryOpBase<src_ctype, dst_ctype>::BinaryOpBase;
void operator()(
const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const {
*dst = operator()(src0, src1);
}
dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const {
float tmp = src0 + src1;
tmp = tmp * std::max(std::min(tmp + 3.f, 6.f), 0.f) / 6.f;
return tmp;
}
};
template <typename src_ctype, typename dst_ctype = src_ctype>
struct FuseAddHSwishOp;
#define OP(_ctype, _simd_type, _simd_type2, _func_suffix, _simd_width) \
template <> \
struct FuseAddHSwishOp<_ctype> : FuseAddHSwishOpBase<_ctype> { \
using FuseAddHSwishOpBase::FuseAddHSwishOpBase; \
using FuseAddHSwishOpBase::operator(); \
constexpr static size_t SIMD_WIDTH = _simd_width; \
void operator()( \
const _simd_type2& src0, const _simd_type2& src1, \
dst_ctype* dst) const { \
auto vitem = operator()(src0, src1); \
GiStore##_func_suffix(dst, vitem.val[0]); \
GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \
} \
_simd_type2 operator()( \
const _simd_type2& src0, const _simd_type2& src1) const { \
auto val1 = src0.val[0]; \
auto val2 = src0.val[1]; \
auto val3 = src1.val[0]; \
auto val4 = src1.val[1]; \
val1 = GiAdd##_func_suffix(val1, val3); \
val2 = GiAdd##_func_suffix(val2, val4); \
H_SWISH_KERN_FALLBACK(_func_suffix, val1, val2); \
return {{val1, val2}}; \
} \
void operator()( \
const _simd_type& src0, const _simd_type& src1, \
dst_ctype* dst) const { \
auto vitem = operator()(src0, src1); \
GiStore##_func_suffix(dst, vitem); \
} \
_simd_type operator()(const _simd_type& src0, const _simd_type& src1) const { \
auto val1 = src0; \
auto val2 = src1; \
val1 = GiAdd##_func_suffix(val1, val2); \
H_SWISH_KERN_N1_FALLBACK(_func_suffix, val1); \
return val1; \
} \
};
OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float))
#undef OP
template <>
struct FuseAddHSwishOpBase<dt_qint32, dt_qint8> : BinaryOpBase<dt_qint32, dt_qint8> {
using BinaryOpBase::BinaryOpBase;
void operator()(const dt_qint32& src0, const dt_qint32& src1, dt_qint8* dst) const {
*dst = operator()(src0, src1);
}
dt_qint8 operator()(const dt_qint32& src0, const dt_qint32& src1) const {
float tmp =
src0.as_int32() * this->scale_src0 + src1.as_int32() * this->scale_src1;
tmp = tmp * std::max(std::min(tmp + 3.f, 6.f), 0.f) / 6.f;
tmp *= this->scale_dst;
return QConverter::convert<dt_qint8, float>(tmp);
}
};
template <>
struct FuseAddHSwishOp<dt_qint32, dt_qint8> : FuseAddHSwishOpBase<dt_qint32, dt_qint8> {
using FuseAddHSwishOpBase::FuseAddHSwishOpBase;
using FuseAddHSwishOpBase::operator();
constexpr static size_t SIMD_WIDTH = 4;
void operator()(
const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1,
dt_qint8* dst) const {
GiStoreLowInt8(reinterpret_cast<int8_t*>(dst), operator()(vsrc0, vsrc1));
}
GI_INT8_t operator()(const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1) const {
GI_FLOAT32_t vitem0, vitem1;
vitem0 = GiAddFloat32(
GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[0]), this->vscale_src0),
GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[0]), this->vscale_src1));
vitem1 = GiAddFloat32(
GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[1]), this->vscale_src0),
GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[1]), this->vscale_src1));
H_SWISH_KERN_FALLBACK(Float32, vitem0, vitem1);
vitem0 = GiMultiplyFloat32(vitem0, this->vscale_dst);
vitem1 = GiMultiplyFloat32(vitem1, this->vscale_dst);
return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>({{vitem0, vitem1}});
}
};
#include "src/fallback/elemwise_helper/kimpl/kern_macro_epilogue.h"
} // namespace fallback
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/fallback/elemwise_helper/kimpl/fuse_add_relu.h
*/
#pragma once
#include "gi_util_impl_helper.h"
#include "src/fallback/elemwise_helper/kimpl/op_base.h"
namespace megdnn {
namespace fallback {
template <typename src_ctype, typename dst_ctype = src_ctype>
struct FuseAddReluOpBase : BinaryOpBase<src_ctype, dst_ctype> {
using BinaryOpBase<src_ctype, dst_ctype>::BinaryOpBase;
void operator()(
const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const {
*dst = operator()(src0, src1);
}
dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const {
auto tmp = src0 + src1;
return tmp > 0 ? tmp : 0;
}
};
template <typename src_ctype, typename dst_ctype = src_ctype>
struct FuseAddReluOp;
#define OP(_ctype, _simd_type, _simd_type2, _func_suffix, _simd_width) \
template <> \
struct FuseAddReluOp<_ctype> : FuseAddReluOpBase<_ctype> { \
using FuseAddReluOpBase::FuseAddReluOpBase; \
using FuseAddReluOpBase::operator(); \
constexpr static size_t SIMD_WIDTH = _simd_width; \
void operator()( \
const _simd_type2& src0, const _simd_type2& src1, \
dst_ctype* dst) const { \
auto vitem = operator()(src0, src1); \
GiStore##_func_suffix(dst, vitem.val[0]); \
GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \
} \
_simd_type2 operator()( \
const _simd_type2& src0, const _simd_type2& src1) const { \
auto val1 = src0.val[0]; \
auto val2 = src0.val[1]; \
auto val3 = src1.val[0]; \
auto val4 = src1.val[1]; \
FUSE_ADD_RELU_SIMD_PACK2_FALLBACK(val1, val2, val3, val4, _func_suffix); \
return {{val1, val2}}; \
} \
void operator()( \
const _simd_type& src0, const _simd_type& src1, \
dst_ctype* dst) const { \
auto vitem = operator()(src0, src1); \
GiStore##_func_suffix(dst, vitem); \
} \
_simd_type operator()(const _simd_type& src0, const _simd_type& src1) const { \
auto val1 = src0; \
auto val2 = src1; \
FUSE_ADD_RELU_SIMD_PACK_FALLBACK(val1, val2, _func_suffix); \
return val1; \
} \
};
OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float))
OP(dt_int32, GI_INT32_t, GI_INT32_V2_t, Int32, GI_SIMD_LEN_BYTE / sizeof(int32_t))
OP(dt_int8, GI_INT8_t, GI_INT8_V2_t, Int8, GI_SIMD_LEN_BYTE / sizeof(int8_t))
#undef OP
template <typename ctype>
struct FuseAddReluOpCommon;
template <>
struct FuseAddReluOpCommon<float> {
inline static GI_FLOAT32_t vzero() { return GiBroadcastFloat32(0); }
};
template <>
struct FuseAddReluOpCommon<int> {
inline static GI_INT32_t vzero() { return GiBroadcastInt32(0); }
};
template <>
struct FuseAddReluOpBase<dt_qint8, dt_qint8> : BinaryOpBase<dt_qint8, dt_qint8> {
using BinaryOpBase::BinaryOpBase;
void operator()(const dt_qint8& src0, const dt_qint8& src1, dt_qint8* dst) const {
*dst = operator()(src0, src1);
}
dt_qint8 operator()(const dt_qint8& src0, const dt_qint8& src1) const {
return QConverter::convert<dt_qint8, float>(std::max<float>(
src0.as_int8() * this->scale0 + src1.as_int8() * this->scale1, 0.f));
}
};
template <>
struct FuseAddReluOp<dt_qint8, dt_qint8> : FuseAddReluOpBase<dt_qint8, dt_qint8>,
FuseAddReluOpCommon<float> {
using FuseAddReluOpBase::FuseAddReluOpBase;
using FuseAddReluOpBase::operator();
constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t);
void operator()(
const GI_INT8_V2_t& vsrc0, const GI_INT8_V2_t& vsrc1, dt_qint8* dst) const {
OPERATOR_BINARY_QINT8_FALLBACK;
}
GI_INT8_t operator()(const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1) const {
auto vitem0 = GiAddFloat32(
GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[0]), this->vscale0),
GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[0]), this->vscale1));
auto vitem1 = GiAddFloat32(
GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[1]), this->vscale0),
GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[1]), this->vscale1));
vitem0 = GiMaximumFloat32(vitem0, this->vzero());
vitem1 = GiMaximumFloat32(vitem1, this->vzero());
return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>({{vitem0, vitem1}});
}
};
template <>
struct FuseAddReluOpBase<dt_qint32, dt_qint8> : BinaryOpBase<dt_qint32, dt_qint8> {
using BinaryOpBase::BinaryOpBase;
void operator()(const dt_qint32& src0, const dt_qint32& src1, dt_qint8* dst) const {
*dst = operator()(src0, src1);
}
dt_qint8 operator()(const dt_qint32& src0, const dt_qint32& src1) const {
return QConverter::convert<dt_qint8, float>(std::max<float>(
src0.as_int32() * this->scale0 + src1.as_int32() * this->scale1, 0.f));
}
};
template <>
struct FuseAddReluOp<dt_qint32, dt_qint8> : FuseAddReluOpBase<dt_qint32, dt_qint8>,
FuseAddReluOpCommon<float> {
using FuseAddReluOpBase::FuseAddReluOpBase;
using FuseAddReluOpBase::operator();
constexpr static size_t SIMD_WIDTH = 4;
void operator()(
const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1,
dt_qint8* dst) const {
GiStoreLowInt8(reinterpret_cast<int8_t*>(dst), operator()(vsrc0, vsrc1));
}
GI_INT8_t operator()(const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1) const {
auto vitem0 = GiAddFloat32(
GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[0]), this->vscale0),
GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[0]), this->vscale1));
auto vitem1 = GiAddFloat32(
GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[1]), this->vscale0),
GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[1]), this->vscale1));
vitem0 = GiMaximumFloat32(vitem0, this->vzero());
vitem1 = GiMaximumFloat32(vitem1, this->vzero());
return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>({{vitem0, vitem1}});
}
};
} // namespace fallback
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/fallback/elemwise_helper/kimpl/fuse_add_sigmoid.h
*/
#pragma once
#include "src/fallback/elemwise_helper/kimpl/op_base.h"
namespace megdnn {
namespace fallback {
template <typename src_ctype, typename dst_ctype = src_ctype>
struct FuseAddSigmoidOpBase : BinaryOpBase<src_ctype, dst_ctype> {
using BinaryOpBase<src_ctype, dst_ctype>::BinaryOpBase;
void operator()(
const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const {
*dst = operator()(src0, src1);
}
dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const {
float tmpf = src0 + src1;
tmpf = exp(-tmpf);
tmpf = 1.f / (1.f + tmpf);
return tmpf;
}
};
template <typename src_ctype, typename dst_ctype = src_ctype>
struct FuseAddSigmoidOp;
#define OP(_ctype, _simd_type, _func_suffix, _simd_width) \
template <> \
struct FuseAddSigmoidOp<_ctype> : FuseAddSigmoidOpBase<_ctype> { \
using FuseAddSigmoidOpBase::FuseAddSigmoidOpBase; \
using FuseAddSigmoidOpBase::operator(); \
constexpr static size_t SIMD_WIDTH = _simd_width; \
void operator()( \
const _simd_type& src0, const _simd_type& src1, \
dst_ctype* dst) const { \
auto vitem = operator()(src0, src1); \
GiStore##_func_suffix(dst, vitem.val[0]); \
GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \
} \
_simd_type operator()(const _simd_type& src0, const _simd_type& src1) const { \
auto val1 = src0.val[0]; \
auto val2 = src0.val[1]; \
auto val3 = src1.val[0]; \
auto val4 = src1.val[1]; \
val1 = GiAdd##_func_suffix(val1, val3); \
val2 = GiAdd##_func_suffix(val2, val4); \
val1 = GiSigmoidPs##_func_suffix(val1); \
val2 = GiSigmoidPs##_func_suffix(val2); \
return {{val1, val2}}; \
} \
};
OP(dt_float32, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float))
#undef OP
} // namespace fallback
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/fallback/elemwise_helper/kimpl/fuse_add_tanh.h
*/
#pragma once
#include "src/fallback/elemwise_helper/kimpl/op_base.h"
namespace megdnn {
namespace fallback {
template <typename src_ctype, typename dst_ctype = src_ctype>
struct FuseAddTanhOpBase : BinaryOpBase<src_ctype, dst_ctype> {
using BinaryOpBase<src_ctype, dst_ctype>::BinaryOpBase;
void operator()(
const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const {
*dst = operator()(src0, src1);
}
dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const {
float tmpf = exp(src0 + (src1));
float tmpf2 = 1 / tmpf;
return (tmpf - tmpf2) / (tmpf + tmpf2);
}
};
template <typename src_ctype, typename dst_ctype = src_ctype>
struct FuseAddTanhOp;
#define OP(_ctype, _simd_type, _func_suffix, _simd_width) \
template <> \
struct FuseAddTanhOp<_ctype> : FuseAddTanhOpBase<_ctype> { \
using FuseAddTanhOpBase::FuseAddTanhOpBase; \
using FuseAddTanhOpBase::operator(); \
constexpr static size_t SIMD_WIDTH = _simd_width; \
void operator()( \
const _simd_type& src0, const _simd_type& src1, \
dst_ctype* dst) const { \
auto vitem = operator()(src0, src1); \
GiStore##_func_suffix(dst, vitem.val[0]); \
GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \
} \
_simd_type operator()(const _simd_type& src0, const _simd_type& src1) const { \
auto val1 = src0.val[0]; \
auto val2 = src0.val[1]; \
auto val3 = src1.val[0]; \
auto val4 = src1.val[1]; \
val1 = GiAdd##_func_suffix(val1, val3); \
val2 = GiAdd##_func_suffix(val2, val4); \
auto exp1 = GiExpPs##_func_suffix(val1); \
auto exp2 = GiExpPs##_func_suffix(val2); \
auto rexp1 = GiRecpe##_func_suffix(exp1); \
auto rexp2 = GiRecpe##_func_suffix(exp2); \
rexp1 = GiMultiply##_func_suffix( \
GiRecpeS##_func_suffix(exp1, rexp1), rexp1); \
rexp2 = GiMultiply##_func_suffix( \
GiRecpeS##_func_suffix(exp2, rexp2), rexp2); \
val1 = GiSubtract##_func_suffix(exp1, rexp1); \
val2 = GiSubtract##_func_suffix(exp2, rexp2); \
exp1 = GiAdd##_func_suffix(exp1, rexp1); \
exp2 = GiAdd##_func_suffix(exp2, rexp2); \
rexp1 = GiRecpe##_func_suffix(exp1); \
rexp2 = GiRecpe##_func_suffix(exp2); \
rexp1 = GiMultiply##_func_suffix( \
GiRecpeS##_func_suffix(exp1, rexp1), rexp1); \
rexp2 = GiMultiply##_func_suffix( \
GiRecpeS##_func_suffix(exp2, rexp2), rexp2); \
val1 = GiMultiply##_func_suffix(val1, rexp1); \
val2 = GiMultiply##_func_suffix(val2, rexp2); \
return {{val1, val2}}; \
} \
};
OP(dt_float32, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float))
#undef OP
} // namespace fallback
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/fallback/elemwise_helper/kimpl/fuse_mul_add3.h
*/
#pragma once
#include "src/fallback/elemwise_helper/kimpl/op_base.h"
namespace megdnn {
namespace fallback {
template <typename src_ctype, typename dst_ctype = src_ctype>
struct FuseMulAdd3OpBase : TernaryOpBase<src_ctype, dst_ctype> {
using TernaryOpBase<src_ctype, dst_ctype>::TernaryOpBase;
void operator()(
const src_ctype& src0, const src_ctype& src1, const src_ctype src2,
dst_ctype* dst) const {
*dst = operator()(src0, src1, src2);
}
dst_ctype operator()(
const src_ctype& src0, const src_ctype& src1, const src_ctype& src2) const {
return (src0 * src1) + src2;
}
};
template <typename src_ctype, typename dst_ctype = src_ctype>
struct FuseMulAdd3Op;
#define OP(_ctype, _simd_type, _func_suffix, _simd_width) \
template <> \
struct FuseMulAdd3Op<_ctype> : FuseMulAdd3OpBase<_ctype> { \
using FuseMulAdd3OpBase::FuseMulAdd3OpBase; \
using FuseMulAdd3OpBase::operator(); \
constexpr static size_t SIMD_WIDTH = _simd_width; \
void operator()( \
const _simd_type& src0, const _simd_type& src1, \
const _simd_type& src2, dst_ctype* dst) const { \
auto vitem = operator()(src0, src1, src2); \
GiStore##_func_suffix(dst, vitem.val[0]); \
GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \
} \
_simd_type operator()( \
const _simd_type& src0, const _simd_type& src1, \
const _simd_type& src2) const { \
auto vitem0 = GiMultiplyAdd##_func_suffix( \
src2.val[0], src0.val[0], src1.val[0]); \
auto vitem1 = GiMultiplyAdd##_func_suffix( \
src2.val[1], src0.val[1], src1.val[1]); \
return {{vitem0, vitem1}}; \
} \
};
OP(dt_float32, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float))
OP(dt_int32, GI_INT32_V2_t, Int32, GI_SIMD_LEN_BYTE / sizeof(int32_t))
OP(dt_int8, GI_INT8_V2_t, Int8, GI_SIMD_LEN_BYTE / sizeof(int8_t))
#undef OP
} // namespace fallback
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/fallback/elemwise/gi_impl/gi_util_impl_helper.h
*/
#pragma once
/*!
* \brief compute fuse_add_relu on two simd packs
*
* Compute
*
* val1 = fuse_add_relu(val1, val3)
* val2 = fuse_add_relu(val2, val4)
*
* This algorithm handles int overflow.
*/
#define FUSE_ADD_RELU_SIMD_PACK2_FALLBACK(val1, val2, val3, val4, func_suffix) \
do { \
val1 = GiMaximum##func_suffix(val1, GiNeg##func_suffix(val3)); \
val2 = GiMaximum##func_suffix(val2, GiNeg##func_suffix(val4)); \
val1 = GiAdd##func_suffix(val1, val3); \
val2 = GiAdd##func_suffix(val2, val4); \
} while (0)
#define FUSE_ADD_RELU_SIMD_PACK_FALLBACK(val1, val2, func_suffix) \
do { \
val1 = GiMaximum##func_suffix(val1, GiNeg##func_suffix(val2)); \
val1 = GiAdd##func_suffix(val1, val2); \
} while (0)
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/fallback/elemwise_helper/kimpl/hswish.h
*/
#pragma once
#include "src/fallback/elemwise_helper/kimpl/kern_macro_prologue.h"
#include "src/fallback/elemwise_helper/kimpl/op_base.h"
namespace megdnn {
namespace fallback {
template <typename src_ctype, typename dst_ctype = src_ctype>
struct HSwishOpBase : UnaryOpBase<src_ctype, dst_ctype> {
using UnaryOpBase<src_ctype, dst_ctype>::UnaryOpBase;
void operator()(const src_ctype& src, dst_ctype* dst) const {
*dst = operator()(src);
}
dst_ctype operator()(const src_ctype& src) const {
float tmp = src;
tmp = tmp * std::max(std::min(tmp + 3.f, 6.f), 0.f) / 6.f;
return (tmp);
}
};
//! h_swish(x) = x * clip(x + 3, 0, 6) / 6
template <typename src_ctype, typename dst_ctype = src_ctype>
struct HSwishOp;
#define OP(_ctype, _simd_type, _simd_type2, _func_suffix, _simd_width) \
template <> \
struct HSwishOp<_ctype> : HSwishOpBase<_ctype> { \
using HSwishOpBase::HSwishOpBase; \
using HSwishOpBase::operator(); \
constexpr static size_t SIMD_WIDTH = _simd_width; \
void operator()(const _simd_type2& src, _ctype* dst) const { \
auto vitem = operator()(src); \
GiStore##_func_suffix(dst, vitem.val[0]); \
GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \
} \
void operator()(const _simd_type& src, _ctype* dst) const { \
auto vitem = operator()(src); \
GiStore##_func_suffix(dst, vitem); \
} \
_simd_type2 operator()(const _simd_type2& src) const { \
auto val1 = src.val[0]; \
auto val2 = src.val[1]; \
H_SWISH_KERN_FALLBACK(_func_suffix, val1, val2); \
return {{val1, val2}}; \
} \
_simd_type operator()(const _simd_type& src) const { \
auto val_zero = GiBroadcast##_func_suffix(0.f); \
auto val_six = GiBroadcast##_func_suffix(6.f); \
auto val_three = GiBroadcast##_func_suffix(3.f); \
auto val_rec_six = GiBroadcast##_func_suffix(1.f / 6.f); \
auto clip1 = GiMaximum##_func_suffix( \
GiMinimum##_func_suffix( \
GiAdd##_func_suffix(src, val_three), val_six), \
val_zero); \
return GiMultiply##_func_suffix( \
GiMultiply##_func_suffix(src, clip1), val_rec_six); \
} \
};
OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float))
#undef OP
template <>
struct HSwishOpBase<dt_qint32, dt_qint8> : UnaryOpBase<dt_qint32, dt_qint8> {
using UnaryOpBase::UnaryOpBase;
void operator()(const dt_qint32& src, dt_qint8* dst) const {
*dst = operator()(src);
}
dt_qint8 operator()(const dt_qint32& src) const {
float tmp = src.as_int32() * this->scale_src;
tmp = tmp * std::max(std::min(tmp + 3.f, 6.f), 0.f) / 6.f;
tmp *= this->scale_dst;
return QConverter::convert<dt_qint8, float>(tmp);
}
};
template <>
struct HSwishOp<dt_qint32, dt_qint8> : HSwishOpBase<dt_qint32, dt_qint8> {
using HSwishOpBase::HSwishOpBase;
using HSwishOpBase::operator();
constexpr static size_t SIMD_WIDTH = 4;
void operator()(const GI_INT32_V2_t& vsrc, dt_qint8* dst) const {
GiStoreLowInt8(reinterpret_cast<int8_t*>(dst), operator()(vsrc));
}
GI_INT8_t operator()(const GI_INT32_V2_t& vsrc) const {
auto vitem0 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[0]), this->vscale_src);
auto vitem1 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[1]), this->vscale_src);
H_SWISH_KERN_FALLBACK(Float32, vitem0, vitem1);
vitem0 = GiMultiplyFloat32(vitem0, this->vscale_dst);
vitem1 = GiMultiplyFloat32(vitem1, this->vscale_dst);
return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>({{vitem0, vitem1}});
}
};
} // namespace fallback
} // namespace megdnn
#include "src/fallback/elemwise_helper/kimpl/kern_macro_epilogue.h"
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/fallback/elemwise_helper/kimpl/kern_macro_epilogue.h
*/
#undef H_SWISH_KERN_FALLBACK
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/fallback/elemwise_helper/kimpl/kern_macro_prologue.h
*/
#define H_SWISH_KERN_FALLBACK(_func_suffix, _val1, _val2) \
do { \
auto val_zero = GiBroadcast##_func_suffix(0.f); \
auto val_six = GiBroadcast##_func_suffix(6.f); \
auto val_three = GiBroadcast##_func_suffix(3.f); \
auto val_rec_six = GiBroadcast##_func_suffix(1.f / 6.f); \
auto clip1 = GiMaximum##_func_suffix( \
GiMinimum##_func_suffix( \
GiAdd##_func_suffix(_val1, val_three), val_six), \
val_zero); \
auto clip2 = GiMaximum##_func_suffix( \
GiMinimum##_func_suffix( \
GiAdd##_func_suffix(_val2, val_three), val_six), \
val_zero); \
_val1 = GiMultiply##_func_suffix( \
GiMultiply##_func_suffix(_val1, clip1), val_rec_six); \
_val2 = GiMultiply##_func_suffix( \
GiMultiply##_func_suffix(_val2, clip2), val_rec_six); \
} while (0);
#define H_SWISH_KERN_N1_FALLBACK(_func_suffix, _val1) \
do { \
auto val_zero = GiBroadcast##_func_suffix(0.f); \
auto val_six = GiBroadcast##_func_suffix(6.f); \
auto val_three = GiBroadcast##_func_suffix(3.f); \
auto val_rec_six = GiBroadcast##_func_suffix(1.f / 6.f); \
auto clip1 = GiMaximum##_func_suffix( \
GiMinimum##_func_suffix( \
GiAdd##_func_suffix(_val1, val_three), val_six), \
val_zero); \
_val1 = GiMultiply##_func_suffix( \
GiMultiply##_func_suffix(_val1, clip1), val_rec_six); \
} while (0);
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/fallback/elemwise_helper/kimpl/max.h
*/
#pragma once
#include "src/fallback/elemwise_helper/kimpl/op_base.h"
namespace megdnn {
namespace fallback {
template <typename src_ctype, typename dst_ctype = src_ctype>
struct MaxOpBase : BinaryOpBase<src_ctype, dst_ctype> {
using BinaryOpBase<src_ctype, dst_ctype>::BinaryOpBase;
void operator()(
const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const {
*dst = operator()(src0, src1);
}
dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const {
return src0 > src1 ? src0 : src1;
}
};
template <typename src_ctype, typename dst_ctype = src_ctype>
struct MaxOp;
#define OP(_ctype, _simd_type, _simd_type2, _func_suffix, _simd_width) \
template <> \
struct MaxOp<_ctype> : MaxOpBase<_ctype> { \
using MaxOpBase::MaxOpBase; \
using MaxOpBase::operator(); \
constexpr static size_t SIMD_WIDTH = _simd_width; \
void operator()( \
const _simd_type2& src0, const _simd_type2& src1, \
dst_ctype* dst) const { \
auto vitem = operator()(src0, src1); \
GiStore##_func_suffix(dst, vitem.val[0]); \
GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \
} \
_simd_type2 operator()( \
const _simd_type2& src0, const _simd_type2& src1) const { \
auto vitem0 = GiMaximum##_func_suffix(src0.val[0], src1.val[0]); \
auto vitem1 = GiMaximum##_func_suffix(src0.val[1], src1.val[1]); \
return {{vitem0, vitem1}}; \
} \
void operator()( \
const _simd_type& src0, const _simd_type& src1, \
dst_ctype* dst) const { \
auto vitem = operator()(src0, src1); \
GiStore##_func_suffix(dst, vitem); \
} \
_simd_type operator()(const _simd_type& src0, const _simd_type& src1) const { \
return GiMaximum##_func_suffix(src0, src1); \
} \
};
OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float))
OP(dt_int32, GI_INT32_t, GI_INT32_V2_t, Int32, GI_SIMD_LEN_BYTE / sizeof(int32_t))
OP(dt_int8, GI_INT8_t, GI_INT8_V2_t, Int8, GI_SIMD_LEN_BYTE / sizeof(int8_t))
#undef OP
template <>
struct MaxOpBase<dt_qint8, dt_qint8> : BinaryOpBase<dt_qint8, dt_qint8> {
using src_ctype = dt_qint8;
using dst_ctype = dt_qint8;
using BinaryOpBase::BinaryOpBase;
void operator()(
const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const {
*dst = operator()(src0, src1);
}
dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const {
float fsrc0 = src0.as_int8() * this->scale0;
float fsrc1 = src1.as_int8() * this->scale1;
return QConverter::convert<dst_ctype, float>(fsrc0 > fsrc1 ? fsrc0 : fsrc1);
}
};
template <>
struct MaxOp<dt_qint8, dt_qint8> : MaxOpBase<dt_qint8, dt_qint8> {
using MaxOpBase::MaxOpBase;
constexpr static size_t SIMD_WIDTH = 16;
using MaxOpBase::operator();
void operator()(
const GI_INT8_V2_t& vsrc0, const GI_INT8_V2_t& vsrc1, dt_qint8* dst) const {
OPERATOR_BINARY_QINT8_FALLBACK;
}
GI_INT8_t operator()(const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1) const {
auto vitem0 = GiMaximumFloat32(
GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[0]), this->vscale0),
GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[0]), this->vscale1));
auto vitem1 = GiMaximumFloat32(
GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[1]), this->vscale0),
GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[1]), this->vscale1));
return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>({{vitem0, vitem1}});
}
};
} // namespace fallback
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/fallback/elemwise_helper/kimpl/min.h
*/
#pragma once
#include "src/fallback/elemwise_helper/kimpl/op_base.h"
namespace megdnn {
namespace fallback {
template <typename src_ctype, typename dst_ctype = src_ctype>
struct MinOpBase : BinaryOpBase<src_ctype, dst_ctype> {
using BinaryOpBase<src_ctype, dst_ctype>::BinaryOpBase;
void operator()(
const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const {
*dst = operator()(src0, src1);
}
dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const {
return src0 < src1 ? src0 : src1;
}
};
template <typename src_ctype, typename dst_ctype = src_ctype>
struct MinOp;
#define OP(_ctype, _simd_type, _simd_type2, _func_suffix, _simd_width) \
template <> \
struct MinOp<_ctype> : MinOpBase<_ctype> { \
using MinOpBase::MinOpBase; \
using MinOpBase::operator(); \
constexpr static size_t SIMD_WIDTH = _simd_width; \
void operator()( \
const _simd_type2& src0, const _simd_type2& src1, \
dst_ctype* dst) const { \
auto vitem = operator()(src0, src1); \
GiStore##_func_suffix(dst, vitem.val[0]); \
GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \
} \
_simd_type2 operator()( \
const _simd_type2& src0, const _simd_type2& src1) const { \
auto vitem0 = GiMinimum##_func_suffix(src0.val[0], src1.val[0]); \
auto vitem1 = GiMinimum##_func_suffix(src0.val[1], src1.val[1]); \
return {{vitem0, vitem1}}; \
} \
void operator()( \
const _simd_type& src0, const _simd_type& src1, \
dst_ctype* dst) const { \
auto vitem = operator()(src0, src1); \
GiStore##_func_suffix(dst, vitem); \
} \
_simd_type operator()(const _simd_type& src0, const _simd_type& src1) const { \
return GiMinimum##_func_suffix(src0, src1); \
} \
};
OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float))
OP(dt_int32, GI_INT32_t, GI_INT32_V2_t, Int32, GI_SIMD_LEN_BYTE / sizeof(int32_t))
OP(dt_int8, GI_INT8_t, GI_INT8_V2_t, Int8, GI_SIMD_LEN_BYTE / sizeof(int8_t))
#undef OP
template <>
struct MinOpBase<dt_qint8, dt_qint8> : BinaryOpBase<dt_qint8, dt_qint8> {
using BinaryOpBase::BinaryOpBase;
void operator()(const dt_qint8& src0, const dt_qint8& src1, dt_qint8* dst) const {
*dst = operator()(src0, src1);
}
dt_qint8 operator()(const dt_qint8& src0, const dt_qint8& src1) const {
float fsrc0 = src0.as_int8() * this->scale0;
float fsrc1 = src1.as_int8() * this->scale1;
return QConverter::convert<dt_qint8, float>(fsrc0 < fsrc1 ? fsrc0 : fsrc1);
}
};
template <>
struct MinOp<dt_qint8, dt_qint8> : MinOpBase<dt_qint8, dt_qint8> {
using MinOpBase::MinOpBase;
constexpr static size_t SIMD_WIDTH = 16;
using MinOpBase::operator();
void operator()(
const GI_INT8_V2_t& vsrc0, const GI_INT8_V2_t& vsrc1, dt_qint8* dst) const {
OPERATOR_BINARY_QINT8_FALLBACK;
}
GI_INT8_t operator()(const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1) const {
auto vitem0 = GiMinimumFloat32(
GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[0]), this->vscale0),
GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[0]), this->vscale1));
auto vitem1 = GiMinimumFloat32(
GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[1]), this->vscale0),
GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[1]), this->vscale1));
return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>({{vitem0, vitem1}});
}
};
} // namespace fallback
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/fallback/elemwise_helper/kimpl/mul.h
*/
#pragma once
#include "src/fallback/elemwise_helper/kimpl/op_base.h"
namespace megdnn {
namespace fallback {
template <typename src_ctype, typename dst_ctype = src_ctype>
struct MulOpBase : BinaryOpBase<src_ctype, dst_ctype> {
using BinaryOpBase<src_ctype, dst_ctype>::BinaryOpBase;
void operator()(
const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const {
*dst = operator()(src0, src1);
}
dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const {
return src0 * src1;
}
};
template <typename src_ctype, typename dst_ctype = src_ctype>
struct MulOp;
#define OP(_ctype, _simd_type, _simd_type2, _func_suffix, _simd_width) \
template <> \
struct MulOp<_ctype> : MulOpBase<_ctype> { \
using MulOpBase::MulOpBase; \
using MulOpBase::operator(); \
constexpr static size_t SIMD_WIDTH = _simd_width; \
void operator()( \
const _simd_type2& src0, const _simd_type2& src1, \
dst_ctype* dst) const { \
auto vitem = operator()(src0, src1); \
GiStore##_func_suffix(dst, vitem.val[0]); \
GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \
} \
_simd_type2 operator()( \
const _simd_type2& src0, const _simd_type2& src1) const { \
auto vitem0 = GiMultiply##_func_suffix(src0.val[0], src1.val[0]); \
auto vitem1 = GiMultiply##_func_suffix(src0.val[1], src1.val[1]); \
return {{vitem0, vitem1}}; \
} \
void operator()( \
const _simd_type& src0, const _simd_type& src1, \
dst_ctype* dst) const { \
auto vitem = operator()(src0, src1); \
GiStore##_func_suffix(dst, vitem); \
} \
_simd_type operator()(const _simd_type& src0, const _simd_type& src1) const { \
return GiMultiply##_func_suffix(src0, src1); \
} \
};
OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float))
OP(dt_int32, GI_INT32_t, GI_INT32_V2_t, Int32, GI_SIMD_LEN_BYTE / sizeof(int32_t))
OP(dt_int8, GI_INT8_t, GI_INT8_V2_t, Int8, GI_SIMD_LEN_BYTE / sizeof(int8_t))
#undef OP
template <>
struct MulOpBase<dt_qint8, dt_qint8> : BinaryOpBase<dt_qint8, dt_qint8> {
using BinaryOpBase::BinaryOpBase;
void operator()(const dt_qint8& src0, const dt_qint8& src1, dt_qint8* dst) const {
*dst = operator()(src0, src1);
}
dt_qint8 operator()(const dt_qint8& src0, const dt_qint8& src1) const {
return QConverter::convert<dt_qint8, float>(
src0.as_int8() * scale_src0 * src1.as_int8() * scale1);
}
};
template <>
struct MulOp<dt_qint8, dt_qint8> : MulOpBase<dt_qint8, dt_qint8> {
using MulOpBase::MulOpBase;
constexpr static size_t SIMD_WIDTH = 16;
using MulOpBase::operator();
void operator()(
const GI_INT8_V2_t& vsrc0, const GI_INT8_V2_t& vsrc1, dt_qint8* dst) const {
OPERATOR_BINARY_QINT8_FALLBACK;
}
GI_INT8_t operator()(const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1) const {
auto vitem0 = GiMultiplyFloat32(
GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[0]), this->vscale_src0),
GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[0]), this->vscale1));
auto vitem1 = GiMultiplyFloat32(
GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[1]), this->vscale_src0),
GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[1]), this->vscale1));
return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>({{vitem0, vitem1}});
}
};
} // namespace fallback
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/fallback/elemwise_helper/kimpl/none.h
*/
#pragma once
#include "src/fallback/elemwise_helper/kimpl/op_base.h"
namespace megdnn {
namespace fallback {
template <typename src_ctype, typename dst_ctype = src_ctype>
struct NoneOpBase : UnaryOpBase<src_ctype, dst_ctype> {
using UnaryOpBase<src_ctype, dst_ctype>::UnaryOpBase;
dst_ctype operator()(const src_ctype& src) const { return src; }
};
template <typename src_ctype, typename dst_type = src_ctype>
struct NoneOp;
#define OP(_ctype, _simd_type, _simd_type2, _func_suffix, _simd_width) \
template <> \
struct NoneOp<_ctype> : NoneOpBase<_ctype> { \
NoneOp(){}; \
NoneOp(float, float){}; \
using NoneOpBase::NoneOpBase; \
using NoneOpBase::operator(); \
constexpr static size_t SIMD_WIDTH = _simd_width; \
_simd_type2 operator()(const _simd_type2& src) const { return src; } \
void operator()(const _simd_type2& src, _ctype* dst) const { \
GiStore##_func_suffix(dst, src.val[0]); \
GiStore##_func_suffix(dst + SIMD_WIDTH, src.val[1]); \
} \
void operator()(const _simd_type& src, _ctype* dst) const { \
GiStore##_func_suffix(dst, src); \
} \
_simd_type operator()(const _simd_type& src) const { return src; } \
};
OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float))
OP(dt_int32, GI_INT32_t, GI_INT32_V2_t, Int32, GI_SIMD_LEN_BYTE / sizeof(int32_t))
OP(dt_int8, GI_INT8_t, GI_INT8_V2_t, Int8, GI_SIMD_LEN_BYTE / sizeof(int8_t))
#undef OP
template <>
struct NoneOpBase<dt_qint8, dt_qint8> : UnaryOpBase<dt_qint8, dt_qint8> {
using UnaryOpBase::UnaryOpBase;
void operator()(const dt_qint8& src, dt_qint8* dst) const { *dst = src; }
};
template <>
struct NoneOpBase<dt_qint32, dt_qint8> : UnaryOpBase<dt_qint32, dt_qint8> {
using UnaryOpBase::UnaryOpBase;
void operator()(const dt_qint32& src, dt_qint8* dst) const {
*(reinterpret_cast<dt_qint32*>(dst)) = src;
}
};
#pragma GCC diagnostic ignored "-Waddress-of-packed-member"
template <>
struct NoneOp<dt_qint32, dt_qint8> : NoneOpBase<dt_qint32, dt_qint8> {
using NoneOpBase::NoneOpBase;
using NoneOpBase::operator();
constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int32_t);
void operator()(const GI_INT32_V2_t& vsrc, dt_qint8* dst) const {
GiStoreInt32(reinterpret_cast<int32_t*>(dst), vsrc.val[0]);
GiStoreInt32(reinterpret_cast<int32_t*>(dst + 16), vsrc.val[1]);
}
void operator()(const GI_INT32_t& src, dt_qint8* dst) const {
GiStoreInt32(reinterpret_cast<int32_t*>(dst), src);
}
};
} // namespace fallback
} // namespace megdnn
// vim: syntax=cpp.doxygen
此差异已折叠。
/**
* \file dnn/src/fallback/elemwise_helper/kimpl/pow.h
*/
#pragma once
#include "src/fallback/elemwise_helper/kimpl/op_base.h"
namespace megdnn {
namespace fallback {
/////////////////////// POW float only ////////////////////////////
template <typename src_ctype, typename dst_ctype = src_ctype>
struct PowOp : BinaryOpBase<src_ctype, dst_ctype> {
using BinaryOpBase<src_ctype, dst_ctype>::BinaryOpBase;
constexpr static size_t SIMD_WIDTH = 1;
void operator()(
const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const {
*dst = operator()(src0, src1);
}
dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const {
return powf(src0, src1);
}
};
} // namespace fallback
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/fallback/elemwise_helper/kimpl/relu.h
*/
#pragma once
#include "src/fallback/elemwise_helper/kimpl/op_base.h"
namespace megdnn {
namespace fallback {
template <typename src_ctype, typename dst_ctype = src_ctype>
struct ReluOpBase : UnaryOpBase<src_ctype, dst_ctype> {
using UnaryOpBase<src_ctype, dst_ctype>::UnaryOpBase;
void operator()(const src_ctype& src, dst_ctype* dst) const {
*dst = operator()(src);
}
dst_ctype operator()(const src_ctype& src) const { return src > 0 ? src : 0; }
};
template <typename src_ctype, typename dst_type = src_ctype>
struct ReluOp;
#define OP(_ctype, _simd_type, _simd_type2, _func_suffix, _simd_width) \
template <> \
struct ReluOp<_ctype> : ReluOpBase<_ctype> { \
using ReluOpBase::ReluOpBase; \
using ReluOpBase::operator(); \
constexpr static size_t SIMD_WIDTH = _simd_width; \
void operator()(const _simd_type2& src, _ctype* dst) const { \
auto vitem = operator()(src); \
GiStore##_func_suffix(dst, vitem.val[0]); \
GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \
} \
_simd_type2 operator()(const _simd_type2& src) const { \
auto vzero = GiBroadcast##_func_suffix(0); \
auto vitem0 = GiMaximum##_func_suffix(src.val[0], vzero); \
auto vitem1 = GiMaximum##_func_suffix(src.val[1], vzero); \
return {{vitem0, vitem1}}; \
} \
void operator()(const _simd_type& src, _ctype* dst) const { \
auto vitem = operator()(src); \
GiStore##_func_suffix(dst, vitem); \
} \
_simd_type operator()(const _simd_type& src) const { \
auto vzero = GiBroadcast##_func_suffix(0); \
return GiMaximum##_func_suffix(src, vzero); \
} \
};
OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float))
OP(dt_int32, GI_INT32_t, GI_INT32_V2_t, Int32, GI_SIMD_LEN_BYTE / sizeof(int32_t))
OP(dt_int8, GI_INT8_t, GI_INT8_V2_t, Int8, GI_SIMD_LEN_BYTE / sizeof(int8_t))
#undef OP
template <>
struct ReluOpBase<dt_qint8, dt_qint8> : UnaryOpBase<dt_qint8, dt_qint8> {
using UnaryOpBase::UnaryOpBase;
void operator()(const dt_qint8& src, dt_qint8* dst) const {
*dst = operator()(src);
}
dt_qint8 operator()(const dt_qint8& src) const {
float fsrc = src.as_int8() * this->scale;
fsrc = std::max<float>(fsrc, 0.f);
return QConverter::convert<dt_qint8, float>(fsrc);
}
};
template <>
struct ReluOp<dt_qint8, dt_qint8> : ReluOpBase<dt_qint8, dt_qint8> {
using ReluOpBase::ReluOpBase;
constexpr static size_t SIMD_WIDTH = 16;
using ReluOpBase::operator();
void operator()(const GI_INT8_V2_t& vsrc, dt_qint8* dst) const {
OPERATOR_UNARY_QINT8_FALLBACK;
}
GI_INT8_t operator()(const GI_INT32_V2_t& vsrc) const {
auto vzero = GiBroadcastFloat32(0.f);
auto vitem0 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[0]), this->vscale);
auto vitem1 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[1]), this->vscale);
vitem0 = GiMaximumFloat32(vitem0, vzero);
vitem1 = GiMaximumFloat32(vitem1, vzero);
return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>({{vitem0, vitem1}});
}
};
template <>
struct ReluOpBase<dt_qint32, dt_qint8> : UnaryOpBase<dt_qint32, dt_qint8> {
using UnaryOpBase::UnaryOpBase;
void operator()(const dt_qint32& src, dt_qint8* dst) const {
*dst = operator()(src);
}
dt_qint8 operator()(const dt_qint32& src) const {
float fsrc = src.as_int32() * this->scale;
fsrc = std::max<float>(fsrc, 0.f);
return QConverter::convert<dt_qint8, float>(fsrc);
}
};
//! if old armv7, special define relu with fixup
#if defined(__ARM_ARCH) && __ARM_ARCH < 8
template <>
struct ReluOp<dt_qint32, dt_qint8> : ReluOpBase<dt_qint32, dt_qint8>, FixupBase {
using ReluOpBase::operator();
constexpr static size_t SIMD_WIDTH = 4;
ReluOp(DType src_dtype, DType dst_dtype)
: ReluOpBase(src_dtype, dst_dtype), FixupBase(scale) {}
ReluOp(float src_scale, float dst_scale)
: ReluOpBase(src_scale, dst_scale), FixupBase(scale) {}
void operator()(const int32x4x2_t& vsrc, dt_qint8* dst) const {
vst1_s8(reinterpret_cast<int8_t*>(dst), operator()(vsrc));
}
int8x8_t operator()(const int32x4x2_t& vsrc) const {
int32x4_t vitem0 = vqrdmulhq_s32(vsrc.val[0], vmultiplier);
int32x4_t vitem1 = vqrdmulhq_s32(vsrc.val[1], vmultiplier);
vitem0 = vmaxq_s32(vitem0, QConverterBase::vzero());
vitem1 = vmaxq_s32(vitem1, QConverterBase::vzero());
return vqmovn_s16(vcombine_s16(
vqmovn_s32(vrshlq_s32(vitem0, vshift)),
vqmovn_s32(vrshlq_s32(vitem1, vshift))));
}
int8x8_t operator()(const float32x4_t& vsrc) const {
int32x4_t vitem0 = vqrdmulhq_s32(vcvtq_s32_f32(vsrc), vmultiplier);
vitem0 = vmaxq_s32(vitem0, QConverterBase::vzero());
vitem0 = vrshlq_s32(vitem0, vshift);
int16x4_t vitem = vqmovn_s32(vitem0);
return vqmovn_s16(vcombine_s16(vitem, vitem));
}
void operator()(const int32x4_t& src, dt_qint8* dst) const {
auto vitem0 = vmulq_f32(vcvtq_f32_s32(src), this->vscale);
vitem0 = vmaxq_f32(vitem0, QConverterBase::vfzero());
auto result = QConverter::convert<int8x8_t, float32x4_t>(vitem0);
vst1_lane_s32(reinterpret_cast<int32_t*>(dst), (int32x2_t)result, 0);
}
void operator()(const float32x4_t& src, dt_qint8* dst) const {
auto vitem0 = vmulq_f32(src, this->vscale);
vitem0 = vmaxq_f32(vitem0, QConverterBase::vfzero());
auto result = QConverter::convert<int8x8_t, float32x4_t>(vitem0);
vst1_lane_s32(reinterpret_cast<int32_t*>(dst), (int32x2_t)result, 0);
}
};
#else
template <>
struct ReluOp<dt_qint32, dt_qint8> : ReluOpBase<dt_qint32, dt_qint8> {
using ReluOpBase::ReluOpBase;
using ReluOpBase::operator();
constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int32_t);
void operator()(const GI_INT32_V2_t& vsrc, dt_qint8* dst) const {
GiStoreLowInt8(reinterpret_cast<int8_t*>(dst), operator()(vsrc));
}
void operator()(const GI_INT32_t& src, dt_qint8* dst) const {
GiStoreLane0Int32(
reinterpret_cast<int32_t*>(dst), (GI_INT32_t)(operator()(src)));
}
GI_INT8_t operator()(const GI_INT32_V2_t& vsrc) const {
auto vitem0 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[0]), this->vscale);
auto vitem1 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[1]), this->vscale);
vitem0 = GiMaximumFloat32(vitem0, QConverterBase::vfzero());
vitem1 = GiMaximumFloat32(vitem1, QConverterBase::vfzero());
return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>({{vitem0, vitem1}});
}
GI_INT8_t operator()(const GI_INT32_t& src) const {
auto vitem0 = GiMultiplyFloat32(GiCastToFloat32(src), this->vscale);
vitem0 = GiMaximumFloat32(vitem0, QConverterBase::vfzero());
return QConverter::convert<GI_INT8_t, GI_FLOAT32_t>(vitem0);
}
GI_INT8_t operator()(const GI_FLOAT32_t& src) const {
auto vitem0 = GiMultiplyFloat32(src, this->vscale);
vitem0 = GiMaximumFloat32(vitem0, QConverterBase::vfzero());
return QConverter::convert<GI_INT8_t, GI_FLOAT32_t>(vitem0);
}
};
#endif
} // namespace fallback
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/fallback/elemwise_helper/kimpl/sigmoid.h
*/
#pragma once
#include "src/fallback/elemwise_helper/kimpl/op_base.h"
namespace megdnn {
namespace fallback {
template <typename src_ctype, typename dst_ctype = src_ctype>
struct SigmoidOpBase : UnaryOpBase<src_ctype, dst_ctype> {
using UnaryOpBase<src_ctype, dst_ctype>::UnaryOpBase;
void operator()(const src_ctype& src, dst_ctype* dst) const {
*dst = operator()(src);
}
dst_ctype operator()(const src_ctype& src) const {
float tmpf = src;
tmpf = exp(-tmpf);
tmpf = 1.f / (1.f + tmpf);
return tmpf;
}
};
template <typename src_ctype, typename dst_ctype = src_ctype>
struct SigmoidOp;
#define OP(_ctype, _simd_type, _simd_type2, _func_suffix, _simd_width) \
template <> \
struct SigmoidOp<_ctype> : SigmoidOpBase<_ctype> { \
using SigmoidOpBase::SigmoidOpBase; \
using SigmoidOpBase::operator(); \
constexpr static size_t SIMD_WIDTH = _simd_width; \
void operator()(const _simd_type2& src, _ctype* dst) const { \
auto vitem = operator()(src); \
GiStore##_func_suffix(dst, vitem.val[0]); \
GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \
} \
void operator()(const _simd_type& src, _ctype* dst) const { \
auto vitem = operator()(src); \
GiStore##_func_suffix(dst, vitem); \
} \
_simd_type2 operator()(const _simd_type2& src) const { \
return {{operator()(src.val[0]), operator()(src.val[1])}}; \
} \
_simd_type operator()(const _simd_type& src) const { \
return GiSigmoidPs##_func_suffix(src); \
} \
};
OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float))
#undef OP
} // namespace fallback
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/fallback/elemwise_helper/kimpl/sub.h
*/
#pragma once
#include "src/fallback/elemwise_helper/kimpl/op_base.h"
namespace megdnn {
namespace fallback {
template <typename src_ctype, typename dst_ctype = src_ctype>
struct SubOpBase : BinaryOpBase<src_ctype, dst_ctype> {
using BinaryOpBase<src_ctype, dst_ctype>::BinaryOpBase;
void operator()(
const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const {
*dst = operator()(src0, src1);
}
dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const {
return src0 - src1;
}
};
template <typename src_ctype, typename dst_ctype = src_ctype>
struct SubOp;
#define OP(_ctype, _simd_type, _simd_type2, _func_suffix, _simd_width) \
template <> \
struct SubOp<_ctype> : SubOpBase<_ctype> { \
using SubOpBase::SubOpBase; \
using SubOpBase::operator(); \
constexpr static size_t SIMD_WIDTH = _simd_width; \
void operator()( \
const _simd_type2& src0, const _simd_type2& src1, \
dst_ctype* dst) const { \
auto vitem = operator()(src0, src1); \
GiStore##_func_suffix(dst, vitem.val[0]); \
GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \
} \
_simd_type2 operator()( \
const _simd_type2& src0, const _simd_type2& src1) const { \
auto vitem0 = GiSubtract##_func_suffix(src0.val[0], src1.val[0]); \
auto vitem1 = GiSubtract##_func_suffix(src0.val[1], src1.val[1]); \
return {{vitem0, vitem1}}; \
} \
void operator()( \
const _simd_type& src0, const _simd_type& src1, \
dst_ctype* dst) const { \
auto vitem = operator()(src0, src1); \
GiStore##_func_suffix(dst, vitem); \
} \
_simd_type operator()(const _simd_type& src0, const _simd_type& src1) const { \
return GiSubtract##_func_suffix(src0, src1); \
} \
};
OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float))
OP(dt_int32, GI_INT32_t, GI_INT32_V2_t, Int32, GI_SIMD_LEN_BYTE / sizeof(int32_t))
OP(dt_int8, GI_INT8_t, GI_INT8_V2_t, Int8, GI_SIMD_LEN_BYTE / sizeof(int8_t))
#undef OP
template <>
struct SubOpBase<dt_qint8, dt_qint8> : BinaryOpBase<dt_qint8, dt_qint8> {
using BinaryOpBase::BinaryOpBase;
void operator()(const dt_qint8& src0, const dt_qint8& src1, dt_qint8* dst) const {
*dst = operator()(src0, src1);
}
dt_qint8 operator()(const dt_qint8& src0, const dt_qint8& src1) const {
return QConverter::convert<dt_qint8, float>(
src0.as_int8() * scale0 - src1.as_int8() * scale1);
}
};
template <>
struct SubOp<dt_qint8, dt_qint8> : SubOpBase<dt_qint8, dt_qint8> {
using SubOpBase::SubOpBase;
constexpr static size_t SIMD_WIDTH = 16;
using SubOpBase::operator();
void operator()(
const GI_INT8_V2_t& vsrc0, const GI_INT8_V2_t& vsrc1, dt_qint8* dst) const {
OPERATOR_BINARY_QINT8_FALLBACK;
}
GI_INT8_t operator()(const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1) const {
auto vitem0 = GiSubtractFloat32(
GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[0]), this->vscale0),
GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[0]), this->vscale1));
auto vitem1 = GiSubtractFloat32(
GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[1]), this->vscale0),
GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[1]), this->vscale1));
return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>({{vitem0, vitem1}});
}
};
} // namespace fallback
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/fallback/elemwise_helper/kimpl/tanh.h
*/
#pragma once
#include "src/fallback/elemwise_helper/kimpl/op_base.h"
namespace megdnn {
namespace fallback {
template <typename src_ctype, typename dst_ctype = src_ctype>
struct TanhOpBase : UnaryOpBase<src_ctype, dst_ctype> {
using UnaryOpBase<src_ctype, dst_ctype>::UnaryOpBase;
void operator()(const src_ctype& src, dst_ctype* dst) const {
*dst = operator()(src);
}
dst_ctype operator()(const src_ctype& src) const {
float tmp = src;
return tanh(tmp);
}
};
template <typename src_ctype, typename dst_type = src_ctype>
struct TanhOp;
#define OP(_ctype, _simd_type, _simd_type2, _func_suffix, _simd_width) \
template <> \
struct TanhOp<_ctype> : TanhOpBase<_ctype> { \
using TanhOpBase::TanhOpBase; \
using TanhOpBase::operator(); \
constexpr static size_t SIMD_WIDTH = _simd_width; \
void operator()(const _simd_type2& src, _ctype* dst) const { \
auto vitem = operator()(src); \
GiStore##_func_suffix(dst, vitem.val[0]); \
GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \
} \
_simd_type2 operator()(const _simd_type2& src) const { \
auto one_val = GiBroadcast##_func_suffix(1.f); \
auto two_val = GiBroadcast##_func_suffix(2.f); \
auto val1 = src.val[0]; \
auto val2 = src.val[1]; \
val1 = GiMultiply##_func_suffix(two_val, val1); \
val2 = GiMultiply##_func_suffix(two_val, val2); \
val1 = GiExpPs##_func_suffix(val1); \
val2 = GiExpPs##_func_suffix(val2); \
val1 = GiAdd##_func_suffix(one_val, val1); \
val2 = GiAdd##_func_suffix(one_val, val2); \
auto rval1 = GiRecpe##_func_suffix(val1); \
auto rval2 = GiRecpe##_func_suffix(val2); \
rval1 = GiMultiply##_func_suffix( \
GiRecpeS##_func_suffix(val1, rval1), rval1); \
rval2 = GiMultiply##_func_suffix( \
GiRecpeS##_func_suffix(val2, rval2), rval2); \
val1 = GiMultiply##_func_suffix(two_val, rval1); \
val2 = GiMultiply##_func_suffix(two_val, rval2); \
val1 = GiSubtract##_func_suffix(one_val, val1); \
val2 = GiSubtract##_func_suffix(one_val, val2); \
return {{val1, val2}}; \
} \
_simd_type operator()(const _simd_type& src) const { \
auto one_val = GiBroadcast##_func_suffix(1.f); \
auto two_val = GiBroadcast##_func_suffix(2.f); \
auto val1 = src; \
val1 = GiMultiply##_func_suffix(two_val, val1); \
val1 = GiExpPs##_func_suffix(val1); \
val1 = GiAdd##_func_suffix(one_val, val1); \
auto rval1 = GiRecpe##_func_suffix(val1); \
rval1 = GiMultiply##_func_suffix( \
GiRecpeS##_func_suffix(val1, rval1), rval1); \
val1 = GiMultiply##_func_suffix(two_val, rval1); \
val1 = GiSubtract##_func_suffix(one_val, val1); \
return val1; \
} \
};
OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float))
#undef OP
} // namespace fallback
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/fallback/elemwise_helper/kimpl/true_div.h
*/
#pragma once
#include "src/fallback/elemwise_helper/kimpl/op_base.h"
namespace megdnn {
namespace fallback {
//! use a couple Newton-Raphson steps to refine the estimate.
//! A / B => 1. rB = vrecpeq_f32(B) 2. rB= vmulq_f32(vrecpsq_f32(B, rB), rB)
//! 3. A * rB
template <typename src_ctype, typename dst_ctype = src_ctype>
struct TrueDivOpBase : BinaryOpBase<src_ctype, dst_ctype> {
using BinaryOpBase<src_ctype, dst_ctype>::BinaryOpBase;
void operator()(
const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const {
*dst = operator()(src0, src1);
}
dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const {
return src0 / src1;
}
};
template <typename src_ctype, typename dst_ctype = src_ctype>
struct TrueDivOp;
#define OP(_ctype, _simd_type, _simd_type2, _func_suffix, _simd_width) \
template <> \
struct TrueDivOp<_ctype> : TrueDivOpBase<_ctype> { \
using TrueDivOpBase::TrueDivOpBase; \
using TrueDivOpBase::operator(); \
constexpr static size_t SIMD_WIDTH = _simd_width; \
void operator()( \
const _simd_type2& src0, const _simd_type2& src1, \
dst_ctype* dst) const { \
auto vitem = operator()(src0, src1); \
GiStore##_func_suffix(dst, vitem.val[0]); \
GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \
} \
_simd_type2 operator()( \
const _simd_type2& src0, const _simd_type2& src1) const { \
auto val1 = src0.val[0]; \
auto val2 = src0.val[1]; \
auto val3 = src1.val[0]; \
auto val4 = src1.val[1]; \
val1 = GiDivide##_func_suffix(val1, val3); \
val2 = GiDivide##_func_suffix(val2, val4); \
return {{val1, val2}}; \
} \
void operator()( \
const _simd_type& src0, const _simd_type& src1, \
dst_ctype* dst) const { \
auto vitem = operator()(src0, src1); \
GiStore##_func_suffix(dst, vitem); \
} \
_simd_type operator()(const _simd_type& src0, const _simd_type& src1) const { \
return GiDivide##_func_suffix(src0, src1); \
} \
};
OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float))
#undef OP
} // namespace fallback
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/fallback/elemwise_helper/kimpl/typecvt.h
*/
#pragma once
#include "src/fallback/elemwise_helper/kimpl/op_base.h"
namespace megdnn {
namespace fallback {
template <typename src_ctype, typename dst_ctype = src_ctype>
struct TypeCvtOp;
template <>
struct TypeCvtOp<dt_qint32, dt_qint8> : UnaryOpBase<dt_qint32, dt_qint8> {
using UnaryOpBase::UnaryOpBase;
constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(float);
void operator()(const GI_INT32_V2_t& vsrc, dt_qint8* dst) const {
GiStoreLowInt8(reinterpret_cast<int8_t*>(dst), operator()(vsrc));
}
void operator()(const GI_INT32_t& vsrc, dt_qint8* dst) const {
GiStoreLane0Int32(
reinterpret_cast<int32_t*>(dst), (GI_INT32_t)(operator()(vsrc)));
}
void operator()(const src_ctype& src, dst_ctype* dst) const {
*dst = operator()(src);
}
dt_qint8 operator()(const dt_qint32& src) const {
float fsrc = src.as_int32() * this->scale;
return QConverter::convert<dt_qint8, float>(fsrc);
}
GI_INT8_t operator()(const GI_INT32_V2_t& vsrc) const {
auto vitem0 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[0]), this->vscale);
auto vitem1 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[1]), this->vscale);
return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>({{vitem0, vitem1}});
}
GI_INT8_t operator()(const GI_INT32_t& src) const {
auto vitem0 = GiMultiplyFloat32(GiCastToFloat32(src), this->vscale);
return QConverter::convert<GI_INT8_t, GI_FLOAT32_t>(vitem0);
}
GI_INT8_t operator()(const GI_FLOAT32_t& src) const {
auto vitem0 = GiMultiplyFloat32(src, this->vscale);
return QConverter::convert<GI_INT8_t, GI_FLOAT32_t>(vitem0);
}
};
} // namespace fallback
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/fallback/elemwise_helper/op_binary.h
*/
#pragma once
#include "src/fallback/elemwise_helper/kimpl/add.h"
#include "src/fallback/elemwise_helper/kimpl/fuse_add_h_swish.h"
#include "src/fallback/elemwise_helper/kimpl/fuse_add_relu.h"
#include "src/fallback/elemwise_helper/kimpl/fuse_add_sigmoid.h"
#include "src/fallback/elemwise_helper/kimpl/fuse_add_tanh.h"
#include "src/fallback/elemwise_helper/kimpl/max.h"
#include "src/fallback/elemwise_helper/kimpl/min.h"
#include "src/fallback/elemwise_helper/kimpl/mul.h"
#include "src/fallback/elemwise_helper/kimpl/pow.h"
#include "src/fallback/elemwise_helper/kimpl/sub.h"
#include "src/fallback/elemwise_helper/kimpl/true_div.h"
//////////////////// quantization //////////////////////////////
namespace megdnn {
namespace fallback {
#define cb(op) \
template <> \
struct op<dt_qint8, dt_qint8> \
: BinaryQuantizationOp<dt_qint8, dt_qint8, op<float, float>> { \
using BinaryQuantizationOp< \
dt_qint8, dt_qint8, op<float, float>>::BinaryQuantizationOp; \
};
cb(TrueDivOp);
cb(FuseAddSigmoidOp);
cb(FuseAddTanhOp);
cb(FuseAddHSwishOp);
#undef cb
} // namespace fallback
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/fallback/elemwise_helper/op_common.h
*/
#pragma once
namespace megdnn {
/*!
* \brief broadcast type
* BCAST_x[0]x[1]...: x[i] == !stride[i]
*/
enum BcastType {
VEC,
VEC_VEC,
VEC_BCAST101,
VEC_BCASTX0X,
VEC_BCAST111C,
VEC_BCAST101xX,
VEC_SCALAR,
SCALAR_VEC,
BCAST101_VEC,
BCASTX0X_VEC,
BCAST111C_VEC,
BCAST101xX_VEC,
VEC_VEC_VEC,
VEC_VEC_SCALAR,
BCAST101_VEC_BCAST101,
BCAST111C_VEC_BCAST111C,
BCAST101xX_VEC_BCAST101xX,
VEC_BCAST101_VEC,
VEC_BCAST111C_VEC,
VEC_BCAST101xX_VEC,
VEC_SCALAR_VEC,
VEC_SCALAR_SCALAR,
UNKNOWN_BCAST_TYPE
};
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/fallback/elemwise_helper/op_ternary.h
*/
#pragma once
#include "src/fallback/elemwise_helper/kimpl/fuse_mul_add3.h"
//////////////////// quantization //////////////////////////////
namespace megdnn {
namespace fallback {
#define cb(op) \
template <> \
struct op<dt_qint8, dt_qint8> \
: TernaryQuantizationOp<dt_qint8, dt_qint8, op<float, float>> { \
using TernaryQuantizationOp< \
dt_qint8, dt_qint8, op<float, float>>::TernaryQuantizationOp; \
};
cb(FuseMulAdd3Op);
#undef cb
} // namespace fallback
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/fallback/elemwise_helper/op_unary.h
*/
#pragma once
#include "src/fallback/elemwise_helper/kimpl/abs.h"
#include "src/fallback/elemwise_helper/kimpl/exp.h"
#include "src/fallback/elemwise_helper/kimpl/fast_tanh.h"
#include "src/fallback/elemwise_helper/kimpl/hswish.h"
#include "src/fallback/elemwise_helper/kimpl/none.h"
#include "src/fallback/elemwise_helper/kimpl/relu.h"
#include "src/fallback/elemwise_helper/kimpl/sigmoid.h"
#include "src/fallback/elemwise_helper/kimpl/tanh.h"
#include "src/fallback/elemwise_helper/kimpl/typecvt.h"
//////////////////// quantization //////////////////////////////
namespace megdnn {
namespace fallback {
#define cb(op) \
template <> \
struct op<dt_qint8, dt_qint8> \
: UnaryQuantizationOp<dt_qint8, dt_qint8, op<float, float>> { \
using UnaryQuantizationOp< \
dt_qint8, dt_qint8, op<float, float>>::UnaryQuantizationOp; \
};
cb(SigmoidOp);
cb(ExpOp);
cb(TanhOp);
cb(FastTanhOp);
cb(HSwishOp);
#undef cb
} // namespace fallback
} // namespace megdnn
// vim: syntax=cpp.doxygen
此差异已折叠。
......@@ -19,7 +19,7 @@
#include <windows.h>
#else
#if defined(__arm__) || defined(__aarch64__)
#include <arm_neon.h>
#include "src/arm_common/simd_macro/marm_neon.h"
#endif
#if defined(__x86_64__) || defined(__i386__)
#include <cpuid.h>
......
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册