atlas_runtime_op.cpp 20.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#include "megbrain/opr/atlas_runtime_op.h"
#include <memory>
#include "megbrain/common.h"
#include "megbrain/graph/operator_node.h"
#include "megdnn/basic_types.h"
#include "megdnn/dtype.h"

#if MGB_ATLAS
#include "acl/acl_mdl.h"

using namespace mgb;
using namespace opr;

namespace {
/**
 * \brief get mgb shape from acl shape, batch from mgb
 */
18 19 20
TensorShape acl_shape_to_mgb_shape_for_output(
        aclmdlDesc* model_desc, size_t output_idx, size_t output_dtype_size,
        aclmdlIODims acl_shape, size_t batch) {
21 22 23 24 25
    TensorShape ret;
    ret.ndim = acl_shape.dimCount;
    for (size_t i = 0; i < ret.ndim; ++i) {
        ret[i] = acl_shape.dims[i];
    }
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
    if (acl_shape.dims[0] == -1) {
        batch = aclmdlGetOutputSizeByIndex(model_desc, output_idx);
        size_t chw = output_dtype_size;
        for (size_t i = 1; i < ret.ndim; ++i) {
            chw *= ret[i];
        }
        mgb_assert(
                batch % chw == 0,
                "When the input batch is static and the output batch is dynamic, it is "
                "necessary to reconfigure the output batch. The output size obtained "
                "from the aclmdlGetOutputSizeByIndex interface should be evenly "
                "divided by "
                "shapes other than the batch. expect 0, but got %zu\n",
                batch % chw);
        batch /= chw;
    }
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
    ret[0] = batch;
    return ret;
}

/**
 * \brief deduce the input shape from aclFormat and aipp config.
 *
 * \param acl_shape shape from om file
 * \param batch batchsize from mgb
 * \param enable_dynamic_batch True if set dynamic batch size
 * \param om_format layout format from om file
 * \param aipp_input_fmt input_format in static aipp config of om file
 */
TensorShape acl_shape_to_mgb_shape_for_input(
        aclmdlIODims acl_shape, size_t batch, bool enable_dynamic_batch,
        aclFormat om_format, AtlasRuntimeOpr::AippInputFormat aipp_input_fmt) {
58
    MGB_MARK_USED_VAR(aipp_input_fmt);
59 60 61 62 63 64
    TensorShape ret;
    ret.ndim = acl_shape.dimCount;
    for (size_t i = 0; i < ret.ndim; ++i) {
        ret[i] = acl_shape.dims[i];
    }
    if (enable_dynamic_batch) {
M
Megvii Engine Team 已提交
65 66 67 68 69
        mgb_assert(
                ret[0] == static_cast<size_t>(-1),
                "batch size expected to be -1 when enable dynamic "
                "batchsize, got: %zu\n",
                ret[0]);
70 71
        ret[0] = batch;
    } else {
M
Megvii Engine Team 已提交
72 73 74 75 76
        mgb_assert(
                ret[0] == batch,
                "batchsize mismatch if no dynamic batchsize enabled, "
                "expected: %zu got: %zu\n",
                ret[0], batch);
77 78
    }

79
    mgb_assert(om_format != ACL_FORMAT_UNDEFINED, "om input format should be defined");
80 81 82 83 84 85 86 87 88 89 90 91

    return ret;
}

DType acl_dtype_to_mgb_dtype(aclDataType data_type) {
    switch (data_type) {
        case ACL_UINT8:
            return dtype::Uint8();
        case ACL_FLOAT16:
#if !MEGDNN_DISABLE_FLOAT16
            return dtype::Float16();
#else
M
Megvii Engine Team 已提交
92
            mgb_throw(MegBrainError, "Float16 support is disabled at compile time.");
93 94 95 96 97 98 99 100 101 102
#endif
        case ACL_FLOAT:
            return dtype::Float32();
        case ACL_INT8:
            return dtype::Int8();
        case ACL_INT16:
            return dtype::Int16();
        case ACL_INT32:
            return dtype::Int32();
        default:
M
Megvii Engine Team 已提交
103 104 105
            mgb_throw(
                    MegBrainError, "aclDataType %x is not supported by MegBrain.",
                    static_cast<int>(data_type));
106 107 108 109 110 111
    }
}

/**
 * \brief generate batch size which match the batch_choice
 */
M
Megvii Engine Team 已提交
112 113
SmallVector<size_t> gen_batch_vec(
        size_t origin_batch, const SmallVector<size_t>& batch_choices) {
114 115 116 117 118 119 120 121 122 123 124 125
    SmallVector<size_t> ret;
    size_t idx = 0;
    size_t nr_batch_choices = batch_choices.size();
    size_t batch = origin_batch;
    while (idx < nr_batch_choices) {
        size_t val = batch_choices[idx];
        while (batch >= batch_choices[idx]) {
            ret.push_back(val);
            batch -= val;
        }
        idx++;
    }
M
Megvii Engine Team 已提交
126 127 128
    mgb_assert(
            batch == 0, "Invalid batch size %zu, can not be generate by batch choices",
            origin_batch);
129 130 131 132 133 134 135 136 137

    return ret;
}

class PtrGetter {
public:
    PtrGetter(const VarNodeArray& vars) {
        for (auto&& var : vars) {
            m_ptrs.push_back(var->dev_tensor().raw_ptr());
M
Megvii Engine Team 已提交
138 139
            m_batch_in_bytes.push_back(
                    var->layout().stride[0] * var->layout().dtype.size());
140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160
        }
    }

