atlas_runtime_op.cpp 17.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#include "megbrain/opr/atlas_runtime_op.h"
#include <memory>
#include "megbrain/common.h"
#include "megbrain/graph/operator_node.h"
#include "megdnn/basic_types.h"
#include "megdnn/dtype.h"

#if MGB_ATLAS
#include "acl/acl_mdl.h"

using namespace mgb;
using namespace opr;

namespace {
/**
 * \brief get mgb shape from acl shape, batch from mgb
 */
M
Megvii Engine Team 已提交
18
TensorShape acl_shape_to_mgb_shape_for_output(aclmdlIODims acl_shape, size_t batch) {
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
    TensorShape ret;
    ret.ndim = acl_shape.dimCount;
    for (size_t i = 0; i < ret.ndim; ++i) {
        ret[i] = acl_shape.dims[i];
    }
    ret[0] = batch;
    return ret;
}

/**
 * \brief deduce the input shape from aclFormat and aipp config.
 *
 * \param acl_shape shape from om file
 * \param batch batchsize from mgb
 * \param enable_dynamic_batch True if set dynamic batch size
 * \param om_format layout format from om file
 * \param aipp_input_fmt input_format in static aipp config of om file
 */
TensorShape acl_shape_to_mgb_shape_for_input(
        aclmdlIODims acl_shape, size_t batch, bool enable_dynamic_batch,
        aclFormat om_format, AtlasRuntimeOpr::AippInputFormat aipp_input_fmt) {
    TensorShape ret;
    ret.ndim = acl_shape.dimCount;
M
Megvii Engine Team 已提交
42 43 44
    mgb_assert(
            ret.ndim == 4, "Unexpected ndim form aclmdlIODims expected 4, but got %zu",
            ret.ndim);
45 46 47 48
    for (size_t i = 0; i < ret.ndim; ++i) {
        ret[i] = acl_shape.dims[i];
    }
    if (enable_dynamic_batch) {
M
Megvii Engine Team 已提交
49 50 51 52 53
        mgb_assert(
                ret[0] == static_cast<size_t>(-1),
                "batch size expected to be -1 when enable dynamic "
                "batchsize, got: %zu\n",
                ret[0]);
54 55
        ret[0] = batch;
    } else {
M
Megvii Engine Team 已提交
56 57 58 59 60
        mgb_assert(
                ret[0] == batch,
                "batchsize mismatch if no dynamic batchsize enabled, "
                "expected: %zu got: %zu\n",
                ret[0], batch);
61 62 63
    }

    if (aipp_input_fmt != AtlasRuntimeOpr::AippInputFormat::NO_AIPP) {
M
Megvii Engine Team 已提交
64 65 66
        mgb_assert(
                om_format == ACL_FORMAT_NHWC,
                "om format should be NHWC if enable aipp");
67 68 69 70 71 72 73 74 75 76 77 78 79
    }

    return ret;
}

DType acl_dtype_to_mgb_dtype(aclDataType data_type) {
    switch (data_type) {
        case ACL_UINT8:
            return dtype::Uint8();
        case ACL_FLOAT16:
#if !MEGDNN_DISABLE_FLOAT16
            return dtype::Float16();
#else
M
Megvii Engine Team 已提交
80
            mgb_throw(MegBrainError, "Float16 support is disabled at compile time.");
81 82 83 84 85 86 87 88 89 90
#endif
        case ACL_FLOAT:
            return dtype::Float32();
        case ACL_INT8:
            return dtype::Int8();
        case ACL_INT16:
            return dtype::Int16();
        case ACL_INT32:
            return dtype::Int32();
        default:
M
Megvii Engine Team 已提交
91 92 93
            mgb_throw(
                    MegBrainError, "aclDataType %x is not supported by MegBrain.",
                    static_cast<int>(data_type));
94 95 96 97 98 99
    }
}

/**
 * \brief generate batch size which match the batch_choice
 */
M
Megvii Engine Team 已提交
100 101
SmallVector<size_t> gen_batch_vec(
        size_t origin_batch, const SmallVector<size_t>& batch_choices) {
102 103 104 105 106 107 108 109 110 111 112 113
    SmallVector<size_t> ret;
    size_t idx = 0;
    size_t nr_batch_choices = batch_choices.size();
    size_t batch = origin_batch;
    while (idx < nr_batch_choices) {
        size_t val = batch_choices[idx];
        while (batch >= batch_choices[idx]) {
            ret.push_back(val);
            batch -= val;
        }
        idx++;
    }
M
Megvii Engine Team 已提交
114 115 116
    mgb_assert(
            batch == 0, "Invalid batch size %zu, can not be generate by batch choices",
            origin_batch);
117 118 119 120 121 122 123 124 125

    return ret;
}

class PtrGetter {
public:
    PtrGetter(const VarNodeArray& vars) {
        for (auto&& var : vars) {
            m_ptrs.push_back(var->dev_tensor().raw_ptr());
M
Megvii Engine Team 已提交
126 127
            m_batch_in_bytes.push_back(
                    var->layout().stride[0] * var->layout().dtype.size());
128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
        }
    }

