提交 f509b1be 编写于 作者: M Megvii Engine Team

fix(build): split elemwise_multi_type cpp

GitOrigin-RevId: 13267e9db6fa3194291965f50fe08eb892815e8a
上级 3252016e
......@@ -11,29 +11,139 @@
#pragma once
#include "megdnn/tensor_iter.h"
#include "src/common/elemwise_multi_type/opr_impl_helper.h"
#include "src/naive/handle.h"
namespace megdnn {
namespace naive {
class ElemwiseMultiTypeImpl : public ElemwiseMultiTypeImplHelper {
template <typename KernImpl, typename ElemParam>
void dispatch_qint_op_dtype(const ElemParam& param, const TensorND& dst_tensor);
template <typename KernImpl, typename src_ctype, typename ElemParam>
void dispatch_add_qint_op_dst(const ElemParam& param, const TensorND& dst_tensor);
void dispatch_add_qint_op_dst(const ElemParam& param, const TensorND& dst) {
switch (dst.layout.dtype.enumv()) {
#define cb(_dt) \
case DTypeTrait<_dt>::enumv: \
dispatch_add_qint_op<KernImpl, src_ctype, typename DTypeTrait<_dt>::ctype>( \
param, dst); \
break;
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb)
MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb)
#undef cb
default:
megdnn_assert(
0, "not support %s %s\n", param[0].layout.dtype.name(),
dst.layout.dtype.name());
}
}
template <typename KernImpl, typename ElemParam>
void dispatch_qint_op_dtype(const ElemParam& param, const TensorND& dst) {
switch (param[0].layout.dtype.enumv()) {
#define cb(_dt) \
case DTypeTrait<_dt>::enumv: \
dispatch_add_qint_op_dst< \
KernImpl, typename DTypeTrait<_dt>::ctype, ElemParam>(param, dst); \
break;
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb)
MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb)
#undef cb
default:
megdnn_assert_internal(0);
}
}
template <typename KernImpl, typename src_ctype, typename dst_ctype>
void dispatch_add_qint_op(
const ElemwiseOpParamN<1>& param, const TensorND& dst_tensor);
const ElemwiseOpParamN<1>& param, const TensorND& dst_tensor) {
auto src = param[0];
auto size = param.size;
auto work = [src, size, dst_tensor]() {
auto iA = tensor_iter_valonly<src_ctype>(src).begin();
auto pD = tensor_iter_valonly<dst_ctype>(dst_tensor).begin();
auto param0 =
src.layout.dtype.param<typename DTypeTrait<src_ctype>::dtype>();
auto dst_param = dst_tensor.layout.dtype
.param<typename DTypeTrait<dst_ctype>::dtype>();
for (size_t i = 0; i < size; i++) {
src_ctype a = *iA;
*pD = dst_param.quantize(KernImpl::apply(param0.dequantize(a)));
++iA;
++pD;
}
};
MEGDNN_DISPATCH_CPU_KERN_OPR(work());
}
template <typename KernImpl, typename src_ctype, typename dst_ctype>
void dispatch_add_qint_op(
const ElemwiseOpParamN<2>& param, const TensorND& dst_tensor);
const ElemwiseOpParamN<2>& param, const TensorND& dst_tensor) {
auto size = param.size;
auto src0 = param[0];
auto src1 = param[1];
auto work = [src0, src1, size, dst_tensor]() {
// This is needed as these iterators are captured as const value.
auto iA = tensor_iter_valonly<src_ctype>(src0).begin();
auto iB = tensor_iter_valonly<src_ctype>(src1).begin();
auto pD = tensor_iter_valonly<dst_ctype>(dst_tensor).begin();
auto param0 =
src0.layout.dtype.param<typename DTypeTrait<src_ctype>::dtype>();
auto param1 =
src1.layout.dtype.param<typename DTypeTrait<src_ctype>::dtype>();
auto dst_param = dst_tensor.layout.dtype
.param<typename DTypeTrait<dst_ctype>::dtype>();
for (size_t i = 0; i < size; i++) {
src_ctype a = *iA;
src_ctype b = *iB;
*pD = dst_param.quantize(
KernImpl::apply(param0.dequantize(a), param1.dequantize(b)));
++iA;
++iB;
++pD;
}
};
MEGDNN_DISPATCH_CPU_KERN_OPR(work());
}
template <typename KernImpl, typename src_ctype, typename dst_ctype>
void dispatch_add_qint_op(
const ElemwiseOpParamN<3>& param, const TensorND& dst_tensor);
const ElemwiseOpParamN<3>& param, const TensorND& dst_tensor) {
auto size = param.size;
auto src0 = param[0];
auto src1 = param[1];
auto src2 = param[2];
auto work = [src0, src1, src2, size, dst_tensor]() {
// This is needed as these iterators are captured as const value.
auto iA = tensor_iter_valonly<src_ctype>(src0).begin();
auto iB = tensor_iter_valonly<src_ctype>(src1).begin();
auto iC = tensor_iter_valonly<src_ctype>(src2).begin();
auto pD = tensor_iter_valonly<dst_ctype>(dst_tensor).begin();
auto param0 =
src0.layout.dtype.param<typename DTypeTrait<src_ctype>::dtype>();
auto param1 =
src1.layout.dtype.param<typename DTypeTrait<src_ctype>::dtype>();
auto param2 =
src2.layout.dtype.param<typename DTypeTrait<src_ctype>::dtype>();
auto dst_param = dst_tensor.layout.dtype
.param<typename DTypeTrait<dst_ctype>::dtype>();
for (size_t i = 0; i < size; i++) {
src_ctype a = *iA;
src_ctype b = *iB;
src_ctype c = *iC;
*pD = dst_param.quantize(KernImpl::apply(
param0.dequantize(a), param1.dequantize(b),
param2.dequantize(c)));
++iA;
++iB;
++iC;
++pD;
}
};
MEGDNN_DISPATCH_CPU_KERN_OPR(work());
}
protected:
template <typename ctype>
......
......@@ -10,135 +10,12 @@
*/
#include "./opr_impl.h"
#include "megdnn/tensor_iter.h"
#include "src/common/elemwise/kern_defs.cuh"
#include "src/common/elemwise_multi_type/kern_defs.cuh"
#include "src/naive/handle.h"
using namespace megdnn;
using namespace naive;
template <typename KernImpl, typename src_ctype, typename dst_ctype>
void ElemwiseMultiTypeImpl::dispatch_add_qint_op(
const ElemwiseOpParamN<1>& param, const TensorND& dst_tensor) {
auto src = param[0];
auto size = param.size;
auto work = [src, size, dst_tensor]() {
auto iA = tensor_iter_valonly<src_ctype>(src).begin();
auto pD = tensor_iter_valonly<dst_ctype>(dst_tensor).begin();
auto param0 = src.layout.dtype.param<typename DTypeTrait<src_ctype>::dtype>();
auto dst_param =
dst_tensor.layout.dtype.param<typename DTypeTrait<dst_ctype>::dtype>();
for (size_t i = 0; i < size; i++) {
src_ctype a = *iA;
*pD = dst_param.quantize(KernImpl::apply(param0.dequantize(a)));
++iA;
++pD;
}
};
MEGDNN_DISPATCH_CPU_KERN_OPR(work());
}
template <typename KernImpl, typename src_ctype, typename dst_ctype>
void ElemwiseMultiTypeImpl::dispatch_add_qint_op(
const ElemwiseOpParamN<2>& param, const TensorND& dst_tensor) {
auto size = param.size;
auto src0 = param[0];
auto src1 = param[1];
auto work = [src0, src1, size, dst_tensor]() {
// This is needed as these iterators are captured as const value.
auto iA = tensor_iter_valonly<src_ctype>(src0).begin();
auto iB = tensor_iter_valonly<src_ctype>(src1).begin();
auto pD = tensor_iter_valonly<dst_ctype>(dst_tensor).begin();
auto param0 = src0.layout.dtype.param<typename DTypeTrait<src_ctype>::dtype>();
auto param1 = src1.layout.dtype.param<typename DTypeTrait<src_ctype>::dtype>();
auto dst_param =
dst_tensor.layout.dtype.param<typename DTypeTrait<dst_ctype>::dtype>();
for (size_t i = 0; i < size; i++) {
src_ctype a = *iA;
src_ctype b = *iB;
*pD = dst_param.quantize(
KernImpl::apply(param0.dequantize(a), param1.dequantize(b)));
++iA;
++iB;
++pD;
}
};
MEGDNN_DISPATCH_CPU_KERN_OPR(work());
}
template <typename KernImpl, typename src_ctype, typename dst_ctype>
void ElemwiseMultiTypeImpl::dispatch_add_qint_op(
const ElemwiseOpParamN<3>& param, const TensorND& dst_tensor) {
auto size = param.size;
auto src0 = param[0];
auto src1 = param[1];
auto src2 = param[2];
auto work = [src0, src1, src2, size, dst_tensor]() {
// This is needed as these iterators are captured as const value.
auto iA = tensor_iter_valonly<src_ctype>(src0).begin();
auto iB = tensor_iter_valonly<src_ctype>(src1).begin();
auto iC = tensor_iter_valonly<src_ctype>(src2).begin();
auto pD = tensor_iter_valonly<dst_ctype>(dst_tensor).begin();
auto param0 = src0.layout.dtype.param<typename DTypeTrait<src_ctype>::dtype>();
auto param1 = src1.layout.dtype.param<typename DTypeTrait<src_ctype>::dtype>();
auto param2 = src2.layout.dtype.param<typename DTypeTrait<src_ctype>::dtype>();
auto dst_param =
dst_tensor.layout.dtype.param<typename DTypeTrait<dst_ctype>::dtype>();
for (size_t i = 0; i < size; i++) {
src_ctype a = *iA;
src_ctype b = *iB;
src_ctype c = *iC;
*pD = dst_param.quantize(KernImpl::apply(
param0.dequantize(a), param1.dequantize(b), param2.dequantize(c)));
++iA;
++iB;
++iC;
++pD;
}
};
MEGDNN_DISPATCH_CPU_KERN_OPR(work());
}
template <typename KernImpl, typename src_ctype, typename ElemParam>
void ElemwiseMultiTypeImpl::dispatch_add_qint_op_dst(
const ElemParam& param, const TensorND& dst) {
switch (dst.layout.dtype.enumv()) {
#define cb(_dt) \
case DTypeTrait<_dt>::enumv: \
dispatch_add_qint_op<KernImpl, src_ctype, typename DTypeTrait<_dt>::ctype>( \
param, dst); \
break;
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb)
MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb)
#undef cb
default:
megdnn_assert(
0, "not support %s %s\n", param[0].layout.dtype.name(),
dst.layout.dtype.name());
}
}
template <typename KernImpl, typename ElemParam>
void ElemwiseMultiTypeImpl::dispatch_qint_op_dtype(
const ElemParam& param, const TensorND& dst) {
switch (param[0].layout.dtype.enumv()) {
#define cb(_dt) \
case DTypeTrait<_dt>::enumv: \
dispatch_add_qint_op_dst< \
KernImpl, typename DTypeTrait<_dt>::ctype, ElemParam>(param, dst); \
break;
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb)
MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb)
#undef cb
default:
megdnn_assert_internal(0);
}
}
void ElemwiseMultiTypeImpl::on_quantized_mode(
const ElemwiseOpParamN<1>& param, const TensorND& dst, Elemwise::Mode mode) {
megdnn_assert(param[0].layout.dtype.category() == DTypeCategory::QUANTIZED);
......@@ -182,79 +59,4 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
}
}
void ElemwiseMultiTypeImpl::on_quantized_mode(
const ElemwiseOpParamN<2>& param, const TensorND& dst, Elemwise::Mode mode) {
megdnn_assert(
param[0].layout.dtype.enumv() == param[1].layout.dtype.enumv() &&
param[0].layout.dtype.category() == DTypeCategory::QUANTIZED);
megdnn_assert(dst.layout.dtype.category() == DTypeCategory::QUANTIZED);
switch (mode) {
#define DISPATCH(_mode) \
case Elemwise::Mode::_mode: { \
typedef ElemwiseKern< \
megcorePlatformCPU, param_enumv::Elemwise::Mode::_mode, float> \
KernImpl; \
dispatch_qint_op_dtype<KernImpl, ElemwiseOpParamN<2>>(param, dst); \
break; \
}
DISPATCH(ABS_GRAD);
DISPATCH(ADD);
DISPATCH(FLOOR_DIV);
DISPATCH(MAX);
DISPATCH(MIN);
DISPATCH(MOD);
DISPATCH(MUL);
DISPATCH(POW);
DISPATCH(SIGMOID_GRAD);
DISPATCH(SUB);
DISPATCH(SWITCH_GT0);
DISPATCH(TANH_GRAD);
DISPATCH(TRUE_DIV);
DISPATCH(LOG_SUM_EXP);
DISPATCH(LT);
DISPATCH(LEQ);
DISPATCH(EQ);
DISPATCH(FUSE_ADD_RELU);
DISPATCH(FUSE_ADD_SIGMOID);
DISPATCH(FUSE_ADD_TANH);
DISPATCH(FAST_TANH_GRAD);
DISPATCH(ATAN2);
DISPATCH(H_SWISH_GRAD);
DISPATCH(FUSE_ADD_H_SWISH);
#undef DISPATCH
default:
megdnn_assert_internal(0);
}
}
void ElemwiseMultiTypeImpl::on_quantized_mode(
const ElemwiseOpParamN<3>& param, const TensorND& dst, Elemwise::Mode mode) {
megdnn_assert(
param[0].layout.dtype.category() == DTypeCategory::QUANTIZED &&
param[0].layout.dtype.category() == param[1].layout.dtype.category() &&
param[0].layout.dtype.category() == param[2].layout.dtype.category());
megdnn_assert(dst.layout.dtype.category() == DTypeCategory::QUANTIZED);
switch (mode) {
#define DISPATCH(_mode) \
case Elemwise::Mode::_mode: { \
typedef ElemwiseKern< \
megcorePlatformCPU, param_enumv::Elemwise::Mode::_mode, float> \
KernImpl; \
dispatch_qint_op_dtype<KernImpl, ElemwiseOpParamN<3>>(param, dst); \
break; \
}
DISPATCH(FUSE_MUL_ADD3);
DISPATCH(COND_LEQ_MOV);
#undef DISPATCH
default:
megdnn_assert_internal(0);
}
}
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/naive/elemwise_multi_type/opr_impl_4.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "./opr_impl.h"
#include "src/common/elemwise/kern_defs.cuh"
#include "src/common/elemwise_multi_type/kern_defs.cuh"
using namespace megdnn;
using namespace naive;
void ElemwiseMultiTypeImpl::on_quantized_mode(
const ElemwiseOpParamN<2>& param, const TensorND& dst, Elemwise::Mode mode) {
megdnn_assert(
param[0].layout.dtype.enumv() == param[1].layout.dtype.enumv() &&
param[0].layout.dtype.category() == DTypeCategory::QUANTIZED);
megdnn_assert(dst.layout.dtype.category() == DTypeCategory::QUANTIZED);
switch (mode) {
#define DISPATCH(_mode) \
case Elemwise::Mode::_mode: { \
typedef ElemwiseKern< \
megcorePlatformCPU, param_enumv::Elemwise::Mode::_mode, float> \
KernImpl; \
dispatch_qint_op_dtype<KernImpl, ElemwiseOpParamN<2>>(param, dst); \
break; \
}
DISPATCH(ABS_GRAD);
DISPATCH(ADD);
DISPATCH(FLOOR_DIV);
DISPATCH(MAX);
DISPATCH(MIN);
DISPATCH(MOD);
DISPATCH(MUL);
DISPATCH(POW);
DISPATCH(SIGMOID_GRAD);
DISPATCH(SUB);
DISPATCH(SWITCH_GT0);
DISPATCH(TANH_GRAD);
DISPATCH(TRUE_DIV);
DISPATCH(LOG_SUM_EXP);
DISPATCH(LT);
DISPATCH(LEQ);
DISPATCH(EQ);
DISPATCH(FUSE_ADD_RELU);
DISPATCH(FUSE_ADD_SIGMOID);
DISPATCH(FUSE_ADD_TANH);
DISPATCH(FAST_TANH_GRAD);
DISPATCH(ATAN2);
DISPATCH(H_SWISH_GRAD);
DISPATCH(FUSE_ADD_H_SWISH);
#undef DISPATCH
default:
megdnn_assert_internal(0);
}
}
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/naive/elemwise_multi_type/opr_impl_5.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "./opr_impl.h"
#include "src/common/elemwise/kern_defs.cuh"
#include "src/common/elemwise_multi_type/kern_defs.cuh"
using namespace megdnn;
using namespace naive;
void ElemwiseMultiTypeImpl::on_quantized_mode(
const ElemwiseOpParamN<3>& param, const TensorND& dst, Elemwise::Mode mode) {
megdnn_assert(
param[0].layout.dtype.category() == DTypeCategory::QUANTIZED &&
param[0].layout.dtype.category() == param[1].layout.dtype.category() &&
param[0].layout.dtype.category() == param[2].layout.dtype.category());
megdnn_assert(dst.layout.dtype.category() == DTypeCategory::QUANTIZED);
switch (mode) {
#define DISPATCH(_mode) \
case Elemwise::Mode::_mode: { \
typedef ElemwiseKern< \
megcorePlatformCPU, param_enumv::Elemwise::Mode::_mode, float> \
KernImpl; \
dispatch_qint_op_dtype<KernImpl, ElemwiseOpParamN<3>>(param, dst); \
break; \
}
DISPATCH(FUSE_MUL_ADD3);
DISPATCH(COND_LEQ_MOV);
#undef DISPATCH
default:
megdnn_assert_internal(0);
}
}
// vim: syntax=cpp.doxygen
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册