    std::pair<void*, size_t> get(size_t batch, size_t idx) {
        std::pair<void*, size_t> ret;
        ret.first = m_ptrs[idx];
        ret.second = batch * m_batch_in_bytes[idx];
        m_ptrs[idx] = reinterpret_cast<void*>(
                reinterpret_cast<uintptr_t>(ret.first) + ret.second);
        return ret;
    }

private:
    SmallVector<void*> m_ptrs;
    SmallVector<size_t> m_batch_in_bytes;
};

};  // namespace

/* ====================== AtlasRuntimeOpr ==================== */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(AtlasRuntimeOpr);
M
Megvii Engine Team 已提交
161 162 163
AtlasRuntimeOpr::AtlasRuntimeOpr(
        SharedBuffer buf, const std::pair<uint32_t, aclmdlDesc*>& model,
        const VarNodeArray& inputs, const OperatorNodeConfig& config)
164 165 166 167 168 169 170 171 172
        : Super(inputs[0]->owner_graph(), config, "atlas_runtime", inputs),
          m_buffer{std::move(buf)},
          m_model_id{model.first},
          m_model_desc{model.second} {
    mgb_assert(
            inputs[0]->comp_node().device_type() == CompNode::DeviceType::ATLAS,
            "AtlasRuntimeOpr can only be used on atlas comp node; "
            "got %s",
            inputs[0]->comp_node().to_string().c_str());
M
Megvii Engine Team 已提交
173 174 175
    mgb_assert(
            m_buffer.data() != nullptr ||
            (m_model_id != INVALID_MODEL_ID && m_model_desc != nullptr));
176 177 178 179 180

    for (auto i : inputs) {
        add_input({i});
    }
    if (m_model_id == INVALID_MODEL_ID && m_model_desc == nullptr) {
M
Megvii Engine Team 已提交
181 182
        MGB_ATLAS_CHECK(
                aclmdlLoadFromMem(m_buffer.data(), m_buffer.size(), &m_model_id));
183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201
        m_model_desc = aclmdlCreateDesc();
        MGB_ATLAS_CHECK(aclmdlGetDesc(m_model_desc, m_model_id));
        m_is_model_holder = true;
    }

    //! aipp input format
    m_aipp_input_format = SmallVector<AippInputFormat>(inputs.size());
    aclAippInfo aipp_info;
    for (size_t i = 0; i < inputs.size(); ++i) {
        aclError acl_err = aclmdlGetFirstAippInfo(m_model_id, i, &aipp_info);
        if (ACL_ERROR_NONE == acl_err) {
            switch (aipp_info.inputFormat) {
                case ACL_YUV420SP_U8:
                    m_aipp_input_format[i] = AippInputFormat::YUV420SP_U8;
                    break;
                case ACL_RGB888_U8:
                    m_aipp_input_format[i] = AippInputFormat::RGB888_U8;
                    break;
                default:
M
Megvii Engine Team 已提交
202 203 204
                    mgb_throw(
                            MegBrainError,
                            "Unsupported aclAippInputFormat for input %zu. ", i);
205
            }
206 207 208
        } else if (
                ACL_ERROR_NOT_STATIC_AIPP == acl_err ||
                ACL_ERROR_GE_AIPP_NOT_EXIST == acl_err) {
209 210 211 212 213 214 215 216 217 218 219
            m_aipp_input_format[i] = AippInputFormat::NO_AIPP;
        } else {
            MGB_ATLAS_CHECK(acl_err);
        }
    }

    size_t dynamic_index;
    auto errcode = aclmdlGetInputIndexByName(
            m_model_desc, ACL_DYNAMIC_TENSOR_NAME, &dynamic_index);
    if (errcode == ACL_ERROR_NONE) {
        aclmdlHW hw_info;
M
Megvii Engine Team 已提交
220
        MGB_ATLAS_CHECK(aclmdlGetDynamicHW(m_model_desc, dynamic_index, &hw_info));
221 222 223 224 225 226 227 228
        mgb_assert(hw_info.hwCount == 0, "Currently not support dynamic HW");
    }

    //! dynamic batch size
    aclmdlBatch acl_batch;
    MGB_ATLAS_CHECK(aclmdlGetDynamicBatch(m_model_desc, &acl_batch));
    if (acl_batch.batchCount) {
        size_t dynamic_data_size;
M
Megvii Engine Team 已提交
229
        dynamic_data_size = aclmdlGetInputSizeByIndex(m_model_desc, dynamic_index);
230 231 232 233
        m_dyn_batch_tensor = DeviceTensorND(
                inputs[0]->comp_node(), {{dynamic_data_size}, dtype::Uint8()});

        for (size_t i = 0; i < acl_batch.batchCount; ++i) {
M
Megvii Engine Team 已提交
234
            m_dyn_batch_choices.push_back(static_cast<size_t>(acl_batch.batch[i]));
235
        }
M
Megvii Engine Team 已提交
236 237 238
        std::sort(
                m_dyn_batch_choices.begin(), m_dyn_batch_choices.end(),
                std::greater<>());
239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257
    }

    //! add output
    size_t nr_outputs = aclmdlGetNumOutputs(m_model_desc);
    using F = VarNode::Flag;
    if (nr_outputs == 1) {
        add_output(None);
    } else {
        for (size_t i = 0; i < nr_outputs; ++i) {
            add_output(ssprintf("o%zu", i));
        }
    }
    if (!m_dyn_batch_choices.empty()) {
        /**
         * \warning If enable dynamic batchsize, the memory of output
         * should be the largest be the size with the largest batch_size, so we
         * set the flag to SYS_MEM_ALLOC.
         */
        for (size_t i = 0; i < nr_outputs; ++i) {
M
Megvii Engine Team 已提交
258
            output(i)->add_flag(F::NO_SYS_MEM_ALLOC).add_flag(F::NO_MEM_RECLAIM);
259 260 261 262 263 264 265 266 267 268 269 270 271
        }
    }
    add_equivalence_component<mgb::ScalarHash<const void*>>(m_buffer.data());
};

