tensor_manip.cpp 13.2 KB
Newer Older
M
Megvii Engine Team 已提交
1
#include "megbrain/opr/tensor_manip.h"
2
#include "megbrain/imperative/ops/autogen.h"
3
#include "megbrain/imperative/ops/opr_attr.h"
4 5

#include "../dnn_op_helper.h"
6 7 8 9
#include "../op_trait.h"

namespace mgb::imperative {

10
namespace get_var_shape {
M
Megvii Engine Team 已提交
11
cg::OperatorNodeBase* apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
12
    auto&& op_def = def.cast_final_safe<GetVarShape>();
13 14
    OperatorNodeConfig config{op_def.make_name()};
    return opr::GetVarShape::make(inputs, op_def.param(), config).node()->owner_opr();
15 16
}

17
DispatchMode decide_dispatch_mode(
M
Megvii Engine Team 已提交
18
        const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
19 20
    bool host_computable = true;
    for (auto&& inp : inputs) {
21
        // FIXME(czh): remove value check after proxy graph's
22 23 24 25 26 27 28 29 30 31 32 33
        // apply_on_device_tensornd is supported and output Tensor
        // is made before add_task.
        // then if layout is valid, ptr->layout must be ready
        if (inp.value.empty() || inp.value.layout().ndim == 0) {
            host_computable = false;
            break;
        }
    }
    return host_computable ? DEFAULT_CPU : KERNEL;
}

void apply_on_device_tensornd(
M
Megvii Engine Team 已提交
34
        const OpDef& def, const SmallVector<DeviceTensorND>& inputs,
35
        SmallVector<DeviceTensorND>* outputs) {
36
    auto&& op_def = def.cast_final_safe<GetVarShape>();
37 38 39 40 41 42 43 44 45 46 47 48

    TensorShape shp;
    if (inputs.size() == 1) {
        shp = inputs[0].layout();
    } else {
        TensorShapeArray src(inputs.size());
        for (size_t i = 0; i < inputs.size(); ++i) {
            src[i] = inputs[i].layout();
        }
        megdnn::Elemwise::deduce_shape(src, shp);
    }

49
    mgb_assert(shp.ndim != 0, "input shape invalid");
M
Megvii Engine Team 已提交
50 51 52 53
    mgb_assert(
            (*outputs)[0].comp_node() == CompNode::default_cpu(),
            "GetVarShape's apply_on_device_tensornd should receive default_cpu "
            "outputs.");
54

55
    HostTensorND hv;
56 57
    if (op_def.axis == opr::GetVarShape::Param::INVALID_AXIS) {
        hv = HostTensorND(CompNode::default_cpu(), {shp.ndim}, dtype::Int32());
58 59 60 61
        auto* ptr = hv.ptr<dt_int32>();
        for (size_t i = 0; i < shp.ndim; ++i) {
            ptr[i] = shp.shape[i];
        }
M
Megvii Engine Team 已提交
62
    } else {
63 64 65 66 67
        int32_t axis = op_def.axis;
        if (axis < 0) {
            axis += shp.ndim;
        }
        mgb_assert(axis >= 0 && axis < (int32_t)shp.ndim);
68
        hv = HostTensorND(CompNode::default_cpu(), {1}, dtype::Int32());
69
        auto* ptr = hv.ptr<dt_int32>();
70
        ptr[0] = shp.shape[axis];
71
    }
72 73 74
    (*outputs)[0] = DeviceTensorND::make_proxy(hv);
}

M
Megvii Engine Team 已提交
75 76
HostTensorND get_var_shape_host_tensor(
        const OpDef& def, const SmallVector<TensorPtr>& inputs) {
77 78
    SmallVector<DeviceTensorND> input_tensornds;
    for (auto&& inp : inputs) {
79
        input_tensornds.push_back(inp->dev_tensor(false));
80
    }
M
Megvii Engine Team 已提交
81 82
    SmallVector<DeviceTensorND> output_tensornds = {
            {CompNode::default_cpu(), dtype::Int32()}};
83 84
    apply_on_device_tensornd(def, input_tensornds, &output_tensornds);
    // restore to input comp_node
M
Megvii Engine Team 已提交
85 86
    return HostTensorND::make_proxy(output_tensornds[0])
            .proxy_to_comp_node(inputs[0]->comp_node());
87 88 89
}

SmallVector<TensorPtr> apply_on_physical_tensor(
90 91
        const OpDef& def, const SmallVector<TensorPtr>& inputs,
        SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
92
    return {Tensor::make(get_var_shape_host_tensor(def, inputs))};
93 94
}

95
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
M
Megvii Engine Team 已提交
96
        const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
97
    auto&& op_def = def.cast_final_safe<GetVarShape>();
98
    auto&& desc = inputs[0];
99 100 101 102 103 104 105
    TensorShape shp;
    if (inputs.size() == 1) {
        shp = desc.layout;
    } else {
        TensorShapeArray src(inputs.size());
        for (size_t i = 0; i < inputs.size(); ++i) {
            src[i] = inputs[i].layout;
106 107 108
            if (!src[i].ndim) {
                return {{{TensorLayout(dtype::Int32()), desc.comp_node}}, false};
            }
109 110 111 112
        }
        megdnn::Elemwise::deduce_shape(src, shp);
    }
    if (!shp.ndim) {
113
        return {{{TensorLayout(dtype::Int32()), desc.comp_node}}, false};
114
    }
115
    DeviceTensorND value;
116
    if (op_def.axis == opr::GetVarShape::Param::INVALID_AXIS) {
117
        value = DeviceTensorND(CompNode::default_cpu(), {shp.ndim}, dtype::Int32());
118
        auto* ptr = value.ptr<dt_int32>();
119 120
        for (size_t i = 0; i < shp.ndim; ++i) {
            ptr[i] = shp[i];
121
        }
M
Megvii Engine Team 已提交
122
    } else {
123 124
        int32_t axis = op_def.axis;
        if (axis < 0) {
125
            axis += shp.ndim;
126
        }
127
        mgb_assert(axis >= 0 && axis < (int32_t)shp.ndim);
128 129
        value = DeviceTensorND(CompNode::default_cpu(), {1}, dtype::Int32());
        auto* ptr = value.ptr<dt_int32>();
130
        ptr[0] = shp[axis];
131
    }
132
    return {{{value.layout(), desc.comp_node, std::move(value)}}, true};
133 134 135 136
}