    std::pair<void*, size_t> get(size_t batch, size_t idx) {
        std::pair<void*, size_t> ret;
        ret.first = m_ptrs[idx];
        ret.second = batch * m_batch_in_bytes[idx];
        m_ptrs[idx] = reinterpret_cast<void*>(
                reinterpret_cast<uintptr_t>(ret.first) + ret.second);
        return ret;
    }

private:
    SmallVector<void*> m_ptrs;
    SmallVector<size_t> m_batch_in_bytes;
};

};  // namespace

/* ====================== AtlasRuntimeOpr ==================== */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(AtlasRuntimeOpr);
M
Megvii Engine Team 已提交
149 150 151
AtlasRuntimeOpr::AtlasRuntimeOpr(
        SharedBuffer buf, const std::pair<uint32_t, aclmdlDesc*>& model,
        const VarNodeArray& inputs, const OperatorNodeConfig& config)
152 153 154 155 156 157 158 159 160
        : Super(inputs[0]->owner_graph(), config, "atlas_runtime", inputs),
          m_buffer{std::move(buf)},
          m_model_id{model.first},
          m_model_desc{model.second} {
    mgb_assert(
            inputs[0]->comp_node().device_type() == CompNode::DeviceType::ATLAS,
            "AtlasRuntimeOpr can only be used on atlas comp node; "
            "got %s",
            inputs[0]->comp_node().to_string().c_str());
M
Megvii Engine Team 已提交
161 162 163
    mgb_assert(
            m_buffer.data() != nullptr ||
            (m_model_id != INVALID_MODEL_ID && m_model_desc != nullptr));
164 165 166 167 168

    for (auto i : inputs) {
        add_input({i});
    }
    if (m_model_id == INVALID_MODEL_ID && m_model_desc == nullptr) {
M
Megvii Engine Team 已提交
169 170
        MGB_ATLAS_CHECK(
                aclmdlLoadFromMem(m_buffer.data(), m_buffer.size(), &m_model_id));
171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189
        m_model_desc = aclmdlCreateDesc();
        MGB_ATLAS_CHECK(aclmdlGetDesc(m_model_desc, m_model_id));
        m_is_model_holder = true;
    }

    //! aipp input format
    m_aipp_input_format = SmallVector<AippInputFormat>(inputs.size());
    aclAippInfo aipp_info;
    for (size_t i = 0; i < inputs.size(); ++i) {
        aclError acl_err = aclmdlGetFirstAippInfo(m_model_id, i, &aipp_info);
        if (ACL_ERROR_NONE == acl_err) {
            switch (aipp_info.inputFormat) {
                case ACL_YUV420SP_U8:
                    m_aipp_input_format[i] = AippInputFormat::YUV420SP_U8;
                    break;
                case ACL_RGB888_U8:
                    m_aipp_input_format[i] = AippInputFormat::RGB888_U8;
                    break;
                default:
M
Megvii Engine Team 已提交
190 191 192
                    mgb_throw(
                            MegBrainError,
                            "Unsupported aclAippInputFormat for input %zu. ", i);
193 194 195 196 197 198 199 200 201 202 203 204 205
            }
        } else if (ACL_ERROR_NOT_STATIC_AIPP == acl_err) {
            m_aipp_input_format[i] = AippInputFormat::NO_AIPP;
        } else {
            MGB_ATLAS_CHECK(acl_err);
        }
    }

    size_t dynamic_index;
    auto errcode = aclmdlGetInputIndexByName(
            m_model_desc, ACL_DYNAMIC_TENSOR_NAME, &dynamic_index);
    if (errcode == ACL_ERROR_NONE) {
        aclmdlHW hw_info;
M
Megvii Engine Team 已提交
206
        MGB_ATLAS_CHECK(aclmdlGetDynamicHW(m_model_desc, dynamic_index, &hw_info));
207 208 209 210 211 212 213 214
        mgb_assert(hw_info.hwCount == 0, "Currently not support dynamic HW");
    }

    //! dynamic batch size
    aclmdlBatch acl_batch;
    MGB_ATLAS_CHECK(aclmdlGetDynamicBatch(m_model_desc, &acl_batch));
    if (acl_batch.batchCount) {
        size_t dynamic_data_size;
M
Megvii Engine Team 已提交
215
        dynamic_data_size = aclmdlGetInputSizeByIndex(m_model_desc, dynamic_index);
216 217 218 219
        m_dyn_batch_tensor = DeviceTensorND(
                inputs[0]->comp_node(), {{dynamic_data_size}, dtype::Uint8()});

        for (size_t i = 0; i < acl_batch.batchCount; ++i) {
M
Megvii Engine Team 已提交
220
            m_dyn_batch_choices.push_back(static_cast<size_t>(acl_batch.batch[i]));
221
        }
M
Megvii Engine Team 已提交
222 223 224
        std::sort(
                m_dyn_batch_choices.begin(), m_dyn_batch_choices.end(),
                std::greater<>());
225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243
    }

    //! add output
    size_t nr_outputs = aclmdlGetNumOutputs(m_model_desc);
    using F = VarNode::Flag;
    if (nr_outputs == 1) {
        add_output(None);
    } else {
        for (size_t i = 0; i < nr_outputs; ++i) {
            add_output(ssprintf("o%zu", i));
        }
    }
    if (!m_dyn_batch_choices.empty()) {
        /**
         * \warning If enable dynamic batchsize, the memory of output
         * should be the largest be the size with the largest batch_size, so we
         * set the flag to SYS_MEM_ALLOC.
         */
        for (size_t i = 0; i < nr_outputs; ++i) {
M
Megvii Engine Team 已提交
244
            output(i)->add_flag(F::NO_SYS_MEM_ALLOC).add_flag(F::NO_MEM_RECLAIM);
245 246 247 248 249 250 251 252 253 254 255 256 257
        }
    }
    add_equivalence_component<mgb::ScalarHash<const void*>>(m_buffer.data());
};

