convolution.cpp 12.0 KB
Newer Older
1
#include "megbrain/opr/dnn/convolution.h"
2 3 4
#include "../algo_chooser.h"
#include "../blob_manager_impl.h"
#include "../dnn_op_helper.h"
5
#include "../op_trait.h"
6 7
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
8 9 10

namespace mgb {
namespace imperative {
M
Megvii Engine Team 已提交
11 12
namespace {
namespace convolution {
13 14 15 16 17
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
    auto* node = &node_->cast_final_safe<opr::Convolution>();
    return Convolution::make(node->param(), node->execution_policy());
}

M
Megvii Engine Team 已提交
18
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
19 20
    auto&& conv = static_cast<const Convolution&>(def);
    OperatorNodeConfig config{conv.make_name()};
M
Megvii Engine Team 已提交
21 22
    return opr::Convolution::make(
            inputs[0], inputs[1], conv.param(), conv.policy(), config);
23 24
}

25 26
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
        const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
27 28 29 30 31 32 33 34
    auto&& conv = def.cast_final_safe<Convolution>();
    DnnOprHelper<megdnn::ConvolutionForward> dnn_opr(conv.param());
    auto&& data = inputs[0].layout;
    auto&& filter = inputs[1].layout;
    TensorLayout output_layout{data.dtype};
    if (data.ndim && filter.ndim) {
        // deduce_layout won't override existing dtype
        dnn_opr.opr().deduce_layout(data, filter, output_layout);
35
    }
36
    return {{{output_layout, inputs[0].comp_node}}, output_layout.ndim != 0};
37 38
}

39 40 41
// Convolution::Param -> ConvBias::Param
auto conv_bias_param_from_convolution(const Convolution& conv) {
    megdnn::ConvBias::Param param;
42 43 44 45 46 47 48 49 50
    param.pad_h = conv.pad_h;
    param.pad_w = conv.pad_w;
    param.stride_h = conv.stride_h;
    param.stride_w = conv.stride_w;
    param.dilate_h = conv.dilate_h;
    param.dilate_w = conv.dilate_w;
    param.sparse = conv.sparse;
    param.compute_mode = conv.compute_mode;
    param.format = conv.format;
51 52
    return param;
}
53

54 55 56 57 58 59
SmallVector<TensorPtr> apply_on_physical_tensor(
        const OpDef& def, const SmallVector<TensorPtr>& inputs,
        SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
    // create megdnn opr
    auto&& conv = def.cast_final_safe<Convolution>();
    CompNode cn = inputs[0]->comp_node();
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81

    // calling dnn ConvolutionForward when device is rocm
    // because there is no dnn ConvBiasForward on rocm
    if (cn.device_type() == CompNode::DeviceType::ROCM) {
        DnnOprCaller<megdnn::ConvolutionForward> dnn_opr(
                cn, conv.param(), conv.policy());
        auto out_layout = [&] {
            if (validated) {
                return output_descs[0].layout;
            } else {
                return dnn_opr.deduce_layout(inputs[0]->layout(), inputs[1]->layout());
            }
        }();

        // alloc memory
        auto out = Tensor::make(out_layout, cn);
        dnn_opr.exec_fastrun(inputs[0], inputs[1], out);
        return {out};
    }

    // calling dnn ConvBiasForward on cuda because it's faster then ConvolutionForward
    // ConvolutionForward internally uses ConvBiasForward to calculate the result
82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
    auto&& param = conv_bias_param_from_convolution(conv);
    DnnOprCaller<megdnn::ConvBiasForward> dnn_opr(cn, param, conv.policy());

    megdnn::TensorND empty_bias;
    empty_bias.layout.dtype = inputs[0]->dtype();
    empty_bias.layout.ndim = 0;

    auto out_layout = [&] {
        if (validated) {
            return output_descs[0].layout;
        } else {
            TensorLayout out_layout{inputs[0]->dtype()};
            dnn_opr.op()->deduce_layout(
                    inputs[0]->layout(), inputs[1]->layout(), empty_bias.layout,
                    empty_bias.layout, out_layout);
            return out_layout;
        }
    }();
100 101

    // alloc memory
102
    auto out = Tensor::make(out_layout, cn);
103
    dnn_opr.exec_fastrun(inputs[0], inputs[1], empty_bias, empty_bias, out);
104
    return {out};
105 106
}

107
OP_TRAIT_REG(Convolution, Convolution, opr::Convolution)
M
Megvii Engine Team 已提交
108 109
        .make_from_op_node(make_from_op_node)
        .apply_on_var_node(apply_on_var_node)
110 111
        .infer_output_attrs_fallible(infer_output_attrs_fallible)
        .apply_on_physical_tensor(apply_on_physical_tensor)
M
Megvii Engine Team 已提交
112 113 114
        .fallback();
}  // namespace convolution
}  // namespace
115