std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
    auto* node = &node_->cast_final_safe<opr::GetVarShape>();
137
    return GetVarShape::make(node->param());
138 139 140
}

OP_TRAIT_REG(GetVarShape, GetVarShape, opr::GetVarShape)
M
Megvii Engine Team 已提交
141 142 143 144 145 146 147 148
        .make_from_op_node(make_from_op_node)
        .decide_dispatch_mode(decide_dispatch_mode)
        .infer_output_attrs_fallible(infer_output_attrs_fallible)
        .apply_on_var_node(apply_on_var_node)
        .apply_on_device_tensornd(apply_on_device_tensornd)
        .apply_on_physical_tensor(apply_on_physical_tensor)
        .fallback();
}  // namespace get_var_shape
149

150
namespace param_pack {
151 152
TensorShapeArray get_shapes(const std::vector<std::vector<size_t>>& shapes) {
    TensorShapeArray ret;
M
Megvii Engine Team 已提交
153
    for (auto&& i : shapes) {
154 155 156 157 158 159 160 161 162 163 164 165 166
        SmallVector<size_t> shape(i.begin(), i.end());
        TensorShape shp(shape);
        ret.push_back(shp);
    }
    return ret;
}

cg::OperatorNodeBase* param_pack_split_apply_on_var_node(
        const OpDef& def, const VarNodeArray& inputs) {
    auto&& param = def.cast_final_safe<ParamPackSplit>();
    auto&& graph = inputs[0]->owner_graph();

    auto&& shapes = get_shapes(param.shapes);
167
    OperatorNodeConfig config(param.make_name());
168 169 170 171 172 173 174
    cg::OperatorNodeBase* opr =
            graph->insert_opr(std::make_unique<mgb::opr::ParamPackSplit>(
                    inputs[0], param.offsets, shapes, config));
    return opr;
}

SmallVector<TensorPtr> param_pack_split_apply_on_physical_tensor(
175 176
        const OpDef& def, const SmallVector<TensorPtr>& inputs,
        SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
177
    auto&& param = def.cast_final_safe<ParamPackSplit>();
M
Megvii Engine Team 已提交
178 179
    mgb_assert(
            inputs.size() == 1, "ParamPackSplit take 1 input, got %lu", inputs.size());
180 181 182 183 184 185 186 187
    auto&& inp = inputs[0];
    auto&& shp = inp->layout();
    mgb_assert(shp.ndim == 1, "ParamPackSplit input shape invalid, ndim should be 1");
    mgb_assert(param.shapes.size() * 2 == param.offsets.size());
    SmallVector<TensorPtr> ret;
    auto&& shapes = get_shapes(param.shapes);
    size_t dtype_size = inputs[0]->layout().dtype.size();
    for (size_t i = 0; i < shapes.size(); ++i) {
188
        // memory forward
M
Megvii Engine Team 已提交
189
        ret.push_back(inputs[0]->sub(param.offsets[i * 2] * dtype_size, shapes[i]));
190 191 192 193 194 195 196 197 198 199 200 201 202 203 204
    }
    return ret;
}

OP_TRAIT_REG(ParamPackSplit, ParamPackSplit, mgb::opr::ParamPackSplit)
        .apply_on_var_node(param_pack_split_apply_on_var_node)
        .apply_on_physical_tensor(param_pack_split_apply_on_physical_tensor)
        .fallback();

cg::OperatorNodeBase* param_pack_concat_apply_on_var_node(
        const OpDef& def, const VarNodeArray& inputs) {
    auto&& param = def.cast_final_safe<ParamPackConcat>();
    auto&& graph = inputs[0]->owner_graph();

    VarNodeArray inps(inputs.begin(), inputs.end() - 1);
205
    OperatorNodeConfig config{param.make_name()};
206 207 208 209 210 211
    cg::OperatorNodeBase* opr =
            graph->insert_opr(std::make_unique<mgb::opr::ParamPackConcat>(
                    inps, inputs.back(), param.offsets, config));
    return opr;
}

212
SmallVector<TensorPtr> param_pack_concat_apply_on_physical_tensor(
213 214
        const OpDef& def, const SmallVector<TensorPtr>& inputs,
        SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
215 216 217 218 219 220 221 222
    def.cast_final_safe<ParamPackConcat>();
    mgb_assert(inputs.size() > 1, "param_pack should have at least one input");
    auto comp_node = inputs.front()->comp_node();
    auto dtype = inputs.front()->dtype();
    size_t nr_inputs = inputs.size() - 1;
    size_t nr_elems = 0;
    for (size_t i = 0; i < nr_inputs; ++i) {
        auto& input = inputs[i];
M
Megvii Engine Team 已提交
223 224 225 226 227 228
        mgb_assert(
                comp_node == input->comp_node(),
                "inputs for param_pack_concat must in same comp_node");
        mgb_assert(
                dtype == input->dtype(),
                "inputs for param_pack_concat must have same dtype");
229 230 231 232
        nr_elems += input->layout().total_nr_elems();
    }
    auto dest_layout = TensorLayout({nr_elems}, dtype);
    auto output = Tensor::make(dest_layout, comp_node);
233 234 235 236
    // FIXME: add param to ParamPackConcat
    DnnOprCaller<megdnn::ParamPackConcat> caller{comp_node};
    HostTensorStorage srcs_storage{comp_node};
    srcs_storage.ensure_size(sizeof(void*) * nr_inputs);
237
    TensorLayout srcs_layout = TensorLayout{{nr_inputs}, dtype::Int32()};
238 239 240
    HostTensorND srcs_tensornd;
    srcs_tensornd.reset(srcs_storage, srcs_layout);
    auto* srcs_raw_ptr = reinterpret_cast<void**>(srcs_storage.ptr());
241
    for (size_t i = 0; i < nr_inputs; ++i) {
242
        srcs_raw_ptr[i] = inputs[i]->dnn_tensor().raw_ptr();
243
    }
244 245
    caller.exec_with_ws(srcs_tensornd.as_megdnn(), inputs.back(), output);
    async_release(srcs_tensornd);
M
Megvii Engine Team 已提交
246
    return {output};
247 248
}

249 250
OP_TRAIT_REG(ParamPackConcat, ParamPackConcat, mgb::opr::ParamPackConcat)
        .apply_on_var_node(param_pack_concat_apply_on_var_node)
251
        .apply_on_physical_tensor(param_pack_concat_apply_on_physical_tensor)
252
        .fallback();
M
Megvii Engine Team 已提交
253
}  // namespace param_pack
254

