elemwise.cpp 12.0 KB
Newer Older
M
Megvii Engine Team 已提交
1
#include "megbrain/imperative/opr_utility.h"
2 3
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/basic_arith.h"
4
#include "megbrain/opr/utility.h"
5

6
#include "../blob_manager_impl.h"
M
Megvii Engine Team 已提交
7 8
#include "../dnn_op_helper.h"
#include "../op_trait.h"
9 10 11 12 13 14 15 16 17 18 19

namespace mgb {
namespace imperative {

namespace {

std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
    auto* node = &node_->cast_final_safe<opr::Elemwise>();
    return Elemwise::make(node->param().mode);
}

M
Megvii Engine Team 已提交
20
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
21
    auto&& elemwise_opr = def.cast_final_safe<Elemwise>();
22
    OperatorNodeConfig config{elemwise_opr.make_name()};
23
    return opr::Elemwise::make(inputs, elemwise_opr.mode, config);
24 25
}

26
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
M
Megvii Engine Team 已提交
27
        const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
28
    auto&& op_def = def.cast_final_safe<Elemwise>();
29
    auto trait = megdnn::Elemwise::ModeTrait::from_mode(op_def.mode);
M
Megvii Engine Team 已提交
30 31 32
    mgb_assert(
            inputs.size() == trait.arity, "%s expects %u inputs; got %zu actually",
            trait.name, trait.arity, inputs.size());
33 34 35
    TensorShapeArray inp_shapes;
    DType out_dt;
    CompNode out_cn;
M
Megvii Engine Team 已提交
36 37
    for (size_t i = 0; i < inputs.size(); ++i) {
        auto&& t = inputs[i];
38 39 40 41 42 43 44 45 46 47 48 49 50
        if (!i) {
            out_cn = t.comp_node;
            out_dt = t.layout.dtype;
        } else {
            mgb_assert(t.comp_node == out_cn);
            mgb_assert(t.layout.dtype == out_dt);
        }
        if (t.layout.ndim > 0) {
            inp_shapes.push_back(t.layout);
        } else {
            TensorLayout out_layout;
            out_layout.ndim = 0;
            out_layout.dtype = out_dt;
51
            return {{{out_layout, out_cn}}, false};
52 53
        }
    }
54 55 56
    // copy from megdnn::ElemwiseForward::check_dtype
    switch (out_dt.category()) {
        case DTypeCategory::FLOAT:
M
Megvii Engine Team 已提交
57
            mgb_assert(trait.allow_float, "unsupport mode %s for float\n", trait.name);
58 59
            break;
        case DTypeCategory::INT:
M
Megvii Engine Team 已提交
60
            mgb_assert(trait.allow_int, "unsupport mode %s for int\n", trait.name);
61 62
            break;
        case DTypeCategory::BOOL:
M
Megvii Engine Team 已提交
63
            mgb_assert(trait.allow_bool, "unsupport mode %s for bool\n", trait.name);
64 65 66 67 68 69
            break;
        default:
            // Quantized Dtype could also be handled by this op,
            // but scales need to be the same.
            break;
    }
70 71

    auto&& out_shape = opr::Elemwise::get_output_var_shape(op_def.mode, inp_shapes);
72
    return {{{TensorLayout(out_shape, out_dt, inputs[0].layout.format), out_cn}}, true};
73 74
}

75
DispatchMode decide_dispatch_mode(
M
Megvii Engine Team 已提交
76
        const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
77 78 79
    bool host_computable = true;
    constexpr int size_threshhold = TensorShape::MAX_NDIM;
    for (auto&& inp : inputs) {
M
Megvii Engine Team 已提交
80 81
        if (inp.value.empty() || inp.value.layout().ndim == 0 ||
            inp.value.layout().total_nr_elems() > size_threshhold) {
82 83 84 85 86 87 88 89
            host_computable = false;
            break;
        }
    }
    return host_computable ? DEFAULT_CPU : KERNEL;
}

