提交 273f891b 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

fix(mgb/gopt): fix run-time winograd-transform and nchwxx error

GitOrigin-RevId: aca796f17defd041802e926e9f0742b02dd48de4
上级 02abc36e
...@@ -310,7 +310,8 @@ bool ConvBiasImpl::AlgoS8CF32WinogradF23_4x4_NCHW44::usable( ...@@ -310,7 +310,8 @@ bool ConvBiasImpl::AlgoS8CF32WinogradF23_4x4_NCHW44::usable(
(param.filter_meta.dilation[0] == (param.filter_meta.dilation[0] ==
param.filter_meta.dilation[1] && param.filter_meta.dilation[1] &&
param.filter_meta.dilation[0] == 1) && param.filter_meta.dilation[0] == 1) &&
param.compute_mode == param::ConvBias::ComputeMode::FLOAT32 && (param.compute_mode == param::ConvBias::ComputeMode::FLOAT32 ||
param.compute_mode == param::ConvBias::ComputeMode::DEFAULT) &&
param.src_type.enumv() == DTypeEnum::QuantizedS8 && param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.bias_type.enumv() == DTypeEnum::QuantizedS32 && param.bias_type.enumv() == DTypeEnum::QuantizedS32 &&
param.dst_type.enumv() == DTypeEnum::QuantizedS8; param.dst_type.enumv() == DTypeEnum::QuantizedS8;
......
...@@ -76,7 +76,7 @@ public: ...@@ -76,7 +76,7 @@ public:
ohw_tile_size)); ohw_tile_size));
all_algos.emplace_back(refhold.back().get()); all_algos.emplace_back(refhold.back().get());
} }
for (size_t oc_tile_size : {24, 48}) { for (size_t oc_tile_size : {48, 24}) {
refhold.emplace_back(new AlgoConv1x1( refhold.emplace_back(new AlgoConv1x1(
static_cast<MatrixMulImpl::AlgoBase*>(algo), static_cast<MatrixMulImpl::AlgoBase*>(algo),
oc_tile_size)); oc_tile_size));
......
...@@ -992,7 +992,6 @@ Args Args::from_argv(int argc, char **argv) { ...@@ -992,7 +992,6 @@ Args Args::from_argv(int argc, char **argv) {
graph_opt.graph_opt.enable_nchw44_dot(); graph_opt.graph_opt.enable_nchw44_dot();
continue; continue;
} }
if (!strcmp(argv[i], "--enable-fuse-conv-bias-nonlinearity")) { if (!strcmp(argv[i], "--enable-fuse-conv-bias-nonlinearity")) {
mgb_log_warn("enable fuse-conv-bias-nonlinearity optimization"); mgb_log_warn("enable fuse-conv-bias-nonlinearity optimization");
graph_opt.graph_opt.enable_fuse_conv_bias_nonlinearity(); graph_opt.graph_opt.enable_fuse_conv_bias_nonlinearity();
...@@ -1202,7 +1201,7 @@ Args Args::from_argv(int argc, char **argv) { ...@@ -1202,7 +1201,7 @@ Args Args::from_argv(int argc, char **argv) {
} }
if (!strcmp(argv[i], "--winograd-transform")) { if (!strcmp(argv[i], "--winograd-transform")) {
mgb_log_warn("enable winograd transform"); mgb_log_warn("enable winograd transform");
graph_opt.graph_opt.winograd_transform = true; graph_opt.graph_opt.weight_winograd_transform = true;
continue; continue;
} }
......
...@@ -468,10 +468,6 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare( ...@@ -468,10 +468,6 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare(
} }
#endif #endif
if (options().graph_opt.winograd_transform) {
options().graph_opt.winograd_transform = false;
gopt::transform_vars_inplace_with_winograd(dest_vars);
}
#if MGB_JIT #if MGB_JIT
if (std::abs(options().graph_opt_level) == 0 && options().graph_opt.jit) { if (std::abs(options().graph_opt_level) == 0 && options().graph_opt.jit) {
......
...@@ -95,6 +95,8 @@ struct GraphCommonOptimizeOptions { ...@@ -95,6 +95,8 @@ struct GraphCommonOptimizeOptions {
//! fuse pattern like ReLU(conv_bias(x, w, b) + z) or conv_bias(x, w, b) //! fuse pattern like ReLU(conv_bias(x, w, b) + z) or conv_bias(x, w, b)
//! + z -> conv_bias(x, w, b, z) //! + z -> conv_bias(x, w, b, z)
bool fuse_conv_bias_with_z = false; bool fuse_conv_bias_with_z = false;
//! whether to enable fast-run profiled winograd opr replace
bool weight_winograd_transform = false;
enum LayoutTransform : uint32_t { enum LayoutTransform : uint32_t {
DEFAULT, DEFAULT,
NCHW4, ///< compute using NCHW4 tensor format NCHW4, ///< compute using NCHW4 tensor format
...@@ -124,6 +126,7 @@ struct GraphCommonOptimizeOptions { ...@@ -124,6 +126,7 @@ struct GraphCommonOptimizeOptions {
SET(f16_io_comp); SET(f16_io_comp);
SET(fuse_conv_bias_nonlinearity); SET(fuse_conv_bias_nonlinearity);
SET(fuse_conv_bias_with_z); SET(fuse_conv_bias_with_z);
SET(weight_winograd_transform);
#undef SET #undef SET
#define SET(_trans, _trans_capital) \ #define SET(_trans, _trans_capital) \
GraphCommonOptimizeOptions& enable_##_trans() { \ GraphCommonOptimizeOptions& enable_##_trans() { \
...@@ -307,8 +310,6 @@ class ComputingGraph : public std::enable_shared_from_this<ComputingGraph>, ...@@ -307,8 +310,6 @@ class ComputingGraph : public std::enable_shared_from_this<ComputingGraph>,
uint8_t jit = 0; uint8_t jit = 0;
//! whether to enable fine-grained TensorRT opr replace //! whether to enable fine-grained TensorRT opr replace
bool tensorrt = false; bool tensorrt = false;
//! whether to enable fast-run profiled winograd opr replace
bool winograd_transform = false;
} graph_opt; } graph_opt;
//! get attribute for an operator //! get attribute for an operator
......
...@@ -10,15 +10,16 @@ ...@@ -10,15 +10,16 @@
*/ */
#include "megbrain/gopt/framework.h" #include "megbrain/gopt/framework.h"
#include "megbrain/gopt/inference.h"
#include "megbrain/gopt/basic_arith.h" #include "megbrain/gopt/basic_arith.h"
#include "megbrain/gopt/misc.h"
#include "megbrain/gopt/gtrans.h" #include "megbrain/gopt/gtrans.h"
#include "megbrain/gopt/inference.h"
#include "megbrain/gopt/misc.h"
#include "megbrain/gopt/weights_preprocess.h"
#include "megbrain/graph/cg.h" #include "megbrain/graph/cg.h"
#include "megbrain/graph/event.h" #include "megbrain/graph/event.h"
#include "megbrain/graph/exc_extra_info.h" #include "megbrain/graph/exc_extra_info.h"
#include "megbrain/serialization/serializer.h"
#include "megbrain/serialization/opr_shallow_copy.h" #include "megbrain/serialization/opr_shallow_copy.h"
#include "megbrain/serialization/serializer.h"
#include "megbrain/utils/timer.h" #include "megbrain/utils/timer.h"
#if MGB_JIT #if MGB_JIT
...@@ -773,6 +774,8 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options( ...@@ -773,6 +774,8 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options(
add_pass<FuseConvBiasZPass>(); add_pass<FuseConvBiasZPass>();
}); });
cb(weight_winograd_transform,
{ add_pass<WinogradTransformReplacePass>(); });
#undef cb #undef cb
if (need_param_fuse) { if (need_param_fuse) {
......
...@@ -24,6 +24,10 @@ const char* WinogradTransformReplacePass::name() const { ...@@ -24,6 +24,10 @@ const char* WinogradTransformReplacePass::name() const {
void WinogradTransformReplacePass::apply(OptState& opt) const { void WinogradTransformReplacePass::apply(OptState& opt) const {
auto rewriter = opt.graph().make_rewriter(); auto rewriter = opt.graph().make_rewriter();
ConstVarPropogate cvprop{ConstVarType::IMMUTABLE_AND_PARAM};
opt.graph().iter([&cvprop](OperatorNodeBase *opr) {
cvprop.add_opr(opr);
});
auto get_algo = [](const opr::ConvBias& opr) -> std::string { auto get_algo = [](const opr::ConvBias& opr) -> std::string {
auto&& inputs = opr.input(); auto&& inputs = opr.input();
...@@ -75,12 +79,10 @@ void WinogradTransformReplacePass::apply(OptState& opt) const { ...@@ -75,12 +79,10 @@ void WinogradTransformReplacePass::apply(OptState& opt) const {
for (auto i : inputs) { for (auto i : inputs) {
new_inp.push_back(rewriter.get_var(i)); new_inp.push_back(rewriter.get_var(i));
} }
if (!(cvprop.is_midconst(inputs[1]) ||
if (!inputs[1]->contain_flag( cvprop.is_const(inputs[1]))) {
VarNode::Flag::PERSISTENT_DEVICE_VALUE)) {
break; break;
} }
auto algo_name = get_algo(conv_bias_opr); auto algo_name = get_algo(conv_bias_opr);
auto winograd_param = auto winograd_param =
megdnn::ConvBias::parse_winograd_name(algo_name); megdnn::ConvBias::parse_winograd_name(algo_name);
......
...@@ -672,14 +672,9 @@ void AlgoChooser<megdnn::ConvBias>::get_origin_param_and_layouts( ...@@ -672,14 +672,9 @@ void AlgoChooser<megdnn::ConvBias>::get_origin_param_and_layouts(
auto format = static_cast<megdnn::param::ConvBias::Format>( auto format = static_cast<megdnn::param::ConvBias::Format>(
ctx.megdnn_opr()->param().format); ctx.megdnn_opr()->param().format);
size_t output_block_size = ctx.megdnn_opr()->param().output_block_size; size_t output_block_size = ctx.megdnn_opr()->param().output_block_size;
TensorLayout origin_layout;
megdnn::ConvBias::deduce_winograd_origin_layout_and_param( megdnn::ConvBias::deduce_winograd_origin_layout_and_param(
format, output_block_size, ctx.layouts()[0], ctx.layouts()[1], format, output_block_size, ctx.layouts()[0], ctx.layouts()[1],
origin_layout, param); layouts[1], param);
for (size_t i = 0; i < ctx.layouts().size(); i++) {
layouts[i] = ctx.layouts()[i];
}
layouts[1] = origin_layout;
} }
template <typename Opr> template <typename Opr>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册