255 256 257 258 259 260
namespace split {
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
    using Options = opr::Split::Options;
    auto* node = &node_->cast_final_safe<opr::Split>();
    auto&& opt = node->options();
    int axis = opt.axis;
M
Megvii Engine Team 已提交
261 262 263
    mgb_assert(
            opt.method == Options::Method::SPECIFY,
            "only Split with SPECIFY output shapes is supported");
264
    mgb_assert(opt.partition.size() == opt.nr_part);
265
    return Split::make(axis, 0);
266 267 268 269 270 271 272
}

auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
    using Options = opr::Split::Options;
    auto&& sp = static_cast<const Split&>(def);
    OperatorNodeConfig config{sp.make_name()};
    opr::Split::Options opt;
273 274 275 276 277 278 279 280 281 282 283 284
    if (sp.nsections) {
        opt = Options::make_average(sp.axis, sp.nsections);
        opt.method = Options::Method::CALL_BACK;
    } else {
        opt.axis = sp.axis;
        opt.method = Options::Method::SPECIFY;
        mgb_assert(inputs.size() > 1);
        opt.nr_part = inputs.size() - 1;
        opt.partition.resize(opt.nr_part);
        for (size_t i = 1; i < inputs.size(); ++i)
            opt.partition[i - 1] = inputs[i];
    }
