convolution.cpp 11.1 KB
Newer Older
1
#include "megbrain/opr/dnn/convolution.h"
2 3 4
#include "../algo_chooser.h"
#include "../blob_manager_impl.h"
#include "../dnn_op_helper.h"
5
#include "../op_trait.h"
6 7
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
8 9 10

namespace mgb {
namespace imperative {
M
Megvii Engine Team 已提交
11 12
namespace {
namespace convolution {
13 14 15 16 17
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
    auto* node = &node_->cast_final_safe<opr::Convolution>();
    return Convolution::make(node->param(), node->execution_policy());
}

M
Megvii Engine Team 已提交
18
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
19 20
    auto&& conv = static_cast<const Convolution&>(def);
    OperatorNodeConfig config{conv.make_name()};
M
Megvii Engine Team 已提交
21 22
    return opr::Convolution::make(
            inputs[0], inputs[1], conv.param(), conv.policy(), config);
23 24
}

25 26
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
        const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
27 28 29 30 31 32 33 34
    auto&& conv = def.cast_final_safe<Convolution>();
    DnnOprHelper<megdnn::ConvolutionForward> dnn_opr(conv.param());
    auto&& data = inputs[0].layout;
    auto&& filter = inputs[1].layout;
    TensorLayout output_layout{data.dtype};
    if (data.ndim && filter.ndim) {
        // deduce_layout won't override existing dtype
        dnn_opr.opr().deduce_layout(data, filter, output_layout);
35
    }
36
    return {{{output_layout, inputs[0].comp_node}}, output_layout.ndim != 0};
37 38
}

39 40 41
// Convolution::Param -> ConvBias::Param
auto conv_bias_param_from_convolution(const Convolution& conv) {
    megdnn::ConvBias::Param param;
42 43 44 45 46 47 48 49 50
    param.pad_h = conv.pad_h;
    param.pad_w = conv.pad_w;
    param.stride_h = conv.stride_h;
    param.stride_w = conv.stride_w;
    param.dilate_h = conv.dilate_h;
    param.dilate_w = conv.dilate_w;
    param.sparse = conv.sparse;
    param.compute_mode = conv.compute_mode;
    param.format = conv.format;
51 52
    return param;
}
53

54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
SmallVector<TensorPtr> apply_on_physical_tensor(
        const OpDef& def, const SmallVector<TensorPtr>& inputs,
        SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
    // create megdnn opr
    auto&& conv = def.cast_final_safe<Convolution>();
    CompNode cn = inputs[0]->comp_node();
    auto&& param = conv_bias_param_from_convolution(conv);
    DnnOprCaller<megdnn::ConvBiasForward> dnn_opr(cn, param, conv.policy());

    megdnn::TensorND empty_bias;
    empty_bias.layout.dtype = inputs[0]->dtype();
    empty_bias.layout.ndim = 0;

    auto out_layout = [&] {
        if (validated) {
            return output_descs[0].layout;
        } else {
            TensorLayout out_layout{inputs[0]->dtype()};
            dnn_opr.op()->deduce_layout(
                    inputs[0]->layout(), inputs[1]->layout(), empty_bias.layout,
                    empty_bias.layout, out_layout);
            return out_layout;
        }
    }();
78 79

    // alloc memory
80
    auto out = Tensor::make(out_layout, cn);
81
    dnn_opr.exec_fastrun(inputs[0], inputs[1], empty_bias, empty_bias, out);
82
    return {out};
83 84
}

85
OP_TRAIT_REG(Convolution, Convolution, opr::Convolution)
M
Megvii Engine Team 已提交
86 87
        .make_from_op_node(make_from_op_node)
        .apply_on_var_node(apply_on_var_node)
88 89
        .infer_output_attrs_fallible(infer_output_attrs_fallible)
        .apply_on_physical_tensor(apply_on_physical_tensor)
M
Megvii Engine Team 已提交
90 91 92
        .fallback();
}  // namespace convolution
}  // namespace
93

94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117
namespace {
namespace conv_bias {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
    auto&& conv = static_cast<const ConvBias&>(def);
    cg::OperatorNodeConfig config{conv.dtype};
    config.name(conv.make_name());
    if (inputs.size() == 2) {
        return opr::ConvBias::make(
                inputs[0], inputs[1], conv.param(), conv.policy(), config);
    } else if (inputs.size() == 3) {
        return opr::ConvBias::make(
                inputs[0], inputs[1], inputs[2], conv.param(), conv.policy(), config);
    } else if (inputs.size() == 4) {
        return opr::ConvBias::make(
                inputs[0], inputs[1], inputs[2], inputs[3], conv.param(), conv.policy(),
                config);
    }
    mgb_assert(0);
}

OP_TRAIT_REG(ConvBias, ConvBias).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace conv_bias
}  // namespace