AtlasRuntimeOpr::~AtlasRuntimeOpr() {
    if (m_is_model_holder) {
        MGB_ATLAS_CHECK(aclmdlUnload(m_model_id));
        MGB_ATLAS_CHECK(aclmdlDestroyDesc(m_model_desc));
    }
}

void AtlasRuntimeOpr::scn_do_execute() {
M
Megvii Engine Team 已提交
272
    auto&& acl_env = CompNodeEnv::from_comp_node(input(0)->comp_node()).atlas_env();
273 274 275 276 277 278
    acl_env.activate();

    if (!m_dyn_batch_choices.empty()) {
        for (size_t i = 0; i < output().size(); i++) {
            auto output_size = aclmdlGetOutputSizeByIndex(m_model_desc, i);
            auto ovar = output(i);
279
            output_size = std::max<size_t>(
M
Megvii Engine Team 已提交
280
                    output_size, ovar->dtype().size(ovar->shape().total_nr_elems()));
281 282 283 284 285 286 287 288 289 290 291 292 293
            ovar->shape_alloc(ovar->shape(), output_size);
        }
    }

    PtrGetter input_getter(input());
    PtrGetter output_getter(output());

    bool enable_dynamic_batch = !m_dyn_batch_choices.empty();
    size_t nr_inputs = aclmdlGetNumInputs(m_model_desc);
    size_t nr_outputs = aclmdlGetNumOutputs(m_model_desc);
    size_t input_batch = input(0)->layout()[0];

    if (enable_dynamic_batch) {
M
Megvii Engine Team 已提交
294 295 296 297
        mgb_assert(
                nr_inputs == input().size() + 1,
                "nr inputs got from om model should be one more than got "
                "from megbrain");
298 299 300 301 302 303 304 305 306 307 308
    }
    SmallVector<size_t> batches_each_run;
    if (enable_dynamic_batch) {
        batches_each_run = gen_batch_vec(input_batch, m_dyn_batch_choices);
    } else {
        batches_each_run.push_back(input_batch);
    }

    for (auto&& batch : batches_each_run) {
        //! prepare input
        auto model_inputs = aclmdlCreateDataset();
M
Megvii Engine Team 已提交
309
        mgb_assert(model_inputs != nullptr, "failed to create atlas input dataset.");
310 311 312
        for (size_t i = 0; i < input().size(); i++) {
            auto value_pair = input_getter.get(batch, i);
            auto input_size = aclmdlGetInputSizeByIndex(m_model_desc, i);
313 314
            //! FIXME iff enable dynamic batchsize and dynamic aipp, the input
            //! size should be the size of aclmdlGetInputSizeByIndex.
315
            if (enable_dynamic_batch) {
M
Megvii Engine Team 已提交
316 317 318 319 320
                mgb_assert(
                        input_size ==
                                value_pair.second / batch * m_dyn_batch_choices[0],
                        "input %zu size mismatch, expected: %zu got: %zu", i,
                        input_size, value_pair.second / batch * m_dyn_batch_choices[0]);
321 322 323
            }
            aclDataBuffer* input_db =
                    aclCreateDataBuffer(value_pair.first, value_pair.second);
M
Megvii Engine Team 已提交
324 325 326 327 328
            mgb_assert(
                    input_db != nullptr,
                    "failed to create atlas input data buffer for input "
                    "%zu:%s.",
                    i, input(i)->cname());
329 330 331 332 333 334 335
            aclmdlAddDatasetBuffer(model_inputs, input_db);
        }
        //! append unit tensor for dynamic batch
        if (enable_dynamic_batch) {
            aclDataBuffer* input_db = aclCreateDataBuffer(
                    reinterpret_cast<void*>(m_dyn_batch_tensor.raw_ptr()),
                    m_dyn_batch_tensor.layout().span().dist_byte());
M
Megvii Engine Team 已提交
336 337 338 339
            mgb_assert(
                    input_db != nullptr,
                    "failed to create atlas input data buffer for dynamic "
                    "batch tensor.");
340 341 342 343 344 345 346 347 348
            MGB_ATLAS_CHECK(aclmdlAddDatasetBuffer(model_inputs, input_db));

            MGB_ATLAS_CHECK(aclmdlSetDynamicBatchSize(
                    m_model_id, model_inputs, input().size(),
                    static_cast<uint64_t>(batch)));
        }

        //! prepare output
        auto model_outputs = aclmdlCreateDataset();
M
Megvii Engine Team 已提交
349
        mgb_assert(model_outputs != nullptr, "failed to create atlas output dataset.");
350 351
        for (size_t i = 0; i < nr_outputs; i++) {
            auto value_pair = output_getter.get(batch, i);
352
            size_t output_size = value_pair.second;
353
            if (enable_dynamic_batch || m_dyn_batch_output[i]) {
354 355
                output_size = aclmdlGetOutputSizeByIndex(m_model_desc, i);
            }
356
            aclDataBuffer* output_db =
357
                    aclCreateDataBuffer(value_pair.first, output_size);
M
Megvii Engine Team 已提交
358 359 360 361 362
            mgb_assert(
                    output_db != nullptr,
                    "failed to create atlas output data buffer for output "
                    "%zu:%s.",
                    i, output(i)->cname());
363
            aclmdlAddDatasetBuffer(model_outputs, output_db);
364 365 366 367 368 369 370 371 372 373 374 375

            if (m_dyn_batch_output[i]) {
                auto tensor_ndim = output(0)->shape().ndim;
                std::vector<int64_t> tensor_shape(tensor_ndim, 0);
                for (size_t j = 0; j < tensor_ndim; j++) {
                    tensor_shape[j] = output(0)->shape()[j];
                }
                aclTensorDesc* tensorDesc = aclCreateTensorDesc(
                        aclmdlGetOutputDataType(m_model_desc, i), tensor_ndim,
                        tensor_shape.data(), aclmdlGetOutputFormat(m_model_desc, i));
                aclmdlSetDatasetTensorDesc(model_outputs, tensorDesc, i);
            }
376 377 378 379 380 381 382 383
        }
        MGB_ATLAS_CHECK(aclmdlExecute(m_model_id, model_inputs, model_outputs));

        for (size_t i = 0; i < nr_inputs; ++i) {
            aclDataBuffer* db_ptr = aclmdlGetDatasetBuffer(model_inputs, i);
            MGB_ATLAS_CHECK(aclDestroyDataBuffer(db_ptr));
        }
        for (size_t i = 0; i < nr_outputs; ++i) {
384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408
            if (m_dyn_batch_output[i]) {
                const DeviceTensorND old_dev_tensor = output(i)->dev_tensor();

                auto new_output_desc = aclmdlGetDatasetTensorDesc(model_outputs, i);

                TensorShape new_shape;
                new_shape.ndim = aclGetTensorDescNumDims(new_output_desc);
                mgb_assert(
                        new_shape.ndim == old_dev_tensor.layout().ndim,
                        "for static input batch and dynamic output batch, the output "
                        "ndim should be consistent with the one before calling "
                        "aclmdlExecute(), so expect %zu, but got %zu",
                        old_dev_tensor.layout().ndim, new_shape.ndim);
                for (size_t j = 0; j < new_shape.ndim; j++) {
                    new_shape.shape[j] = aclGetTensorDescDim(new_output_desc, j);
                }

                TensorLayout new_layout{
                        new_shape, old_dev_tensor.dtype(), old_dev_tensor.format()};
                DeviceTensorND new_dev_tensor{
                        old_dev_tensor.comp_node(), new_layout, old_dev_tensor.dtype(),
                        old_dev_tensor.format()};
                new_dev_tensor.reset(old_dev_tensor.storage(), new_layout);
                output(i)->force_assign_dev_tensor_from_tensor(new_dev_tensor);
            }
409 410 411 412 413 414 415 416
            aclDataBuffer* db_ptr = aclmdlGetDatasetBuffer(model_outputs, i);
            MGB_ATLAS_CHECK(aclDestroyDataBuffer(db_ptr));
        }
        MGB_ATLAS_CHECK(aclmdlDestroyDataset(model_inputs));
        MGB_ATLAS_CHECK(aclmdlDestroyDataset(model_outputs));
    }
}

