From 82ef25c64110b94dbd8c4322bed9fc7f273cd4ad Mon Sep 17 00:00:00 2001 From: wangguibao Date: Thu, 14 Mar 2019 17:39:47 +0800 Subject: [PATCH] Add text_classification demo Change-Id: Iaf667d45b09aae30d38317bdf8e43f3ff41d016a --- cmake/external/opencv.cmake | 1 + cmake/paddlepaddle.cmake | 2 +- .../include/fluid_cpu_engine.h | 43 ++- predictor/src/pdserving.cpp | 2 + sdk-cpp/CMakeLists.txt | 20 +- sdk-cpp/conf/predictors.prototxt | 15 + sdk-cpp/demo/text_classification.cpp | 266 ++++++++++++++++++ sdk-cpp/proto/text_classification.proto | 37 +++ serving/CMakeLists.txt | 14 + serving/conf/model_toolkit.prototxt | 10 + serving/conf/service.prototxt | 5 + serving/conf/workflow.prototxt | 9 +- serving/op/text_classification_op.cpp | 124 ++++++++ serving/op/text_classification_op.h | 39 +++ serving/proto/CMakeLists.txt | 1 + serving/proto/text_classification.proto | 37 +++ 16 files changed, 599 insertions(+), 26 deletions(-) create mode 100644 sdk-cpp/demo/text_classification.cpp create mode 100644 sdk-cpp/proto/text_classification.proto create mode 100644 serving/op/text_classification_op.cpp create mode 100644 serving/op/text_classification_op.h create mode 100644 serving/proto/text_classification.proto diff --git a/cmake/external/opencv.cmake b/cmake/external/opencv.cmake index 31d58efc..07b0afdb 100644 --- a/cmake/external/opencv.cmake +++ b/cmake/external/opencv.cmake @@ -42,6 +42,7 @@ ExternalProject_Add( -DBUILD_PERF_TESTS=OFF -DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE} -DWITH_EIGEN=OFF + -DWITH_CUDA=OFF -DWITH_JPEG=ON -DBUILD_JPEG=ON -DWITH_PNG=ON diff --git a/cmake/paddlepaddle.cmake b/cmake/paddlepaddle.cmake index 0fcb8b6c..5a490c8e 100644 --- a/cmake/paddlepaddle.cmake +++ b/cmake/paddlepaddle.cmake @@ -30,7 +30,7 @@ ExternalProject_Add( ${EXTERNAL_PROJECT_LOG_ARGS} # TODO(wangguibao): change to de newst repo when they changed. GIT_REPOSITORY "https://github.com/PaddlePaddle/Paddle" - GIT_TAG "v1.2.0" + GIT_TAG "v1.3.0" PREFIX ${PADDLE_SOURCES_DIR} UPDATE_COMMAND "" BINARY_DIR ${CMAKE_BINARY_DIR}/Paddle diff --git a/inferencer-fluid-cpu/include/fluid_cpu_engine.h b/inferencer-fluid-cpu/include/fluid_cpu_engine.h index 84920edf..928d5f87 100644 --- a/inferencer-fluid-cpu/include/fluid_cpu_engine.h +++ b/inferencer-fluid-cpu/include/fluid_cpu_engine.h @@ -118,15 +118,15 @@ class FluidCpuAnalysisCore : public FluidFamilyCore { return -1; } - paddle::contrib::AnalysisConfig analysis_config; - analysis_config.param_file = data_path + "/__params__"; - analysis_config.prog_file = data_path + "/__model__"; - analysis_config.use_gpu = false; - analysis_config.device = 0; - analysis_config.specify_input_name = true; + paddle::AnalysisConfig analysis_config; + analysis_config.SetParamsFile(data_path + "/__params__"); + analysis_config.SetProgFile(data_path + "/__model__"); + analysis_config.DisableGpu(); + analysis_config.SetCpuMathLibraryNumThreads(1); + analysis_config.SwitchSpecifyInputNames(true); AutoLock lock(GlobalPaddleCreateMutex::instance()); - _core = paddle::CreatePaddlePredictor( - analysis_config); + _core = + paddle::CreatePaddlePredictor(analysis_config); if (NULL == _core.get()) { LOG(ERROR) << "create paddle predictor failed, path: " << data_path; return -1; @@ -174,14 +174,14 @@ class FluidCpuAnalysisDirCore : public FluidFamilyCore { return -1; } - paddle::contrib::AnalysisConfig analysis_config; - analysis_config.model_dir = data_path; - analysis_config.use_gpu = false; - analysis_config.device = 0; - analysis_config.specify_input_name = true; + paddle::AnalysisConfig analysis_config; + analysis_config.SetModel(data_path); + analysis_config.DisableGpu(); + analysis_config.SwitchSpecifyInputNames(true); + analysis_config.SetCpuMathLibraryNumThreads(1); AutoLock lock(GlobalPaddleCreateMutex::instance()); - _core = paddle::CreatePaddlePredictor( - analysis_config); + _core = + paddle::CreatePaddlePredictor(analysis_config); if (NULL == _core.get()) { LOG(ERROR) << "create paddle predictor failed, path: " << data_path; return -1; @@ -478,15 +478,14 @@ class FluidCpuAnalysisDirWithSigmoidCore : public FluidCpuWithSigmoidCore { return -1; } - paddle::contrib::AnalysisConfig analysis_config; - analysis_config.model_dir = data_path; - analysis_config.use_gpu = false; - analysis_config.device = 0; - analysis_config.specify_input_name = true; + paddle::AnalysisConfig analysis_config; + analysis_config.SetModel(data_path); + analysis_config.DisableGpu(); + analysis_config.SwitchSpecifyInputNames(true); + analysis_config.SetCpuMathLibraryNumThreads(1); AutoLock lock(GlobalPaddleCreateMutex::instance()); _core->_fluid_core = - paddle::CreatePaddlePredictor( - analysis_config); + paddle::CreatePaddlePredictor(analysis_config); if (NULL == _core.get()) { LOG(ERROR) << "create paddle predictor failed, path: " << data_path; return -1; diff --git a/predictor/src/pdserving.cpp b/predictor/src/pdserving.cpp index 5d88ba35..c4afc9c9 100644 --- a/predictor/src/pdserving.cpp +++ b/predictor/src/pdserving.cpp @@ -123,6 +123,8 @@ int main(int argc, char** argv) { } } google::InitGoogleLogging(strdup(argv[0])); + FLAGS_logbufsecs = 0; + FLAGS_logbuflevel = -1; LOG(INFO) << "Succ initialize logger"; diff --git a/sdk-cpp/CMakeLists.txt b/sdk-cpp/CMakeLists.txt index 9d875a1d..8e1ad709 100644 --- a/sdk-cpp/CMakeLists.txt +++ b/sdk-cpp/CMakeLists.txt @@ -1,3 +1,17 @@ +if (NOT EXISTS + ${CMAKE_CURRENT_LIST_DIR}/data/text_classification/test_set.txt) + execute_process(COMMAND wget + --no-check-certificate + https://paddle-serving.bj.bcebos.com/data/text_classification/test_set.tar.gz + --output-document + ${CMAKE_CURRENT_LIST_DIR}/data/text_classification/test_set.tar.gz) + + execute_process(COMMAND ${CMAKE_COMMAND} -E tar xzf + "${CMAKE_CURRENT_LIST_DIR}/data/text_classification/test_set.tar.gz" + WORKING_DIRECTORY + ${CMAKE_CURRENT_LIST_DIR}/data/text_classification + ) +endif() include(src/CMakeLists.txt) include(proto/CMakeLists.txt) add_library(sdk-cpp ${sdk_cpp_srcs}) @@ -43,8 +57,8 @@ install(TARGETS ximage ${PADDLE_SERVING_INSTALL_DIR}/demo/client/image_classification/bin) install(DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/conf DESTINATION ${PADDLE_SERVING_INSTALL_DIR}/demo/client/image_classification/) -install(DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/data DESTINATION - ${PADDLE_SERVING_INSTALL_DIR}/demo/client/image_classification/) +install(DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/data/images DESTINATION + ${PADDLE_SERVING_INSTALL_DIR}/demo/client/image_classification/data) install(TARGETS echo RUNTIME DESTINATION ${PADDLE_SERVING_INSTALL_DIR}/demo/client/echo/bin) @@ -74,3 +88,5 @@ install(TARGETS text_classification ${PADDLE_SERVING_INSTALL_DIR}/demo/client/text_classification/bin) install(DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/conf DESTINATION ${PADDLE_SERVING_INSTALL_DIR}/demo/client/text_classification/) +install(DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/data/text_classification DESTINATION + ${PADDLE_SERVING_INSTALL_DIR}/demo/client/text_classification/data) diff --git a/sdk-cpp/conf/predictors.prototxt b/sdk-cpp/conf/predictors.prototxt index ef552051..da59ad7b 100644 --- a/sdk-cpp/conf/predictors.prototxt +++ b/sdk-cpp/conf/predictors.prototxt @@ -94,3 +94,18 @@ predictors { } } } + +predictors { + name: "text_classification" + service_name: "baidu.paddle_serving.predictor.text_classification.TextClassificationService" + endpoint_router: "WeightedRandomRender" + weighted_random_render_conf { + variant_weight_list: "50" + } + variants { + tag: "var1" + naming_conf { + cluster: "list://127.0.0.1:8010" + } + } +} diff --git a/sdk-cpp/demo/text_classification.cpp b/sdk-cpp/demo/text_classification.cpp new file mode 100644 index 00000000..3e549fa5 --- /dev/null +++ b/sdk-cpp/demo/text_classification.cpp @@ -0,0 +1,266 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include +#include "predictor/builtin_format.pb.h" +#include "sdk-cpp/include/common.h" +#include "sdk-cpp/include/predictor_sdk.h" +#include "sdk-cpp/text_classification.pb.h" + +using baidu::paddle_serving::sdk_cpp::Predictor; +using baidu::paddle_serving::sdk_cpp::PredictorApi; +using baidu::paddle_serving::predictor::text_classification::TextReqInstance; +using baidu::paddle_serving::predictor::text_classification::TextResInstance; +using baidu::paddle_serving::predictor::text_classification::Request; +using baidu::paddle_serving::predictor::text_classification::Response; + +const char *g_test_file = "./data/text_classification/test_set.txt"; +DEFINE_int32(batch_size, 50, "Set the batch size of test file."); + +// Text Classification Data Feed +// +// Input format: +// ([termid list], truth_label) +// Where 'termid list' is a variant length id list, `truth label` is a single +// number (0 or 1) +// +const int MAX_LINE_SIZE = 1024 * 1024; +std::vector g_pred_labels; +const float g_decision_boundary = 0.500; + +class DataFeed { + public: + virtual ~DataFeed() {} + virtual void init(); + virtual bool set_file(const char *filename); + std::vector> &get_test_input() { return _test_input; } + std::vector &get_labels() { return _test_label; } + uint32_t sample_id() { return _current_sample_id; } + void set_sample_id(uint32_t sample_id) { _current_sample_id = sample_id; } + + private: + std::vector> _test_input; + std::vector _test_label; + uint32_t _current_sample_id; + int _batch_size; + std::shared_ptr _batch_id_buffer; + std::shared_ptr _label_buffer; +}; + +void DataFeed::init() { + _batch_id_buffer.reset(new int[10240 * 1024], [](int *p) { delete[] p; }); + _label_buffer.reset(new int[10240 * 1024], [](int *p) { delete[] p; }); +} + +bool DataFeed::set_file(const char *filename) { + std::ifstream ifs(filename); + char *line = new char[MAX_LINE_SIZE]; + int len = 0; + char *sequence_begin_ptr = NULL; + char *sequence_end_ptr = NULL; + char *id_begin_ptr = NULL; + char *id_end_ptr = NULL; + char *label_ptr = NULL; + int label = -1; + int id = -1; + while (!ifs.eof()) { + std::vector vec; + ifs.getline(line, MAX_LINE_SIZE); + len = strlen(line); + if (line[0] != '(' || line[len - 1] != ')') { + continue; + } + line[len - 1] = '\0'; + + sequence_begin_ptr = strchr(line, '(') + 1; + if (*sequence_begin_ptr != '[') { + continue; + } + + sequence_end_ptr = strchr(sequence_begin_ptr, ']'); + if (sequence_end_ptr == NULL) { + continue; + } + *sequence_end_ptr = '\0'; + + id_begin_ptr = sequence_begin_ptr; + while (id_begin_ptr != NULL) { + id_begin_ptr++; + id_end_ptr = strchr(id_begin_ptr, ','); + if (id_end_ptr != NULL) { + *id_end_ptr = '\0'; + } + id = atoi(id_begin_ptr); + id_begin_ptr = id_end_ptr; + vec.push_back(id); + } + + label_ptr = strchr(sequence_end_ptr + 1, ','); + if (label_ptr == NULL) { + continue; + } + *label_ptr = '\0'; + + label_ptr++; + label = atoi(label_ptr); + + _test_input.push_back(vec); + _test_label.push_back(label); + } + + ifs.close(); + + std::cout << "read record" << _test_input.size() << std::endl; + + return 0; +} + +int create_req(std::shared_ptr data_feed, Request &req) { // NOLINT + std::vector> &inputs = data_feed->get_test_input(); + uint32_t current_sample_id = data_feed->sample_id(); + int idx = 0; + + for (idx = 0; + idx < FLAGS_batch_size && current_sample_id + idx < inputs.size(); + ++idx) { + TextReqInstance *req_instance = req.add_instances(); + std::vector &sample = inputs.at(current_sample_id + idx); + for (auto x : sample) { + req_instance->add_ids(x); + } + } + + if (idx < FLAGS_batch_size) { + return -1; + } + + data_feed->set_sample_id(current_sample_id + FLAGS_batch_size); + return 0; +} + +void extract_res(const Request &req, const Response &res) { + uint32_t sample_size = res.predictions_size(); + std::string err_string; + for (uint32_t si = 0; si < sample_size; ++si) { + const TextResInstance &res_instance = res.predictions(si); + + if (res_instance.class_1_prob() < g_decision_boundary) { + g_pred_labels.push_back(0); + } else if (res_instance.class_1_prob() >= g_decision_boundary) { + g_pred_labels.push_back(1); + } + } +} + +int main(int argc, char **argv) { + PredictorApi api; + + // initialize logger instance + struct stat st_buf; + int ret = 0; + if ((ret = stat("./log", &st_buf)) != 0) { + mkdir("./log", 0777); + ret = stat("./log", &st_buf); + if (ret != 0) { + LOG(WARNING) << "Log path ./log not exist, and create fail"; + return -1; + } + } + FLAGS_log_dir = "./log"; + google::InitGoogleLogging(strdup(argv[0])); + FLAGS_logbufsecs = 0; + FLAGS_logbuflevel = -1; + + g_pred_labels.clear(); + + std::shared_ptr local_feed(new DataFeed()); + local_feed->init(); + local_feed->set_file(g_test_file); + + if (api.create("./conf", "predictors.prototxt") != 0) { + LOG(ERROR) << "Failed create predictors api!"; + return -1; + } + + Request req; + Response res; + + api.thrd_initialize(); + + uint64_t elapse_ms = 0; + while (true) { + api.thrd_clear(); + + Predictor *predictor = api.fetch_predictor("text_classification"); + if (!predictor) { + LOG(ERROR) << "Failed fetch predictor: text_classification"; + return -1; + } + + req.Clear(); + res.Clear(); + + if (create_req(local_feed, req) != 0) { + break; + } + + timeval start; + gettimeofday(&start, NULL); + + if (predictor->inference(&req, &res) != 0) { + LOG(ERROR) << "failed call predictor with req:" << req.ShortDebugString(); + return -1; + } + + timeval end; + gettimeofday(&end, NULL); + + elapse_ms += (end.tv_sec * 1000 + end.tv_usec / 1000) - + (start.tv_sec * 1000 + start.tv_usec / 1000); + +#if 1 + LOG(INFO) << "single round elapse time " + << (end.tv_sec * 1000000 + end.tv_usec) - + (start.tv_sec * 1000000 + start.tv_usec); +#endif + extract_res(req, res); + res.Clear(); + } // while (true) + + int correct = 0; + std::vector &truth_label = local_feed->get_labels(); + for (int i = 0; i < g_pred_labels.size(); ++i) { + if (g_pred_labels[i] == truth_label[i]) { + ++correct; + } + } + + LOG(INFO) << "Elapse ms " << elapse_ms; + double qps = (static_cast(g_pred_labels.size()) / elapse_ms) * 1000; + + LOG(INFO) << "QPS: " << qps << "/s"; + LOG(INFO) << "Accuracy " + << static_cast(correct) / g_pred_labels.size(); + + api.thrd_finalize(); + api.destroy(); + + return 0; +} + +/* vim: set expandtab ts=4 sw=4 sts=4 tw=100: */ diff --git a/sdk-cpp/proto/text_classification.proto b/sdk-cpp/proto/text_classification.proto new file mode 100644 index 00000000..64643351 --- /dev/null +++ b/sdk-cpp/proto/text_classification.proto @@ -0,0 +1,37 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; +import "pds_option.proto"; +import "builtin_format.proto"; +package baidu.paddle_serving.predictor.text_classification; + +option cc_generic_services = true; + +message TextReqInstance { repeated int64 ids = 1; }; + +message Request { repeated TextReqInstance instances = 1; }; + +message TextResInstance { + required float class_0_prob = 1; + required float class_1_prob = 2; +}; + +message Response { repeated TextResInstance predictions = 1; }; + +service TextClassificationService { + rpc inference(Request) returns (Response); + rpc debug(Request) returns (Response); + option (pds.options).generate_stub = true; +}; diff --git a/serving/CMakeLists.txt b/serving/CMakeLists.txt index 928e7417..3619d91f 100644 --- a/serving/CMakeLists.txt +++ b/serving/CMakeLists.txt @@ -1,3 +1,17 @@ +if (NOT EXISTS + ${CMAKE_CURRENT_LIST_DIR}/data/model/paddle/fluid/text_classification_lstm) + execute_process(COMMAND wget + --no-check-certificate https://paddle-serving.bj.bcebos.com/data/text_classification/text_classification_lstm.tar.gz + --output-document + ${CMAKE_CURRENT_LIST_DIR}/data/model/paddle/fluid/text_classification_lstm.tar.gz) + + execute_process(COMMAND ${CMAKE_COMMAND} -E tar xzf + "${CMAKE_CURRENT_LIST_DIR}/data/model/paddle/fluid/text_classification_lstm.tar.gz" + WORKING_DIRECTORY + ${CMAKE_CURRENT_LIST_DIR}/data/model/paddle/fluid + ) +endif() + find_library(MKLML_LIBS NAMES libmklml_intel.so libiomp5.so) include(op/CMakeLists.txt) include(proto/CMakeLists.txt) diff --git a/serving/conf/model_toolkit.prototxt b/serving/conf/model_toolkit.prototxt index 2693e35c..269e3474 100644 --- a/serving/conf/model_toolkit.prototxt +++ b/serving/conf/model_toolkit.prototxt @@ -8,3 +8,13 @@ engines { batch_infer_size: 0 enable_batch_align: 0 } +engines { + name: "text_classification_bow" + type: "FLUID_CPU_ANALYSIS_DIR" + reloadable_meta: "./data/model/paddle/fluid_time_file" + reloadable_type: "timestamp_ne" + model_data_path: "./data/model/paddle/fluid/text_classification_lstm" + runtime_thread_num: 0 + batch_infer_size: 0 + enable_batch_align: 0 +} diff --git a/serving/conf/service.prototxt b/serving/conf/service.prototxt index 15a4156f..b9b2c52c 100644 --- a/serving/conf/service.prototxt +++ b/serving/conf/service.prototxt @@ -22,3 +22,8 @@ services { workflows: "workflow5" } +services { + name: "TextClassificationService" + workflows: "workflow6" +} + diff --git a/serving/conf/workflow.prototxt b/serving/conf/workflow.prototxt index 0300b370..33538654 100644 --- a/serving/conf/workflow.prototxt +++ b/serving/conf/workflow.prototxt @@ -51,7 +51,6 @@ workflows { } } } - workflows { name: "workflow5" workflow_type: "Sequence" @@ -60,4 +59,12 @@ workflows { type: "Int64TensorEchoOp" } } +workflows { + name: "workflow6" + workflow_type: "Sequence" + nodes { + name: "text_classify_op" + type: "TextClassificationOp" + } +} diff --git a/serving/op/text_classification_op.cpp b/serving/op/text_classification_op.cpp new file mode 100644 index 00000000..a47fd11a --- /dev/null +++ b/serving/op/text_classification_op.cpp @@ -0,0 +1,124 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "serving/op/text_classification_op.h" +#include +#include "predictor/framework/infer.h" +#include "predictor/framework/memory.h" + +namespace baidu { +namespace paddle_serving { +namespace serving { + +using baidu::paddle_serving::predictor::MempoolWrapper; +using baidu::paddle_serving::predictor::text_classification::TextResInstance; +using baidu::paddle_serving::predictor::text_classification::Response; +using baidu::paddle_serving::predictor::text_classification::TextReqInstance; +using baidu::paddle_serving::predictor::text_classification::Request; + +int TextClassificationOp::inference() { + const Request *req = dynamic_cast(get_request_message()); + + TensorVector *in = butil::get_object(); + uint32_t sample_size = req->instances_size(); + if (sample_size <= 0) { + LOG(WARNING) << "No instances need to inference!"; + return -1; + } + + paddle::PaddleTensor lod_tensor; + lod_tensor.dtype = paddle::PaddleDType::INT64; + std::vector> &lod = lod_tensor.lod; + lod.resize(1); + lod[0].push_back(0); + + for (uint32_t si = 0; si < sample_size; ++si) { + const TextReqInstance &req_instance = req->instances(si); + lod[0].push_back(lod[0].back() + req_instance.ids_size()); + } + + lod_tensor.shape = {lod[0].back(), 1}; + lod_tensor.data.Resize(lod[0].back() * sizeof(int64_t)); + + int offset = 0; + for (uint32_t si = 0; si < sample_size; ++si) { + // parse text sequence + int64_t *data_ptr = static_cast(lod_tensor.data.data()) + offset; + const TextReqInstance &req_instance = req->instances(si); + int id_count = req_instance.ids_size(); + memcpy(data_ptr, + req_instance.ids().data(), + sizeof(int64_t) * req_instance.ids_size()); + offset += req_instance.ids_size(); + } + + in->push_back(lod_tensor); + + TensorVector *out = butil::get_object(); + if (!out) { + LOG(ERROR) << "Failed get tls output object"; + return -1; + } + + // call paddle fluid model for inferencing + if (predictor::InferManager::instance().infer( + TEXT_CLASSIFICATION_MODEL_NAME, in, out, sample_size)) { + LOG(ERROR) << "Failed do infer in fluid model: " + << TEXT_CLASSIFICATION_MODEL_NAME; + return -1; + } + + if (out->size() != in->size()) { + LOG(ERROR) << "Output tensor size not equal that of input"; + return -1; + } + + Response *res = mutable_data(); + + for (size_t i = 0; i < out->size(); ++i) { + int dim1 = out->at(i).shape[0]; + int dim2 = out->at(i).shape[1]; + + if (out->at(i).dtype != paddle::PaddleDType::FLOAT32) { + LOG(ERROR) << "Expected data type float"; + return -1; + } + + float *data = static_cast(out->at(i).data.data()); + for (int j = 0; j < dim1; ++j) { + TextResInstance *res_instance = res->add_predictions(); + res_instance->set_class_0_prob(data[j * dim2]); + res_instance->set_class_1_prob(data[j * dim2 + 1]); + } + } + + for (size_t i = 0; i < in->size(); ++i) { + (*in)[i].shape.clear(); + } + in->clear(); + butil::return_object(in); + + for (size_t i = 0; i < out->size(); ++i) { + (*out)[i].shape.clear(); + } + out->clear(); + butil::return_object(out); + return 0; +} + +DEFINE_OP(TextClassificationOp); + +} // namespace serving +} // namespace paddle_serving +} // namespace baidu diff --git a/serving/op/text_classification_op.h b/serving/op/text_classification_op.h new file mode 100644 index 00000000..5f67a6e0 --- /dev/null +++ b/serving/op/text_classification_op.h @@ -0,0 +1,39 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "paddle/fluid/inference/paddle_inference_api.h" +#include "serving/text_classification.pb.h" + +namespace baidu { +namespace paddle_serving { +namespace serving { + +static const char* TEXT_CLASSIFICATION_MODEL_NAME = "text_classification_bow"; + +class TextClassificationOp + : public baidu::paddle_serving::predictor::OpWithChannel< + baidu::paddle_serving::predictor::text_classification::Response> { + public: + typedef std::vector TensorVector; + + DECLARE_OP(TextClassificationOp); + + int inference(); +}; + +} // namespace serving +} // namespace paddle_serving +} // namespace baidu diff --git a/serving/proto/CMakeLists.txt b/serving/proto/CMakeLists.txt index 58e2b2ef..dcf26f7d 100644 --- a/serving/proto/CMakeLists.txt +++ b/serving/proto/CMakeLists.txt @@ -4,6 +4,7 @@ LIST(APPEND protofiles ${CMAKE_CURRENT_LIST_DIR}/sparse_service.proto ${CMAKE_CURRENT_LIST_DIR}/echo_service.proto ${CMAKE_CURRENT_LIST_DIR}/int64tensor_service.proto + ${CMAKE_CURRENT_LIST_DIR}/text_classification.proto ) PROTOBUF_GENERATE_SERVING_CPP(PROTO_SRCS PROTO_HDRS ${protofiles}) diff --git a/serving/proto/text_classification.proto b/serving/proto/text_classification.proto new file mode 100644 index 00000000..d984ea80 --- /dev/null +++ b/serving/proto/text_classification.proto @@ -0,0 +1,37 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; +import "pds_option.proto"; +import "builtin_format.proto"; +package baidu.paddle_serving.predictor.text_classification; + +option cc_generic_services = true; + +message TextReqInstance { repeated int64 ids = 1; }; + +message Request { repeated TextReqInstance instances = 1; }; + +message TextResInstance { + required float class_0_prob = 1; + required float class_1_prob = 2; +}; + +message Response { repeated TextResInstance predictions = 1; }; + +service TextClassificationService { + rpc inference(Request) returns (Response); + rpc debug(Request) returns (Response); + option (pds.options).generate_impl = true; +}; -- GitLab