adaptive_pooling.cpp 5.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12
#include "megbrain/opr/dnn/adaptive_pooling.h"
#include "../internal/megdnn_opr_wrapper.inl"
#include "megbrain/graph/grad_impl.h"
#include "megbrain/opr/utility.h"

#include "megdnn/opr_param_defs.h"
#include "megdnn/oprs/nn.h"

using namespace mgb;
using namespace opr;

MGB_DYN_TYPE_OBJ_FINAL_IMPL(AdaptivePoolingForward);
M
Megvii Engine Team 已提交
13 14 15 16 17
AdaptivePoolingForward::AdaptivePoolingForward(
        VarNode* src, VarNode* out_shape, const Param& param,
        const OperatorNodeConfig& config)
        : Super(OperatorNodeBaseCtorParam{
                  src->owner_graph(), config, "adaptive_pooling", {src, out_shape}}) {
18 19 20 21 22
    init_megdnn_opr(*this, param);
    add_input({src, out_shape});
    outshape_by_symvar_enable(1, 1);
}

M
Megvii Engine Team 已提交
23 24 25
SymbolVar AdaptivePoolingForward::make(
        SymbolVar src, SymbolVar out_shape, const Param& param,
        const OperatorNodeConfig& config) {
26 27 28 29 30
    return src.insert_single_output_opr<AdaptivePoolingForward>(
            src.node(), out_shape.node(), param, config);
}

void AdaptivePoolingForward::scn_do_execute() {
M
Megvii Engine Team 已提交
31 32 33
    megdnn_opr()->exec(
            input(0)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(),
            intl::get_megdnn_workspace_from_var(output().back()));
34 35 36 37 38 39 40
}

void AdaptivePoolingForward::outshape_by_symvar_do_get_output_shape(
        TensorShape& dest, const ShapeInferInfo& shpinfo) {
    TensorShape oshp2d;
    cg::copy_tensor_value_to_shape(oshp2d, *shpinfo.shpval_inp_val.at(0));
    auto src = shpinfo.shape_inp_shp.at(0);
M
Megvii Engine Team 已提交
41
    mgb_assert(
42
            src.ndim == 4 && (oshp2d.ndim == 2 || oshp2d.ndim == 1),
M
Megvii Engine Team 已提交
43 44
            "shape mismatch for AdaptivePooling: src=%s, out2d=%s",
            src.to_string().c_str(), oshp2d.to_string().c_str());
45

46
    auto param_format = param().format;
47
    bool tshp1n = oshp2d.ndim == 1;
48 49 50 51 52
    if (param_format == Param::Format::NCHW) {
        dest.ndim = 4;
        dest.shape[0] = src.shape[0];
        dest.shape[1] = src.shape[1];
        dest.shape[2] = oshp2d.shape[0];
53
        dest.shape[3] = (tshp1n) ? oshp2d.shape[0] : oshp2d.shape[1];
54 55 56 57
    } else if (param_format == Param::Format::NHWC) {
        dest.ndim = 4;
        dest.shape[0] = src.shape[0];
        dest.shape[1] = oshp2d.shape[0];
58
        dest.shape[2] = (tshp1n) ? oshp2d.shape[0] : oshp2d.shape[1];
59 60 61 62
        dest.shape[3] = src.shape[3];
    } else {
        mgb_throw(MegBrainError, "AdaptivePooling only support NCHW or NHWC format");
    }
63 64 65 66 67 68
}

size_t AdaptivePoolingForward::get_workspace_size_bytes(
        const TensorShapeArray& input_shapes,
        const TensorShapeArray& output_shapes) const {
    return megdnn_opr()->get_workspace_in_bytes(
M
Megvii Engine Team 已提交
69 70
            {input_shapes[0], this->input(0)->dtype(), this->input(0)->format()},
            {output_shapes[0], this->output(0)->dtype(), this->output(0)->format()});
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
}

void AdaptivePoolingForward::init_output_dtype() {
    output(0)->dtype(input(0)->dtype());
}

void AdaptivePoolingForward::add_input_layout_constraint() {
    mixin::megdnn_utils::add_input_layout_constraint_contig(*this);
}

void AdaptivePoolingForward::init_output_static_infer_desc() {
    Super::init_output_static_infer_desc();
    init_output_static_infer_desc_workspace(false);
}

void AdaptivePoolingForward::record_execute_deps(ExecDependencyArray& deps) {
    record_megdnn_opr(deps);
}

90
#if MGB_ENABLE_GRAD
91 92 93 94
MGB_IMPL_OPR_GRAD(AdaptivePoolingForward) {
    if (wrt_idx == 0) {
        // wrt src
        SymbolVar grad = AdaptivePoolingBackward::make(
M
Megvii Engine Team 已提交
95
                opr.input(0), opr.input(1), opr.output(0), out_grad[0], opr.param());
96 97 98 99 100 101 102 103 104 105 106 107
        return grad.node();
    } else {
        mgb_assert(wrt_idx == 1);
        return InvalidGrad::make(opr, wrt_idx);
    }
}
#endif

MGB_DYN_TYPE_OBJ_FINAL_IMPL(AdaptivePoolingBackward);
AdaptivePoolingBackward::AdaptivePoolingBackward(
        VarNode* src, VarNode* out_shape, VarNode* dst, VarNode* diff,
        const Param& param, const OperatorNodeConfig& config)
M
Megvii Engine Team 已提交
108 109 110 111
        : Super(
                  OperatorNodeBaseCtorParam{
                          src->owner_graph(), config, "adaptive_pooling_bwd", {src}},
                  0, true) {
112 113 114 115
    init_megdnn_opr(*this, param);
    add_input({src, out_shape, dst, diff});
}

M
Megvii Engine Team 已提交
116 117 118
SymbolVar AdaptivePoolingBackward::make(
        SymbolVar src, SymbolVar out_shape, SymbolVar dst, SymbolVar diff,
        const Param& param, const OperatorNodeConfig& config) {
119
    return src.insert_single_output_opr<AdaptivePoolingBackward>(
M
Megvii Engine Team 已提交
120
            src.node(), out_shape.node(), dst.node(), diff.node(), param, config);
121 122 123
}

void AdaptivePoolingBackward::scn_do_execute() {
M
Megvii Engine Team 已提交
124 125 126 127
    megdnn_opr()->exec(
            input(0)->dev_tensor().as_megdnn(), input(2)->dev_tensor().as_megdnn(),
            input(3)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(),
            intl::get_megdnn_workspace_from_var(output().back()));
128 129 130 131 132 133 134 135 136 137 138 139
}
size_t AdaptivePoolingBackward::get_workspace_size_bytes(
        const TensorShapeArray& input_shapes,
        const TensorShapeArray& output_shapes) const {
    return megdnn_opr()->get_workspace_in_bytes(
            {input_shapes[0], input(0)->dtype(), input(0)->format()},
            {input_shapes[2], input(2)->dtype(), input(2)->format()},
            {input_shapes[3], input(3)->dtype(), input(3)->format()},
            {output_shapes[0], output(0)->dtype(), output(0)->format()});
}

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}