M
Megvii Engine Team 已提交
417 418
void AtlasRuntimeOpr::get_output_var_shape(
        const TensorShapeArray& inp_shape, TensorShapeArray& out_shape) const {
419 420 421 422 423
    size_t nr_inputs = aclmdlGetNumInputs(m_model_desc);
    size_t batch_size = inp_shape[0][0];
    //! enable dynamic batchsize
    if (!m_dyn_batch_choices.empty()) {
        mgb_assert(!gen_batch_vec(batch_size, m_dyn_batch_choices).empty());
M
Megvii Engine Team 已提交
424 425 426 427
        mgb_assert(
                nr_inputs == inp_shape.size() + 1,
                "nr inputs got from om model should be one more than got "
                "from megbrain");
428 429 430 431 432 433 434 435
    }
    for (size_t i = 0; i < inp_shape.size(); ++i) {
        aclmdlIODims input_dims;
        MGB_ATLAS_CHECK(aclmdlGetInputDimsV2(m_model_desc, i, &input_dims));
        auto om_format = aclmdlGetInputFormat(m_model_desc, i);
        TensorShape shape_from_om = acl_shape_to_mgb_shape_for_input(
                input_dims, batch_size, !m_dyn_batch_choices.empty(), om_format,
                m_aipp_input_format[i]);
M
Megvii Engine Team 已提交
436 437 438 439
        mgb_assert(
                shape_from_om.eq_shape(inp_shape[i]),
                "shape mismatch of input %zu, expected: %s got: %s", i,
                shape_from_om.to_string().c_str(), inp_shape[i].to_string().c_str());
440 441 442 443 444
    }

    for (size_t i = 0; i < out_shape.size(); ++i) {
        aclmdlIODims output_dims;
        MGB_ATLAS_CHECK(aclmdlGetOutputDims(m_model_desc, i, &output_dims));
445 446 447
        out_shape[i] = acl_shape_to_mgb_shape_for_output(
                m_model_desc, i, output(i)->dtype().size(), output_dims, batch_size);
        m_dyn_batch_output.push_back(output_dims.dims[0] == -1);
448 449 450 451 452 453 454 455 456 457 458 459 460
    }
}