void apply_on_device_tensornd(
M
Megvii Engine Team 已提交
90
        const OpDef& def, const SmallVector<DeviceTensorND>& inputs,
91
        SmallVector<DeviceTensorND>* outputs) {
92
    auto&& op_def = def.cast_final_safe<Elemwise>();
93
    auto&& trait = megdnn::Elemwise::ModeTrait::from_mode(op_def.mode);
M
Megvii Engine Team 已提交
94 95 96
    mgb_assert(
            inputs.size() == trait.arity, "%s expects %u inputs; got %zu actually",
            trait.name, trait.arity, inputs.size());
97 98
    DnnOprCaller<megdnn::Elemwise> dnn_opr(inputs[0].comp_node());
    opr::Elemwise::perform(op_def.mode, (*outputs)[0], inputs, dnn_opr.op);
99 100 101
}

SmallVector<TensorPtr> apply_on_physical_tensor(
102 103
        const OpDef& def, const SmallVector<TensorPtr>& inputs,
        SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
104 105 106 107 108 109 110 111 112 113
    auto comp_node = inputs[0]->comp_node();
    using Mode = Elemwise::Mode;
    using TensorND = megdnn::TensorND;
    auto&& op_def = def.cast_final_safe<Elemwise>();
    SmallVector<TensorND> inp_tensornds;
    TensorShapeArray inp_shapes(inputs.size());
    inp_tensornds.reserve(inputs.size());

    TensorLayout layout{inputs[0]->layout().dtype};
    bool is_empty = false;
M
Megvii Engine Team 已提交
114
    for (unsigned i = 0; i < inputs.size(); ++i) {
115 116 117 118 119 120 121 122 123 124 125 126 127 128
        if (inputs[i]->layout().is_empty()) {
            is_empty = true;
        }
        inp_tensornds.push_back(inputs[i]->dnn_tensor());
        inp_shapes[i] = inputs[i]->layout();
    }
    megdnn::Elemwise::deduce_shape(inp_shapes, layout);
    layout.init_contiguous_stride();

    DeviceTensorND out =
            BlobManager::inst()->alloc_workspace_with_defrag(comp_node, layout);
    if (is_empty) {
        return {Tensor::make(out)};
    }
129
    DnnOprCaller<megdnn::Elemwise> dnn_opr(comp_node);
130

131 132 133
    dnn_opr.op->param() = op_def.param();
    if (dnn_opr.op->param().mode == Mode::FUSE_MUL_ADD3 ||
        dnn_opr.op->param().mode == Mode::FUSE_MUL_ADD4 ||
134 135
        (inp_tensornds.size() &&
         inp_tensornds[0].layout.dtype.category() == DTypeCategory::QUANTIZED)) {
136
        opr::Elemwise::perform_dnn(comp_node, out, inp_tensornds, dnn_opr.op);
137
    } else {
138
        dnn_opr.op->exec(inp_tensornds, out.as_megdnn());
139
    }
140 141

    return {Tensor::make(out)};
142 143
}

M
Megvii Engine Team 已提交
144 145
MGB_DEFINE_OPR_CLASS(
        ForceInplaceElemwise,
146
        cg::SingleCNOperatorNodeBaseT<opr::mixin::MegDNNOprHolder>) // {
147
public:
148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
    struct Param {
        using Mode = megdnn::Elemwise::Param::Mode;
        Mode mode;
        size_t inplace_index;
    };
    using Mode = Param::Mode;
    ForceInplaceElemwise(
            const VarNodeArray& inputs, Param param, OperatorNodeConfig config = {})
            : Super(inputs[0]->owner_graph(), config, "device_add_update", inputs),
              m_param{param} {
        for (auto* input : inputs) {
            add_input({input});
        }
        add_output(None)
                ->set_fwd_in2out_writable_force(input(param.inplace_index))
                .add_flag(VarNode::Flag::NO_MEM_RECLAIM);
164
    }