116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
namespace {
namespace conv_bias {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
    auto&& conv = static_cast<const ConvBias&>(def);
    cg::OperatorNodeConfig config{conv.dtype};
    config.name(conv.make_name());
    if (inputs.size() == 2) {
        return opr::ConvBias::make(
                inputs[0], inputs[1], conv.param(), conv.policy(), config);
    } else if (inputs.size() == 3) {
        return opr::ConvBias::make(
                inputs[0], inputs[1], inputs[2], conv.param(), conv.policy(), config);
    } else if (inputs.size() == 4) {
        return opr::ConvBias::make(
                inputs[0], inputs[1], inputs[2], inputs[3], conv.param(), conv.policy(),
                config);
    }
    mgb_assert(0);
}

OP_TRAIT_REG(ConvBias, ConvBias).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace conv_bias
}  // namespace

M
Megvii Engine Team 已提交
140 141 142
namespace {
namespace convolution_backward_data {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
143 144
    auto&& conv = static_cast<const ConvolutionBackwardData&>(def);
    OperatorNodeConfig config{conv.make_name()};
145 146 147 148 149
    DType output_dtype = conv.dtype;
    if (output_dtype.valid()) {
        config.output_dtype(output_dtype);
    }

150
    if (inputs.size() == 2) {
M
Megvii Engine Team 已提交
151 152
        return opr::ConvolutionBackwardData::make(
                inputs[0], inputs[1], conv.param(), conv.policy(), config);
153 154
    } else {
        mgb_assert(inputs.size() == 3);
M
Megvii Engine Team 已提交
155 156
        return opr::ConvolutionBackwardData::make(
                inputs[0], inputs[1], inputs[2], conv.param(), conv.policy(), config);
157 158 159
    }
}

160 161
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
        const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
162 163 164 165 166 167 168 169 170
    auto&& convbwd = def.cast_final_safe<ConvolutionBackwardData>();
    DnnOprHelper<megdnn::ConvolutionBackwardData> dnn_opr(convbwd.param());
    // force set dtype
    auto&& filter = inputs[0].layout;
    auto&& diff = inputs[1].layout;
    TensorLayout output_layout{convbwd.dtype};
    if (filter.ndim && diff.ndim) {
        // deduce_layout won't override existing dtype
        dnn_opr.opr().deduce_layout(filter, diff, output_layout);
171
    }
172
    return {{{output_layout, inputs[0].comp_node}}, output_layout.ndim != 0};
173 174 175 176 177 178
}

SmallVector<TensorPtr> apply_on_physical_tensor(
        const OpDef& def, const SmallVector<TensorPtr>& inputs,
        SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
    // create megdnn opr
179
    auto&& convbwd = def.cast_final_safe<ConvolutionBackwardData>();
180
    CompNode cn = inputs[0]->comp_node();
181 182 183 184 185 186 187 188 189 190 191 192
    DnnOprCaller<megdnn::ConvolutionBackwardData> dnn_opr(
            cn, convbwd.param(), convbwd.policy());
    auto out_layout = [&] {
        if (validated) {
            return output_descs[0].layout;
        } else {
            TensorLayout out_layout{inputs[0]->dtype()};
            dnn_opr.op()->deduce_layout(
                    inputs[0]->layout(), inputs[1]->layout(), out_layout);
            return out_layout;
        }
    }();
193
    auto out = Tensor::make(out_layout, cn);
194
    dnn_opr.exec_fastrun(inputs[0], inputs[1], out);
195
    return {out};
196 197
}

198
OP_TRAIT_REG(ConvolutionBackwardData, ConvolutionBackwardData)
M
Megvii Engine Team 已提交
199
        .apply_on_var_node(apply_on_var_node)
200 201
        .infer_output_attrs_fallible(infer_output_attrs_fallible)
        .apply_on_physical_tensor(apply_on_physical_tensor)
M
Megvii Engine Team 已提交
202 203 204
        .fallback();
}  // namespace convolution_backward_data
}  // namespace
205

M
Megvii Engine Team 已提交
206 207
namespace {
namespace convolution3d {
208 209 210 211 212
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
    auto* node = &node_->cast_final_safe<opr::Convolution3D>();
    return Convolution3D::make(node->param(), node->execution_policy());
}

M
Megvii Engine Team 已提交
213
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
214 215 216 217
    auto&& conv = static_cast<const Convolution3D&>(def);
    return opr::Convolution3D::make(inputs[0], inputs[1], conv.param(), conv.policy());
}

218 219
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
        const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
220
    auto&& conv = def.cast_final_safe<Convolution3D>();
221
    TensorLayout src = inputs[0].layout;
222
    TensorLayout filter = inputs[1].layout;
