diff --git a/dnn/include/megdnn/oprs/nn.h b/dnn/include/megdnn/oprs/nn.h index a61adf2599bcf20fafac562d2acc534d133ad97e..c4126ff1a91c552a67ec905cabafa88d0415889f 100644 --- a/dnn/include/megdnn/oprs/nn.h +++ b/dnn/include/megdnn/oprs/nn.h @@ -1741,6 +1741,67 @@ protected: const TensorLayout& grad_s, size_t workspace_in_bytes); }; +class LSQBase : public OperatorBase { + DEF_OPR_IMPL_CTOR(LSQBase, OperatorBase); + DEF_OPR_PARAM(LSQ); + +protected: + void deduce_layout_fwd(const TensorLayout& input, TensorLayout& output); + void check_layout_fwd(const TensorLayout& input, const TensorLayout& scale, + const TensorLayout& zero_point, + const TensorLayout& grad_scale, + const TensorLayout& output); +}; + +class LSQForward : public LSQBase { + DEF_OPR_IMPL(LSQForward, LSQBase, 4, 1); + +public: + virtual void exec(_megdnn_tensor_in input, _megdnn_tensor_in scale, + _megdnn_tensor_in zero_point, + _megdnn_tensor_in grad_scale, _megdnn_tensor_out output, + _megdnn_workspace workspace) = 0; + void deduce_layout(const TensorLayout& input, const TensorLayout& scale, + const TensorLayout& zero_point, + const TensorLayout& grad_scale, TensorLayout& output); + virtual size_t get_workspace_in_bytes(const TensorLayout& input, + const TensorLayout& scale, + const TensorLayout& zero_point, + const TensorLayout& grad_scale, + const TensorLayout& output) = 0; + +protected: + void check_exec(const TensorLayout& input, const TensorLayout& scale, + const TensorLayout& zero_point, + const TensorLayout& grad_scale, const TensorLayout& output, + size_t workspace_in_bytes); +}; +using LSQ = LSQForward; + +class LSQBackward : public LSQBase { + DEF_OPR_IMPL(LSQBackward, LSQBase, 5, 2); + +public: + virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in input, + _megdnn_tensor_in scale, _megdnn_tensor_in zero_point, + _megdnn_tensor_in grad_scale, _megdnn_tensor_out grad_x, + _megdnn_tensor_out grad_s, + _megdnn_workspace workspace) = 0; + virtual size_t get_workspace_in_bytes(const TensorLayout& diff, + const TensorLayout& input, + const TensorLayout& scale, + const TensorLayout& zero_point, + const TensorLayout& grad_scale, + const TensorLayout& grad_x, + const TensorLayout& grad_s) = 0; + +protected: + void check_exec(const TensorLayout& diff, const TensorLayout& input, + const TensorLayout& scale, const TensorLayout& zero_point, + const TensorLayout& grad_scale, const TensorLayout& grad_x, + const TensorLayout& grad_s, size_t workspace_in_bytes); +}; + } // namespace megdnn #include "megdnn/internal/opr_header_epilogue.h" diff --git a/dnn/scripts/opr_param_defs.py b/dnn/scripts/opr_param_defs.py index dde581637ce9ead171897e396a0e88273bada7bf..566dc2f5283d154c6865a4dc6f02d251d769dbc1 100755 --- a/dnn/scripts/opr_param_defs.py +++ b/dnn/scripts/opr_param_defs.py @@ -1124,3 +1124,8 @@ Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later o add_fields('int32', 'qmin', '-2147483648'). add_fields('int32', 'qmax', '2147483647') ) +(pdef('LSQ'). + add_fields('int32', 'qmin', '-2147483648'). + add_fields('int32', 'qmax', '2147483647') + ) + diff --git a/dnn/src/common/elemwise_helper.cpp b/dnn/src/common/elemwise_helper.cpp index 62200430984ce26cfdce46f4880d8392d00d20d5..200a99d4778e05b1da71d3f8d4f4a0c47b2822bd 100644 --- a/dnn/src/common/elemwise_helper.cpp +++ b/dnn/src/common/elemwise_helper.cpp @@ -37,6 +37,7 @@ namespace megdnn { megdnn_assert(size, "uninitialized ElemwiseOpParamN"); } + template struct ElemwiseOpParamN<7>; template struct ElemwiseOpParamN<6>; template struct ElemwiseOpParamN<5>; template struct ElemwiseOpParamN<4>; diff --git a/dnn/src/common/handle_impl.h b/dnn/src/common/handle_impl.h index 4d6354e3baebdcd0e7f13350eb9dfec8c8a15dd2..03c83053143bef9ddef0aa38a08748e724877af9 100644 --- a/dnn/src/common/handle_impl.h +++ b/dnn/src/common/handle_impl.h @@ -208,7 +208,9 @@ private: cb(FakeQuantBackward) \ cb(TQTForward) \ cb(TQTBackward) \ - cb(CheckHasInf) + cb(CheckHasInf) \ + cb(LSQForward) \ + cb(LSQBackward) /*! * \brief specialize HandleImpl::create_operator for a single opr type; diff --git a/dnn/src/common/lsq.cpp b/dnn/src/common/lsq.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c8078ec365f02f84dbd78d1f223a446ee6cb6871 --- /dev/null +++ b/dnn/src/common/lsq.cpp @@ -0,0 +1,69 @@ +/** + * \file dnn/src/common/lsq.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "megdnn/oprs.h" +#include "src/common/utils.h" + +namespace megdnn { + +void LSQBase::deduce_layout_fwd(const TensorLayout& input, + TensorLayout& output) { + output = TensorLayout(input, input.dtype); +} + +void LSQBase::check_layout_fwd(const TensorLayout& input, + const TensorLayout& scale, + const TensorLayout& zero_point, + const TensorLayout& grad_scale, + const TensorLayout& output) { + megdnn_assert(input.dtype == dtype::Float32()); + megdnn_assert(scale.dtype == dtype::Float32()); + megdnn_assert(zero_point.dtype == dtype::Float32()); + megdnn_assert(grad_scale.dtype == dtype::Float32()); + TensorLayout expected; + deduce_layout_fwd(input, expected); + megdnn_assert_eq_layout(expected, output); +} + +void LSQForward::deduce_layout(const TensorLayout& input, + const TensorLayout& /* scale */, + const TensorLayout& /*zero_point*/, + const TensorLayout& /*grad_scale*/, + TensorLayout& output) { + deduce_layout_fwd(input, output); +} + +void LSQForward::check_exec(const TensorLayout& input, + const TensorLayout& scale, + const TensorLayout& zero_point, + const TensorLayout& grad_scale, + const TensorLayout& output, + size_t workspace_in_bytes) { + check_layout_fwd(input, scale, zero_point, grad_scale, output); + auto required_workspace_space = get_workspace_in_bytes( + input, scale, zero_point, grad_scale, output); + megdnn_assert(workspace_in_bytes >= required_workspace_space); +} + +void LSQBackward::check_exec( + const TensorLayout& diff, const TensorLayout& input, + const TensorLayout& scale, const TensorLayout& zero_point, + const TensorLayout& grad_scale, const TensorLayout& grad_x, + const TensorLayout& grad_s, size_t workspace_in_bytes) { + megdnn_assert_eq_shape(diff, input); + megdnn_assert_eq_shape(grad_x, input); + auto required_worspace_space = get_workspace_in_bytes( + diff, input, scale, zero_point, grad_scale, grad_x, grad_s); + megdnn_assert(workspace_in_bytes >= required_worspace_space); +} + +} // namespace megdnn diff --git a/dnn/src/common/opr_trait.h b/dnn/src/common/opr_trait.h index 313bac68b5015f53f9a4a6165e4694e5d5181df4..646ba28433890d96527bf6c52b6b7651f1222207 100644 --- a/dnn/src/common/opr_trait.h +++ b/dnn/src/common/opr_trait.h @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once #include "megdnn/oprs.h" @@ -121,6 +122,8 @@ DEF(UniformRNG, 1, true, true); DEF(GaussianRNG, 1, true, true); DEF(ChecksumForward, 1, true, false); DEF(CheckHasInf, 2, true, true); +DEF(LSQForward, 5, true, true); +DEF(LSQBackward, 7, true, false); } // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/elemwise_helper.cuh b/dnn/src/cuda/elemwise_helper.cuh index 5424389bc1006326574f989e6a05e61c04bbc263..66712dde4a54fba193a194fe963cca1f90c4df1f 100644 --- a/dnn/src/cuda/elemwise_helper.cuh +++ b/dnn/src/cuda/elemwise_helper.cuh @@ -947,6 +947,119 @@ struct OpCallerUniform { } }; + +//! specialization for arity == 6 +template +struct OpCallerUniform { + Op op; + PVis par[6]; + static const uint32_t packed_size = PVis::packed_size; + + devfunc void thread_init(uint32_t idx) { + idx = idx * packed_size; + par[0].thread_init(idx); + par[1].thread_init(idx); + par[2].thread_init(idx); + par[3].thread_init(idx); + par[4].thread_init(idx); + par[5].thread_init(idx); + } + + devfunc void on(uint32_t idx) { + idx = idx * packed_size; + op(idx, par[0].at(idx), par[1].at(idx), par[2].at(idx), par[3].at(idx), + par[4].at(idx), par[5].at(idx)); + } + + devfunc void on(uint32_t idx, uint32_t remain) { + idx = idx * packed_size; + if (remain >= packed_size) { + op(idx, par[0].at(idx), par[1].at(idx), par[2].at(idx), + par[3].at(idx), par[4].at(idx), par[5].at(idx)); + } else { + auto ptr0 = par[0].ptr(); + auto ptr1 = par[1].ptr(); + auto ptr2 = par[2].ptr(); + auto ptr3 = par[3].ptr(); + auto ptr4 = par[4].ptr(); + auto ptr5 = par[5].ptr(); + for (int i = 0; i < remain; i++) { + op(idx + i, ptr0[par[0].offset(idx + i)], + ptr1[par[1].offset(idx + i)], ptr2[par[2].offset(idx + i)], + ptr3[par[3].offset(idx + i)], ptr4[par[4].offset(idx + i)], + ptr5[par[5].offset(idx + i)]); + } + } + } + + devfunc void next() { + par[0].next(); + par[1].next(); + par[2].next(); + par[3].next(); + par[4].next(); + par[5].next(); + } +}; + + +//! specialization for arity == 7 +template +struct OpCallerUniform { + Op op; + PVis par[7]; + static const uint32_t packed_size = PVis::packed_size; + + devfunc void thread_init(uint32_t idx) { + idx = idx * packed_size; + par[0].thread_init(idx); + par[1].thread_init(idx); + par[2].thread_init(idx); + par[3].thread_init(idx); + par[4].thread_init(idx); + par[5].thread_init(idx); + par[6].thread_init(idx); + } + + devfunc void on(uint32_t idx) { + idx = idx * packed_size; + op(idx, par[0].at(idx), par[1].at(idx), par[2].at(idx), par[3].at(idx), + par[4].at(idx), par[5].at(idx), par[6].at(idx)); + } + + devfunc void on(uint32_t idx, uint32_t remain) { + idx = idx * packed_size; + if (remain >= packed_size) { + op(idx, par[0].at(idx), par[1].at(idx), par[2].at(idx), + par[3].at(idx), par[4].at(idx), par[5].at(idx), par[6].at(idx)); + } else { + auto ptr0 = par[0].ptr(); + auto ptr1 = par[1].ptr(); + auto ptr2 = par[2].ptr(); + auto ptr3 = par[3].ptr(); + auto ptr4 = par[4].ptr(); + auto ptr5 = par[5].ptr(); + auto ptr6 = par[6].ptr(); + for (int i = 0; i < remain; i++) { + op(idx + i, ptr0[par[0].offset(idx + i)], + ptr1[par[1].offset(idx + i)], ptr2[par[2].offset(idx + i)], + ptr3[par[3].offset(idx + i)], ptr4[par[4].offset(idx + i)], + ptr5[par[5].offset(idx + i)], ptr6[par[6].offset(idx + i)]); + } + } + } + + devfunc void next() { + par[0].next(); + par[1].next(); + par[2].next(); + par[3].next(); + par[4].next(); + par[5].next(); + par[6].next(); + } +}; + /*! * \brief call binary (i.e. arity == 2) operator with different param * visitors diff --git a/dnn/src/cuda/handle_create.cpp b/dnn/src/cuda/handle_create.cpp index afbd1dee63729531a041b436bc38feee86ef3e6b..a0316ff2072cce9802932bc96d65ef648fff2e24 100644 --- a/dnn/src/cuda/handle_create.cpp +++ b/dnn/src/cuda/handle_create.cpp @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #include "src/common/handle_impl.h" @@ -15,6 +16,7 @@ #include "src/cuda/add_update/opr_impl.h" #include "src/cuda/argmxx/opr_impl.h" #include "src/cuda/argsort/opr_impl.h" +#include "src/cuda/batch_conv_bias/opr_impl.h" #include "src/cuda/batch_normalization/opr_impl.h" #include "src/cuda/batched_matrix_mul/opr_impl.h" #include "src/cuda/check_has_inf/opr_impl.h" @@ -35,6 +37,7 @@ #include "src/cuda/elemwise/opr_impl.h" #include "src/cuda/elemwise_multi_type/opr_impl.h" #include "src/cuda/eye/opr_impl.h" +#include "src/cuda/fake_quant/opr_impl.h" #include "src/cuda/flip/opr_impl.h" #include "src/cuda/gaussian_blur/opr_impl.h" #include "src/cuda/group_local/opr_impl.h" @@ -45,6 +48,7 @@ #include "src/cuda/local/opr_impl.h" #include "src/cuda/local_share/opr_impl.h" #include "src/cuda/lrn/opr_impl.h" +#include "src/cuda/lsq/opr_impl.h" #include "src/cuda/mask_conv/opr_impl.h" #include "src/cuda/matrix_inverse/opr_impl.h" #include "src/cuda/matrix_mul/opr_impl.h" @@ -56,9 +60,11 @@ #include "src/cuda/reduce/opr_impl.h" #include "src/cuda/relayout/opr_impl.h" #include "src/cuda/relayout_format/opr_impl.h" +#include "src/cuda/remap/opr_impl.h" #include "src/cuda/repeat/opr_impl.h" #include "src/cuda/resize/opr_impl.h" #include "src/cuda/rng/opr_impl.h" +#include "src/cuda/roi_align/opr_impl.h" #include "src/cuda/roi_copy/opr_impl.h" #include "src/cuda/roi_pooling/opr_impl.h" #include "src/cuda/rotate/opr_impl.h" @@ -70,16 +76,11 @@ #include "src/cuda/tensor_remap/opr_impl.h" #include "src/cuda/tile/opr_impl.h" #include "src/cuda/topk/opr_impl.h" +#include "src/cuda/tqt/opr_impl.h" #include "src/cuda/transpose/opr_impl.h" #include "src/cuda/type_cvt/opr_impl.h" #include "src/cuda/warp_affine/opr_impl.h" #include "src/cuda/warp_perspective/opr_impl.h" -#include "src/cuda/local_share/opr_impl.h" -#include "src/cuda/roi_align/opr_impl.h" -#include "src/cuda/batch_conv_bias/opr_impl.h" -#include "src/cuda/remap/opr_impl.h" -#include "src/cuda/fake_quant/opr_impl.h" -#include "src/cuda/tqt/opr_impl.h" namespace megdnn { namespace cuda { diff --git a/dnn/src/cuda/lsq/kern.cu b/dnn/src/cuda/lsq/kern.cu new file mode 100644 index 0000000000000000000000000000000000000000..74950fc7780ee4b900343108916f75d458364527 --- /dev/null +++ b/dnn/src/cuda/lsq/kern.cu @@ -0,0 +1,30 @@ +/** + * \file dnn/src/cuda/lsq/kern.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "./kern.cuh" + +namespace megdnn { +namespace cuda { + +#define cb(_dtype) \ + INST_RUN_ELEMWISE(LSQKernOp::ctype>, \ + DTypeTrait<_dtype>::ctype, 3); \ + INST_RUN_ELEMWISE(LSQBwdKernOp::ctype>, \ + DTypeTrait<_dtype>::ctype, 3); \ + INST_RUN_ELEMWISE(LSQKernOpNonContig::ctype>, \ + DTypeTrait<_dtype>::ctype, 5); \ + INST_RUN_ELEMWISE(LSQBwdKernOpNonContig::ctype>, \ + DTypeTrait<_dtype>::ctype, 7); +cb(megdnn::dtype::Float32) + +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/lsq/kern.cuh b/dnn/src/cuda/lsq/kern.cuh new file mode 100644 index 0000000000000000000000000000000000000000..6bed31be21a32d403050f6af4ad28124d3248d20 --- /dev/null +++ b/dnn/src/cuda/lsq/kern.cuh @@ -0,0 +1,126 @@ +/** + * \file dnn/src/cuda/lsq/kern.cuh + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once + +#include "src/cuda/elemwise_helper.cuh" +#include "src/cuda/utils.cuh" + +#if MEGDNN_CC_HOST +#include "megdnn/oprs.h" +#endif + +namespace megdnn { +namespace cuda { +template +struct LSQKernOp { + ctype* input; + ctype* output; + ctype qmin, qmax; + + __device__ void operator()(uint32_t idx, ctype scale, ctype zero_point, + ctype grad_scale) { + ctype x = input[idx] / scale + zero_point; + x = fmaxf(fminf(x, qmax), qmin); + x = round(x); + output[idx] = (x - zero_point) * scale; + } + +#if MEGDNN_CC_HOST + LSQKernOp(const TensorND& input, const TensorND& output, + const LSQ::Param& param) + : input{input.ptr()}, + output{output.ptr()}, + qmin(param.qmin), + qmax(param.qmax) {} +#endif +}; + +template +struct LSQBwdKernOp { + ctype* diff; + ctype* input; + ctype* grad_x; + ctype* grad_s; + ctype qmin, qmax; + + __device__ void operator()(uint32_t idx, ctype scale, ctype zero_point, + ctype grad_scale) { + ctype x = input[idx] / scale + zero_point; + bool ind_small = x < qmin; + bool ind_big = x > qmax; + bool ind_middle = ind_small ^ ind_big; + ind_middle = !ind_middle; + grad_s[idx] = ind_small * qmin + ind_big * qmax + + ind_middle * (-x + round(x)); + grad_s[idx] = grad_s[idx] * grad_scale * diff[idx]; + grad_x[idx] = ind_middle * diff[idx]; + } + +#if MEGDNN_CC_HOST + LSQBwdKernOp(const TensorND& diff, const TensorND& input, + const TensorND& grad_x, const TensorND& grad_s, + const LSQ::Param& param) + : diff{diff.ptr()}, + input{input.ptr()}, + grad_x{grad_x.ptr()}, + grad_s{grad_s.ptr()}, + qmin(param.qmin), + qmax(param.qmax) {} +#endif +}; + +template +struct LSQKernOpNonContig { + ctype qmin; + ctype qmax; + + __device__ void operator()(uint32_t, ctype& output, ctype& input, + ctype& scale, ctype& zero_point, + ctype grad_scale) { + ctype x = input / scale + zero_point; + x = fmaxf(fminf(x, qmax), qmin); + x = round(x); + output = (x - zero_point) * scale; + } +#if MEGDNN_CC_HOST + LSQKernOpNonContig(const LSQ::Param& param) + : qmin(param.qmin), qmax(param.qmax) {} +#endif +}; + +template +struct LSQBwdKernOpNonContig { + ctype qmin; + ctype qmax; + + __device__ void operator()(uint32_t, ctype& grad_x, ctype& grad_s, + ctype& diff, ctype& input, ctype& scale, + ctype& zero_point, ctype grad_scale) { + ctype x = input / scale + zero_point; + bool ind_small = x < qmin; + bool ind_big = x > qmax; + bool ind_middle = ind_small ^ ind_big; + ind_middle = !ind_middle; + grad_s = ind_small * qmin + ind_big * qmax + + ind_middle * (-x + round(x)); + grad_s = grad_s * grad_scale * diff; + grad_x = ind_middle * diff; + } +#if MEGDNN_CC_HOST + LSQBwdKernOpNonContig(const LSQ::Param& param) + : qmin(param.qmin), qmax(param.qmax) {} +#endif +}; + +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/lsq/opr_impl.cpp b/dnn/src/cuda/lsq/opr_impl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d4338c09fdbc4e814ceb3b06fee311e452158bde --- /dev/null +++ b/dnn/src/cuda/lsq/opr_impl.cpp @@ -0,0 +1,151 @@ +/** + * \file dnn/src/cuda/lsq/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "./opr_impl.h" +#include "./kern.cuh" +#include "src/common/utils.h" +namespace megdnn { +namespace cuda { + +void LSQForwardImpl::exec(_megdnn_tensor_in input, _megdnn_tensor_in scale, + _megdnn_tensor_in zero_point, + _megdnn_tensor_in grad_scale, + _megdnn_tensor_out output, + _megdnn_workspace workspace) { + check_exec(input.layout, scale.layout, zero_point.layout, grad_scale.layout, + output.layout, workspace.size); + + if (!input.layout.is_contiguous() || !output.layout.is_contiguous()) + return exec_noncontig(input, scale, zero_point, grad_scale, output); + + ElemwiseOpParamN<3> ele_param; + ele_param[0] = scale; + ele_param[0].layout = ele_param[0].layout.broadcast(input.layout); + ele_param[1] = zero_point; + ele_param[1].layout = ele_param[1].layout.broadcast(input.layout); + ele_param[2] = grad_scale; + ele_param[2].layout = ele_param[2].layout.broadcast(input.layout); + ele_param.init_from_given_tensor(); + auto m_param = param(); + auto stream = cuda_stream(handle()); + +#define cb(DType) \ + if (input.layout.dtype == DType()) { \ + using T = typename DTypeTrait::ctype; \ + run_elemwise, T, 3>(ele_param, stream, \ + {input, output, m_param}); \ + return; \ + } + cb(megdnn::dtype::Float32) +#undef cb +} + +void LSQForwardImpl::exec_noncontig(_megdnn_tensor_in input, + _megdnn_tensor_in scale, + _megdnn_tensor_in zero_point, + _megdnn_tensor_in grad_scale, + _megdnn_tensor_out output) { + ElemwiseOpParamN<5> ele_param; + ele_param[0] = output; + ele_param[1] = input; + ele_param[2] = scale; + ele_param[2].layout = ele_param[2].layout.broadcast(input.layout); + ele_param[3] = zero_point; + ele_param[3].layout = ele_param[3].layout.broadcast(input.layout); + ele_param[4] = grad_scale; + ele_param[4].layout = ele_param[4].layout.broadcast(input.layout); + ele_param.init_from_given_tensor(); + auto m_param = param(); + auto stream = cuda_stream(handle()); + +#define cb(DType) \ + if (input.layout.dtype == DType()) { \ + using T = typename DTypeTrait::ctype; \ + run_elemwise, T, 5>(ele_param, stream, \ + {m_param}); \ + return; \ + } + cb(megdnn::dtype::Float32) +#undef cb +} + +void LSQBackwardImpl::exec(_megdnn_tensor_in diff, _megdnn_tensor_in input, + _megdnn_tensor_in scale, + _megdnn_tensor_in zero_point, + _megdnn_tensor_in grad_scale, + _megdnn_tensor_out grad_x, _megdnn_tensor_out grad_s, + _megdnn_workspace workspace) { + check_exec(diff.layout, input.layout, scale.layout, zero_point.layout, + grad_scale.layout, grad_x.layout, grad_s.layout, workspace.size); + + if (!input.layout.is_contiguous() || !diff.layout.is_contiguous() || + !grad_x.layout.is_contiguous() || !grad_s.layout.is_contiguous()) + return exec_noncontig(diff, input, scale, zero_point, grad_scale, + grad_x, grad_s); + + ElemwiseOpParamN<3> ele_param; + ele_param[0] = scale; + ele_param[0].layout = ele_param[0].layout.broadcast(input.layout); + ele_param[1] = zero_point; + ele_param[1].layout = ele_param[1].layout.broadcast(input.layout); + ele_param[2] = grad_scale; + ele_param[2].layout = ele_param[2].layout.broadcast(input.layout); + ele_param.init_from_given_tensor(); + auto m_param = param(); + auto stream = cuda_stream(handle()); + +#define cb(DType) \ + if (grad_x.layout.dtype == DType()) { \ + using T = typename DTypeTrait::ctype; \ + run_elemwise, T, 3>( \ + ele_param, stream, {diff, input, grad_x, grad_s, m_param}); \ + return; \ + } + cb(megdnn::dtype::Float32) +#undef cb +} + +void LSQBackwardImpl::exec_noncontig(_megdnn_tensor_in diff, + _megdnn_tensor_in input, + _megdnn_tensor_in scale, + _megdnn_tensor_in zero_point, + _megdnn_tensor_in grad_scale, + _megdnn_tensor_out grad_x, + _megdnn_tensor_out grad_s) { + ElemwiseOpParamN<7> ele_param; + ele_param[0] = grad_x; + ele_param[1] = grad_s; + ele_param[2] = diff; + ele_param[3] = input; + ele_param[4] = scale; + ele_param[4].layout = ele_param[4].layout.broadcast(input.layout); + ele_param[5] = zero_point; + ele_param[5].layout = ele_param[5].layout.broadcast(input.layout); + ele_param[6] = grad_scale; + ele_param[6].layout = ele_param[6].layout.broadcast(input.layout); + ele_param.init_from_given_tensor(); + auto m_param = param(); + auto stream = cuda_stream(handle()); + +#define cb(DType) \ + if (input.layout.dtype == DType()) { \ + using T = typename DTypeTrait::ctype; \ + run_elemwise, T, 7>(ele_param, stream, \ + {m_param}); \ + return; \ + } + cb(megdnn::dtype::Float32) +#undef cb +} + +} // namespace cuda +} // namespace megdnn \ No newline at end of file diff --git a/dnn/src/cuda/lsq/opr_impl.h b/dnn/src/cuda/lsq/opr_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..aba0caf4658c42f310aeb5b1740e05c53cfb403a --- /dev/null +++ b/dnn/src/cuda/lsq/opr_impl.h @@ -0,0 +1,65 @@ +/** + * \file dnn/src/cuda/lsq/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once +#include "megdnn/oprs.h" +#include "src/cuda/utils.h" +namespace megdnn { +namespace cuda { + +class LSQForwardImpl final : public LSQForward { +public: + using LSQForward::LSQForward; + void exec(_megdnn_tensor_in input, _megdnn_tensor_in scale, + _megdnn_tensor_in zero_point, _megdnn_tensor_in grad_scale, + _megdnn_tensor_out output, _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes(const TensorLayout&, /* input */ + const TensorLayout&, /* scale */ + const TensorLayout&, /* zero_point */ + const TensorLayout&, /* grad_scale */ + const TensorLayout& /* output */) override { + return 0; + } + +private: + void exec_noncontig(_megdnn_tensor_in input, _megdnn_tensor_in scale, + _megdnn_tensor_in zero_point, + _megdnn_tensor_in grad_scale, + _megdnn_tensor_out output); +}; + +class LSQBackwardImpl final : public LSQBackward { +public: + using LSQBackward::LSQBackward; + void exec(_megdnn_tensor_in diff, _megdnn_tensor_in input, + _megdnn_tensor_in scale, _megdnn_tensor_in zero_point, + _megdnn_tensor_in grad_scale, _megdnn_tensor_out grad_x, + _megdnn_tensor_out grad_s, _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes(const TensorLayout& /* diff */, + const TensorLayout& /* input */, + const TensorLayout& /* scale */, + const TensorLayout& /* zero_point */, + const TensorLayout& /* grad_scale */, + const TensorLayout& /* grad_x */, + const TensorLayout& /* grad_s */) override { + return 0; + } + +private: + void exec_noncontig(_megdnn_tensor_in diff, _megdnn_tensor_in input, + _megdnn_tensor_in scale, _megdnn_tensor_in zero_point, + _megdnn_tensor_in grad_scale, _megdnn_tensor_out grad_x, + _megdnn_tensor_out grad_s); +}; + +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/naive/handle.cpp b/dnn/src/naive/handle.cpp index 4ec7e555e1dee8e58f97164adb614097044c3952..29560a6483b59d8d2f525f5305411e98597e5f71 100644 --- a/dnn/src/naive/handle.cpp +++ b/dnn/src/naive/handle.cpp @@ -50,6 +50,7 @@ #include "src/naive/local/opr_impl.h" #include "src/naive/local_share/opr_impl.h" #include "src/naive/lrn/opr_impl.h" +#include "src/naive/lsq/opr_impl.h" #include "src/naive/mask_conv/opr_impl.h" #include "src/naive/matrix_inverse/opr_impl.h" #include "src/naive/matrix_mul/opr_impl.h" diff --git a/dnn/src/naive/lsq/opr_impl.cpp b/dnn/src/naive/lsq/opr_impl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7568b695294be17e30606d29f22944e04c0e6fe6 --- /dev/null +++ b/dnn/src/naive/lsq/opr_impl.cpp @@ -0,0 +1,141 @@ +/** + * \file dnn/src/naive/lsq/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/naive/lsq/opr_impl.h" +#include +#include "megdnn/tensor_iter.h" +#include "src/common/elemwise_helper.cuh" +#include "src/common/utils.h" +#include "src/naive/handle.h" + +namespace { +using namespace megdnn; + +template +void forward_impl(const ElemwiseOpParamN<5> src, float qmin, float qmax) { + auto inp = tensor_iter_valonly(src[0]).begin(); + auto out = tensor_iter_valonly(src[1]).begin(); + auto scale = tensor_iter_valonly(src[2]).begin(); + auto zero_point = tensor_iter_valonly(src[3]).begin(); + auto grad_scale = tensor_iter_valonly(src[4]).begin(); + size_t total = src[0].layout.total_nr_elems(); + for (size_t i = 0; i < total; ++i) { + T x = (*inp) / (*scale) + (*zero_point); + x = x <= qmin ? qmin : x; + x = x >= qmax ? qmax : x; + x = round(x); + *out = (x - (*zero_point)) * (*scale); + ++inp; + ++out; + ++scale; + ++zero_point; + ++grad_scale; + } +} + +template +void backward_impl(const ElemwiseOpParamN<7> src, float qmin, float qmax) { + auto diff = tensor_iter_valonly(src[0]).begin(); + auto input = tensor_iter_valonly(src[1]).begin(); + auto scale = tensor_iter_valonly(src[2]).begin(); + auto zero_point = tensor_iter_valonly(src[3]).begin(); + auto grad_scale = tensor_iter_valonly(src[4]).begin(); + auto grad_x = tensor_iter_valonly(src[5]).begin(); + auto grad_s = tensor_iter_valonly(src[6]).begin(); + size_t total = src[0].layout.total_nr_elems(); + + for (size_t i = 0; i < total; ++i) { + T x = (*input) / (*scale) + (*zero_point); + bool ind_small = x < qmin; + bool ind_big = x > qmax; + bool ind_middle = ind_small ^ ind_big; + ind_middle = !ind_middle; + + *grad_s = ind_small * qmin + ind_big * qmax + + ind_middle * (-x + round(x)); + *grad_s = (*grad_s) * (*grad_scale) * (*diff); + *grad_x = ind_middle * (*diff); + + ++diff; + ++input; + ++scale; + ++zero_point; + ++grad_scale; + ++grad_x; + ++grad_s; + } +} + +} // namespace +namespace megdnn { +namespace naive { + +void LSQForwardImpl::exec(_megdnn_tensor_in input, _megdnn_tensor_in scale, + _megdnn_tensor_in zero_point, + _megdnn_tensor_in grad_scale, + _megdnn_tensor_out output, + _megdnn_workspace workspace) { + check_exec(input.layout, scale.layout, zero_point.layout, grad_scale.layout, + output.layout, workspace.size); + ElemwiseOpParamN<5> src; + src[0] = input; + src[1] = output; + src[2] = scale; + src[2].layout = src[2].layout.broadcast(input.layout); + src[3] = zero_point; + src[3].layout = src[3].layout.broadcast(input.layout); + src[4] = grad_scale; + src[4].layout = src[4].layout.broadcast(input.layout); +#define cb(DType) \ + if (input.layout.dtype == DType()) { \ + using T = typename DTypeTrait::ctype; \ + MEGDNN_DISPATCH_CPU_KERN_OPR( \ + forward_impl(src, param().qmin, param().qmax)); \ + return; \ + } + cb(dtype::Float32) +#undef cb +} + +void LSQBackwardImpl::exec(_megdnn_tensor_in diff, _megdnn_tensor_in input, + _megdnn_tensor_in scale, + _megdnn_tensor_in zero_point, + _megdnn_tensor_in grad_scale, + _megdnn_tensor_out grad_x, _megdnn_tensor_out grad_s, + _megdnn_workspace workspace) { + check_exec(diff.layout, input.layout, scale.layout, zero_point.layout, + grad_scale.layout, grad_x.layout, grad_s.layout, workspace.size); + ElemwiseOpParamN<7> src; + src[0] = diff; + src[1] = input; + src[2] = scale; + src[2].layout = src[2].layout.broadcast(input.layout); + src[3] = zero_point; + src[3].layout = src[3].layout.broadcast(input.layout); + src[4] = grad_scale; + src[4].layout = src[4].layout.broadcast(input.layout); + src[5] = grad_x; + src[6] = grad_s; +#define cb(DType) \ + if (diff.layout.dtype == DType() && grad_x.layout.dtype == DType() && \ + input.layout.dtype == DType()) { \ + using T = typename DTypeTrait::ctype; \ + MEGDNN_DISPATCH_CPU_KERN_OPR( \ + backward_impl(src, param().qmin, param().qmax)); \ + return; \ + } + cb(dtype::Float32) +#undef cb +} + +} // namespace naive +} // namespace megdnn diff --git a/dnn/src/naive/lsq/opr_impl.h b/dnn/src/naive/lsq/opr_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..6b583bd938a896ca0bc9477abc5293112c1b0d18 --- /dev/null +++ b/dnn/src/naive/lsq/opr_impl.h @@ -0,0 +1,53 @@ +/** + * \file dnn/src/naive/lsq/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once +#include "megdnn/oprs.h" + +namespace megdnn { +namespace naive { + +class LSQForwardImpl final : public LSQForward { +public: + using LSQForward::LSQForward; + void exec(_megdnn_tensor_in input, _megdnn_tensor_in scale, + _megdnn_tensor_in zero_point, _megdnn_tensor_in grad_scale, + _megdnn_tensor_out output, _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes(const TensorLayout& /* input */, + const TensorLayout& /* scale */, + const TensorLayout& /* zero_point */, + const TensorLayout& /* grad_scale */, + const TensorLayout& /* output */) override { + return 0; + } +}; + +class LSQBackwardImpl final : public LSQBackward { +public: + using LSQBackward::LSQBackward; + void exec(_megdnn_tensor_in diff, _megdnn_tensor_in input, + _megdnn_tensor_in scale, _megdnn_tensor_in zero_point, + _megdnn_tensor_in grad_scale, _megdnn_tensor_out grad_x, + _megdnn_tensor_out grad_s, _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes(const TensorLayout& /* diff */, + const TensorLayout& /* input */, + const TensorLayout& /* scale */, + const TensorLayout& /* zero_point */, + const TensorLayout& /* grad_scale */, + const TensorLayout& /* grad_x */, + const TensorLayout& /* grad_s */) override { + return 0; + } +}; + +} // namespace naive +} // namespace megdnn diff --git a/dnn/test/common/lsq.h b/dnn/test/common/lsq.h new file mode 100644 index 0000000000000000000000000000000000000000..5f165aa50e87b281a3eb6273bcf17b1d04e084ce --- /dev/null +++ b/dnn/test/common/lsq.h @@ -0,0 +1,53 @@ +/** + * \file dnn/test/common/lsq.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once +#include "megdnn/basic_types.h" +#include "megdnn/opr_param_defs.h" + +namespace megdnn { +namespace test { +namespace lsq { + +struct TestArg { + param::LSQ param; + TensorShape ishape; + TensorShape scale_shape; + TensorShape zeropoint_shape; + TensorShape gradscale_shape; + TestArg(param::LSQ param, TensorShape ishape, TensorShape scale_shape, + TensorShape zeropoint_shape, TensorShape gradscale_shape) + : param(param), + ishape(ishape), + scale_shape(scale_shape), + zeropoint_shape(zeropoint_shape), + gradscale_shape(gradscale_shape) {} +}; + +inline std::vector get_args() { + std::vector args; + param::LSQ cur_param; + + cur_param.qmin = -127; + cur_param.qmax = 127; + + for (size_t i = 10; i < 30; i += 2) { + args.emplace_back(cur_param, TensorShape{10, 64, i, i}, TensorShape{1}, + TensorShape{1}, TensorShape{1}); + } + + return args; +} + +} // namespace lsq +} // namespace test +} // namespace megdnn \ No newline at end of file diff --git a/dnn/test/cuda/lsq.cpp b/dnn/test/cuda/lsq.cpp new file mode 100644 index 0000000000000000000000000000000000000000..263ae3e2fbe69931bf276ff0e0964f38470c9f18 --- /dev/null +++ b/dnn/test/cuda/lsq.cpp @@ -0,0 +1,110 @@ +/** + * \file dnn/test/cuda/lsq.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "test/common/lsq.h" +#include "megdnn/oprs.h" +#include "test/common/checker.h" +#include "test/cuda/fixture.h" + +namespace megdnn { +namespace test { + +using namespace lsq; + +TEST_F(CUDA, LSQ) { + std::vector args = get_args(); + auto dtype = dtype::Float32(); + + for (auto&& arg : args) { + auto param = arg.param; + auto ishape = arg.ishape; + auto scale_shape = arg.scale_shape; + auto zeropoint_shape = arg.zeropoint_shape; + auto gradscale_shape = arg.gradscale_shape; + Checker checker(handle_cuda()); + checker.set_param(param) + .set_dtype(0, dtype) + .set_dtype(1, dtype) + .set_dtype(2, dtype) + .set_dtype(3, dtype) + .set_dtype(4, dtype) + .execs({ishape, scale_shape, zeropoint_shape, gradscale_shape, + ishape}); + } + // test noncontiguous layout + for (auto&& arg : args) { + auto param = arg.param; + auto ishape = arg.ishape; + auto sshape = arg.scale_shape; + auto zeropoint_shape = arg.zeropoint_shape; + auto gradscale_shape = arg.gradscale_shape; + Checker checker(handle_cuda()); + TensorLayout ilayout( + ishape, + {(long int)(ishape[1] * ishape[2] * ishape[3] * 2), + (long int)(ishape[2] * ishape[3]), (long int)ishape[3], 1}, + dtype::Float32()); + checker.set_param(param).execl({ilayout, + {sshape, dtype::Float32()}, + {zeropoint_shape, dtype::Float32()}, + {gradscale_shape, dtype::Float32()}, + ilayout}); + } +} + +TEST_F(CUDA, LSQ_BACKWARD) { + std::vector args = get_args(); + auto dtype = dtype::Float32(); + + for (auto&& arg : args) { + auto param = arg.param; + auto ishape = arg.ishape; + auto scale_shape = arg.scale_shape; + auto zeropoint_shape = arg.zeropoint_shape; + auto gradscale_shape = arg.gradscale_shape; + Checker checker(handle_cuda()); + checker.set_param(param) + .set_dtype(0, dtype) + .set_dtype(1, dtype) + .set_dtype(2, dtype) + .set_dtype(3, dtype) + .set_dtype(4, dtype) + .set_dtype(5, dtype) + .set_dtype(6, dtype) + .execs({ishape, ishape, scale_shape, zeropoint_shape, + gradscale_shape, ishape, ishape}); + } + // test noncontiguous layout + for (auto&& arg : args) { + auto param = arg.param; + auto ishape = arg.ishape; + auto sshape = arg.scale_shape; + auto zeropoint_shape = arg.zeropoint_shape; + auto gradscale_shape = arg.gradscale_shape; + Checker checker(handle_cuda()); + TensorLayout ilayout( + ishape, + {(long int)(ishape[1] * ishape[2] * ishape[3] * 2), + (long int)(ishape[2] * ishape[3]), (long int)ishape[3], 1}, + dtype::Float32()); + checker.set_param(param).execl({ilayout, + ilayout, + {sshape, dtype::Float32()}, + {zeropoint_shape, dtype::Float32()}, + {gradscale_shape, dtype::Float32()}, + ilayout, + ilayout}); + } +} + +} // namespace test +} // namespace megdnn \ No newline at end of file diff --git a/dnn/test/naive/lsq.cpp b/dnn/test/naive/lsq.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3ea601cac801bd668fd4273b434f09e61f627bac --- /dev/null +++ b/dnn/test/naive/lsq.cpp @@ -0,0 +1,45 @@ +/** + * \file dnn/test/naive/sliding_window_transpose.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "test/naive/fixture.h" + +#include "megdnn/oprs/nn.h" +#include "test/common/checker.h" + +using namespace megdnn; +using namespace test; + +TEST_F(NAIVE, LSQ_FORWARD) { + Checker checker(handle(), /* check_dispatch */ false); + + param::LSQ param; + + param.qmin = -127; + param.qmax = 127; + + TensorND input = + TensorValue({2, 2, 2, 2}, dtype::Float32(), + {0, 1, 3, 4, 1, 2, 4, 5, 3, 4, 6, 7, 4, 5, 7, 8}); + + TensorND scale_shape = TensorValue({1}, dtype::Float32(), {2}); + + TensorND zero_point = TensorValue({1}, dtype::Float32(), {1}); + + TensorND grad_scale = TensorValue({1}, dtype::Float32(), {0.5}); + + TensorND output = + TensorValue({2, 2, 2, 2}, dtype::Float32(), + {0, 2, 4, 4, 2, 2, 4, 6, 4, 4, 6, 8, 4, 6, 8, 8}); + + checker.set_param(param).exect( + Testcase{input, scale_shape, zero_point, grad_scale, {}}, + Testcase{{}, {}, {}, {}, output}); +} diff --git a/imperative/python/megengine/quantization/__init__.py b/imperative/python/megengine/quantization/__init__.py index 2d6bf959b3b460152db31a5755728e6fe5439eb9..4177c4e9772e2679a69f6fd19867acf073ae8878 100644 --- a/imperative/python/megengine/quantization/__init__.py +++ b/imperative/python/megengine/quantization/__init__.py @@ -6,7 +6,7 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -from .fake_quant import TQT, FakeQuantize +from .fake_quant import LSQ, TQT, FakeQuantize from .observer import ( ExponentialMovingAverageObserver, HistogramObserver, diff --git a/imperative/python/megengine/quantization/fake_quant.py b/imperative/python/megengine/quantization/fake_quant.py index b38df20a1c1cd2c468b9b489740253617569f58e..ae6c09123b12733c0e6757a3da0411ef35ec9e6a 100644 --- a/imperative/python/megengine/quantization/fake_quant.py +++ b/imperative/python/megengine/quantization/fake_quant.py @@ -12,13 +12,15 @@ from .. import functional as F from ..core.tensor.dtype import QuantDtypeMeta, _builtin_quant_dtypes from ..logger import get_logger from ..module import Module -from ..tensor import Parameter +from ..tensor import Parameter, Tensor from .utils import ( + LSQParams, QParams, QParamsModuleMixin, QuantMode, create_qparams, fake_quant_tensor, + lsq_forward, tqt_forward, ) @@ -117,3 +119,58 @@ class FakeQuantize(_FakeQuantize): qparams.dtype_meta, self.dtype ) return fake_quant_tensor(inp, qparams) + + +class LSQ(_FakeQuantize, QParamsModuleMixin): + r""" + LSQ: https://arxiv.org/pdf/1902.08153.pdf Estimating and scaling the + task loss gradient at each weight and activation layer's quantizer step size + + :param dtype: a string or :class:`~.QuantDtypeMeta` indicating the target + quantization dtype of input. + :param enable: whether do ``normal_forward`` or ``fake_quant_forward``. + :param eps:a small value to avoid division by zero. Default: 1e-5 + """ + + def init( + self, + dtype: Union[str, QuantDtypeMeta], + enable: bool = True, + eps: float = 1e-5, + **kwargs + ): + super().__init__(dtype=dtype, enable=enable, **kwargs) + self.eps = Tensor(eps, dtype="float32") + self.step_size = Parameter(1.0, dtype="float32") + + def set_qparams(self, qparams: LSQParams): + self.mode = qparams.mode + if qparams.mode == QuantMode.ASYMMERTIC: + self.zero_point = qparams.zero_point + else: + self.zero_point = Tensor([0.0], dtype="float32") + if qparams.scale is None: + raise AssertionError("Can not get an initialized scale") + init_step_size = qparams.scale + if init_step_size < self.eps: + init_step_size = 0 + else: + init_step_size = init_step_size - self.eps + self.step_size = Parameter(init_step_size, dtype="float32") + + self.grad_scale = qparams.grad_scale + + def fake_quant_forward(self, inp, qparams: LSQParams = None): + step_size = F.abs(self.step_size) + self.eps + return lsq_forward( + self.qmin, self.qmax, inp, step_size, self.zero_point, self.grad_scale + ) + + def get_qparams(self): + return LSQParams( + mode=self.mode, + dtype_meta=self.dtype, + scale=F.abs(self.step_size.detach()) + self.eps, + zero_point=self.zero_point, + grad_scale=self.grad_scale, + ) diff --git a/imperative/python/megengine/quantization/utils.py b/imperative/python/megengine/quantization/utils.py index 127a134c2bcf368603834b93b0ce2986e2dd3c9d..a83844fdd7fb78fd79ea63b3650ae8d5f1c96165 100644 --- a/imperative/python/megengine/quantization/utils.py +++ b/imperative/python/megengine/quantization/utils.py @@ -43,6 +43,16 @@ def tqt_forward(qmin, qmax, inp, scale): return output +def lsq_forward(qmin, qmax, inp, step_size, zero_point=None, scale_grad=None): + if zero_point is None: + zero_point = Tensor([0.0], dtype=np.float32) + if scale_grad is None: + scale_grad = Tensor([1.0], dtype=np.float32) + op = builtin.LSQ(qmin=qmin, qmax=qmax) + (output,) = apply(op, inp, step_size, zero_point, scale_grad) + return output + + def register_method_to_class(cls): def decorator(func): @wraps(func) @@ -105,6 +115,47 @@ class QParams: return "QParams({})".format(content) +class LSQParams: + """ + To standardize LSQ's qparams format. If custom + qparams is needed, inherit this class and add custom ``__slots__``. + """ + + __slots__ = "mode", "dtype_meta", "scale", "zero_point", "grad_scale" + + def __init__( + self, + mode: QuantMode, + dtype_meta: QuantDtypeMeta, + scale: Tensor, + zero_point: Tensor, + grad_scale: Tensor, + ): + self.mode = mode + self.dtype_meta = dtype_meta + self.scale = scale + self.zero_point = zero_point + self.grad_scale = grad_scale + + def update(self, lsqparams: "LSQParams"): + for key in self.__slots__: + setattr(self, key, getattr(lsqparams, key)) + + def __eq__(self, other): + if len(self.__slots__) != len(other.__slots__): + return False + for key in self.__slots__: + if not hasattr(other, key) or getattr(self, key) != getattr(other, key): + return False + return True + + def __repr__(self): + content = ", ".join( + ["{}={}".format(key, getattr(self, key)) for key in self.__slots__] + ) + return "LSQParams({})".format(content) + + class QParamsModuleMixin(abc.ABC): def get_quantized_dtype(self): qparams = self.get_qparams() diff --git a/imperative/python/test/unit/quantization/test_fake_quant.py b/imperative/python/test/unit/quantization/test_fake_quant.py index a72aa6a48dcfe53980622d1e46aa440199bd61ad..9f93a1759f03a890383cad568d8a3ea9c56ee3bd 100644 --- a/imperative/python/test/unit/quantization/test_fake_quant.py +++ b/imperative/python/test/unit/quantization/test_fake_quant.py @@ -10,6 +10,7 @@ import numpy as np import pytest import megengine as mge +import megengine.functional as F from megengine import tensor from megengine.core.autodiff.grad import Function, Grad from megengine.core.tensor.dtype import QuantDtypeMeta @@ -19,6 +20,7 @@ from megengine.quantization.utils import ( QuantMode, create_qparams, fake_quant_tensor, + lsq_forward, tqt_forward, ) @@ -150,3 +152,78 @@ def test_fakequant(): zero_point = tensor(1.0 * np.ones((1, 32, 1, 1)), dtype=np.float32) scale = tensor(4.0 * np.ones((1, 32, 1, 1)), dtype=np.float32) run(zero_point, scale) + + +class LSQ_numpy: + def __init__(self, lowerbound, upperbound): + super().__init__() + self.lowerbound = lowerbound + self.upperbound = upperbound + + def forward(self, inp, scale, zero_point, grad_scale): + inp_scaled = inp / scale + zero_point + inp_clipped = np.maximum( + np.minimum(inp_scaled, self.upperbound), self.lowerbound + ) + inp_rounded = np.floor(inp_clipped + 0.5) + inp_flq = (inp_rounded - zero_point) * scale + self.saved_tensors = (inp_scaled, inp_rounded, scale, grad_scale) + return inp_flq + + def backward(self, grad_inp_flq): + (inp_scaled, inp_rounded, scale, grad_scale) = self.saved_tensors + + ind_small = inp_scaled < self.lowerbound + ind_big = inp_scaled > self.upperbound + ind_middle = np.logical_xor(ind_small, ind_big) + ind_middle = np.abs(ind_middle - 1) + + grad_s = ( + ind_small * self.lowerbound + + ind_big * self.upperbound + + ind_middle * (-inp_scaled + inp_rounded) + ) + grad_s = grad_s * grad_scale * grad_inp_flq + grad_s = grad_s.sum() + grad_inp = grad_inp_flq * ind_middle + + return grad_inp, grad_s + + +def test_lsq(): + def preprocess(scale, eps): + scale = np.array([0]) if scale < eps else scale - eps + return np.abs(scale) + eps + + g = [] + + def cb(grad): + g.append(grad) + + x = np.random.randint(-128, 128, size=(1, 2, 3, 4)).astype("float32") + s = np.random.rand(1) + eps = np.array([1e-5], dtype="float32") + s = preprocess(s, eps) + zero_point = np.array([1.0], dtype="float32") + grad_s = np.array([2.0], dtype="float32") + + g_y = np.ones(shape=(1, 2, 3, 4), dtype="float32") + + n = LSQ_numpy(-127, 127) + y_np = n.forward(x, s, zero_point, grad_s) + g_x_np, g_s_np = n.backward(g_y) + + x = mge.tensor(x, dtype="float32") + s = mge.tensor(s, dtype="float32") + zero_point = mge.tensor(zero_point, dtype="float32") + grad_s = mge.tensor(grad_s, dtype="float32") + + g_y = mge.tensor(g_y, dtype="float32") + grad = Grad().wrt(x, s, callback=cb) + y = lsq_forward(-127, 127, x, s, zero_point, grad_s) + grad(y, g_y) + g_x, g_s = g + + np.testing.assert_allclose(y.numpy(), y_np, rtol=1e-7, atol=1e-7) + np.testing.assert_allclose(g_x.numpy(), g_x_np, rtol=1e-7, atol=1e-7) + np.testing.assert_allclose(g_s.numpy(), g_s_np, rtol=5e-7, atol=5e-7) diff --git a/imperative/src/impl/ops/specializations.cpp b/imperative/src/impl/ops/specializations.cpp index 185b1d7ac2e26180d075e020afa72f8a3e5a6b3d..cc6bb192e0ff3671f6cf3bcd532f7451a3d55d6a 100644 --- a/imperative/src/impl/ops/specializations.cpp +++ b/imperative/src/impl/ops/specializations.cpp @@ -6,23 +6,26 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ // FIXME: split this file into separate files for each specialized op #include "megbrain/imperative/ops/autogen.h" -#include "megbrain/opr/dnn/convolution.h" +#include "megbrain/opr/basic_arith.h" +#include "megbrain/opr/blas.h" #include "megbrain/opr/dnn/adaptive_pooling.h" +#include "megbrain/opr/dnn/convolution.h" +#include "megbrain/opr/dnn/correlation.h" #include "megbrain/opr/dnn/fake_quant.h" -#include "megbrain/opr/dnn/tqt.h" -#include "megbrain/opr/dnn/pooling.h" +#include "megbrain/opr/dnn/images2neibs.h" #include "megbrain/opr/dnn/local.h" +#include "megbrain/opr/dnn/lsq.h" +#include "megbrain/opr/dnn/pooling.h" #include "megbrain/opr/dnn/roi_align.h" -#include "megbrain/opr/dnn/correlation.h" #include "megbrain/opr/dnn/roi_pooling.h" -#include "megbrain/opr/basic_arith.h" -#include "megbrain/opr/blas.h" +#include "megbrain/opr/dnn/tqt.h" #include "megbrain/opr/imgproc.h" #include "megbrain/opr/indexing.h" #include "megbrain/opr/io.h" @@ -32,40 +35,38 @@ #include "megbrain/opr/tensor_gen.h" #include "megbrain/opr/tensor_manip.h" #include "megbrain/opr/utility.h" -#include "megbrain/opr/dnn/images2neibs.h" #include "../op_trait.h" namespace mgb::imperative { -namespace { namespace dimshuffle { +namespace { +namespace dimshuffle { std::shared_ptr make_from_op_node(cg::OperatorNodeBase* node_) { auto* node = &node_->cast_final_safe(); std::vector pattern(node->param().pattern_len); - for (size_t i = 0; i < node->param().pattern_len; ++ i) { + for (size_t i = 0; i < node->param().pattern_len; ++i) { pattern[i] = node->param().pattern[i]; } return Dimshuffle::make(pattern); } -auto apply_on_var_node( - const OpDef& def, - const VarNodeArray& inputs) { +auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto&& ds = static_cast(def); OperatorNodeConfig config{ds.make_name()}; return opr::Dimshuffle::make(inputs[0], ds.pattern, 0UL, config); } OP_TRAIT_REG(Dimshuffle, Dimshuffle, opr::Dimshuffle) - .make_from_op_node(make_from_op_node) - .apply_on_var_node(apply_on_var_node) - .fallback(); -}} // dimshuffle + .make_from_op_node(make_from_op_node) + .apply_on_var_node(apply_on_var_node) + .fallback(); +} // namespace dimshuffle +} // namespace -namespace { namespace add_axis { -auto apply_on_var_node( - const OpDef& def, - const VarNodeArray& inputs) { +namespace { +namespace add_axis { +auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto&& add_axis = static_cast(def); using Desc = opr::AxisAddRemove::AxisDesc; std::vector param; @@ -76,15 +77,13 @@ auto apply_on_var_node( return opr::AxisAddRemove::make(inputs[0], param, config); } -OP_TRAIT_REG(AddAxis, AddAxis) - .apply_on_var_node(apply_on_var_node) - .fallback(); -}} // add_axis +OP_TRAIT_REG(AddAxis, AddAxis).apply_on_var_node(apply_on_var_node).fallback(); +} // namespace add_axis +} // namespace -namespace { namespace remove_axis { -auto apply_on_var_node( - const OpDef& def, - const VarNodeArray& inputs) { +namespace { +namespace remove_axis { +auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto&& remove_axis = static_cast(def); using Desc = opr::AxisAddRemove::AxisDesc; std::vector param; @@ -96,36 +95,35 @@ auto apply_on_var_node( } OP_TRAIT_REG(RemoveAxis, RemoveAxis) - .apply_on_var_node(apply_on_var_node) - .fallback(); -}} // remove_axis + .apply_on_var_node(apply_on_var_node) + .fallback(); +} // namespace remove_axis +} // namespace -namespace { namespace top_k { -auto apply_on_var_node( - const OpDef& def, - const VarNodeArray& inputs) { +namespace { +namespace top_k { +auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto&& topk = static_cast(def); OperatorNodeConfig config{topk.make_name()}; return opr::TopK::make(inputs[0], inputs[1], topk.param(), config)[0] - .node()->owner_opr(); + .node() + ->owner_opr(); } -OP_TRAIT_REG(TopK, TopK) - .apply_on_var_node(apply_on_var_node) - .fallback(); -}} // top_k +OP_TRAIT_REG(TopK, TopK).apply_on_var_node(apply_on_var_node).fallback(); +} // namespace top_k +} // namespace -namespace { namespace reduce { -auto apply_on_var_node( - const OpDef& def, - const VarNodeArray& inputs) { +namespace { +namespace reduce { +auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto&& reduce = static_cast(def); OperatorNodeConfig config{reduce.make_name()}; if (inputs.size() > 1) { return opr::Reduce::make(inputs[0], reduce.param(), inputs[1], config); } else { - return opr::Reduce::make( - inputs[0], reduce.param(), (cg::VarNode*)nullptr, config); + return opr::Reduce::make(inputs[0], reduce.param(), + (cg::VarNode*)nullptr, config); } } @@ -135,86 +133,92 @@ std::shared_ptr make_from_op_node(cg::OperatorNodeBase* node_) { } OP_TRAIT_REG(Reduce, Reduce, opr::Reduce) - .make_from_op_node(make_from_op_node) - .apply_on_var_node(apply_on_var_node) - .fallback(); -}} // reduce + .make_from_op_node(make_from_op_node) + .apply_on_var_node(apply_on_var_node) + .fallback(); +} // namespace reduce +} // namespace -namespace { namespace adaptive_pooling { -auto apply_on_var_node( - const OpDef& def, - const VarNodeArray& inputs) { +namespace { +namespace adaptive_pooling { +auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto&& pool = static_cast(def); OperatorNodeConfig config{pool.make_name()}; - return opr::AdaptivePooling::make(inputs[0], inputs[1], pool.param(), config); + return opr::AdaptivePooling::make(inputs[0], inputs[1], pool.param(), + config); } OP_TRAIT_REG(AdaptivePooling, AdaptivePooling) - .apply_on_var_node(apply_on_var_node) - .fallback(); -}} // adaptive_pooling + .apply_on_var_node(apply_on_var_node) + .fallback(); +} // namespace adaptive_pooling +} // namespace -namespace { namespace conv_bias { -auto apply_on_var_node( - const OpDef& def, - const VarNodeArray& inputs) { +namespace { +namespace conv_bias { +auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto&& conv = static_cast(def); cg::OperatorNodeConfig config{conv.dtype}; config.name(conv.make_name()); if (inputs.size() == 2) { - return opr::ConvBias::make(inputs[0], inputs[1], conv.param(), conv.policy(), config); + return opr::ConvBias::make(inputs[0], inputs[1], conv.param(), + conv.policy(), config); } else if (inputs.size() == 3) { - return opr::ConvBias::make(inputs[0], inputs[1], inputs[2], conv.param(), conv.policy(), config); + return opr::ConvBias::make(inputs[0], inputs[1], inputs[2], + conv.param(), conv.policy(), config); } else if (inputs.size() == 4) { - return opr::ConvBias::make(inputs[0], inputs[1], inputs[2], inputs[3], conv.param(), conv.policy(), config); + return opr::ConvBias::make(inputs[0], inputs[1], inputs[2], inputs[3], + conv.param(), conv.policy(), config); } mgb_assert(0); } OP_TRAIT_REG(ConvBias, ConvBias) - .apply_on_var_node(apply_on_var_node) - .fallback(); -}} // conv_bias + .apply_on_var_node(apply_on_var_node) + .fallback(); +} // namespace conv_bias +} // namespace -namespace { namespace batch_conv_bias { -auto apply_on_var_node( - const OpDef& def, - const VarNodeArray& inputs) { +namespace { +namespace batch_conv_bias { +auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto&& conv = static_cast(def); cg::OperatorNodeConfig config{conv.dtype}; config.name(conv.make_name()); if (inputs.size() == 2) { - return opr::BatchConvBias::make(inputs[0], inputs[1], conv.param(), conv.policy(), config); + return opr::BatchConvBias::make(inputs[0], inputs[1], conv.param(), + conv.policy(), config); } else if (inputs.size() == 3) { - return opr::BatchConvBias::make(inputs[0], inputs[1], inputs[2], conv.param(), conv.policy(), config); + return opr::BatchConvBias::make(inputs[0], inputs[1], inputs[2], + conv.param(), conv.policy(), config); } else if (inputs.size() == 4) { - return opr::BatchConvBias::make(inputs[0], inputs[1], inputs[2], inputs[3], conv.param(), conv.policy(), config); + return opr::BatchConvBias::make(inputs[0], inputs[1], inputs[2], + inputs[3], conv.param(), conv.policy(), + config); } mgb_assert(0); } OP_TRAIT_REG(BatchConvBias, BatchConvBias) - .apply_on_var_node(apply_on_var_node) - .fallback(); -}} // batch_conv_bias + .apply_on_var_node(apply_on_var_node) + .fallback(); +} // namespace batch_conv_bias +} // namespace -namespace { namespace pooling { -auto apply_on_var_node( - const OpDef& def, - const VarNodeArray& inputs) { +namespace { +namespace pooling { +auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto&& pool = static_cast(def); OperatorNodeConfig config{pool.make_name()}; return opr::Pooling::make(inputs[0], pool.param(), config); } -OP_TRAIT_REG(Pooling, Pooling) - .apply_on_var_node(apply_on_var_node) - .fallback(); -}} // pooling +OP_TRAIT_REG(Pooling, Pooling).apply_on_var_node(apply_on_var_node).fallback(); +} // namespace pooling +} // namespace -namespace { namespace matrix_mul { -auto apply_on_var_node( - const OpDef& def, - const VarNodeArray& inputs) { +namespace { +namespace matrix_mul { +auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto&& matmul = static_cast(def); mgb_assert(inputs.size() == 2); OperatorNodeConfig config{matmul.make_name()}; @@ -222,14 +226,14 @@ auto apply_on_var_node( matmul.policy(), config); } OP_TRAIT_REG(MatrixMul, MatrixMul) - .apply_on_var_node(apply_on_var_node) - .fallback(); -}} // matrix_mul + .apply_on_var_node(apply_on_var_node) + .fallback(); +} // namespace matrix_mul +} // namespace -namespace { namespace batched_matrix_mul { -auto apply_on_var_node( - const OpDef& def, - const VarNodeArray& inputs) { +namespace { +namespace batched_matrix_mul { +auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto&& matmul = static_cast(def); mgb_assert(inputs.size() == 2); OperatorNodeConfig config{matmul.make_name()}; @@ -237,166 +241,155 @@ auto apply_on_var_node( matmul.policy(), config); } OP_TRAIT_REG(BatchedMatrixMul, BatchedMatrixMul) - .apply_on_var_node(apply_on_var_node) - .fallback(); -}} // batched_matrix_mul + .apply_on_var_node(apply_on_var_node) + .fallback(); +} // namespace batched_matrix_mul +} // namespace -namespace { namespace dot { -auto apply_on_var_node( - const OpDef& def, - const VarNodeArray& inputs) { +namespace { +namespace dot { +auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto&& op = def.cast_final_safe(); mgb_assert(inputs.size() == 2); OperatorNodeConfig config{op.make_name()}; return opr::Dot::make(inputs[0], inputs[1], config); } -OP_TRAIT_REG(Dot, Dot) - .apply_on_var_node(apply_on_var_node) - .fallback(); -}} // dot +OP_TRAIT_REG(Dot, Dot).apply_on_var_node(apply_on_var_node).fallback(); +} // namespace dot +} // namespace -namespace { namespace argsort { -auto apply_on_var_node( - const OpDef& def, - const VarNodeArray& inputs) { +namespace { +namespace argsort { +auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto&& argsort = static_cast(def); OperatorNodeConfig config{argsort.make_name()}; return opr::Argsort::make(inputs[0], argsort.param(), config); } -OP_TRAIT_REG(Argsort, Argsort) - .apply_on_var_node(apply_on_var_node) - .fallback(); -}} // argsort +OP_TRAIT_REG(Argsort, Argsort).apply_on_var_node(apply_on_var_node).fallback(); +} // namespace argsort +} // namespace -namespace { namespace argmax { -auto apply_on_var_node( - const OpDef& def, - const VarNodeArray& inputs) { +namespace { +namespace argmax { +auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto&& argmax = static_cast(def); OperatorNodeConfig config{argmax.make_name()}; return opr::Argmax::make(inputs[0], argmax.param(), config); } -OP_TRAIT_REG(Argmax, Argmax) - .apply_on_var_node(apply_on_var_node) - .fallback(); -}} // argmax +OP_TRAIT_REG(Argmax, Argmax).apply_on_var_node(apply_on_var_node).fallback(); +} // namespace argmax +} // namespace -namespace { namespace argmin { -auto apply_on_var_node( - const OpDef& def, - const VarNodeArray& inputs) { +namespace { +namespace argmin { +auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto&& argmin = static_cast(def); OperatorNodeConfig config{argmin.make_name()}; return opr::Argmin::make(inputs[0], argmin.param(), config); } -OP_TRAIT_REG(Argmin, Argmin) - .apply_on_var_node(apply_on_var_node) - .fallback(); -}} // argmin +OP_TRAIT_REG(Argmin, Argmin).apply_on_var_node(apply_on_var_node).fallback(); +} // namespace argmin +} // namespace -namespace { namespace warp_perspective { -auto apply_on_var_node( - const OpDef& def, - const VarNodeArray& inputs) { +namespace { +namespace warp_perspective { +auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto&& warp = static_cast(def); OperatorNodeConfig config{warp.make_name()}; if (inputs.size() == 3) { - return opr::WarpPerspective::make(inputs[0], inputs[1], inputs[2], warp.param(), config); + return opr::WarpPerspective::make(inputs[0], inputs[1], inputs[2], + warp.param(), config); } else { mgb_assert(inputs.size() == 4); - return opr::WarpPerspective::make( - inputs[0], inputs[1], inputs[2], inputs[3], warp.param(), config); + return opr::WarpPerspective::make(inputs[0], inputs[1], inputs[2], + inputs[3], warp.param(), config); } } OP_TRAIT_REG(WarpPerspective, WarpPerspective) - .apply_on_var_node(apply_on_var_node) - .fallback(); -}} // warp_perspective + .apply_on_var_node(apply_on_var_node) + .fallback(); +} // namespace warp_perspective +} // namespace -namespace { namespace group_local { -auto apply_on_var_node( - const OpDef& def, - const VarNodeArray& inputs) { +namespace { +namespace group_local { +auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto&& local = static_cast(def); mgb_assert(inputs.size() == 2); OperatorNodeConfig config{local.make_name()}; return opr::GroupLocal::make(inputs[0], inputs[1], local.param(), config); } OP_TRAIT_REG(GroupLocal, GroupLocal) - .apply_on_var_node(apply_on_var_node) - .fallback(); -}} // group_local + .apply_on_var_node(apply_on_var_node) + .fallback(); +} // namespace group_local +} // namespace -namespace { namespace indexing_one_hot { -auto apply_on_var_node( - const OpDef& def, - const VarNodeArray& inputs) { +namespace { +namespace indexing_one_hot { +auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto&& op = static_cast(def); mgb_assert(inputs.size() == 2); OperatorNodeConfig config{op.make_name()}; return opr::IndexingOneHot::make(inputs[0], inputs[1], op.param(), config); } OP_TRAIT_REG(IndexingOneHot, IndexingOneHot) - .apply_on_var_node(apply_on_var_node) - .fallback(); -}} // indexing_one_hot + .apply_on_var_node(apply_on_var_node) + .fallback(); +} // namespace indexing_one_hot +} // namespace -namespace { namespace indexing_set_one_hot { -auto apply_on_var_node( - const OpDef& def, - const VarNodeArray& inputs) { +namespace { +namespace indexing_set_one_hot { +auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto&& op = static_cast(def); mgb_assert(inputs.size() == 3); OperatorNodeConfig config{op.make_name()}; - return opr::IndexingSetOneHot::make(inputs[0], inputs[1], inputs[2], op.param(), config); + return opr::IndexingSetOneHot::make(inputs[0], inputs[1], inputs[2], + op.param(), config); } OP_TRAIT_REG(IndexingSetOneHot, IndexingSetOneHot) - .apply_on_var_node(apply_on_var_node) - .fallback(); -}} // indexing_set_one_hot + .apply_on_var_node(apply_on_var_node) + .fallback(); +} // namespace indexing_set_one_hot +} // namespace -namespace { namespace typecvt { -auto apply_on_var_node( - const OpDef& def, - const VarNodeArray& inputs) { +namespace { +namespace typecvt { +auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto&& op = static_cast(def); mgb_assert(inputs.size() == 1); OperatorNodeConfig config{op.make_name()}; return opr::TypeCvt::make(inputs[0], op.dtype, config); } -OP_TRAIT_REG(TypeCvt, TypeCvt) - .apply_on_var_node(apply_on_var_node) - .fallback(); -}} // typecvt +OP_TRAIT_REG(TypeCvt, TypeCvt).apply_on_var_node(apply_on_var_node).fallback(); +} // namespace typecvt +} // namespace -namespace { namespace concat { -auto apply_on_var_node( - const OpDef& def, - const VarNodeArray& inputs) { +namespace { +namespace concat { +auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto&& op = static_cast(def); cg::OperatorNodeConfig config{op.comp_node}; config.name(op.make_name()); return opr::Concat::make(inputs, op.axis, config); } -OP_TRAIT_REG(Concat, Concat) - .apply_on_var_node(apply_on_var_node) - .fallback(); -}} // concat +OP_TRAIT_REG(Concat, Concat).apply_on_var_node(apply_on_var_node).fallback(); +} // namespace concat +} // namespace -namespace { namespace copy { -auto apply_on_var_node( - const OpDef& def, - const VarNodeArray& inputs) { +namespace { +namespace copy { +auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto&& op = static_cast(def); mgb_assert(inputs.size() == 1); cg::OperatorNodeConfig config{op.comp_node}; config.name(op.make_name()); return opr::Copy::make(inputs[0], config); } -OP_TRAIT_REG(Copy, Copy) - .apply_on_var_node(apply_on_var_node) - .fallback(); -}} // copy +OP_TRAIT_REG(Copy, Copy).apply_on_var_node(apply_on_var_node).fallback(); +} // namespace copy +} // namespace namespace { namespace assert_equal { auto apply_on_var_node( @@ -408,81 +401,81 @@ auto apply_on_var_node( } else { // workaround for MiniGraph, which only allow one opr in the graph mgb_assert(inputs.size() == 3); - return opr::AssertEqual::make(inputs[0], inputs[1], inputs[2], op.param(), {}); + return opr::AssertEqual::make(inputs[0], inputs[1], inputs[2], + op.param(), {}); } } OP_TRAIT_REG(AssertEqual, AssertEqual) - .apply_on_var_node(apply_on_var_node) - .fallback(); -}} // assert_equal + .apply_on_var_node(apply_on_var_node) + .fallback(); +} // namespace assert_equal +} // namespace -namespace { namespace roi_align { -VarNodeArray apply_on_var_node( - const OpDef& def, - const VarNodeArray& inputs) { +namespace { +namespace roi_align { +VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto&& op = static_cast(def); mgb_assert(inputs.size() == 2); OperatorNodeConfig config{op.make_name()}; - auto* opr = opr::ROIAlign::make( - inputs[0], inputs[1], op.param(), config).node()->owner_opr(); + auto* opr = opr::ROIAlign::make(inputs[0], inputs[1], op.param(), config) + .node() + ->owner_opr(); return {opr->output(0), opr->output(1)}; } OP_TRAIT_REG(ROIAlign, ROIAlign) - .apply_on_var_node(apply_on_var_node) - .fallback(); -}} // roi_align + .apply_on_var_node(apply_on_var_node) + .fallback(); +} // namespace roi_align +} // namespace -namespace { namespace correlation { -auto apply_on_var_node( - const OpDef& def, - const VarNodeArray& inputs) { +namespace { +namespace correlation { +auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto&& op = static_cast(def); mgb_assert(inputs.size() == 2); OperatorNodeConfig config{op.make_name()}; - return opr::Correlation::make( - inputs[0], inputs[1], op.param(), config); + return opr::Correlation::make(inputs[0], inputs[1], op.param(), config); } OP_TRAIT_REG(Correlation, Correlation) - .apply_on_var_node(apply_on_var_node) - .fallback(); -}} // correlation + .apply_on_var_node(apply_on_var_node) + .fallback(); +} // namespace correlation +} // namespace #if MGB_CUDA -namespace { namespace nvof { -auto apply_on_var_node( - const OpDef& def, - const VarNodeArray& inputs) { +namespace { +namespace nvof { +auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto&& op = static_cast(def); mgb_assert(inputs.size() == 1); OperatorNodeConfig config{op.make_name()}; return opr::NvOf::make(inputs[0], op.param(), config); } -OP_TRAIT_REG(NvOf, NvOf) - .apply_on_var_node(apply_on_var_node) - .fallback(); -}} // nvof +OP_TRAIT_REG(NvOf, NvOf).apply_on_var_node(apply_on_var_node).fallback(); +} // namespace nvof +} // namespace #endif -namespace { namespace linspace { -auto apply_on_var_node( - const OpDef& def, - const VarNodeArray& inputs) { +namespace { +namespace linspace { +auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto&& op = static_cast(def); mgb_assert(inputs.size() == 3); cg::OperatorNodeConfig config{op.comp_node}; config.name(op.make_name()); - return opr::Linspace::make(inputs[0], inputs[1], inputs[2], op.param(), config); + return opr::Linspace::make(inputs[0], inputs[1], inputs[2], op.param(), + config); } OP_TRAIT_REG(Linspace, Linspace) - .apply_on_var_node(apply_on_var_node) - .fallback(); -}} // linspace + .apply_on_var_node(apply_on_var_node) + .fallback(); +} // namespace linspace +} // namespace -namespace { namespace eye { -auto apply_on_var_node( - const OpDef& def, - const VarNodeArray& inputs) { +namespace { +namespace eye { +auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto&& op = static_cast(def); mgb_assert(inputs.size() == 1); cg::OperatorNodeConfig config{op.comp_node}; @@ -490,58 +483,59 @@ auto apply_on_var_node( opr::Eye::Param param{op.k, op.dtype.enumv()}; return opr::Eye::make(inputs[0], param, config); } -OP_TRAIT_REG(Eye, Eye) - .apply_on_var_node(apply_on_var_node) - .fallback(); -}} // eye +OP_TRAIT_REG(Eye, Eye).apply_on_var_node(apply_on_var_node).fallback(); +} // namespace eye +} // namespace -namespace { namespace roi_pooling { -VarNodeArray apply_on_var_node( - const OpDef& def, - const VarNodeArray& inputs) { +namespace { +namespace roi_pooling { +VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto&& op = static_cast(def); mgb_assert(inputs.size() == 3); OperatorNodeConfig config{op.make_name()}; - auto* opr = opr::ROIPooling::make( - inputs[0], inputs[1], inputs[2], op.param(), config - ).node()->owner_opr(); + auto* opr = opr::ROIPooling::make(inputs[0], inputs[1], inputs[2], + op.param(), config) + .node() + ->owner_opr(); return {opr->output(0), opr->output(1)}; } OP_TRAIT_REG(ROIPooling, ROIPooling) - .apply_on_var_node(apply_on_var_node) - .fallback(); -}} // roi_pooling + .apply_on_var_node(apply_on_var_node) + .fallback(); +} // namespace roi_pooling +} // namespace -namespace { namespace remap { -auto apply_on_var_node( - const OpDef& def, - const VarNodeArray& inputs) { +namespace { +namespace remap { +auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto&& op = static_cast(def); mgb_assert(inputs.size() == 2); OperatorNodeConfig config{op.make_name()}; return opr::Remap::make(inputs[0], inputs[1], op.param(), config); } -OP_TRAIT_REG(Remap, Remap) - .apply_on_var_node(apply_on_var_node) - .fallback(); -}} // remap +OP_TRAIT_REG(Remap, Remap).apply_on_var_node(apply_on_var_node).fallback(); +} // namespace remap +} // namespace namespace { auto get_index( - const VarNodeArray& inputs, size_t vidx, - const std::vector>& mask) { + const VarNodeArray& inputs, size_t vidx, + const std::vector>& mask) { size_t length = mask.size(); opr::Subtensor::IndexDesc ret(length); - for (size_t i = 0; i < length; ++ i) { + for (size_t i = 0; i < length; ++i) { auto&& [axis, begin, end, step, idx] = mask[i]; ret[i].axis = axis; if (idx) { ret[i].idx = inputs[vidx++]; } else { mgb_assert(begin || end || step); - if (begin) ret[i].begin = inputs[vidx++]; - if (end) ret[i].end = inputs[vidx++]; - if (step) ret[i].step = inputs[vidx++]; + if (begin) + ret[i].begin = inputs[vidx++]; + if (end) + ret[i].end = inputs[vidx++]; + if (step) + ret[i].step = inputs[vidx++]; } } mgb_assert(vidx == inputs.size()); @@ -550,19 +544,19 @@ auto get_index( #define IN1 inputs[0] #define IN2 inputs[0], inputs[1] -#define FANCY_INDEXING_IMPL(NAME, NR_INPUT) \ -namespace NAME##_impl { \ -auto apply_on_var_node( \ - const OpDef& def, \ - const VarNodeArray& inputs) { \ - auto&& op = static_cast(def); \ - OperatorNodeConfig config{op.make_name()}; \ - return opr::NAME::make(IN##NR_INPUT, get_index(inputs, NR_INPUT, op.items), config); \ -} \ -OP_TRAIT_REG(NAME, NAME) \ - .apply_on_var_node(apply_on_var_node) \ - .fallback(); \ -} +#define FANCY_INDEXING_IMPL(NAME, NR_INPUT) \ + namespace NAME##_impl { \ + auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { \ + auto&& op = static_cast(def); \ + OperatorNodeConfig config{op.make_name()}; \ + return opr::NAME::make(IN##NR_INPUT, \ + get_index(inputs, NR_INPUT, op.items), \ + config); \ + } \ + OP_TRAIT_REG(NAME, NAME) \ + .apply_on_var_node(apply_on_var_node) \ + .fallback(); \ + } FANCY_INDEXING_IMPL(Subtensor, 1) FANCY_INDEXING_IMPL(SetSubtensor, 2) @@ -580,76 +574,88 @@ FANCY_INDEXING_IMPL(BatchedSetMeshIndexing, 2) #undef FANCY_INDEXING_IMPL #undef IN1 #undef IN2 -} // anonymous namespace +} // anonymous namespace -namespace { namespace fake_quant { -auto apply_on_var_node( - const OpDef& def, - const VarNodeArray& inputs) { +namespace { +namespace fake_quant { +auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto&& op = static_cast(def); mgb_assert(inputs.size() == 3); OperatorNodeConfig config{op.make_name()}; - return opr::FakeQuant::make(inputs[0], inputs[1], inputs[2], op.param(), config); + return opr::FakeQuant::make(inputs[0], inputs[1], inputs[2], op.param(), + config); } OP_TRAIT_REG(FakeQuant, FakeQuant) - .apply_on_var_node(apply_on_var_node) - .fallback(); -}} // fake_quant + .apply_on_var_node(apply_on_var_node) + .fallback(); +} // namespace fake_quant +} // namespace -namespace { namespace tqt { -auto apply_on_var_node( - const OpDef& def, - const VarNodeArray& inputs) { +namespace { +namespace tqt { +auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto&& op = static_cast(def); mgb_assert(inputs.size() == 2); OperatorNodeConfig config{op.make_name()}; return opr::TQT::make(inputs[0], inputs[1], op.param(), config); } -OP_TRAIT_REG(TQT, TQT) - .apply_on_var_node(apply_on_var_node) - .fallback(); -}} // tqt +OP_TRAIT_REG(TQT, TQT).apply_on_var_node(apply_on_var_node).fallback(); +} // namespace tqt +} // namespace -namespace { namespace elemwise_multi_type { -auto apply_on_var_node( - const OpDef& def, - const VarNodeArray& inputs) { +namespace { +namespace elemwise_multi_type { +auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto&& op = static_cast(def); OperatorNodeConfig config{op.dtype}; config.name(op.make_name()); return opr::ElemwiseMultiType::make(inputs, op.param(), config); } OP_TRAIT_REG(ElemwiseMultiType, ElemwiseMultiType) - .apply_on_var_node(apply_on_var_node) - .fallback(); -}} // elemwise_multi_type + .apply_on_var_node(apply_on_var_node) + .fallback(); +} // namespace elemwise_multi_type +} // namespace -namespace { namespace svd { -auto apply_on_var_node( - const OpDef& def, - const VarNodeArray& inputs) { +namespace { +namespace svd { +auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto&& op = static_cast(def); mgb_assert(inputs.size() == 1); OperatorNodeConfig config{op.make_name()}; return opr::SVD::make(inputs[0], op.param(), config)[0] - .node()->owner_opr()->usable_output(); + .node() + ->owner_opr() + ->usable_output(); } -OP_TRAIT_REG(SVD, SVD) - .apply_on_var_node(apply_on_var_node) - .fallback(); -}} // svd +OP_TRAIT_REG(SVD, SVD).apply_on_var_node(apply_on_var_node).fallback(); +} // namespace svd +} // namespace -namespace { namespace images2neibs { -auto apply_on_var_node( - const OpDef& def, - const VarNodeArray& inputs) { +namespace { +namespace images2neibs { +auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto&& op = static_cast(def); OperatorNodeConfig config{op.make_name()}; return opr::Images2Neibs::make(inputs[0], op.param(), config); } OP_TRAIT_REG(Images2Neibs, Images2Neibs) - .apply_on_var_node(apply_on_var_node) - .fallback(); -}} // images2neibs + .apply_on_var_node(apply_on_var_node) + .fallback(); +} // namespace images2neibs +} // namespace + +namespace { +namespace lsq { +auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { + auto&& op = static_cast(def); + mgb_assert(inputs.size() == 4); + OperatorNodeConfig config{op.make_name()}; + return opr::LSQ::make(inputs[0], inputs[1], inputs[2], inputs[3], + op.param(), config); +} +OP_TRAIT_REG(LSQ, LSQ).apply_on_var_node(apply_on_var_node).fallback(); +} // namespace lsq +} // namespace -} // namespace mgb::imperative +} // namespace mgb::imperative diff --git a/imperative/src/test/backward_graph.cpp b/imperative/src/test/backward_graph.cpp index e4ab27d0b582ede7c8b0bb00530b7a5b08e09c13..1ef56c588a2610c98ef393cc793db91e7200ffa4 100644 --- a/imperative/src/test/backward_graph.cpp +++ b/imperative/src/test/backward_graph.cpp @@ -6,22 +6,24 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #include "./helper.h" +#include "megbrain/imperative/backward_graph_opt.h" +#include "megbrain/imperative/ops/autogen.h" +#include "megbrain/imperative/ops/opr_attr.h" #include "megbrain/opr/basic_arith.h" #include "megbrain/opr/dnn/batch_norm.h" -#include "megbrain/imperative/ops/opr_attr.h" -#include "megbrain/imperative/ops/autogen.h" -#include "megbrain/imperative/backward_graph_opt.h" using namespace mgb; using namespace cg; using namespace imperative; template -T prepare_backward_graph_inputs(const BackwardGraphResult& bg, const T& inputs, const T& outputs, const T& grads) { +T prepare_backward_graph_inputs(const BackwardGraphResult& bg, const T& inputs, + const T& outputs, const T& grads) { T ret; size_t i = 0; for (auto&& t : inputs) { @@ -54,7 +56,9 @@ T expand_grads(const U& bg, const T& outputs) { } template -T prepare_optimized_backward_inputs(const OptimizedBackwardGraphResult& bg, const T& precomp, const T& inputs, const T& outputs, const T& grads) { +T prepare_optimized_backward_inputs(const OptimizedBackwardGraphResult& bg, + const T& precomp, const T& inputs, + const T& outputs, const T& grads) { T ret = precomp; size_t i = 0; for (auto&& t : inputs) { @@ -75,7 +79,8 @@ T prepare_optimized_backward_inputs(const OptimizedBackwardGraphResult& bg, cons return ret; } -SmallVector apply_shared_on_physical_tensor(std::shared_ptr def, SmallVector inputs) { +SmallVector apply_shared_on_physical_tensor( + std::shared_ptr def, SmallVector inputs) { return OpDef::apply_on_physical_tensor(*def, inputs); } @@ -83,7 +88,7 @@ TEST(TestImperative, BackwardGraphBasic) { HostTensorGenerator<> gen; SmallVector hvs; SmallVector inputs; - for(size_t i = 0; i < 2; ++ i) { + for (size_t i = 0; i < 2; ++i) { hvs.push_back(*gen({42})); inputs.push_back(Tensor::make(hvs.back())); } @@ -97,7 +102,8 @@ TEST(TestImperative, BackwardGraphBasic) { for (auto&& i : inputs) { input_descs.push_back({i->layout(), i->comp_node()}); } - auto result = OpDef::make_backward_graph(*attr, input_descs, {true, true}, {true}); + auto result = OpDef::make_backward_graph(*attr, input_descs, {true, true}, + {true}); auto&& save_for_backward = result.save_for_backward; auto&& input_has_grad = result.input_has_grad; @@ -106,9 +112,9 @@ TEST(TestImperative, BackwardGraphBasic) { hvs.push_back(*gen({42})); inputs.push_back(Tensor::make(hvs.back())); mgb_assert(save_for_backward.size() == inputs.size()); - for (size_t i = 0; i < inputs.size(); ++ i) { + for (size_t i = 0; i < inputs.size(); ++i) { if (!save_for_backward[i]) { - inputs[i].reset(); // drop unused tensor + inputs[i].reset(); // drop unused tensor } } SmallVector backward_graph_inputs; @@ -118,13 +124,11 @@ TEST(TestImperative, BackwardGraphBasic) { } } inputs.clear(); - auto input_grads = result.backward.apply( - backward_graph_inputs, - apply_shared_on_physical_tensor, - [&](auto&& x){ return x; } - ); + auto input_grads = result.backward.apply(backward_graph_inputs, + apply_shared_on_physical_tensor, + [&](auto&& x) { return x; }); mgb_assert(input_grads.size() == input_has_grad.size()); - for (size_t i = 0; i < input_has_grad.size(); ++ i) { + for (size_t i = 0; i < input_has_grad.size(); ++i) { mgb_assert(input_has_grad[i] == static_cast(input_grads[i])); } @@ -133,9 +137,10 @@ TEST(TestImperative, BackwardGraphBasic) { res.emplace_back(); res.back().copy_from(i->dev_tensor()).sync(); } - for (size_t i = 0; i < 42; ++ i) { - for (size_t j = 0; j < 1; ++ j) { - ASSERT_EQ(hvs[2].ptr()[i] * hvs[j].ptr()[i], res[j ^ 1].ptr()[i]); + for (size_t i = 0; i < 42; ++i) { + for (size_t j = 0; j < 1; ++j) { + ASSERT_EQ(hvs[2].ptr()[i] * hvs[j].ptr()[i], + res[j ^ 1].ptr()[i]); } } } @@ -152,7 +157,8 @@ TEST(TestImperative, BackwardGraphIdentity) { SmallVector input_descs; input_descs.push_back({a->layout(), a->comp_node()}); - auto result = OpDef::make_backward_graph(*attr, input_descs, {true}, {true}); + auto result = + OpDef::make_backward_graph(*attr, input_descs, {true}, {true}); auto&& save_for_backward = result.save_for_backward; auto&& input_has_grad = result.input_has_grad; @@ -160,9 +166,9 @@ TEST(TestImperative, BackwardGraphIdentity) { inputs.push_back(outputs[0]); inputs.push_back(dc); mgb_assert(save_for_backward.size() == inputs.size()); - for (size_t i = 0; i < inputs.size(); ++ i) { + for (size_t i = 0; i < inputs.size(); ++i) { if (!save_for_backward[i]) { - inputs[i].reset(); // drop unused tensor + inputs[i].reset(); // drop unused tensor } } SmallVector backward_graph_inputs; @@ -172,19 +178,17 @@ TEST(TestImperative, BackwardGraphIdentity) { } } inputs.clear(); - auto input_grads = result.backward.apply( - backward_graph_inputs, - apply_shared_on_physical_tensor, - [&](auto&& x){ return x; } - ); + auto input_grads = result.backward.apply(backward_graph_inputs, + apply_shared_on_physical_tensor, + [&](auto&& x) { return x; }); mgb_assert(input_grads.size() == input_has_grad.size()); - for (size_t i = 0; i < input_has_grad.size(); ++ i) { + for (size_t i = 0; i < input_has_grad.size(); ++i) { mgb_assert(input_has_grad[i] == static_cast(input_grads[i])); } HostTensorND hv; hv.copy_from(input_grads[0]->dev_tensor()).sync(); - for (size_t i = 0; i < 42; ++ i) { + for (size_t i = 0; i < 42; ++i) { ASSERT_EQ(host_dc->ptr()[i], hv.ptr()[i]); } } @@ -192,7 +196,7 @@ TEST(TestImperative, BackwardGraphIdentity) { TEST(TestImperative, BatchNormGrad) { auto cn = CompNode::load("xpux"); using Param = opr::BatchNorm::Param; - size_t N=2, C=3, H=5, W=5; + size_t N = 2, C = 3, H = 5, W = 5; LogicalTensorDesc inp{TensorLayout{{N, C, H, W}, dtype::Float32()}, cn}; LogicalTensorDesc stat{TensorLayout{{C}, dtype::Float32()}, cn}; { @@ -202,7 +206,8 @@ TEST(TestImperative, BatchNormGrad) { param.fwd_mode = Param::FwdMode::TRAINING; attr.param.write_pod(param); OpDef::make_backward_graph(attr, {inp, stat, stat, stat, stat}, - {true, true ,true, false, false}, {false, false, false, false, true}); + {true, true, true, false, false}, + {false, false, false, false, true}); } { auto op = OprAttr::make("BatchNorm"); @@ -210,8 +215,8 @@ TEST(TestImperative, BatchNormGrad) { Param param; param.fwd_mode = Param::FwdMode::TRAINING; attr.param.write_pod(param); - OpDef::make_backward_graph(attr, {inp, stat, stat}, - {true, true ,true}, {false, false, true}); + OpDef::make_backward_graph(attr, {inp, stat, stat}, {true, true, true}, + {false, false, true}); } } @@ -220,7 +225,8 @@ TEST(TestImperative, OptimizedBackwardGraphBasic) { LogicalTensorDesc desc = {TensorLayout(dtype::Float32()), cn}; HostTensorGenerator<> gen; auto op = std::shared_ptr(Elemwise::make(Elemwise::Mode::ADD)); - auto bg = OpDef::make_backward_graph(*op, {desc, desc}, {true, true}, {true}); + auto bg = + OpDef::make_backward_graph(*op, {desc, desc}, {true, true}, {true}); auto obg = OptimizedBackwardGraphResult(bg); ASSERT_EQ(obg.save_for_backward.size(), 4); ASSERT_FALSE(obg.save_for_backward[0]); @@ -235,30 +241,30 @@ TEST(TestImperative, OptimizedBackwardGraphBasic) { auto dc_tn = Tensor::make(*dc_hv); auto c_tn = OpDef::apply_on_physical_tensor(*op, {a_tn, b_tn})[0]; - auto backward_graph_inputs = prepare_backward_graph_inputs>(bg, {a_tn, b_tn}, {c_tn}, {dc_tn}); - auto grads = expand_grads(bg, bg.backward.apply( - backward_graph_inputs, - apply_shared_on_physical_tensor, - [&](auto&& x){ return x; } - )); + auto backward_graph_inputs = + prepare_backward_graph_inputs>( + bg, {a_tn, b_tn}, {c_tn}, {dc_tn}); + auto grads = + expand_grads(bg, bg.backward.apply(backward_graph_inputs, + apply_shared_on_physical_tensor, + [&](auto&& x) { return x; })); - auto precomp = obg.precomp.apply( - SmallVector{a_tn, b_tn, c_tn}, - apply_shared_on_physical_tensor, - [&](auto&& x){ return x; } - ); + auto precomp = obg.precomp.apply(SmallVector{a_tn, b_tn, c_tn}, + apply_shared_on_physical_tensor, + [&](auto&& x) { return x; }); ASSERT_EQ(precomp.size(), 2); ASSERT_EQ(precomp[0]->shape().ndim, 1); ASSERT_LE(precomp[0]->shape()[0], 2); ASSERT_EQ(precomp[1]->shape().ndim, 1); ASSERT_LE(precomp[1]->shape()[0], 2); - auto backward_inputs = prepare_optimized_backward_inputs>(obg, precomp, {a_tn, b_tn}, {c_tn}, {dc_tn}); - auto grads2 = expand_grads(obg, obg.backward.apply( - backward_inputs, - apply_shared_on_physical_tensor, - [&](auto&& x){ return x; } - )); + auto backward_inputs = + prepare_optimized_backward_inputs>( + obg, precomp, {a_tn, b_tn}, {c_tn}, {dc_tn}); + auto grads2 = expand_grads( + obg, + obg.backward.apply(backward_inputs, apply_shared_on_physical_tensor, + [&](auto&& x) { return x; })); ASSERT_EQ(grads2.size(), 2); MGB_ASSERT_TENSOR_EQ(grads[0]->get_value(), grads2[0]->get_value()); diff --git a/src/core/include/megbrain/ir/ops.td b/src/core/include/megbrain/ir/ops.td index 8da5647e896c2cf10bed5d849f816ebe191d03ed..ffcee330c9bc01080356d0a4186d49188d79d3bb 100644 --- a/src/core/include/megbrain/ir/ops.td +++ b/src/core/include/megbrain/ir/ops.td @@ -271,6 +271,7 @@ def BatchedSetMeshIndexing: FancyIndexingBase<"BatchedSetMeshIndexing">; def FakeQuant: MgbHashableOp<"FakeQuant", [FakeQuantParam]>; def AssertEqual: MgbHashableOp<"AssertEqual",[AssertEqualParam]>; def TQT: MgbHashableOp<"TQT", [TQTParam]>; +def LSQ: MgbHashableOp<"LSQ", [LSQParam]>; def ElemwiseMultiType: MgbHashableOp<"ElemwiseMultiType", [ElemwiseMultiTypeParam]> { let extraArguments = (ins MgbDTypeAttr:$dtype diff --git a/src/opr/impl/dnn/dnn.oprdecl b/src/opr/impl/dnn/dnn.oprdecl index c511ed496eee6eedc36ec75f3aba851bb5b60562..cda4eca765109a6860587af59a6b6f302ceb531b 100644 --- a/src/opr/impl/dnn/dnn.oprdecl +++ b/src/opr/impl/dnn/dnn.oprdecl @@ -324,5 +324,7 @@ decl_opr('FakeQuant', decl_opr('TQT', inputs=[Doc('src','input tensor'),Doc('scale','scale tensor')], params='TQT') - +decl_opr('LSQ', + inputs=[Doc('src','input tensor'),Doc('scale','scale tensor'),Doc('zero_point','zero point tensor'),Doc('grad_scale','grad scale tensor')], + params='LSQ') # vim: ft=python diff --git a/src/opr/impl/dnn/dnn.sereg.h b/src/opr/impl/dnn/dnn.sereg.h index 01cedfa50232f0fb34bae85787663e250b2874b9..6e9e70fcff128bed7a1f9b4ea0e15139603a4c78 100644 --- a/src/opr/impl/dnn/dnn.sereg.h +++ b/src/opr/impl/dnn/dnn.sereg.h @@ -6,20 +6,22 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ +#include "megbrain/opr/dnn/adaptive_pooling.h" #include "megbrain/opr/dnn/batch_norm.h" #include "megbrain/opr/dnn/convolution.h" #include "megbrain/opr/dnn/correlation.h" +#include "megbrain/opr/dnn/fake_quant.h" #include "megbrain/opr/dnn/images2neibs.h" -#include "megbrain/opr/dnn/pooling.h" -#include "megbrain/opr/dnn/adaptive_pooling.h" -#include "megbrain/opr/dnn/roi_pooling.h" -#include "megbrain/opr/dnn/roi_align.h" #include "megbrain/opr/dnn/local.h" #include "megbrain/opr/dnn/lrn.h" -#include "megbrain/opr/dnn/fake_quant.h" +#include "megbrain/opr/dnn/lsq.h" +#include "megbrain/opr/dnn/pooling.h" +#include "megbrain/opr/dnn/roi_align.h" +#include "megbrain/opr/dnn/roi_pooling.h" #include "megbrain/opr/dnn/tqt.h" #include "megbrain/serialization/sereg.h" #include "megdnn/opr_param_defs.h" @@ -183,7 +185,8 @@ struct ConvLoadDumpImpl { static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) { auto&& opr = opr_.cast_final_safe(); ctx.write_param(opr.param()); - ctx.write_param(opr.execution_policy_transient()); + ctx.write_param( + opr.execution_policy_transient()); } static VarNode* make(const cg::VarNodeArray& inputs, const ConvParam& param, @@ -251,6 +254,20 @@ struct OprMaker { } }; +template <> +struct OprMaker { + using Param = opr::LSQBackward::Param; + static cg::OperatorNodeBase* make(const Param& param, + const cg::VarNodeArray& i, + ComputingGraph& graph, + const OperatorNodeConfig& config) { + MGB_MARK_USED_VAR(graph); + return opr::LSQBackward::make(i[0], i[1], i[2], i[3], i[4], param, + config)[0] + .node() + ->owner_opr(); + } +}; template <> struct OprLoadDumpImpl : public PoolingLoadDumpImplowner_graph(), + config, + "lsq_bwd", + {y_grad, x, scale, zero_point, grad_scale}}, + 1, true) { + init_megdnn_opr(*this, param); + add_input({y_grad, x, scale, zero_point, grad_scale}); +} + +SymbolVarArray LSQBackward::make(SymbolVar y_grad, SymbolVar x, SymbolVar scale, + SymbolVar zero_point, SymbolVar grad_scale, + const Param& param, + const OperatorNodeConfig& config) { + auto&& out = x.node()->owner_graph() + ->insert_opr(std::make_unique( + y_grad.node(), x.node(), scale.node(), + zero_point.node(), grad_scale.node(), param, + config)) + ->output(); + SymbolVarArray ret(out.size()); + for (size_t i = 0; i < ret.size(); ++i) { + ret[i] = out[i]; + } + return ret; +} + +void LSQBackward::init_output_static_infer_desc() { + using namespace cg::static_infer; + auto&& mgr = owner_graph()->static_infer_manager(); + + mgr.register_shape_infer(output(0), + ShapeInferDesc::make_identity(input(1))); + mgr.register_shape_infer(output(1), + ShapeInferDesc::make_identity(input(1))); + this->init_output_static_infer_desc_workspace( + intl::AutoAddWorkspaceNeedLimitGetter::val); +} + +void LSQBackward::init_output_dtype() { + output(0)->dtype(input(1)->dtype()); + output(1)->dtype(input(2)->dtype()); +} diff --git a/src/opr/include/megbrain/opr/dnn/lsq.h b/src/opr/include/megbrain/opr/dnn/lsq.h new file mode 100644 index 0000000000000000000000000000000000000000..8a6e95d25d6051eec55a940381b06b57cea97119 --- /dev/null +++ b/src/opr/include/megbrain/opr/dnn/lsq.h @@ -0,0 +1,50 @@ +/** + * \file src/opr/include/megbrain/opr/dnn/lsq.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once +#include "megbrain/opr/internal/megdnn_opr_wrapper.h" +#include "megdnn/oprs.h" +namespace mgb { +namespace opr { + +MGB_DEFINE_OPR_CLASS(LSQForward, + intl::MegDNNOprWrapperFwd) // { +public: +LSQForward(VarNode* src, VarNode* scale, VarNode* zero_point, + VarNode* grad_scale, const Param& param, + const OperatorNodeConfig& config); + +static SymbolVar make(SymbolVar src, SymbolVar scale, SymbolVar zero_point, + SymbolVar grad_scale, const Param& param = {}, + const OperatorNodeConfig& config = {}); +}; +using LSQ = LSQForward; + +MGB_DEFINE_OPR_CLASS(LSQBackward, + intl::MegDNNOprWrapperBwd) // { +public: +LSQBackward(VarNode* y_grad, VarNode* x, VarNode* scale, VarNode* zero_point, + VarNode* grad_scale, const Param& param, + const OperatorNodeConfig& config); + +static SymbolVarArray make(SymbolVar y_grad, SymbolVar x, SymbolVar scale, + SymbolVar zero_point, SymbolVar grad_scale, + const Param& param = {}, + const OperatorNodeConfig& config = {}); + +private: +void init_output_static_infer_desc() override; +void init_output_dtype() override; +}; + +} // namespace opr +} // namespace mgb diff --git a/src/opr/test/dnn/lsq.cpp b/src/opr/test/dnn/lsq.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4d220f512df2e9c591cad0aef852b0c0be1fcec2 --- /dev/null +++ b/src/opr/test/dnn/lsq.cpp @@ -0,0 +1,78 @@ +/** + * \file src/opr/test/dnn/lsq.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "megbrain/opr/dnn/lsq.h" +#include "megbrain/comp_node_env.h" +#include "megbrain/test/autocheck.h" + +using namespace std; +using namespace mgb; + +namespace { + +void run() { + using Checker = AutoOprChecker<4, 1>; + + auto make_graph = + [&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray { + auto o0 = opr::LSQForward::make(inputs[0], inputs[1], inputs[2], + inputs[3]); + return {o0}; + }; + + auto fwd = [&](Checker::NumOutArray& dest, Checker::NumInpArray inp) { + auto opr = MegDNNHandle::get( + CompNodeEnv::from_comp_node(CompNode::default_cpu())) + ->create_operator(); + dest[0].dtype(dtype::Float32()) + .comp_node(inp[0]->comp_node()) + .resize(inp[0]->shape()); + opr->exec(inp[0]->as_megdnn(), inp[1]->as_megdnn(), inp[2]->as_megdnn(), + inp[3]->as_megdnn(), dest[0].as_megdnn(), {}); + }; + + auto gen = [&](HostTensorND& src) { + HostTensorGenerator + src_gen(10.f); + src = *src_gen(src.shape(), src.comp_node()); + }; + + Checker::RunOptions opt; + opt.numdiff_max_err = 1e-5; + + Checker checker{make_graph, fwd}; + checker.set_input_generator(0, gen) + .set_input_generator(1, gen) + .set_input_generator(2, gen) + .set_input_generator(3, gen) + .set_input_allow_grad(0, false) + .set_input_allow_grad(1, false) + .set_input_allow_grad(2, false) + .set_input_allow_grad(3, false) + .set_output_allow_grad(0, false); + checker.run({TensorShape{1, 2, 3, 4}, TensorShape{1}, TensorShape{1}, + TensorShape{1}}, + opt) + .run({TensorShape{2, 3, 8, 8}, TensorShape{1}, TensorShape{1}, + TensorShape{1}}, + opt) + .run({TensorShape{1, 3, 4, 4}, TensorShape{1}, TensorShape{1}, + TensorShape{1}}, + opt); +} + +} // anonymous namespace + +TEST(TestOprDNN, LSQForward) { + REQUIRE_GPU(1); + run(); +} \ No newline at end of file diff --git a/src/serialization/impl/schema.fbs b/src/serialization/impl/schema.fbs index 4696cd7a1613bbd2066abf40377ce1a346854303..2a9c3e9229d779cfa23fff8ca1c7374c9100a280 100644 --- a/src/serialization/impl/schema.fbs +++ b/src/serialization/impl/schema.fbs @@ -107,6 +107,7 @@ union OperatorParam { param.FakeQuant = 73, param.TQT = 74, param.Correlation = 75, + param.LSQ = 76, } table Operator {