165 166 167 168 169 170 171 172
    static SymbolVar make(const VarNodeArray& inputs, Param param) {
        return SymbolVar{inputs[0]}.insert_single_output_opr<ForceInplaceElemwise>(
                inputs, param);
    }
    static cg::OperatorNodeBase* shallow_copy(
            const serialization::OprShallowCopyContext& ctx,
            const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs,
            const OperatorNodeConfig& config);
M
Megvii Engine Team 已提交
173

174
protected:
175 176 177 178
    NodeProp* do_make_node_prop() const override {
        auto ret = Super::do_make_node_prop();
        ret->add_flag(NodeProp::Flag::FORCE_UPDATE_INPUT_VAR);
        return ret;
179
    }
180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200
    void create_megdnn_opr() override {
        auto opr = DnnOprCaller<megdnn::Elemwise>::create_operator(comp_node());
        opr->param().mode = m_param.mode;
        set_megdnn_opr(std::move(opr));
    }
    void scn_do_execute() override {
        auto to_dnnnd = [&](auto* var) { return var->dev_tensor().as_megdnn(); };
        megdnn::TensorNDArray inputs_dnnnd;
        for (auto* input : input()) {
            inputs_dnnnd.push_back(to_dnnnd(input));
        }
        mgb_assert(
                input(m_param.inplace_index)
                        ->contain_flag(VarNode::Flag::NO_SYS_MEM_ALLOC),
                "ForceInplaceElemwise cannot be applied in internal tensor");
        auto* out_dest = output(0);
        auto* opr = static_cast<megdnn::Elemwise*>(megdnn_opr());
        opr->exec(std::move(inputs_dnnnd), to_dnnnd(out_dest));
    }
    void init_output_static_infer_desc() override {
        using namespace cg::static_infer;
M
Megvii Engine Team 已提交
201

202 203 204
        owner_graph()->static_infer_manager().register_shape_infer(
                output(0), ShapeInferDesc::make_identity(input(m_param.inplace_index)));
    }
205 206

private:
207 208 209 210
    Param m_param;
    void record_execute_deps(ExecDependencyArray& deps) override {
        record_megdnn_opr(deps);
    }
211 212 213 214 215
};

MGB_DYN_TYPE_OBJ_FINAL_IMPL(ForceInplaceElemwise);

cg::OperatorNodeBase* ForceInplaceElemwise::shallow_copy(
M
Megvii Engine Team 已提交
216 217 218 219
        const serialization::OprShallowCopyContext& ctx,
        const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs,
        const OperatorNodeConfig& config) {
    auto&& opr = opr_.cast_final_safe<ForceInplaceElemwise>();
220
    auto* graph = ctx.owner_graph(opr, inputs);
M
Megvii Engine Team 已提交
221 222
    return graph->insert_opr(
            std::make_unique<ForceInplaceElemwise>(inputs, opr.m_param, config));
223 224 225 226 227
}

MGB_REG_OPR_SHALLOW_COPY(ForceInplaceElemwise, ForceInplaceElemwise::shallow_copy);

cg::OperatorNodeBase* apply_inplace_add_on_var_node(
M
Megvii Engine Team 已提交
228 229
        const OpDef& def, const VarNodeArray& inputs) {
    auto dest = inputs[0], delta = inputs[1], alpha = inputs[2], beta = inputs[3];
230
    auto mode = ForceInplaceElemwise::Param::Mode::FUSE_MUL_ADD4;
M
Megvii Engine Team 已提交
231 232 233
    return ForceInplaceElemwise::make({alpha, dest, beta, delta}, {mode, 1})
            .node()
            ->owner_opr();
234 235 236
}

