提交 c03249c0 编写于 作者: M Megvii Engine Team

feat(dnn/opr): add megdnn fake quant opr

GitOrigin-RevId: 5a04b6da2ffefe1f7f76da0bca1fcc66b7389bf1
上级 2e530779
......@@ -60,7 +60,7 @@ struct PreprocessedFilter {
TensorNDArray tensors;
};
} // namespace intl
} // namespace detail
/**
* \brief base class for convolution operation
......@@ -1562,6 +1562,58 @@ protected:
};
using BatchConvBias = BatchConvBiasForward;
class FakeQuantBase : public OperatorBase {
DEF_OPR_IMPL_CTOR(FakeQuantBase, OperatorBase);
DEF_OPR_PARAM(FakeQuant);
protected:
void deduce_layout_fwd(const TensorLayout& input, TensorLayout& output);
void check_layout_fwd(const TensorLayout& input, const TensorLayout& scale,
const TensorLayout& zero_point,
const TensorLayout& output);
};
class FakeQuantForward : public FakeQuantBase {
DEF_OPR_IMPL(FakeQuantForward, FakeQuantBase, 3, 1);
public:
virtual void exec(_megdnn_tensor_in input, _megdnn_tensor_in scale,
_megdnn_tensor_in zero_point, _megdnn_tensor_out output,
_megdnn_workspace workspace) = 0;
void deduce_layout(const TensorLayout& input, const TensorLayout& scale,
const TensorLayout& zero_point, TensorLayout& output);
virtual size_t get_workspace_in_bytes(const TensorLayout& input,
const TensorLayout& scale,
const TensorLayout& zero_point,
const TensorLayout& output) = 0;
protected:
void check_exec(const TensorLayout& input, const TensorLayout& scale,
const TensorLayout& zero_point, const TensorLayout& output,
size_t workspace_in_bytes);
};
using FakeQuant = FakeQuantForward;
class FakeQuantBackward : public FakeQuantBase {
DEF_OPR_IMPL(FakeQuantBackward, FakeQuantBase, 4, 1);
public:
virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in input,
_megdnn_tensor_in scale, _megdnn_tensor_in zero_point,
_megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
virtual size_t get_workspace_in_bytes(const TensorLayout& diff,
const TensorLayout& input,
const TensorLayout& scale,
const TensorLayout& zero_point,
const TensorLayout& grad) = 0;
protected:
void check_exec(const TensorLayout& diff, const TensorLayout& input,
const TensorLayout& scale, const TensorLayout& zero_point,
const TensorLayout& grad, size_t workspace_in_bytes);
};
} // namespace megdnn
#include "megdnn/internal/opr_header_epilogue.h"
......
......@@ -943,5 +943,9 @@ when the ``I`` suffix is present.
add_enum_alias('Format', 'ConvolutionV0').
add_enum_alias('ComputeMode', 'Convolution', name_field="compute_mode")
)
(pdef('FakeQuant').
add_fields('int32','qmin','-2147483648').
add_fields('int32','qmax','2147483647')
)
/**
* \file dnn/src/common/fakequant.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "megdnn/oprs.h"
#include "src/common/utils.h"
namespace megdnn {
void FakeQuantBase::deduce_layout_fwd(const TensorLayout& input,
TensorLayout& output) {
output = TensorLayout(input, input.dtype);
}
void FakeQuantBase::check_layout_fwd(const TensorLayout& input,
const TensorLayout& scale,
const TensorLayout& zero_point,
const TensorLayout& output) {
megdnn_assert(input.dtype == dtype::Float32());
megdnn_assert(scale.dtype == dtype::Float32());
megdnn_assert(zero_point.dtype == dtype::Float32());
TensorLayout expected;
deduce_layout_fwd(input, expected);
megdnn_assert_eq_layout(expected, output);
}
void FakeQuantForward::deduce_layout(const TensorLayout& input,
const TensorLayout& /*scale*/,
const TensorLayout& /*zero_point*/,
TensorLayout& output) {
deduce_layout_fwd(input, output);
}
void FakeQuantForward::check_exec(const TensorLayout& input,
const TensorLayout& scale,
const TensorLayout& zero_point,
const TensorLayout& output,
size_t workspace_in_bytes) {
check_layout_fwd(input, scale, zero_point, output);
auto required_workspace_space =
get_workspace_in_bytes(input, scale, zero_point, output);
megdnn_assert(workspace_in_bytes >= required_workspace_space);
}
void FakeQuantBackward::check_exec(const TensorLayout& diff,
const TensorLayout& input,
const TensorLayout& scale,
const TensorLayout& zero_point,
const TensorLayout& grad,
size_t workspace_in_bytes) {
megdnn_assert_eq_shape(input, diff);
megdnn_assert_eq_shape(input, grad);
auto required_worspace_space =
get_workspace_in_bytes(diff, input, scale, zero_point, grad);
megdnn_assert(workspace_in_bytes >= required_worspace_space);
}
} // namespace megdnn
\ No newline at end of file
......@@ -201,7 +201,9 @@ private:
cb(RemapBackwardMat) \
cb(AdaptivePoolingForward) \
cb(AdaptivePoolingBackward) \
cb(DctChannelSelectForward)
cb(DctChannelSelectForward) \
cb(FakeQuantForward) \
cb(FakeQuantBackward)
/*!
* \brief specialize HandleImpl::create_operator for a single opr type;
......
......@@ -13,9 +13,9 @@
#pragma once
#include "src/common/elemwise_helper.cuh"
#include "src/cuda/utils.cuh"
#include "src/cuda/int_fastdiv.cuh"
#include "src/cuda/query_blocksize.cuh"
#include "src/cuda/utils.cuh"
/*
* please note that all arithmetics on GPU are 32-bit for best performance; this
......@@ -649,6 +649,102 @@ struct OpCallerUniform<Op, 3, PVis> {
}
};
//! specialization for arity == 4
template <class Op, class PVis>
struct OpCallerUniform<Op, 4, PVis> {
Op op;
PVis par[4];
static const uint32_t packed_size = PVis::packed_size;
devfunc void thread_init(uint32_t idx) {
idx = idx * packed_size;
par[0].thread_init(idx);
par[1].thread_init(idx);
par[2].thread_init(idx);
par[3].thread_init(idx);
}
devfunc void on(uint32_t idx) {
idx = idx * packed_size;
op(idx, par[0].at(idx), par[1].at(idx), par[2].at(idx), par[3].at(idx));
}
devfunc void on(uint32_t idx, uint32_t remain) {
idx = idx * packed_size;
if (remain >= packed_size) {
op(idx, par[0].at(idx), par[1].at(idx), par[2].at(idx),
par[3].at(idx));
} else {
auto ptr0 = par[0].ptr();
auto ptr1 = par[1].ptr();
auto ptr2 = par[2].ptr();
auto ptr3 = par[3].ptr();
for (int i = 0; i < remain; i++) {
op(idx + i, ptr0[par[0].offset(idx + i)],
ptr1[par[1].offset(idx + i)], ptr2[par[2].offset(idx + i)],
ptr3[par[3].offset(idx + i)]);
}
}
}
devfunc void next() {
par[0].next();
par[1].next();
par[2].next();
par[3].next();
}
};
//! specialization for arity == 5
template <class Op, class PVis>
struct OpCallerUniform<Op, 5, PVis> {
Op op;
PVis par[5];
static const uint32_t packed_size = PVis::packed_size;
devfunc void thread_init(uint32_t idx) {
idx = idx * packed_size;
par[0].thread_init(idx);
par[1].thread_init(idx);
par[2].thread_init(idx);
par[3].thread_init(idx);
par[4].thread_init(idx);
}
devfunc void on(uint32_t idx) {
idx = idx * packed_size;
op(idx, par[0].at(idx), par[1].at(idx), par[2].at(idx), par[3].at(idx),
par[4].at(idx));
}
devfunc void on(uint32_t idx, uint32_t remain) {
idx = idx * packed_size;
if (remain >= packed_size) {
op(idx, par[0].at(idx), par[1].at(idx), par[2].at(idx),
par[3].at(idx), par[4].at(idx));
} else {
auto ptr0 = par[0].ptr();
auto ptr1 = par[1].ptr();
auto ptr2 = par[2].ptr();
auto ptr3 = par[3].ptr();
auto ptr4 = par[4].ptr();
for (int i = 0; i < remain; i++) {
op(idx + i, ptr0[par[0].offset(idx + i)],
ptr1[par[1].offset(idx + i)], ptr2[par[2].offset(idx + i)],
ptr3[par[3].offset(idx + i)], ptr4[par[4].offset(idx + i)]);
}
}
}
devfunc void next() {
par[0].next();
par[1].next();
par[2].next();
par[3].next();
par[4].next();
}
};
/*!
* \brief call binary (i.e. arity == 2) operator with different param
* visitors
......
/**
* \file dnn/src/cuda/fake_quant/kern.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "./kern.cuh"
namespace megdnn {
namespace cuda {
#define cb(_dtype) \
INST_RUN_ELEMWISE(FakeQuantKernOp<DTypeTrait<_dtype>::ctype>, \
DTypeTrait<_dtype>::ctype, 2); \
INST_RUN_ELEMWISE(FakeQuantBwdKernOp<DTypeTrait<_dtype>::ctype>, \
DTypeTrait<_dtype>::ctype, 2); \
INST_RUN_ELEMWISE(FakeQuantKernOpNonContig<DTypeTrait<_dtype>::ctype>, \
DTypeTrait<_dtype>::ctype, 4); \
INST_RUN_ELEMWISE(FakeQuantBwdKernOpNonContig<DTypeTrait<_dtype>::ctype>, \
DTypeTrait<_dtype>::ctype, 5);
cb(megdnn::dtype::Float32)
} // namespace cuda
} // namespace megdnn
\ No newline at end of file
/**
* \file dnn/src/cuda/elemwise_helper.cuh
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "src/cuda/elemwise_helper.cuh"
#include "src/cuda/utils.cuh"
#if MEGDNN_CC_HOST
#include "megdnn/oprs.h"
#endif
namespace megdnn {
namespace cuda {
template <typename ctype>
struct FakeQuantKernOp {
ctype* input;
ctype* output;
ctype qmin, qmax;
__device__ void operator()(uint32_t idx, ctype scale, ctype zero_point) {
ctype x = round(input[idx] / scale) + zero_point;
x = fmaxf(fminf(x, qmax), qmin);
output[idx] = (x - zero_point) * scale;
}
#if MEGDNN_CC_HOST
FakeQuantKernOp(const TensorND& input, const TensorND& output,
const FakeQuant::Param& param)
: input{input.ptr<ctype>()},
output{output.ptr<ctype>()},
qmin(param.qmin),
qmax(param.qmax) {}
#endif
};
template <typename ctype>
struct FakeQuantBwdKernOp {
ctype* diff;
ctype* input;
ctype* grad;
ctype qmin, qmax;
__device__ void operator()(uint32_t idx, ctype scale, ctype zero_point) {
ctype x = round(input[idx] / scale) + zero_point;
grad[idx] = x <= qmax && x >= qmin ? diff[idx] : 0.0;
}
#if MEGDNN_CC_HOST
FakeQuantBwdKernOp(const TensorND& diff, const TensorND& input,
const TensorND& grad, const FakeQuant::Param& param)
: diff{diff.ptr<ctype>()},
input{input.ptr<ctype>()},
grad{grad.ptr<ctype>()},
qmin(param.qmin),
qmax(param.qmax) {}
#endif
};
template <typename ctype>
struct FakeQuantKernOpNonContig {
ctype qmin;
ctype qmax;
__device__ void operator()(uint32_t, ctype& output, ctype input,
ctype scale, ctype zero_point) {
ctype x = round(input / scale) + zero_point;
x = fmaxf(fminf(x, qmax), qmin);
output = (x - zero_point) * scale;
}
#if MEGDNN_CC_HOST
FakeQuantKernOpNonContig(const FakeQuant::Param& param)
: qmin(param.qmin), qmax(param.qmax) {}
#endif
};
template <typename ctype>
struct FakeQuantBwdKernOpNonContig {
ctype qmin;
ctype qmax;
__device__ void operator()(uint32_t, ctype& grad, ctype diff, ctype input,
ctype scale, ctype zero_point) {
ctype x = round(input / scale) + zero_point;
grad = x <= qmax && x >= qmin ? diff : 0.0;
}
#if MEGDNN_CC_HOST
FakeQuantBwdKernOpNonContig(const FakeQuant::Param& param)
: qmin(param.qmin), qmax(param.qmax) {}
#endif
};
} // namespace cuda
} // namespace megdnn
/**
* \file dnn/src/cuda/fake_quant/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "./opr_impl.h"
#include "./kern.cuh"
#include "src/common/utils.h"
namespace megdnn {
namespace cuda {
void FakeQuantForwardImpl::exec(_megdnn_tensor_in input,
_megdnn_tensor_in scale,
_megdnn_tensor_in zero_point,
_megdnn_tensor_out output,
_megdnn_workspace workspace) {
check_exec(input.layout, scale.layout, zero_point.layout, output.layout,
workspace.size);
if (!input.layout.is_contiguous() || !output.layout.is_contiguous()) {
return exec_noncontig(input, scale, zero_point, output);
}
ElemwiseOpParamN<2> ele_param;
ele_param[0] = scale;
ele_param[0].layout = ele_param[0].layout.broadcast(input.layout);
ele_param[1] = zero_point;
ele_param[1].layout = ele_param[1].layout.broadcast(input.layout);
ele_param.init_from_given_tensor();
auto stream = cuda_stream(handle());
#define cb(DType) \
if (input.layout.dtype == DType()) { \
using T = typename DTypeTrait<DType>::ctype; \
run_elemwise<FakeQuantKernOp<T>, T, 2>(ele_param, stream, \
{input, output, m_param}); \
return; \
}
cb(megdnn::dtype::Float32)
#undef cb
}
void FakeQuantForwardImpl::exec_noncontig(_megdnn_tensor_in input,
_megdnn_tensor_in scale,
_megdnn_tensor_in zero_point,
_megdnn_tensor_out output) {
ElemwiseOpParamN<4> ele_param;
ele_param[0] = output;
ele_param[1] = input;
ele_param[2] = scale;
ele_param[2].layout = ele_param[2].layout.broadcast(input.layout);
ele_param[3] = zero_point;
ele_param[3].layout = ele_param[3].layout.broadcast(input.layout);
ele_param.init_from_given_tensor();
auto stream = cuda_stream(handle());
#define cb(DType) \
if (input.layout.dtype == DType()) { \
using T = typename DTypeTrait<DType>::ctype; \
run_elemwise<FakeQuantKernOpNonContig<T>, T, 4>(ele_param, stream, \
{m_param}); \
return; \
}
cb(megdnn::dtype::Float32)
#undef cb
}
void FakeQuantBackwardImpl::exec(_megdnn_tensor_in diff,
_megdnn_tensor_in input,
_megdnn_tensor_in scale,
_megdnn_tensor_in zero_point,
_megdnn_tensor_out grad,
_megdnn_workspace workspace) {
check_exec(diff.layout, input.layout, scale.layout, zero_point.layout,
grad.layout, workspace.size);
if (!input.layout.is_contiguous() || !diff.layout.is_contiguous() ||
!grad.layout.is_contiguous()) {
return exec_noncontig(diff, input, scale, zero_point, grad);
}
ElemwiseOpParamN<2> ele_param;
ele_param[0] = scale;
ele_param[0].layout = ele_param[0].layout.broadcast(input.layout);
ele_param[1] = zero_point;
ele_param[1].layout = ele_param[1].layout.broadcast(input.layout);
ele_param.init_from_given_tensor();
auto m_param = param();
auto stream = cuda_stream(handle());
#define cb(DType) \
if (grad.layout.dtype == DType()) { \
using T = typename DTypeTrait<DType>::ctype; \
run_elemwise<FakeQuantBwdKernOp<T>, T, 2>( \
ele_param, stream, {diff, input, grad, m_param}); \
return; \
}
cb(megdnn::dtype::Float32)
#undef cb
}
void FakeQuantBackwardImpl::exec_noncontig(_megdnn_tensor_in diff,
_megdnn_tensor_in input,
_megdnn_tensor_in scale,
_megdnn_tensor_in zero_point,
_megdnn_tensor_out grad) {
ElemwiseOpParamN<5> ele_param;
ele_param[0] = grad;
ele_param[1] = diff;
ele_param[2] = input;
ele_param[3] = scale;
ele_param[3].layout = ele_param[3].layout.broadcast(input.layout);
ele_param[4] = zero_point;
ele_param[4].layout = ele_param[4].layout.broadcast(input.layout);
ele_param.init_from_given_tensor();
auto m_param = param();
auto stream = cuda_stream(handle());
#define cb(DType) \
if (grad.layout.dtype == DType()) { \
using T = typename DTypeTrait<DType>::ctype; \
run_elemwise<FakeQuantBwdKernOpNonContig<T>, T, 5>(ele_param, stream, \
{m_param}); \
return; \
}
cb(megdnn::dtype::Float32)
#undef cb
}
} // namespace cuda
} // namespace megdnn
/**
* \file dnn/src/cuda/fake_quant/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megdnn/oprs.h"
#include "src/cuda/utils.h"
namespace megdnn {
namespace cuda {
class FakeQuantForwardImpl : public FakeQuantForward {
public:
using FakeQuantForward::FakeQuantForward;
void exec(_megdnn_tensor_in input, _megdnn_tensor_in scale,
_megdnn_tensor_in zero_point, _megdnn_tensor_out output,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&,
const TensorLayout&,
const TensorLayout&) override {
return 0;
}
private:
void exec_noncontig(_megdnn_tensor_in input, _megdnn_tensor_in scale,
_megdnn_tensor_in zero_point,
_megdnn_tensor_out output);
};
class FakeQuantBackwardImpl : public FakeQuantBackward {
public:
using FakeQuantBackward::FakeQuantBackward;
void exec(_megdnn_tensor_in diff, _megdnn_tensor_in input,
_megdnn_tensor_in scale, _megdnn_tensor_in zero_point,
_megdnn_tensor_out grad, _megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&,
const TensorLayout&) override {
return 0;
}
private:
void exec_noncontig(_megdnn_tensor_in diff, _megdnn_tensor_in input,
_megdnn_tensor_in scale, _megdnn_tensor_in zero_point,
_megdnn_tensor_out grad);
};
} // namespace cuda
} // namespace megdnn
......@@ -77,6 +77,7 @@
#include "src/cuda/roi_align/opr_impl.h"
#include "src/cuda/batch_conv_bias/opr_impl.h"
#include "src/cuda/remap/opr_impl.h"
#include "src/cuda/fake_quant/opr_impl.h"
namespace megdnn {
namespace cuda {
......
/**
* \file dnn/src/naive/fakequant/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/naive/fake_quant/opr_impl.h"
#include <cmath>
#include <iostream>
#include "megdnn/tensor_iter.h"
#include "src/common/elemwise_helper.cuh"
#include "src/common/utils.h"
#include "src/naive/handle.h"
namespace {
using namespace megdnn;
template <typename T>
void forward_impl(const ElemwiseOpParamN<4> src, float qmin, float qmax) {
auto inp = tensor_iter_valonly<T>(src[0]).begin();
auto out = tensor_iter_valonly<T>(src[1]).begin();
auto scale = tensor_iter_valonly<T>(src[2]).begin();
auto zero_point = tensor_iter_valonly<T>(src[3]).begin();
size_t total = src[0].layout.total_nr_elems();
for (size_t i = 0; i < total; ++i) {
T x = round(*inp / (*scale)) + *zero_point;
x = x <= qmin ? qmin : x;
x = x >= qmax ? qmax : x;
*out = (x - *zero_point) * *scale;
++inp;
++out;
++scale;
++zero_point;
}
}
template <typename T>
void backward_impl(const ElemwiseOpParamN<5> src, float qmin, float qmax) {
auto diff = tensor_iter_valonly<T>(src[0]).begin();
auto input = tensor_iter_valonly<T>(src[1]).begin();
auto scale = tensor_iter_valonly<T>(src[2]).begin();
auto zero_point = tensor_iter_valonly<T>(src[3]).begin();
auto grad = tensor_iter_valonly<T>(src[4]).begin();
size_t total = src[0].layout.total_nr_elems();
for (size_t i = 0; i < total; ++i) {
T x = round(*input / (*scale)) + *zero_point;
*grad = (x >= qmin && x <= qmax) ? *diff : 0.0;
++diff;
++input;
++scale;
++zero_point;
++grad;
}
}
} // namespace
namespace megdnn {
namespace naive {
void FakeQuantForwardImpl::exec(_megdnn_tensor_in input,
_megdnn_tensor_in scale,
_megdnn_tensor_in zero_point,
_megdnn_tensor_out output,
_megdnn_workspace workspace) {
check_exec(input.layout, scale.layout, zero_point.layout, output.layout,
workspace.size);
ElemwiseOpParamN<4> src;
src[0] = input;
src[1] = output;
src[2] = scale;
src[2].layout = src[2].layout.broadcast(input.layout);
src[3] = zero_point;
src[3].layout = src[3].layout.broadcast(input.layout);
#define cb(DType) \
if (input.layout.dtype == DType()) { \
using T = typename DTypeTrait<DType>::ctype; \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
forward_impl<T>(src, param().qmin, param().qmax)); \
return; \
}
cb(dtype::Float32)
#undef cb
}
void FakeQuantBackwardImpl::exec(_megdnn_tensor_in diff,
_megdnn_tensor_in input,
_megdnn_tensor_in scale,
_megdnn_tensor_in zero_point,
_megdnn_tensor_out grad,
_megdnn_workspace workspace) {
check_exec(diff.layout, input.layout, scale.layout, zero_point.layout,
grad.layout, workspace.size);
ElemwiseOpParamN<5> src;
src[0] = diff;
src[1] = input;
src[2] = scale;
src[2].layout = src[2].layout.broadcast(input.layout);
src[3] = zero_point;
src[3].layout = src[3].layout.broadcast(input.layout);
src[4] = grad;
#define cb(DType) \
if (diff.layout.dtype == DType() && grad.layout.dtype == DType() && \
input.layout.dtype == DType()) { \
using T = typename DTypeTrait<DType>::ctype; \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
backward_impl<T>(src, param().qmin, param().qmax)); \
return; \
}
cb(dtype::Float32)
#undef cb
}
} // namespace naive
} // namespace megdnn
/**
* \file dnn/src/naive/fakequant/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megdnn/oprs.h"
namespace megdnn {
namespace naive {
class FakeQuantForwardImpl : public FakeQuantForward {
public:
using FakeQuantForward::FakeQuantForward;
void exec(_megdnn_tensor_in input, _megdnn_tensor_in scale,
_megdnn_tensor_in zero_point, _megdnn_tensor_out output,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&,
const TensorLayout&,
const TensorLayout&) override {
return 0;
}
};
class FakeQuantBackwardImpl : public FakeQuantBackward {
public:
using FakeQuantBackward::FakeQuantBackward;
void exec(_megdnn_tensor_in diff, _megdnn_tensor_in input,
_megdnn_tensor_in scale, _megdnn_tensor_in zero_point,
_megdnn_tensor_out grad, _megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&,
const TensorLayout&) override {
return 0;
}
};
} // namespace naive
} // namespace megdnn
......@@ -79,6 +79,8 @@
#include "src/naive/warp_affine/opr_impl.h"
#include "src/naive/warp_perspective/opr_impl.h"
#include "src/naive/winograd_filter_preprocess/opr_impl.h"
#include "src/naive/remap/opr_impl.h"
#include "src/naive/fake_quant/opr_impl.h"
static size_t g_image2d_pitch_alignment = 1;
......
/**
* \file dnn/test/common/fake_quant.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megdnn/basic_types.h"
#include "megdnn/opr_param_defs.h"
namespace megdnn {
namespace test {
namespace fake_quant {
struct TestArg {
param::FakeQuant param;
TensorShape ishape;
TensorShape scale_shape;
TensorShape zeropoint_shape;
TestArg(param::FakeQuant param, TensorShape ishape, TensorShape scale_shape,
TensorShape zeropoint_shape)
: param(param),
ishape(ishape),
scale_shape(scale_shape),
zeropoint_shape(zeropoint_shape) {}
};
inline std::vector<TestArg> get_args() {
std::vector<TestArg> args;
param::FakeQuant cur_param;
cur_param.qmin = -128;
cur_param.qmax = 128;
for (size_t i = 10; i < 40; i += 2) {
args.emplace_back(cur_param, TensorShape{10, 64, i, i}, TensorShape{1},
TensorShape{1});
}
for (size_t m : {1, 10})
for (size_t n : {1, 10})
for (size_t j : {1, 10})
for (size_t k : {1, 10}) {
args.emplace_back(cur_param, TensorShape{10, 64, 10, 10},
TensorShape{10, 64, m, n},
TensorShape{10, 64, j, k});
}
return args;
}
} // namespace fake_quant
} // namespace test
} // namespace megdnn
\ No newline at end of file
......@@ -111,6 +111,8 @@ DEF(Remap, 3, true, true);
DEF(RemapBackwardData, 3, true, false);
DEF(RemapBackwardMat, 4, true, false);
DEF(DctChannelSelectForward, 4, true, true);
DEF(FakeQuantForward, 4, true, true);
DEF(FakeQuantBackward, 5, true, false);
} // namespace test
} // namespace megdnn
......
/**
* \file dnn/test/cuda/fake_quant.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "test/common/fake_quant.h"
#include "megdnn/oprs.h"
#include "test/common/checker.h"
#include "test/cuda/fixture.h"
namespace megdnn {
namespace test {
using namespace fake_quant;
TEST_F(CUDA, FAKE_QUANT) {
std::vector<TestArg> args = get_args();
auto dtype = dtype::Float32();
std::unique_ptr<RNG> rng;
for (auto&& arg : args) {
auto param = arg.param;
auto ishape = arg.ishape;
auto scale_shape = arg.scale_shape;
auto zeropoint_shape = arg.zeropoint_shape;
Checker<FakeQuantForward> checker(handle_cuda());
checker.set_param(param)
.set_dtype(0, dtype)
.set_dtype(1, dtype)
.set_dtype(2, dtype)
.set_dtype(3, dtype)
.execs(TensorShapeArray{ishape, scale_shape, zeropoint_shape,
ishape});
}
// test noncontiguous layout
for (auto&& arg : args) {
auto param = arg.param;
auto ishape = arg.ishape;
auto scale_shape = arg.scale_shape;
auto zeropoint_shape = arg.zeropoint_shape;
Checker<FakeQuantForward> checker(handle_cuda());
TensorLayout ilayout(
ishape,
{(long int)(ishape[1] * ishape[2] * ishape[3] * 2),
(long int)(ishape[2] * ishape[3]), (long int)ishape[3], 1},
dtype::Float32());
checker.set_param(param).execl({ilayout,
{scale_shape, dtype::Float32()},
{zeropoint_shape, dtype::Float32()},
ilayout});
}
}
TEST_F(CUDA, FAKE_QUANT_BACKWARD) {
std::vector<TestArg> args = get_args();
auto dtype = dtype::Float32();
for (auto&& arg : args) {
auto param = arg.param;
auto ishape = arg.ishape;
auto scale_shape = arg.scale_shape;
auto zeropoint_shape = arg.zeropoint_shape;
Checker<FakeQuantBackward> checker(handle_cuda());
checker.set_param(param)
.set_dtype(0, dtype)
.set_dtype(1, dtype)
.set_dtype(2, dtype)
.set_dtype(3, dtype)
.set_dtype(4, dtype)
.execs(TensorShapeArray{ishape, ishape, scale_shape,
zeropoint_shape, ishape});
}
// test noncontiguous layout
for (auto&& arg : args) {
auto param = arg.param;
auto ishape = arg.ishape;
auto scale_shape = arg.scale_shape;
auto zeropoint_shape = arg.zeropoint_shape;
Checker<FakeQuantBackward> checker(handle_cuda());
TensorLayout ilayout(
ishape,
{(long int)(ishape[1] * ishape[2] * ishape[3] * 2),
(long int)(ishape[2] * ishape[3]), (long int)ishape[3], 1},
dtype::Float32());
checker.set_param(param).execl({ilayout,
ilayout,
{scale_shape, dtype::Float32()},
{zeropoint_shape, dtype::Float32()},
ilayout});
}
}
} // namespace test
} // namespace megdnn
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册