tensor_manip.cpp 7.6 KB
Newer Older
1
/**
M
Megvii Engine Team 已提交
2 3
 * \file imperative/src/impl/ops/tensor_manip.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
4
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6
 *
M
Megvii Engine Team 已提交
7 8 9
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 11
 */

12
#include "megbrain/imperative/ops/autogen.h"
13 14 15 16 17 18
#include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/opr/tensor_manip.h"
#include "../op_trait.h"

namespace mgb::imperative {

19
namespace get_var_shape {
20 21 22
cg::OperatorNodeBase* apply_on_var_node(
        const OpDef& def,
        const VarNodeArray& inputs) {
23
    auto&& op_def = def.cast_final_safe<GetVarShape>();
24 25
    OperatorNodeConfig config{op_def.make_name()};
    return opr::GetVarShape::make(inputs, op_def.param(), config).node()->owner_opr();
26 27
}

28
DispatchMode decide_dispatch_mode(
29
        const OpDef& def,
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
        const SmallVector<LogicalTensorDesc>& inputs) {
    bool host_computable = true;
    for (auto&& inp : inputs) {
        // FIXME(czh): remove value chech after proxy graph's
        // apply_on_device_tensornd is supported and output Tensor
        // is made before add_task.
        // then if layout is valid, ptr->layout must be ready
        if (inp.value.empty() || inp.value.layout().ndim == 0) {
            host_computable = false;
            break;
        }
    }
    return host_computable ? DEFAULT_CPU : KERNEL;
}

void apply_on_device_tensornd(
        const OpDef& def,
        const SmallVector<DeviceTensorND>& inputs,
        SmallVector<DeviceTensorND>* outputs) {
49
    auto&& op_def = def.cast_final_safe<GetVarShape>();
50 51
    mgb_assert(inputs.size() == 1, "GetVarShape take 1 input, got %lu", inputs.size());
    auto&& inp = inputs[0];
52
    auto&& shp = inp.layout();
53
    mgb_assert(shp.ndim != 0, "input shape invalid");
54 55 56
    mgb_assert((*outputs)[0].comp_node() == CompNode::default_cpu(),
        "GetVarShape's apply_on_device_tensornd should receive default_cpu outputs.");

57
    HostTensorND hv;
58 59
    if (op_def.axis == opr::GetVarShape::Param::INVALID_AXIS) {
        hv = HostTensorND(CompNode::default_cpu(), {shp.ndim}, dtype::Int32());
60 61 62 63 64
        auto* ptr = hv.ptr<dt_int32>();
        for (size_t i = 0; i < shp.ndim; ++i) {
            ptr[i] = shp.shape[i];
        }
    }else{
65 66 67 68 69
        int32_t axis = op_def.axis;
        if (axis < 0) {
            axis += shp.ndim;
        }
        mgb_assert(axis >= 0 && axis < (int32_t)shp.ndim);
70
        hv = HostTensorND(CompNode::default_cpu(), {1}, dtype::Int32());
71
        auto* ptr = hv.ptr<dt_int32>();
72
        ptr[0] = shp.shape[axis];
73
    }
74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
    (*outputs)[0] = DeviceTensorND::make_proxy(hv);
}

SmallVector<TensorPtr> apply_on_physical_tensor(
        const OpDef& def,
        const SmallVector<TensorPtr>& inputs) {
    SmallVector<DeviceTensorND> input_tensornds;
    input_tensornds.reserve(inputs.size());
    for (auto&& inp : inputs) {
        input_tensornds.push_back(inp->dev_tensor());
    }
    SmallVector<DeviceTensorND> output_tensornds = {{CompNode::default_cpu(), dtype::Int32()}};

    apply_on_device_tensornd(def, input_tensornds, &output_tensornds);

    // restore to input comp_node
    HostTensorND host_tensornd = HostTensorND::make_proxy(output_tensornds[0])
        .proxy_to_comp_node(inputs[0]->comp_node());
    return {Tensor::make(std::move(host_tensornd))};
93 94
}

95
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
96 97
        const OpDef& def,
        const SmallVector<LogicalTensorDesc>& inputs) {
98
    auto&& op_def = def.cast_final_safe<GetVarShape>();
99 100 101
    mgb_assert(inputs.size() == 1, "GetVarShape take 1 input, got %lu", inputs.size());
    auto&& desc = inputs[0];
    if (!desc.layout.ndim) {
102
        return {{{TensorLayout(dtype::Int32()), desc.comp_node}}, false};
103
    }
104
    DeviceTensorND value;
105
    if (op_def.axis == opr::GetVarShape::Param::INVALID_AXIS) {
106 107 108 109 110 111
        value = DeviceTensorND(CompNode::default_cpu(), {desc.layout.ndim}, dtype::Int32());
        auto* ptr = value.ptr<dt_int32>();
        for (size_t i = 0; i < desc.layout.ndim; ++i) {
            ptr[i] = desc.layout[i];
        }
    }else{
112 113 114 115 116
        int32_t axis = op_def.axis;
        if (axis < 0) {
            axis += desc.layout.ndim;
        }
        mgb_assert(axis >= 0 && axis < (int32_t)desc.layout.ndim);
117 118
        value = DeviceTensorND(CompNode::default_cpu(), {1}, dtype::Int32());
        auto* ptr = value.ptr<dt_int32>();
119
        ptr[0] = desc.layout[axis];
120
    }
121
    return {{{value.layout(), desc.comp_node, std::move(value)}}, true};
122 123 124 125
}

std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
    auto* node = &node_->cast_final_safe<opr::GetVarShape>();
126
    return GetVarShape::make(node->param());
127 128 129 130
}

