layer_norm.cpp 3.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
#include "megbrain/opr/dnn/layer_norm.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/internal/megdnn_opr_wrapper.h"

#include "../blob_manager_impl.h"
#include "../dnn_op_helper.h"
#include "../op_trait.h"

namespace mgb::imperative {
namespace layer_norm {

cg::OperatorNodeBase* apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
    auto&& op = static_cast<const LayerNorm&>(def);
    size_t nr_inp = inputs.size();
    auto p = op.param();
    mgb_assert((nr_inp == 3 && p.affine) || (nr_inp == 1 && !p.affine));
    OperatorNodeConfig config{op.make_name()};
    if (nr_inp == 3) {
        return opr::LayerNorm::make(
                       inputs[0], inputs[1], inputs[2], op.param(), config)[0]
                .node()
                ->owner_opr();
    } else {
        return opr::LayerNorm::make(inputs[0], op.param(), config)[0]
                .node()
                ->owner_opr();
    }
}

std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
        const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
    auto&& op_def = def.cast_final_safe<LayerNorm>();
    size_t nr_inp = inputs.size();
    auto p = op_def.param();
    mgb_assert(
            (nr_inp == 3 && p.affine) || (nr_inp == 1 && !p.affine),
            "num of inputs of pooling should be 1 or 3 but you give %zu",
            inputs.size());

    auto&& inp = inputs[0];
    auto& inp_cn = inp.comp_node;

    if (inp.layout.ndim == 0) {
        return {{{TensorLayout{inp.layout.dtype}, inp_cn, {}},
                 {TensorLayout{dtype::Float32()}, inp_cn, {}},
                 {TensorLayout{dtype::Float32()}, inp_cn, {}}},
                false};
    }

    TensorLayout oup_layout, mean_layout, rstd_layout;
    megdnn::LayerNorm::deduce_layout_fwd_impl(
            inp.layout, p, oup_layout, mean_layout, rstd_layout);
    return {{{oup_layout, inp_cn, {}},
             {mean_layout, inp_cn, {}},
             {rstd_layout, inp_cn, {}}},
            true};
}

SmallVector<TensorPtr> apply_on_physical_tensor(
        const OpDef& def, const SmallVector<TensorPtr>& inputs,
        SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
    auto&& op_def = def.cast_final_safe<LayerNorm>();
    size_t nr_inp = inputs.size();
    auto p = op_def.param();

    mgb_assert(
            (nr_inp == 3 && p.affine) || (nr_inp == 1 && !p.affine),
            "num of inputs of pooling should be 1 or 3 but you give %zu",
            inputs.size());

    auto cn = inputs[0]->comp_node();
    DnnOprCaller<megdnn::LayerNorm> caller(cn);
    auto&& dnn_opr = caller.op;
    dnn_opr->param() = p;

    TensorLayout oup_layout, mean_layout, rstd_layout;
    megdnn::LayerNorm::deduce_layout_fwd_impl(
            inputs[0]->dnn_tensor().layout, p, oup_layout, mean_layout, rstd_layout);

80 81 82 83 84 85
    auto out = Tensor::make(oup_layout, cn);

    auto mean = Tensor::make(mean_layout, cn);

    auto rstd = Tensor::make(rstd_layout, cn);

86 87 88 89 90
    auto wk_size = caller.op->get_workspace_in_bytes(
            inputs[0]->dnn_tensor().layout,
            p.affine ? inputs[1]->dnn_tensor().layout : TensorLayout(),
            p.affine ? inputs[2]->dnn_tensor().layout : TensorLayout(), oup_layout,
            mean_layout, rstd_layout);
91
    auto dnn_wk = caller.create_workspace(wk_size);
92

93
    caller.op->exec(
94 95
            inputs[0]->dnn_tensor(),
            p.affine ? inputs[1]->dnn_tensor() : megdnn::TensorND(),
96 97 98
            p.affine ? inputs[2]->dnn_tensor() : megdnn::TensorND(), out->dnn_tensor(),
            mean->dnn_tensor(), rstd->dnn_tensor(), dnn_wk);
    return {out, mean, rstd};
99 100 101 102 103 104 105 106 107 108
}

OP_TRAIT_REG(LayerNorm, LayerNorm)
        .apply_on_var_node(apply_on_var_node)
        .infer_output_attrs_fallible(infer_output_attrs_fallible)
        .apply_on_physical_tensor(apply_on_physical_tensor)
        .fallback();

}  // namespace layer_norm
}  // namespace mgb::imperative