void AtlasRuntimeOpr::add_input_layout_constraint() {
    //! default contiguous
    for (auto i : input()) {
        i->add_layout_constraint_contiguous();
    }
}

void AtlasRuntimeOpr::init_output_dtype() {
    DType dt_acl, dt_input;
    for (size_t i = 0; i < input().size(); ++i) {
M
Megvii Engine Team 已提交
461
        dt_acl = acl_dtype_to_mgb_dtype(aclmdlGetInputDataType(m_model_desc, i));
462
        dt_input = input(i)->dtype();
M
Megvii Engine Team 已提交
463 464 465 466 467 468
        mgb_assert(
                dt_acl.valid() && dt_input.valid() &&
                        dt_acl.enumv() == dt_input.enumv(),
                "dtype mismatch of input %zu: expected %s, "
                "got %s",
                i, dt_acl.name(), dt_input.name());
469 470 471
    }

    for (size_t i = 0; i < output().size(); ++i) {
M
Megvii Engine Team 已提交
472 473 474 475
        dt_acl = acl_dtype_to_mgb_dtype(aclmdlGetOutputDataType(m_model_desc, i));
        mgb_assert(
                dt_acl.valid(),
                "output dtype checking failed: invalid dtype returned.");
476
        if (dt_acl.enumv() == DTypeEnum::QuantizedS8) {
M
Megvii Engine Team 已提交
477 478 479 480
            mgb_assert(
                    output(i)->dtype().valid(),
                    "user should specify scale of output tensor of "
                    "AtlasRuntimeOpr.");
481 482 483 484 485 486
        }
        if (!output(i)->dtype().valid())
            output(i)->dtype(dt_acl);
    }
}