AtlasRuntimeOpr::~AtlasRuntimeOpr() {
    if (m_is_model_holder) {
        MGB_ATLAS_CHECK(aclmdlUnload(m_model_id));
        MGB_ATLAS_CHECK(aclmdlDestroyDesc(m_model_desc));
    }
}

void AtlasRuntimeOpr::scn_do_execute() {
M
Megvii Engine Team 已提交
258
    auto&& acl_env = CompNodeEnv::from_comp_node(input(0)->comp_node()).atlas_env();
259 260 261 262 263 264
    acl_env.activate();

    if (!m_dyn_batch_choices.empty()) {
        for (size_t i = 0; i < output().size(); i++) {
            auto output_size = aclmdlGetOutputSizeByIndex(m_model_desc, i);
            auto ovar = output(i);
265
            output_size = std::max<size_t>(
M
Megvii Engine Team 已提交
266
                    output_size, ovar->dtype().size(ovar->shape().total_nr_elems()));
267 268 269 270 271 272 273 274 275 276 277 278 279
            ovar->shape_alloc(ovar->shape(), output_size);
        }
    }

    PtrGetter input_getter(input());
    PtrGetter output_getter(output());

    bool enable_dynamic_batch = !m_dyn_batch_choices.empty();
    size_t nr_inputs = aclmdlGetNumInputs(m_model_desc);
    size_t nr_outputs = aclmdlGetNumOutputs(m_model_desc);
    size_t input_batch = input(0)->layout()[0];

    if (enable_dynamic_batch) {
M
Megvii Engine Team 已提交
280 281 282 283
        mgb_assert(
                nr_inputs == input().size() + 1,
                "nr inputs got from om model should be one more than got "
                "from megbrain");
284 285 286 287 288 289 290 291 292 293 294
    }
    SmallVector<size_t> batches_each_run;
    if (enable_dynamic_batch) {
        batches_each_run = gen_batch_vec(input_batch, m_dyn_batch_choices);
    } else {
        batches_each_run.push_back(input_batch);
    }

    for (auto&& batch : batches_each_run) {
        //! prepare input
        auto model_inputs = aclmdlCreateDataset();
M
Megvii Engine Team 已提交
295
        mgb_assert(model_inputs != nullptr, "failed to create atlas input dataset.");
296 297 298
        for (size_t i = 0; i < input().size(); i++) {
            auto value_pair = input_getter.get(batch, i);
            auto input_size = aclmdlGetInputSizeByIndex(m_model_desc, i);
299 300
            //! FIXME iff enable dynamic batchsize and dynamic aipp, the input
            //! size should be the size of aclmdlGetInputSizeByIndex.
301
            if (enable_dynamic_batch) {
M
Megvii Engine Team 已提交
302 303 304 305 306
                mgb_assert(
                        input_size ==
                                value_pair.second / batch * m_dyn_batch_choices[0],
                        "input %zu size mismatch, expected: %zu got: %zu", i,
                        input_size, value_pair.second / batch * m_dyn_batch_choices[0]);
307 308 309
            }
            aclDataBuffer* input_db =
                    aclCreateDataBuffer(value_pair.first, value_pair.second);
M
Megvii Engine Team 已提交
310 311 312 313 314
            mgb_assert(
                    input_db != nullptr,
                    "failed to create atlas input data buffer for input "
                    "%zu:%s.",
                    i, input(i)->cname());
315 316 317 318 319 320 321
            aclmdlAddDatasetBuffer(model_inputs, input_db);
        }
        //! append unit tensor for dynamic batch
        if (enable_dynamic_batch) {
            aclDataBuffer* input_db = aclCreateDataBuffer(
                    reinterpret_cast<void*>(m_dyn_batch_tensor.raw_ptr()),
                    m_dyn_batch_tensor.layout().span().dist_byte());
M
Megvii Engine Team 已提交
322 323 324 325
            mgb_assert(
                    input_db != nullptr,
                    "failed to create atlas input data buffer for dynamic "
                    "batch tensor.");
326 327 328 329 330 331 332 333 334
            MGB_ATLAS_CHECK(aclmdlAddDatasetBuffer(model_inputs, input_db));

            MGB_ATLAS_CHECK(aclmdlSetDynamicBatchSize(
                    m_model_id, model_inputs, input().size(),
                    static_cast<uint64_t>(batch)));
        }

        //! prepare output
        auto model_outputs = aclmdlCreateDataset();
M
Megvii Engine Team 已提交
335
        mgb_assert(model_outputs != nullptr, "failed to create atlas output dataset.");
336 337
        for (size_t i = 0; i < nr_outputs; i++) {
            auto value_pair = output_getter.get(batch, i);
338 339 340 341
            size_t output_size = value_pair.second;
            if (enable_dynamic_batch) {
                output_size = aclmdlGetOutputSizeByIndex(m_model_desc, i);
            }
342
            aclDataBuffer* output_db =
343
                    aclCreateDataBuffer(value_pair.first, output_size);
M
Megvii Engine Team 已提交
344 345 346 347 348
            mgb_assert(
                    output_db != nullptr,
                    "failed to create atlas output data buffer for output "
                    "%zu:%s.",
                    i, output(i)->cname());
349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365
            aclmdlAddDatasetBuffer(model_outputs, output_db);
        }
        MGB_ATLAS_CHECK(aclmdlExecute(m_model_id, model_inputs, model_outputs));

        for (size_t i = 0; i < nr_inputs; ++i) {
            aclDataBuffer* db_ptr = aclmdlGetDatasetBuffer(model_inputs, i);
            MGB_ATLAS_CHECK(aclDestroyDataBuffer(db_ptr));
        }
        for (size_t i = 0; i < nr_outputs; ++i) {
            aclDataBuffer* db_ptr = aclmdlGetDatasetBuffer(model_outputs, i);
            MGB_ATLAS_CHECK(aclDestroyDataBuffer(db_ptr));
        }
        MGB_ATLAS_CHECK(aclmdlDestroyDataset(model_inputs));
        MGB_ATLAS_CHECK(aclmdlDestroyDataset(model_outputs));
    }
}

