// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include #include #include #include #include "core/sdk-cpp/builtin_format.pb.h" #include "core/sdk-cpp/general_model_service.pb.h" #include "core/sdk-cpp/include/common.h" #include "core/sdk-cpp/include/predictor_sdk.h" using baidu::paddle_serving::sdk_cpp::Predictor; using baidu::paddle_serving::sdk_cpp::PredictorApi; DECLARE_bool(profile_client); DECLARE_bool(profile_server); // given some input data, pack into pb, and send request namespace baidu { namespace paddle_serving { namespace general_model { class ModelRes { public: ModelRes() {} ~ModelRes() {} public: const std::vector>& get_int64_by_name( const std::string& name) { return _int64_map[name]; } const std::vector>& get_float_by_name( const std::string& name) { return _float_map[name]; } public: std::map>> _int64_map; std::map>> _float_map; }; class PredictorRes { public: PredictorRes() {} ~PredictorRes() {} public: void clear() { _models.clear();} const std::vector>& get_int64_by_name( const int model_idx, const std::string& name) { return _models[model_idx].get_int64_by_name(name); } const std::vector>& get_float_by_name( const int model_idx, const std::string& name) { return _models[model_idx].get_float_by_name(name); } void set_variant_tag(const std::string& variant_tag) { _variant_tag = variant_tag; } const std::string& variant_tag() { return _variant_tag; } int model_num() {return _models.size();} std::vector _models; private: std::string _variant_tag; }; class PredictorClient { public: PredictorClient() {} ~PredictorClient() {} void init_gflags(std::vector argv); int init(const std::string& client_conf); void set_predictor_conf(const std::string& conf_path, const std::string& conf_file); int create_predictor_by_desc(const std::string& sdk_desc); int create_predictor(); int destroy_predictor(); int predict(const std::vector>& float_feed, const std::vector& float_feed_name, const std::vector>& int_feed, const std::vector& int_feed_name, const std::vector& fetch_name, PredictorRes& predict_res, // NOLINT const int& pid); int batch_predict( const std::vector>>& float_feed_batch, const std::vector& float_feed_name, const std::vector>>& int_feed_batch, const std::vector& int_feed_name, const std::vector& fetch_name, PredictorRes& predict_res_batch, // NOLINT const int& pid); private: PredictorApi _api; Predictor* _predictor; std::string _predictor_conf; std::string _predictor_path; std::string _conf_file; std::map _feed_name_to_idx; std::map _fetch_name_to_idx; std::map _fetch_name_to_var_name; std::map _fetch_name_to_type; std::vector> _shape; std::vector _type; std::vector _last_request_ts; }; } // namespace general_model } // namespace paddle_serving } // namespace baidu /* vim: set expandtab ts=4 sw=4 sts=4 tw=100: */