M
Megvii Engine Team 已提交
487 488
SymbolVarArray AtlasRuntimeOpr::make(
        SharedBuffer buf, const SymbolVarArray& src, const OperatorNodeConfig& config) {
489 490
    VarNodeArray var_node_array = cg::to_var_node_array(src);
    auto atlas_runtime_opr = std::make_unique<AtlasRuntimeOpr>(
M
Megvii Engine Team 已提交
491
            std::move(buf), std::pair<uint32_t, aclmdlDesc*>{INVALID_MODEL_ID, nullptr},
492
            var_node_array, config);
M
Megvii Engine Team 已提交
493 494 495 496 497
    auto ret =
            cg::to_symbol_var_array(src[0].node()
                                            ->owner_graph()
                                            ->insert_opr(std::move(atlas_runtime_opr))
                                            ->output());
498 499 500
    return ret;
}

M
Megvii Engine Team 已提交
501 502 503 504 505 506 507 508
SymbolVarArray AtlasRuntimeOpr::make(
        const void* buf, size_t size, const SymbolVarArray& src,
        const OperatorNodeConfig& config) {
    mgb_throw_if(
            !CompNode::get_device_count(CompNode::DeviceType::ATLAS), SystemError,
            "can not create AtlasRuntimeOpr when atlas is not "
            "available");
    std::shared_ptr<uint8_t> shptr{new uint8_t[size], [](uint8_t* p) { delete[] p; }};
509 510 511 512 513 514 515 516 517
    memcpy(shptr.get(), buf, size);
    SharedBuffer buffer{std::move(shptr), size};
    return make(std::move(buffer), src, config);
}

SymbolVarArray AtlasRuntimeOpr::make(
        const SharedBuffer buf, const std::pair<uint32_t, aclmdlDesc*>& model,
        const SymbolVarArray& src, const OperatorNodeConfig& config) {
    VarNodeArray var_node_array = cg::to_var_node_array(src);
M
Megvii Engine Team 已提交
518 519 520 521 522 523 524
    auto atlas_runtime_opr =
            std::make_unique<AtlasRuntimeOpr>(buf, model, var_node_array, config);
    auto ret =
            cg::to_symbol_var_array(src[0].node()
                                            ->owner_graph()
                                            ->insert_opr(std::move(atlas_runtime_opr))
                                            ->output());
525 526 527 528 529 530 531 532
    return ret;
}

constexpr uint32_t AtlasRuntimeOpr::INVALID_MODEL_ID;

#endif  // MGB_atlas

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}