提交 ad7c458a 编写于 作者: G guru4elephant

add general model init

上级 e55c15fe
......@@ -2,6 +2,7 @@ LIST(APPEND protofiles
${CMAKE_CURRENT_LIST_DIR}/proto/server_configure.proto
${CMAKE_CURRENT_LIST_DIR}/proto/sdk_configure.proto
${CMAKE_CURRENT_LIST_DIR}/proto/inferencer_configure.proto
${CMAKE_CURRENT_LIST_DIR}/proto/general_model_config.proto
)
PROTOBUF_GENERATE_CPP(configure_proto_srcs configure_proto_hdrs ${protofiles})
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package baidu.paddle_serving.configure;
message Shape {
repeated int32 shape = 1;
};
message GeneralModelConfig {
repeated bool is_lod_feed = 1;
repeated int32 feed_type = 2;
repeated Shape feed_shape = 4;
};
......@@ -31,6 +31,8 @@ DECLARE_string(logger_path);
DECLARE_string(logger_file);
DECLARE_string(resource_path);
DECLARE_string(resource_file);
DECLARE_string(general_model_path);
DECLARE_string(general_model_file);
DECLARE_bool(enable_mc_cache);
DECLARE_bool(enable_nshead_protocol);
DECLARE_string(nshead_protocol);
......@@ -40,6 +42,7 @@ DECLARE_int32(reload_interval_s);
DECLARE_bool(enable_model_toolkit);
DECLARE_string(enable_protocol_list);
DECLARE_bool(enable_cube);
DECLARE_bool(enable_general_model);
// STATIC Variables
extern const char* START_OP_NAME;
......
......@@ -54,6 +54,7 @@
#include "configure/include/configure_parser.h"
#include "configure/server_configure.pb.h"
#include "configure/general_model_config.pb.h"
#include "predictor/common/constant.h"
#include "predictor/common/types.h"
......
......@@ -21,6 +21,8 @@ namespace paddle_serving {
namespace predictor {
using configure::ResourceConf;
using configure::GeneralModelConf;
using configure::Shape;
using rec::mcube::CubeAPI;
// __thread bool p_thread_initialized = false;
......@@ -96,6 +98,40 @@ int Resource::initialize(const std::string& path, const std::string& file) {
return 0;
}
int Resource::general_model_initialize(
const std::string& path, const std::string & file) {
if (!FLAGS_enable_general_model) {
return 0;
}
GeneralModelConf model_config;
if (configure::read_proto_conf(path, file, &model_config) != 0) {
LOG(ERROR) << "Failed initialize resource from: " << path << "/" << file;
return -1;
}
_config.reset(new PaddleGeneralModelConfig());
_config->_feed_type.resize(model_config.is_feed_type_size());
_config->_is_lod_feed.resize(model_config.is_lod_feed_size());
_config->_capacity.resize(model_config.feed_shape_size());
_config->_feed_shape.resize(model_config.feed_shape_size());
for (int i = 0; i < model_config.is_lod_feed_size(); ++i) {
_config->feed_type[i] = model_config.feed_type(i);
if (model_config.is_lod_feed(i)) {
_config->_feed_shape[i] = {-1};
_config->_is_lod_feed[i] = true;
} else {
_config->capacity[i] = 1;
_config->_is_lod_feed[i] = false;
for (int j = 0; j < model_config.feed_shape(i).shape_size(); ++j) {
int dim = model_cnofig.feed_shape(i).shape(j);
_config->_feed_shape[i].push_back(dim);
_config->capacity[i] *= dim;
}
}
}
}
int Resource::cube_initialize(const std::string& path,
const std::string& file) {
// cube
......
......@@ -15,6 +15,7 @@
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "cube/cube-api/include/cube_api.h"
#include "kvdb/paddle_rocksdb.h"
#include "predictor/common/inner_common.h"
......@@ -25,6 +26,26 @@ namespace baidu {
namespace paddle_serving {
namespace predictor {
class PaddleGeneralModelConfig {
PaddleGeneralModelConfig();
~PaddleGeneralModelConfig();
void load_config(std::string);
public:
std::vector<int> _feed_type; // 0 int64, 1 float
std::vector<bool> _is_lod_feed; // true lod tensor
std::vector<int> _capacity; // capacity for each tensor
/*
feed_shape_ for feeded variable
feed_shape_[i][j] represents the jth dim for ith input Tensor
if is_lod_feed_[i] == False, feed_shape_[i][0] = -1
*/
std::vector<std::vector<int>> _feed_shape;
};
class BaseRdDict;
struct DynamicResource {
DynamicResource();
......@@ -55,6 +76,10 @@ class Resource {
int initialize(const std::string& path, const std::string& file);
int cube_initialize(const std::string& path, const std::string& file);
int general_model_initialize(
const std::string& path, const std::string & file);
int thread_initialize();
int thread_clear();
......@@ -73,6 +98,7 @@ class Resource {
private:
int thread_finalize() { return 0; }
std::shared_ptr<RocksDBWrapper> db;
std::shared_ptr<PaddleGeneralModelConfig> _config;
THREAD_KEY_T _tls_bspec_key;
};
......
......@@ -216,6 +216,17 @@ int main(int argc, char** argv) {
LOG(INFO) << "Succ initialize cube";
#ifndef BCLOUD
if (Resource::instance().general_model_initialize(
FLAGS_general_model_path, FLAGS_general_model_file) != 0) {
LOG(ERROR) << "Failed to initialize general model conf: "
<< FLAGS_general_model_path << "/"
<< FLAGS_general_model_file;
return -1;
}
LOG(INFO) << "Succ initialize general model"
// FATAL messages are output to stderr
FLAGS_stderrthreshold = 3;
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册