M
Megvii Engine Team 已提交
118 119 120
namespace {
namespace convolution_backward_data {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
121 122
    auto&& conv = static_cast<const ConvolutionBackwardData&>(def);
    OperatorNodeConfig config{conv.make_name()};
123 124 125 126 127
    DType output_dtype = conv.dtype;
    if (output_dtype.valid()) {
        config.output_dtype(output_dtype);
    }

128
    if (inputs.size() == 2) {
M
Megvii Engine Team 已提交
129 130
        return opr::ConvolutionBackwardData::make(
                inputs[0], inputs[1], conv.param(), conv.policy(), config);
131 132
    } else {
        mgb_assert(inputs.size() == 3);
M
Megvii Engine Team 已提交
133 134
        return opr::ConvolutionBackwardData::make(
                inputs[0], inputs[1], inputs[2], conv.param(), conv.policy(), config);
135 136 137
    }
}

138 139
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
        const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
140 141 142 143 144 145 146 147 148
    auto&& convbwd = def.cast_final_safe<ConvolutionBackwardData>();
    DnnOprHelper<megdnn::ConvolutionBackwardData> dnn_opr(convbwd.param());
    // force set dtype
    auto&& filter = inputs[0].layout;
    auto&& diff = inputs[1].layout;
    TensorLayout output_layout{convbwd.dtype};
    if (filter.ndim && diff.ndim) {
        // deduce_layout won't override existing dtype
        dnn_opr.opr().deduce_layout(filter, diff, output_layout);
149
    }
150
    return {{{output_layout, inputs[0].comp_node}}, output_layout.ndim != 0};
151 152 153 154 155 156
}

SmallVector<TensorPtr> apply_on_physical_tensor(
        const OpDef& def, const SmallVector<TensorPtr>& inputs,
        SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
    // create megdnn opr
157
    auto&& convbwd = def.cast_final_safe<ConvolutionBackwardData>();
158
    CompNode cn = inputs[0]->comp_node();
159 160 161 162 163 164 165 166 167 168 169 170
    DnnOprCaller<megdnn::ConvolutionBackwardData> dnn_opr(
            cn, convbwd.param(), convbwd.policy());
    auto out_layout = [&] {
        if (validated) {
            return output_descs[0].layout;
        } else {
            TensorLayout out_layout{inputs[0]->dtype()};
            dnn_opr.op()->deduce_layout(
                    inputs[0]->layout(), inputs[1]->layout(), out_layout);
            return out_layout;
        }
    }();
171
    auto out = Tensor::make(out_layout, cn);
172
    dnn_opr.exec_fastrun(inputs[0], inputs[1], out);
173
    return {out};
174 175
}

176
OP_TRAIT_REG(ConvolutionBackwardData, ConvolutionBackwardData)
M
Megvii Engine Team 已提交
177
        .apply_on_var_node(apply_on_var_node)
178 179
        .infer_output_attrs_fallible(infer_output_attrs_fallible)
        .apply_on_physical_tensor(apply_on_physical_tensor)
M
Megvii Engine Team 已提交
180 181 182
        .fallback();
}  // namespace convolution_backward_data
}  // namespace
183

M
Megvii Engine Team 已提交
184 185
namespace {
namespace convolution3d {
186 187 188 189 190
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
    auto* node = &node_->cast_final_safe<opr::Convolution3D>();
    return Convolution3D::make(node->param(), node->execution_policy());
}

M
Megvii Engine Team 已提交
191
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
192 193 194 195
    auto&& conv = static_cast<const Convolution3D&>(def);
    return opr::Convolution3D::make(inputs[0], inputs[1], conv.param(), conv.policy());
}

196 197
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
        const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
198
    auto&& conv = def.cast_final_safe<Convolution3D>();
199
    TensorLayout src = inputs[0].layout;
200
    TensorLayout filter = inputs[1].layout;
201 202
    if (src.ndim == 0 || filter.ndim == 0) {
        return {{{TensorLayout{src.dtype}, inputs[0].comp_node}}, false};
203
    }
204 205 206
    DnnOprHelper<megdnn::Convolution3DForward> dnn_opr(conv.param());
    auto output = dnn_opr.deduce_layout(src, filter);
    return {{{output, inputs[0].comp_node}}, false};
