convolution.cpp 3.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12
/**
 * \file imperative/src/impl/ops/dnn/convolution.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "megbrain/opr/dnn/convolution.h"
M
Megvii Engine Team 已提交
13
#include "megbrain/imperative/ops/autogen.h"
14 15 16 17 18 19

#include "../op_trait.h"

namespace mgb {
namespace imperative {

M
Megvii Engine Team 已提交
20 21
namespace {
namespace convolution {
22 23 24 25 26
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
    auto* node = &node_->cast_final_safe<opr::Convolution>();
    return Convolution::make(node->param(), node->execution_policy());
}

M
Megvii Engine Team 已提交
27
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
28 29
    auto&& conv = static_cast<const Convolution&>(def);
    OperatorNodeConfig config{conv.make_name()};
M
Megvii Engine Team 已提交
30 31
    return opr::Convolution::make(
            inputs[0], inputs[1], conv.param(), conv.policy(), config);
32 33 34
}

OP_TRAIT_REG(Convolution, Convolution, opr::Convolution)
M
Megvii Engine Team 已提交
35 36 37 38 39
        .make_from_op_node(make_from_op_node)
        .apply_on_var_node(apply_on_var_node)
        .fallback();
}  // namespace convolution
}  // namespace
40

M
Megvii Engine Team 已提交
41 42 43
namespace {
namespace convolution_backward_data {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
44 45
    auto&& conv = static_cast<const ConvolutionBackwardData&>(def);
    OperatorNodeConfig config{conv.make_name()};
46 47 48 49 50
    DType output_dtype = conv.dtype;
    if (output_dtype.valid()) {
        config.output_dtype(output_dtype);
    }

51
    if (inputs.size() == 2) {
M
Megvii Engine Team 已提交
52 53
        return opr::ConvolutionBackwardData::make(
                inputs[0], inputs[1], conv.param(), conv.policy(), config);
54 55
    } else {
        mgb_assert(inputs.size() == 3);
M
Megvii Engine Team 已提交
56 57
        return opr::ConvolutionBackwardData::make(
                inputs[0], inputs[1], inputs[2], conv.param(), conv.policy(), config);
58 59 60 61
    }
}

OP_TRAIT_REG(ConvolutionBackwardData, ConvolutionBackwardData)
M
Megvii Engine Team 已提交
62 63 64 65
        .apply_on_var_node(apply_on_var_node)
        .fallback();
}  // namespace convolution_backward_data
}  // namespace
66

M
Megvii Engine Team 已提交
67 68
namespace {
namespace convolution3d {
69 70 71 72 73
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
    auto* node = &node_->cast_final_safe<opr::Convolution3D>();
    return Convolution3D::make(node->param(), node->execution_policy());
}

M
Megvii Engine Team 已提交
74
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
75 76 77 78 79
    auto&& conv = static_cast<const Convolution3D&>(def);
    return opr::Convolution3D::make(inputs[0], inputs[1], conv.param(), conv.policy());
}

OP_TRAIT_REG(Convolution3D, Convolution3D, opr::Convolution3D)
M
Megvii Engine Team 已提交
80 81 82 83 84
        .make_from_op_node(make_from_op_node)
        .apply_on_var_node(apply_on_var_node)
        .fallback();
}  // namespace convolution3d
}  // namespace
85

M
Megvii Engine Team 已提交
86 87 88
namespace {
namespace convolution3d_backward_data {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
89 90 91
    auto&& conv = static_cast<const Convolution3DBackwardData&>(def);
    OperatorNodeConfig config{conv.make_name()};
    mgb_assert(inputs.size() == 2);
M
Megvii Engine Team 已提交
92 93
    return opr::Convolution3DBackwardData::make(
            inputs[0], inputs[1], conv.param(), conv.policy(), config);
94 95 96
}

OP_TRAIT_REG(Convolution3DBackwardData, Convolution3DBackwardData)
M
Megvii Engine Team 已提交
97 98 99 100
        .apply_on_var_node(apply_on_var_node)
        .fallback();
}  // namespace convolution3d_backward_data
}  // namespace
101

M
Megvii Engine Team 已提交
102 103
}  // namespace imperative
}  // namespace mgb