提交 b60cc8ca 编写于 作者: M Megvii Engine Team

feat(mgb/opr): add megbrain fake quant opr

GitOrigin-RevId: a858bce939c7a68e311bbaf8417ce3feb2227537
上级 c03249c0
......@@ -315,5 +315,9 @@ r"""
"""),
has_out_dtype=True)
decl_opr('FakeQuant',
inputs=[Doc('src','input tenosr'),Doc('scale','scale tensor'),Doc('zero_point','zero point tensor')],
params='FakeQuant')
# vim: ft=python
......@@ -18,6 +18,7 @@
#include "megbrain/opr/dnn/roi_align.h"
#include "megbrain/opr/dnn/local.h"
#include "megbrain/opr/dnn/lrn.h"
#include "megbrain/opr/dnn/fake_quant.h"
#include "megbrain/serialization/sereg.h"
......@@ -423,6 +424,8 @@ namespace opr {
MGB_SEREG_OPR(DeformablePSROIPoolingBackward, 5);
MGB_SEREG_OPR(BatchConvBiasForward, 0);
MGB_SEREG_OPR(FakeQuant, 3);
MGB_SEREG_OPR(FakeQuantBackward, 4);
} // namespace opr
......
/**
* \file src/opr/impl/dnn/fake_quant.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "megbrain/opr/dnn/fake_quant.h"
#include "../internal/megdnn_opr_wrapper.inl"
#include "megbrain/graph/grad_impl.h"
#include "megbrain/opr/internal/out_shape_by_sym_var.h"
#include "megbrain/opr/utility.h"
using namespace mgb;
using namespace opr;
MGB_DYN_TYPE_OBJ_FINAL_IMPL(FakeQuantForward);
MEGDNN_OPR_INIT3(FakeQuantForward, "fakequant_fwd");
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(FakeQuantForward) {
if (wrt_idx == 0) {
// wrt src
SymbolVar grad =
FakeQuantBackward::make(out_grad[0], opr.input(0), opr.input(1),
opr.input(2), opr.param());
return grad.node();
} else {
return nullptr;
}
}
#endif
MGB_DYN_TYPE_OBJ_FINAL_IMPL(FakeQuantBackward);
MEGDNN_OPR_INIT4(FakeQuantBackward, "fakequant_bwd", 1, true);
/**
* \file src/opr/include/megbrain/opr/dnn/fake_quant.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
#include "megdnn/oprs.h"
namespace mgb {
namespace opr {
MGB_DEFINE_OPR_CLASS(FakeQuantForward,
intl::MegDNNOprWrapperFwd<megdnn::FakeQuantForward>) // {
public:
FakeQuantForward(VarNode* src, VarNode* scale, VarNode* zero_point,
const Param& param, const OperatorNodeConfig& config);
static SymbolVar make(SymbolVar src, SymbolVar scale, SymbolVar zero_point,
const Param& param = {},
const OperatorNodeConfig& config = {});
}; // namespace opr
using FakeQuant = FakeQuantForward;
MGB_DEFINE_OPR_CLASS(FakeQuantBackward,
intl::MegDNNOprWrapperBwd<megdnn::FakeQuantBackward>) // {
public:
FakeQuantBackward(VarNode* diff, VarNode* input, VarNode* scale,
VarNode* zero_point, const Param& param,
const OperatorNodeConfig& config);
static SymbolVar make(SymbolVar diff, SymbolVar input, SymbolVar scale,
SymbolVar zero_point, const Param& param = {},
const OperatorNodeConfig& config = {});
};
} // namespace mgb
} // namespace opr
\ No newline at end of file
......@@ -102,6 +102,7 @@ union OperatorParam {
param.AdaptivePooling = 70,
param.NvOf = 71,
param.DctChannelSelect = 72,
param.FakeQuant = 73,
}
table Operator {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册