tensor_manip.cpp 11.9 KB
Newer Older
1
/**
M
Megvii Engine Team 已提交
2 3
 * \file imperative/src/impl/ops/tensor_manip.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
4
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6
 *
M
Megvii Engine Team 已提交
7 8 9
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 11
 */

M
Megvii Engine Team 已提交
12
#include "megbrain/opr/tensor_manip.h"
13
#include "megbrain/imperative/ops/autogen.h"
14
#include "megbrain/imperative/ops/opr_attr.h"
15 16 17

#include "../async_releaser.h"
#include "../dnn_op_helper.h"
18 19 20 21
#include "../op_trait.h"

namespace mgb::imperative {

22
namespace get_var_shape {
M
Megvii Engine Team 已提交
23
cg::OperatorNodeBase* apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
24
    auto&& op_def = def.cast_final_safe<GetVarShape>();
25 26
    OperatorNodeConfig config{op_def.make_name()};
    return opr::GetVarShape::make(inputs, op_def.param(), config).node()->owner_opr();
27 28
}

29
DispatchMode decide_dispatch_mode(
M
Megvii Engine Team 已提交
30
        const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
31 32
    bool host_computable = true;
    for (auto&& inp : inputs) {
33
        // FIXME(czh): remove value check after proxy graph's
34 35 36 37 38 39 40 41 42 43 44 45
        // apply_on_device_tensornd is supported and output Tensor
        // is made before add_task.
        // then if layout is valid, ptr->layout must be ready
        if (inp.value.empty() || inp.value.layout().ndim == 0) {
            host_computable = false;
            break;
        }
    }
    return host_computable ? DEFAULT_CPU : KERNEL;
}

void apply_on_device_tensornd(
M
Megvii Engine Team 已提交
46
        const OpDef& def, const SmallVector<DeviceTensorND>& inputs,
47
        SmallVector<DeviceTensorND>* outputs) {
48
    auto&& op_def = def.cast_final_safe<GetVarShape>();
49 50 51 52 53 54 55 56 57 58 59 60

    TensorShape shp;
    if (inputs.size() == 1) {
        shp = inputs[0].layout();
    } else {
        TensorShapeArray src(inputs.size());
        for (size_t i = 0; i < inputs.size(); ++i) {
            src[i] = inputs[i].layout();
        }
        megdnn::Elemwise::deduce_shape(src, shp);
    }

61
    mgb_assert(shp.ndim != 0, "input shape invalid");
M
Megvii Engine Team 已提交
62 63 64 65
    mgb_assert(
            (*outputs)[0].comp_node() == CompNode::default_cpu(),
            "GetVarShape's apply_on_device_tensornd should receive default_cpu "
            "outputs.");
66

67
    HostTensorND hv;
68 69
    if (op_def.axis == opr::GetVarShape::Param::INVALID_AXIS) {
        hv = HostTensorND(CompNode::default_cpu(), {shp.ndim}, dtype::Int32());
70 71 72 73
        auto* ptr = hv.ptr<dt_int32>();
        for (size_t i = 0; i < shp.ndim; ++i) {
            ptr[i] = shp.shape[i];
        }
M
Megvii Engine Team 已提交
74
    } else {
75 76 77 78 79
        int32_t axis = op_def.axis;
        if (axis < 0) {
            axis += shp.ndim;
        }
        mgb_assert(axis >= 0 && axis < (int32_t)shp.ndim);
80
        hv = HostTensorND(CompNode::default_cpu(), {1}, dtype::Int32());
81
        auto* ptr = hv.ptr<dt_int32>();
82
        ptr[0] = shp.shape[axis];
83
    }
84 85 86
    (*outputs)[0] = DeviceTensorND::make_proxy(hv);
}

M
Megvii Engine Team 已提交
87 88
HostTensorND get_var_shape_host_tensor(
        const OpDef& def, const SmallVector<TensorPtr>& inputs) {
89 90 91 92
    SmallVector<DeviceTensorND> input_tensornds;
    for (auto&& inp : inputs) {
        input_tensornds.push_back(inp->dev_tensor());
    }
M
Megvii Engine Team 已提交
93 94
    SmallVector<DeviceTensorND> output_tensornds = {
            {CompNode::default_cpu(), dtype::Int32()}};
95 96
    apply_on_device_tensornd(def, input_tensornds, &output_tensornds);
    // restore to input comp_node
M
Megvii Engine Team 已提交
97 98
    return HostTensorND::make_proxy(output_tensornds[0])
            .proxy_to_comp_node(inputs[0]->comp_node());
99 100 101
}

SmallVector<TensorPtr> apply_on_physical_tensor(
102 103
        const OpDef& def, const SmallVector<TensorPtr>& inputs,
        SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
104
    return {Tensor::make(std::move(get_var_shape_host_tensor(def, inputs)))};
105 106
}

107
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
M
Megvii Engine Team 已提交
108
        const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
109
    auto&& op_def = def.cast_final_safe<GetVarShape>();
110
    auto&& desc = inputs[0];
111 112 113 114 115 116 117 118 119 120 121
    TensorShape shp;
    if (inputs.size() == 1) {
        shp = desc.layout;
    } else {
        TensorShapeArray src(inputs.size());
        for (size_t i = 0; i < inputs.size(); ++i) {
            src[i] = inputs[i].layout;
        }
        megdnn::Elemwise::deduce_shape(src, shp);
    }
    if (!shp.ndim) {
122
        return {{{TensorLayout(dtype::Int32()), desc.comp_node}}, false};
123
    }
124
    DeviceTensorND value;
125
    if (op_def.axis == opr::GetVarShape::Param::INVALID_AXIS) {
126
        value = DeviceTensorND(CompNode::default_cpu(), {shp.ndim}, dtype::Int32());
127
        auto* ptr = value.ptr<dt_int32>();
128 129
        for (size_t i = 0; i < shp.ndim; ++i) {
            ptr[i] = shp[i];
130
        }
M
Megvii Engine Team 已提交
131
    } else {
132 133
        int32_t axis = op_def.axis;
        if (axis < 0) {
134
            axis += shp.ndim;
135
        }
136
        mgb_assert(axis >= 0 && axis < (int32_t)shp.ndim);
137 138
        value = DeviceTensorND(CompNode::default_cpu(), {1}, dtype::Int32());
        auto* ptr = value.ptr<dt_int32>();
139
        ptr[0] = shp[axis];
140
    }
141
    return {{{value.layout(), desc.comp_node, std::move(value)}}, true};
142 143 144 145
}