M
Megvii Engine Team 已提交
366 367
void AtlasRuntimeOpr::get_output_var_shape(
        const TensorShapeArray& inp_shape, TensorShapeArray& out_shape) const {
368 369 370 371 372
    size_t nr_inputs = aclmdlGetNumInputs(m_model_desc);
    size_t batch_size = inp_shape[0][0];
    //! enable dynamic batchsize
    if (!m_dyn_batch_choices.empty()) {
        mgb_assert(!gen_batch_vec(batch_size, m_dyn_batch_choices).empty());
M
Megvii Engine Team 已提交
373 374 375 376
        mgb_assert(
                nr_inputs == inp_shape.size() + 1,
                "nr inputs got from om model should be one more than got "
                "from megbrain");
377 378 379 380 381 382 383 384
    }
    for (size_t i = 0; i < inp_shape.size(); ++i) {
        aclmdlIODims input_dims;
        MGB_ATLAS_CHECK(aclmdlGetInputDimsV2(m_model_desc, i, &input_dims));
        auto om_format = aclmdlGetInputFormat(m_model_desc, i);
        TensorShape shape_from_om = acl_shape_to_mgb_shape_for_input(
                input_dims, batch_size, !m_dyn_batch_choices.empty(), om_format,
                m_aipp_input_format[i]);
M
Megvii Engine Team 已提交
385 386 387 388
        mgb_assert(
                shape_from_om.eq_shape(inp_shape[i]),
                "shape mismatch of input %zu, expected: %s got: %s", i,
                shape_from_om.to_string().c_str(), inp_shape[i].to_string().c_str());
389 390 391 392 393
    }

    for (size_t i = 0; i < out_shape.size(); ++i) {
        aclmdlIODims output_dims;
        MGB_ATLAS_CHECK(aclmdlGetOutputDims(m_model_desc, i, &output_dims));
M
Megvii Engine Team 已提交
394
        out_shape[i] = acl_shape_to_mgb_shape_for_output(output_dims, batch_size);
395 396 397 398 399 400 401 402 403 404 405 406 407
    }
}

