提交 273f891b 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

fix(mgb/gopt): fix run-time winograd-transform and nchwxx error

GitOrigin-RevId: aca796f17defd041802e926e9f0742b02dd48de4
上级 02abc36e
......@@ -310,7 +310,8 @@ bool ConvBiasImpl::AlgoS8CF32WinogradF23_4x4_NCHW44::usable(
(param.filter_meta.dilation[0] ==
param.filter_meta.dilation[1] &&
param.filter_meta.dilation[0] == 1) &&
param.compute_mode == param::ConvBias::ComputeMode::FLOAT32 &&
(param.compute_mode == param::ConvBias::ComputeMode::FLOAT32 ||
param.compute_mode == param::ConvBias::ComputeMode::DEFAULT) &&
param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.bias_type.enumv() == DTypeEnum::QuantizedS32 &&
param.dst_type.enumv() == DTypeEnum::QuantizedS8;
......
......@@ -76,7 +76,7 @@ public:
ohw_tile_size));
all_algos.emplace_back(refhold.back().get());
}
for (size_t oc_tile_size : {24, 48}) {
for (size_t oc_tile_size : {48, 24}) {
refhold.emplace_back(new AlgoConv1x1(
static_cast<MatrixMulImpl::AlgoBase*>(algo),
oc_tile_size));
......
......@@ -992,7 +992,6 @@ Args Args::from_argv(int argc, char **argv) {
graph_opt.graph_opt.enable_nchw44_dot();
continue;
}
if (!strcmp(argv[i], "--enable-fuse-conv-bias-nonlinearity")) {
mgb_log_warn("enable fuse-conv-bias-nonlinearity optimization");
graph_opt.graph_opt.enable_fuse_conv_bias_nonlinearity();
......@@ -1202,7 +1201,7 @@ Args Args::from_argv(int argc, char **argv) {
}
if (!strcmp(argv[i], "--winograd-transform")) {
mgb_log_warn("enable winograd transform");
graph_opt.graph_opt.winograd_transform = true;
graph_opt.graph_opt.weight_winograd_transform = true;
continue;
}
......
......@@ -468,10 +468,6 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare(
}
#endif
if (options().graph_opt.winograd_transform) {
options().graph_opt.winograd_transform = false;
gopt::transform_vars_inplace_with_winograd(dest_vars);
}
#if MGB_JIT
if (std::abs(options().graph_opt_level) == 0 && options().graph_opt.jit) {
......
......@@ -95,6 +95,8 @@ struct GraphCommonOptimizeOptions {
//! fuse pattern like ReLU(conv_bias(x, w, b) + z) or conv_bias(x, w, b)
//! + z -> conv_bias(x, w, b, z)
bool fuse_conv_bias_with_z = false;
//! whether to enable fast-run profiled winograd opr replace
bool weight_winograd_transform = false;
enum LayoutTransform : uint32_t {
DEFAULT,
NCHW4, ///< compute using NCHW4 tensor format
......@@ -124,6 +126,7 @@ struct GraphCommonOptimizeOptions {
SET(f16_io_comp);
SET(fuse_conv_bias_nonlinearity);
SET(fuse_conv_bias_with_z);
SET(weight_winograd_transform);
#undef SET
#define SET(_trans, _trans_capital) \
GraphCommonOptimizeOptions& enable_##_trans() { \
......@@ -307,8 +310,6 @@ class ComputingGraph : public std::enable_shared_from_this<ComputingGraph>,
uint8_t jit = 0;
//! whether to enable fine-grained TensorRT opr replace
bool tensorrt = false;
//! whether to enable fast-run profiled winograd opr replace
bool winograd_transform = false;
} graph_opt;
//! get attribute for an operator
......
......@@ -10,15 +10,16 @@
*/
#include "megbrain/gopt/framework.h"
#include "megbrain/gopt/inference.h"
#include "megbrain/gopt/basic_arith.h"
#include "megbrain/gopt/misc.h"
#include "megbrain/gopt/gtrans.h"
#include "megbrain/gopt/inference.h"
#include "megbrain/gopt/misc.h"
#include "megbrain/gopt/weights_preprocess.h"
#include "megbrain/graph/cg.h"
#include "megbrain/graph/event.h"
#include "megbrain/graph/exc_extra_info.h"
#include "megbrain/serialization/serializer.h"
#include "megbrain/serialization/opr_shallow_copy.h"
#include "megbrain/serialization/serializer.h"
#include "megbrain/utils/timer.h"
#if MGB_JIT
......@@ -773,6 +774,8 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options(
add_pass<FuseConvBiasZPass>();
});
cb(weight_winograd_transform,
{ add_pass<WinogradTransformReplacePass>(); });
#undef cb
if (need_param_fuse) {
......
......@@ -24,6 +24,10 @@ const char* WinogradTransformReplacePass::name() const {
void WinogradTransformReplacePass::apply(OptState& opt) const {
auto rewriter = opt.graph().make_rewriter();
ConstVarPropogate cvprop{ConstVarType::IMMUTABLE_AND_PARAM};
opt.graph().iter([&cvprop](OperatorNodeBase *opr) {
cvprop.add_opr(opr);
});
auto get_algo = [](const opr::ConvBias& opr) -> std::string {
auto&& inputs = opr.input();
......@@ -75,12 +79,10 @@ void WinogradTransformReplacePass::apply(OptState& opt) const {
for (auto i : inputs) {
new_inp.push_back(rewriter.get_var(i));
}
if (!inputs[1]->contain_flag(
VarNode::Flag::PERSISTENT_DEVICE_VALUE)) {
if (!(cvprop.is_midconst(inputs[1]) ||
cvprop.is_const(inputs[1]))) {
break;
}
auto algo_name = get_algo(conv_bias_opr);
auto winograd_param =
megdnn::ConvBias::parse_winograd_name(algo_name);
......
......@@ -672,14 +672,9 @@ void AlgoChooser<megdnn::ConvBias>::get_origin_param_and_layouts(
auto format = static_cast<megdnn::param::ConvBias::Format>(
ctx.megdnn_opr()->param().format);
size_t output_block_size = ctx.megdnn_opr()->param().output_block_size;
TensorLayout origin_layout;
megdnn::ConvBias::deduce_winograd_origin_layout_and_param(
format, output_block_size, ctx.layouts()[0], ctx.layouts()[1],
origin_layout, param);
for (size_t i = 0; i < ctx.layouts().size(); i++) {
layouts[i] = ctx.layouts()[i];
}
layouts[1] = origin_layout;
layouts[1], param);
}
template <typename Opr>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册