atlas_runtime_op.cpp 17.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#include "megbrain/opr/atlas_runtime_op.h"
#include <memory>
#include "megbrain/common.h"
#include "megbrain/graph/operator_node.h"
#include "megdnn/basic_types.h"
#include "megdnn/dtype.h"

#if MGB_ATLAS
#include "acl/acl_mdl.h"

using namespace mgb;
using namespace opr;

namespace {
/**
 * \brief get mgb shape from acl shape, batch from mgb
 */
M
Megvii Engine Team 已提交
18
TensorShape acl_shape_to_mgb_shape_for_output(aclmdlIODims acl_shape, size_t batch) {
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
    TensorShape ret;
    ret.ndim = acl_shape.dimCount;
    for (size_t i = 0; i < ret.ndim; ++i) {
        ret[i] = acl_shape.dims[i];
    }
    ret[0] = batch;
    return ret;
}

/**
 * \brief deduce the input shape from aclFormat and aipp config.
 *
 * \param acl_shape shape from om file
 * \param batch batchsize from mgb
 * \param enable_dynamic_batch True if set dynamic batch size
 * \param om_format layout format from om file
 * \param aipp_input_fmt input_format in static aipp config of om file
 */
TensorShape acl_shape_to_mgb_shape_for_input(
        aclmdlIODims acl_shape, size_t batch, bool enable_dynamic_batch,
        aclFormat om_format, AtlasRuntimeOpr::AippInputFormat aipp_input_fmt) {
40
    MGB_MARK_USED_VAR(aipp_input_fmt);
41 42 43 44 45 46
    TensorShape ret;
    ret.ndim = acl_shape.dimCount;
    for (size_t i = 0; i < ret.ndim; ++i) {
        ret[i] = acl_shape.dims[i];
    }
    if (enable_dynamic_batch) {
M
Megvii Engine Team 已提交
47 48 49 50 51
        mgb_assert(
                ret[0] == static_cast<size_t>(-1),
                "batch size expected to be -1 when enable dynamic "
                "batchsize, got: %zu\n",
                ret[0]);
52 53
        ret[0] = batch;
    } else {
M
Megvii Engine Team 已提交
54 55 56 57 58
        mgb_assert(
                ret[0] == batch,
                "batchsize mismatch if no dynamic batchsize enabled, "
                "expected: %zu got: %zu\n",
                ret[0], batch);
59 60
    }

61
    mgb_assert(om_format != ACL_FORMAT_UNDEFINED, "om input format should be defined");
62 63 64 65 66 67 68 69 70 71 72 73

    return ret;
}

DType acl_dtype_to_mgb_dtype(aclDataType data_type) {
    switch (data_type) {
        case ACL_UINT8:
            return dtype::Uint8();
        case ACL_FLOAT16:
#if !MEGDNN_DISABLE_FLOAT16
            return dtype::Float16();
#else
M
Megvii Engine Team 已提交
74
            mgb_throw(MegBrainError, "Float16 support is disabled at compile time.");
75 76 77 78 79 80 81 82 83 84
#endif
        case ACL_FLOAT:
            return dtype::Float32();
        case ACL_INT8:
            return dtype::Int8();
        case ACL_INT16:
            return dtype::Int16();
        case ACL_INT32:
            return dtype::Int32();
        default:
M
Megvii Engine Team 已提交
85 86 87
            mgb_throw(
                    MegBrainError, "aclDataType %x is not supported by MegBrain.",
                    static_cast<int>(data_type));
88 89 90 91 92 93
    }
}

/**
 * \brief generate batch size which match the batch_choice
 */
M
Megvii Engine Team 已提交
94 95
SmallVector<size_t> gen_batch_vec(
        size_t origin_batch, const SmallVector<size_t>& batch_choices) {
96 97 98 99 100 101 102 103 104 105 106 107
    SmallVector<size_t> ret;
    size_t idx = 0;
    size_t nr_batch_choices = batch_choices.size();
    size_t batch = origin_batch;
    while (idx < nr_batch_choices) {
        size_t val = batch_choices[idx];
        while (batch >= batch_choices[idx]) {
            ret.push_back(val);
            batch -= val;
        }
        idx++;
    }
M
Megvii Engine Team 已提交
108 109 110
    mgb_assert(
            batch == 0, "Invalid batch size %zu, can not be generate by batch choices",
            origin_batch);
111 112 113 114 115 116 117 118 119

    return ret;
}

class PtrGetter {
public:
    PtrGetter(const VarNodeArray& vars) {
        for (auto&& var : vars) {
            m_ptrs.push_back(var->dev_tensor().raw_ptr());
M
Megvii Engine Team 已提交
120 121
            m_batch_in_bytes.push_back(
                    var->layout().stride[0] * var->layout().dtype.size());
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
        }
    }

