提交 ed1c6e8b 编写于 作者: W wuzewu

Update module api to support gpu and batch_size config

上级 04308cee
...@@ -16,7 +16,7 @@ if __name__ == "__main__": ...@@ -16,7 +16,7 @@ if __name__ == "__main__":
inputs = {"text": test_text} inputs = {"text": test_text}
# execute predict and print the result # execute predict and print the result
results = lac.lexical_analysis(data=inputs) results = lac.lexical_analysis(data=inputs, use_gpu=True, batch_size=10)
for result in results: for result in results:
if six.PY2: if six.PY2:
print( print(
......
...@@ -21,6 +21,7 @@ import argparse ...@@ -21,6 +21,7 @@ import argparse
import json import json
import os import os
import sys import sys
import ast
import six import six
import pandas import pandas
...@@ -80,6 +81,18 @@ class RunCommand(BaseCommand): ...@@ -80,6 +81,18 @@ class RunCommand(BaseCommand):
default=config['default'], default=config['default'],
help=config['help']) help=config['help'])
self.arg_config_group.add_argument(
'--use_gpu',
type=ast.literal_eval,
default=False,
help="whether use GPU for prediction")
self.arg_config_group.add_argument(
'--batch_size',
type=int,
default=1,
help="batch size for prediction")
self.arg_config_group.add_argument( self.arg_config_group.add_argument(
'--config', '--config',
type=str, type=str,
...@@ -224,7 +237,11 @@ class RunCommand(BaseCommand): ...@@ -224,7 +237,11 @@ class RunCommand(BaseCommand):
return False return False
results = self.module( results = self.module(
sign_name=self.module.default_signature.name, data=data, **config) sign_name=self.module.default_signature.name,
data=data,
use_gpu=self.args.use_gpu,
batch_size=self.args.batch_size,
**config)
if six.PY2: if six.PY2:
try: try:
......
...@@ -434,7 +434,7 @@ class Module(object): ...@@ -434,7 +434,7 @@ class Module(object):
for key, value in self.extra_info.items(): for key, value in self.extra_info.items():
utils.from_pyobj_to_module_attr(value, extra_info.map.data[key]) utils.from_pyobj_to_module_attr(value, extra_info.map.data[key])
def __call__(self, sign_name, data, **kwargs): def __call__(self, sign_name, data, use_gpu=False, batch_size=1, **kwargs):
self.check_processor() self.check_processor()
def _get_reader_and_feeder(data_format, data, place): def _get_reader_and_feeder(data_format, data, place):
...@@ -463,15 +463,13 @@ class Module(object): ...@@ -463,15 +463,13 @@ class Module(object):
with fluid.program_guard(program): with fluid.program_guard(program):
result = [] result = []
index = 0 index = 0
if "PADDLEHUB_CUDA_ENABLE" in os.environ: try:
place = fluid.CUDAPlace(0) _places = os.environ["CUDA_VISIBLE_DEVICES"]
else: int(_places[0])
place = fluid.CPUPlace() except:
use_gpu = False
if "PADDLEHUB_BATCH_SIZE" in os.environ: place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
batch_size = os.environ["PADDLEHUB_BATCH_SIZE"]
else:
batch_size = 1
exe = fluid.Executor(place=place) exe = fluid.Executor(place=place)
data = self.processor.preprocess( data = self.processor.preprocess(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册