std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
    auto* node = &node_->cast_final_safe<opr::GetVarShape>();
146
    return GetVarShape::make(node->param());
147 148 149
}

OP_TRAIT_REG(GetVarShape, GetVarShape, opr::GetVarShape)
M
Megvii Engine Team 已提交
150 151 152 153 154 155 156 157
        .make_from_op_node(make_from_op_node)
        .decide_dispatch_mode(decide_dispatch_mode)
        .infer_output_attrs_fallible(infer_output_attrs_fallible)
        .apply_on_var_node(apply_on_var_node)
        .apply_on_device_tensornd(apply_on_device_tensornd)
        .apply_on_physical_tensor(apply_on_physical_tensor)
        .fallback();
}  // namespace get_var_shape
158

159
namespace param_pack {
160 161
TensorShapeArray get_shapes(const std::vector<std::vector<size_t>>& shapes) {
    TensorShapeArray ret;
M
Megvii Engine Team 已提交
162
    for (auto&& i : shapes) {
163 164 165 166 167 168 169 170 171 172 173 174 175
        SmallVector<size_t> shape(i.begin(), i.end());
        TensorShape shp(shape);
        ret.push_back(shp);
    }
    return ret;
}

cg::OperatorNodeBase* param_pack_split_apply_on_var_node(
        const OpDef& def, const VarNodeArray& inputs) {
    auto&& param = def.cast_final_safe<ParamPackSplit>();
    auto&& graph = inputs[0]->owner_graph();

    auto&& shapes = get_shapes(param.shapes);
176
    OperatorNodeConfig config(param.make_name());
177 178 179 180 181 182 183
    cg::OperatorNodeBase* opr =
            graph->insert_opr(std::make_unique<mgb::opr::ParamPackSplit>(
                    inputs[0], param.offsets, shapes, config));
    return opr;
}

SmallVector<TensorPtr> param_pack_split_apply_on_physical_tensor(
184 185
        const OpDef& def, const SmallVector<TensorPtr>& inputs,
        SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
186
    auto&& param = def.cast_final_safe<ParamPackSplit>();
M
Megvii Engine Team 已提交
187 188
    mgb_assert(
            inputs.size() == 1, "ParamPackSplit take 1 input, got %lu", inputs.size());
189 190 191 192 193 194 195 196
    auto&& inp = inputs[0];
    auto&& shp = inp->layout();
    mgb_assert(shp.ndim == 1, "ParamPackSplit input shape invalid, ndim should be 1");
    mgb_assert(param.shapes.size() * 2 == param.offsets.size());
    SmallVector<TensorPtr> ret;
    auto&& shapes = get_shapes(param.shapes);
    size_t dtype_size = inputs[0]->layout().dtype.size();
    for (size_t i = 0; i < shapes.size(); ++i) {
197
        // memory forward
M
Megvii Engine Team 已提交
198
        ret.push_back(inputs[0]->sub(param.offsets[i * 2] * dtype_size, shapes[i]));
199 200 201 202 203 204 205 206 207 208 209 210 211 212 213
    }
    return ret;
}

OP_TRAIT_REG(ParamPackSplit, ParamPackSplit, mgb::opr::ParamPackSplit)
        .apply_on_var_node(param_pack_split_apply_on_var_node)
        .apply_on_physical_tensor(param_pack_split_apply_on_physical_tensor)
        .fallback();

cg::OperatorNodeBase* param_pack_concat_apply_on_var_node(
        const OpDef& def, const VarNodeArray& inputs) {
    auto&& param = def.cast_final_safe<ParamPackConcat>();
    auto&& graph = inputs[0]->owner_graph();

    VarNodeArray inps(inputs.begin(), inputs.end() - 1);
214
    OperatorNodeConfig config{param.make_name()};
215 216 217 218 219 220
    cg::OperatorNodeBase* opr =
            graph->insert_opr(std::make_unique<mgb::opr::ParamPackConcat>(
                    inps, inputs.back(), param.offsets, config));
    return opr;
}

221
SmallVector<TensorPtr> param_pack_concat_apply_on_physical_tensor(
222 223
        const OpDef& def, const SmallVector<TensorPtr>& inputs,
        SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
224 225 226 227 228 229 230 231
    def.cast_final_safe<ParamPackConcat>();
    mgb_assert(inputs.size() > 1, "param_pack should have at least one input");
    auto comp_node = inputs.front()->comp_node();
    auto dtype = inputs.front()->dtype();
    size_t nr_inputs = inputs.size() - 1;
    size_t nr_elems = 0;
    for (size_t i = 0; i < nr_inputs; ++i) {
        auto& input = inputs[i];
M
Megvii Engine Team 已提交
232 233 234 235 236 237
        mgb_assert(
                comp_node == input->comp_node(),
                "inputs for param_pack_concat must in same comp_node");
        mgb_assert(
                dtype == input->dtype(),
                "inputs for param_pack_concat must have same dtype");
238 239 240 241 242
        nr_elems += input->layout().total_nr_elems();
    }
    auto dest_layout = TensorLayout({nr_elems}, dtype);
    auto output = Tensor::make(dest_layout, comp_node);
    auto caller = DnnOprCaller<megdnn::ParamPackConcat>(comp_node);
M
Megvii Engine Team 已提交
243
    size_t srcs_size = sizeof(void*) * nr_inputs;
244
    void** srcs_raw_ptr = (void**)comp_node.alloc_host(srcs_size);
M
Megvii Engine Team 已提交
245 246 247
    std::shared_ptr<dt_byte> srcs_ptr = {
            (dt_byte*)srcs_raw_ptr,
            [comp_node](dt_byte* ptr) { comp_node.free_host(ptr); }};
248 249 250 251 252 253 254
    TensorLayout srcs_layout = TensorLayout{{nr_inputs}, dtype::Int32()};
    size_t ws_size;
    {
        TensorShapeArray src_shapes;
        for (size_t i = 0; i < nr_inputs; ++i) {
            src_shapes.push_back(inputs[i]->shape());
        }
M
Megvii Engine Team 已提交
255 256
        ws_size = caller.op->get_workspace_in_bytes(
                src_shapes, inputs.back()->shape(), TensorShape{});
257 258
    }
    for (size_t i = 0; i < nr_inputs; ++i) {
259
        srcs_raw_ptr[i] = inputs[i]->dev_tensor().as_megdnn().raw_ptr();
260 261 262
    }
    HostTensorStorage srcs_storage;
    srcs_storage.reset(comp_node, srcs_size, srcs_ptr);
M
Megvii Engine Team 已提交
263 264 265
    caller.op->exec(
            {srcs_raw_ptr, srcs_layout}, inputs.back()->dev_tensor().as_megdnn(),
            output->dev_tensor().as_megdnn(),
266
            caller.create_workspace({{ws_size}, dtype::Byte()}));
M
Megvii Engine Team 已提交
267 268 269
    AsyncReleaser::inst()->add(
            HostTensorND{comp_node, srcs_layout}.storage(srcs_storage));
    return {output};
270 271
}

272 273
OP_TRAIT_REG(ParamPackConcat, ParamPackConcat, mgb::opr::ParamPackConcat)
        .apply_on_var_node(param_pack_concat_apply_on_var_node)
274
        .apply_on_physical_tensor(param_pack_concat_apply_on_physical_tensor)
275
        .fallback();
M
Megvii Engine Team 已提交
276
}  // namespace param_pack
277

