main.c 3.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/**
 * \file example/c_example/main.c
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "lite-c/global_c.h"
#include "lite-c/network_c.h"
#include "lite-c/tensor_c.h"

#include <stdio.h>
17
#include <stdlib.h>
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
#include <string.h>

#define LITE_CAPI_CHECK(_expr)                                       \
    do {                                                             \
        int _ret = (_expr);                                          \
        if (_ret) {                                                  \
            fprintf(stderr, "error msg: %s", LITE_get_last_error()); \
            return -1;                                               \
        }                                                            \
    } while (0)

int basic_c_interface(const char* mode_path) {
    //! create and load the network
    LiteNetwork c_network;
    LITE_CAPI_CHECK(
            LITE_make_network(&c_network, *default_config(), *default_network_io()));

    LITE_CAPI_CHECK(LITE_load_model_from_path(c_network, mode_path));

    //! set input data to input tensor
    LiteTensor c_input_tensor;
M
Megvii Engine Team 已提交
39
    LITE_CAPI_CHECK(LITE_get_io_tensor(c_network, "data", LITE_IO, &c_input_tensor));
40 41
    void* dst_ptr;
    size_t length_in_byte;
M
Megvii Engine Team 已提交
42 43
    LITE_CAPI_CHECK(
            LITE_get_tensor_total_size_in_byte(c_input_tensor, &length_in_byte));
44 45
    LITE_CAPI_CHECK(LITE_get_tensor_memory(c_input_tensor, &dst_ptr));
    //! copy or forward data to network
46
    LITE_memset(dst_ptr, 5, length_in_byte);
47 48 49 50 51 52 53 54 55 56

    //! forward
    LITE_CAPI_CHECK(LITE_forward(c_network));
    LITE_CAPI_CHECK(LITE_wait(c_network));

    //! get the output data or read tensor data
    const char* output_name;
    LiteTensor c_output_tensor;
    //! get the first output tensor name
    LITE_CAPI_CHECK(LITE_get_output_name(c_network, 0, &output_name));
M
Megvii Engine Team 已提交
57 58
    LITE_CAPI_CHECK(
            LITE_get_io_tensor(c_network, output_name, LITE_IO, &c_output_tensor));
59 60 61
    void* output_ptr;
    size_t length_output_in_byte;
    LITE_CAPI_CHECK(LITE_get_tensor_memory(c_output_tensor, &output_ptr));
M
Megvii Engine Team 已提交
62 63
    LITE_CAPI_CHECK(LITE_get_tensor_total_size_in_byte(
            c_output_tensor, &length_output_in_byte));
64 65 66 67 68 69

    size_t out_length = length_output_in_byte / sizeof(float);
    printf("length=%zu\n", out_length);

    float max = -1.0f;
    float sum = 0.0f;
70 71 72 73 74 75 76 77 78 79
    int is_enable_ipc_debug = LITE_is_enable_ipc_debug_mode();
    float* copy_ptr = NULL;
    float* final_dst_ptr = (float*)output_ptr;
    if (is_enable_ipc_debug) {
        copy_ptr = (float*)(malloc(length_output_in_byte));
        LITE_CAPI_CHECK(LITE_copy_server_tensor_memory(
                output_ptr, copy_ptr, length_output_in_byte));
        final_dst_ptr = (float*)copy_ptr;
    }

80
    for (size_t i = 0; i < out_length; i++) {
81
        float data = final_dst_ptr[i];
82 83 84 85 86
        sum += data;
        if (max < data)
            max = data;
    }
    printf("max=%e, sum=%e\n", max, sum);
87 88 89 90
    LITE_destroy_network(c_network);
    if (is_enable_ipc_debug) {
        free(copy_ptr);
    }
91 92 93 94
    return 0;
}

int main(int argc, char** argv) {
95 96 97
    if (argc < 3) {
        printf("usage: lite_c_examples is_enable_fork_debug_model <model file> , just "
               "test C interface "
98 99 100
               "build.\n");
        return -1;
    }
101 102 103 104
    if (atoi(argv[1])) {
        LITE_enable_lite_ipc_debug();
    }
    return basic_c_interface(argv[2]);
105 106 107
}

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}