void AtlasRuntimeOpr::add_input_layout_constraint() {
    //! default contiguous
    for (auto i : input()) {
        i->add_layout_constraint_contiguous();
    }
}

void AtlasRuntimeOpr::init_output_dtype() {
    DType dt_acl, dt_input;
    for (size_t i = 0; i < input().size(); ++i) {
M
Megvii Engine Team 已提交
408
        dt_acl = acl_dtype_to_mgb_dtype(aclmdlGetInputDataType(m_model_desc, i));
409
        dt_input = input(i)->dtype();
M
Megvii Engine Team 已提交
410 411 412 413 414 415
        mgb_assert(
                dt_acl.valid() && dt_input.valid() &&
                        dt_acl.enumv() == dt_input.enumv(),
                "dtype mismatch of input %zu: expected %s, "
                "got %s",
                i, dt_acl.name(), dt_input.name());
416 417 418
    }

    for (size_t i = 0; i < output().size(); ++i) {
M
Megvii Engine Team 已提交
419 420 421 422
        dt_acl = acl_dtype_to_mgb_dtype(aclmdlGetOutputDataType(m_model_desc, i));
        mgb_assert(
                dt_acl.valid(),
                "output dtype checking failed: invalid dtype returned.");
423
        if (dt_acl.enumv() == DTypeEnum::QuantizedS8) {
M
Megvii Engine Team 已提交
424 425 426 427
            mgb_assert(
                    output(i)->dtype().valid(),
                    "user should specify scale of output tensor of "
                    "AtlasRuntimeOpr.");
428 429 430 431 432 433
        }
        if (!output(i)->dtype().valid())
            output(i)->dtype(dt_acl);
    }
}

