提交 ed1c6e8b 编写于 作者: W wuzewu

Update module api to support gpu and batch_size config

上级 04308cee
......@@ -16,7 +16,7 @@ if __name__ == "__main__":
inputs = {"text": test_text}
# execute predict and print the result
results = lac.lexical_analysis(data=inputs)
results = lac.lexical_analysis(data=inputs, use_gpu=True, batch_size=10)
for result in results:
if six.PY2:
print(
......
......@@ -21,6 +21,7 @@ import argparse
import json
import os
import sys
import ast
import six
import pandas
......@@ -80,6 +81,18 @@ class RunCommand(BaseCommand):
default=config['default'],
help=config['help'])
self.arg_config_group.add_argument(
'--use_gpu',
type=ast.literal_eval,
default=False,
help="whether use GPU for prediction")
self.arg_config_group.add_argument(
'--batch_size',
type=int,
default=1,
help="batch size for prediction")
self.arg_config_group.add_argument(
'--config',
type=str,
......@@ -224,7 +237,11 @@ class RunCommand(BaseCommand):
return False
results = self.module(
sign_name=self.module.default_signature.name, data=data, **config)
sign_name=self.module.default_signature.name,
data=data,
use_gpu=self.args.use_gpu,
batch_size=self.args.batch_size,
**config)
if six.PY2:
try:
......
......@@ -434,7 +434,7 @@ class Module(object):
for key, value in self.extra_info.items():
utils.from_pyobj_to_module_attr(value, extra_info.map.data[key])
def __call__(self, sign_name, data, **kwargs):
def __call__(self, sign_name, data, use_gpu=False, batch_size=1, **kwargs):
self.check_processor()
def _get_reader_and_feeder(data_format, data, place):
......@@ -463,15 +463,13 @@ class Module(object):
with fluid.program_guard(program):
result = []
index = 0
if "PADDLEHUB_CUDA_ENABLE" in os.environ:
place = fluid.CUDAPlace(0)
else:
place = fluid.CPUPlace()
try:
_places = os.environ["CUDA_VISIBLE_DEVICES"]
int(_places[0])
except:
use_gpu = False
if "PADDLEHUB_BATCH_SIZE" in os.environ:
batch_size = os.environ["PADDLEHUB_BATCH_SIZE"]
else:
batch_size = 1
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place=place)
data = self.processor.preprocess(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册