Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Serving
提交
527e0cb2
S
Serving
项目概览
PaddlePaddle
/
Serving
大约 1 年 前同步成功
通知
186
Star
833
Fork
253
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
105
列表
看板
标记
里程碑
合并请求
10
Wiki
2
Wiki
分析
仓库
DevOps
项目成员
Pages
S
Serving
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
105
Issue
105
列表
看板
标记
里程碑
合并请求
10
合并请求
10
Pages
分析
分析
仓库分析
DevOps
Wiki
2
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
527e0cb2
编写于
7月 01, 2020
作者:
W
wangjiawei04
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update pipeline
上级
c8ba9440
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
42 addition
and
22 deletion
+42
-22
python/pipeline/channel.py
python/pipeline/channel.py
+28
-8
python/pipeline/operator.py
python/pipeline/operator.py
+6
-6
python/pipeline/pipeline_client.py
python/pipeline/pipeline_client.py
+5
-5
python/pipeline/pipeline_server.py
python/pipeline/pipeline_server.py
+2
-2
python/pipeline/profiler.py
python/pipeline/profiler.py
+1
-1
未找到文件。
python/pipeline/channel.py
浏览文件 @
527e0cb2
...
...
@@ -27,7 +27,7 @@ import logging
import
enum
import
copy
_LOGGER
=
logging
.
getLogger
(
__name__
)
_LOGGER
=
logging
.
getLogger
()
class
ChannelDataEcode
(
enum
.
Enum
):
...
...
@@ -102,12 +102,32 @@ class ChannelData(object):
def
check_npdata
(
npdata
):
ecode
=
ChannelDataEcode
.
OK
.
value
error_info
=
None
if
isinstance
(
npdata
,
list
):
# batch data
for
sample
in
npdata
:
if
not
isinstance
(
sample
,
dict
):
ecode
=
ChannelDataEcode
.
TYPE_ERROR
.
value
error_info
=
"the value of data must "
\
"be dict, but get {}."
.
format
(
type
(
sample
))
break
for
_
,
value
in
sample
.
items
():
if
not
isinstance
(
value
,
np
.
ndarray
):
ecode
=
ChannelDataEcode
.
TYPE_ERROR
.
value
error_info
=
"the value of data must "
\
"be np.ndarray, but get {}."
.
format
(
type
(
value
))
return
ecode
,
error_info
elif
isinstance
(
npdata
,
dict
):
# batch_size = 1
for
_
,
value
in
npdata
.
items
():
if
not
isinstance
(
value
,
np
.
ndarray
):
ecode
=
ChannelDataEcode
.
TYPE_ERROR
.
value
error_info
=
"the value of data must "
\
"be np.ndarray, but get {}."
.
format
(
type
(
value
))
break
else
:
ecode
=
ChannelDataEcode
.
TYPE_ERROR
.
value
error_info
=
"the value of data must "
\
"be dict, but get {}."
.
format
(
type
(
npdata
))
return
ecode
,
error_info
def
parse
(
self
):
...
...
python/pipeline/operator.py
浏览文件 @
527e0cb2
...
...
@@ -25,7 +25,7 @@ from .proto import pipeline_service_pb2
from
.channel
import
ThreadChannel
,
ProcessChannel
,
ChannelDataEcode
,
ChannelData
,
ChannelDataType
from
.util
import
NameGenerator
_LOGGER
=
logging
.
getLogger
(
__name__
)
_LOGGER
=
logging
.
getLogger
()
_op_name_gen
=
NameGenerator
(
"Op"
)
...
...
@@ -142,7 +142,7 @@ class Op(object):
_LOGGER
.
debug
(
self
.
_log
(
"get call_result"
))
return
call_result
def
postprocess
(
self
,
fetch_dict
):
def
postprocess
(
self
,
input_dict
,
fetch_dict
):
return
fetch_dict
def
stop
(
self
):
...
...
@@ -267,10 +267,10 @@ class Op(object):
midped_data
=
preped_data
return
midped_data
,
error_channeldata
def
_run_postprocess
(
self
,
midped_data
,
data_id
,
log_func
):
def
_run_postprocess
(
self
,
input_dict
,
midped_data
,
data_id
,
log_func
):
output_data
,
error_channeldata
=
None
,
None
try
:
postped_data
=
self
.
postprocess
(
midped_data
)
postped_data
=
self
.
postprocess
(
input_dict
,
midped_data
)
except
Exception
as
e
:
error_info
=
log_func
(
e
)
_LOGGER
.
error
(
error_info
)
...
...
@@ -359,8 +359,8 @@ class Op(object):
# postprocess
self
.
_profiler_record
(
"{}-postp#{}_0"
.
format
(
op_info_prefix
,
tid
))
output_data
,
error_channeldata
=
self
.
_run_postprocess
(
midped_data
,
data_id
,
log
)
output_data
,
error_channeldata
=
self
.
_run_postprocess
(
parsed_data
,
midped_data
,
data_id
,
log
)
self
.
_profiler_record
(
"{}-postp#{}_1"
.
format
(
op_info_prefix
,
tid
))
if
error_channeldata
is
not
None
:
self
.
_push_to_output_channels
(
error_channeldata
,
...
...
python/pipeline/pipeline_client.py
浏览文件 @
527e0cb2
...
...
@@ -20,7 +20,7 @@ import functools
from
.proto
import
pipeline_service_pb2
from
.proto
import
pipeline_service_pb2_grpc
_LOGGER
=
logging
.
getLogger
(
__name__
)
_LOGGER
=
logging
.
getLogger
()
class
PipelineClient
(
object
):
...
...
@@ -52,7 +52,7 @@ class PipelineClient(object):
return
{
"ecode"
:
resp
.
ecode
,
"error_info"
:
resp
.
error_info
}
fetch_map
=
{
"ecode"
:
resp
.
ecode
}
for
idx
,
key
in
enumerate
(
resp
.
key
):
if
key
not
in
fetch
:
if
fetch
is
not
None
and
key
not
in
fetch
:
continue
data
=
resp
.
value
[
idx
]
try
:
...
...
@@ -62,16 +62,16 @@ class PipelineClient(object):
fetch_map
[
key
]
=
data
return
fetch_map
def
predict
(
self
,
feed_dict
,
fetch
,
asyn
=
False
):
def
predict
(
self
,
feed_dict
,
fetch
=
None
,
asyn
=
False
):
if
not
isinstance
(
feed_dict
,
dict
):
raise
TypeError
(
"feed must be dict type with format: {name: value}."
)
if
not
isinstance
(
fetch
,
list
):
if
fetch
is
not
None
and
not
isinstance
(
fetch
,
list
):
raise
TypeError
(
"fetch must be list type with format: [name]."
)
req
=
self
.
_pack_request_package
(
feed_dict
)
if
not
asyn
:
resp
=
self
.
_stub
.
inference
(
req
)
return
self
.
_unpack_response_package
(
resp
)
return
self
.
_unpack_response_package
(
resp
,
fetch
)
else
:
call_future
=
self
.
_stub
.
inference
.
future
(
req
)
return
PipelinePredictFuture
(
...
...
python/pipeline/pipeline_server.py
浏览文件 @
527e0cb2
...
...
@@ -45,7 +45,7 @@ from .channel import ThreadChannel, ProcessChannel, ChannelData, ChannelDataEcod
from
.profiler
import
TimeProfiler
from
.util
import
NameGenerator
_LOGGER
=
logging
.
getLogger
(
__name__
)
_LOGGER
=
logging
.
getLogger
()
_profiler
=
TimeProfiler
()
...
...
@@ -384,7 +384,7 @@ class PipelineServer(object):
def
prepare_server
(
self
,
yml_file
):
with
open
(
yml_file
)
as
f
:
yml_config
=
yaml
.
load
(
f
.
read
())
yml_config
=
yaml
.
load
(
f
.
read
()
,
Loader
=
yaml
.
FullLoader
)
self
.
_port
=
yml_config
.
get
(
'port'
,
8080
)
if
not
self
.
_port_is_available
(
self
.
_port
):
raise
SystemExit
(
"Prot {} is already used"
.
format
(
self
.
_port
))
...
...
python/pipeline/profiler.py
浏览文件 @
527e0cb2
...
...
@@ -24,7 +24,7 @@ else:
raise
Exception
(
"Error Python version"
)
import
time
_LOGGER
=
logging
.
getLogger
(
__name__
)
_LOGGER
=
logging
.
getLogger
()
class
TimeProfiler
(
object
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录