提交 a8cd270e 编写于 作者: M MRXLT

add use_mkl && fix imdb script

上级 e269af44
......@@ -21,6 +21,7 @@ logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)
def load_vocab(filename):
vocab = {}
with open(filename) as f:
......@@ -31,17 +32,19 @@ def load_vocab(filename):
vocab["<unk>"] = len(vocab)
return vocab
if __name__ == "__main__":
vocab = load_vocab('imdb.vocab')
dict_dim = len(vocab)
data = fluid.layers.data(name="words", shape=[1], dtype="int64", lod_level=1)
data = fluid.layers.data(
name="words", shape=[1], dtype="int64", lod_level=1)
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
dataset = fluid.DatasetFactory().create_dataset()
filelist = ["train_data/%s" % x for x in os.listdir("train_data")]
dataset.set_use_var([data, label])
pipe_command = "/home/users/dongdaxiang/paddle_whls/custom_op/paddle_release_home/python/bin/python imdb_reader.py"
pipe_command = "python imdb_reader.py"
dataset.set_pipe_command(pipe_command)
dataset.set_batch_size(4)
dataset.set_filelist(filelist)
......@@ -59,16 +62,14 @@ if __name__ == "__main__":
import paddle_serving_client.io as serving_io
for i in range(epochs):
exe.train_from_dataset(program=fluid.default_main_program(),
dataset=dataset, debug=False)
exe.train_from_dataset(
program=fluid.default_main_program(), dataset=dataset, debug=False)
logger.info("TRAIN --> pass: {}".format(i))
if i == 5:
serving_io.save_model("serving_server_model",
"serving_client_conf",
{"words": data, "label": label},
{"cost": avg_cost, "acc": acc,
"prediction": prediction},
fluid.default_main_program())
serving_io.save_model("serving_server_model", "serving_client_conf",
{"words": data,
"label": label}, {
"cost": avg_cost,
"acc": acc,
"prediction": prediction
}, fluid.default_main_program())
......@@ -23,7 +23,6 @@ def batch_predict(batch_size=4):
client = Client()
client.load_client_config(conf_file)
client.connect(["127.0.0.1:9292"])
start = time.time()
fetch = ["acc", "cost", "prediction"]
feed_batch = []
for line in sys.stdin:
......@@ -44,8 +43,6 @@ def batch_predict(batch_size=4):
for i in range(len(feed_batch)):
print("{} {}".format(fetch_batch[i]["prediction"][1], feed_batch[i][
"label"][0]))
cost = time.time() - start
print("total cost : {}".format(cost))
if __name__ == '__main__':
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from paddle_serving_server_gpu import OpMaker
from paddle_serving_server_gpu import OpSeqMaker
from paddle_serving_server_gpu import Server
op_maker = OpMaker()
read_op = op_maker.create('general_reader')
general_infer_op = op_maker.create('general_infer')
op_seq_maker = OpSeqMaker()
op_seq_maker.add_op(read_op)
op_seq_maker.add_op(general_infer_op)
server = Server()
server.set_op_sequence(op_seq_maker.get_op_sequence())
server.set_num_threads(12)
server.load_model_config(sys.argv[1])
port = int(sys.argv[2])
server.prepare_server(workdir="work_dir1", port=port, device="gpu")
server.run_server()
......@@ -29,7 +29,6 @@ op_seq_maker.add_op(general_infer_op)
op_seq_maker.add_op(general_response_op)
server = Server()
server.set_vlog_level(3)
server.set_op_sequence(op_seq_maker.get_op_sequence())
server.set_num_threads(4)
......
......@@ -89,6 +89,7 @@ class Server(object):
self.module_path = os.path.dirname(paddle_serving_server.__file__)
self.cur_path = os.getcwd()
self.use_local_bin = False
self.mkl_flag = False
def set_max_concurrency(self, concurrency):
self.max_concurrency = concurrency
......@@ -172,16 +173,16 @@ class Server(object):
# check config here
# print config here
def use_mkl(self):
self.mkl_flag = True
def get_device_version(self):
avx_flag = False
mkl_flag = False
mkl_flag = self.mkl_flag
openblas_flag = False
r = os.system("cat /proc/cpuinfo | grep avx > /dev/null 2>&1")
if r == 0:
avx_flag = True
r = os.system("which mkl")
if r == 0:
mkl_flag = True
if avx_flag:
if mkl_flag:
device_version = "serving-cpu-avx-mkl-"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册