    std::pair<void*, size_t> get(size_t batch, size_t idx) {
        std::pair<void*, size_t> ret;
        ret.first = m_ptrs[idx];
        ret.second = batch * m_batch_in_bytes[idx];
        m_ptrs[idx] = reinterpret_cast<void*>(
                reinterpret_cast<uintptr_t>(ret.first) + ret.second);
        return ret;
    }

private:
    SmallVector<void*> m_ptrs;
    SmallVector<size_t> m_batch_in_bytes;
};

};  // namespace

/* ====================== AtlasRuntimeOpr ==================== */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(AtlasRuntimeOpr);
M
Megvii Engine Team 已提交
143 144 145
AtlasRuntimeOpr::AtlasRuntimeOpr(
        SharedBuffer buf, const std::pair<uint32_t, aclmdlDesc*>& model,
        const VarNodeArray& inputs, const OperatorNodeConfig& config)
146 147 148 149 150 151 152 153 154
        : Super(inputs[0]->owner_graph(), config, "atlas_runtime", inputs),
          m_buffer{std::move(buf)},
          m_model_id{model.first},
          m_model_desc{model.second} {
    mgb_assert(
            inputs[0]->comp_node().device_type() == CompNode::DeviceType::ATLAS,
            "AtlasRuntimeOpr can only be used on atlas comp node; "
            "got %s",
            inputs[0]->comp_node().to_string().c_str());
M
Megvii Engine Team 已提交
155 156 157
    mgb_assert(
            m_buffer.data() != nullptr ||
            (m_model_id != INVALID_MODEL_ID && m_model_desc != nullptr));
158 159 160 161 162

    for (auto i : inputs) {
        add_input({i});
    }
    if (m_model_id == INVALID_MODEL_ID && m_model_desc == nullptr) {
M
Megvii Engine Team 已提交
163 164
        MGB_ATLAS_CHECK(
                aclmdlLoadFromMem(m_buffer.data(), m_buffer.size(), &m_model_id));
165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183
        m_model_desc = aclmdlCreateDesc();
        MGB_ATLAS_CHECK(aclmdlGetDesc(m_model_desc, m_model_id));
        m_is_model_holder = true;
    }

    //! aipp input format
    m_aipp_input_format = SmallVector<AippInputFormat>(inputs.size());
    aclAippInfo aipp_info;
    for (size_t i = 0; i < inputs.size(); ++i) {
        aclError acl_err = aclmdlGetFirstAippInfo(m_model_id, i, &aipp_info);
        if (ACL_ERROR_NONE == acl_err) {
            switch (aipp_info.inputFormat) {
                case ACL_YUV420SP_U8:
                    m_aipp_input_format[i] = AippInputFormat::YUV420SP_U8;
                    break;
                case ACL_RGB888_U8:
                    m_aipp_input_format[i] = AippInputFormat::RGB888_U8;
                    break;
                default:
M
Megvii Engine Team 已提交
184 185 186
                    mgb_throw(
                            MegBrainError,
                            "Unsupported aclAippInputFormat for input %zu. ", i);
187
            }
188 189 190
        } else if (
                ACL_ERROR_NOT_STATIC_AIPP == acl_err ||
                ACL_ERROR_GE_AIPP_NOT_EXIST == acl_err) {
191 192 193 194 195 196 197 198 199 200 201
            m_aipp_input_format[i] = AippInputFormat::NO_AIPP;
        } else {
            MGB_ATLAS_CHECK(acl_err);
        }
    }

    size_t dynamic_index;
    auto errcode = aclmdlGetInputIndexByName(
            m_model_desc, ACL_DYNAMIC_TENSOR_NAME, &dynamic_index);
    if (errcode == ACL_ERROR_NONE) {
        aclmdlHW hw_info;
M
Megvii Engine Team 已提交
202
        MGB_ATLAS_CHECK(aclmdlGetDynamicHW(m_model_desc, dynamic_index, &hw_info));
203 204 205 206 207 208 209 210
        mgb_assert(hw_info.hwCount == 0, "Currently not support dynamic HW");
    }

    //! dynamic batch size
    aclmdlBatch acl_batch;
    MGB_ATLAS_CHECK(aclmdlGetDynamicBatch(m_model_desc, &acl_batch));
    if (acl_batch.batchCount) {
        size_t dynamic_data_size;
M
Megvii Engine Team 已提交
211
        dynamic_data_size = aclmdlGetInputSizeByIndex(m_model_desc, dynamic_index);
212 213 214 215
        m_dyn_batch_tensor = DeviceTensorND(
                inputs[0]->comp_node(), {{dynamic_data_size}, dtype::Uint8()});

        for (size_t i = 0; i < acl_batch.batchCount; ++i) {
M
Megvii Engine Team 已提交
216
            m_dyn_batch_choices.push_back(static_cast<size_t>(acl_batch.batch[i]));
217
        }
M
Megvii Engine Team 已提交
218 219 220
        std::sort(
                m_dyn_batch_choices.begin(), m_dyn_batch_choices.end(),
                std::greater<>());
221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239
    }

    //! add output
    size_t nr_outputs = aclmdlGetNumOutputs(m_model_desc);
    using F = VarNode::Flag;
    if (nr_outputs == 1) {
        add_output(None);
    } else {
        for (size_t i = 0; i < nr_outputs; ++i) {
            add_output(ssprintf("o%zu", i));
        }
    }
    if (!m_dyn_batch_choices.empty()) {
        /**
         * \warning If enable dynamic batchsize, the memory of output
         * should be the largest be the size with the largest batch_size, so we
         * set the flag to SYS_MEM_ALLOC.
         */
        for (size_t i = 0; i < nr_outputs; ++i) {
M
Megvii Engine Team 已提交
240
            output(i)->add_flag(F::NO_SYS_MEM_ALLOC).add_flag(F::NO_MEM_RECLAIM);
241 242 243 244 245 246 247 248 249 250 251 252 253
        }
    }
    add_equivalence_component<mgb::ScalarHash<const void*>>(m_buffer.data());
};

