pooling.cpp 4.2 KB
Newer Older
1 2 3 4
/**
 * \file src/opr/impl/dnn/pooling.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11 12
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */
#include "megbrain/opr/dnn/pooling.h"
#include "megbrain/graph/grad_impl.h"
13
#include "megbrain/opr/search_policy/algo_chooser.h"
14 15
#include "../internal/megdnn_opr_wrapper.inl"

16 17
#include "../search_policy/workspace_need_limit_getter.inl"

18 19 20 21
using namespace mgb;
using namespace opr;

MGB_DYN_TYPE_OBJ_FINAL_IMPL(PoolingForward);
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58

PoolingForward::PoolingForward(VarNode* i0, const Param& param,
                               const ExecutionPolicy& policy,
                               const OperatorNodeConfig& config)
        : Super(OperatorNodeBaseCtorParam{
                  i0->owner_graph(), config, "pooling", {i0}}) {
    init_megdnn_opr(*this, param);
    add_input({i0});
    m_policy = policy;
    
    intl::MegDNNOprInitPostCtor<PoolingForward>::apply(*this);
}

SymbolVar PoolingForward::make(SymbolVar i0, const Param& param,
                               const OperatorNodeConfig& config,
                               const ExecutionPolicy& policy) {
    intl::MegDNNOprInitInputsModifier<PoolingForward>::apply(param, {&i0});
    return i0.insert_single_output_opr<PoolingForward>(i0.node(), param, policy,
                                                       config);
}

void PoolingForward::init_output_static_infer_desc() {
    Super::set_nr_managed_outputs(this->output().size() - 1);
    Super::Super::init_output_static_infer_desc();
    init_output_static_infer_desc_workspace(
            intl::AutoAddWorkspaceNeedLimitGetter<megdnn::PoolingForward>::val);
}

size_t PoolingForward::get_workspace_size_bytes(
        const TensorShapeArray& input_shapes,
        const TensorShapeArray& output_shapes) const {
    return AlgoChooser<megdnn::PoolingForward>::setup_algo(
            {TensorLayout{input_shapes[0], input(0)->dtype(),
                          input(0)->format()},
             {output_shapes[0], output(0)->dtype(), output(0)->format()}},
            megdnn_opr(), this, false);
}
59

60
#if MGB_ENABLE_GRAD
61 62 63 64 65 66
MGB_IMPL_OPR_GRAD(PoolingForward) {
    mgb_assert(wrt_idx == 0);
    SymbolVar grad = PoolingBackward::make(
            opr.input(0), opr.output(0), out_grad[0], opr.param());
    return grad.node();
}
67
#endif
68 69

MGB_DYN_TYPE_OBJ_FINAL_IMPL(PoolingBackward);
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104

PoolingBackward::PoolingBackward(VarNode* i0, VarNode* i1, VarNode* i2,
                                 const Param& param,
                                 const ExecutionPolicy& policy,
                                 const OperatorNodeConfig& config)
        : Super(
                  OperatorNodeBaseCtorParam{
                          i0->owner_graph(), config, "pooling_bwd", {i0}},
                  0, true) {
    init_megdnn_opr(*this, param);
    add_input({i0, i1, i2});
    intl::MegDNNOprInitPostCtor<PoolingBackward>::apply(*this);
}

SymbolVar PoolingBackward::make(SymbolVar i0, SymbolVar i1, SymbolVar i2,
                                const Param& param,
                                const OperatorNodeConfig& config,
                                const ExecutionPolicy& policy) {
    intl::MegDNNOprInitInputsModifier<PoolingBackward>::apply(param,
                                                              {&i0, &i1, &i2});
    return i0.insert_single_output_opr<PoolingBackward>(
            i0.node(), i1.node(), i2.node(), param, policy, config);
}

size_t PoolingBackward::get_workspace_size_bytes(
        const TensorShapeArray& input_shapes,
        const TensorShapeArray& output_shapes) const {
    return AlgoChooser<megdnn::PoolingBackward>::setup_algo(
            {TensorLayout{input_shapes[0], input(0)->dtype(),
                          input(0)->format()},
             {input_shapes[1], input(1)->dtype(), input(1)->format()},
             {input_shapes[2], input(2)->dtype(), input(2)->format()},
             {output_shapes[0], output(0)->dtype(), output(0)->format()}},
            megdnn_opr(), this, false);
}
105 106 107

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}