elemwise.cpp 13.2 KB
Newer Older
1 2 3 4
/**
 * \file imperative/src/impl/ops/elemwise.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

M
Megvii Engine Team 已提交
12
#include "megbrain/imperative/opr_utility.h"
13 14
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/basic_arith.h"
15
#include "megbrain/opr/utility.h"
16

17
#include "../blob_manager_impl.h"
M
Megvii Engine Team 已提交
18 19
#include "../dnn_op_helper.h"
#include "../op_trait.h"
20 21 22 23 24 25 26 27 28 29 30

namespace mgb {
namespace imperative {

namespace {

std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
    auto* node = &node_->cast_final_safe<opr::Elemwise>();
    return Elemwise::make(node->param().mode);
}

M
Megvii Engine Team 已提交
31
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
32
    auto&& elemwise_opr = def.cast_final_safe<Elemwise>();
33
    OperatorNodeConfig config{elemwise_opr.make_name()};
34
    return opr::Elemwise::make(inputs, elemwise_opr.mode, config);
35 36
}

37
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
M
Megvii Engine Team 已提交
38
        const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
39
    auto&& op_def = def.cast_final_safe<Elemwise>();
40
    auto trait = megdnn::Elemwise::ModeTrait::from_mode(op_def.mode);
M
Megvii Engine Team 已提交
41 42 43
    mgb_assert(
            inputs.size() == trait.arity, "%s expects %u inputs; got %zu actually",
            trait.name, trait.arity, inputs.size());
44 45 46
    TensorShapeArray inp_shapes;
    DType out_dt;
    CompNode out_cn;
M
Megvii Engine Team 已提交
47 48
    for (size_t i = 0; i < inputs.size(); ++i) {
        auto&& t = inputs[i];
49 50 51 52 53 54 55 56 57 58 59 60 61
        if (!i) {
            out_cn = t.comp_node;
            out_dt = t.layout.dtype;
        } else {
            mgb_assert(t.comp_node == out_cn);
            mgb_assert(t.layout.dtype == out_dt);
        }
        if (t.layout.ndim > 0) {
            inp_shapes.push_back(t.layout);
        } else {
            TensorLayout out_layout;
            out_layout.ndim = 0;
            out_layout.dtype = out_dt;
62
            return {{{out_layout, out_cn}}, false};
63 64
        }
    }
65 66 67
    // copy from megdnn::ElemwiseForward::check_dtype
    switch (out_dt.category()) {
        case DTypeCategory::FLOAT:
M
Megvii Engine Team 已提交
68
            mgb_assert(trait.allow_float, "unsupport mode %s for float\n", trait.name);
69 70
            break;
        case DTypeCategory::INT:
M
Megvii Engine Team 已提交
71
            mgb_assert(trait.allow_int, "unsupport mode %s for int\n", trait.name);
72 73
            break;
        case DTypeCategory::BOOL:
M
Megvii Engine Team 已提交
74
            mgb_assert(trait.allow_bool, "unsupport mode %s for bool\n", trait.name);
75 76 77 78 79 80
            break;
        default:
            // Quantized Dtype could also be handled by this op,
            // but scales need to be the same.
            break;
    }
81 82

    auto&& out_shape = opr::Elemwise::get_output_var_shape(op_def.mode, inp_shapes);
83
    return {{{TensorLayout(out_shape, out_dt, inputs[0].layout.format), out_cn}}, true};
84 85
}

86
DispatchMode decide_dispatch_mode(
M
Megvii Engine Team 已提交
87
        const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
88 89 90
    bool host_computable = true;
    constexpr int size_threshhold = TensorShape::MAX_NDIM;
    for (auto&& inp : inputs) {
M
Megvii Engine Team 已提交
91 92
        if (inp.value.empty() || inp.value.layout().ndim == 0 ||
            inp.value.layout().total_nr_elems() > size_threshhold) {
93 94 95 96 97 98 99 100
            host_computable = false;
            break;
        }
    }
    return host_computable ? DEFAULT_CPU : KERNEL;
}

void apply_on_device_tensornd(
M
Megvii Engine Team 已提交
101
        const OpDef& def, const SmallVector<DeviceTensorND>& inputs,
102
        SmallVector<DeviceTensorND>* outputs) {
103 104
    auto&& op_def = def.cast_final_safe<Elemwise>();
    auto trait = megdnn::Elemwise::ModeTrait::from_mode(op_def.mode);
M
Megvii Engine Team 已提交
105 106 107 108 109
    mgb_assert(
            inputs.size() == trait.arity, "%s expects %u inputs; got %zu actually",
            trait.name, trait.arity, inputs.size());
    auto&& dnn_opr =
            opr::intl::create_megdnn_opr<megdnn::Elemwise>(inputs[0].comp_node());
110 111 112
    opr::Elemwise::perform(op_def.mode, (*outputs)[0], inputs, dnn_opr);
}

113
void execute(
M
Megvii Engine Team 已提交
114
        const OpDef& def, SmallVector<TensorPtr> inputs, SmallVector<TensorPtr> outputs,
115 116 117
        SmallVector<TensorPtr> workspace) {
    mgb_assert(outputs.size() == 1);
    SmallVector<DeviceTensorND> inp_tensornds(inputs.size());
M
Megvii Engine Team 已提交
118
    for (size_t i = 0; i < inputs.size(); ++i) {
119 120 121 122 123 124 125
        inp_tensornds[i] = inputs[i]->dev_tensor();
    }
    SmallVector<DeviceTensorND> out_tensornds = {outputs[0]->dev_tensor()};
    apply_on_device_tensornd(def, inp_tensornds, &out_tensornds);
}

std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc(
M
Megvii Engine Team 已提交
126
        const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors,
127 128 129
        const SmallVector<MemoryDesc>& inputs_mems) {
    auto&& op_def = def.cast_final_safe<Elemwise>();
    TensorShapeArray inp_shapes(inputs_tensors.size());
M
Megvii Engine Team 已提交
130
    for (size_t i = 0; i < inputs_tensors.size(); ++i) {
131 132 133
        inp_shapes[i] = inputs_tensors[i]->layout();
    }
    TensorShape shape = opr::Elemwise::get_output_var_shape(op_def.mode, inp_shapes);
M
Megvii Engine Team 已提交
134 135 136 137 138
    SmallVector<MemoryDesc> outputs = {
            {{shape, inputs_tensors[0]->dtype()},
             0,
             inputs_tensors[0]->comp_node(),
             StorageIdentifier::make(1)}};
139 140 141
    return {outputs, {}};
}

142
SmallVector<TensorPtr> apply_on_physical_tensor(
M
Megvii Engine Team 已提交
143
        const OpDef& def, const SmallVector<TensorPtr>& inputs) {
144
    auto&& op_def = def.cast_final_safe<Elemwise>();
145
    SmallVector<DeviceTensorND> inp_tensornds(inputs.size());
146
    TensorShapeArray inp_shapes(inputs.size());
M
Megvii Engine Team 已提交
147
    for (unsigned i = 0; i < inputs.size(); ++i) {
148
        inp_tensornds[i] = inputs[i]->dev_tensor();
149
        inp_shapes[i] = inputs[i]->layout();
150
    }
M
Megvii Engine Team 已提交
151 152 153
    TensorShape shape = opr::Elemwise::get_output_var_shape(op_def.mode, inp_shapes);
    DeviceTensorND out = BlobManager::inst()->alloc_workspace_with_defrag(
            inp_tensornds[0].comp_node(), {shape, inp_tensornds[0].layout().dtype});
154
    SmallVector<DeviceTensorND> oup_tensornds = {out};
155 156
    apply_on_device_tensornd(def, inp_tensornds, &oup_tensornds);
    return {Tensor::make(oup_tensornds[0])};
157 158
}

M
Megvii Engine Team 已提交
159 160 161
MGB_DEFINE_OPR_CLASS(
        ForceInplaceElemwise,
        cg::SingleCNOperatorNodeBaseT<opr::mixin::MegDNNOprHolder>)  //{
162
public:
M
Megvii Engine Team 已提交
163 164 165 166 167 168 169 170 171 172 173 174
struct Param {
    using Mode = megdnn::Elemwise::Param::Mode;
    Mode mode;
    size_t inplace_index;
};
using Mode = Param::Mode;
ForceInplaceElemwise(
        const VarNodeArray& inputs, Param param, OperatorNodeConfig config = {})
        : Super(inputs[0]->owner_graph(), config, "device_add_update", inputs),
          m_param{param} {
    for (auto* input : inputs) {
        add_input({input});
175
    }
M
Megvii Engine Team 已提交
176 177 178 179 180 181 182 183 184 185 186 187 188
    add_output(None)
            ->set_fwd_in2out_writable_force(input(param.inplace_index))
            .add_flag(VarNode::Flag::NO_MEM_RECLAIM);
}
static SymbolVar make(const VarNodeArray& inputs, Param param) {
    return SymbolVar{inputs[0]}.insert_single_output_opr<ForceInplaceElemwise>(
            inputs, param);
}
static cg::OperatorNodeBase* shallow_copy(
        const serialization::OprShallowCopyContext& ctx,
        const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs,
        const OperatorNodeConfig& config);

189
protected:
M
Megvii Engine Team 已提交
190 191 192 193 194 195 196 197 198 199 200 201 202 203 204
NodeProp* do_make_node_prop() const override {
    auto ret = Super::do_make_node_prop();
    ret->add_flag(NodeProp::Flag::FORCE_UPDATE_INPUT_VAR);
    return ret;
}
void create_megdnn_opr() override {
    auto opr = DnnOprCaller<megdnn::Elemwise>::create_operator(comp_node());
    opr->param().mode = m_param.mode;
    set_megdnn_opr(std::move(opr));
}
void scn_do_execute() override {
    auto to_dnnnd = [&](auto* var) { return var->dev_tensor().as_megdnn(); };
    megdnn::TensorNDArray inputs_dnnnd;
    for (auto* input : input()) {
        inputs_dnnnd.push_back(to_dnnnd(input));
205
    }
M
Megvii Engine Team 已提交
206 207 208 209 210 211 212 213 214 215 216 217 218
    mgb_assert(
            input(m_param.inplace_index)->contain_flag(VarNode::Flag::NO_SYS_MEM_ALLOC),
            "ForceInplaceElemwise cannot be applied in internal tensor");
    auto* out_dest = output(0);
    auto* opr = static_cast<megdnn::Elemwise*>(megdnn_opr());
    opr->exec(std::move(inputs_dnnnd), to_dnnnd(out_dest));
}
void init_output_static_infer_desc() override {
    using namespace cg::static_infer;

    owner_graph()->static_infer_manager().register_shape_infer(
            output(0), ShapeInferDesc::make_identity(input(m_param.inplace_index)));
}
219 220

private:
M
Megvii Engine Team 已提交
221 222 223 224
Param m_param;
void record_execute_deps(ExecDependencyArray& deps) override {
    record_megdnn_opr(deps);
}
225 226 227 228 229
};

MGB_DYN_TYPE_OBJ_FINAL_IMPL(ForceInplaceElemwise);

cg::OperatorNodeBase* ForceInplaceElemwise::shallow_copy(
M
Megvii Engine Team 已提交
230 231 232 233
        const serialization::OprShallowCopyContext& ctx,
        const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs,
        const OperatorNodeConfig& config) {
    auto&& opr = opr_.cast_final_safe<ForceInplaceElemwise>();
234
    auto* graph = ctx.owner_graph(opr, inputs);
M
Megvii Engine Team 已提交
235 236
    return graph->insert_opr(
            std::make_unique<ForceInplaceElemwise>(inputs, opr.m_param, config));
237 238 239 240 241
}

MGB_REG_OPR_SHALLOW_COPY(ForceInplaceElemwise, ForceInplaceElemwise::shallow_copy);

cg::OperatorNodeBase* apply_inplace_add_on_var_node(
M
Megvii Engine Team 已提交
242 243
        const OpDef& def, const VarNodeArray& inputs) {
    auto dest = inputs[0], delta = inputs[1], alpha = inputs[2], beta = inputs[3];
244
    auto mode = ForceInplaceElemwise::Param::Mode::FUSE_MUL_ADD4;
M
Megvii Engine Team 已提交
245 246 247
    return ForceInplaceElemwise::make({alpha, dest, beta, delta}, {mode, 1})
            .node()
            ->owner_opr();
248 249 250
}

SmallVector<TensorPtr> apply_inplace_add_on_physical_tensor(
M
Megvii Engine Team 已提交
251 252 253
        const OpDef& def, const SmallVector<TensorPtr>& inputs) {
    mgb_assert(
            inputs[0]->blob().use_count() == 2 && inputs[0]->blob()->storage().unique(),
254
            "This inplace modification may change the elements of other tensors. "
M
Megvii Engine Team 已提交
255 256 257
            "Please set MEGENGINE_INPLACE_UPDATE to 0 to ensure the program runs "
            "correctly.");
    auto dest = inputs[0], delta = inputs[1], alpha = inputs[2], beta = inputs[3];
258 259 260 261
    auto tensor_to_scalar = [](const TensorPtr& tensor) -> float {
        return *tensor->get_value().ptr<float>();
    };
    DnnOprCaller<megdnn::AddUpdate> caller{dest->comp_node()};
M
Megvii Engine Team 已提交
262
    caller.op->param() = {tensor_to_scalar(alpha), tensor_to_scalar(beta)};
263
    caller.op->exec(dest->dev_tensor().as_megdnn(), delta->dev_tensor().as_megdnn());
M
Megvii Engine Team 已提交
264
    return {std::make_shared<Tensor>(dest->blob(), dest->offset(), dest->layout())};
265 266
}

267
void execute_inplace(
M
Megvii Engine Team 已提交
268
        const OpDef& def, SmallVector<TensorPtr> inputs, SmallVector<TensorPtr> outputs,
269 270 271 272
        SmallVector<TensorPtr> workspace) {
    apply_inplace_add_on_physical_tensor(def, inputs);
}

M
Megvii Engine Team 已提交
273 274 275
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>>
infer_inplace_output_mem_desc(
        const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors,
276 277
        const SmallVector<MemoryDesc>& inputs_mems) {
    auto dest = inputs_tensors[0];
M
Megvii Engine Team 已提交
278 279 280
    SmallVector<MemoryDesc> outputs = {
            {dest->layout(), 0, dest->comp_node(),
             StorageIdentifier::make(&inputs_mems[0])}};
281 282 283
    return {outputs, {}};
}

284
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_inplace_add_output_attrs_fallible(
M
Megvii Engine Team 已提交
285
        const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
286 287
    mgb_assert(inputs.size() == 4, "invalid input number for inplace_add");
    CompNode cn;
M
Megvii Engine Team 已提交
288
    for (auto&& input : inputs) {
289 290 291 292 293 294
        if (!cn.valid()) {
            cn = input.comp_node;
        } else {
            mgb_assert(input.comp_node == cn, "inputs should be in same comp_node");
        }
    }
M
Megvii Engine Team 已提交
295
    auto dest = inputs[0], delta = inputs[1], alpha = inputs[2], beta = inputs[3];
296 297
    bool succeed = dest.layout.ndim != 0;
    if (succeed) {
M
Megvii Engine Team 已提交
298 299 300 301 302 303 304 305 306
        mgb_assert(
                delta.layout.ndim == 0 || dest.layout.eq_shape(delta.layout),
                "dest and delta must have same shape");
        mgb_assert(
                alpha.layout.ndim == 0 || alpha.layout.eq_shape({1}),
                "alpha should be scalar");
        mgb_assert(
                beta.layout.ndim == 0 || beta.layout.eq_shape({1}),
                "beta should be scalar");
307 308 309
    }
    mgb_assert(alpha.layout.dtype == dtype::Float32(), "alpha should be float32");
    mgb_assert(beta.layout.dtype == dtype::Float32(), "beta should be float32");
310 311
    // inplace op result's desc value is changed
    return {{{dest.layout, dest.comp_node}}, succeed};
312 313
}

314
OP_TRAIT_REG(Elemwise, Elemwise, opr::Elemwise)
M
Megvii Engine Team 已提交
315 316 317 318 319 320 321 322 323
        .make_from_op_node(make_from_op_node)
        .decide_dispatch_mode(decide_dispatch_mode)
        .apply_on_var_node(apply_on_var_node)
        .infer_output_attrs_fallible(infer_output_attrs_fallible)
        .apply_on_device_tensornd(apply_on_device_tensornd)
        .apply_on_physical_tensor(apply_on_physical_tensor)
        .infer_output_mem_desc(infer_output_mem_desc)
        .execute(execute)
        .fallback();
324 325

OP_TRAIT_REG(InplaceAdd, InplaceAdd, opr::AddUpdate)
M
Megvii Engine Team 已提交
326 327 328 329 330 331 332
        .apply_on_var_node(apply_inplace_add_on_var_node)
        .apply_on_physical_tensor(apply_inplace_add_on_physical_tensor)
        .infer_output_attrs_fallible(infer_inplace_add_output_attrs_fallible)
        .infer_output_mem_desc(infer_inplace_output_mem_desc)
        .execute(execute_inplace)
        .fallback();
}  // anonymous namespace
333 334 335 336 337

}  // namespace imperative
}  // namespace mgb

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}