Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Serving
提交
ce43cb50
S
Serving
项目概览
PaddlePaddle
/
Serving
大约 2 年 前同步成功
通知
187
Star
833
Fork
253
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
105
列表
看板
标记
里程碑
合并请求
10
Wiki
2
Wiki
分析
仓库
DevOps
项目成员
Pages
S
Serving
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
105
Issue
105
列表
看板
标记
里程碑
合并请求
10
合并请求
10
Pages
分析
分析
仓库分析
DevOps
Wiki
2
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ce43cb50
编写于
7月 01, 2020
作者:
W
wangjiawei04
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update pipeline
上级
b9782cd9
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
42 addition
and
22 deletion
+42
-22
python/pipeline/channel.py
python/pipeline/channel.py
+28
-8
python/pipeline/operator.py
python/pipeline/operator.py
+6
-6
python/pipeline/pipeline_client.py
python/pipeline/pipeline_client.py
+5
-5
python/pipeline/pipeline_server.py
python/pipeline/pipeline_server.py
+2
-2
python/pipeline/profiler.py
python/pipeline/profiler.py
+1
-1
未找到文件。
python/pipeline/channel.py
浏览文件 @
ce43cb50
...
@@ -27,7 +27,7 @@ import logging
...
@@ -27,7 +27,7 @@ import logging
import
enum
import
enum
import
copy
import
copy
_LOGGER
=
logging
.
getLogger
(
__name__
)
_LOGGER
=
logging
.
getLogger
()
class
ChannelDataEcode
(
enum
.
Enum
):
class
ChannelDataEcode
(
enum
.
Enum
):
...
@@ -102,12 +102,32 @@ class ChannelData(object):
...
@@ -102,12 +102,32 @@ class ChannelData(object):
def
check_npdata
(
npdata
):
def
check_npdata
(
npdata
):
ecode
=
ChannelDataEcode
.
OK
.
value
ecode
=
ChannelDataEcode
.
OK
.
value
error_info
=
None
error_info
=
None
if
isinstance
(
npdata
,
list
):
# batch data
for
sample
in
npdata
:
if
not
isinstance
(
sample
,
dict
):
ecode
=
ChannelDataEcode
.
TYPE_ERROR
.
value
error_info
=
"the value of data must "
\
"be dict, but get {}."
.
format
(
type
(
sample
))
break
for
_
,
value
in
sample
.
items
():
if
not
isinstance
(
value
,
np
.
ndarray
):
ecode
=
ChannelDataEcode
.
TYPE_ERROR
.
value
error_info
=
"the value of data must "
\
"be np.ndarray, but get {}."
.
format
(
type
(
value
))
return
ecode
,
error_info
elif
isinstance
(
npdata
,
dict
):
# batch_size = 1
for
_
,
value
in
npdata
.
items
():
for
_
,
value
in
npdata
.
items
():
if
not
isinstance
(
value
,
np
.
ndarray
):
if
not
isinstance
(
value
,
np
.
ndarray
):
ecode
=
ChannelDataEcode
.
TYPE_ERROR
.
value
ecode
=
ChannelDataEcode
.
TYPE_ERROR
.
value
error_info
=
"the value of data must "
\
error_info
=
"the value of data must "
\
"be np.ndarray, but get {}."
.
format
(
type
(
value
))
"be np.ndarray, but get {}."
.
format
(
type
(
value
))
break
break
else
:
ecode
=
ChannelDataEcode
.
TYPE_ERROR
.
value
error_info
=
"the value of data must "
\
"be dict, but get {}."
.
format
(
type
(
npdata
))
return
ecode
,
error_info
return
ecode
,
error_info
def
parse
(
self
):
def
parse
(
self
):
...
...
python/pipeline/operator.py
浏览文件 @
ce43cb50
...
@@ -25,7 +25,7 @@ from .proto import pipeline_service_pb2
...
@@ -25,7 +25,7 @@ from .proto import pipeline_service_pb2
from
.channel
import
ThreadChannel
,
ProcessChannel
,
ChannelDataEcode
,
ChannelData
,
ChannelDataType
from
.channel
import
ThreadChannel
,
ProcessChannel
,
ChannelDataEcode
,
ChannelData
,
ChannelDataType
from
.util
import
NameGenerator
from
.util
import
NameGenerator
_LOGGER
=
logging
.
getLogger
(
__name__
)
_LOGGER
=
logging
.
getLogger
()
_op_name_gen
=
NameGenerator
(
"Op"
)
_op_name_gen
=
NameGenerator
(
"Op"
)
...
@@ -142,7 +142,7 @@ class Op(object):
...
@@ -142,7 +142,7 @@ class Op(object):
_LOGGER
.
debug
(
self
.
_log
(
"get call_result"
))
_LOGGER
.
debug
(
self
.
_log
(
"get call_result"
))
return
call_result
return
call_result
def
postprocess
(
self
,
fetch_dict
):
def
postprocess
(
self
,
input_dict
,
fetch_dict
):
return
fetch_dict
return
fetch_dict
def
stop
(
self
):
def
stop
(
self
):
...
@@ -267,10 +267,10 @@ class Op(object):
...
@@ -267,10 +267,10 @@ class Op(object):
midped_data
=
preped_data
midped_data
=
preped_data
return
midped_data
,
error_channeldata
return
midped_data
,
error_channeldata
def
_run_postprocess
(
self
,
midped_data
,
data_id
,
log_func
):
def
_run_postprocess
(
self
,
input_dict
,
midped_data
,
data_id
,
log_func
):
output_data
,
error_channeldata
=
None
,
None
output_data
,
error_channeldata
=
None
,
None
try
:
try
:
postped_data
=
self
.
postprocess
(
midped_data
)
postped_data
=
self
.
postprocess
(
input_dict
,
midped_data
)
except
Exception
as
e
:
except
Exception
as
e
:
error_info
=
log_func
(
e
)
error_info
=
log_func
(
e
)
_LOGGER
.
error
(
error_info
)
_LOGGER
.
error
(
error_info
)
...
@@ -359,8 +359,8 @@ class Op(object):
...
@@ -359,8 +359,8 @@ class Op(object):
# postprocess
# postprocess
self
.
_profiler_record
(
"{}-postp#{}_0"
.
format
(
op_info_prefix
,
tid
))
self
.
_profiler_record
(
"{}-postp#{}_0"
.
format
(
op_info_prefix
,
tid
))
output_data
,
error_channeldata
=
self
.
_run_postprocess
(
midped_data
,
output_data
,
error_channeldata
=
self
.
_run_postprocess
(
data_id
,
log
)
parsed_data
,
midped_data
,
data_id
,
log
)
self
.
_profiler_record
(
"{}-postp#{}_1"
.
format
(
op_info_prefix
,
tid
))
self
.
_profiler_record
(
"{}-postp#{}_1"
.
format
(
op_info_prefix
,
tid
))
if
error_channeldata
is
not
None
:
if
error_channeldata
is
not
None
:
self
.
_push_to_output_channels
(
error_channeldata
,
self
.
_push_to_output_channels
(
error_channeldata
,
...
...
python/pipeline/pipeline_client.py
浏览文件 @
ce43cb50
...
@@ -20,7 +20,7 @@ import functools
...
@@ -20,7 +20,7 @@ import functools
from
.proto
import
pipeline_service_pb2
from
.proto
import
pipeline_service_pb2
from
.proto
import
pipeline_service_pb2_grpc
from
.proto
import
pipeline_service_pb2_grpc
_LOGGER
=
logging
.
getLogger
(
__name__
)
_LOGGER
=
logging
.
getLogger
()
class
PipelineClient
(
object
):
class
PipelineClient
(
object
):
...
@@ -52,7 +52,7 @@ class PipelineClient(object):
...
@@ -52,7 +52,7 @@ class PipelineClient(object):
return
{
"ecode"
:
resp
.
ecode
,
"error_info"
:
resp
.
error_info
}
return
{
"ecode"
:
resp
.
ecode
,
"error_info"
:
resp
.
error_info
}
fetch_map
=
{
"ecode"
:
resp
.
ecode
}
fetch_map
=
{
"ecode"
:
resp
.
ecode
}
for
idx
,
key
in
enumerate
(
resp
.
key
):
for
idx
,
key
in
enumerate
(
resp
.
key
):
if
key
not
in
fetch
:
if
fetch
is
not
None
and
key
not
in
fetch
:
continue
continue
data
=
resp
.
value
[
idx
]
data
=
resp
.
value
[
idx
]
try
:
try
:
...
@@ -62,16 +62,16 @@ class PipelineClient(object):
...
@@ -62,16 +62,16 @@ class PipelineClient(object):
fetch_map
[
key
]
=
data
fetch_map
[
key
]
=
data
return
fetch_map
return
fetch_map
def
predict
(
self
,
feed_dict
,
fetch
,
asyn
=
False
):
def
predict
(
self
,
feed_dict
,
fetch
=
None
,
asyn
=
False
):
if
not
isinstance
(
feed_dict
,
dict
):
if
not
isinstance
(
feed_dict
,
dict
):
raise
TypeError
(
raise
TypeError
(
"feed must be dict type with format: {name: value}."
)
"feed must be dict type with format: {name: value}."
)
if
not
isinstance
(
fetch
,
list
):
if
fetch
is
not
None
and
not
isinstance
(
fetch
,
list
):
raise
TypeError
(
"fetch must be list type with format: [name]."
)
raise
TypeError
(
"fetch must be list type with format: [name]."
)
req
=
self
.
_pack_request_package
(
feed_dict
)
req
=
self
.
_pack_request_package
(
feed_dict
)
if
not
asyn
:
if
not
asyn
:
resp
=
self
.
_stub
.
inference
(
req
)
resp
=
self
.
_stub
.
inference
(
req
)
return
self
.
_unpack_response_package
(
resp
)
return
self
.
_unpack_response_package
(
resp
,
fetch
)
else
:
else
:
call_future
=
self
.
_stub
.
inference
.
future
(
req
)
call_future
=
self
.
_stub
.
inference
.
future
(
req
)
return
PipelinePredictFuture
(
return
PipelinePredictFuture
(
...
...
python/pipeline/pipeline_server.py
浏览文件 @
ce43cb50
...
@@ -45,7 +45,7 @@ from .channel import ThreadChannel, ProcessChannel, ChannelData, ChannelDataEcod
...
@@ -45,7 +45,7 @@ from .channel import ThreadChannel, ProcessChannel, ChannelData, ChannelDataEcod
from
.profiler
import
TimeProfiler
from
.profiler
import
TimeProfiler
from
.util
import
NameGenerator
from
.util
import
NameGenerator
_LOGGER
=
logging
.
getLogger
(
__name__
)
_LOGGER
=
logging
.
getLogger
()
_profiler
=
TimeProfiler
()
_profiler
=
TimeProfiler
()
...
@@ -384,7 +384,7 @@ class PipelineServer(object):
...
@@ -384,7 +384,7 @@ class PipelineServer(object):
def
prepare_server
(
self
,
yml_file
):
def
prepare_server
(
self
,
yml_file
):
with
open
(
yml_file
)
as
f
:
with
open
(
yml_file
)
as
f
:
yml_config
=
yaml
.
load
(
f
.
read
())
yml_config
=
yaml
.
load
(
f
.
read
()
,
Loader
=
yaml
.
FullLoader
)
self
.
_port
=
yml_config
.
get
(
'port'
,
8080
)
self
.
_port
=
yml_config
.
get
(
'port'
,
8080
)
if
not
self
.
_port_is_available
(
self
.
_port
):
if
not
self
.
_port_is_available
(
self
.
_port
):
raise
SystemExit
(
"Prot {} is already used"
.
format
(
self
.
_port
))
raise
SystemExit
(
"Prot {} is already used"
.
format
(
self
.
_port
))
...
...
python/pipeline/profiler.py
浏览文件 @
ce43cb50
...
@@ -24,7 +24,7 @@ else:
...
@@ -24,7 +24,7 @@ else:
raise
Exception
(
"Error Python version"
)
raise
Exception
(
"Error Python version"
)
import
time
import
time
_LOGGER
=
logging
.
getLogger
(
__name__
)
_LOGGER
=
logging
.
getLogger
()
class
TimeProfiler
(
object
):
class
TimeProfiler
(
object
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录