AtlasRuntimeOpr::~AtlasRuntimeOpr() {
    if (m_is_model_holder) {
        MGB_ATLAS_CHECK(aclmdlUnload(m_model_id));
        MGB_ATLAS_CHECK(aclmdlDestroyDesc(m_model_desc));
    }
}

void AtlasRuntimeOpr::scn_do_execute() {
M
Megvii Engine Team 已提交
254
    auto&& acl_env = CompNodeEnv::from_comp_node(input(0)->comp_node()).atlas_env();
255 256 257 258 259 260
    acl_env.activate();

    if (!m_dyn_batch_choices.empty()) {
        for (size_t i = 0; i < output().size(); i++) {
            auto output_size = aclmdlGetOutputSizeByIndex(m_model_desc, i);
            auto ovar = output(i);
261
            output_size = std::max<size_t>(
M
Megvii Engine Team 已提交
262
                    output_size, ovar->dtype().size(ovar->shape().total_nr_elems()));
263 264 265 266 267 268 269 270 271 272 273 274 275
            ovar->shape_alloc(ovar->shape(), output_size);
        }
    }

    PtrGetter input_getter(input());
    PtrGetter output_getter(output());

    bool enable_dynamic_batch = !m_dyn_batch_choices.empty();
    size_t nr_inputs = aclmdlGetNumInputs(m_model_desc);
    size_t nr_outputs = aclmdlGetNumOutputs(m_model_desc);
    size_t input_batch = input(0)->layout()[0];

    if (enable_dynamic_batch) {
M
Megvii Engine Team 已提交
276 277 278 279
        mgb_assert(
                nr_inputs == input().size() + 1,
                "nr inputs got from om model should be one more than got "
                "from megbrain");
280 281 282 283 284 285 286 287 288 289 290
    }
    SmallVector<size_t> batches_each_run;
    if (enable_dynamic_batch) {
        batches_each_run = gen_batch_vec(input_batch, m_dyn_batch_choices);
    } else {
        batches_each_run.push_back(input_batch);
    }

    for (auto&& batch : batches_each_run) {
        //! prepare input
        auto model_inputs = aclmdlCreateDataset();
M
Megvii Engine Team 已提交
291
        mgb_assert(model_inputs != nullptr, "failed to create atlas input dataset.");
292 293 294
        for (size_t i = 0; i < input().size(); i++) {
            auto value_pair = input_getter.get(batch, i);
            auto input_size = aclmdlGetInputSizeByIndex(m_model_desc, i);
295 296
            //! FIXME iff enable dynamic batchsize and dynamic aipp, the input
            //! size should be the size of aclmdlGetInputSizeByIndex.
297
            if (enable_dynamic_batch) {
M
Megvii Engine Team 已提交
298 299 300 301 302
                mgb_assert(
                        input_size ==
                                value_pair.second / batch * m_dyn_batch_choices[0],
                        "input %zu size mismatch, expected: %zu got: %zu", i,
                        input_size, value_pair.second / batch * m_dyn_batch_choices[0]);
303 304 305
            }
            aclDataBuffer* input_db =
                    aclCreateDataBuffer(value_pair.first, value_pair.second);
M
Megvii Engine Team 已提交
306 307 308 309 310
            mgb_assert(
                    input_db != nullptr,
                    "failed to create atlas input data buffer for input "
                    "%zu:%s.",
                    i, input(i)->cname());
311 312 313 314 315 316 317
            aclmdlAddDatasetBuffer(model_inputs, input_db);
        }
        //! append unit tensor for dynamic batch
        if (enable_dynamic_batch) {
            aclDataBuffer* input_db = aclCreateDataBuffer(
                    reinterpret_cast<void*>(m_dyn_batch_tensor.raw_ptr()),
                    m_dyn_batch_tensor.layout().span().dist_byte());
M
Megvii Engine Team 已提交
318 319 320 321
            mgb_assert(
                    input_db != nullptr,
                    "failed to create atlas input data buffer for dynamic "
                    "batch tensor.");
322 323 324 325 326 327 328 329 330
            MGB_ATLAS_CHECK(aclmdlAddDatasetBuffer(model_inputs, input_db));

            MGB_ATLAS_CHECK(aclmdlSetDynamicBatchSize(
                    m_model_id, model_inputs, input().size(),
                    static_cast<uint64_t>(batch)));
        }

        //! prepare output
        auto model_outputs = aclmdlCreateDataset();
M
Megvii Engine Team 已提交
331
        mgb_assert(model_outputs != nullptr, "failed to create atlas output dataset.");
332 333
        for (size_t i = 0; i < nr_outputs; i++) {
            auto value_pair = output_getter.get(batch, i);
334 335 336 337
            size_t output_size = value_pair.second;
            if (enable_dynamic_batch) {
                output_size = aclmdlGetOutputSizeByIndex(m_model_desc, i);
            }
338
            aclDataBuffer* output_db =
339
                    aclCreateDataBuffer(value_pair.first, output_size);
M
Megvii Engine Team 已提交
340 341 342 343 344
            mgb_assert(
                    output_db != nullptr,
                    "failed to create atlas output data buffer for output "
                    "%zu:%s.",
                    i, output(i)->cname());
345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361
            aclmdlAddDatasetBuffer(model_outputs, output_db);
        }
        MGB_ATLAS_CHECK(aclmdlExecute(m_model_id, model_inputs, model_outputs));

        for (size_t i = 0; i < nr_inputs; ++i) {
            aclDataBuffer* db_ptr = aclmdlGetDatasetBuffer(model_inputs, i);
            MGB_ATLAS_CHECK(aclDestroyDataBuffer(db_ptr));
        }
        for (size_t i = 0; i < nr_outputs; ++i) {
            aclDataBuffer* db_ptr = aclmdlGetDatasetBuffer(model_outputs, i);
            MGB_ATLAS_CHECK(aclDestroyDataBuffer(db_ptr));
        }
        MGB_ATLAS_CHECK(aclmdlDestroyDataset(model_inputs));
        MGB_ATLAS_CHECK(aclmdlDestroyDataset(model_outputs));
    }
}