278 279 280 281 282 283
namespace split {
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
    using Options = opr::Split::Options;
    auto* node = &node_->cast_final_safe<opr::Split>();
    auto&& opt = node->options();
    int axis = opt.axis;
M
Megvii Engine Team 已提交
284 285 286
    mgb_assert(
            opt.method == Options::Method::SPECIFY,
            "only Split with SPECIFY output shapes is supported");
287
    mgb_assert(opt.partition.size() == opt.nr_part);
288
    return Split::make(axis, 0);
289 290 291 292 293 294 295
}

auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
    using Options = opr::Split::Options;
    auto&& sp = static_cast<const Split&>(def);
    OperatorNodeConfig config{sp.make_name()};
    opr::Split::Options opt;
296 297 298 299 300 301 302 303 304 305 306 307
    if (sp.nsections) {
        opt = Options::make_average(sp.axis, sp.nsections);
        opt.method = Options::Method::CALL_BACK;
    } else {
        opt.axis = sp.axis;
        opt.method = Options::Method::SPECIFY;
        mgb_assert(inputs.size() > 1);
        opt.nr_part = inputs.size() - 1;
        opt.partition.resize(opt.nr_part);
        for (size_t i = 1; i < inputs.size(); ++i)
            opt.partition[i - 1] = inputs[i];
    }
308 309 310 311 312 313 314 315
    return opr::Split::make(inputs[0], opt, config);
}

OP_TRAIT_REG(Split, Split, opr::Split)
        .make_from_op_node(make_from_op_node)
        .apply_on_var_node(apply_on_var_node)
        .fallback();

M
Megvii Engine Team 已提交
316
}  // namespace split
317

M
Megvii Engine Team 已提交
318
}  // namespace mgb::imperative