Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
f509b1be
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
f509b1be
编写于
1月 11, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(build): split elemwise_multi_type cpp
GitOrigin-RevId: 13267e9db6fa3194291965f50fe08eb892815e8a
上级
3252016e
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
230 addition
and
205 deletion
+230
-205
dnn/src/naive/elemwise_multi_type/opr_impl.h
dnn/src/naive/elemwise_multi_type/opr_impl.h
+117
-7
dnn/src/naive/elemwise_multi_type/opr_impl_3.cpp
dnn/src/naive/elemwise_multi_type/opr_impl_3.cpp
+0
-198
dnn/src/naive/elemwise_multi_type/opr_impl_4.cpp
dnn/src/naive/elemwise_multi_type/opr_impl_4.cpp
+68
-0
dnn/src/naive/elemwise_multi_type/opr_impl_5.cpp
dnn/src/naive/elemwise_multi_type/opr_impl_5.cpp
+45
-0
未找到文件。
dnn/src/naive/elemwise_multi_type/opr_impl.h
浏览文件 @
f509b1be
...
...
@@ -11,29 +11,139 @@
#pragma once
#include "megdnn/tensor_iter.h"
#include "src/common/elemwise_multi_type/opr_impl_helper.h"
#include "src/naive/handle.h"
namespace
megdnn
{
namespace
naive
{
class
ElemwiseMultiTypeImpl
:
public
ElemwiseMultiTypeImplHelper
{
template
<
typename
KernImpl
,
typename
ElemParam
>
void
dispatch_qint_op_dtype
(
const
ElemParam
&
param
,
const
TensorND
&
dst_tensor
);
template
<
typename
KernImpl
,
typename
src_ctype
,
typename
ElemParam
>
void
dispatch_add_qint_op_dst
(
const
ElemParam
&
param
,
const
TensorND
&
dst_tensor
);
void
dispatch_add_qint_op_dst
(
const
ElemParam
&
param
,
const
TensorND
&
dst
)
{
switch
(
dst
.
layout
.
dtype
.
enumv
())
{
#define cb(_dt) \
case DTypeTrait<_dt>::enumv: \
dispatch_add_qint_op<KernImpl, src_ctype, typename DTypeTrait<_dt>::ctype>( \
param, dst); \
break;
MEGDNN_FOREACH_QUANTIZED_DTYPE
(
cb
)
MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE
(
cb
)
#undef cb
default:
megdnn_assert
(
0
,
"not support %s %s
\n
"
,
param
[
0
].
layout
.
dtype
.
name
(),
dst
.
layout
.
dtype
.
name
());
}
}
template
<
typename
KernImpl
,
typename
ElemParam
>
void
dispatch_qint_op_dtype
(
const
ElemParam
&
param
,
const
TensorND
&
dst
)
{
switch
(
param
[
0
].
layout
.
dtype
.
enumv
())
{
#define cb(_dt) \
case DTypeTrait<_dt>::enumv: \
dispatch_add_qint_op_dst< \
KernImpl, typename DTypeTrait<_dt>::ctype, ElemParam>(param, dst); \
break;
MEGDNN_FOREACH_QUANTIZED_DTYPE
(
cb
)
MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE
(
cb
)
#undef cb
default:
megdnn_assert_internal
(
0
);
}
}
template
<
typename
KernImpl
,
typename
src_ctype
,
typename
dst_ctype
>
void
dispatch_add_qint_op
(
const
ElemwiseOpParamN
<
1
>&
param
,
const
TensorND
&
dst_tensor
);
const
ElemwiseOpParamN
<
1
>&
param
,
const
TensorND
&
dst_tensor
)
{
auto
src
=
param
[
0
];
auto
size
=
param
.
size
;
auto
work
=
[
src
,
size
,
dst_tensor
]()
{
auto
iA
=
tensor_iter_valonly
<
src_ctype
>
(
src
).
begin
();
auto
pD
=
tensor_iter_valonly
<
dst_ctype
>
(
dst_tensor
).
begin
();
auto
param0
=
src
.
layout
.
dtype
.
param
<
typename
DTypeTrait
<
src_ctype
>::
dtype
>
();
auto
dst_param
=
dst_tensor
.
layout
.
dtype
.
param
<
typename
DTypeTrait
<
dst_ctype
>::
dtype
>
();
for
(
size_t
i
=
0
;
i
<
size
;
i
++
)
{
src_ctype
a
=
*
iA
;
*
pD
=
dst_param
.
quantize
(
KernImpl
::
apply
(
param0
.
dequantize
(
a
)));
++
iA
;
++
pD
;
}
};
MEGDNN_DISPATCH_CPU_KERN_OPR
(
work
());
}
template
<
typename
KernImpl
,
typename
src_ctype
,
typename
dst_ctype
>
void
dispatch_add_qint_op
(
const
ElemwiseOpParamN
<
2
>&
param
,
const
TensorND
&
dst_tensor
);
const
ElemwiseOpParamN
<
2
>&
param
,
const
TensorND
&
dst_tensor
)
{
auto
size
=
param
.
size
;
auto
src0
=
param
[
0
];
auto
src1
=
param
[
1
];
auto
work
=
[
src0
,
src1
,
size
,
dst_tensor
]()
{
// This is needed as these iterators are captured as const value.
auto
iA
=
tensor_iter_valonly
<
src_ctype
>
(
src0
).
begin
();
auto
iB
=
tensor_iter_valonly
<
src_ctype
>
(
src1
).
begin
();
auto
pD
=
tensor_iter_valonly
<
dst_ctype
>
(
dst_tensor
).
begin
();
auto
param0
=
src0
.
layout
.
dtype
.
param
<
typename
DTypeTrait
<
src_ctype
>::
dtype
>
();
auto
param1
=
src1
.
layout
.
dtype
.
param
<
typename
DTypeTrait
<
src_ctype
>::
dtype
>
();
auto
dst_param
=
dst_tensor
.
layout
.
dtype
.
param
<
typename
DTypeTrait
<
dst_ctype
>::
dtype
>
();
for
(
size_t
i
=
0
;
i
<
size
;
i
++
)
{
src_ctype
a
=
*
iA
;
src_ctype
b
=
*
iB
;
*
pD
=
dst_param
.
quantize
(
KernImpl
::
apply
(
param0
.
dequantize
(
a
),
param1
.
dequantize
(
b
)));
++
iA
;
++
iB
;
++
pD
;
}
};
MEGDNN_DISPATCH_CPU_KERN_OPR
(
work
());
}
template
<
typename
KernImpl
,
typename
src_ctype
,
typename
dst_ctype
>
void
dispatch_add_qint_op
(
const
ElemwiseOpParamN
<
3
>&
param
,
const
TensorND
&
dst_tensor
);
const
ElemwiseOpParamN
<
3
>&
param
,
const
TensorND
&
dst_tensor
)
{
auto
size
=
param
.
size
;
auto
src0
=
param
[
0
];
auto
src1
=
param
[
1
];
auto
src2
=
param
[
2
];
auto
work
=
[
src0
,
src1
,
src2
,
size
,
dst_tensor
]()
{
// This is needed as these iterators are captured as const value.
auto
iA
=
tensor_iter_valonly
<
src_ctype
>
(
src0
).
begin
();
auto
iB
=
tensor_iter_valonly
<
src_ctype
>
(
src1
).
begin
();
auto
iC
=
tensor_iter_valonly
<
src_ctype
>
(
src2
).
begin
();
auto
pD
=
tensor_iter_valonly
<
dst_ctype
>
(
dst_tensor
).
begin
();
auto
param0
=
src0
.
layout
.
dtype
.
param
<
typename
DTypeTrait
<
src_ctype
>::
dtype
>
();
auto
param1
=
src1
.
layout
.
dtype
.
param
<
typename
DTypeTrait
<
src_ctype
>::
dtype
>
();
auto
param2
=
src2
.
layout
.
dtype
.
param
<
typename
DTypeTrait
<
src_ctype
>::
dtype
>
();
auto
dst_param
=
dst_tensor
.
layout
.
dtype
.
param
<
typename
DTypeTrait
<
dst_ctype
>::
dtype
>
();
for
(
size_t
i
=
0
;
i
<
size
;
i
++
)
{
src_ctype
a
=
*
iA
;
src_ctype
b
=
*
iB
;
src_ctype
c
=
*
iC
;
*
pD
=
dst_param
.
quantize
(
KernImpl
::
apply
(
param0
.
dequantize
(
a
),
param1
.
dequantize
(
b
),
param2
.
dequantize
(
c
)));
++
iA
;
++
iB
;
++
iC
;
++
pD
;
}
};
MEGDNN_DISPATCH_CPU_KERN_OPR
(
work
());
}
protected:
template
<
typename
ctype
>
...
...
dnn/src/naive/elemwise_multi_type/opr_impl_3.cpp
浏览文件 @
f509b1be
...
...
@@ -10,135 +10,12 @@
*/
#include "./opr_impl.h"
#include "megdnn/tensor_iter.h"
#include "src/common/elemwise/kern_defs.cuh"
#include "src/common/elemwise_multi_type/kern_defs.cuh"
#include "src/naive/handle.h"
using
namespace
megdnn
;
using
namespace
naive
;
template
<
typename
KernImpl
,
typename
src_ctype
,
typename
dst_ctype
>
void
ElemwiseMultiTypeImpl
::
dispatch_add_qint_op
(
const
ElemwiseOpParamN
<
1
>&
param
,
const
TensorND
&
dst_tensor
)
{
auto
src
=
param
[
0
];
auto
size
=
param
.
size
;
auto
work
=
[
src
,
size
,
dst_tensor
]()
{
auto
iA
=
tensor_iter_valonly
<
src_ctype
>
(
src
).
begin
();
auto
pD
=
tensor_iter_valonly
<
dst_ctype
>
(
dst_tensor
).
begin
();
auto
param0
=
src
.
layout
.
dtype
.
param
<
typename
DTypeTrait
<
src_ctype
>::
dtype
>
();
auto
dst_param
=
dst_tensor
.
layout
.
dtype
.
param
<
typename
DTypeTrait
<
dst_ctype
>::
dtype
>
();
for
(
size_t
i
=
0
;
i
<
size
;
i
++
)
{
src_ctype
a
=
*
iA
;
*
pD
=
dst_param
.
quantize
(
KernImpl
::
apply
(
param0
.
dequantize
(
a
)));
++
iA
;
++
pD
;
}
};
MEGDNN_DISPATCH_CPU_KERN_OPR
(
work
());
}
template
<
typename
KernImpl
,
typename
src_ctype
,
typename
dst_ctype
>
void
ElemwiseMultiTypeImpl
::
dispatch_add_qint_op
(
const
ElemwiseOpParamN
<
2
>&
param
,
const
TensorND
&
dst_tensor
)
{
auto
size
=
param
.
size
;
auto
src0
=
param
[
0
];
auto
src1
=
param
[
1
];
auto
work
=
[
src0
,
src1
,
size
,
dst_tensor
]()
{
// This is needed as these iterators are captured as const value.
auto
iA
=
tensor_iter_valonly
<
src_ctype
>
(
src0
).
begin
();
auto
iB
=
tensor_iter_valonly
<
src_ctype
>
(
src1
).
begin
();
auto
pD
=
tensor_iter_valonly
<
dst_ctype
>
(
dst_tensor
).
begin
();
auto
param0
=
src0
.
layout
.
dtype
.
param
<
typename
DTypeTrait
<
src_ctype
>::
dtype
>
();
auto
param1
=
src1
.
layout
.
dtype
.
param
<
typename
DTypeTrait
<
src_ctype
>::
dtype
>
();
auto
dst_param
=
dst_tensor
.
layout
.
dtype
.
param
<
typename
DTypeTrait
<
dst_ctype
>::
dtype
>
();
for
(
size_t
i
=
0
;
i
<
size
;
i
++
)
{
src_ctype
a
=
*
iA
;
src_ctype
b
=
*
iB
;
*
pD
=
dst_param
.
quantize
(
KernImpl
::
apply
(
param0
.
dequantize
(
a
),
param1
.
dequantize
(
b
)));
++
iA
;
++
iB
;
++
pD
;
}
};
MEGDNN_DISPATCH_CPU_KERN_OPR
(
work
());
}
template
<
typename
KernImpl
,
typename
src_ctype
,
typename
dst_ctype
>
void
ElemwiseMultiTypeImpl
::
dispatch_add_qint_op
(
const
ElemwiseOpParamN
<
3
>&
param
,
const
TensorND
&
dst_tensor
)
{
auto
size
=
param
.
size
;
auto
src0
=
param
[
0
];
auto
src1
=
param
[
1
];
auto
src2
=
param
[
2
];
auto
work
=
[
src0
,
src1
,
src2
,
size
,
dst_tensor
]()
{
// This is needed as these iterators are captured as const value.
auto
iA
=
tensor_iter_valonly
<
src_ctype
>
(
src0
).
begin
();
auto
iB
=
tensor_iter_valonly
<
src_ctype
>
(
src1
).
begin
();
auto
iC
=
tensor_iter_valonly
<
src_ctype
>
(
src2
).
begin
();
auto
pD
=
tensor_iter_valonly
<
dst_ctype
>
(
dst_tensor
).
begin
();
auto
param0
=
src0
.
layout
.
dtype
.
param
<
typename
DTypeTrait
<
src_ctype
>::
dtype
>
();
auto
param1
=
src1
.
layout
.
dtype
.
param
<
typename
DTypeTrait
<
src_ctype
>::
dtype
>
();
auto
param2
=
src2
.
layout
.
dtype
.
param
<
typename
DTypeTrait
<
src_ctype
>::
dtype
>
();
auto
dst_param
=
dst_tensor
.
layout
.
dtype
.
param
<
typename
DTypeTrait
<
dst_ctype
>::
dtype
>
();
for
(
size_t
i
=
0
;
i
<
size
;
i
++
)
{
src_ctype
a
=
*
iA
;
src_ctype
b
=
*
iB
;
src_ctype
c
=
*
iC
;
*
pD
=
dst_param
.
quantize
(
KernImpl
::
apply
(
param0
.
dequantize
(
a
),
param1
.
dequantize
(
b
),
param2
.
dequantize
(
c
)));
++
iA
;
++
iB
;
++
iC
;
++
pD
;
}
};
MEGDNN_DISPATCH_CPU_KERN_OPR
(
work
());
}
template
<
typename
KernImpl
,
typename
src_ctype
,
typename
ElemParam
>
void
ElemwiseMultiTypeImpl
::
dispatch_add_qint_op_dst
(
const
ElemParam
&
param
,
const
TensorND
&
dst
)
{
switch
(
dst
.
layout
.
dtype
.
enumv
())
{
#define cb(_dt) \
case DTypeTrait<_dt>::enumv: \
dispatch_add_qint_op<KernImpl, src_ctype, typename DTypeTrait<_dt>::ctype>( \
param, dst); \
break;
MEGDNN_FOREACH_QUANTIZED_DTYPE
(
cb
)
MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE
(
cb
)
#undef cb
default:
megdnn_assert
(
0
,
"not support %s %s
\n
"
,
param
[
0
].
layout
.
dtype
.
name
(),
dst
.
layout
.
dtype
.
name
());
}
}
template
<
typename
KernImpl
,
typename
ElemParam
>
void
ElemwiseMultiTypeImpl
::
dispatch_qint_op_dtype
(
const
ElemParam
&
param
,
const
TensorND
&
dst
)
{
switch
(
param
[
0
].
layout
.
dtype
.
enumv
())
{
#define cb(_dt) \
case DTypeTrait<_dt>::enumv: \
dispatch_add_qint_op_dst< \
KernImpl, typename DTypeTrait<_dt>::ctype, ElemParam>(param, dst); \
break;
MEGDNN_FOREACH_QUANTIZED_DTYPE
(
cb
)
MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE
(
cb
)
#undef cb
default:
megdnn_assert_internal
(
0
);
}
}
void
ElemwiseMultiTypeImpl
::
on_quantized_mode
(
const
ElemwiseOpParamN
<
1
>&
param
,
const
TensorND
&
dst
,
Elemwise
::
Mode
mode
)
{
megdnn_assert
(
param
[
0
].
layout
.
dtype
.
category
()
==
DTypeCategory
::
QUANTIZED
);
...
...
@@ -182,79 +59,4 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
}
}
void
ElemwiseMultiTypeImpl
::
on_quantized_mode
(
const
ElemwiseOpParamN
<
2
>&
param
,
const
TensorND
&
dst
,
Elemwise
::
Mode
mode
)
{
megdnn_assert
(
param
[
0
].
layout
.
dtype
.
enumv
()
==
param
[
1
].
layout
.
dtype
.
enumv
()
&&
param
[
0
].
layout
.
dtype
.
category
()
==
DTypeCategory
::
QUANTIZED
);
megdnn_assert
(
dst
.
layout
.
dtype
.
category
()
==
DTypeCategory
::
QUANTIZED
);
switch
(
mode
)
{
#define DISPATCH(_mode) \
case Elemwise::Mode::_mode: { \
typedef ElemwiseKern< \
megcorePlatformCPU, param_enumv::Elemwise::Mode::_mode, float> \
KernImpl; \
dispatch_qint_op_dtype<KernImpl, ElemwiseOpParamN<2>>(param, dst); \
break; \
}
DISPATCH
(
ABS_GRAD
);
DISPATCH
(
ADD
);
DISPATCH
(
FLOOR_DIV
);
DISPATCH
(
MAX
);
DISPATCH
(
MIN
);
DISPATCH
(
MOD
);
DISPATCH
(
MUL
);
DISPATCH
(
POW
);
DISPATCH
(
SIGMOID_GRAD
);
DISPATCH
(
SUB
);
DISPATCH
(
SWITCH_GT0
);
DISPATCH
(
TANH_GRAD
);
DISPATCH
(
TRUE_DIV
);
DISPATCH
(
LOG_SUM_EXP
);
DISPATCH
(
LT
);
DISPATCH
(
LEQ
);
DISPATCH
(
EQ
);
DISPATCH
(
FUSE_ADD_RELU
);
DISPATCH
(
FUSE_ADD_SIGMOID
);
DISPATCH
(
FUSE_ADD_TANH
);
DISPATCH
(
FAST_TANH_GRAD
);
DISPATCH
(
ATAN2
);
DISPATCH
(
H_SWISH_GRAD
);
DISPATCH
(
FUSE_ADD_H_SWISH
);
#undef DISPATCH
default:
megdnn_assert_internal
(
0
);
}
}
void
ElemwiseMultiTypeImpl
::
on_quantized_mode
(
const
ElemwiseOpParamN
<
3
>&
param
,
const
TensorND
&
dst
,
Elemwise
::
Mode
mode
)
{
megdnn_assert
(
param
[
0
].
layout
.
dtype
.
category
()
==
DTypeCategory
::
QUANTIZED
&&
param
[
0
].
layout
.
dtype
.
category
()
==
param
[
1
].
layout
.
dtype
.
category
()
&&
param
[
0
].
layout
.
dtype
.
category
()
==
param
[
2
].
layout
.
dtype
.
category
());
megdnn_assert
(
dst
.
layout
.
dtype
.
category
()
==
DTypeCategory
::
QUANTIZED
);
switch
(
mode
)
{
#define DISPATCH(_mode) \
case Elemwise::Mode::_mode: { \
typedef ElemwiseKern< \
megcorePlatformCPU, param_enumv::Elemwise::Mode::_mode, float> \
KernImpl; \
dispatch_qint_op_dtype<KernImpl, ElemwiseOpParamN<3>>(param, dst); \
break; \
}
DISPATCH
(
FUSE_MUL_ADD3
);
DISPATCH
(
COND_LEQ_MOV
);
#undef DISPATCH
default:
megdnn_assert_internal
(
0
);
}
}
// vim: syntax=cpp.doxygen
dnn/src/naive/elemwise_multi_type/opr_impl_4.cpp
0 → 100644
浏览文件 @
f509b1be
/**
* \file dnn/src/naive/elemwise_multi_type/opr_impl_4.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "./opr_impl.h"
#include "src/common/elemwise/kern_defs.cuh"
#include "src/common/elemwise_multi_type/kern_defs.cuh"
using
namespace
megdnn
;
using
namespace
naive
;
void
ElemwiseMultiTypeImpl
::
on_quantized_mode
(
const
ElemwiseOpParamN
<
2
>&
param
,
const
TensorND
&
dst
,
Elemwise
::
Mode
mode
)
{
megdnn_assert
(
param
[
0
].
layout
.
dtype
.
enumv
()
==
param
[
1
].
layout
.
dtype
.
enumv
()
&&
param
[
0
].
layout
.
dtype
.
category
()
==
DTypeCategory
::
QUANTIZED
);
megdnn_assert
(
dst
.
layout
.
dtype
.
category
()
==
DTypeCategory
::
QUANTIZED
);
switch
(
mode
)
{
#define DISPATCH(_mode) \
case Elemwise::Mode::_mode: { \
typedef ElemwiseKern< \
megcorePlatformCPU, param_enumv::Elemwise::Mode::_mode, float> \
KernImpl; \
dispatch_qint_op_dtype<KernImpl, ElemwiseOpParamN<2>>(param, dst); \
break; \
}
DISPATCH
(
ABS_GRAD
);
DISPATCH
(
ADD
);
DISPATCH
(
FLOOR_DIV
);
DISPATCH
(
MAX
);
DISPATCH
(
MIN
);
DISPATCH
(
MOD
);
DISPATCH
(
MUL
);
DISPATCH
(
POW
);
DISPATCH
(
SIGMOID_GRAD
);
DISPATCH
(
SUB
);
DISPATCH
(
SWITCH_GT0
);
DISPATCH
(
TANH_GRAD
);
DISPATCH
(
TRUE_DIV
);
DISPATCH
(
LOG_SUM_EXP
);
DISPATCH
(
LT
);
DISPATCH
(
LEQ
);
DISPATCH
(
EQ
);
DISPATCH
(
FUSE_ADD_RELU
);
DISPATCH
(
FUSE_ADD_SIGMOID
);
DISPATCH
(
FUSE_ADD_TANH
);
DISPATCH
(
FAST_TANH_GRAD
);
DISPATCH
(
ATAN2
);
DISPATCH
(
H_SWISH_GRAD
);
DISPATCH
(
FUSE_ADD_H_SWISH
);
#undef DISPATCH
default:
megdnn_assert_internal
(
0
);
}
}
// vim: syntax=cpp.doxygen
dnn/src/naive/elemwise_multi_type/opr_impl_5.cpp
0 → 100644
浏览文件 @
f509b1be
/**
* \file dnn/src/naive/elemwise_multi_type/opr_impl_5.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "./opr_impl.h"
#include "src/common/elemwise/kern_defs.cuh"
#include "src/common/elemwise_multi_type/kern_defs.cuh"
using
namespace
megdnn
;
using
namespace
naive
;
void
ElemwiseMultiTypeImpl
::
on_quantized_mode
(
const
ElemwiseOpParamN
<
3
>&
param
,
const
TensorND
&
dst
,
Elemwise
::
Mode
mode
)
{
megdnn_assert
(
param
[
0
].
layout
.
dtype
.
category
()
==
DTypeCategory
::
QUANTIZED
&&
param
[
0
].
layout
.
dtype
.
category
()
==
param
[
1
].
layout
.
dtype
.
category
()
&&
param
[
0
].
layout
.
dtype
.
category
()
==
param
[
2
].
layout
.
dtype
.
category
());
megdnn_assert
(
dst
.
layout
.
dtype
.
category
()
==
DTypeCategory
::
QUANTIZED
);
switch
(
mode
)
{
#define DISPATCH(_mode) \
case Elemwise::Mode::_mode: { \
typedef ElemwiseKern< \
megcorePlatformCPU, param_enumv::Elemwise::Mode::_mode, float> \
KernImpl; \
dispatch_qint_op_dtype<KernImpl, ElemwiseOpParamN<3>>(param, dst); \
break; \
}
DISPATCH
(
FUSE_MUL_ADD3
);
DISPATCH
(
COND_LEQ_MOV
);
#undef DISPATCH
default:
megdnn_assert_internal
(
0
);
}
}
// vim: syntax=cpp.doxygen
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录