M
Megvii Engine Team 已提交
362 363
void AtlasRuntimeOpr::get_output_var_shape(
        const TensorShapeArray& inp_shape, TensorShapeArray& out_shape) const {
364 365 366 367 368
    size_t nr_inputs = aclmdlGetNumInputs(m_model_desc);
    size_t batch_size = inp_shape[0][0];
    //! enable dynamic batchsize
    if (!m_dyn_batch_choices.empty()) {
        mgb_assert(!gen_batch_vec(batch_size, m_dyn_batch_choices).empty());
M
Megvii Engine Team 已提交
369 370 371 372
        mgb_assert(
                nr_inputs == inp_shape.size() + 1,
                "nr inputs got from om model should be one more than got "
                "from megbrain");
373 374 375 376 377 378 379 380
    }
    for (size_t i = 0; i < inp_shape.size(); ++i) {
        aclmdlIODims input_dims;
        MGB_ATLAS_CHECK(aclmdlGetInputDimsV2(m_model_desc, i, &input_dims));
        auto om_format = aclmdlGetInputFormat(m_model_desc, i);
        TensorShape shape_from_om = acl_shape_to_mgb_shape_for_input(
                input_dims, batch_size, !m_dyn_batch_choices.empty(), om_format,
                m_aipp_input_format[i]);
M
Megvii Engine Team 已提交
381 382 383 384
        mgb_assert(
                shape_from_om.eq_shape(inp_shape[i]),
                "shape mismatch of input %zu, expected: %s got: %s", i,
                shape_from_om.to_string().c_str(), inp_shape[i].to_string().c_str());
385 386 387 388 389
    }

    for (size_t i = 0; i < out_shape.size(); ++i) {
        aclmdlIODims output_dims;
        MGB_ATLAS_CHECK(aclmdlGetOutputDims(m_model_desc, i, &output_dims));
M
Megvii Engine Team 已提交
390
        out_shape[i] = acl_shape_to_mgb_shape_for_output(output_dims, batch_size);
391 392 393 394 395 396 397 398 399 400 401 402 403
    }
}