M
Megvii Engine Team 已提交
434 435
SymbolVarArray AtlasRuntimeOpr::make(
        SharedBuffer buf, const SymbolVarArray& src, const OperatorNodeConfig& config) {
436 437
    VarNodeArray var_node_array = cg::to_var_node_array(src);
    auto atlas_runtime_opr = std::make_unique<AtlasRuntimeOpr>(
M
Megvii Engine Team 已提交
438
            std::move(buf), std::pair<uint32_t, aclmdlDesc*>{INVALID_MODEL_ID, nullptr},
439
            var_node_array, config);
M
Megvii Engine Team 已提交
440 441 442 443 444
    auto ret =
            cg::to_symbol_var_array(src[0].node()
                                            ->owner_graph()
                                            ->insert_opr(std::move(atlas_runtime_opr))
                                            ->output());
445 446 447
    return ret;
}

M
Megvii Engine Team 已提交
448 449 450 451 452 453 454 455
SymbolVarArray AtlasRuntimeOpr::make(
        const void* buf, size_t size, const SymbolVarArray& src,
        const OperatorNodeConfig& config) {
    mgb_throw_if(
            !CompNode::get_device_count(CompNode::DeviceType::ATLAS), SystemError,
            "can not create AtlasRuntimeOpr when atlas is not "
            "available");
    std::shared_ptr<uint8_t> shptr{new uint8_t[size], [](uint8_t* p) { delete[] p; }};
456 457 458 459 460 461 462 463 464
    memcpy(shptr.get(), buf, size);
    SharedBuffer buffer{std::move(shptr), size};
    return make(std::move(buffer), src, config);
}

SymbolVarArray AtlasRuntimeOpr::make(
        const SharedBuffer buf, const std::pair<uint32_t, aclmdlDesc*>& model,
        const SymbolVarArray& src, const OperatorNodeConfig& config) {
    VarNodeArray var_node_array = cg::to_var_node_array(src);
M
Megvii Engine Team 已提交
465 466 467 468 469 470 471
    auto atlas_runtime_opr =
            std::make_unique<AtlasRuntimeOpr>(buf, model, var_node_array, config);
    auto ret =
            cg::to_symbol_var_array(src[0].node()
                                            ->owner_graph()
                                            ->insert_opr(std::move(atlas_runtime_opr))
                                            ->output());
472 473 474 475 476 477 478 479
    return ret;
}

constexpr uint32_t AtlasRuntimeOpr::INVALID_MODEL_ID;

#endif  // MGB_atlas

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}