未验证 提交 3a14c484 编写于 作者: H huangjianhui 提交者: GitHub

Merge branch 'develop' into develop

...@@ -191,24 +191,44 @@ int GeneralDetectionOp::inference() { ...@@ -191,24 +191,44 @@ int GeneralDetectionOp::inference() {
boxes = post_processor_.FilterTagDetRes(boxes, ratio_h, ratio_w, srcimg); boxes = post_processor_.FilterTagDetRes(boxes, ratio_h, ratio_w, srcimg);
for (int i = boxes.size() - 1; i >= 0; i--) { float max_wh_ratio = 0.0f;
crop_img = GetRotateCropImage(img, boxes[i]); std::vector<cv::Mat> crop_imgs;
std::vector<cv::Mat> resize_imgs;
float wh_ratio = float(crop_img.cols) / float(crop_img.rows); int max_resize_w = 0;
int max_resize_h = 0;
int box_num = boxes.size();
std::vector<std::vector<float>> output_rec;
for (int i = 0; i < box_num; ++i) {
cv::Mat line_img = GetRotateCropImage(img, boxes[i]);
float wh_ratio = float(line_img.cols) / float(line_img.rows);
max_wh_ratio = max_wh_ratio > wh_ratio ? max_wh_ratio : wh_ratio;
crop_imgs.push_back(line_img);
}
for (int i = 0; i < box_num; ++i) {
cv::Mat resize_img;
crop_img = crop_imgs[i];
this->resize_op_rec.Run( this->resize_op_rec.Run(
crop_img, resize_img_rec, wh_ratio, this->use_tensorrt_); crop_img, resize_img, max_wh_ratio, this->use_tensorrt_);
this->normalize_op_.Run( this->normalize_op_.Run(
&resize_img_rec, this->mean_rec, this->scale_rec, this->is_scale_); &resize_img, this->mean_rec, this->scale_rec, this->is_scale_);
std::vector<float> output_rec( max_resize_w = std::max(max_resize_w, resize_img.cols);
1 * 3 * resize_img_rec.rows * resize_img_rec.cols, 0.0f); max_resize_h = std::max(max_resize_h, resize_img.rows);
resize_imgs.push_back(resize_img);
}
int buf_size = 3 * max_resize_h * max_resize_w;
output_rec = std::vector<std::vector<float>>(box_num,
std::vector<float>(buf_size, 0.0f));
for (int i = 0; i < box_num; ++i) {
resize_img_rec = resize_imgs[i];
this->permute_op_.Run(&resize_img_rec, output_rec.data()); this->permute_op_.Run(&resize_img_rec, output_rec[i].data());
}
// Inference. // Inference.
output_shape = {1, 3, resize_img_rec.rows, resize_img_rec.cols}; output_shape = {box_num, 3, max_resize_h, max_resize_w};
out_num = std::accumulate( out_num = std::accumulate(
output_shape.begin(), output_shape.end(), 1, std::multiplies<int>()); output_shape.begin(), output_shape.end(), 1, std::multiplies<int>());
databuf_size_out = out_num * sizeof(float); databuf_size_out = out_num * sizeof(float);
...@@ -217,17 +237,19 @@ int GeneralDetectionOp::inference() { ...@@ -217,17 +237,19 @@ int GeneralDetectionOp::inference() {
LOG(ERROR) << "Malloc failed, size: " << databuf_size_out; LOG(ERROR) << "Malloc failed, size: " << databuf_size_out;
return -1; return -1;
} }
memcpy(databuf_data_out, output_rec.data(), databuf_size_out); int offset = buf_size * sizeof(float);
for (int i = 0; i < box_num; ++i) {
memcpy(databuf_data_out + i * offset, output_rec[i].data(), offset);
}
databuf_char_out = reinterpret_cast<char*>(databuf_data_out); databuf_char_out = reinterpret_cast<char*>(databuf_data_out);
paddle::PaddleBuf paddleBuf(databuf_char_out, databuf_size_out); paddle::PaddleBuf paddleBuf(databuf_char_out, databuf_size_out);
paddle::PaddleTensor tensor_out; paddle::PaddleTensor tensor_out;
tensor_out.name = "image"; tensor_out.name = "image";
tensor_out.dtype = paddle::PaddleDType::FLOAT32; tensor_out.dtype = paddle::PaddleDType::FLOAT32;
tensor_out.shape = {1, 3, resize_img_rec.rows, resize_img_rec.cols}; tensor_out.shape = output_shape;
tensor_out.data = paddleBuf; tensor_out.data = paddleBuf;
out->push_back(tensor_out); out->push_back(tensor_out);
} }
}
out->erase(out->begin(), out->begin() + infer_outnum); out->erase(out->begin(), out->begin() + infer_outnum);
int64_t end = timeline.TimeStampUS(); int64_t end = timeline.TimeStampUS();
......
...@@ -63,7 +63,7 @@ class GeneralDetectionOp ...@@ -63,7 +63,7 @@ class GeneralDetectionOp
double det_db_thresh_ = 0.3; double det_db_thresh_ = 0.3;
double det_db_box_thresh_ = 0.5; double det_db_box_thresh_ = 0.5;
double det_db_unclip_ratio_ = 2.0; double det_db_unclip_ratio_ = 1.5;
std::vector<float> mean_det = {0.485f, 0.456f, 0.406f}; std::vector<float> mean_det = {0.485f, 0.456f, 0.406f};
std::vector<float> scale_det = {1 / 0.229f, 1 / 0.224f, 1 / 0.225f}; std::vector<float> scale_det = {1 / 0.229f, 1 / 0.224f, 1 / 0.225f};
......
...@@ -186,9 +186,9 @@ int GeneralDistKVInferOp::inference() { ...@@ -186,9 +186,9 @@ int GeneralDistKVInferOp::inference() {
if (values.size() != keys.size() || values[0].buff.size() == 0) { if (values.size() != keys.size() || values[0].buff.size() == 0) {
LOG(ERROR) << "cube value return null"; LOG(ERROR) << "cube value return null";
} }
// size_t EMBEDDING_SIZE = values[0].buff.size() / sizeof(float); size_t EMBEDDING_SIZE = values[0].buff.size() / sizeof(float);
// size_t EMBEDDING_SIZE = (values[0].buff.size() - 10) / sizeof(float); // size_t EMBEDDING_SIZE = (values[0].buff.size() - 10) / sizeof(float);
size_t EMBEDDING_SIZE = 9; //size_t EMBEDDING_SIZE = 9;
TensorVector sparse_out; TensorVector sparse_out;
sparse_out.resize(sparse_count); sparse_out.resize(sparse_count);
TensorVector dense_out; TensorVector dense_out;
...@@ -241,7 +241,7 @@ int GeneralDistKVInferOp::inference() { ...@@ -241,7 +241,7 @@ int GeneralDistKVInferOp::inference() {
// The data generated by pslib has 10 bytes of information to be filtered // The data generated by pslib has 10 bytes of information to be filtered
// out // out
memcpy(data_ptr, cur_val->buff.data() + 10, cur_val->buff.size() - 10); memcpy(data_ptr, cur_val->buff.data(), cur_val->buff.size() );
// VLOG(3) << keys[cube_val_idx] << ":" << data_ptr[0] << ", " << // VLOG(3) << keys[cube_val_idx] << ":" << data_ptr[0] << ", " <<
// data_ptr[1] << ", " <<data_ptr[2] << ", " <<data_ptr[3] << ", " // data_ptr[1] << ", " <<data_ptr[2] << ", " <<data_ptr[3] << ", "
// <<data_ptr[4] << ", " <<data_ptr[5] << ", " <<data_ptr[6] << ", " // <<data_ptr[4] << ", " <<data_ptr[5] << ", " <<data_ptr[6] << ", "
......
...@@ -215,6 +215,7 @@ TaskHandler<TaskT> TaskExecutor<TaskT>::schedule( ...@@ -215,6 +215,7 @@ TaskHandler<TaskT> TaskExecutor<TaskT>::schedule(
LOG(ERROR) << "Failed get TaskT from object pool"; LOG(ERROR) << "Failed get TaskT from object pool";
return TaskHandler<TaskT>::valid_handle(); return TaskHandler<TaskT>::valid_handle();
} }
task->clear();
/* /*
if (!BatchTasks<TaskT>::check_valid(in, out, _overrun)) { if (!BatchTasks<TaskT>::check_valid(in, out, _overrun)) {
......
...@@ -99,7 +99,40 @@ struct Task { ...@@ -99,7 +99,40 @@ struct Task {
outLodTensorVector.clear(); outLodTensorVector.clear();
} }
~Task() { ~Task() {
read_fd = -1;
write_fd = -1;
owner_tid = -1;
inVectorT_ptr = NULL;
outVectorT_ptr = NULL;
set_feed_lod_index.clear();
set_feed_nobatch_index.clear();
vector_fetch_lod_index.clear();
set_fetch_nobatch_index.clear();
rem = -1;
total_feed_batch = 0;
taskmeta_num = 0;
index.store(0, butil::memory_order_relaxed);
THREAD_MUTEX_DESTROY(&task_mut); THREAD_MUTEX_DESTROY(&task_mut);
fetch_init = false;
outLodTensorVector.clear();
}
void clear(){
read_fd = -1;
write_fd = -1;
owner_tid = -1;
inVectorT_ptr = NULL;
outVectorT_ptr = NULL;
set_feed_lod_index.clear();
set_feed_nobatch_index.clear();
vector_fetch_lod_index.clear();
set_fetch_nobatch_index.clear();
rem = -1;
total_feed_batch = 0;
taskmeta_num = 0;
index.store(0, butil::memory_order_relaxed);
THREAD_MUTEX_INIT(&task_mut, NULL);
fetch_init = false;
outLodTensorVector.clear(); outLodTensorVector.clear();
} }
...@@ -323,7 +356,7 @@ struct Task { ...@@ -323,7 +356,7 @@ struct Task {
size_t feedvar_index = vector_fetch_lod_index[index]; size_t feedvar_index = vector_fetch_lod_index[index];
// 由于PaddleTensor的resize实现,是每次都会清空,所以必须先统计总长度。 // 由于PaddleTensor的resize实现,是每次都会清空,所以必须先统计总长度。
for (size_t taskmeta_index = 0; taskmeta_index < taskmeta_num; for (size_t taskmeta_index = 0; taskmeta_index < taskmeta_num;
++taskmeta_num) { ++taskmeta_index) {
data_length += data_length +=
outLodTensorVector[taskmeta_index][index].data.length(); outLodTensorVector[taskmeta_index][index].data.length();
lod_length += outLodTensorVector[taskmeta_index][index].lod[0].size(); lod_length += outLodTensorVector[taskmeta_index][index].lod[0].size();
...@@ -347,7 +380,7 @@ struct Task { ...@@ -347,7 +380,7 @@ struct Task {
size_t once_lod_length = 0; size_t once_lod_length = 0;
size_t last_lod_value = fetchVarTensor.lod[0][lod_length_offset]; size_t last_lod_value = fetchVarTensor.lod[0][lod_length_offset];
for (size_t taskmeta_index = 0; taskmeta_index < taskmeta_num; for (size_t taskmeta_index = 0; taskmeta_index < taskmeta_num;
++taskmeta_num) { ++taskmeta_index) {
void* dst_ptr = fetchVarTensor.data.data() + data_length_offset; void* dst_ptr = fetchVarTensor.data.data() + data_length_offset;
void* source_ptr = void* source_ptr =
outLodTensorVector[taskmeta_index][index].data.data(); outLodTensorVector[taskmeta_index][index].data.data();
......
...@@ -277,7 +277,7 @@ class DBReloadableInferEngine : public ReloadableInferEngine { ...@@ -277,7 +277,7 @@ class DBReloadableInferEngine : public ReloadableInferEngine {
LOG(WARNING) << "Loading cube cache[" << next_idx << "] ..."; LOG(WARNING) << "Loading cube cache[" << next_idx << "] ...";
std::string model_path = conf.model_dir(); std::string model_path = conf.model_dir();
if (access(model_path.c_str(), F_OK) == 0) { if (access(model_path.c_str(), F_OK) == 0) {
std::string cube_cache_path = model_path + "/" + "cube_cache"; std::string cube_cache_path = model_path + "/cube_cache";
int reload_cache_ret = md->caches[next_idx]->reload_data(cube_cache_path); int reload_cache_ret = md->caches[next_idx]->reload_data(cube_cache_path);
LOG(WARNING) << "Loading cube cache[" << next_idx << "] done."; LOG(WARNING) << "Loading cube cache[" << next_idx << "] done.";
} else { } else {
...@@ -437,7 +437,7 @@ class CloneDBReloadableInferEngine ...@@ -437,7 +437,7 @@ class CloneDBReloadableInferEngine
// create caches // create caches
std::string model_path = conf.model_dir(); std::string model_path = conf.model_dir();
if (access(model_path.c_str(), F_OK) == 0) { if (access(model_path.c_str(), F_OK) == 0) {
std::string cube_cache_path = model_path + "cube_cache"; std::string cube_cache_path = model_path + "/cube_cache";
int reload_cache_ret = int reload_cache_ret =
md->caches[next_idx]->reload_data(cube_cache_path); md->caches[next_idx]->reload_data(cube_cache_path);
LOG(WARNING) << "create cube cache[" << next_idx << "] done."; LOG(WARNING) << "create cube cache[" << next_idx << "] done.";
......
...@@ -82,14 +82,14 @@ void ResizeImgType0::Run(const cv::Mat &img, cv::Mat &resize_img, ...@@ -82,14 +82,14 @@ void ResizeImgType0::Run(const cv::Mat &img, cv::Mat &resize_img,
else if (resize_h / 32 < 1 + 1e-5) else if (resize_h / 32 < 1 + 1e-5)
resize_h = 32; resize_h = 32;
else else
resize_h = (resize_h / 32) * 32; resize_h = (resize_h / 32 - 1) * 32;
if (resize_w % 32 == 0) if (resize_w % 32 == 0)
resize_w = resize_w; resize_w = resize_w;
else if (resize_w / 32 < 1 + 1e-5) else if (resize_w / 32 < 1 + 1e-5)
resize_w = 32; resize_w = 32;
else else
resize_w = (resize_w / 32) * 32; resize_w = (resize_w / 32 - 1) * 32;
if (!use_tensorrt) { if (!use_tensorrt) {
cv::resize(img, resize_img, cv::Size(resize_w, resize_h)); cv::resize(img, resize_img, cv::Size(resize_w, resize_h));
ratio_h = float(resize_h) / float(h); ratio_h = float(resize_h) / float(h);
......
[{
"dict_name": "test_dict",
"shard": 1,
"dup": 1,
"timeout": 200,
"retry": 3,
"backup_request": 100,
"type": "ipport_list",
"load_balancer": "rr",
"nodes": [{
"ipport_list": "list://127.0.0.1:8027"
}]
}]
--port=8027
--dict_split=1
--in_mem=true
--log_dir=./log/
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
from paddle_serving_client import Client from paddle_serving_client import Client
import sys import sys
import os import os
import criteo as criteo import criteo_reader as criteo
import time import time
from paddle_serving_client.metric import auc from paddle_serving_client.metric import auc
import numpy as np import numpy as np
...@@ -35,22 +35,23 @@ reader = dataset.infer_reader(test_filelists, batch, buf_size) ...@@ -35,22 +35,23 @@ reader = dataset.infer_reader(test_filelists, batch, buf_size)
label_list = [] label_list = []
prob_list = [] prob_list = []
start = time.time() start = time.time()
for ei in range(10000): for ei in range(100):
if py_version == 2: if py_version == 2:
data = reader().next() data = reader().next()
else: else:
data = reader().__next__() data = reader().__next__()
feed_dict = {} feed_dict = {}
feed_dict['dense_input'] = data[0][0] feed_dict['dense_input'] = np.array(data[0][0]).reshape(1, len(data[0][0]))
for i in range(1, 27): for i in range(1, 27):
feed_dict["embedding_{}.tmp_0".format(i - 1)] = np.array(data[0][i]).reshape(-1) feed_dict["embedding_{}.tmp_0".format(i - 1)] = np.array(data[0][i]).reshape(len(data[0][i]))
feed_dict["embedding_{}.tmp_0.lod".format(i - 1)] = [0, len(data[0][i])] feed_dict["embedding_{}.tmp_0.lod".format(i - 1)] = [0, len(data[0][i])]
fetch_map = client.predict(feed=feed_dict, fetch=["prob"]) fetch_map = client.predict(feed=feed_dict, fetch=["prob"],batch=True)
print(fetch_map) print(fetch_map)
prob_list.append(fetch_map['prob'][0][1]) prob_list.append(fetch_map['prob'][0][1])
label_list.append(data[0][-1][0]) label_list.append(data[0][-1][0])
print(auc(label_list, prob_list))
end = time.time() end = time.time()
print(end - start) print(end - start)
...@@ -16,5 +16,5 @@ This model support TensorRT, if you want a faster inference, please use `--use_t ...@@ -16,5 +16,5 @@ This model support TensorRT, if you want a faster inference, please use `--use_t
### Perform prediction ### Perform prediction
``` ```
python3 test_client.py 000000570688.jpg python3 test_client.py 000000014439.jpg
``` ```
...@@ -18,5 +18,5 @@ python3 -m paddle_serving_server.serve --model serving_server --port 9494 --gpu_ ...@@ -18,5 +18,5 @@ python3 -m paddle_serving_server.serve --model serving_server --port 9494 --gpu_
### 执行预测 ### 执行预测
``` ```
python3 test_client.py 000000570688.jpg python3 test_client.py 000000014439.jpg
``` ```
...@@ -27,7 +27,7 @@ preprocess = Sequential([ ...@@ -27,7 +27,7 @@ preprocess = Sequential([
PadStride(128) PadStride(128)
]) ])
postprocess = RCNNPostprocess("label_list.txt", "output") postprocess = RCNNPostprocess("label_list.txt", "output", [608, 608])
client = Client() client = Client()
client.load_client_config("serving_client/serving_client_conf.prototxt") client.load_client_config("serving_client/serving_client_conf.prototxt")
...@@ -41,5 +41,6 @@ fetch_map = client.predict( ...@@ -41,5 +41,6 @@ fetch_map = client.predict(
}, },
fetch=["save_infer_model/scale_0.tmp_1"], fetch=["save_infer_model/scale_0.tmp_1"],
batch=False) batch=False)
print(fetch_map)
fetch_map["image"] = sys.argv[1] fetch_map["image"] = sys.argv[1]
postprocess(fetch_map) postprocess(fetch_map)
...@@ -16,5 +16,5 @@ This model support TensorRT, if you want a faster inference, please use `--use_t ...@@ -16,5 +16,5 @@ This model support TensorRT, if you want a faster inference, please use `--use_t
### Perform prediction ### Perform prediction
``` ```
python3 test_client.py 000000570688.jpg python3 test_client.py 000000014439.jpg
``` ```
...@@ -18,5 +18,5 @@ python3 -m paddle_serving_server.serve --model serving_server --port 9494 --gpu_ ...@@ -18,5 +18,5 @@ python3 -m paddle_serving_server.serve --model serving_server --port 9494 --gpu_
### 执行预测 ### 执行预测
``` ```
python3 test_client.py 000000570688.jpg python3 test_client.py 000000014439.jpg
``` ```
...@@ -27,7 +27,7 @@ preprocess = Sequential([ ...@@ -27,7 +27,7 @@ preprocess = Sequential([
PadStride(128) PadStride(128)
]) ])
postprocess = RCNNPostprocess("label_list.txt", "output") postprocess = RCNNPostprocess("label_list.txt", "output", [608, 608])
client = Client() client = Client()
client.load_client_config("serving_client/serving_client_conf.prototxt") client.load_client_config("serving_client/serving_client_conf.prototxt")
...@@ -41,5 +41,6 @@ fetch_map = client.predict( ...@@ -41,5 +41,6 @@ fetch_map = client.predict(
}, },
fetch=["save_infer_model/scale_0.tmp_1"], fetch=["save_infer_model/scale_0.tmp_1"],
batch=False) batch=False)
print(fetch_map)
fetch_map["image"] = sys.argv[1] fetch_map["image"] = sys.argv[1]
postprocess(fetch_map) postprocess(fetch_map)
...@@ -28,7 +28,7 @@ from paddle_serving_client.utils import benchmark_args, show_latency ...@@ -28,7 +28,7 @@ from paddle_serving_client.utils import benchmark_args, show_latency
def parse_benchmark(filein, fileout): def parse_benchmark(filein, fileout):
with open(filein, "r") as fin: with open(filein, "r") as fin:
res = yaml.load(fin) res = yaml.load(fin, yaml.FullLoader)
del_list = [] del_list = []
for key in res["DAG"].keys(): for key in res["DAG"].keys():
if "call" in key: if "call" in key:
...@@ -41,7 +41,7 @@ def parse_benchmark(filein, fileout): ...@@ -41,7 +41,7 @@ def parse_benchmark(filein, fileout):
def gen_yml(device, gpu_id): def gen_yml(device, gpu_id):
fin = open("config.yml", "r") fin = open("config.yml", "r")
config = yaml.load(fin) config = yaml.load(fin, yaml.FullLoader)
fin.close() fin.close()
config["dag"]["tracer"] = {"interval_s": 10} config["dag"]["tracer"] = {"interval_s": 10}
if device == "gpu": if device == "gpu":
......
...@@ -28,7 +28,7 @@ from paddle_serving_client.utils import benchmark_args, show_latency ...@@ -28,7 +28,7 @@ from paddle_serving_client.utils import benchmark_args, show_latency
def parse_benchmark(filein, fileout): def parse_benchmark(filein, fileout):
with open(filein, "r") as fin: with open(filein, "r") as fin:
res = yaml.load(fin) res = yaml.load(fin, yaml.FullLoader)
del_list = [] del_list = []
for key in res["DAG"].keys(): for key in res["DAG"].keys():
if "call" in key: if "call" in key:
...@@ -41,7 +41,7 @@ def parse_benchmark(filein, fileout): ...@@ -41,7 +41,7 @@ def parse_benchmark(filein, fileout):
def gen_yml(device, gpu_id): def gen_yml(device, gpu_id):
fin = open("config.yml", "r") fin = open("config.yml", "r")
config = yaml.load(fin) config = yaml.load(fin, yaml.FullLoader)
fin.close() fin.close()
config["dag"]["tracer"] = {"interval_s": 10} config["dag"]["tracer"] = {"interval_s": 10}
if device == "gpu": if device == "gpu":
......
...@@ -28,7 +28,7 @@ from paddle_serving_client.utils import benchmark_args, show_latency ...@@ -28,7 +28,7 @@ from paddle_serving_client.utils import benchmark_args, show_latency
def parse_benchmark(filein, fileout): def parse_benchmark(filein, fileout):
with open(filein, "r") as fin: with open(filein, "r") as fin:
res = yaml.load(fin) res = yaml.load(fin, yaml.FullLoader)
del_list = [] del_list = []
for key in res["DAG"].keys(): for key in res["DAG"].keys():
if "call" in key: if "call" in key:
...@@ -41,7 +41,7 @@ def parse_benchmark(filein, fileout): ...@@ -41,7 +41,7 @@ def parse_benchmark(filein, fileout):
def gen_yml(device, gpu_id): def gen_yml(device, gpu_id):
fin = open("config.yml", "r") fin = open("config.yml", "r")
config = yaml.load(fin) config = yaml.load(fin, yaml.FullLoader)
fin.close() fin.close()
config["dag"]["tracer"] = {"interval_s": 10} config["dag"]["tracer"] = {"interval_s": 10}
if device == "gpu": if device == "gpu":
......
...@@ -28,7 +28,7 @@ from paddle_serving_client.utils import benchmark_args, show_latency ...@@ -28,7 +28,7 @@ from paddle_serving_client.utils import benchmark_args, show_latency
def parse_benchmark(filein, fileout): def parse_benchmark(filein, fileout):
with open(filein, "r") as fin: with open(filein, "r") as fin:
res = yaml.load(fin) res = yaml.load(fin, yaml.FullLoader)
del_list = [] del_list = []
for key in res["DAG"].keys(): for key in res["DAG"].keys():
if "call" in key: if "call" in key:
...@@ -41,7 +41,7 @@ def parse_benchmark(filein, fileout): ...@@ -41,7 +41,7 @@ def parse_benchmark(filein, fileout):
def gen_yml(device, gpu_id): def gen_yml(device, gpu_id):
fin = open("config.yml", "r") fin = open("config.yml", "r")
config = yaml.load(fin) config = yaml.load(fin, yaml.FullLoader)
fin.close() fin.close()
config["dag"]["tracer"] = {"interval_s": 10} config["dag"]["tracer"] = {"interval_s": 10}
if device == "gpu": if device == "gpu":
......
...@@ -28,7 +28,7 @@ from paddle_serving_client.utils import benchmark_args, show_latency ...@@ -28,7 +28,7 @@ from paddle_serving_client.utils import benchmark_args, show_latency
def parse_benchmark(filein, fileout): def parse_benchmark(filein, fileout):
with open(filein, "r") as fin: with open(filein, "r") as fin:
res = yaml.load(fin) res = yaml.load(fin, yaml.FullLoader)
del_list = [] del_list = []
for key in res["DAG"].keys(): for key in res["DAG"].keys():
if "call" in key: if "call" in key:
...@@ -41,7 +41,7 @@ def parse_benchmark(filein, fileout): ...@@ -41,7 +41,7 @@ def parse_benchmark(filein, fileout):
def gen_yml(device, gpu_id): def gen_yml(device, gpu_id):
fin = open("config.yml", "r") fin = open("config.yml", "r")
config = yaml.load(fin) config = yaml.load(fin, yaml.FullLoader)
fin.close() fin.close()
config["dag"]["tracer"] = {"interval_s": 10} config["dag"]["tracer"] = {"interval_s": 10}
if device == "gpu": if device == "gpu":
......
...@@ -28,7 +28,7 @@ from paddle_serving_client.utils import benchmark_args, show_latency ...@@ -28,7 +28,7 @@ from paddle_serving_client.utils import benchmark_args, show_latency
def parse_benchmark(filein, fileout): def parse_benchmark(filein, fileout):
with open(filein, "r") as fin: with open(filein, "r") as fin:
res = yaml.load(fin) res = yaml.load(fin, yaml.FullLoader)
del_list = [] del_list = []
for key in res["DAG"].keys(): for key in res["DAG"].keys():
if "call" in key: if "call" in key:
...@@ -41,7 +41,7 @@ def parse_benchmark(filein, fileout): ...@@ -41,7 +41,7 @@ def parse_benchmark(filein, fileout):
def gen_yml(device, gpu_id): def gen_yml(device, gpu_id):
fin = open("config.yml", "r") fin = open("config.yml", "r")
config = yaml.load(fin) config = yaml.load(fin, yaml.FullLoader)
fin.close() fin.close()
config["dag"]["tracer"] = {"interval_s": 10} config["dag"]["tracer"] = {"interval_s": 10}
if device == "gpu": if device == "gpu":
......
...@@ -28,7 +28,7 @@ from paddle_serving_client.utils import benchmark_args, show_latency ...@@ -28,7 +28,7 @@ from paddle_serving_client.utils import benchmark_args, show_latency
def parse_benchmark(filein, fileout): def parse_benchmark(filein, fileout):
with open(filein, "r") as fin: with open(filein, "r") as fin:
res = yaml.load(fin) res = yaml.load(fin, yaml.FullLoader)
del_list = [] del_list = []
for key in res["DAG"].keys(): for key in res["DAG"].keys():
if "call" in key: if "call" in key:
...@@ -41,7 +41,7 @@ def parse_benchmark(filein, fileout): ...@@ -41,7 +41,7 @@ def parse_benchmark(filein, fileout):
def gen_yml(device, gpu_id): def gen_yml(device, gpu_id):
fin = open("config.yml", "r") fin = open("config.yml", "r")
config = yaml.load(fin) config = yaml.load(fin, yaml.FullLoader)
fin.close() fin.close()
config["dag"]["tracer"] = {"interval_s": 10} config["dag"]["tracer"] = {"interval_s": 10}
if device == "gpu": if device == "gpu":
......
...@@ -28,7 +28,7 @@ from paddle_serving_client.utils import benchmark_args, show_latency ...@@ -28,7 +28,7 @@ from paddle_serving_client.utils import benchmark_args, show_latency
def parse_benchmark(filein, fileout): def parse_benchmark(filein, fileout):
with open(filein, "r") as fin: with open(filein, "r") as fin:
res = yaml.load(fin) res = yaml.load(fin, yaml.FullLoader)
del_list = [] del_list = []
for key in res["DAG"].keys(): for key in res["DAG"].keys():
if "call" in key: if "call" in key:
...@@ -41,7 +41,7 @@ def parse_benchmark(filein, fileout): ...@@ -41,7 +41,7 @@ def parse_benchmark(filein, fileout):
def gen_yml(device, gpu_id): def gen_yml(device, gpu_id):
fin = open("config.yml", "r") fin = open("config.yml", "r")
config = yaml.load(fin) config = yaml.load(fin, yaml.FullLoader)
fin.close() fin.close()
config["dag"]["tracer"] = {"interval_s": 10} config["dag"]["tracer"] = {"interval_s": 10}
if device == "gpu": if device == "gpu":
......
...@@ -28,7 +28,7 @@ from paddle_serving_client.utils import benchmark_args, show_latency ...@@ -28,7 +28,7 @@ from paddle_serving_client.utils import benchmark_args, show_latency
def parse_benchmark(filein, fileout): def parse_benchmark(filein, fileout):
with open(filein, "r") as fin: with open(filein, "r") as fin:
res = yaml.load(fin) res = yaml.load(fin, yaml.FullLoader)
del_list = [] del_list = []
for key in res["DAG"].keys(): for key in res["DAG"].keys():
if "call" in key: if "call" in key:
...@@ -41,7 +41,7 @@ def parse_benchmark(filein, fileout): ...@@ -41,7 +41,7 @@ def parse_benchmark(filein, fileout):
def gen_yml(device, gpu_id): def gen_yml(device, gpu_id):
fin = open("config.yml", "r") fin = open("config.yml", "r")
config = yaml.load(fin) config = yaml.load(fin, yaml.FullLoader)
fin.close() fin.close()
config["dag"]["tracer"] = {"interval_s": 10} config["dag"]["tracer"] = {"interval_s": 10}
if device == "gpu": if device == "gpu":
......
...@@ -28,7 +28,7 @@ from paddle_serving_client.utils import benchmark_args, show_latency ...@@ -28,7 +28,7 @@ from paddle_serving_client.utils import benchmark_args, show_latency
def parse_benchmark(filein, fileout): def parse_benchmark(filein, fileout):
with open(filein, "r") as fin: with open(filein, "r") as fin:
res = yaml.load(fin) res = yaml.load(fin, yaml.FullLoader)
del_list = [] del_list = []
for key in res["DAG"].keys(): for key in res["DAG"].keys():
if "call" in key: if "call" in key:
...@@ -41,7 +41,7 @@ def parse_benchmark(filein, fileout): ...@@ -41,7 +41,7 @@ def parse_benchmark(filein, fileout):
def gen_yml(device, gpu_id): def gen_yml(device, gpu_id):
fin = open("config.yml", "r") fin = open("config.yml", "r")
config = yaml.load(fin) config = yaml.load(fin, yaml.FullLoader)
fin.close() fin.close()
config["dag"]["tracer"] = {"interval_s": 10} config["dag"]["tracer"] = {"interval_s": 10}
if device == "gpu": if device == "gpu":
......
...@@ -28,7 +28,7 @@ from paddle_serving_client.utils import benchmark_args, show_latency ...@@ -28,7 +28,7 @@ from paddle_serving_client.utils import benchmark_args, show_latency
def parse_benchmark(filein, fileout): def parse_benchmark(filein, fileout):
with open(filein, "r") as fin: with open(filein, "r") as fin:
res = yaml.load(fin) res = yaml.load(fin, yaml.FullLoader)
del_list = [] del_list = []
for key in res["DAG"].keys(): for key in res["DAG"].keys():
if "call" in key: if "call" in key:
...@@ -41,7 +41,7 @@ def parse_benchmark(filein, fileout): ...@@ -41,7 +41,7 @@ def parse_benchmark(filein, fileout):
def gen_yml(device, gpu_id): def gen_yml(device, gpu_id):
fin = open("config.yml", "r") fin = open("config.yml", "r")
config = yaml.load(fin) config = yaml.load(fin, yaml.FullLoader)
fin.close() fin.close()
config["dag"]["tracer"] = {"interval_s": 10} config["dag"]["tracer"] = {"interval_s": 10}
if device == "gpu": if device == "gpu":
......
...@@ -28,7 +28,7 @@ from paddle_serving_client.utils import benchmark_args, show_latency ...@@ -28,7 +28,7 @@ from paddle_serving_client.utils import benchmark_args, show_latency
def parse_benchmark(filein, fileout): def parse_benchmark(filein, fileout):
with open(filein, "r") as fin: with open(filein, "r") as fin:
res = yaml.load(fin) res = yaml.load(fin, yaml.FullLoader)
del_list = [] del_list = []
for key in res["DAG"].keys(): for key in res["DAG"].keys():
if "call" in key: if "call" in key:
...@@ -41,7 +41,7 @@ def parse_benchmark(filein, fileout): ...@@ -41,7 +41,7 @@ def parse_benchmark(filein, fileout):
def gen_yml(device, gpu_id): def gen_yml(device, gpu_id):
fin = open("config.yml", "r") fin = open("config.yml", "r")
config = yaml.load(fin) config = yaml.load(fin, yaml.FullLoader)
fin.close() fin.close()
config["dag"]["tracer"] = {"interval_s": 10} config["dag"]["tracer"] = {"interval_s": 10}
if device == "gpu": if device == "gpu":
......
...@@ -33,7 +33,7 @@ def cv2_to_base64(image): ...@@ -33,7 +33,7 @@ def cv2_to_base64(image):
def parse_benchmark(filein, fileout): def parse_benchmark(filein, fileout):
with open(filein, "r") as fin: with open(filein, "r") as fin:
res = yaml.load(fin) res = yaml.load(fin, yaml.FullLoader)
del_list = [] del_list = []
for key in res["DAG"].keys(): for key in res["DAG"].keys():
if "call" in key: if "call" in key:
...@@ -46,7 +46,7 @@ def parse_benchmark(filein, fileout): ...@@ -46,7 +46,7 @@ def parse_benchmark(filein, fileout):
def gen_yml(device, gpu_id): def gen_yml(device, gpu_id):
fin = open("config.yml", "r") fin = open("config.yml", "r")
config = yaml.load(fin) config = yaml.load(fin, yaml.FullLoader)
fin.close() fin.close()
config["dag"]["tracer"] = {"interval_s": 30} config["dag"]["tracer"] = {"interval_s": 30}
if device == "gpu": if device == "gpu":
......
...@@ -25,7 +25,7 @@ class FasterRCNNOp(Op): ...@@ -25,7 +25,7 @@ class FasterRCNNOp(Op):
self.img_preprocess = Sequential([ self.img_preprocess = Sequential([
BGR2RGB(), Div(255.0), BGR2RGB(), Div(255.0),
Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], False), Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], False),
Resize((640, 640)), Transpose((2, 0, 1)) Resize(640, 640), Transpose((2, 0, 1))
]) ])
self.img_postprocess = RCNNPostprocess("label_list.txt", "output") self.img_postprocess = RCNNPostprocess("label_list.txt", "output")
......
...@@ -33,7 +33,7 @@ def cv2_to_base64(image): ...@@ -33,7 +33,7 @@ def cv2_to_base64(image):
def parse_benchmark(filein, fileout): def parse_benchmark(filein, fileout):
with open(filein, "r") as fin: with open(filein, "r") as fin:
res = yaml.load(fin) res = yaml.load(fin, yaml.FullLoader)
del_list = [] del_list = []
for key in res["DAG"].keys(): for key in res["DAG"].keys():
if "call" in key: if "call" in key:
...@@ -46,7 +46,7 @@ def parse_benchmark(filein, fileout): ...@@ -46,7 +46,7 @@ def parse_benchmark(filein, fileout):
def gen_yml(device, gpu_id): def gen_yml(device, gpu_id):
fin = open("config.yml", "r") fin = open("config.yml", "r")
config = yaml.load(fin) config = yaml.load(fin, yaml.FullLoader)
fin.close() fin.close()
config["dag"]["tracer"] = {"interval_s": 30} config["dag"]["tracer"] = {"interval_s": 30}
if device == "gpu": if device == "gpu":
......
...@@ -33,7 +33,7 @@ def cv2_to_base64(image): ...@@ -33,7 +33,7 @@ def cv2_to_base64(image):
def parse_benchmark(filein, fileout): def parse_benchmark(filein, fileout):
with open(filein, "r") as fin: with open(filein, "r") as fin:
res = yaml.load(fin) res = yaml.load(fin, yaml.FullLoader)
del_list = [] del_list = []
for key in res["DAG"].keys(): for key in res["DAG"].keys():
if "call" in key: if "call" in key:
...@@ -46,7 +46,7 @@ def parse_benchmark(filein, fileout): ...@@ -46,7 +46,7 @@ def parse_benchmark(filein, fileout):
def gen_yml(device, gpu_id): def gen_yml(device, gpu_id):
fin = open("config.yml", "r") fin = open("config.yml", "r")
config = yaml.load(fin) config = yaml.load(fin, yaml.FullLoader)
fin.close() fin.close()
config["dag"]["tracer"] = {"interval_s": 30} config["dag"]["tracer"] = {"interval_s": 30}
if device == "gpu": if device == "gpu":
......
...@@ -54,7 +54,7 @@ from paddle_serving_client.utils import benchmark_args, show_latency ...@@ -54,7 +54,7 @@ from paddle_serving_client.utils import benchmark_args, show_latency
def parse_benchmark(filein, fileout): def parse_benchmark(filein, fileout):
with open(filein, "r") as fin: with open(filein, "r") as fin:
res = yaml.load(fin) res = yaml.load(fin, yaml.FullLoader)
del_list = [] del_list = []
for key in res["DAG"].keys(): for key in res["DAG"].keys():
if "call" in key: if "call" in key:
...@@ -67,7 +67,7 @@ def parse_benchmark(filein, fileout): ...@@ -67,7 +67,7 @@ def parse_benchmark(filein, fileout):
def gen_yml(device): def gen_yml(device):
fin = open("config.yml", "r") fin = open("config.yml", "r")
config = yaml.load(fin) config = yaml.load(fin, yaml.FullLoader)
fin.close() fin.close()
config["dag"]["tracer"] = {"interval_s": 10} config["dag"]["tracer"] = {"interval_s": 10}
if device == "gpu": if device == "gpu":
......
...@@ -28,7 +28,7 @@ from paddle_serving_client.utils import benchmark_args, show_latency ...@@ -28,7 +28,7 @@ from paddle_serving_client.utils import benchmark_args, show_latency
def parse_benchmark(filein, fileout): def parse_benchmark(filein, fileout):
with open(filein, "r") as fin: with open(filein, "r") as fin:
res = yaml.load(fin) res = yaml.load(fin, yaml.FullLoader)
del_list = [] del_list = []
for key in res["DAG"].keys(): for key in res["DAG"].keys():
if "call" in key: if "call" in key:
...@@ -41,7 +41,7 @@ def parse_benchmark(filein, fileout): ...@@ -41,7 +41,7 @@ def parse_benchmark(filein, fileout):
def gen_yml(device): def gen_yml(device):
fin = open("config.yml", "r") fin = open("config.yml", "r")
config = yaml.load(fin) config = yaml.load(fin, yaml.FullLoader)
fin.close() fin.close()
config["dag"]["tracer"] = {"interval_s": 10} config["dag"]["tracer"] = {"interval_s": 10}
if device == "gpu": if device == "gpu":
......
...@@ -27,7 +27,7 @@ from paddle_serving_client.utils import benchmark_args, show_latency ...@@ -27,7 +27,7 @@ from paddle_serving_client.utils import benchmark_args, show_latency
def gen_yml(): def gen_yml():
fin = open("config.yml", "r") fin = open("config.yml", "r")
config = yaml.load(fin) config = yaml.load(fin, yaml.FullLoader)
fin.close() fin.close()
config["dag"]["tracer"] = {"interval_s": 5} config["dag"]["tracer"] = {"interval_s": 5}
with open("config2.yml", "w") as fout: with open("config2.yml", "w") as fout:
......
...@@ -96,7 +96,7 @@ if __name__ == "__main__": ...@@ -96,7 +96,7 @@ if __name__ == "__main__":
args = parse_args() args = parse_args()
benchmark_cfg_filename = args.benchmark_cfg benchmark_cfg_filename = args.benchmark_cfg
f = open(benchmark_cfg_filename, 'r') f = open(benchmark_cfg_filename, 'r')
benchmark_config = yaml.load(f) benchmark_config = yaml.load(f, yaml.FullLoader)
f.close() f.close()
benchmark_log_filename = args.benchmark_log benchmark_log_filename = args.benchmark_log
f = open(benchmark_log_filename, 'r') f = open(benchmark_log_filename, 'r')
......
...@@ -274,7 +274,7 @@ class OpAnalyst(object): ...@@ -274,7 +274,7 @@ class OpAnalyst(object):
""" """
import yaml import yaml
with open(op_config_yaml) as f: with open(op_config_yaml) as f:
op_config = yaml.load(f) op_config = yaml.load(f, yaml.FullLoader)
# check that each model is deployed on a different card # check that each model is deployed on a different card
card_set = set() card_set = set()
......
...@@ -341,7 +341,7 @@ class ServerYamlConfChecker(object): ...@@ -341,7 +341,7 @@ class ServerYamlConfChecker(object):
" or yml_dict can be selected as the parameter.") " or yml_dict can be selected as the parameter.")
if yml_file is not None: if yml_file is not None:
with io.open(yml_file, encoding='utf-8') as f: with io.open(yml_file, encoding='utf-8') as f:
conf = yaml.load(f.read()) conf = yaml.load(f.read(), yaml.FullLoader)
elif yml_dict is not None: elif yml_dict is not None:
conf = yml_dict conf = yml_dict
else: else:
......
...@@ -7,7 +7,7 @@ protobuf>=3.12.2 ...@@ -7,7 +7,7 @@ protobuf>=3.12.2
grpcio-tools>=1.28.1 grpcio-tools>=1.28.1
grpcio>=1.28.1 grpcio>=1.28.1
func-timeout>=4.3.5 func-timeout>=4.3.5
pyyaml>=1.3.0 pyyaml>=5.1
flask>=1.1.2 flask>=1.1.2
click==7.1.2 click==7.1.2
itsdangerous==1.1.0 itsdangerous==1.1.0
......
...@@ -6,7 +6,7 @@ google>=2.0.3 ...@@ -6,7 +6,7 @@ google>=2.0.3
opencv-python==4.2.0.32 opencv-python==4.2.0.32
protobuf>=3.12.2 protobuf>=3.12.2
func-timeout>=4.3.5 func-timeout>=4.3.5
pyyaml>=1.3.0 pyyaml>=5.1
flask>=1.1.2 flask>=1.1.2
click==7.1.2 click==7.1.2
itsdangerous==1.1.0 itsdangerous==1.1.0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册