207 208 209 210 211 212
}

SmallVector<TensorPtr> apply_on_physical_tensor(
        const OpDef& def, const SmallVector<TensorPtr>& inputs,
        SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
    // create megdnn opr
213
    auto&& conv = def.cast_final_safe<Convolution3D>();
214
    CompNode cn = inputs[0]->comp_node();
215 216 217 218 219 220 221 222
    DnnOprCaller<megdnn::Convolution3D> dnn_opr(cn, conv.param(), conv.policy());
    auto out_layout = [&] {
        if (validated) {
            return output_descs[0].layout;
        } else {
            return dnn_opr.deduce_layout(inputs[0]->layout(), inputs[1]->layout());
        }
    }();
223
    // alloc memory
224
    auto out = Tensor::make(out_layout, cn);
225
    dnn_opr.exec_fastrun(inputs[0], inputs[1], out);
226
    return {out};
227 228
}

229
OP_TRAIT_REG(Convolution3D, Convolution3D, opr::Convolution3D)
M
Megvii Engine Team 已提交
230 231
        .make_from_op_node(make_from_op_node)
        .apply_on_var_node(apply_on_var_node)
232 233
        .infer_output_attrs_fallible(infer_output_attrs_fallible)
        .apply_on_physical_tensor(apply_on_physical_tensor)
M
Megvii Engine Team 已提交
234 235 236
        .fallback();
}  // namespace convolution3d
}  // namespace
237

M
Megvii Engine Team 已提交
238 239
namespace {
namespace convolution3d_backward_data {
240 241 242 243 244 245 246 247 248 249

std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
        const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
    mgb_assert(
            inputs.size() == 2,
            "inputs num of conv_transpose3d should be 2 but you give %zu",
            inputs.size());
    auto&& op_def = def.cast_final_safe<Convolution3DBackwardData>();
    auto&& weight = inputs[0];
    auto&& diff = inputs[1];
250 251
    if (!(weight.layout.ndim && diff.layout.ndim)) {
        return {{{TensorLayout{weight.layout.dtype}, weight.comp_node}}, false};
252
    }
253 254 255
    DnnOprHelper<megdnn::Convolution3DBackwardData> dnn_opr(op_def.param());
    auto oup_layout = dnn_opr.deduce_layout(weight.layout, diff.layout);
    return {{{oup_layout, weight.comp_node}}, true};
256 257 258 259 260
}

SmallVector<TensorPtr> apply_on_physical_tensor(
        const OpDef& def, const SmallVector<TensorPtr>& inputs,
        SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
261
    auto&& conv = def.cast_final_safe<Convolution3DBackwardData>();
262
    auto cn = inputs[0]->comp_node();
263 264 265

    auto&& wlayout = inputs[0]->layout();
    auto&& dlayout = inputs[1]->layout();
266

267 268
    DnnOprCaller<megdnn::Convolution3DBackwardData> dnn_op(
            cn, conv.param(), conv.policy());
269

270 271 272 273 274 275 276
    auto oup_layout = [&] {
        if (validated) {
            return output_descs[0].layout;
        } else {
            return dnn_op.deduce_layout(wlayout, dlayout);
        }
    }();
277
    auto oup = Tensor::make(oup_layout, cn);
278
    dnn_op.exec_fastrun(inputs[0], inputs[1], oup);
279
    return {oup};
280 281
}

M
Megvii Engine Team 已提交
282
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
283 284 285
    auto&& conv = static_cast<const Convolution3DBackwardData&>(def);
    OperatorNodeConfig config{conv.make_name()};
    mgb_assert(inputs.size() == 2);
M
Megvii Engine Team 已提交
286 287
    return opr::Convolution3DBackwardData::make(
            inputs[0], inputs[1], conv.param(), conv.policy(), config);
288 289 290
}

OP_TRAIT_REG(Convolution3DBackwardData, Convolution3DBackwardData)
M
Megvii Engine Team 已提交
291
        .apply_on_var_node(apply_on_var_node)
292 293
        .infer_output_attrs_fallible(infer_output_attrs_fallible)
        .apply_on_physical_tensor(apply_on_physical_tensor)
M
Megvii Engine Team 已提交
294 295 296
        .fallback();
}  // namespace convolution3d_backward_data
}  // namespace
297

M
Megvii Engine Team 已提交
298 299
}  // namespace imperative
}  // namespace mgb