diff --git a/core/cube/CMakeLists.txt b/core/cube/CMakeLists.txt index f9dc4d2c2508720f450b4aee3aba5dfdd7ccd43b..a61d2df92a92bc26fabd4a3cf87c6db1dc1cc3f0 100644 --- a/core/cube/CMakeLists.txt +++ b/core/cube/CMakeLists.txt @@ -11,10 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License - -#execute_process(COMMAND go env -w GO111MODULE=off) add_subdirectory(cube-server) add_subdirectory(cube-api) add_subdirectory(cube-builder) -#add_subdirectory(cube-transfer) -#add_subdirectory(cube-agent) +add_subdirectory(cube-transfer) +add_subdirectory(cube-agent) diff --git a/core/cube/cube-agent/CMakeLists.txt b/core/cube/cube-agent/CMakeLists.txt index 30158aa506e53ec8a37d10aef4f29bfcd5a60d06..701f0c8a55e3326e1327f3b1f68458f99c60143b 100644 --- a/core/cube/cube-agent/CMakeLists.txt +++ b/core/cube/cube-agent/CMakeLists.txt @@ -15,7 +15,6 @@ set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake") project(cube-agent Go) - include(cmake/golang.cmake) ExternalGoProject_Add(agent-docopt-go github.com/docopt/docopt-go) diff --git a/core/cube/cube-transfer/CMakeLists.txt b/core/cube/cube-transfer/CMakeLists.txt index 78e47c5b840631a3092f3a799e2424d370444a2e..2e9d3dede03c5b27bcd0e24eaa6584df343c09e2 100644 --- a/core/cube/cube-transfer/CMakeLists.txt +++ b/core/cube/cube-transfer/CMakeLists.txt @@ -18,11 +18,9 @@ project(cube-transfer Go) include(cmake/golang.cmake) -ExternalGoProject_Add(rfw github.com/mipearson/rfw) -ExternalGoProject_Add(docopt-go github.com/docopt/docopt-go) -add_custom_target(logex - COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} get github.com/Badangel/logex - DEPENDS rfw) +ExternalGoProject_Add(transfer-rfw github.com/mipearson/rfw) +ExternalGoProject_Add(transfer-docopt-go github.com/docopt/docopt-go) +ExternalGoProject_Add(transfer-logex github.com/Badangel/logex) add_subdirectory(src) install(DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/conf DESTINATION ${PADDLE_SERVING_INSTALL_DIR}) diff --git a/core/cube/cube-transfer/src/CMakeLists.txt b/core/cube/cube-transfer/src/CMakeLists.txt index 62d3f7ef7759a0d2a09eb4fe32a064694ece5408..b71278537a2ee03468019e7bd7e5ec4d786becf2 100644 --- a/core/cube/cube-transfer/src/CMakeLists.txt +++ b/core/cube/cube-transfer/src/CMakeLists.txt @@ -14,6 +14,6 @@ set(SOURCE_FILE cube-transfer.go) add_go_executable(cube-transfer ${SOURCE_FILE}) -add_dependencies(cube-transfer docopt-go) -add_dependencies(cube-transfer rfw) -add_dependencies(cube-transfer logex) +add_dependencies(cube-transfer transfer-docopt-go) +add_dependencies(cube-transfer transfer-rfw) +add_dependencies(cube-transfer transfer-logex) diff --git a/core/general-server/op/general_dist_kv_infer_op.cpp b/core/general-server/op/general_dist_kv_infer_op.cpp index 6cfd88788063a13d60f0f8ada29711760ecae174..063af3b51112029e9f1cb941d1c9668728c418fb 100644 --- a/core/general-server/op/general_dist_kv_infer_op.cpp +++ b/core/general-server/op/general_dist_kv_infer_op.cpp @@ -37,7 +37,149 @@ using baidu::paddle_serving::predictor::general_model::Request; using baidu::paddle_serving::predictor::InferManager; using baidu::paddle_serving::predictor::PaddleGeneralModelConfig; -int GeneralDistKVInferOp::inference() { return 0; } +// DistKV Infer Op: seek cube and then call paddle inference +// op seq: general_reader-> dist_kv_infer -> general_response +int GeneralDistKVInferOp::inference() { + VLOG(2) << "Going to run inference"; + const std::vector pre_node_names = pre_names(); + if (pre_node_names.size() != 1) { + LOG(ERROR) << "This op(" << op_name() + << ") can only have one predecessor op, but received " + << pre_node_names.size(); + return -1; + } + const std::string pre_name = pre_node_names[0]; + + const GeneralBlob *input_blob = get_depend_argument(pre_name); + if (!input_blob) { + LOG(ERROR) << "input_blob is nullptr,error"; + return -1; + } + uint64_t log_id = input_blob->GetLogId(); + VLOG(2) << "(logid=" << log_id << ") Get precedent op name: " << pre_name; + + GeneralBlob *output_blob = mutable_data(); + if (!output_blob) { + LOG(ERROR) << "(logid=" << log_id << ") output_blob is nullptr,error"; + return -1; + } + output_blob->SetLogId(log_id); + + if (!input_blob) { + LOG(ERROR) << "(logid=" << log_id + << ") Failed mutable depended argument, op:" << pre_name; + return -1; + } + + const TensorVector *in = &input_blob->tensor_vector; + TensorVector *out = &output_blob->tensor_vector; + std::vector keys; + std::vector values; + int sparse_count = 0; // sparse inputs counts, sparse would seek cube + int dense_count = 0; // dense inputs counts, dense would directly call paddle infer + std::vector> dataptr_size_pairs; + size_t key_len = 0; + for (size_t i = 0; i < in->size(); ++i) { + if (in->at(i).dtype != paddle::PaddleDType::INT64) { + ++dense_count; + continue; + } + ++sparse_count; + size_t elem_num = 1; + for (size_t s = 0; s < in->at(i).shape.size(); ++s) { + elem_num *= in->at(i).shape[s]; + } + key_len += elem_num; + int64_t *data_ptr = static_cast(in->at(i).data.data()); + dataptr_size_pairs.push_back(std::make_pair(data_ptr, elem_num)); + } + keys.resize(key_len); + VLOG(3) << "(logid=" << log_id << ") cube number of keys to look up: " << key_len; + int key_idx = 0; + for (size_t i = 0; i < dataptr_size_pairs.size(); ++i) { + std::copy(dataptr_size_pairs[i].first, + dataptr_size_pairs[i].first + dataptr_size_pairs[i].second, + keys.begin() + key_idx); + key_idx += dataptr_size_pairs[i].second; + } + rec::mcube::CubeAPI *cube = rec::mcube::CubeAPI::instance(); + std::vector table_names = cube->get_table_names(); + if (table_names.size() == 0) { + LOG(ERROR) << "cube init error or cube config not given."; + return -1; + } + // gather keys and seek cube servers, put results in values + int ret = cube->seek(table_names[0], keys, &values); + VLOG(3) << "(logid=" << log_id << ") cube seek status: " << ret; + if (values.size() != keys.size() || values[0].buff.size() == 0) { + LOG(ERROR) << "cube value return null"; + } + // EMBEDDING_SIZE means the length of sparse vector, user can define length here. + size_t EMBEDDING_SIZE = values[0].buff.size() / sizeof(float); + TensorVector sparse_out; + sparse_out.resize(sparse_count); + TensorVector dense_out; + dense_out.resize(dense_count); + int cube_val_idx = 0; + int sparse_idx = 0; + int dense_idx = 0; + std::unordered_map in_out_map; + baidu::paddle_serving::predictor::Resource &resource = + baidu::paddle_serving::predictor::Resource::instance(); + std::shared_ptr model_config = resource.get_general_model_config().front(); + //copy data to tnsor + for (size_t i = 0; i < in->size(); ++i) { + if (in->at(i).dtype != paddle::PaddleDType::INT64) { + dense_out[dense_idx] = in->at(i); + ++dense_idx; + continue; + } + sparse_out[sparse_idx].lod.resize(in->at(i).lod.size()); + for (size_t x = 0; x < sparse_out[sparse_idx].lod.size(); ++x) { + sparse_out[sparse_idx].lod[x].resize(in->at(i).lod[x].size()); + std::copy(in->at(i).lod[x].begin(), + in->at(i).lod[x].end(), + sparse_out[sparse_idx].lod[x].begin()); + } + sparse_out[sparse_idx].dtype = paddle::PaddleDType::FLOAT32; + sparse_out[sparse_idx].shape.push_back(sparse_out[sparse_idx].lod[0].back()); + sparse_out[sparse_idx].shape.push_back(EMBEDDING_SIZE); + sparse_out[sparse_idx].name = model_config->_feed_name[i]; + sparse_out[sparse_idx].data.Resize(sparse_out[sparse_idx].lod[0].back() * + EMBEDDING_SIZE * sizeof(float)); + float *dst_ptr = static_cast(sparse_out[sparse_idx].data.data()); + for (int x = 0; x < sparse_out[sparse_idx].lod[0].back(); ++x) { + float *data_ptr = dst_ptr + x * EMBEDDING_SIZE; + memcpy(data_ptr, + values[cube_val_idx].buff.data(), + values[cube_val_idx].buff.size()); + cube_val_idx++; + } + ++sparse_idx; + } + VLOG(3) << "(logid=" << log_id << ") sparse tensor load success."; + TensorVector infer_in; + infer_in.insert(infer_in.end(), dense_out.begin(), dense_out.end()); + infer_in.insert(infer_in.end(), sparse_out.begin(), sparse_out.end()); + int batch_size = input_blob->_batch_size; + output_blob->_batch_size = batch_size; + Timer timeline; + int64_t start = timeline.TimeStampUS(); + timeline.Start(); + // call paddle inference here + if (InferManager::instance().infer( + engine_name().c_str(), &infer_in, out, batch_size)) { + LOG(ERROR) << << "(logid=" << log_id << ") Failed do infer in fluid model: " << engine_name(); + return -1; + } + int64_t end = timeline.TimeStampUS(); + + CopyBlobInfo(input_blob, output_blob); + AddBlobInfo(output_blob, start); + AddBlobInfo(output_blob, end); + return 0; + +} DEFINE_OP(GeneralDistKVInferOp); } // namespace serving diff --git a/doc/HTTP_SERVICE_CN.md b/doc/HTTP_SERVICE_CN.md index f442e1e8ca4a68123106dbd25fd7b93cf171caa2..ff7082b0c6c2f091a199420be45ce83403befdd4 100644 --- a/doc/HTTP_SERVICE_CN.md +++ b/doc/HTTP_SERVICE_CN.md @@ -12,7 +12,7 @@ BRPC-Server会尝试去JSON字符串中再去反序列化出Proto格式的数据 ### Http+protobuf方式 各种语言都提供了对ProtoBuf的支持,如果您对此比较熟悉,您也可以先将数据使用ProtoBuf序列化,再将序列化后的数据放入Http请求数据体中,然后指定Content-Type: application/proto,从而使用http/h2+protobuf二进制串访问服务。 - +实测随着数据量的增大,使用JSON方式的Http的数据量和反序列化的耗时会大幅度增加,推荐当您的数据量较大时,使用Http+protobuf方式,后续我们会在框架的HttpClient中增加该功能,目前暂没有支持。 **理论上讲,序列化/反序列化的性能从高到底排序为:protobuf > http/h2+protobuf > http** @@ -109,7 +109,7 @@ repeated int32 numbers = 1; ### Http压缩 -支持gzip压缩,但gzip并不是一个压缩解压速度非常快的方法,当数据量较小时候,使用gzip压缩反而会得不偿失,推荐至少数据大于512字节时才考虑使用gzip压缩。 +支持gzip压缩,但gzip并不是一个压缩解压速度非常快的方法,当数据量较小时候,使用gzip压缩反而会得不偿失,推荐至少数据大于512字节时才考虑使用gzip压缩,实测结果是当数据量小于50K时,压缩的收益都不大。 #### Client请求的数据体压缩 diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index de1fe2843bd32a9cc9e2aa567c0ddddd7457c67c..589420ad45ae7f347c8e7b9b25c5cc0034830263 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -81,7 +81,6 @@ if (SERVER) if(WITH_LITE) set(VERSION_SUFFIX 2) endif() - add_custom_command( OUTPUT ${PADDLE_SERVING_BINARY_DIR}/.timestamp COMMAND cp -r diff --git a/python/examples/criteo_ctr_with_cube/README.md b/python/examples/criteo_ctr_with_cube/README.md new file mode 100755 index 0000000000000000000000000000000000000000..493b3d72c1fff9275c2a99cfee45efd4bef1af4c --- /dev/null +++ b/python/examples/criteo_ctr_with_cube/README.md @@ -0,0 +1,72 @@ +## Criteo CTR with Sparse Parameter Indexing Service + +([简体中文](./README_CN.md)|English) + +### Get Sample Dataset + +go to directory `python/examples/criteo_ctr_with_cube` +``` +sh get_data.sh +``` + +### Download Model and Sparse Parameter Sequence Files +``` +wget https://paddle-serving.bj.bcebos.com/unittest/ctr_cube_unittest.tar.gz +tar xf ctr_cube_unittest.tar.gz +mv models/ctr_client_conf ./ +mv models/ctr_serving_model_kv ./ +mv models/data ./cube/ +``` +the model will be in ./ctr_server_model_kv and ./ctr_client_config. + +### Start Sparse Parameter Indexing Service +``` +wget https://paddle-serving.bj.bcebos.com/others/cube_app.tar.gz +tar xf cube_app.tar.gz +mv cube_app/cube* ./cube/ +sh cube_prepare.sh & +``` + +Here, the sparse parameter is loaded by cube sparse parameter indexing service Cube. + +### Start RPC Predictor, the number of serving thread is 4(configurable in test_server.py) + +``` +python test_server.py ctr_serving_model_kv +``` + +### Run Prediction + +``` +python test_client.py ctr_client_conf/serving_client_conf.prototxt ./raw_data +``` + +### Benchmark + +CPU :Intel(R) Xeon(R) CPU 6148 @ 2.40GHz + +Model :[Criteo CTR](https://github.com/PaddlePaddle/Serving/blob/develop/python/examples/criteo_ctr_with_cube/network_conf.py) + +server core/thread num : 4/8 + +Run +``` +bash benchmark.sh +``` +1000 batches will be sent by every client + +| client thread num | prepro | client infer | op0 | op1 | op2 | postpro | avg_latency | qps | +| ------------------ | ------ | ------------ | ------ | ----- | ------ | ------- | ----- | ----- | +| 1 | 0.035 | 1.596 | 0.021 | 0.518 | 0.0024 | 0.0025 | 6.774 | 147.7 | +| 2 | 0.034 | 1.780 | 0.027 | 0.463 | 0.0020 | 0.0023 | 6.931 | 288.3 | +| 4 | 0.038 | 2.954 | 0.025 | 0.455 | 0.0019 | 0.0027 | 8.378 | 477.5 | +| 8 | 0.044 | 8.230 | 0.028 | 0.464 | 0.0023 | 0.0034 | 14.191 | 563.8 | +| 16 | 0.048 | 21.037 | 0.028 | 0.455 | 0.0025 | 0.0041 | 27.236 | 587.5 | + +the average latency of threads + +![avg cost](../../../doc/criteo-cube-benchmark-avgcost.png) + +The QPS is + +![qps](../../../doc/criteo-cube-benchmark-qps.png) diff --git a/python/examples/criteo_ctr_with_cube/README_CN.md b/python/examples/criteo_ctr_with_cube/README_CN.md new file mode 100644 index 0000000000000000000000000000000000000000..7a0eb43c203aafeb38b64d249954cdabf7bf7a38 --- /dev/null +++ b/python/examples/criteo_ctr_with_cube/README_CN.md @@ -0,0 +1,70 @@ +## 带稀疏参数索引服务的CTR预测服务 +(简体中文|[English](./README.md)) + +### 获取样例数据 +进入目录 `python/examples/criteo_ctr_with_cube` +``` +sh get_data.sh +``` + +### 下载模型和稀疏参数序列文件 +``` +wget https://paddle-serving.bj.bcebos.com/unittest/ctr_cube_unittest.tar.gz +tar xf ctr_cube_unittest.tar.gz +mv models/ctr_client_conf ./ +mv models/ctr_serving_model_kv ./ +mv models/data ./cube/ +``` +执行脚本后会在当前目录有ctr_server_model_kv和ctr_client_config文件夹。 + +### 启动稀疏参数索引服务 +``` +wget https://paddle-serving.bj.bcebos.com/others/cube_app.tar.gz +tar xf cube_app.tar.gz +mv cube_app/cube* ./cube/ +sh cube_prepare.sh & +``` + +此处,模型当中的稀疏参数会被存放在稀疏参数索引服务Cube当中。 + +### 启动RPC预测服务,服务端线程数为4(可在test_server.py配置) + +``` +python test_server.py ctr_serving_model_kv +``` + +### 执行预测 + +``` +python test_client.py ctr_client_conf/serving_client_conf.prototxt ./raw_data +``` + +### Benchmark + +设备 :Intel(R) Xeon(R) CPU 6148 @ 2.40GHz + +模型 :[Criteo CTR](https://github.com/PaddlePaddle/Serving/blob/develop/python/examples/criteo_ctr_with_cube/network_conf.py) + +server core/thread num : 4/8 + +执行 +``` +bash benchmark.sh +``` +客户端每个线程会发送1000个batch + +| client thread num | prepro | client infer | op0 | op1 | op2 | postpro | avg_latency | qps | +| ------------------ | ------ | ------------ | ------ | ----- | ------ | ------- | ----- | ----- | +| 1 | 0.035 | 1.596 | 0.021 | 0.518 | 0.0024 | 0.0025 | 6.774 | 147.7 | +| 2 | 0.034 | 1.780 | 0.027 | 0.463 | 0.0020 | 0.0023 | 6.931 | 288.3 | +| 4 | 0.038 | 2.954 | 0.025 | 0.455 | 0.0019 | 0.0027 | 8.378 | 477.5 | +| 8 | 0.044 | 8.230 | 0.028 | 0.464 | 0.0023 | 0.0034 | 14.191 | 563.8 | +| 16 | 0.048 | 21.037 | 0.028 | 0.455 | 0.0025 | 0.0041 | 27.236 | 587.5 | + +平均每个线程耗时图如下 + +![avg cost](../../../doc/criteo-cube-benchmark-avgcost.png) + +每个线程QPS耗时如下 + +![qps](../../../doc/criteo-cube-benchmark-qps.png) diff --git a/python/examples/criteo_ctr_with_cube/criteo_reader.py b/python/examples/criteo_ctr_with_cube/criteo_reader.py new file mode 100755 index 0000000000000000000000000000000000000000..2a80af78a9c2033bf246f703ca70a817ab786af3 --- /dev/null +++ b/python/examples/criteo_ctr_with_cube/criteo_reader.py @@ -0,0 +1,83 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# pylint: disable=doc-string-missing + +import sys +import paddle.fluid.incubate.data_generator as dg + + +class CriteoDataset(dg.MultiSlotDataGenerator): + def setup(self, sparse_feature_dim): + self.cont_min_ = [0, -3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + self.cont_max_ = [ + 20, 600, 100, 50, 64000, 500, 100, 50, 500, 10, 10, 10, 50 + ] + self.cont_diff_ = [ + 20, 603, 100, 50, 64000, 500, 100, 50, 500, 10, 10, 10, 50 + ] + self.hash_dim_ = sparse_feature_dim + # here, training data are lines with line_index < train_idx_ + self.train_idx_ = 41256555 + self.continuous_range_ = range(1, 14) + self.categorical_range_ = range(14, 40) + + def _process_line(self, line): + features = line.rstrip('\n').split('\t') + dense_feature = [] + sparse_feature = [] + for idx in self.continuous_range_: + if features[idx] == '': + dense_feature.append(0.0) + else: + dense_feature.append((float(features[idx]) - self.cont_min_[idx - 1]) / \ + self.cont_diff_[idx - 1]) + for idx in self.categorical_range_: + sparse_feature.append( + [hash(str(idx) + features[idx]) % self.hash_dim_]) + + return dense_feature, sparse_feature, [int(features[0])] + + def infer_reader(self, filelist, batch, buf_size): + def local_iter(): + for fname in filelist: + with open(fname.strip(), "r") as fin: + for line in fin: + dense_feature, sparse_feature, label = self._process_line( + line) + #yield dense_feature, sparse_feature, label + yield [dense_feature] + sparse_feature + [label] + + import paddle + batch_iter = paddle.batch( + paddle.reader.shuffle( + local_iter, buf_size=buf_size), + batch_size=batch) + return batch_iter + + def generate_sample(self, line): + def data_iter(): + dense_feature, sparse_feature, label = self._process_line(line) + feature_name = ["dense_input"] + for idx in self.categorical_range_: + feature_name.append("C" + str(idx - 13)) + feature_name.append("label") + yield zip(feature_name, [dense_feature] + sparse_feature + [label]) + + return data_iter + + +if __name__ == "__main__": + criteo_dataset = CriteoDataset() + criteo_dataset.setup(int(sys.argv[1])) + criteo_dataset.run_from_stdin() diff --git a/python/examples/criteo_ctr_with_cube/get_data.sh b/python/examples/criteo_ctr_with_cube/get_data.sh new file mode 100755 index 0000000000000000000000000000000000000000..1f244b3a4aa81488bb493825576ba30c4b3bba22 --- /dev/null +++ b/python/examples/criteo_ctr_with_cube/get_data.sh @@ -0,0 +1,2 @@ +wget --no-check-certificate https://paddle-serving.bj.bcebos.com/data/ctr_prediction/ctr_data.tar.gz +tar -zxvf ctr_data.tar.gz diff --git a/python/examples/criteo_ctr_with_cube/local_train.py b/python/examples/criteo_ctr_with_cube/local_train.py new file mode 100755 index 0000000000000000000000000000000000000000..555e2e929c170c24a3175a88144ff74356d82514 --- /dev/null +++ b/python/examples/criteo_ctr_with_cube/local_train.py @@ -0,0 +1,101 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# pylint: disable=doc-string-missing + +from __future__ import print_function + +from args import parse_args +import os +import paddle.fluid as fluid +import paddle +import sys +from network_conf import dnn_model + +dense_feature_dim = 13 + +paddle.enable_static() +def train(): + args = parse_args() + sparse_only = args.sparse_only + if not os.path.isdir(args.model_output_dir): + os.mkdir(args.model_output_dir) + dense_input = fluid.layers.data( + name="dense_input", shape=[dense_feature_dim], dtype='float32') + sparse_input_ids = [ + fluid.layers.data( + name="C" + str(i), shape=[1], lod_level=1, dtype="int64") + for i in range(1, 27) + ] + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + + #nn_input = None if sparse_only else dense_input + nn_input = dense_input + predict_y, loss, auc_var, batch_auc_var, infer_vars = dnn_model( + nn_input, sparse_input_ids, label, args.embedding_size, + args.sparse_feature_dim) + + optimizer = fluid.optimizer.SGD(learning_rate=1e-4) + optimizer.minimize(loss) + + exe = fluid.Executor(fluid.CPUPlace()) + exe.run(fluid.default_startup_program()) + dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") + dataset.set_use_var([dense_input] + sparse_input_ids + [label]) + + python_executable = "python3.6" + pipe_command = "{} criteo_reader.py {}".format(python_executable, + args.sparse_feature_dim) + + dataset.set_pipe_command(pipe_command) + dataset.set_batch_size(128) + thread_num = 10 + dataset.set_thread(thread_num) + + whole_filelist = [ + "raw_data/part-%d" % x for x in range(len(os.listdir("raw_data"))) + ] + + print(whole_filelist) + dataset.set_filelist(whole_filelist[:100]) + dataset.load_into_memory() + fluid.layers.Print(auc_var) + epochs = 1 + for i in range(epochs): + exe.train_from_dataset( + program=fluid.default_main_program(), dataset=dataset, debug=True) + print("epoch {} finished".format(i)) + + import paddle_serving_client.io as server_io + feed_var_dict = {} + feed_var_dict['dense_input'] = dense_input + for i, sparse in enumerate(sparse_input_ids): + feed_var_dict["embedding_{}.tmp_0".format(i)] = sparse + fetch_var_dict = {"prob": predict_y} + + feed_kv_dict = {} + feed_kv_dict['dense_input'] = dense_input + for i, emb in enumerate(infer_vars): + feed_kv_dict["embedding_{}.tmp_0".format(i)] = emb + fetch_var_dict = {"prob": predict_y} + + server_io.save_model("ctr_serving_model", "ctr_client_conf", feed_var_dict, + fetch_var_dict, fluid.default_main_program()) + + server_io.save_model("ctr_serving_model_kv", "ctr_client_conf_kv", + feed_kv_dict, fetch_var_dict, + fluid.default_main_program()) + + +if __name__ == '__main__': + train() diff --git a/python/examples/criteo_ctr_with_cube/network_conf.py b/python/examples/criteo_ctr_with_cube/network_conf.py new file mode 100755 index 0000000000000000000000000000000000000000..2975533a72ad21d6dd5896446fd06c1f9bdfe8b4 --- /dev/null +++ b/python/examples/criteo_ctr_with_cube/network_conf.py @@ -0,0 +1,77 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# pylint: disable=doc-string-missing + +import paddle.fluid as fluid +import math + + +def dnn_model(dense_input, sparse_inputs, label, embedding_size, + sparse_feature_dim): + def embedding_layer(input): + emb = fluid.layers.embedding( + input=input, + is_sparse=True, + is_distributed=False, + size=[sparse_feature_dim, embedding_size], + param_attr=fluid.ParamAttr( + name="SparseFeatFactors", + initializer=fluid.initializer.Uniform())) + x = fluid.layers.sequence_pool(input=emb, pool_type='sum') + return emb, x + + def mlp_input_tensor(emb_sums, dense_tensor): + #if isinstance(dense_tensor, fluid.Variable): + # return fluid.layers.concat(emb_sums, axis=1) + #else: + return fluid.layers.concat(emb_sums + [dense_tensor], axis=1) + + def mlp(mlp_input): + fc1 = fluid.layers.fc(input=mlp_input, + size=400, + act='relu', + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Normal( + scale=1 / math.sqrt(mlp_input.shape[1])))) + fc2 = fluid.layers.fc(input=fc1, + size=400, + act='relu', + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Normal( + scale=1 / math.sqrt(fc1.shape[1])))) + fc3 = fluid.layers.fc(input=fc2, + size=400, + act='relu', + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Normal( + scale=1 / math.sqrt(fc2.shape[1])))) + pre = fluid.layers.fc(input=fc3, + size=2, + act='softmax', + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Normal( + scale=1 / math.sqrt(fc3.shape[1])))) + return pre + + emb_pair_sums = list(map(embedding_layer, sparse_inputs)) + emb_sums = [x[1] for x in emb_pair_sums] + infer_vars = [x[0] for x in emb_pair_sums] + mlp_in = mlp_input_tensor(emb_sums, dense_input) + predict = mlp(mlp_in) + cost = fluid.layers.cross_entropy(input=predict, label=label) + avg_cost = fluid.layers.reduce_sum(cost) + accuracy = fluid.layers.accuracy(input=predict, label=label) + auc_var, batch_auc_var, auc_states = \ + fluid.layers.auc(input=predict, label=label, num_thresholds=2 ** 12, slide_steps=20) + return predict, avg_cost, auc_var, batch_auc_var, infer_vars diff --git a/python/examples/criteo_ctr_with_cube/test_client.py b/python/examples/criteo_ctr_with_cube/test_client.py new file mode 100755 index 0000000000000000000000000000000000000000..bef04807e9b5d5c2cdc316828ed6f960f0eeb0f8 --- /dev/null +++ b/python/examples/criteo_ctr_with_cube/test_client.py @@ -0,0 +1,56 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# pylint: disable=doc-string-missing + +from paddle_serving_client import Client +import sys +import os +import criteo as criteo +import time +from paddle_serving_client.metric import auc +import numpy as np +py_version = sys.version_info[0] + +client = Client() +client.load_client_config(sys.argv[1]) +client.connect(["127.0.0.1:9292"]) + +batch = 1 +buf_size = 100 +dataset = criteo.CriteoDataset() +dataset.setup(1000001) +test_filelists = ["{}/part-0".format(sys.argv[2])] +reader = dataset.infer_reader(test_filelists, batch, buf_size) +label_list = [] +prob_list = [] +start = time.time() +for ei in range(10000): + if py_version == 2: + data = reader().next() + else: + data = reader().__next__() + feed_dict = {} + feed_dict['dense_input'] = data[0][0] + for i in range(1, 27): + feed_dict["embedding_{}.tmp_0".format(i - 1)] = np.array(data[0][i]).reshape(-1) + feed_dict["embedding_{}.tmp_0.lod".format(i - 1)] = [0, len(data[0][i])] + fetch_map = client.predict(feed=feed_dict, fetch=["prob"]) + print(fetch_map) + prob_list.append(fetch_map['prob'][0][1]) + label_list.append(data[0][-1][0]) + +print(auc(label_list, prob_list)) +end = time.time() +print(end - start) + diff --git a/python/examples/criteo_ctr_with_cube/test_server.py b/python/examples/criteo_ctr_with_cube/test_server.py new file mode 100755 index 0000000000000000000000000000000000000000..479c602910b5afa52b35a66d00316f54905c0741 --- /dev/null +++ b/python/examples/criteo_ctr_with_cube/test_server.py @@ -0,0 +1,41 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# pylint: disable=doc-string-missing + +import os +import sys +from paddle_serving_server import OpMaker +from paddle_serving_server import OpSeqMaker +from paddle_serving_server import Server + +op_maker = OpMaker() +read_op = op_maker.create('general_reader') +general_dist_kv_infer_op = op_maker.create('general_dist_kv_infer') +response_op = op_maker.create('general_response') + +op_seq_maker = OpSeqMaker() +op_seq_maker.add_op(read_op) +op_seq_maker.add_op(general_dist_kv_infer_op) +op_seq_maker.add_op(response_op) + +server = Server() +server.set_op_sequence(op_seq_maker.get_op_sequence()) +server.set_num_threads(4) +server.load_model_config(sys.argv[1]) +server.prepare_server( + workdir="work_dir1", + port=9292, + device="cpu", + cube_conf="./cube/conf/cube.conf") +server.run_server() diff --git a/python/examples/fit_a_line/test_client.py b/python/examples/fit_a_line/test_client.py old mode 100644 new mode 100755 diff --git a/python/examples/fit_a_line/test_httpclient.py b/python/examples/fit_a_line/test_httpclient.py index ad6e8502f436bfcdd313ee68610c9ad6c832407e..b1c7057a64e4f37ea14d19ea32740d71eff42146 100644 --- a/python/examples/fit_a_line/test_httpclient.py +++ b/python/examples/fit_a_line/test_httpclient.py @@ -34,10 +34,7 @@ test_reader = paddle.batch( for data in test_reader(): new_data = np.zeros((1, 13)).astype("float32") new_data[0] = data[0][0] - lst_data = [] - for i in range(200): - lst_data.append(data[0][0]) fetch_map = client.predict( - feed={"x": lst_data}, fetch=fetch_list, batch=True) + feed={"x": new_data}, fetch=fetch_list, batch=True) print(fetch_map) break diff --git a/python/paddle_serving_client/httpclient.py b/python/paddle_serving_client/httpclient.py old mode 100644 new mode 100755 index 356d27bd5b094db036615dc2e61a491d9d726678..dc120686e42c7e4368e8cd32216c8b63c9d56782 --- a/python/paddle_serving_client/httpclient.py +++ b/python/paddle_serving_client/httpclient.py @@ -57,6 +57,7 @@ def data_bytes_number(datalist): else: raise ValueError( "In the Function data_bytes_number(), data must be list.") + return total_bytes_number class HttpClient(object): @@ -141,6 +142,15 @@ class HttpClient(object): else: self.http_timeout_ms = http_timeout_ms + def set_ip(self, ip): + self.ip = ip + + def set_service_name(self, service_name): + self.service_name = service_name + + def set_port(self, port): + self.port = port + def set_request_compress(self, try_request_gzip): self.try_request_gzip = try_request_gzip @@ -294,9 +304,10 @@ class HttpClient(object): raise ValueError( "feedvar is string-type,feed, feed can`t be a single int or others." ) - - total_data_number = total_data_number + data_bytes_number( - data_value) + # 如果不压缩,那么不需要统计数据量。 + if self.try_request_gzip: + total_data_number = total_data_number + data_bytes_number( + data_value) Request["tensor"][index]["elem_type"] = elem_type Request["tensor"][index]["shape"] = shape Request["tensor"][index][data_key] = data_value