提交 f30c0e06 编写于 作者: M Megvii Engine Team 提交者: huangxinda

feat(mgb/opr): add lsq opr

GitOrigin-RevId: 45494a2b57fe1dacf498040c8290fcec3c335990
上级 eb66681f
......@@ -1741,6 +1741,67 @@ protected:
const TensorLayout& grad_s, size_t workspace_in_bytes);
};
class LSQBase : public OperatorBase {
DEF_OPR_IMPL_CTOR(LSQBase, OperatorBase);
DEF_OPR_PARAM(LSQ);
protected:
void deduce_layout_fwd(const TensorLayout& input, TensorLayout& output);
void check_layout_fwd(const TensorLayout& input, const TensorLayout& scale,
const TensorLayout& zero_point,
const TensorLayout& grad_scale,
const TensorLayout& output);
};
class LSQForward : public LSQBase {
DEF_OPR_IMPL(LSQForward, LSQBase, 4, 1);
public:
virtual void exec(_megdnn_tensor_in input, _megdnn_tensor_in scale,
_megdnn_tensor_in zero_point,
_megdnn_tensor_in grad_scale, _megdnn_tensor_out output,
_megdnn_workspace workspace) = 0;
void deduce_layout(const TensorLayout& input, const TensorLayout& scale,
const TensorLayout& zero_point,
const TensorLayout& grad_scale, TensorLayout& output);
virtual size_t get_workspace_in_bytes(const TensorLayout& input,
const TensorLayout& scale,
const TensorLayout& zero_point,
const TensorLayout& grad_scale,
const TensorLayout& output) = 0;
protected:
void check_exec(const TensorLayout& input, const TensorLayout& scale,
const TensorLayout& zero_point,
const TensorLayout& grad_scale, const TensorLayout& output,
size_t workspace_in_bytes);
};
using LSQ = LSQForward;
class LSQBackward : public LSQBase {
DEF_OPR_IMPL(LSQBackward, LSQBase, 5, 2);
public:
virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in input,
_megdnn_tensor_in scale, _megdnn_tensor_in zero_point,
_megdnn_tensor_in grad_scale, _megdnn_tensor_out grad_x,
_megdnn_tensor_out grad_s,
_megdnn_workspace workspace) = 0;
virtual size_t get_workspace_in_bytes(const TensorLayout& diff,
const TensorLayout& input,
const TensorLayout& scale,
const TensorLayout& zero_point,
const TensorLayout& grad_scale,
const TensorLayout& grad_x,
const TensorLayout& grad_s) = 0;
protected:
void check_exec(const TensorLayout& diff, const TensorLayout& input,
const TensorLayout& scale, const TensorLayout& zero_point,
const TensorLayout& grad_scale, const TensorLayout& grad_x,
const TensorLayout& grad_s, size_t workspace_in_bytes);
};
} // namespace megdnn
#include "megdnn/internal/opr_header_epilogue.h"
......
......@@ -1124,3 +1124,8 @@ Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later o
add_fields('int32', 'qmin', '-2147483648').
add_fields('int32', 'qmax', '2147483647')
)
(pdef('LSQ').
add_fields('int32', 'qmin', '-2147483648').
add_fields('int32', 'qmax', '2147483647')
)
......@@ -37,6 +37,7 @@ namespace megdnn {
megdnn_assert(size, "uninitialized ElemwiseOpParamN");
}
template struct ElemwiseOpParamN<7>;
template struct ElemwiseOpParamN<6>;
template struct ElemwiseOpParamN<5>;
template struct ElemwiseOpParamN<4>;
......
......@@ -208,7 +208,9 @@ private:
cb(FakeQuantBackward) \
cb(TQTForward) \
cb(TQTBackward) \
cb(CheckHasInf)
cb(CheckHasInf) \
cb(LSQForward) \
cb(LSQBackward)
/*!
* \brief specialize HandleImpl::create_operator for a single opr type;
......
/**
* \file dnn/src/common/lsq.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "megdnn/oprs.h"
#include "src/common/utils.h"
namespace megdnn {
void LSQBase::deduce_layout_fwd(const TensorLayout& input,
TensorLayout& output) {
output = TensorLayout(input, input.dtype);
}
void LSQBase::check_layout_fwd(const TensorLayout& input,
const TensorLayout& scale,
const TensorLayout& zero_point,
const TensorLayout& grad_scale,
const TensorLayout& output) {
megdnn_assert(input.dtype == dtype::Float32());
megdnn_assert(scale.dtype == dtype::Float32());
megdnn_assert(zero_point.dtype == dtype::Float32());
megdnn_assert(grad_scale.dtype == dtype::Float32());
TensorLayout expected;
deduce_layout_fwd(input, expected);
megdnn_assert_eq_layout(expected, output);
}
void LSQForward::deduce_layout(const TensorLayout& input,
const TensorLayout& /* scale */,
const TensorLayout& /*zero_point*/,
const TensorLayout& /*grad_scale*/,
TensorLayout& output) {
deduce_layout_fwd(input, output);
}
void LSQForward::check_exec(const TensorLayout& input,
const TensorLayout& scale,
const TensorLayout& zero_point,
const TensorLayout& grad_scale,
const TensorLayout& output,
size_t workspace_in_bytes) {
check_layout_fwd(input, scale, zero_point, grad_scale, output);
auto required_workspace_space = get_workspace_in_bytes(
input, scale, zero_point, grad_scale, output);
megdnn_assert(workspace_in_bytes >= required_workspace_space);
}
void LSQBackward::check_exec(
const TensorLayout& diff, const TensorLayout& input,
const TensorLayout& scale, const TensorLayout& zero_point,
const TensorLayout& grad_scale, const TensorLayout& grad_x,
const TensorLayout& grad_s, size_t workspace_in_bytes) {
megdnn_assert_eq_shape(diff, input);
megdnn_assert_eq_shape(grad_x, input);
auto required_worspace_space = get_workspace_in_bytes(
diff, input, scale, zero_point, grad_scale, grad_x, grad_s);
megdnn_assert(workspace_in_bytes >= required_worspace_space);
}
} // namespace megdnn
......@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megdnn/oprs.h"
......@@ -121,6 +122,8 @@ DEF(UniformRNG, 1, true, true);
DEF(GaussianRNG, 1, true, true);
DEF(ChecksumForward, 1, true, false);
DEF(CheckHasInf, 2, true, true);
DEF(LSQForward, 5, true, true);
DEF(LSQBackward, 7, true, false);
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -947,6 +947,119 @@ struct OpCallerUniform<Op, 5, PVis> {
}
};
//! specialization for arity == 6
template <class Op, class PVis>
struct OpCallerUniform<Op, 6, PVis> {
Op op;
PVis par[6];
static const uint32_t packed_size = PVis::packed_size;
devfunc void thread_init(uint32_t idx) {
idx = idx * packed_size;
par[0].thread_init(idx);
par[1].thread_init(idx);
par[2].thread_init(idx);
par[3].thread_init(idx);
par[4].thread_init(idx);
par[5].thread_init(idx);
}
devfunc void on(uint32_t idx) {
idx = idx * packed_size;
op(idx, par[0].at(idx), par[1].at(idx), par[2].at(idx), par[3].at(idx),
par[4].at(idx), par[5].at(idx));
}
devfunc void on(uint32_t idx, uint32_t remain) {
idx = idx * packed_size;
if (remain >= packed_size) {
op(idx, par[0].at(idx), par[1].at(idx), par[2].at(idx),
par[3].at(idx), par[4].at(idx), par[5].at(idx));
} else {
auto ptr0 = par[0].ptr();
auto ptr1 = par[1].ptr();
auto ptr2 = par[2].ptr();
auto ptr3 = par[3].ptr();
auto ptr4 = par[4].ptr();
auto ptr5 = par[5].ptr();
for (int i = 0; i < remain; i++) {
op(idx + i, ptr0[par[0].offset(idx + i)],
ptr1[par[1].offset(idx + i)], ptr2[par[2].offset(idx + i)],
ptr3[par[3].offset(idx + i)], ptr4[par[4].offset(idx + i)],
ptr5[par[5].offset(idx + i)]);
}
}
}
devfunc void next() {
par[0].next();
par[1].next();
par[2].next();
par[3].next();
par[4].next();
par[5].next();
}
};
//! specialization for arity == 7
template <class Op, class PVis>
struct OpCallerUniform<Op, 7, PVis> {
Op op;
PVis par[7];
static const uint32_t packed_size = PVis::packed_size;
devfunc void thread_init(uint32_t idx) {
idx = idx * packed_size;
par[0].thread_init(idx);
par[1].thread_init(idx);
par[2].thread_init(idx);
par[3].thread_init(idx);
par[4].thread_init(idx);
par[5].thread_init(idx);
par[6].thread_init(idx);
}
devfunc void on(uint32_t idx) {
idx = idx * packed_size;
op(idx, par[0].at(idx), par[1].at(idx), par[2].at(idx), par[3].at(idx),
par[4].at(idx), par[5].at(idx), par[6].at(idx));
}
devfunc void on(uint32_t idx, uint32_t remain) {
idx = idx * packed_size;
if (remain >= packed_size) {
op(idx, par[0].at(idx), par[1].at(idx), par[2].at(idx),
par[3].at(idx), par[4].at(idx), par[5].at(idx), par[6].at(idx));
} else {
auto ptr0 = par[0].ptr();
auto ptr1 = par[1].ptr();
auto ptr2 = par[2].ptr();
auto ptr3 = par[3].ptr();
auto ptr4 = par[4].ptr();
auto ptr5 = par[5].ptr();
auto ptr6 = par[6].ptr();
for (int i = 0; i < remain; i++) {
op(idx + i, ptr0[par[0].offset(idx + i)],
ptr1[par[1].offset(idx + i)], ptr2[par[2].offset(idx + i)],
ptr3[par[3].offset(idx + i)], ptr4[par[4].offset(idx + i)],
ptr5[par[5].offset(idx + i)], ptr6[par[6].offset(idx + i)]);
}
}
}
devfunc void next() {
par[0].next();
par[1].next();
par[2].next();
par[3].next();
par[4].next();
par[5].next();
par[6].next();
}
};
/*!
* \brief call binary (i.e. arity == 2) operator with different param
* visitors
......
......@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/common/handle_impl.h"
......@@ -15,6 +16,7 @@
#include "src/cuda/add_update/opr_impl.h"
#include "src/cuda/argmxx/opr_impl.h"
#include "src/cuda/argsort/opr_impl.h"
#include "src/cuda/batch_conv_bias/opr_impl.h"
#include "src/cuda/batch_normalization/opr_impl.h"
#include "src/cuda/batched_matrix_mul/opr_impl.h"
#include "src/cuda/check_has_inf/opr_impl.h"
......@@ -35,6 +37,7 @@
#include "src/cuda/elemwise/opr_impl.h"
#include "src/cuda/elemwise_multi_type/opr_impl.h"
#include "src/cuda/eye/opr_impl.h"
#include "src/cuda/fake_quant/opr_impl.h"
#include "src/cuda/flip/opr_impl.h"
#include "src/cuda/gaussian_blur/opr_impl.h"
#include "src/cuda/group_local/opr_impl.h"
......@@ -45,6 +48,7 @@
#include "src/cuda/local/opr_impl.h"
#include "src/cuda/local_share/opr_impl.h"
#include "src/cuda/lrn/opr_impl.h"
#include "src/cuda/lsq/opr_impl.h"
#include "src/cuda/mask_conv/opr_impl.h"
#include "src/cuda/matrix_inverse/opr_impl.h"
#include "src/cuda/matrix_mul/opr_impl.h"
......@@ -56,9 +60,11 @@
#include "src/cuda/reduce/opr_impl.h"
#include "src/cuda/relayout/opr_impl.h"
#include "src/cuda/relayout_format/opr_impl.h"
#include "src/cuda/remap/opr_impl.h"
#include "src/cuda/repeat/opr_impl.h"
#include "src/cuda/resize/opr_impl.h"
#include "src/cuda/rng/opr_impl.h"
#include "src/cuda/roi_align/opr_impl.h"
#include "src/cuda/roi_copy/opr_impl.h"
#include "src/cuda/roi_pooling/opr_impl.h"
#include "src/cuda/rotate/opr_impl.h"
......@@ -70,16 +76,11 @@
#include "src/cuda/tensor_remap/opr_impl.h"
#include "src/cuda/tile/opr_impl.h"
#include "src/cuda/topk/opr_impl.h"
#include "src/cuda/tqt/opr_impl.h"
#include "src/cuda/transpose/opr_impl.h"
#include "src/cuda/type_cvt/opr_impl.h"
#include "src/cuda/warp_affine/opr_impl.h"
#include "src/cuda/warp_perspective/opr_impl.h"
#include "src/cuda/local_share/opr_impl.h"
#include "src/cuda/roi_align/opr_impl.h"
#include "src/cuda/batch_conv_bias/opr_impl.h"
#include "src/cuda/remap/opr_impl.h"
#include "src/cuda/fake_quant/opr_impl.h"
#include "src/cuda/tqt/opr_impl.h"
namespace megdnn {
namespace cuda {
......
/**
* \file dnn/src/cuda/lsq/kern.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "./kern.cuh"
namespace megdnn {
namespace cuda {
#define cb(_dtype) \
INST_RUN_ELEMWISE(LSQKernOp<DTypeTrait<_dtype>::ctype>, \
DTypeTrait<_dtype>::ctype, 3); \
INST_RUN_ELEMWISE(LSQBwdKernOp<DTypeTrait<_dtype>::ctype>, \
DTypeTrait<_dtype>::ctype, 3); \
INST_RUN_ELEMWISE(LSQKernOpNonContig<DTypeTrait<_dtype>::ctype>, \
DTypeTrait<_dtype>::ctype, 5); \
INST_RUN_ELEMWISE(LSQBwdKernOpNonContig<DTypeTrait<_dtype>::ctype>, \
DTypeTrait<_dtype>::ctype, 7);
cb(megdnn::dtype::Float32)
} // namespace cuda
} // namespace megdnn
/**
* \file dnn/src/cuda/lsq/kern.cuh
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "src/cuda/elemwise_helper.cuh"
#include "src/cuda/utils.cuh"
#if MEGDNN_CC_HOST
#include "megdnn/oprs.h"
#endif
namespace megdnn {
namespace cuda {
template <typename ctype>
struct LSQKernOp {
ctype* input;
ctype* output;
ctype qmin, qmax;
__device__ void operator()(uint32_t idx, ctype scale, ctype zero_point,
ctype grad_scale) {
ctype x = input[idx] / scale + zero_point;
x = fmaxf(fminf(x, qmax), qmin);
x = round(x);
output[idx] = (x - zero_point) * scale;
}
#if MEGDNN_CC_HOST
LSQKernOp(const TensorND& input, const TensorND& output,
const LSQ::Param& param)
: input{input.ptr<ctype>()},
output{output.ptr<ctype>()},
qmin(param.qmin),
qmax(param.qmax) {}
#endif
};
template <typename ctype>
struct LSQBwdKernOp {
ctype* diff;
ctype* input;
ctype* grad_x;
ctype* grad_s;
ctype qmin, qmax;
__device__ void operator()(uint32_t idx, ctype scale, ctype zero_point,
ctype grad_scale) {
ctype x = input[idx] / scale + zero_point;
bool ind_small = x < qmin;
bool ind_big = x > qmax;
bool ind_middle = ind_small ^ ind_big;
ind_middle = !ind_middle;
grad_s[idx] = ind_small * qmin + ind_big * qmax +
ind_middle * (-x + round(x));
grad_s[idx] = grad_s[idx] * grad_scale * diff[idx];
grad_x[idx] = ind_middle * diff[idx];
}
#if MEGDNN_CC_HOST
LSQBwdKernOp(const TensorND& diff, const TensorND& input,
const TensorND& grad_x, const TensorND& grad_s,
const LSQ::Param& param)
: diff{diff.ptr<ctype>()},
input{input.ptr<ctype>()},
grad_x{grad_x.ptr<ctype>()},
grad_s{grad_s.ptr<ctype>()},
qmin(param.qmin),
qmax(param.qmax) {}
#endif
};
template <typename ctype>
struct LSQKernOpNonContig {
ctype qmin;
ctype qmax;
__device__ void operator()(uint32_t, ctype& output, ctype& input,
ctype& scale, ctype& zero_point,
ctype grad_scale) {
ctype x = input / scale + zero_point;
x = fmaxf(fminf(x, qmax), qmin);
x = round(x);
output = (x - zero_point) * scale;
}
#if MEGDNN_CC_HOST
LSQKernOpNonContig(const LSQ::Param& param)
: qmin(param.qmin), qmax(param.qmax) {}
#endif
};
template <typename ctype>
struct LSQBwdKernOpNonContig {
ctype qmin;
ctype qmax;
__device__ void operator()(uint32_t, ctype& grad_x, ctype& grad_s,
ctype& diff, ctype& input, ctype& scale,
ctype& zero_point, ctype grad_scale) {
ctype x = input / scale + zero_point;
bool ind_small = x < qmin;
bool ind_big = x > qmax;
bool ind_middle = ind_small ^ ind_big;
ind_middle = !ind_middle;
grad_s = ind_small * qmin + ind_big * qmax +
ind_middle * (-x + round(x));
grad_s = grad_s * grad_scale * diff;
grad_x = ind_middle * diff;
}
#if MEGDNN_CC_HOST
LSQBwdKernOpNonContig(const LSQ::Param& param)
: qmin(param.qmin), qmax(param.qmax) {}
#endif
};
} // namespace cuda
} // namespace megdnn
/**
* \file dnn/src/cuda/lsq/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "./opr_impl.h"
#include "./kern.cuh"
#include "src/common/utils.h"
namespace megdnn {
namespace cuda {
void LSQForwardImpl::exec(_megdnn_tensor_in input, _megdnn_tensor_in scale,
_megdnn_tensor_in zero_point,
_megdnn_tensor_in grad_scale,
_megdnn_tensor_out output,
_megdnn_workspace workspace) {
check_exec(input.layout, scale.layout, zero_point.layout, grad_scale.layout,
output.layout, workspace.size);
if (!input.layout.is_contiguous() || !output.layout.is_contiguous())
return exec_noncontig(input, scale, zero_point, grad_scale, output);
ElemwiseOpParamN<3> ele_param;
ele_param[0] = scale;
ele_param[0].layout = ele_param[0].layout.broadcast(input.layout);
ele_param[1] = zero_point;
ele_param[1].layout = ele_param[1].layout.broadcast(input.layout);
ele_param[2] = grad_scale;
ele_param[2].layout = ele_param[2].layout.broadcast(input.layout);
ele_param.init_from_given_tensor();
auto m_param = param();
auto stream = cuda_stream(handle());
#define cb(DType) \
if (input.layout.dtype == DType()) { \
using T = typename DTypeTrait<DType>::ctype; \
run_elemwise<LSQKernOp<T>, T, 3>(ele_param, stream, \
{input, output, m_param}); \
return; \
}
cb(megdnn::dtype::Float32)
#undef cb
}
void LSQForwardImpl::exec_noncontig(_megdnn_tensor_in input,
_megdnn_tensor_in scale,
_megdnn_tensor_in zero_point,
_megdnn_tensor_in grad_scale,
_megdnn_tensor_out output) {
ElemwiseOpParamN<5> ele_param;
ele_param[0] = output;
ele_param[1] = input;
ele_param[2] = scale;
ele_param[2].layout = ele_param[2].layout.broadcast(input.layout);
ele_param[3] = zero_point;
ele_param[3].layout = ele_param[3].layout.broadcast(input.layout);
ele_param[4] = grad_scale;
ele_param[4].layout = ele_param[4].layout.broadcast(input.layout);
ele_param.init_from_given_tensor();
auto m_param = param();
auto stream = cuda_stream(handle());
#define cb(DType) \
if (input.layout.dtype == DType()) { \
using T = typename DTypeTrait<DType>::ctype; \
run_elemwise<LSQKernOpNonContig<T>, T, 5>(ele_param, stream, \
{m_param}); \
return; \
}
cb(megdnn::dtype::Float32)
#undef cb
}
void LSQBackwardImpl::exec(_megdnn_tensor_in diff, _megdnn_tensor_in input,
_megdnn_tensor_in scale,
_megdnn_tensor_in zero_point,
_megdnn_tensor_in grad_scale,
_megdnn_tensor_out grad_x, _megdnn_tensor_out grad_s,
_megdnn_workspace workspace) {
check_exec(diff.layout, input.layout, scale.layout, zero_point.layout,
grad_scale.layout, grad_x.layout, grad_s.layout, workspace.size);
if (!input.layout.is_contiguous() || !diff.layout.is_contiguous() ||
!grad_x.layout.is_contiguous() || !grad_s.layout.is_contiguous())
return exec_noncontig(diff, input, scale, zero_point, grad_scale,
grad_x, grad_s);
ElemwiseOpParamN<3> ele_param;
ele_param[0] = scale;
ele_param[0].layout = ele_param[0].layout.broadcast(input.layout);
ele_param[1] = zero_point;
ele_param[1].layout = ele_param[1].layout.broadcast(input.layout);
ele_param[2] = grad_scale;
ele_param[2].layout = ele_param[2].layout.broadcast(input.layout);
ele_param.init_from_given_tensor();
auto m_param = param();
auto stream = cuda_stream(handle());
#define cb(DType) \
if (grad_x.layout.dtype == DType()) { \
using T = typename DTypeTrait<DType>::ctype; \
run_elemwise<LSQBwdKernOp<T>, T, 3>( \
ele_param, stream, {diff, input, grad_x, grad_s, m_param}); \
return; \
}
cb(megdnn::dtype::Float32)
#undef cb
}
void LSQBackwardImpl::exec_noncontig(_megdnn_tensor_in diff,
_megdnn_tensor_in input,
_megdnn_tensor_in scale,
_megdnn_tensor_in zero_point,
_megdnn_tensor_in grad_scale,
_megdnn_tensor_out grad_x,
_megdnn_tensor_out grad_s) {
ElemwiseOpParamN<7> ele_param;
ele_param[0] = grad_x;
ele_param[1] = grad_s;
ele_param[2] = diff;
ele_param[3] = input;
ele_param[4] = scale;
ele_param[4].layout = ele_param[4].layout.broadcast(input.layout);
ele_param[5] = zero_point;
ele_param[5].layout = ele_param[5].layout.broadcast(input.layout);
ele_param[6] = grad_scale;
ele_param[6].layout = ele_param[6].layout.broadcast(input.layout);
ele_param.init_from_given_tensor();
auto m_param = param();
auto stream = cuda_stream(handle());
#define cb(DType) \
if (input.layout.dtype == DType()) { \
using T = typename DTypeTrait<DType>::ctype; \
run_elemwise<LSQBwdKernOpNonContig<T>, T, 7>(ele_param, stream, \
{m_param}); \
return; \
}
cb(megdnn::dtype::Float32)
#undef cb
}
} // namespace cuda
} // namespace megdnn
\ No newline at end of file
/**
* \file dnn/src/cuda/lsq/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megdnn/oprs.h"
#include "src/cuda/utils.h"
namespace megdnn {
namespace cuda {
class LSQForwardImpl final : public LSQForward {
public:
using LSQForward::LSQForward;
void exec(_megdnn_tensor_in input, _megdnn_tensor_in scale,
_megdnn_tensor_in zero_point, _megdnn_tensor_in grad_scale,
_megdnn_tensor_out output, _megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout&, /* input */
const TensorLayout&, /* scale */
const TensorLayout&, /* zero_point */
const TensorLayout&, /* grad_scale */
const TensorLayout& /* output */) override {
return 0;
}
private:
void exec_noncontig(_megdnn_tensor_in input, _megdnn_tensor_in scale,
_megdnn_tensor_in zero_point,
_megdnn_tensor_in grad_scale,
_megdnn_tensor_out output);
};
class LSQBackwardImpl final : public LSQBackward {
public:
using LSQBackward::LSQBackward;
void exec(_megdnn_tensor_in diff, _megdnn_tensor_in input,
_megdnn_tensor_in scale, _megdnn_tensor_in zero_point,
_megdnn_tensor_in grad_scale, _megdnn_tensor_out grad_x,
_megdnn_tensor_out grad_s, _megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout& /* diff */,
const TensorLayout& /* input */,
const TensorLayout& /* scale */,
const TensorLayout& /* zero_point */,
const TensorLayout& /* grad_scale */,
const TensorLayout& /* grad_x */,
const TensorLayout& /* grad_s */) override {
return 0;
}
private:
void exec_noncontig(_megdnn_tensor_in diff, _megdnn_tensor_in input,
_megdnn_tensor_in scale, _megdnn_tensor_in zero_point,
_megdnn_tensor_in grad_scale, _megdnn_tensor_out grad_x,
_megdnn_tensor_out grad_s);
};
} // namespace cuda
} // namespace megdnn
......@@ -50,6 +50,7 @@
#include "src/naive/local/opr_impl.h"
#include "src/naive/local_share/opr_impl.h"
#include "src/naive/lrn/opr_impl.h"
#include "src/naive/lsq/opr_impl.h"
#include "src/naive/mask_conv/opr_impl.h"
#include "src/naive/matrix_inverse/opr_impl.h"
#include "src/naive/matrix_mul/opr_impl.h"
......
/**
* \file dnn/src/naive/lsq/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/naive/lsq/opr_impl.h"
#include <cmath>
#include "megdnn/tensor_iter.h"
#include "src/common/elemwise_helper.cuh"
#include "src/common/utils.h"
#include "src/naive/handle.h"
namespace {
using namespace megdnn;
template <typename T>
void forward_impl(const ElemwiseOpParamN<5> src, float qmin, float qmax) {
auto inp = tensor_iter_valonly<T>(src[0]).begin();
auto out = tensor_iter_valonly<T>(src[1]).begin();
auto scale = tensor_iter_valonly<T>(src[2]).begin();
auto zero_point = tensor_iter_valonly<T>(src[3]).begin();
auto grad_scale = tensor_iter_valonly<T>(src[4]).begin();
size_t total = src[0].layout.total_nr_elems();
for (size_t i = 0; i < total; ++i) {
T x = (*inp) / (*scale) + (*zero_point);
x = x <= qmin ? qmin : x;
x = x >= qmax ? qmax : x;
x = round(x);
*out = (x - (*zero_point)) * (*scale);
++inp;
++out;
++scale;
++zero_point;
++grad_scale;
}
}
template <typename T>
void backward_impl(const ElemwiseOpParamN<7> src, float qmin, float qmax) {
auto diff = tensor_iter_valonly<T>(src[0]).begin();
auto input = tensor_iter_valonly<T>(src[1]).begin();
auto scale = tensor_iter_valonly<T>(src[2]).begin();
auto zero_point = tensor_iter_valonly<T>(src[3]).begin();
auto grad_scale = tensor_iter_valonly<T>(src[4]).begin();
auto grad_x = tensor_iter_valonly<T>(src[5]).begin();
auto grad_s = tensor_iter_valonly<T>(src[6]).begin();
size_t total = src[0].layout.total_nr_elems();
for (size_t i = 0; i < total; ++i) {
T x = (*input) / (*scale) + (*zero_point);
bool ind_small = x < qmin;
bool ind_big = x > qmax;
bool ind_middle = ind_small ^ ind_big;
ind_middle = !ind_middle;
*grad_s = ind_small * qmin + ind_big * qmax +
ind_middle * (-x + round(x));
*grad_s = (*grad_s) * (*grad_scale) * (*diff);
*grad_x = ind_middle * (*diff);
++diff;
++input;
++scale;
++zero_point;
++grad_scale;
++grad_x;
++grad_s;
}
}
} // namespace
namespace megdnn {
namespace naive {
void LSQForwardImpl::exec(_megdnn_tensor_in input, _megdnn_tensor_in scale,
_megdnn_tensor_in zero_point,
_megdnn_tensor_in grad_scale,
_megdnn_tensor_out output,
_megdnn_workspace workspace) {
check_exec(input.layout, scale.layout, zero_point.layout, grad_scale.layout,
output.layout, workspace.size);
ElemwiseOpParamN<5> src;
src[0] = input;
src[1] = output;
src[2] = scale;
src[2].layout = src[2].layout.broadcast(input.layout);
src[3] = zero_point;
src[3].layout = src[3].layout.broadcast(input.layout);
src[4] = grad_scale;
src[4].layout = src[4].layout.broadcast(input.layout);
#define cb(DType) \
if (input.layout.dtype == DType()) { \
using T = typename DTypeTrait<DType>::ctype; \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
forward_impl<T>(src, param().qmin, param().qmax)); \
return; \
}
cb(dtype::Float32)
#undef cb
}
void LSQBackwardImpl::exec(_megdnn_tensor_in diff, _megdnn_tensor_in input,
_megdnn_tensor_in scale,
_megdnn_tensor_in zero_point,
_megdnn_tensor_in grad_scale,
_megdnn_tensor_out grad_x, _megdnn_tensor_out grad_s,
_megdnn_workspace workspace) {
check_exec(diff.layout, input.layout, scale.layout, zero_point.layout,
grad_scale.layout, grad_x.layout, grad_s.layout, workspace.size);
ElemwiseOpParamN<7> src;
src[0] = diff;
src[1] = input;
src[2] = scale;
src[2].layout = src[2].layout.broadcast(input.layout);
src[3] = zero_point;
src[3].layout = src[3].layout.broadcast(input.layout);
src[4] = grad_scale;
src[4].layout = src[4].layout.broadcast(input.layout);
src[5] = grad_x;
src[6] = grad_s;
#define cb(DType) \
if (diff.layout.dtype == DType() && grad_x.layout.dtype == DType() && \
input.layout.dtype == DType()) { \
using T = typename DTypeTrait<DType>::ctype; \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
backward_impl<T>(src, param().qmin, param().qmax)); \
return; \
}
cb(dtype::Float32)
#undef cb
}
} // namespace naive
} // namespace megdnn
/**
* \file dnn/src/naive/lsq/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megdnn/oprs.h"
namespace megdnn {
namespace naive {
class LSQForwardImpl final : public LSQForward {
public:
using LSQForward::LSQForward;
void exec(_megdnn_tensor_in input, _megdnn_tensor_in scale,
_megdnn_tensor_in zero_point, _megdnn_tensor_in grad_scale,
_megdnn_tensor_out output, _megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout& /* input */,
const TensorLayout& /* scale */,
const TensorLayout& /* zero_point */,
const TensorLayout& /* grad_scale */,
const TensorLayout& /* output */) override {
return 0;
}
};
class LSQBackwardImpl final : public LSQBackward {
public:
using LSQBackward::LSQBackward;
void exec(_megdnn_tensor_in diff, _megdnn_tensor_in input,
_megdnn_tensor_in scale, _megdnn_tensor_in zero_point,
_megdnn_tensor_in grad_scale, _megdnn_tensor_out grad_x,
_megdnn_tensor_out grad_s, _megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout& /* diff */,
const TensorLayout& /* input */,
const TensorLayout& /* scale */,
const TensorLayout& /* zero_point */,
const TensorLayout& /* grad_scale */,
const TensorLayout& /* grad_x */,
const TensorLayout& /* grad_s */) override {
return 0;
}
};
} // namespace naive
} // namespace megdnn
/**
* \file dnn/test/common/lsq.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megdnn/basic_types.h"
#include "megdnn/opr_param_defs.h"
namespace megdnn {
namespace test {
namespace lsq {
struct TestArg {
param::LSQ param;
TensorShape ishape;
TensorShape scale_shape;
TensorShape zeropoint_shape;
TensorShape gradscale_shape;
TestArg(param::LSQ param, TensorShape ishape, TensorShape scale_shape,
TensorShape zeropoint_shape, TensorShape gradscale_shape)
: param(param),
ishape(ishape),
scale_shape(scale_shape),
zeropoint_shape(zeropoint_shape),
gradscale_shape(gradscale_shape) {}
};
inline std::vector<TestArg> get_args() {
std::vector<TestArg> args;
param::LSQ cur_param;
cur_param.qmin = -127;
cur_param.qmax = 127;
for (size_t i = 10; i < 30; i += 2) {
args.emplace_back(cur_param, TensorShape{10, 64, i, i}, TensorShape{1},
TensorShape{1}, TensorShape{1});
}
return args;
}
} // namespace lsq
} // namespace test
} // namespace megdnn
\ No newline at end of file
/**
* \file dnn/test/cuda/lsq.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "test/common/lsq.h"
#include "megdnn/oprs.h"
#include "test/common/checker.h"
#include "test/cuda/fixture.h"
namespace megdnn {
namespace test {
using namespace lsq;
TEST_F(CUDA, LSQ) {
std::vector<TestArg> args = get_args();
auto dtype = dtype::Float32();
for (auto&& arg : args) {
auto param = arg.param;
auto ishape = arg.ishape;
auto scale_shape = arg.scale_shape;
auto zeropoint_shape = arg.zeropoint_shape;
auto gradscale_shape = arg.gradscale_shape;
Checker<LSQForward> checker(handle_cuda());
checker.set_param(param)
.set_dtype(0, dtype)
.set_dtype(1, dtype)
.set_dtype(2, dtype)
.set_dtype(3, dtype)
.set_dtype(4, dtype)
.execs({ishape, scale_shape, zeropoint_shape, gradscale_shape,
ishape});
}
// test noncontiguous layout
for (auto&& arg : args) {
auto param = arg.param;
auto ishape = arg.ishape;
auto sshape = arg.scale_shape;
auto zeropoint_shape = arg.zeropoint_shape;
auto gradscale_shape = arg.gradscale_shape;
Checker<LSQForward> checker(handle_cuda());
TensorLayout ilayout(
ishape,
{(long int)(ishape[1] * ishape[2] * ishape[3] * 2),
(long int)(ishape[2] * ishape[3]), (long int)ishape[3], 1},
dtype::Float32());
checker.set_param(param).execl({ilayout,
{sshape, dtype::Float32()},
{zeropoint_shape, dtype::Float32()},
{gradscale_shape, dtype::Float32()},
ilayout});
}
}
TEST_F(CUDA, LSQ_BACKWARD) {
std::vector<TestArg> args = get_args();
auto dtype = dtype::Float32();
for (auto&& arg : args) {
auto param = arg.param;
auto ishape = arg.ishape;
auto scale_shape = arg.scale_shape;
auto zeropoint_shape = arg.zeropoint_shape;
auto gradscale_shape = arg.gradscale_shape;
Checker<LSQBackward> checker(handle_cuda());
checker.set_param(param)
.set_dtype(0, dtype)
.set_dtype(1, dtype)
.set_dtype(2, dtype)
.set_dtype(3, dtype)
.set_dtype(4, dtype)
.set_dtype(5, dtype)
.set_dtype(6, dtype)
.execs({ishape, ishape, scale_shape, zeropoint_shape,
gradscale_shape, ishape, ishape});
}
// test noncontiguous layout
for (auto&& arg : args) {
auto param = arg.param;
auto ishape = arg.ishape;
auto sshape = arg.scale_shape;
auto zeropoint_shape = arg.zeropoint_shape;
auto gradscale_shape = arg.gradscale_shape;
Checker<LSQBackward> checker(handle_cuda());
TensorLayout ilayout(
ishape,
{(long int)(ishape[1] * ishape[2] * ishape[3] * 2),
(long int)(ishape[2] * ishape[3]), (long int)ishape[3], 1},
dtype::Float32());
checker.set_param(param).execl({ilayout,
ilayout,
{sshape, dtype::Float32()},
{zeropoint_shape, dtype::Float32()},
{gradscale_shape, dtype::Float32()},
ilayout,
ilayout});
}
}
} // namespace test
} // namespace megdnn
\ No newline at end of file
/**
* \file dnn/test/naive/sliding_window_transpose.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "test/naive/fixture.h"
#include "megdnn/oprs/nn.h"
#include "test/common/checker.h"
using namespace megdnn;
using namespace test;
TEST_F(NAIVE, LSQ_FORWARD) {
Checker<LSQ> checker(handle(), /* check_dispatch */ false);
param::LSQ param;
param.qmin = -127;
param.qmax = 127;
TensorND input =
TensorValue({2, 2, 2, 2}, dtype::Float32(),
{0, 1, 3, 4, 1, 2, 4, 5, 3, 4, 6, 7, 4, 5, 7, 8});
TensorND scale_shape = TensorValue({1}, dtype::Float32(), {2});
TensorND zero_point = TensorValue({1}, dtype::Float32(), {1});
TensorND grad_scale = TensorValue({1}, dtype::Float32(), {0.5});
TensorND output =
TensorValue({2, 2, 2, 2}, dtype::Float32(),
{0, 2, 4, 4, 2, 2, 4, 6, 4, 4, 6, 8, 4, 6, 8, 8});
checker.set_param(param).exect(
Testcase{input, scale_shape, zero_point, grad_scale, {}},
Testcase{{}, {}, {}, {}, output});
}
......@@ -6,7 +6,7 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .fake_quant import TQT, FakeQuantize
from .fake_quant import LSQ, TQT, FakeQuantize
from .observer import (
ExponentialMovingAverageObserver,
HistogramObserver,
......
......@@ -12,13 +12,15 @@ from .. import functional as F
from ..core.tensor.dtype import QuantDtypeMeta, _builtin_quant_dtypes
from ..logger import get_logger
from ..module import Module
from ..tensor import Parameter
from ..tensor import Parameter, Tensor
from .utils import (
LSQParams,
QParams,
QParamsModuleMixin,
QuantMode,
create_qparams,
fake_quant_tensor,
lsq_forward,
tqt_forward,
)
......@@ -117,3 +119,58 @@ class FakeQuantize(_FakeQuantize):
qparams.dtype_meta, self.dtype
)
return fake_quant_tensor(inp, qparams)
class LSQ(_FakeQuantize, QParamsModuleMixin):
r"""
LSQ: https://arxiv.org/pdf/1902.08153.pdf Estimating and scaling the
task loss gradient at each weight and activation layer's quantizer step size
:param dtype: a string or :class:`~.QuantDtypeMeta` indicating the target
quantization dtype of input.
:param enable: whether do ``normal_forward`` or ``fake_quant_forward``.
:param eps:a small value to avoid division by zero. Default: 1e-5
"""
def init(
self,
dtype: Union[str, QuantDtypeMeta],
enable: bool = True,
eps: float = 1e-5,
**kwargs
):
super().__init__(dtype=dtype, enable=enable, **kwargs)
self.eps = Tensor(eps, dtype="float32")
self.step_size = Parameter(1.0, dtype="float32")
def set_qparams(self, qparams: LSQParams):
self.mode = qparams.mode
if qparams.mode == QuantMode.ASYMMERTIC:
self.zero_point = qparams.zero_point
else:
self.zero_point = Tensor([0.0], dtype="float32")
if qparams.scale is None:
raise AssertionError("Can not get an initialized scale")
init_step_size = qparams.scale
if init_step_size < self.eps:
init_step_size = 0
else:
init_step_size = init_step_size - self.eps
self.step_size = Parameter(init_step_size, dtype="float32")
self.grad_scale = qparams.grad_scale
def fake_quant_forward(self, inp, qparams: LSQParams = None):
step_size = F.abs(self.step_size) + self.eps
return lsq_forward(
self.qmin, self.qmax, inp, step_size, self.zero_point, self.grad_scale
)
def get_qparams(self):
return LSQParams(
mode=self.mode,
dtype_meta=self.dtype,
scale=F.abs(self.step_size.detach()) + self.eps,
zero_point=self.zero_point,
grad_scale=self.grad_scale,
)
......@@ -43,6 +43,16 @@ def tqt_forward(qmin, qmax, inp, scale):
return output
def lsq_forward(qmin, qmax, inp, step_size, zero_point=None, scale_grad=None):
if zero_point is None:
zero_point = Tensor([0.0], dtype=np.float32)
if scale_grad is None:
scale_grad = Tensor([1.0], dtype=np.float32)
op = builtin.LSQ(qmin=qmin, qmax=qmax)
(output,) = apply(op, inp, step_size, zero_point, scale_grad)
return output
def register_method_to_class(cls):
def decorator(func):
@wraps(func)
......@@ -105,6 +115,47 @@ class QParams:
return "QParams({})".format(content)
class LSQParams:
"""
To standardize LSQ's qparams format. If custom
qparams is needed, inherit this class and add custom ``__slots__``.
"""
__slots__ = "mode", "dtype_meta", "scale", "zero_point", "grad_scale"
def __init__(
self,
mode: QuantMode,
dtype_meta: QuantDtypeMeta,
scale: Tensor,
zero_point: Tensor,
grad_scale: Tensor,
):
self.mode = mode
self.dtype_meta = dtype_meta
self.scale = scale
self.zero_point = zero_point
self.grad_scale = grad_scale
def update(self, lsqparams: "LSQParams"):
for key in self.__slots__:
setattr(self, key, getattr(lsqparams, key))
def __eq__(self, other):
if len(self.__slots__) != len(other.__slots__):
return False
for key in self.__slots__:
if not hasattr(other, key) or getattr(self, key) != getattr(other, key):
return False
return True
def __repr__(self):
content = ", ".join(
["{}={}".format(key, getattr(self, key)) for key in self.__slots__]
)
return "LSQParams({})".format(content)
class QParamsModuleMixin(abc.ABC):
def get_quantized_dtype(self):
qparams = self.get_qparams()
......
......@@ -10,6 +10,7 @@ import numpy as np
import pytest
import megengine as mge
import megengine.functional as F
from megengine import tensor
from megengine.core.autodiff.grad import Function, Grad
from megengine.core.tensor.dtype import QuantDtypeMeta
......@@ -19,6 +20,7 @@ from megengine.quantization.utils import (
QuantMode,
create_qparams,
fake_quant_tensor,
lsq_forward,
tqt_forward,
)
......@@ -150,3 +152,78 @@ def test_fakequant():
zero_point = tensor(1.0 * np.ones((1, 32, 1, 1)), dtype=np.float32)
scale = tensor(4.0 * np.ones((1, 32, 1, 1)), dtype=np.float32)
run(zero_point, scale)
class LSQ_numpy:
def __init__(self, lowerbound, upperbound):
super().__init__()
self.lowerbound = lowerbound
self.upperbound = upperbound
def forward(self, inp, scale, zero_point, grad_scale):
inp_scaled = inp / scale + zero_point
inp_clipped = np.maximum(
np.minimum(inp_scaled, self.upperbound), self.lowerbound
)
inp_rounded = np.floor(inp_clipped + 0.5)
inp_flq = (inp_rounded - zero_point) * scale
self.saved_tensors = (inp_scaled, inp_rounded, scale, grad_scale)
return inp_flq
def backward(self, grad_inp_flq):
(inp_scaled, inp_rounded, scale, grad_scale) = self.saved_tensors
ind_small = inp_scaled < self.lowerbound
ind_big = inp_scaled > self.upperbound
ind_middle = np.logical_xor(ind_small, ind_big)
ind_middle = np.abs(ind_middle - 1)
grad_s = (
ind_small * self.lowerbound
+ ind_big * self.upperbound
+ ind_middle * (-inp_scaled + inp_rounded)
)
grad_s = grad_s * grad_scale * grad_inp_flq
grad_s = grad_s.sum()
grad_inp = grad_inp_flq * ind_middle
return grad_inp, grad_s
def test_lsq():
def preprocess(scale, eps):
scale = np.array([0]) if scale < eps else scale - eps
return np.abs(scale) + eps
g = []
def cb(grad):
g.append(grad)
x = np.random.randint(-128, 128, size=(1, 2, 3, 4)).astype("float32")
s = np.random.rand(1)
eps = np.array([1e-5], dtype="float32")
s = preprocess(s, eps)
zero_point = np.array([1.0], dtype="float32")
grad_s = np.array([2.0], dtype="float32")
g_y = np.ones(shape=(1, 2, 3, 4), dtype="float32")
n = LSQ_numpy(-127, 127)
y_np = n.forward(x, s, zero_point, grad_s)
g_x_np, g_s_np = n.backward(g_y)
x = mge.tensor(x, dtype="float32")
s = mge.tensor(s, dtype="float32")
zero_point = mge.tensor(zero_point, dtype="float32")
grad_s = mge.tensor(grad_s, dtype="float32")
g_y = mge.tensor(g_y, dtype="float32")
grad = Grad().wrt(x, s, callback=cb)
y = lsq_forward(-127, 127, x, s, zero_point, grad_s)
grad(y, g_y)
g_x, g_s = g
np.testing.assert_allclose(y.numpy(), y_np, rtol=1e-7, atol=1e-7)
np.testing.assert_allclose(g_x.numpy(), g_x_np, rtol=1e-7, atol=1e-7)
np.testing.assert_allclose(g_s.numpy(), g_s_np, rtol=5e-7, atol=5e-7)
......@@ -6,23 +6,26 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
// FIXME: split this file into separate files for each specialized op
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/blas.h"
#include "megbrain/opr/dnn/adaptive_pooling.h"
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/dnn/correlation.h"
#include "megbrain/opr/dnn/fake_quant.h"
#include "megbrain/opr/dnn/tqt.h"
#include "megbrain/opr/dnn/pooling.h"
#include "megbrain/opr/dnn/images2neibs.h"
#include "megbrain/opr/dnn/local.h"
#include "megbrain/opr/dnn/lsq.h"
#include "megbrain/opr/dnn/pooling.h"
#include "megbrain/opr/dnn/roi_align.h"
#include "megbrain/opr/dnn/correlation.h"
#include "megbrain/opr/dnn/roi_pooling.h"
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/blas.h"
#include "megbrain/opr/dnn/tqt.h"
#include "megbrain/opr/imgproc.h"
#include "megbrain/opr/indexing.h"
#include "megbrain/opr/io.h"
......@@ -32,25 +35,23 @@
#include "megbrain/opr/tensor_gen.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/utility.h"
#include "megbrain/opr/dnn/images2neibs.h"
#include "../op_trait.h"
namespace mgb::imperative {
namespace { namespace dimshuffle {
namespace {
namespace dimshuffle {
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
auto* node = &node_->cast_final_safe<opr::Dimshuffle>();
std::vector<int> pattern(node->param().pattern_len);
for (size_t i = 0; i < node->param().pattern_len; ++ i) {
for (size_t i = 0; i < node->param().pattern_len; ++i) {
pattern[i] = node->param().pattern[i];
}
return Dimshuffle::make(pattern);
}
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& ds = static_cast<const Dimshuffle&>(def);
OperatorNodeConfig config{ds.make_name()};
return opr::Dimshuffle::make(inputs[0], ds.pattern, 0UL, config);
......@@ -60,12 +61,12 @@ OP_TRAIT_REG(Dimshuffle, Dimshuffle, opr::Dimshuffle)
.make_from_op_node(make_from_op_node)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // dimshuffle
} // namespace dimshuffle
} // namespace
namespace { namespace add_axis {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
namespace {
namespace add_axis {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& add_axis = static_cast<const AddAxis&>(def);
using Desc = opr::AxisAddRemove::AxisDesc;
std::vector<Desc> param;
......@@ -76,15 +77,13 @@ auto apply_on_var_node(
return opr::AxisAddRemove::make(inputs[0], param, config);
}
OP_TRAIT_REG(AddAxis, AddAxis)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // add_axis
OP_TRAIT_REG(AddAxis, AddAxis).apply_on_var_node(apply_on_var_node).fallback();
} // namespace add_axis
} // namespace
namespace { namespace remove_axis {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
namespace {
namespace remove_axis {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& remove_axis = static_cast<const RemoveAxis&>(def);
using Desc = opr::AxisAddRemove::AxisDesc;
std::vector<Desc> param;
......@@ -98,34 +97,33 @@ auto apply_on_var_node(
OP_TRAIT_REG(RemoveAxis, RemoveAxis)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // remove_axis
} // namespace remove_axis
} // namespace
namespace { namespace top_k {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
namespace {
namespace top_k {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& topk = static_cast<const TopK&>(def);
OperatorNodeConfig config{topk.make_name()};
return opr::TopK::make(inputs[0], inputs[1], topk.param(), config)[0]
.node()->owner_opr();
.node()
->owner_opr();
}
OP_TRAIT_REG(TopK, TopK)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // top_k
OP_TRAIT_REG(TopK, TopK).apply_on_var_node(apply_on_var_node).fallback();
} // namespace top_k
} // namespace
namespace { namespace reduce {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
namespace {
namespace reduce {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& reduce = static_cast<const Reduce&>(def);
OperatorNodeConfig config{reduce.make_name()};
if (inputs.size() > 1) {
return opr::Reduce::make(inputs[0], reduce.param(), inputs[1], config);
} else {
return opr::Reduce::make(
inputs[0], reduce.param(), (cg::VarNode*)nullptr, config);
return opr::Reduce::make(inputs[0], reduce.param(),
(cg::VarNode*)nullptr, config);
}
}
......@@ -138,35 +136,39 @@ OP_TRAIT_REG(Reduce, Reduce, opr::Reduce)
.make_from_op_node(make_from_op_node)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // reduce
} // namespace reduce
} // namespace
namespace { namespace adaptive_pooling {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
namespace {
namespace adaptive_pooling {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& pool = static_cast<const AdaptivePooling&>(def);
OperatorNodeConfig config{pool.make_name()};
return opr::AdaptivePooling::make(inputs[0], inputs[1], pool.param(), config);
return opr::AdaptivePooling::make(inputs[0], inputs[1], pool.param(),
config);
}
OP_TRAIT_REG(AdaptivePooling, AdaptivePooling)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // adaptive_pooling
} // namespace adaptive_pooling
} // namespace
namespace { namespace conv_bias {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
namespace {
namespace conv_bias {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& conv = static_cast<const ConvBias&>(def);
cg::OperatorNodeConfig config{conv.dtype};
config.name(conv.make_name());
if (inputs.size() == 2) {
return opr::ConvBias::make(inputs[0], inputs[1], conv.param(), conv.policy(), config);
return opr::ConvBias::make(inputs[0], inputs[1], conv.param(),
conv.policy(), config);
} else if (inputs.size() == 3) {
return opr::ConvBias::make(inputs[0], inputs[1], inputs[2], conv.param(), conv.policy(), config);
return opr::ConvBias::make(inputs[0], inputs[1], inputs[2],
conv.param(), conv.policy(), config);
} else if (inputs.size() == 4) {
return opr::ConvBias::make(inputs[0], inputs[1], inputs[2], inputs[3], conv.param(), conv.policy(), config);
return opr::ConvBias::make(inputs[0], inputs[1], inputs[2], inputs[3],
conv.param(), conv.policy(), config);
}
mgb_assert(0);
}
......@@ -174,21 +176,25 @@ auto apply_on_var_node(
OP_TRAIT_REG(ConvBias, ConvBias)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // conv_bias
} // namespace conv_bias
} // namespace
namespace { namespace batch_conv_bias {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
namespace {
namespace batch_conv_bias {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& conv = static_cast<const BatchConvBias&>(def);
cg::OperatorNodeConfig config{conv.dtype};
config.name(conv.make_name());
if (inputs.size() == 2) {
return opr::BatchConvBias::make(inputs[0], inputs[1], conv.param(), conv.policy(), config);
return opr::BatchConvBias::make(inputs[0], inputs[1], conv.param(),
conv.policy(), config);
} else if (inputs.size() == 3) {
return opr::BatchConvBias::make(inputs[0], inputs[1], inputs[2], conv.param(), conv.policy(), config);
return opr::BatchConvBias::make(inputs[0], inputs[1], inputs[2],
conv.param(), conv.policy(), config);
} else if (inputs.size() == 4) {
return opr::BatchConvBias::make(inputs[0], inputs[1], inputs[2], inputs[3], conv.param(), conv.policy(), config);
return opr::BatchConvBias::make(inputs[0], inputs[1], inputs[2],
inputs[3], conv.param(), conv.policy(),
config);
}
mgb_assert(0);
}
......@@ -196,25 +202,23 @@ auto apply_on_var_node(
OP_TRAIT_REG(BatchConvBias, BatchConvBias)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // batch_conv_bias
} // namespace batch_conv_bias
} // namespace
namespace { namespace pooling {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
namespace {
namespace pooling {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& pool = static_cast<const Pooling&>(def);
OperatorNodeConfig config{pool.make_name()};
return opr::Pooling::make(inputs[0], pool.param(), config);
}
OP_TRAIT_REG(Pooling, Pooling)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // pooling
OP_TRAIT_REG(Pooling, Pooling).apply_on_var_node(apply_on_var_node).fallback();
} // namespace pooling
} // namespace
namespace { namespace matrix_mul {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
namespace {
namespace matrix_mul {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& matmul = static_cast<const MatrixMul&>(def);
mgb_assert(inputs.size() == 2);
OperatorNodeConfig config{matmul.make_name()};
......@@ -224,12 +228,12 @@ auto apply_on_var_node(
OP_TRAIT_REG(MatrixMul, MatrixMul)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // matrix_mul
} // namespace matrix_mul
} // namespace
namespace { namespace batched_matrix_mul {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
namespace {
namespace batched_matrix_mul {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& matmul = static_cast<const BatchedMatrixMul&>(def);
mgb_assert(inputs.size() == 2);
OperatorNodeConfig config{matmul.make_name()};
......@@ -239,84 +243,77 @@ auto apply_on_var_node(
OP_TRAIT_REG(BatchedMatrixMul, BatchedMatrixMul)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // batched_matrix_mul
} // namespace batched_matrix_mul
} // namespace
namespace { namespace dot {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
namespace {
namespace dot {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = def.cast_final_safe<Dot>();
mgb_assert(inputs.size() == 2);
OperatorNodeConfig config{op.make_name()};
return opr::Dot::make(inputs[0], inputs[1], config);
}
OP_TRAIT_REG(Dot, Dot)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // dot
OP_TRAIT_REG(Dot, Dot).apply_on_var_node(apply_on_var_node).fallback();
} // namespace dot
} // namespace
namespace { namespace argsort {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
namespace {
namespace argsort {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& argsort = static_cast<const Argsort&>(def);
OperatorNodeConfig config{argsort.make_name()};
return opr::Argsort::make(inputs[0], argsort.param(), config);
}
OP_TRAIT_REG(Argsort, Argsort)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // argsort
OP_TRAIT_REG(Argsort, Argsort).apply_on_var_node(apply_on_var_node).fallback();
} // namespace argsort
} // namespace
namespace { namespace argmax {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
namespace {
namespace argmax {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& argmax = static_cast<const Argmax&>(def);
OperatorNodeConfig config{argmax.make_name()};
return opr::Argmax::make(inputs[0], argmax.param(), config);
}
OP_TRAIT_REG(Argmax, Argmax)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // argmax
OP_TRAIT_REG(Argmax, Argmax).apply_on_var_node(apply_on_var_node).fallback();
} // namespace argmax
} // namespace
namespace { namespace argmin {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
namespace {
namespace argmin {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& argmin = static_cast<const Argmin&>(def);
OperatorNodeConfig config{argmin.make_name()};
return opr::Argmin::make(inputs[0], argmin.param(), config);
}
OP_TRAIT_REG(Argmin, Argmin)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // argmin
OP_TRAIT_REG(Argmin, Argmin).apply_on_var_node(apply_on_var_node).fallback();
} // namespace argmin
} // namespace
namespace { namespace warp_perspective {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
namespace {
namespace warp_perspective {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& warp = static_cast<const WarpPerspective&>(def);
OperatorNodeConfig config{warp.make_name()};
if (inputs.size() == 3) {
return opr::WarpPerspective::make(inputs[0], inputs[1], inputs[2], warp.param(), config);
return opr::WarpPerspective::make(inputs[0], inputs[1], inputs[2],
warp.param(), config);
} else {
mgb_assert(inputs.size() == 4);
return opr::WarpPerspective::make(
inputs[0], inputs[1], inputs[2], inputs[3], warp.param(), config);
return opr::WarpPerspective::make(inputs[0], inputs[1], inputs[2],
inputs[3], warp.param(), config);
}
}
OP_TRAIT_REG(WarpPerspective, WarpPerspective)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // warp_perspective
} // namespace warp_perspective
} // namespace
namespace { namespace group_local {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
namespace {
namespace group_local {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& local = static_cast<const GroupLocal&>(def);
mgb_assert(inputs.size() == 2);
OperatorNodeConfig config{local.make_name()};
......@@ -325,12 +322,12 @@ auto apply_on_var_node(
OP_TRAIT_REG(GroupLocal, GroupLocal)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // group_local
} // namespace group_local
} // namespace
namespace { namespace indexing_one_hot {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
namespace {
namespace indexing_one_hot {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const IndexingOneHot&>(def);
mgb_assert(inputs.size() == 2);
OperatorNodeConfig config{op.make_name()};
......@@ -339,64 +336,60 @@ auto apply_on_var_node(
OP_TRAIT_REG(IndexingOneHot, IndexingOneHot)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // indexing_one_hot
} // namespace indexing_one_hot
} // namespace
namespace { namespace indexing_set_one_hot {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
namespace {
namespace indexing_set_one_hot {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const IndexingSetOneHot&>(def);
mgb_assert(inputs.size() == 3);
OperatorNodeConfig config{op.make_name()};
return opr::IndexingSetOneHot::make(inputs[0], inputs[1], inputs[2], op.param(), config);
return opr::IndexingSetOneHot::make(inputs[0], inputs[1], inputs[2],
op.param(), config);
}
OP_TRAIT_REG(IndexingSetOneHot, IndexingSetOneHot)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // indexing_set_one_hot
} // namespace indexing_set_one_hot
} // namespace
namespace { namespace typecvt {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
namespace {
namespace typecvt {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const TypeCvt&>(def);
mgb_assert(inputs.size() == 1);
OperatorNodeConfig config{op.make_name()};
return opr::TypeCvt::make(inputs[0], op.dtype, config);
}
OP_TRAIT_REG(TypeCvt, TypeCvt)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // typecvt
OP_TRAIT_REG(TypeCvt, TypeCvt).apply_on_var_node(apply_on_var_node).fallback();
} // namespace typecvt
} // namespace
namespace { namespace concat {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
namespace {
namespace concat {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const Concat&>(def);
cg::OperatorNodeConfig config{op.comp_node};
config.name(op.make_name());
return opr::Concat::make(inputs, op.axis, config);
}
OP_TRAIT_REG(Concat, Concat)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // concat
OP_TRAIT_REG(Concat, Concat).apply_on_var_node(apply_on_var_node).fallback();
} // namespace concat
} // namespace
namespace { namespace copy {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
namespace {
namespace copy {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const Copy&>(def);
mgb_assert(inputs.size() == 1);
cg::OperatorNodeConfig config{op.comp_node};
config.name(op.make_name());
return opr::Copy::make(inputs[0], config);
}
OP_TRAIT_REG(Copy, Copy)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // copy
OP_TRAIT_REG(Copy, Copy).apply_on_var_node(apply_on_var_node).fallback();
} // namespace copy
} // namespace
namespace { namespace assert_equal {
auto apply_on_var_node(
......@@ -408,81 +401,81 @@ auto apply_on_var_node(
} else {
// workaround for MiniGraph, which only allow one opr in the graph
mgb_assert(inputs.size() == 3);
return opr::AssertEqual::make(inputs[0], inputs[1], inputs[2], op.param(), {});
return opr::AssertEqual::make(inputs[0], inputs[1], inputs[2],
op.param(), {});
}
}
OP_TRAIT_REG(AssertEqual, AssertEqual)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // assert_equal
} // namespace assert_equal
} // namespace
namespace { namespace roi_align {
VarNodeArray apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
namespace {
namespace roi_align {
VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const ROIAlign&>(def);
mgb_assert(inputs.size() == 2);
OperatorNodeConfig config{op.make_name()};
auto* opr = opr::ROIAlign::make(
inputs[0], inputs[1], op.param(), config).node()->owner_opr();
auto* opr = opr::ROIAlign::make(inputs[0], inputs[1], op.param(), config)
.node()
->owner_opr();
return {opr->output(0), opr->output(1)};
}
OP_TRAIT_REG(ROIAlign, ROIAlign)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // roi_align
} // namespace roi_align
} // namespace
namespace { namespace correlation {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
namespace {
namespace correlation {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const Correlation&>(def);
mgb_assert(inputs.size() == 2);
OperatorNodeConfig config{op.make_name()};
return opr::Correlation::make(
inputs[0], inputs[1], op.param(), config);
return opr::Correlation::make(inputs[0], inputs[1], op.param(), config);
}
OP_TRAIT_REG(Correlation, Correlation)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // correlation
} // namespace correlation
} // namespace
#if MGB_CUDA
namespace { namespace nvof {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
namespace {
namespace nvof {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const NvOf&>(def);
mgb_assert(inputs.size() == 1);
OperatorNodeConfig config{op.make_name()};
return opr::NvOf::make(inputs[0], op.param(), config);
}
OP_TRAIT_REG(NvOf, NvOf)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // nvof
OP_TRAIT_REG(NvOf, NvOf).apply_on_var_node(apply_on_var_node).fallback();
} // namespace nvof
} // namespace
#endif
namespace { namespace linspace {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
namespace {
namespace linspace {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const Linspace&>(def);
mgb_assert(inputs.size() == 3);
cg::OperatorNodeConfig config{op.comp_node};
config.name(op.make_name());
return opr::Linspace::make(inputs[0], inputs[1], inputs[2], op.param(), config);
return opr::Linspace::make(inputs[0], inputs[1], inputs[2], op.param(),
config);
}
OP_TRAIT_REG(Linspace, Linspace)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // linspace
} // namespace linspace
} // namespace
namespace { namespace eye {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
namespace {
namespace eye {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const Eye&>(def);
mgb_assert(inputs.size() == 1);
cg::OperatorNodeConfig config{op.comp_node};
......@@ -490,41 +483,39 @@ auto apply_on_var_node(
opr::Eye::Param param{op.k, op.dtype.enumv()};
return opr::Eye::make(inputs[0], param, config);
}
OP_TRAIT_REG(Eye, Eye)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // eye
OP_TRAIT_REG(Eye, Eye).apply_on_var_node(apply_on_var_node).fallback();
} // namespace eye
} // namespace
namespace { namespace roi_pooling {
VarNodeArray apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
namespace {
namespace roi_pooling {
VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const ROIPooling&>(def);
mgb_assert(inputs.size() == 3);
OperatorNodeConfig config{op.make_name()};
auto* opr = opr::ROIPooling::make(
inputs[0], inputs[1], inputs[2], op.param(), config
).node()->owner_opr();
auto* opr = opr::ROIPooling::make(inputs[0], inputs[1], inputs[2],
op.param(), config)
.node()
->owner_opr();
return {opr->output(0), opr->output(1)};
}
OP_TRAIT_REG(ROIPooling, ROIPooling)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // roi_pooling
} // namespace roi_pooling
} // namespace
namespace { namespace remap {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
namespace {
namespace remap {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const Remap&>(def);
mgb_assert(inputs.size() == 2);
OperatorNodeConfig config{op.make_name()};
return opr::Remap::make(inputs[0], inputs[1], op.param(), config);
}
OP_TRAIT_REG(Remap, Remap)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // remap
OP_TRAIT_REG(Remap, Remap).apply_on_var_node(apply_on_var_node).fallback();
} // namespace remap
} // namespace
namespace {
auto get_index(
......@@ -532,16 +523,19 @@ auto get_index(
const std::vector<std::tuple<int8_t, bool, bool, bool, bool>>& mask) {
size_t length = mask.size();
opr::Subtensor::IndexDesc ret(length);
for (size_t i = 0; i < length; ++ i) {
for (size_t i = 0; i < length; ++i) {
auto&& [axis, begin, end, step, idx] = mask[i];
ret[i].axis = axis;
if (idx) {
ret[i].idx = inputs[vidx++];
} else {
mgb_assert(begin || end || step);
if (begin) ret[i].begin = inputs[vidx++];
if (end) ret[i].end = inputs[vidx++];
if (step) ret[i].step = inputs[vidx++];
if (begin)
ret[i].begin = inputs[vidx++];
if (end)
ret[i].end = inputs[vidx++];
if (step)
ret[i].step = inputs[vidx++];
}
}
mgb_assert(vidx == inputs.size());
......@@ -551,18 +545,18 @@ auto get_index(
#define IN2 inputs[0], inputs[1]
#define FANCY_INDEXING_IMPL(NAME, NR_INPUT) \
namespace NAME##_impl { \
auto apply_on_var_node( \
const OpDef& def, \
const VarNodeArray& inputs) { \
namespace NAME##_impl { \
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { \
auto&& op = static_cast<const NAME&>(def); \
OperatorNodeConfig config{op.make_name()}; \
return opr::NAME::make(IN##NR_INPUT, get_index(inputs, NR_INPUT, op.items), config); \
} \
OP_TRAIT_REG(NAME, NAME) \
return opr::NAME::make(IN##NR_INPUT, \
get_index(inputs, NR_INPUT, op.items), \
config); \
} \
OP_TRAIT_REG(NAME, NAME) \
.apply_on_var_node(apply_on_var_node) \
.fallback(); \
}
}
FANCY_INDEXING_IMPL(Subtensor, 1)
FANCY_INDEXING_IMPL(SetSubtensor, 2)
......@@ -582,38 +576,36 @@ FANCY_INDEXING_IMPL(BatchedSetMeshIndexing, 2)
#undef IN2
} // anonymous namespace
namespace { namespace fake_quant {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
namespace {
namespace fake_quant {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const FakeQuant&>(def);
mgb_assert(inputs.size() == 3);
OperatorNodeConfig config{op.make_name()};
return opr::FakeQuant::make(inputs[0], inputs[1], inputs[2], op.param(), config);
return opr::FakeQuant::make(inputs[0], inputs[1], inputs[2], op.param(),
config);
}
OP_TRAIT_REG(FakeQuant, FakeQuant)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // fake_quant
} // namespace fake_quant
} // namespace
namespace { namespace tqt {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
namespace {
namespace tqt {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const TQT&>(def);
mgb_assert(inputs.size() == 2);
OperatorNodeConfig config{op.make_name()};
return opr::TQT::make(inputs[0], inputs[1], op.param(), config);
}
OP_TRAIT_REG(TQT, TQT)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // tqt
OP_TRAIT_REG(TQT, TQT).apply_on_var_node(apply_on_var_node).fallback();
} // namespace tqt
} // namespace
namespace { namespace elemwise_multi_type {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
namespace {
namespace elemwise_multi_type {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const ElemwiseMultiType&>(def);
OperatorNodeConfig config{op.dtype};
config.name(op.make_name());
......@@ -622,27 +614,27 @@ auto apply_on_var_node(
OP_TRAIT_REG(ElemwiseMultiType, ElemwiseMultiType)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // elemwise_multi_type
} // namespace elemwise_multi_type
} // namespace
namespace { namespace svd {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
namespace {
namespace svd {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const SVD&>(def);
mgb_assert(inputs.size() == 1);
OperatorNodeConfig config{op.make_name()};
return opr::SVD::make(inputs[0], op.param(), config)[0]
.node()->owner_opr()->usable_output();
.node()
->owner_opr()
->usable_output();
}
OP_TRAIT_REG(SVD, SVD)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // svd
OP_TRAIT_REG(SVD, SVD).apply_on_var_node(apply_on_var_node).fallback();
} // namespace svd
} // namespace
namespace { namespace images2neibs {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
namespace {
namespace images2neibs {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const Images2Neibs&>(def);
OperatorNodeConfig config{op.make_name()};
return opr::Images2Neibs::make(inputs[0], op.param(), config);
......@@ -650,6 +642,20 @@ auto apply_on_var_node(
OP_TRAIT_REG(Images2Neibs, Images2Neibs)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // images2neibs
} // namespace images2neibs
} // namespace
namespace {
namespace lsq {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const LSQ&>(def);
mgb_assert(inputs.size() == 4);
OperatorNodeConfig config{op.make_name()};
return opr::LSQ::make(inputs[0], inputs[1], inputs[2], inputs[3],
op.param(), config);
}
OP_TRAIT_REG(LSQ, LSQ).apply_on_var_node(apply_on_var_node).fallback();
} // namespace lsq
} // namespace
} // namespace mgb::imperative
......@@ -6,22 +6,24 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "./helper.h"
#include "megbrain/imperative/backward_graph_opt.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/dnn/batch_norm.h"
#include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/backward_graph_opt.h"
using namespace mgb;
using namespace cg;
using namespace imperative;
template <typename T>
T prepare_backward_graph_inputs(const BackwardGraphResult& bg, const T& inputs, const T& outputs, const T& grads) {
T prepare_backward_graph_inputs(const BackwardGraphResult& bg, const T& inputs,
const T& outputs, const T& grads) {
T ret;
size_t i = 0;
for (auto&& t : inputs) {
......@@ -54,7 +56,9 @@ T expand_grads(const U& bg, const T& outputs) {
}
template <typename T>
T prepare_optimized_backward_inputs(const OptimizedBackwardGraphResult& bg, const T& precomp, const T& inputs, const T& outputs, const T& grads) {
T prepare_optimized_backward_inputs(const OptimizedBackwardGraphResult& bg,
const T& precomp, const T& inputs,
const T& outputs, const T& grads) {
T ret = precomp;
size_t i = 0;
for (auto&& t : inputs) {
......@@ -75,7 +79,8 @@ T prepare_optimized_backward_inputs(const OptimizedBackwardGraphResult& bg, cons
return ret;
}
SmallVector<TensorPtr> apply_shared_on_physical_tensor(std::shared_ptr<OpDef> def, SmallVector<TensorPtr> inputs) {
SmallVector<TensorPtr> apply_shared_on_physical_tensor(
std::shared_ptr<OpDef> def, SmallVector<TensorPtr> inputs) {
return OpDef::apply_on_physical_tensor(*def, inputs);
}
......@@ -83,7 +88,7 @@ TEST(TestImperative, BackwardGraphBasic) {
HostTensorGenerator<> gen;
SmallVector<HostTensorND> hvs;
SmallVector<TensorPtr> inputs;
for(size_t i = 0; i < 2; ++ i) {
for (size_t i = 0; i < 2; ++i) {
hvs.push_back(*gen({42}));
inputs.push_back(Tensor::make(hvs.back()));
}
......@@ -97,7 +102,8 @@ TEST(TestImperative, BackwardGraphBasic) {
for (auto&& i : inputs) {
input_descs.push_back({i->layout(), i->comp_node()});
}
auto result = OpDef::make_backward_graph(*attr, input_descs, {true, true}, {true});
auto result = OpDef::make_backward_graph(*attr, input_descs, {true, true},
{true});
auto&& save_for_backward = result.save_for_backward;
auto&& input_has_grad = result.input_has_grad;
......@@ -106,7 +112,7 @@ TEST(TestImperative, BackwardGraphBasic) {
hvs.push_back(*gen({42}));
inputs.push_back(Tensor::make(hvs.back()));
mgb_assert(save_for_backward.size() == inputs.size());
for (size_t i = 0; i < inputs.size(); ++ i) {
for (size_t i = 0; i < inputs.size(); ++i) {
if (!save_for_backward[i]) {
inputs[i].reset(); // drop unused tensor
}
......@@ -118,13 +124,11 @@ TEST(TestImperative, BackwardGraphBasic) {
}
}
inputs.clear();
auto input_grads = result.backward.apply(
backward_graph_inputs,
auto input_grads = result.backward.apply(backward_graph_inputs,
apply_shared_on_physical_tensor,
[&](auto&& x){ return x; }
);
[&](auto&& x) { return x; });
mgb_assert(input_grads.size() == input_has_grad.size());
for (size_t i = 0; i < input_has_grad.size(); ++ i) {
for (size_t i = 0; i < input_has_grad.size(); ++i) {
mgb_assert(input_has_grad[i] == static_cast<bool>(input_grads[i]));
}
......@@ -133,9 +137,10 @@ TEST(TestImperative, BackwardGraphBasic) {
res.emplace_back();
res.back().copy_from(i->dev_tensor()).sync();
}
for (size_t i = 0; i < 42; ++ i) {
for (size_t j = 0; j < 1; ++ j) {
ASSERT_EQ(hvs[2].ptr<float>()[i] * hvs[j].ptr<float>()[i], res[j ^ 1].ptr<float>()[i]);
for (size_t i = 0; i < 42; ++i) {
for (size_t j = 0; j < 1; ++j) {
ASSERT_EQ(hvs[2].ptr<float>()[i] * hvs[j].ptr<float>()[i],
res[j ^ 1].ptr<float>()[i]);
}
}
}
......@@ -152,7 +157,8 @@ TEST(TestImperative, BackwardGraphIdentity) {
SmallVector<LogicalTensorDesc> input_descs;
input_descs.push_back({a->layout(), a->comp_node()});
auto result = OpDef::make_backward_graph(*attr, input_descs, {true}, {true});
auto result =
OpDef::make_backward_graph(*attr, input_descs, {true}, {true});
auto&& save_for_backward = result.save_for_backward;
auto&& input_has_grad = result.input_has_grad;
......@@ -160,7 +166,7 @@ TEST(TestImperative, BackwardGraphIdentity) {
inputs.push_back(outputs[0]);
inputs.push_back(dc);
mgb_assert(save_for_backward.size() == inputs.size());
for (size_t i = 0; i < inputs.size(); ++ i) {
for (size_t i = 0; i < inputs.size(); ++i) {
if (!save_for_backward[i]) {
inputs[i].reset(); // drop unused tensor
}
......@@ -172,19 +178,17 @@ TEST(TestImperative, BackwardGraphIdentity) {
}
}
inputs.clear();
auto input_grads = result.backward.apply(
backward_graph_inputs,
auto input_grads = result.backward.apply(backward_graph_inputs,
apply_shared_on_physical_tensor,
[&](auto&& x){ return x; }
);
[&](auto&& x) { return x; });
mgb_assert(input_grads.size() == input_has_grad.size());
for (size_t i = 0; i < input_has_grad.size(); ++ i) {
for (size_t i = 0; i < input_has_grad.size(); ++i) {
mgb_assert(input_has_grad[i] == static_cast<bool>(input_grads[i]));
}
HostTensorND hv;
hv.copy_from(input_grads[0]->dev_tensor()).sync();
for (size_t i = 0; i < 42; ++ i) {
for (size_t i = 0; i < 42; ++i) {
ASSERT_EQ(host_dc->ptr<float>()[i], hv.ptr<float>()[i]);
}
}
......@@ -192,7 +196,7 @@ TEST(TestImperative, BackwardGraphIdentity) {
TEST(TestImperative, BatchNormGrad) {
auto cn = CompNode::load("xpux");
using Param = opr::BatchNorm::Param;
size_t N=2, C=3, H=5, W=5;
size_t N = 2, C = 3, H = 5, W = 5;
LogicalTensorDesc inp{TensorLayout{{N, C, H, W}, dtype::Float32()}, cn};
LogicalTensorDesc stat{TensorLayout{{C}, dtype::Float32()}, cn};
{
......@@ -202,7 +206,8 @@ TEST(TestImperative, BatchNormGrad) {
param.fwd_mode = Param::FwdMode::TRAINING;
attr.param.write_pod(param);
OpDef::make_backward_graph(attr, {inp, stat, stat, stat, stat},
{true, true ,true, false, false}, {false, false, false, false, true});
{true, true, true, false, false},
{false, false, false, false, true});
}
{
auto op = OprAttr::make("BatchNorm");
......@@ -210,8 +215,8 @@ TEST(TestImperative, BatchNormGrad) {
Param param;
param.fwd_mode = Param::FwdMode::TRAINING;
attr.param.write_pod(param);
OpDef::make_backward_graph(attr, {inp, stat, stat},
{true, true ,true}, {false, false, true});
OpDef::make_backward_graph(attr, {inp, stat, stat}, {true, true, true},
{false, false, true});
}
}
......@@ -220,7 +225,8 @@ TEST(TestImperative, OptimizedBackwardGraphBasic) {
LogicalTensorDesc desc = {TensorLayout(dtype::Float32()), cn};
HostTensorGenerator<> gen;
auto op = std::shared_ptr<OpDef>(Elemwise::make(Elemwise::Mode::ADD));
auto bg = OpDef::make_backward_graph(*op, {desc, desc}, {true, true}, {true});
auto bg =
OpDef::make_backward_graph(*op, {desc, desc}, {true, true}, {true});
auto obg = OptimizedBackwardGraphResult(bg);
ASSERT_EQ(obg.save_for_backward.size(), 4);
ASSERT_FALSE(obg.save_for_backward[0]);
......@@ -235,30 +241,30 @@ TEST(TestImperative, OptimizedBackwardGraphBasic) {
auto dc_tn = Tensor::make(*dc_hv);
auto c_tn = OpDef::apply_on_physical_tensor(*op, {a_tn, b_tn})[0];
auto backward_graph_inputs = prepare_backward_graph_inputs<SmallVector<TensorPtr>>(bg, {a_tn, b_tn}, {c_tn}, {dc_tn});
auto grads = expand_grads(bg, bg.backward.apply(
backward_graph_inputs,
auto backward_graph_inputs =
prepare_backward_graph_inputs<SmallVector<TensorPtr>>(
bg, {a_tn, b_tn}, {c_tn}, {dc_tn});
auto grads =
expand_grads(bg, bg.backward.apply(backward_graph_inputs,
apply_shared_on_physical_tensor,
[&](auto&& x){ return x; }
));
[&](auto&& x) { return x; }));
auto precomp = obg.precomp.apply(
SmallVector<TensorPtr>{a_tn, b_tn, c_tn},
auto precomp = obg.precomp.apply(SmallVector<TensorPtr>{a_tn, b_tn, c_tn},
apply_shared_on_physical_tensor,
[&](auto&& x){ return x; }
);
[&](auto&& x) { return x; });
ASSERT_EQ(precomp.size(), 2);
ASSERT_EQ(precomp[0]->shape().ndim, 1);
ASSERT_LE(precomp[0]->shape()[0], 2);
ASSERT_EQ(precomp[1]->shape().ndim, 1);
ASSERT_LE(precomp[1]->shape()[0], 2);
auto backward_inputs = prepare_optimized_backward_inputs<SmallVector<TensorPtr>>(obg, precomp, {a_tn, b_tn}, {c_tn}, {dc_tn});
auto grads2 = expand_grads(obg, obg.backward.apply(
backward_inputs,
apply_shared_on_physical_tensor,
[&](auto&& x){ return x; }
));
auto backward_inputs =
prepare_optimized_backward_inputs<SmallVector<TensorPtr>>(
obg, precomp, {a_tn, b_tn}, {c_tn}, {dc_tn});
auto grads2 = expand_grads(
obg,
obg.backward.apply(backward_inputs, apply_shared_on_physical_tensor,
[&](auto&& x) { return x; }));
ASSERT_EQ(grads2.size(), 2);
MGB_ASSERT_TENSOR_EQ(grads[0]->get_value(), grads2[0]->get_value());
......
......@@ -271,6 +271,7 @@ def BatchedSetMeshIndexing: FancyIndexingBase<"BatchedSetMeshIndexing">;
def FakeQuant: MgbHashableOp<"FakeQuant", [FakeQuantParam]>;
def AssertEqual: MgbHashableOp<"AssertEqual",[AssertEqualParam]>;
def TQT: MgbHashableOp<"TQT", [TQTParam]>;
def LSQ: MgbHashableOp<"LSQ", [LSQParam]>;
def ElemwiseMultiType: MgbHashableOp<"ElemwiseMultiType", [ElemwiseMultiTypeParam]> {
let extraArguments = (ins
MgbDTypeAttr:$dtype
......
......@@ -324,5 +324,7 @@ decl_opr('FakeQuant',
decl_opr('TQT',
inputs=[Doc('src','input tensor'),Doc('scale','scale tensor')],
params='TQT')
decl_opr('LSQ',
inputs=[Doc('src','input tensor'),Doc('scale','scale tensor'),Doc('zero_point','zero point tensor'),Doc('grad_scale','grad scale tensor')],
params='LSQ')
# vim: ft=python
......@@ -6,20 +6,22 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "megbrain/opr/dnn/adaptive_pooling.h"
#include "megbrain/opr/dnn/batch_norm.h"
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/dnn/correlation.h"
#include "megbrain/opr/dnn/fake_quant.h"
#include "megbrain/opr/dnn/images2neibs.h"
#include "megbrain/opr/dnn/pooling.h"
#include "megbrain/opr/dnn/adaptive_pooling.h"
#include "megbrain/opr/dnn/roi_pooling.h"
#include "megbrain/opr/dnn/roi_align.h"
#include "megbrain/opr/dnn/local.h"
#include "megbrain/opr/dnn/lrn.h"
#include "megbrain/opr/dnn/fake_quant.h"
#include "megbrain/opr/dnn/lsq.h"
#include "megbrain/opr/dnn/pooling.h"
#include "megbrain/opr/dnn/roi_align.h"
#include "megbrain/opr/dnn/roi_pooling.h"
#include "megbrain/opr/dnn/tqt.h"
#include "megbrain/serialization/sereg.h"
#include "megdnn/opr_param_defs.h"
......@@ -183,7 +185,8 @@ struct ConvLoadDumpImpl {
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) {
auto&& opr = opr_.cast_final_safe<Opr>();
ctx.write_param<ConvParam>(opr.param());
ctx.write_param<megdnn::param::ExecutionPolicy>(opr.execution_policy_transient());
ctx.write_param<megdnn::param::ExecutionPolicy>(
opr.execution_policy_transient());
}
static VarNode* make(const cg::VarNodeArray& inputs, const ConvParam& param,
......@@ -251,6 +254,20 @@ struct OprMaker<opr::TQTBackward, 3> {
}
};
template <>
struct OprMaker<opr::LSQBackward, 5> {
using Param = opr::LSQBackward::Param;
static cg::OperatorNodeBase* make(const Param& param,
const cg::VarNodeArray& i,
ComputingGraph& graph,
const OperatorNodeConfig& config) {
MGB_MARK_USED_VAR(graph);
return opr::LSQBackward::make(i[0], i[1], i[2], i[3], i[4], param,
config)[0]
.node()
->owner_opr();
}
};
template <>
struct OprLoadDumpImpl<opr::AdaptivePoolingBackward, 0>
: public PoolingLoadDumpImpl<opr::AdaptivePoolingBackward,
......@@ -587,6 +604,8 @@ MGB_SEREG_OPR(FakeQuant, 3);
MGB_SEREG_OPR(FakeQuantBackward, 4);
MGB_SEREG_OPR(TQT, 2);
MGB_SEREG_OPR(TQTBackward, 3);
MGB_SEREG_OPR(LSQ, 4);
MGB_SEREG_OPR(LSQBackward, 5);
} // namespace opr
......
/**
* \file src/opr/impl/dnn/lsq.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "megbrain/opr/dnn/lsq.h"
#include "../internal/megdnn_opr_wrapper.inl"
#include "megbrain/graph/grad_impl.h"
#include "megbrain/opr/basic_arith_wrapper.h"
#include "megbrain/opr/internal/out_shape_by_sym_var.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/utility.h"
using namespace mgb;
using namespace opr;
MGB_DYN_TYPE_OBJ_FINAL_IMPL(LSQForward);
MEGDNN_OPR_INIT4(LSQForward, "lsq_fwd");
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(LSQForward) {
SymbolVarArray grad =
LSQBackward::make(out_grad[0], opr.input(0), opr.input(1),
opr.input(2), opr.input(3), opr.param());
if (wrt_idx == 0) {
return grad[0].node();
} else if (wrt_idx == 1) {
return reduce_sum(grad[1], GetVarShape::make(opr.input(wrt_idx)))
.node();
} else {
return nullptr;
}
}
#endif
MGB_DYN_TYPE_OBJ_FINAL_IMPL(LSQBackward);
LSQBackward::LSQBackward(VarNode* y_grad, VarNode* x, VarNode* scale,
VarNode* zero_point, VarNode* grad_scale,
const Param& param, const OperatorNodeConfig& config)
: Super({x->owner_graph(),
config,
"lsq_bwd",
{y_grad, x, scale, zero_point, grad_scale}},
1, true) {
init_megdnn_opr(*this, param);
add_input({y_grad, x, scale, zero_point, grad_scale});
}
SymbolVarArray LSQBackward::make(SymbolVar y_grad, SymbolVar x, SymbolVar scale,
SymbolVar zero_point, SymbolVar grad_scale,
const Param& param,
const OperatorNodeConfig& config) {
auto&& out = x.node()->owner_graph()
->insert_opr(std::make_unique<LSQBackward>(
y_grad.node(), x.node(), scale.node(),
zero_point.node(), grad_scale.node(), param,
config))
->output();
SymbolVarArray ret(out.size());
for (size_t i = 0; i < ret.size(); ++i) {
ret[i] = out[i];
}
return ret;
}
void LSQBackward::init_output_static_infer_desc() {
using namespace cg::static_infer;
auto&& mgr = owner_graph()->static_infer_manager();
mgr.register_shape_infer(output(0),
ShapeInferDesc::make_identity(input(1)));
mgr.register_shape_infer(output(1),
ShapeInferDesc::make_identity(input(1)));
this->init_output_static_infer_desc_workspace(
intl::AutoAddWorkspaceNeedLimitGetter<megdnn::LSQBackward>::val);
}
void LSQBackward::init_output_dtype() {
output(0)->dtype(input(1)->dtype());
output(1)->dtype(input(2)->dtype());
}
/**
* \file src/opr/include/megbrain/opr/dnn/lsq.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
#include "megdnn/oprs.h"
namespace mgb {
namespace opr {
MGB_DEFINE_OPR_CLASS(LSQForward,
intl::MegDNNOprWrapperFwd<megdnn::LSQForward>) // {
public:
LSQForward(VarNode* src, VarNode* scale, VarNode* zero_point,
VarNode* grad_scale, const Param& param,
const OperatorNodeConfig& config);
static SymbolVar make(SymbolVar src, SymbolVar scale, SymbolVar zero_point,
SymbolVar grad_scale, const Param& param = {},
const OperatorNodeConfig& config = {});
};
using LSQ = LSQForward;
MGB_DEFINE_OPR_CLASS(LSQBackward,
intl::MegDNNOprWrapperBwd<megdnn::LSQBackward>) // {
public:
LSQBackward(VarNode* y_grad, VarNode* x, VarNode* scale, VarNode* zero_point,
VarNode* grad_scale, const Param& param,
const OperatorNodeConfig& config);
static SymbolVarArray make(SymbolVar y_grad, SymbolVar x, SymbolVar scale,
SymbolVar zero_point, SymbolVar grad_scale,
const Param& param = {},
const OperatorNodeConfig& config = {});
private:
void init_output_static_infer_desc() override;
void init_output_dtype() override;
};
} // namespace opr
} // namespace mgb
/**
* \file src/opr/test/dnn/lsq.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "megbrain/opr/dnn/lsq.h"
#include "megbrain/comp_node_env.h"
#include "megbrain/test/autocheck.h"
using namespace std;
using namespace mgb;
namespace {
void run() {
using Checker = AutoOprChecker<4, 1>;
auto make_graph =
[&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray {
auto o0 = opr::LSQForward::make(inputs[0], inputs[1], inputs[2],
inputs[3]);
return {o0};
};
auto fwd = [&](Checker::NumOutArray& dest, Checker::NumInpArray inp) {
auto opr = MegDNNHandle::get(
CompNodeEnv::from_comp_node(CompNode::default_cpu()))
->create_operator<megdnn::LSQForward>();
dest[0].dtype(dtype::Float32())
.comp_node(inp[0]->comp_node())
.resize(inp[0]->shape());
opr->exec(inp[0]->as_megdnn(), inp[1]->as_megdnn(), inp[2]->as_megdnn(),
inp[3]->as_megdnn(), dest[0].as_megdnn(), {});
};
auto gen = [&](HostTensorND& src) {
HostTensorGenerator<dtype::Float32, RandomDistribution::GAUSSIAN>
src_gen(10.f);
src = *src_gen(src.shape(), src.comp_node());
};
Checker::RunOptions opt;
opt.numdiff_max_err = 1e-5;
Checker checker{make_graph, fwd};
checker.set_input_generator(0, gen)
.set_input_generator(1, gen)
.set_input_generator(2, gen)
.set_input_generator(3, gen)
.set_input_allow_grad(0, false)
.set_input_allow_grad(1, false)
.set_input_allow_grad(2, false)
.set_input_allow_grad(3, false)
.set_output_allow_grad(0, false);
checker.run({TensorShape{1, 2, 3, 4}, TensorShape{1}, TensorShape{1},
TensorShape{1}},
opt)
.run({TensorShape{2, 3, 8, 8}, TensorShape{1}, TensorShape{1},
TensorShape{1}},
opt)
.run({TensorShape{1, 3, 4, 4}, TensorShape{1}, TensorShape{1},
TensorShape{1}},
opt);
}
} // anonymous namespace
TEST(TestOprDNN, LSQForward) {
REQUIRE_GPU(1);
run();
}
\ No newline at end of file
......@@ -107,6 +107,7 @@ union OperatorParam {
param.FakeQuant = 73,
param.TQT = 74,
param.Correlation = 75,
param.LSQ = 76,
}
table Operator {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册