padding_channel.cpp 27.0 KB
Newer Older
1 2
#include "megbrain/gopt/inference.h"
#include "megbrain/opr/basic_arith.h"
3
#include "megbrain/opr/dnn/adaptive_pooling.h"
4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/dnn/pooling.h"
#include "megbrain/opr/imgproc.h"
#include "megbrain/opr/misc.h"
#include "megbrain/opr/nn_int.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/utility.h"
#include "megbrain/serialization/opr_shallow_copy.h"

#include "megdnn/opr_param_defs.h"
#include "megdnn/tensor_format.h"

#include "megbrain/opr/internal/megdnn_opr_wrapper.h"

#include "megbrain/gopt/misc.h"
#include "megbrain/utils/hash_ct.h"

#include "midout.h"

#include "megbrain/gopt/reformat_manager.h"

MIDOUT_DECL(megbrain_padding_channel)
#define MIDOUT_B(tag) \
    MIDOUT_BEGIN(megbrain_padding_channel, midout_iv(MGB_HASH_STR(tag))) {
#define MIDOUT_E \
    }            \
    MIDOUT_END();

using namespace mgb;
using namespace gopt;
using ReformatKey = ReformatManager::ReformatKey;

/* ==================== PaddingChannelPass ================= */
37
namespace {
38 39

size_t padding_int4(size_t in_channel, bool) {
40 41 42 43 44 45 46
    if (in_channel <= 32) {
        return (8 - (in_channel % 8)) % 8;
    } else {
        return (64 - (in_channel % 64)) % 64;
    }
}

47 48
//! flag is used by user to identify some case, such as in nchw64, flag is used
//! to identify the convbias and convolution backward
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
size_t padding_int8(size_t in_channel, bool flag) {
    if (flag) {
        if (in_channel <= 16) {
            return (4 - (in_channel % 4)) % 4;
        } else {
            return (32 - (in_channel % 32)) % 32;
        }
    } else {
        return (4 - (in_channel % 4)) % 4;
    }
}
size_t padding_4(size_t in_channel, bool) {
    return (4 - (in_channel % 4)) % 4;
};

64 65 66 67
size_t padding_8(size_t in_channel, bool) {
    return (8 - (in_channel % 8)) % 8;
};

68 69 70
}  // namespace

std::unique_ptr<PaddingChannelPass> PaddingChannelPass::make(
71 72
        cg::GraphCommonOptimizeOptions::LayoutTransform layout_transform,
        bool only_padding_weights) {
73 74
    MIDOUT_B("PaddingChannelPass::make")
    using LayoutTrans = cg::GraphCommonOptimizeOptions::LayoutTransform;
75 76
    auto ret = std::unique_ptr<PaddingChannelPass>(
            new PaddingChannelPass(only_padding_weights));
77 78 79 80 81 82
    auto& alignment_map = ret->m_alignment_map;
    if (layout_transform == LayoutTrans::NCHW64) {
        alignment_map[DTypeEnum::QuantizedS4] = padding_int4;
        alignment_map[DTypeEnum::Quantized4Asymm] = padding_int4;
        alignment_map[DTypeEnum::QuantizedS8] = padding_int8;
    } else if (
83
            layout_transform == LayoutTrans::NHWCD4 ||
84 85 86 87 88
            layout_transform == LayoutTrans::NCHW44 ||
            layout_transform == LayoutTrans::NCHW44_DOT) {
        alignment_map[DTypeEnum::QuantizedS8] = padding_4;
        alignment_map[DTypeEnum::Quantized8Asymm] = padding_4;
        alignment_map[DTypeEnum::Float32] = padding_4;
89 90 91 92 93 94 95 96 97 98
#if !MEGDNN_DISABLE_FLOAT16
        alignment_map[DTypeEnum::Float16] = padding_4;
#endif
    } else if (layout_transform == LayoutTrans::NCHW88) {
        alignment_map[DTypeEnum::QuantizedS8] = padding_8;
        alignment_map[DTypeEnum::Quantized8Asymm] = padding_8;
        alignment_map[DTypeEnum::Float32] = padding_8;
#if !MEGDNN_DISABLE_FLOAT16
        alignment_map[DTypeEnum::Float16] = padding_8;
#endif
99 100 101 102 103
    }
    ret->fill_opr_convert_fun(layout_transform);
    return ret;
    MIDOUT_E
}
104 105 106 107 108 109 110
const char* PaddingChannelPass::name() const {
    return mgb_cstr_log("padding output channel to multiple of 4/32");
}