void AtlasRuntimeOpr::add_input_layout_constraint() {
    //! default contiguous
    for (auto i : input()) {
        i->add_layout_constraint_contiguous();
    }
}

void AtlasRuntimeOpr::init_output_dtype() {
    DType dt_acl, dt_input;
    for (size_t i = 0; i < input().size(); ++i) {
M
Megvii Engine Team 已提交
404
        dt_acl = acl_dtype_to_mgb_dtype(aclmdlGetInputDataType(m_model_desc, i));
405
        dt_input = input(i)->dtype();
M
Megvii Engine Team 已提交
406 407 408 409 410 411
        mgb_assert(
                dt_acl.valid() && dt_input.valid() &&
                        dt_acl.enumv() == dt_input.enumv(),
                "dtype mismatch of input %zu: expected %s, "
                "got %s",
                i, dt_acl.name(), dt_input.name());
412 413 414
    }

    for (size_t i = 0; i < output().size(); ++i) {
M
Megvii Engine Team 已提交
415 416 417 418
        dt_acl = acl_dtype_to_mgb_dtype(aclmdlGetOutputDataType(m_model_desc, i));
        mgb_assert(
                dt_acl.valid(),
                "output dtype checking failed: invalid dtype returned.");
419
        if (dt_acl.enumv() == DTypeEnum::QuantizedS8) {
M
Megvii Engine Team 已提交
420 421 422 423
            mgb_assert(
                    output(i)->dtype().valid(),
                    "user should specify scale of output tensor of "
                    "AtlasRuntimeOpr.");
424 425 426 427 428 429
        }
        if (!output(i)->dtype().valid())
            output(i)->dtype(dt_acl);
    }
}