285 286 287 288 289 290 291 292
    return opr::Split::make(inputs[0], opt, config);
}

OP_TRAIT_REG(Split, Split, opr::Split)
        .make_from_op_node(make_from_op_node)
        .apply_on_var_node(apply_on_var_node)
        .fallback();

M
Megvii Engine Team 已提交
293
}  // namespace split
294

295 296 297 298 299 300 301 302 303 304 305 306 307 308
namespace masked_fill {

cg::OperatorNodeBase* apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
    auto&& op_def = def.cast_final_safe<MaskedFill>();
    OperatorNodeConfig config{op_def.make_name()};
    mgb_assert(inputs.size() == 2);
    return opr::MaskedFill::make(inputs[0], inputs[1], op_def.param(), config)
            .node()
            ->owner_opr();
}

SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint(
        const OpDef& def, const SmallVector<TensorPtr>& inputs) {
    SmallVector<VarNode::LayoutConstraintCallback> layout_checker(inputs.size());
309
    layout_checker[1] = [](const TensorLayout& layout) {
310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352
        return layout.is_contiguous();
    };
    return layout_checker;
}

std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
        const OpDef& def, const SmallVector<LogicalTensorDesc>& input_descs) {
    return {{{{input_descs[0].layout, input_descs[0].layout.dtype},
              input_descs[0].comp_node}},
            input_descs[0].layout.ndim != 0};
}

SmallVector<TensorPtr> apply_on_physical_tensor(
        const OpDef& def, const SmallVector<TensorPtr>& inputs,
        SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
    auto&& op = def.cast_final_safe<MaskedFill>();
    auto&& inp = inputs[0];
    auto&& mask = inputs[1];

    TensorLayout outlayout(inp->layout(), inp->layout().dtype);

    auto output = Tensor::make(outlayout, inp->comp_node());

    DnnOprCaller<megdnn::MaskedFill> dnn_opr{inp->comp_node(), op.param()};
    dnn_opr.exec_with_ws(inp, mask, output);
    return {output};
}

std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
    auto* node = &node_->cast_final_safe<opr::MaskedFill>();
    return MaskedFill::make(node->param());
}

OP_TRAIT_REG(MaskedFill, MaskedFill, mgb::opr::MaskedFill)
        .get_input_layout_constraint(get_input_layout_constraint)
        .infer_output_attrs_fallible(infer_output_attrs_fallible)
        .apply_on_physical_tensor(apply_on_physical_tensor)
        .apply_on_var_node(apply_on_var_node)
        .make_from_op_node(make_from_op_node)
        .fallback();

}  // namespace masked_fill

M
Megvii Engine Team 已提交
353
}  // namespace mgb::imperative