OP_TRAIT_REG(GetVarShape, GetVarShape, opr::GetVarShape)
    .make_from_op_node(make_from_op_node)
131
    .decide_dispatch_mode(decide_dispatch_mode)
132 133
    .infer_output_attrs_fallible(infer_output_attrs_fallible)
    .apply_on_var_node(apply_on_var_node)
134
    .apply_on_device_tensornd(apply_on_device_tensornd)
135 136
    .apply_on_physical_tensor(apply_on_physical_tensor)
    .fallback();
137
} // get_var_shape
138

139
namespace param_pack {
140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
TensorShapeArray get_shapes(const std::vector<std::vector<size_t>>& shapes) {
    TensorShapeArray ret;
    for (auto&& i:shapes) {
        SmallVector<size_t> shape(i.begin(), i.end());
        TensorShape shp(shape);
        ret.push_back(shp);
    }
    return ret;
}

cg::OperatorNodeBase* param_pack_split_apply_on_var_node(
        const OpDef& def, const VarNodeArray& inputs) {
    auto&& param = def.cast_final_safe<ParamPackSplit>();
    auto&& graph = inputs[0]->owner_graph();

    auto&& shapes = get_shapes(param.shapes);
156
    OperatorNodeConfig config(param.make_name());
157 158 159 160 161 162 163 164 165
    cg::OperatorNodeBase* opr =
            graph->insert_opr(std::make_unique<mgb::opr::ParamPackSplit>(
                    inputs[0], param.offsets, shapes, config));
    return opr;
}

SmallVector<TensorPtr> param_pack_split_apply_on_physical_tensor(
        const OpDef& def,
        const SmallVector<TensorPtr>& inputs) {
166
    auto&& param = def.cast_final_safe<ParamPackSplit>();
167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192
    mgb_assert(inputs.size() == 1, "ParamPackSplit take 1 input, got %lu", inputs.size());
    auto&& inp = inputs[0];
    auto&& shp = inp->layout();
    mgb_assert(shp.ndim == 1, "ParamPackSplit input shape invalid, ndim should be 1");
    mgb_assert(param.shapes.size() * 2 == param.offsets.size());
    SmallVector<TensorPtr> ret;
    auto&& shapes = get_shapes(param.shapes);
    size_t dtype_size = inputs[0]->layout().dtype.size();
    for (size_t i = 0; i < shapes.size(); ++i) {
        ret.push_back(
                inputs[0]->sub(param.offsets[i * 2] * dtype_size, shapes[i]));
    }
    return ret;
}

OP_TRAIT_REG(ParamPackSplit, ParamPackSplit, mgb::opr::ParamPackSplit)
        .apply_on_var_node(param_pack_split_apply_on_var_node)
        .apply_on_physical_tensor(param_pack_split_apply_on_physical_tensor)
        .fallback();

cg::OperatorNodeBase* param_pack_concat_apply_on_var_node(
        const OpDef& def, const VarNodeArray& inputs) {
    auto&& param = def.cast_final_safe<ParamPackConcat>();
    auto&& graph = inputs[0]->owner_graph();

    VarNodeArray inps(inputs.begin(), inputs.end() - 1);
193
    OperatorNodeConfig config{param.make_name()};
194 195 196 197 198 199 200 201 202
    cg::OperatorNodeBase* opr =
            graph->insert_opr(std::make_unique<mgb::opr::ParamPackConcat>(
                    inps, inputs.back(), param.offsets, config));
    return opr;
}

OP_TRAIT_REG(ParamPackConcat, ParamPackConcat, mgb::opr::ParamPackConcat)
        .apply_on_var_node(param_pack_concat_apply_on_var_node)
        .fallback();
203
} // param_pack
204 205

} // namespace mgb::imperative