223 224
    if (src.ndim == 0 || filter.ndim == 0) {
        return {{{TensorLayout{src.dtype}, inputs[0].comp_node}}, false};
225
    }
226 227 228
    DnnOprHelper<megdnn::Convolution3DForward> dnn_opr(conv.param());
    auto output = dnn_opr.deduce_layout(src, filter);
    return {{{output, inputs[0].comp_node}}, false};
229 230 231 232 233 234
}

SmallVector<TensorPtr> apply_on_physical_tensor(
        const OpDef& def, const SmallVector<TensorPtr>& inputs,
        SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
    // create megdnn opr
235
    auto&& conv = def.cast_final_safe<Convolution3D>();
236
    CompNode cn = inputs[0]->comp_node();
237 238 239 240 241 242 243 244
    DnnOprCaller<megdnn::Convolution3D> dnn_opr(cn, conv.param(), conv.policy());
    auto out_layout = [&] {
        if (validated) {
            return output_descs[0].layout;
        } else {
            return dnn_opr.deduce_layout(inputs[0]->layout(), inputs[1]->layout());
        }
    }();
245
    // alloc memory
246
    auto out = Tensor::make(out_layout, cn);
247
    dnn_opr.exec_fastrun(inputs[0], inputs[1], out);
248
    return {out};
249 250
}

251
OP_TRAIT_REG(Convolution3D, Convolution3D, opr::Convolution3D)
M
Megvii Engine Team 已提交
252 253
        .make_from_op_node(make_from_op_node)
        .apply_on_var_node(apply_on_var_node)
254 255
        .infer_output_attrs_fallible(infer_output_attrs_fallible)
        .apply_on_physical_tensor(apply_on_physical_tensor)
M
Megvii Engine Team 已提交
256 257 258
        .fallback();
}  // namespace convolution3d
}  // namespace
259

M
Megvii Engine Team 已提交
260 261
namespace {
namespace convolution3d_backward_data {
262 263 264 265 266 267 268 269 270 271

std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
        const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
    mgb_assert(
            inputs.size() == 2,
            "inputs num of conv_transpose3d should be 2 but you give %zu",
            inputs.size());
    auto&& op_def = def.cast_final_safe<Convolution3DBackwardData>();
    auto&& weight = inputs[0];
    auto&& diff = inputs[1];
272 273
    if (!(weight.layout.ndim && diff.layout.ndim)) {
        return {{{TensorLayout{weight.layout.dtype}, weight.comp_node}}, false};
274
    }
275 276 277
    DnnOprHelper<megdnn::Convolution3DBackwardData> dnn_opr(op_def.param());
    auto oup_layout = dnn_opr.deduce_layout(weight.layout, diff.layout);
    return {{{oup_layout, weight.comp_node}}, true};
278 279 280 281 282
}

SmallVector<TensorPtr> apply_on_physical_tensor(
        const OpDef& def, const SmallVector<TensorPtr>& inputs,
        SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
283
    auto&& conv = def.cast_final_safe<Convolution3DBackwardData>();
284
    auto cn = inputs[0]->comp_node();
285 286 287

    auto&& wlayout = inputs[0]->layout();
    auto&& dlayout = inputs[1]->layout();
288

289 290
    DnnOprCaller<megdnn::Convolution3DBackwardData> dnn_op(
            cn, conv.param(), conv.policy());
291

292 293 294 295 296 297 298
    auto oup_layout = [&] {
        if (validated) {
            return output_descs[0].layout;
        } else {
            return dnn_op.deduce_layout(wlayout, dlayout);
        }
    }();
299
    auto oup = Tensor::make(oup_layout, cn);
300
    dnn_op.exec_fastrun(inputs[0], inputs[1], oup);
301
    return {oup};
302 303
}

M
Megvii Engine Team 已提交
304
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
305 306 307
    auto&& conv = static_cast<const Convolution3DBackwardData&>(def);
    OperatorNodeConfig config{conv.make_name()};
    mgb_assert(inputs.size() == 2);
M
Megvii Engine Team 已提交
308 309
    return opr::Convolution3DBackwardData::make(
            inputs[0], inputs[1], conv.param(), conv.policy(), config);
310 311 312
}

OP_TRAIT_REG(Convolution3DBackwardData, Convolution3DBackwardData)
M
Megvii Engine Team 已提交
313
        .apply_on_var_node(apply_on_var_node)
314 315
        .infer_output_attrs_fallible(infer_output_attrs_fallible)
        .apply_on_physical_tensor(apply_on_physical_tensor)
M
Megvii Engine Team 已提交
316 317 318
        .fallback();
}  // namespace convolution3d_backward_data
}  // namespace
319

M
Megvii Engine Team 已提交
320 321
}  // namespace imperative
}  // namespace mgb