void PaddingChannelPass::apply(OptState& opt) const {
    MIDOUT_B("PaddingChannelPass::apply");
    // do not check shape
M
Megvii Engine Team 已提交
111 112
    opt.set_var_replace_check_flag(
            VarReplaceCheckFlag::CHECK_ALL ^ VarReplaceCheckFlag::CHECK_SHAPE);
113
    m_padding_oprs.clear();
114
    auto rewriter = opt.graph().make_rewriter();
115 116
    auto on_opr = [this, &opt, &rewriter](OperatorNodeBase* opr) {
        auto it = m_opr_replace_funcs.find(opr->dyn_typeinfo());
117 118 119 120 121 122 123 124
        auto is_skip = false;
        //! if the input of the opr is dynamic shape, skip it
        for (size_t id = 0; id < opr->input().size(); id++) {
            if (0 == opr->input(id)->shape().ndim) {
                is_skip = true;
            }
        }
        if (it != m_opr_replace_funcs.end() && !is_skip) {
125 126 127 128
            VarNodeArray new_inp;
            new_inp.reserve(opr->input().size());
            for (auto&& inp : opr->input()) {
                new_inp.push_back(rewriter.get_var(inp));
129
            }
130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
            auto new_opr = (it->second)(opr, new_inp);
            auto &&out0 = opr->output(), &&out1 = new_opr->output();
            mgb_assert(
                    out0.size() == out1.size(),
                    "bad opr replace: src=%s{%s} dst=%s{%s}, "
                    "src.size=%zu "
                    "dst.size=%zu",
                    opr->cname(), opr->dyn_typeinfo()->name, new_opr->cname(),
                    new_opr->dyn_typeinfo()->name, out0.size(), out1.size());
            for (size_t i = 0; i < out0.size(); ++i) {
                if (!out0[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
                    mgb_assert(!out1[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT));
                    auto src = out0[i];
                    auto dst = out1[i];
                    if (opt.graph().endpoint_contain(src) &&
                        !src->shape().eq_shape(dst->shape())) {
                        dst = extract_subtensor(dst, src->shape());
                    }
                    rewriter.replace_var(src, dst, nullptr);
                }
150 151
            }
        } else {
152
            rewriter.auto_replace_outputs(opr);
153 154
        }
    };
155 156
    opt.graph().iter(on_opr);
    rewriter.apply_inplace();
157

158 159 160 161 162 163 164 165 166 167
    MIDOUT_E
}

VarNode* PaddingChannelPass::extract_subtensor(
        VarNode* inp, const TensorShape& orig_shape) const {
    mgb_assert(inp->shape().ndim == 4);
    mgb_assert(inp->shape()[0] == orig_shape[0]);
    mgb_assert(inp->shape()[2] == orig_shape[2]);
    mgb_assert(inp->shape()[3] == orig_shape[3]);
    size_t orig_channels = orig_shape[1];
168 169 170 171
    //! if channel is not padding, do nothing
    if (orig_channels == inp->shape()[1]) {
        return inp;
    }
172 173 174 175 176 177 178 179 180 181 182 183
    auto x = SymbolVar(inp);
    auto cv = [&x](int v) { return x.make_scalar(v); };
    using AIdx = opr::Subtensor::AxisIndexer;
    auto sub = opr::Subtensor::make(
            x, {AIdx::make_interval(0, None, None, cv(1)),
                AIdx::make_interval(1, None, cv(orig_channels), None),
                AIdx::make_interval(2, None, None, cv(1)),
                AIdx::make_interval(3, None, None, cv(1))});
    return sub.node();
};

VarNode* PaddingChannelPass::pad_in_channels(VarNode* inp, size_t pad_channels) {
184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202
    TensorShape shape;
    size_t axis = 0;
    if (inp->shape().ndim == 4) {
        shape = TensorShape{
                inp->shape()[0], pad_channels, inp->shape()[2], inp->shape()[3]};
        axis = 1;
    } else {
        mgb_assert(inp->shape().ndim == 5);
        //! the channel wise convolution
        if (inp->shape()[1] == 1 && inp->shape()[2] == 1) {
            shape = TensorShape{
                    pad_channels, inp->shape()[1], inp->shape()[2], inp->shape()[3],
                    inp->shape()[4]};
            axis = 0;
        } else {
            //! the group convolution
            mgb_assert(0, "group convolution can't padding cahnnel\n");
        }
    }
203 204 205 206 207 208 209
    std::shared_ptr<HostTensorND> host_val =
            std::make_shared<HostTensorND>(inp->comp_node(), inp->dtype());
    host_val->resize(shape);
    auto ptr = host_val->raw_ptr();
    size_t size_bytes = TensorLayout{shape, inp->dtype()}.span().dist_byte();
    std::memset(ptr, 0, size_bytes);
    auto padding = opr::ImmutableTensor::make(*inp->owner_graph(), *host_val);
210
    auto out = opr::Concat::make({inp, padding}, axis);
211 212 213 214
    return out.node();
};

VarNode* PaddingChannelPass::pad_out_channels(VarNode* inp, size_t pad_channels) {
215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233
    TensorShape shape;
    size_t axis = 0;
    if (inp->shape().ndim == 4) {
        shape = TensorShape{
                pad_channels, inp->shape()[1], inp->shape()[2], inp->shape()[3]};
        axis = 0;
    } else {
        mgb_assert(inp->shape().ndim == 5);
        //! the channel wise convolution
        if (inp->shape()[1] == 1 && inp->shape()[2] == 1) {
            shape = TensorShape{
                    pad_channels, inp->shape()[1], inp->shape()[2], inp->shape()[3],
                    inp->shape()[4]};
            axis = 0;
        } else {
            //! the group convolution
            mgb_assert(0, "group convolution can't padding cahnnel\n");
        }
    }
234 235 236 237 238 239 240
    std::shared_ptr<HostTensorND> host_val =
            std::make_shared<HostTensorND>(inp->comp_node(), inp->dtype());
    host_val->resize(shape);
    auto ptr = host_val->raw_ptr();
    size_t size_bytes = TensorLayout{shape, inp->dtype()}.span().dist_byte();
    std::memset(ptr, 0, size_bytes);
    auto padding = opr::ImmutableTensor::make(*inp->owner_graph(), *host_val);
241
    auto out = opr::Concat::make({inp, padding}, axis);
242 243 244
    return out.node();
};

245 246
// padding policy for dense convolution
OperatorNodeBase* PaddingChannelPass::padding_conv_policy(
247 248
        OperatorNodeBase* opr, const VarNodeArray& new_inp) {
    mgb_assert(opr->input().size() == new_inp.size());
249
    mgb_assert(new_inp.size() >= 2);
250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265
    //! new weights and old weights are same shape
    mgb_assert(opr->input(1)->shape().eq_shape(new_inp[1]->shape()));
    auto inps = new_inp;
    size_t out_channels = opr->input(1)->shape()[0];
    size_t in_channels = opr->input(1)->shape()[1];
    size_t new_in_channels = new_inp[0]->shape()[1];
    auto it = m_alignment_map.find(opr->input(0)->dtype().enumv());
    if (it != m_alignment_map.end()) {
        mgb_assert(it->second);
    } else {
        return serialization::copy_opr_shallow(*opr, inps, opr->config());
    }
    // pad input channels
    if (m_padding_oprs.count(opr->input(0)->owner_opr())) {
        //! as the opr of input var is padding, but the dtype of input and output of
        //! the input opr maybe different, so the alignment is not the same
266 267
        size_t pad_channels_0 =
                m_only_padding_weights ? 0 : it->second(new_in_channels, true);
268 269 270
        size_t pad_channels_1 = it->second(in_channels, true);
        if (pad_channels_0) {
            inps[0] = pad_in_channels(new_inp[0], pad_channels_0);
271
        } else {
272
            pad_channels_1 = new_in_channels - in_channels;
273
        }
274 275
        if (pad_channels_1) {
            inps[1] = pad_in_channels(new_inp[1], pad_channels_1);
276
        }
277 278 279
    } else {
        mgb_assert(new_in_channels == in_channels);
        size_t pad_channels = it->second(in_channels, true);
280
        if (pad_channels > 0 && !m_only_padding_weights) {
281 282
            inps[0] = pad_in_channels(new_inp[0], pad_channels);
            inps[1] = pad_in_channels(new_inp[1], pad_channels);
283
        }
284 285 286 287 288
    }
    out_channels = inps[1]->shape()[0];
    size_t pad_channels = it->second(out_channels, true);
    if (pad_channels > 0) {
        inps[1] = pad_out_channels(inps[1], pad_channels);
289 290 291
        if (inps.size() >= 3) {
            inps[2] = pad_in_channels(inps[2], pad_channels);
        }
292 293 294 295
        m_padding_oprs.insert(opr);
    }
    return serialization::copy_opr_shallow(*opr, inps, opr->config());
};
296

297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318
//! padding policy for channel wise convolution
OperatorNodeBase* PaddingChannelPass::padding_channel_wise_conv_policy(
        OperatorNodeBase* opr, const VarNodeArray& new_inp) {
    mgb_assert(opr->input().size() == new_inp.size());
    mgb_assert(opr->input()[1]->shape().ndim == 5);
    mgb_assert(new_inp.size() >= 2);
    //! new weights and old weights are same shape
    mgb_assert(opr->input(1)->shape().eq_shape(new_inp[1]->shape()));
    auto inps = new_inp;
    size_t group = opr->input(1)->shape()[0];
    size_t new_in_channels = new_inp[0]->shape()[1];
    auto it = m_alignment_map.find(opr->input(0)->dtype().enumv());
    if (it != m_alignment_map.end()) {
        mgb_assert(it->second);
    } else {
        return serialization::copy_opr_shallow(*opr, inps, opr->config());
    }
    // pad input channels
    if (m_padding_oprs.count(opr->input(0)->owner_opr())) {
        size_t pad_channels_1 = new_in_channels - group;
        if (pad_channels_1) {
            inps[1] = pad_in_channels(new_inp[1], pad_channels_1);
319 320 321
            if (inps.size() >= 3) {
                inps[2] = pad_in_channels(new_inp[2], pad_channels_1);
            }
322 323 324 325 326 327
            m_padding_oprs.insert(opr);
        }
    }
    return serialization::copy_opr_shallow(*opr, inps, opr->config());
};

328
void PaddingChannelPass::fill_opr_convert_fun(LayoutTrans layout_trans) {
329
    add_conv_replace_func(layout_trans);
330 331 332
    add_conv_backward_data_replace_func(layout_trans);
    add_format_aware_opr_replace_func(layout_trans);
    add_elemwise_like_opr_replace_func(layout_trans);
333
    add_condition_padding_oprs_replace_func(layout_trans);
334 335 336
    add_nonpadding_oprs_replace_func(layout_trans);
}

337
void PaddingChannelPass::add_conv_replace_func(LayoutTrans layout_trans) {
338 339 340
    if (layout_trans == LayoutTrans::NCHW64) {
        m_opr_replace_funcs[opr::ConvBiasForward::typeinfo()] =
                [this](OperatorNodeBase* opr, const VarNodeArray& new_inp) {
341 342 343 344 345 346 347 348
                    mgb_assert(
                            opr->input()[1]->shape().ndim == 4,
                            "nchw64 format only support padding channel in dense "
                            "convolution\n");
                    if (opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8 ||
                        opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS4 ||
                        opr->input(0)->dtype().enumv() == DTypeEnum::Quantized4Asymm) {
                        return padding_conv_policy(opr, new_inp);
349 350 351 352 353 354 355 356 357 358 359 360
                    } else {
                        mgb_assert(
                                m_padding_oprs.count(opr->input(0)->owner_opr()) == 0,
                                "conv bias operator for data type(%s) cannot be "
                                "padded channel. "
                                "consumer(%s), producer(%s)",
                                opr->input(0)->dtype().name(), opr->cname(),
                                opr->input(0)->owner_opr()->cname());
                        return serialization::copy_opr_shallow(
                                *opr, new_inp, opr->config());
                    }
                };
361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390
    } else if (
            layout_trans == LayoutTrans::NCHW44 ||
            layout_trans == LayoutTrans::NCHW44_DOT ||
            layout_trans == LayoutTrans::NCHW88) {
        auto padding_conv = [this](OperatorNodeBase* opr, const VarNodeArray& new_inp) {
            if (opr->input()[1]->shape().ndim == 4) {
                return padding_conv_policy(opr, new_inp);
            } else {
                mgb_assert(opr->input()[1]->shape().ndim == 5);
                if (opr->input()[1]->shape()[1] == 1 &&
                    opr->input()[1]->shape()[2] == 1) {
                    return padding_channel_wise_conv_policy(opr, new_inp);
                } else {
                    //! group convolution can't padding channel
                    mgb_assert(opr->input().size() == new_inp.size());
                    auto inps = new_inp;
                    for (size_t i = 0; i < new_inp.size(); ++i) {
                        auto cur_inp = opr->input(i);
                        bool padding_cur_inp =
                                m_padding_oprs.count(cur_inp->owner_opr()) > 0;
                        if (padding_cur_inp) {
                            inps[i] = extract_subtensor(inps[i], cur_inp->shape());
                        }
                    }
                    return serialization::copy_opr_shallow(*opr, inps, opr->config());
                }
            }
        };
        m_opr_replace_funcs[opr::ConvBiasForward::typeinfo()] = padding_conv;
        m_opr_replace_funcs[opr::Convolution::typeinfo()] = padding_conv;
391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410
    }
}

void PaddingChannelPass::add_conv_backward_data_replace_func(LayoutTrans layout_trans) {
    if (layout_trans == LayoutTrans::NCHW64) {
        m_opr_replace_funcs[opr::ConvolutionBackwardData::typeinfo()] =
                [this](OperatorNodeBase* opr, const VarNodeArray& new_inp) {
                    if (opr->input(1)->dtype().enumv() != DTypeEnum::QuantizedS8) {
                        mgb_assert(
                                m_padding_oprs.count(opr->input(0)->owner_opr()) == 0,
                                "conv bwd data operator for data type(%s) cannot "
                                "be "
                                "padded channel. "
                                "consumer(%s), producer(%s)",
                                opr->input(0)->dtype().name(), opr->cname(),
                                opr->input(0)->owner_opr()->cname());
                        return serialization::copy_opr_shallow(
                                *opr, new_inp, opr->config());
                    }
                    mgb_assert(opr->input().size() == new_inp.size());
411
                    mgb_assert(
412 413 414 415 416 417 418 419 420 421 422 423 424
                            new_inp.size() == 2,
                            "deconv (conv bwd data) operator for inference can "
                            "only have 2 input vars(got:%zu)",
                            new_inp.size());
                    mgb_assert(opr->input(0)->shape().eq_shape(new_inp[0]->shape()));
                    auto inps = new_inp;
                    size_t out_channels = opr->input(0)->shape()[0];
                    size_t in_channels = opr->input(0)->shape()[1];
                    size_t new_out_channels = new_inp[1]->shape()[1];
                    auto it = m_alignment_map.find(opr->input(1)->dtype().enumv());
                    // pad output channels
                    if (m_padding_oprs.count(opr->input(1)->owner_opr())) {
                        size_t pad_channels = new_out_channels - out_channels;
425
                        inps[0] = pad_out_channels(new_inp[0], pad_channels);
426
                    } else {
427 428 429
                        size_t pad_channels = m_only_padding_weights
                                                    ? 0
                                                    : it->second(out_channels, false);
430 431 432 433
                        if (pad_channels > 0) {
                            inps[0] = pad_out_channels(new_inp[0], pad_channels);
                            inps[1] = pad_in_channels(new_inp[1], pad_channels);
                        }
434
                    }
435 436 437 438 439 440 441
                    out_channels = inps[0]->shape()[0];
                    // pad input channels
                    size_t pad_channels = it->second(in_channels, false);
                    if (pad_channels > 0) {
                        inps[0] = pad_in_channels(inps[0], pad_channels);
                        m_padding_oprs.insert(opr);
                    }
442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457
                    return serialization::copy_opr_shallow(*opr, inps, opr->config());
                };
    } else {
        m_opr_replace_funcs[opr::ConvolutionBackwardData::typeinfo()] =
                [this](OperatorNodeBase* opr, const VarNodeArray& new_inp) {
                    mgb_assert(opr->input(0)->shape().eq_shape(new_inp[0]->shape()));
                    auto inps = new_inp;
                    size_t out_channels = opr->input(0)->shape()[0];
                    size_t new_out_channels = new_inp[1]->shape()[1];
                    // pad output channels
                    if (m_padding_oprs.count(opr->input(1)->owner_opr())) {
                        size_t pad_channels = new_out_channels - out_channels;
                        inps[0] = pad_out_channels(new_inp[0], pad_channels);
                    }
                    out_channels = inps[0]->shape()[0];

458 459 460 461 462
                    return serialization::copy_opr_shallow(*opr, inps, opr->config());
                };
    }
}

463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480
void PaddingChannelPass::add_format_aware_opr_replace_func(LayoutTrans layout_trans) {
    auto replace_format_aware_opr = [this, layout_trans](
                                            OperatorNodeBase* opr,
                                            const VarNodeArray& new_inp) {
        if (layout_trans == LayoutTrans::NCHW64) {
            if (opr->input(0)->dtype().enumv() != DTypeEnum::QuantizedS8 &&
                opr->input(0)->dtype().enumv() != DTypeEnum::QuantizedS4 &&
                opr->input(0)->dtype().enumv() != DTypeEnum::Quantized4Asymm) {
                mgb_assert(
                        m_padding_oprs.count(opr->input(0)->owner_opr()) == 0,
                        "operator(type:%s,name:%s) for data type(%s) cannot be "
                        "padded channel. extra info:"
                        "consumer(%s), producer(%s)",
                        opr->dyn_typeinfo()->name, opr->cname(),
                        opr->input(0)->dtype().name(), opr->cname(),
                        opr->input(0)->owner_opr()->cname());
                return serialization::copy_opr_shallow(*opr, new_inp, opr->config());
            }
481 482
        }
        mgb_assert(opr->input().size() == new_inp.size());
483 484
        if (m_padding_oprs.count(opr->input(0)->owner_opr())) {
            m_padding_oprs.insert(opr);
485 486 487
        }
        return serialization::copy_opr_shallow(*opr, new_inp, opr->config());
    };
488 489
    m_opr_replace_funcs[opr::PoolingForward::typeinfo()] = replace_format_aware_opr;
    m_opr_replace_funcs[opr::WarpPerspectiveForward::typeinfo()] =
490
            replace_format_aware_opr;
491 492 493
    m_opr_replace_funcs[opr::WarpAffine::typeinfo()] = replace_format_aware_opr;
    m_opr_replace_funcs[opr::AdaptivePooling::typeinfo()] = replace_format_aware_opr;
    m_opr_replace_funcs[opr::ResizeForward::typeinfo()] = replace_format_aware_opr;
494
}
495

496 497 498
void PaddingChannelPass::add_elemwise_like_opr_replace_func(LayoutTrans) {
    auto replace_elemwise_like_opr = [this](OperatorNodeBase* opr,
                                            const VarNodeArray& new_inp) {
499 500 501 502 503 504 505
        mgb_assert(opr->input().size() == new_inp.size());
        bool have_padding_inp = false;
        bool padding_all_inps = true;
        bool same_padding = true;
        size_t channels_after_padding = 0;
        size_t i = 0;
        for (auto&& cur_inp : opr->input()) {
506 507 508 509
            if (cur_inp->shape().is_scalar()) {
                ++i;
                continue;
            }
510
            bool padding_cur_inp = m_padding_oprs.count(cur_inp->owner_opr()) > 0;
511 512 513 514 515 516
            if (padding_cur_inp) {
                if (!have_padding_inp)
                    have_padding_inp = true;
                if (channels_after_padding == 0) {
                    channels_after_padding = new_inp[i]->shape()[1];
                } else {
M
Megvii Engine Team 已提交
517
                    same_padding = channels_after_padding == new_inp[i]->shape()[1];
518 519
                }
            }
520
            if (padding_all_inps && (!padding_cur_inp || !same_padding)) {
521
                padding_all_inps = false;
522
            }
523 524 525 526 527 528
            ++i;
        }
        if (have_padding_inp && !padding_all_inps) {
            auto inps = new_inp;
            for (size_t i = 0; i < new_inp.size(); ++i) {
                auto cur_inp = opr->input(i);
529
                bool padding_cur_inp = m_padding_oprs.count(cur_inp->owner_opr()) > 0;
530 531 532 533 534 535
                if (padding_cur_inp) {
                    inps[i] = extract_subtensor(inps[i], cur_inp->shape());
                }
            }
            return serialization::copy_opr_shallow(*opr, inps, opr->config());
        }
536
        if (padding_all_inps && have_padding_inp) {
537
            m_padding_oprs.insert(opr);
538 539 540
        }
        return serialization::copy_opr_shallow(*opr, new_inp, opr->config());
    };
541 542 543
    m_opr_replace_funcs[opr::ElemwiseMultiType::typeinfo()] = replace_elemwise_like_opr;
    m_opr_replace_funcs[opr::Elemwise::typeinfo()] = replace_elemwise_like_opr;
    m_opr_replace_funcs[opr::TypeCvt::typeinfo()] = replace_elemwise_like_opr;
544 545 546 547 548 549 550 551 552 553 554
    m_opr_replace_funcs[opr::PowC::typeinfo()] = replace_elemwise_like_opr;
}

void PaddingChannelPass::add_condition_padding_oprs_replace_func(LayoutTrans) {
    auto replace_condition_oprs = [this](OperatorNodeBase* opr,
                                         const VarNodeArray& new_inp) {
        mgb_assert(opr->input().size() == new_inp.size());
        bool can_forward_padding = true;
        if (auto reduce = opr->try_cast_final<opr::Reduce>()) {
            auto axis = reduce->param().axis;
            if (axis < 0) {
555
                axis += reduce->input(0)->shape().ndim;
556 557 558 559 560
            }
            //! don't reduce in channel
            if (reduce->input().size() > 1) {
                can_forward_padding = false;
            } else {
561
                can_forward_padding = axis != 1;
562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590
            }
        } else if (auto subtensor = opr->try_cast_final<opr::Subtensor>()) {
            auto indexs = subtensor->index_desc();
            size_t input_dim = subtensor->input(0)->shape().ndim;
            for (size_t id = 0; id < indexs.size(); id++) {
                if (indexs[id].axis.get(input_dim) == 1) {
                    //! when subtensor perform on channel dim, if is idx mode or
                    //! end is valid, it can forward without add subtensor
                    can_forward_padding &=
                            indexs[id].idx.node() || indexs[id].end.node();
                }
            }
        }
        auto inps = new_inp;
        for (size_t i = 0; i < new_inp.size(); ++i) {
            auto cur_inp = opr->input(i);
            bool padding_cur_inp = m_padding_oprs.count(cur_inp->owner_opr()) > 0;
            if (padding_cur_inp) {
                if (can_forward_padding) {
                    m_padding_oprs.insert(opr);
                } else {
                    inps[i] = extract_subtensor(inps[i], cur_inp->shape());
                }
            }
        }
        return serialization::copy_opr_shallow(*opr, inps, opr->config());
    };
    m_opr_replace_funcs[opr::Reduce::typeinfo()] = replace_condition_oprs;
    m_opr_replace_funcs[opr::Subtensor::typeinfo()] = replace_condition_oprs;
591
}
592

593 594 595
void PaddingChannelPass::add_nonpadding_oprs_replace_func(LayoutTrans) {
    auto replace_nonpadding_oprs = [this](OperatorNodeBase* opr,
                                          const VarNodeArray& new_inp) {
596 597 598 599
        mgb_assert(opr->input().size() == new_inp.size());
        auto inps = new_inp;
        for (size_t i = 0; i < new_inp.size(); ++i) {
            auto cur_inp = opr->input(i);
600
            bool padding_cur_inp = m_padding_oprs.count(cur_inp->owner_opr()) > 0;
601 602 603 604 605 606
            if (padding_cur_inp) {
                inps[i] = extract_subtensor(inps[i], cur_inp->shape());
            }
        }
        return serialization::copy_opr_shallow(*opr, inps, opr->config());
    };
607
    m_opr_replace_funcs[opr::Reshape::typeinfo()] = replace_nonpadding_oprs;
608
    m_opr_replace_funcs[opr::AxisAddRemove::typeinfo()] = replace_nonpadding_oprs;
609 610
    m_opr_replace_funcs[opr::GetVarShape::typeinfo()] = replace_nonpadding_oprs;
    m_opr_replace_funcs[opr::Concat::typeinfo()] = replace_nonpadding_oprs;
611 612 613 614 615
    m_opr_replace_funcs[opr::Dimshuffle::typeinfo()] = replace_nonpadding_oprs;
    m_opr_replace_funcs[opr::Argmax::typeinfo()] = replace_nonpadding_oprs;
    m_opr_replace_funcs[opr::Argmin::typeinfo()] = replace_nonpadding_oprs;
    m_opr_replace_funcs[opr::IncrSubtensor::typeinfo()] = replace_nonpadding_oprs;
    m_opr_replace_funcs[opr::AssertEqual::typeinfo()] = replace_nonpadding_oprs;
616 617 618
}

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}