M
Megvii Engine Team 已提交
430 431
SymbolVarArray AtlasRuntimeOpr::make(
        SharedBuffer buf, const SymbolVarArray& src, const OperatorNodeConfig& config) {
432 433
    VarNodeArray var_node_array = cg::to_var_node_array(src);
    auto atlas_runtime_opr = std::make_unique<AtlasRuntimeOpr>(
M
Megvii Engine Team 已提交
434
            std::move(buf), std::pair<uint32_t, aclmdlDesc*>{INVALID_MODEL_ID, nullptr},
435
            var_node_array, config);
M
Megvii Engine Team 已提交
436 437 438 439 440
    auto ret =
            cg::to_symbol_var_array(src[0].node()
                                            ->owner_graph()
                                            ->insert_opr(std::move(atlas_runtime_opr))
                                            ->output());
441 442 443
    return ret;
}

M
Megvii Engine Team 已提交
444 445 446 447 448 449 450 451
SymbolVarArray AtlasRuntimeOpr::make(
        const void* buf, size_t size, const SymbolVarArray& src,
        const OperatorNodeConfig& config) {
    mgb_throw_if(
            !CompNode::get_device_count(CompNode::DeviceType::ATLAS), SystemError,
            "can not create AtlasRuntimeOpr when atlas is not "
            "available");
    std::shared_ptr<uint8_t> shptr{new uint8_t[size], [](uint8_t* p) { delete[] p; }};
452 453 454 455 456 457 458 459 460
    memcpy(shptr.get(), buf, size);
    SharedBuffer buffer{std::move(shptr), size};
    return make(std::move(buffer), src, config);
}

SymbolVarArray AtlasRuntimeOpr::make(
        const SharedBuffer buf, const std::pair<uint32_t, aclmdlDesc*>& model,
        const SymbolVarArray& src, const OperatorNodeConfig& config) {
    VarNodeArray var_node_array = cg::to_var_node_array(src);
M
Megvii Engine Team 已提交
461 462 463 464 465 466 467
    auto atlas_runtime_opr =
            std::make_unique<AtlasRuntimeOpr>(buf, model, var_node_array, config);
    auto ret =
            cg::to_symbol_var_array(src[0].node()
                                            ->owner_graph()
                                            ->insert_opr(std::move(atlas_runtime_opr))
                                            ->output());
468 469 470 471 472 473 474 475
    return ret;
}

constexpr uint32_t AtlasRuntimeOpr::INVALID_MODEL_ID;

#endif  // MGB_atlas

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}