SmallVector<TensorPtr> apply_inplace_add_on_physical_tensor(
237 238
        const OpDef& def, const SmallVector<TensorPtr>& inputs,
        SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
M
Megvii Engine Team 已提交
239
    auto dest = inputs[0], delta = inputs[1], alpha = inputs[2], beta = inputs[3];
240
    if (!inputs[0]->storage_is_unique()) {
241 242 243 244 245 246 247 248 249 250 251 252 253 254
        mgb_log_warn(
                "This inplace modification may change the elements of other tensors. "
                "Fallback to non-inplace update.");

        DeviceTensorStorage storage;
        storage.reset(dest->comp_node(), dest->blob()->size(), dest->blob()->storage());
        storage = storage.sub(dest->offset());
        DeviceTensorND dv;
        dv.reset(storage, dest->layout());

        DeviceTensorND dv_new;
        dv_new.copy_from(dv);
        dest = Tensor::make(dv_new);
    }
255 256 257 258
    auto tensor_to_scalar = [](const TensorPtr& tensor) -> float {
        return *tensor->get_value().ptr<float>();
    };
    DnnOprCaller<megdnn::AddUpdate> caller{dest->comp_node()};
M
Megvii Engine Team 已提交
259
    caller.op->param() = {tensor_to_scalar(alpha), tensor_to_scalar(beta)};
260
    caller.op->exec(dest->dev_tensor().as_megdnn(), delta->dev_tensor().as_megdnn());
M
Megvii Engine Team 已提交
261
    return {std::make_shared<Tensor>(dest->blob(), dest->offset(), dest->layout())};
262 263 264
}

std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_inplace_add_output_attrs_fallible(
M
Megvii Engine Team 已提交
265
        const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
266 267
    mgb_assert(inputs.size() == 4, "invalid input number for inplace_add");
    CompNode cn;
M
Megvii Engine Team 已提交
268
    for (auto&& input : inputs) {
269 270 271 272 273 274
        if (!cn.valid()) {
            cn = input.comp_node;
        } else {
            mgb_assert(input.comp_node == cn, "inputs should be in same comp_node");
        }
    }
M
Megvii Engine Team 已提交
275
    auto dest = inputs[0], delta = inputs[1], alpha = inputs[2], beta = inputs[3];
276 277
    bool succeed = dest.layout.ndim != 0;
    if (succeed) {
M
Megvii Engine Team 已提交
278 279 280 281 282 283 284 285 286
        mgb_assert(
                delta.layout.ndim == 0 || dest.layout.eq_shape(delta.layout),
                "dest and delta must have same shape");
        mgb_assert(
                alpha.layout.ndim == 0 || alpha.layout.eq_shape({1}),
                "alpha should be scalar");
        mgb_assert(
                beta.layout.ndim == 0 || beta.layout.eq_shape({1}),
                "beta should be scalar");
287 288 289
    }
    mgb_assert(alpha.layout.dtype == dtype::Float32(), "alpha should be float32");
    mgb_assert(beta.layout.dtype == dtype::Float32(), "beta should be float32");
290 291
    // inplace op result's desc value is changed
    return {{{dest.layout, dest.comp_node}}, succeed};
292 293
}

294
OP_TRAIT_REG(Elemwise, Elemwise, opr::Elemwise)
M
Megvii Engine Team 已提交
295 296 297 298 299 300 301
        .make_from_op_node(make_from_op_node)
        .decide_dispatch_mode(decide_dispatch_mode)
        .apply_on_var_node(apply_on_var_node)
        .infer_output_attrs_fallible(infer_output_attrs_fallible)
        .apply_on_device_tensornd(apply_on_device_tensornd)
        .apply_on_physical_tensor(apply_on_physical_tensor)
        .fallback();
302 303

OP_TRAIT_REG(InplaceAdd, InplaceAdd, opr::AddUpdate)
M
Megvii Engine Team 已提交
304 305 306 307 308
        .apply_on_var_node(apply_inplace_add_on_var_node)
        .apply_on_physical_tensor(apply_inplace_add_on_physical_tensor)
        .infer_output_attrs_fallible(infer_inplace_add_output_attrs_fallible)
        .fallback();
}  // anonymous namespace
309 310 311 312 313

}  // namespace imperative
}  // namespace mgb

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}