sparse_format.cpp 3.9 KB
Newer Older
W
wangguibao 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
/***************************************************************************
 * 
 * Copyright (c) 2018 Baidu.com, Inc. All Rights Reserved
 * 
 **************************************************************************/
 
/**
 * @file demo.cpp
 * @author wanlijin01(wanlijin01@baidu.com)
 * @date 2018/07/09 20:12:44
 * @brief 
 *  
 **/
#include <sys/types.h>
#include <sys/stat.h>
#include <unistd.h>

#include "common.h"
#include <fstream>
#include "predictor_sdk.h"
#include "sparse_service.pb.h"
#include "builtin_format.pb.h"

using baidu::paddle_serving::sdk_cpp::Predictor;
using baidu::paddle_serving::sdk_cpp::PredictorApi;
using baidu::paddle_serving::predictor::sparse_service::Request;
using baidu::paddle_serving::predictor::sparse_service::Response;
using baidu::paddle_serving::predictor::format::SparsePrediction;
using baidu::paddle_serving::predictor::format::SparseInstance;

int create_req(Request& req) {
    SparseInstance *ins = req.mutable_instances()->Add();
    ins->add_keys(26);
    ins->add_keys(182);
    ins->add_keys(232);
    ins->add_shape(2000);
    ins->add_values(1);
    ins->add_values(1);
    ins->add_values(1);

    ins = req.mutable_instances()->Add();
    ins->add_keys(0);
    ins->add_keys(182);
    ins->add_keys(232);
    ins->add_keys(299);
    ins->add_shape(2000);
    ins->add_values(13);
    ins->add_values(1);
    ins->add_values(1);
    ins->add_values(1);
    return 0;
}

void print_res(
        const Request& req,
        const Response& res,
        std::string route_tag,
        uint64_t elapse_ms) {

    for (uint32_t i = 0; i < res.predictions_size(); ++i) {
        const SparsePrediction &prediction = res.predictions(i);
        std::ostringstream oss;
        for (uint32_t j = 0; j < prediction.categories_size(); ++j) {
            oss << prediction.categories(j) << " ";
        }
        LOG(INFO) << "Receive result " << oss.str();
    }

    LOG(INFO) 
        << "Succ call predictor[sparse_format], the tag is: " 
        << route_tag << ", elapse_ms: " << elapse_ms;
}

int main(int argc, char** argv) {
    PredictorApi api;

    // initialize logger instance
    struct stat st_buf;
    int ret = 0;
    if ((ret = stat("./log", &st_buf)) != 0) {
            mkdir("./log", 0777);
            ret = stat("./log", &st_buf);
            if (ret != 0) {
                    LOG(WARNING) << "Log path ./log not exist, and create fail";
                    return -1;
            }
    }
    FLAGS_log_dir = "./log";
    google::InitGoogleLogging(strdup(argv[0]));
     
    if (api.create("./conf", "predictors.prototxt") != 0) {
        LOG(ERROR) << "Failed create predictors api!"; 
        return -1;
    }

    Request req;
    Response res;

    api.thrd_initialize();

    while (true) {
        timeval start;
        gettimeofday(&start, NULL);

        api.thrd_clear();

        Predictor* predictor = api.fetch_predictor("sparse_service");
        if (!predictor) {
            LOG(ERROR) << "Failed fetch predictor: sparse_service"; 
            return -1;
        }

        req.Clear();
        res.Clear();

        if (create_req(req) != 0) {
            return -1;
        }

        butil::IOBufBuilder debug_os;
        if (predictor->debug(&req, &res, &debug_os) != 0) {
            LOG(ERROR) << "failed call predictor with req:"
                        << req.ShortDebugString();
            return -1;
        }

        butil::IOBuf debug_buf;
        debug_os.move_to(debug_buf);
        LOG(INFO) << "Debug string: " << debug_buf;

        timeval end;
        gettimeofday(&end, NULL);

        uint64_t elapse_ms = (end.tv_sec * 1000 + end.tv_usec / 1000)
            - (start.tv_sec * 1000 + start.tv_usec / 1000);
    
        print_res(req, res, predictor->tag(), elapse_ms);
        res.Clear();

        usleep(50);

    } // while (true)

    api.thrd_finalize();
    api.destroy();

    google::ShutdownGoogleLogging();

    return 0;
}

/* vim: set expandtab ts=4 sw=4 sts=4 tw=100: */