Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
0c554a59
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0c554a59
编写于
12月 12, 2018
作者:
S
sneaxiy
浏览文件
操作
浏览文件
下载
差异文件
merge develop
test=develop
上级
ca84c2ca
6951ef9a
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
108 addition
and
93 deletion
+108
-93
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+18
-10
paddle/fluid/framework/executor.cc
paddle/fluid/framework/executor.cc
+8
-5
paddle/fluid/framework/ngraph_bridge.cc
paddle/fluid/framework/ngraph_bridge.cc
+14
-15
paddle/fluid/framework/ngraph_bridge.h
paddle/fluid/framework/ngraph_bridge.h
+0
-3
paddle/fluid/framework/ngraph_operator.cc
paddle/fluid/framework/ngraph_operator.cc
+28
-31
paddle/fluid/framework/ngraph_operator.h
paddle/fluid/framework/ngraph_operator.h
+3
-6
paddle/fluid/inference/analysis/passes/ir_graph_build_pass.cc
...le/fluid/inference/analysis/passes/ir_graph_build_pass.cc
+4
-3
paddle/fluid/inference/tests/api/CMakeLists.txt
paddle/fluid/inference/tests/api/CMakeLists.txt
+1
-1
paddle/fluid/inference/tests/api/tester_helper.h
paddle/fluid/inference/tests/api/tester_helper.h
+14
-2
paddle/fluid/inference/tests/api/trt_models_tester.cc
paddle/fluid/inference/tests/api/trt_models_tester.cc
+3
-0
paddle/fluid/inference/utils/benchmark.cc
paddle/fluid/inference/utils/benchmark.cc
+1
-1
paddle/fluid/inference/utils/visualizer.cc
paddle/fluid/inference/utils/visualizer.cc
+5
-5
paddle/fluid/operators/activation_op.h
paddle/fluid/operators/activation_op.h
+7
-8
paddle/fluid/operators/distributed/brpc_client.cc
paddle/fluid/operators/distributed/brpc_client.cc
+1
-1
paddle/fluid/operators/distributed/grpc_client.cc
paddle/fluid/operators/distributed/grpc_client.cc
+1
-2
未找到文件。
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
0c554a59
...
...
@@ -131,11 +131,13 @@ cc_test(version_test SRCS version_test.cc DEPS version)
cc_library
(
proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog version
)
if
(
NOT WIN32
)
cc_library
(
ngraph_bridge SRCS ngraph_bridge.cc DEPS operator framework_proto ngraph
)
cc_library
(
ngraph_operator SRCS ngraph_operator.cc DEPS ngraph_bridge operator op_info device_context tensor scope glog
shape_inference data_transform lod_tensor profiler
)
endif
(
NOT WIN32
)
if
(
WITH_NGRAPH
)
if
(
NOT WIN32
)
cc_library
(
ngraph_bridge SRCS ngraph_bridge.cc DEPS operator framework_proto ngraph
)
cc_library
(
ngraph_operator SRCS ngraph_operator.cc DEPS ngraph_bridge operator op_info device_context tensor scope glog
shape_inference data_transform lod_tensor profiler ngraph
)
endif
(
NOT WIN32
)
endif
(
WITH_NGRAPH
)
cc_library
(
op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc
)
nv_test
(
op_registry_test SRCS op_registry_test.cc DEPS op_registry
)
...
...
@@ -171,14 +173,20 @@ if(WITH_DISTRIBUTE)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
set_source_files_properties
(
executor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
else
()
if
(
NOT WIN32
)
cc_library
(
executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass ngraph_operator variable_helper garbage_collector
)
else
(
NOT WIN32
)
cc_library
(
executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass variable_helper garbage_collector
)
endif
(
NOT WIN32
)
if
(
WITH_NGRAPH
)
if
(
NOT WIN32
)
cc_library
(
executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass ngraph ngraph_operator variable_helper
)
else
(
NOT WIN32
)
cc_library
(
executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass variable_helper
)
endif
(
NOT WIN32
)
else
(
WITH_NGRAPH
)
cc_library
(
executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass variable_helper
)
endif
(
WITH_NGRAPH
)
cc_test
(
test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op
)
endif
()
target_link_libraries
(
executor garbage_collector
)
cc_library
(
parallel_executor SRCS parallel_executor.cc DEPS
threaded_ssa_graph_executor scope_buffered_ssa_graph_executor
graph build_strategy
...
...
paddle/fluid/framework/executor.cc
浏览文件 @
0c554a59
...
...
@@ -18,7 +18,6 @@ limitations under the License. */
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/ngraph_operator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/transfer_scope_cache.h"
...
...
@@ -27,6 +26,10 @@ limitations under the License. */
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h"
#ifdef PADDLE_WITH_NGRAPH
#include "paddle/fluid/framework/ngraph_operator.h"
#endif
DECLARE_bool
(
benchmark
);
DEFINE_bool
(
use_mkldnn
,
false
,
"Use MKLDNN to run"
);
DEFINE_bool
(
use_ngraph
,
false
,
"Use NGRAPH to run"
);
...
...
@@ -131,11 +134,11 @@ static void DeleteUnusedTensors(
static
void
EnableFusedOp
(
ExecutorPrepareContext
*
ctx
)
{
#ifdef PADDLE_WITH_NGRAPH
VLOG
(
3
)
<<
"use_ngraph=True"
;
auto
intervals
=
FusedOperator
::
Fused
OpIntervals
(
&
ctx
->
ops_
);
auto
intervals
=
NgraphOperator
::
Ngraph
OpIntervals
(
&
ctx
->
ops_
);
for
(
auto
&
interval
:
intervals
)
{
auto
*
fused_op
=
new
FusedOperator
(
ctx
->
prog_
,
ctx
->
block_id_
,
interval
.
at
(
0
),
interval
.
at
(
1
));
*
interval
[
0
]
=
std
::
unique_ptr
<
OperatorBase
>
(
fused
_op
);
auto
*
ng_op
=
new
NgraphOperator
(
ctx
->
prog_
,
ctx
->
block_id_
,
interval
.
at
(
0
)
,
interval
.
at
(
1
));
*
interval
[
0
]
=
std
::
unique_ptr
<
OperatorBase
>
(
ng
_op
);
}
for
(
auto
it
=
intervals
.
rbegin
();
it
!=
intervals
.
rend
();
++
it
)
{
ctx
->
ops_
.
erase
(
it
->
at
(
0
)
+
1
,
it
->
at
(
1
));
...
...
paddle/fluid/framework/ngraph_bridge.cc
浏览文件 @
0c554a59
...
...
@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_NGRAPH
#include <algorithm>
#include <functional>
#include <vector>
...
...
@@ -27,14 +26,15 @@ namespace paddle {
namespace
framework
{
static
std
::
shared_ptr
<
ngraph
::
Node
>
GetNode
(
const
std
::
shared_ptr
<
OperatorBase
>&
op
,
const
std
::
string
prm
,
const
std
::
shared_ptr
<
OperatorBase
>&
op
,
const
std
::
string
name
,
const
VariableNameMap
&
var_map
,
std
::
shared_ptr
<
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
ngraph
::
Node
>>>
ngb_node_map
)
{
auto
&
var_names
=
var_map
.
at
(
prm
);
auto
&
var_names
=
var_map
.
at
(
name
);
PADDLE_ENFORCE_EQ
(
var_names
.
size
(),
1
,
"op %s prm %s expects one associated var"
,
op
->
Type
(),
prm
);
"op %s name %s expects one associated var"
,
op
->
Type
(),
name
);
if
(
ngb_node_map
->
find
(
var_names
[
0
])
!=
ngb_node_map
->
end
())
{
return
(
*
ngb_node_map
)[
var_names
[
0
]];
}
else
{
...
...
@@ -43,42 +43,42 @@ static std::shared_ptr<ngraph::Node> GetNode(
}
static
std
::
shared_ptr
<
ngraph
::
Node
>
GetInputNode
(
const
std
::
shared_ptr
<
OperatorBase
>&
op
,
const
std
::
string
prm
,
const
std
::
shared_ptr
<
OperatorBase
>&
op
,
const
std
::
string
name
,
std
::
shared_ptr
<
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
ngraph
::
Node
>>>
ngb_node_map
)
{
return
GetNode
(
op
,
prm
,
op
->
Inputs
(),
ngb_node_map
);
return
GetNode
(
op
,
name
,
op
->
Inputs
(),
ngb_node_map
);
}
static
std
::
shared_ptr
<
ngraph
::
Node
>
GetOutputNode
(
const
std
::
shared_ptr
<
OperatorBase
>&
op
,
const
std
::
string
prm
,
const
std
::
shared_ptr
<
OperatorBase
>&
op
,
const
std
::
string
name
,
std
::
shared_ptr
<
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
ngraph
::
Node
>>>
ngb_node_map
)
{
return
GetNode
(
op
,
prm
,
op
->
Outputs
(),
ngb_node_map
);
return
GetNode
(
op
,
name
,
op
->
Outputs
(),
ngb_node_map
);
}
static
void
SetOutputNode
(
const
std
::
shared_ptr
<
OperatorBase
>&
op
,
const
std
::
string
prm
,
const
std
::
shared_ptr
<
OperatorBase
>&
op
,
const
std
::
string
name
,
std
::
shared_ptr
<
ngraph
::
Node
>
node
,
std
::
shared_ptr
<
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
ngraph
::
Node
>>>
ngb_node_map
)
{
auto
&
var_names
=
op
->
Outputs
().
at
(
prm
);
auto
&
var_names
=
op
->
Outputs
().
at
(
name
);
if
(
var_names
.
size
()
==
1
)
{
(
*
ngb_node_map
)[
var_names
[
0
]]
=
node
;
}
else
if
(
var_names
.
size
()
==
0
)
{
(
*
ngb_node_map
)[
""
]
=
node
;
}
else
{
PADDLE_THROW
(
"
prm %s has more than 1 var_names."
,
prm
);
PADDLE_THROW
(
"
name %s has more than 1 var_names."
,
name
);
}
}
static
bool
HasOutput
(
const
std
::
shared_ptr
<
OperatorBase
>&
op
,
const
std
::
string
prm
)
{
const
std
::
string
name
)
{
auto
&
outputs
=
op
->
Outputs
();
if
(
outputs
.
find
(
prm
)
==
outputs
.
end
())
return
false
;
return
outputs
.
at
(
prm
).
size
()
>
0
;
if
(
outputs
.
find
(
name
)
==
outputs
.
end
())
return
false
;
return
outputs
.
at
(
name
).
size
()
>
0
;
}
template
<
typename
T
>
...
...
@@ -118,4 +118,3 @@ void NgraphBridge::BuildNgNode(const std::shared_ptr<OperatorBase>& op) {
}
// namespace framework
}
// namespace paddle
#endif
paddle/fluid/framework/ngraph_bridge.h
浏览文件 @
0c554a59
...
...
@@ -14,8 +14,6 @@ limitations under the License. */
#pragma once
#ifdef PADDLE_WITH_NGRAPH
#include <algorithm>
#include <map>
#include <string>
...
...
@@ -53,4 +51,3 @@ class NgraphBridge {
}
// namespace framework
}
// namespace paddle
#endif
paddle/fluid/framework/ngraph_operator.cc
浏览文件 @
0c554a59
...
...
@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_NGRAPH
#include <glog/logging.h>
#include <algorithm>
...
...
@@ -58,16 +57,16 @@ typedef enum { /* nGraph support state on ops */
}
op_state
;
// perform graph build through bridge and execute computation
class
Ngraph
Operator
{
class
Ngraph
Engine
{
public:
explicit
Ngraph
Operator
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
,
const
std
::
vector
<
std
::
shared_ptr
<
OperatorBase
>>&
ops
,
const
std
::
unordered_map
<
std
::
string
,
ngraph
::
element
::
Type
>&
var_type_map
,
const
std
::
unordered_set
<
std
::
string
>&
persist
,
const
std
::
unordered_set
<
std
::
string
>&
fetches
,
const
std
::
unordered_set
<
std
::
string
>&
post_op_inputs
,
op_state
ng_op_state
)
explicit
Ngraph
Engine
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
,
const
std
::
vector
<
std
::
shared_ptr
<
OperatorBase
>>&
ops
,
const
std
::
unordered_map
<
std
::
string
,
ngraph
::
element
::
Type
>&
var_type_map
,
const
std
::
unordered_set
<
std
::
string
>&
persist
,
const
std
::
unordered_set
<
std
::
string
>&
fetches
,
const
std
::
unordered_set
<
std
::
string
>&
post_op_inputs
,
op_state
ng_op_state
)
:
scope_
(
scope
),
place_
(
place
),
fused_ops_
(
ops
),
...
...
@@ -132,7 +131,7 @@ class NgraphOperator {
};
std
::
vector
<
std
::
vector
<
std
::
vector
<
std
::
unique_ptr
<
OperatorBase
>>::
iterator
>>
FusedOperator
::
Fused
OpIntervals
(
NgraphOperator
::
Ngraph
OpIntervals
(
std
::
vector
<
std
::
unique_ptr
<
paddle
::
framework
::
OperatorBase
>>*
ops
)
{
std
::
vector
<
std
::
vector
<
std
::
vector
<
std
::
unique_ptr
<
OperatorBase
>>::
iterator
>>
intervals
;
...
...
@@ -185,7 +184,7 @@ FusedOperator::FusedOpIntervals(
return
intervals
;
}
FusedOperator
::
Fused
Operator
(
NgraphOperator
::
Ngraph
Operator
(
const
ProgramDesc
&
prog
,
size_t
block_id
,
std
::
vector
<
std
::
unique_ptr
<
OperatorBase
>>::
iterator
start
,
std
::
vector
<
std
::
unique_ptr
<
OperatorBase
>>::
iterator
end
,
...
...
@@ -215,7 +214,7 @@ FusedOperator::FusedOperator(
Process
();
}
void
Fused
Operator
::
Process
()
{
void
Ngraph
Operator
::
Process
()
{
auto
&
bdesc
=
pdesc_
.
Block
(
block_
);
for
(
auto
&
var
:
bdesc
.
AllVars
())
{
if
(
!
(
var
->
GetType
()
==
proto
::
VarType
::
SELECTED_ROWS
||
...
...
@@ -251,8 +250,8 @@ void FusedOperator::Process() {
}
}
void
Fused
Operator
::
RunImpl
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
{
void
Ngraph
Operator
::
RunImpl
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
{
op_state
ng_op_state
=
PARTIAL_TEST
;
auto
&
bdesc
=
pdesc_
.
Block
(
block_
);
for
(
auto
*
op
:
bdesc
.
AllOps
())
{
...
...
@@ -266,19 +265,19 @@ void FusedOperator::RunImpl(const Scope& scope,
ng_op_state
=
ng_op_state
==
PARTIAL_TEST
?
FULL_TEST
:
FULL_TRAIN
;
}
Ngraph
Operator
ngraph_op
(
scope
,
place
,
fused_ops_
,
var_type_map_
,
persistables_
,
fetches_
,
post_op_inputs_
,
ng_op_state
);
ngraph_
op
.
Run
(
scope
,
place
);
Ngraph
Engine
ngraph_engine
(
scope
,
place
,
fused_ops_
,
var_type_map_
,
persistables_
,
fetches_
,
post_op_inputs_
,
ng_op_state
);
ngraph_
engine
.
Run
(
scope
,
place
);
}
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
ngraph
::
Function
>>
Ngraph
Operator
::
func_cache_
=
{};
Ngraph
Engine
::
func_cache_
=
{};
std
::
shared_ptr
<
ngraph
::
runtime
::
Backend
>
Ngraph
Operator
::
backend_
=
std
::
shared_ptr
<
ngraph
::
runtime
::
Backend
>
Ngraph
Engine
::
backend_
=
ngraph
::
runtime
::
Backend
::
create
(
"CPU"
);
void
Ngraph
Operator
::
GetNgInputShape
(
std
::
shared_ptr
<
OperatorBase
>
op
)
{
void
Ngraph
Engine
::
GetNgInputShape
(
std
::
shared_ptr
<
OperatorBase
>
op
)
{
op
->
RuntimeInferShape
(
scope_
,
place_
);
for
(
auto
&
var_name_item
:
op
->
Inputs
())
{
for
(
auto
&
var_name
:
var_name_item
.
second
)
{
...
...
@@ -301,7 +300,7 @@ void NgraphOperator::GetNgInputShape(std::shared_ptr<OperatorBase> op) {
}
}
void
Ngraph
Operator
::
BuildNgNodes
()
{
void
Ngraph
Engine
::
BuildNgNodes
()
{
for
(
auto
&
var_name
:
var_out_
)
{
if
(
var_node_map_
->
find
(
var_name
)
==
var_node_map_
->
end
())
{
auto
*
var
=
scope_
.
FindVar
(
var_name
);
...
...
@@ -323,7 +322,7 @@ void NgraphOperator::BuildNgNodes() {
}
}
void
Ngraph
Operator
::
BuildNgIO
()
{
void
Ngraph
Engine
::
BuildNgIO
()
{
std
::
unordered_set
<
std
::
string
>
inputs
;
std
::
unordered_set
<
std
::
string
>
outputs
;
...
...
@@ -395,7 +394,7 @@ void NgraphOperator::BuildNgIO() {
}
}
void
Ngraph
Operator
::
BuildNgFunction
()
{
void
Ngraph
Engine
::
BuildNgFunction
()
{
BuildNgNodes
();
ngraph_function_
=
nullptr
;
ngraph
::
NodeVector
func_outputs
;
...
...
@@ -416,7 +415,7 @@ void NgraphOperator::BuildNgFunction() {
std
::
make_shared
<
ngraph
::
Function
>
(
func_outputs
,
func_inputs
);
}
std
::
shared_ptr
<
std
::
string
>
Ngraph
Operator
::
GetCacheKey
()
{
std
::
shared_ptr
<
std
::
string
>
Ngraph
Engine
::
GetCacheKey
()
{
auto
cache_key
=
std
::
make_shared
<
std
::
string
>
(
""
);
*
cache_key
+=
std
::
to_string
(
fused_ops_
.
size
());
for
(
auto
&
op
:
fused_ops_
)
{
...
...
@@ -444,7 +443,7 @@ std::shared_ptr<std::string> NgraphOperator::GetCacheKey() {
return
cache_key
;
}
void
Ngraph
Operator
::
GetNgFunction
()
{
void
Ngraph
Engine
::
GetNgFunction
()
{
bool
cache_on
=
true
;
if
(
cache_on
)
{
std
::
string
cache_key_val
=
*
GetCacheKey
();
...
...
@@ -459,8 +458,7 @@ void NgraphOperator::GetNgFunction() {
}
}
void
NgraphOperator
::
Run
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
{
void
NgraphEngine
::
Run
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
{
std
::
vector
<
std
::
shared_ptr
<
ngraph
::
runtime
::
Tensor
>>
t_in
;
std
::
vector
<
std
::
shared_ptr
<
ngraph
::
runtime
::
Tensor
>>
t_out
;
...
...
@@ -545,7 +543,6 @@ void NgraphOperator::Run(const Scope& scope,
}
backend_
->
call
(
ngraph_function_
,
t_out
,
t_in
);
}
// Ngraph
Operator
::RunImpl
}
// Ngraph
Engine
::RunImpl
}
// namespace framework
}
// namespace paddle
#endif
paddle/fluid/framework/ngraph_operator.h
浏览文件 @
0c554a59
...
...
@@ -14,8 +14,6 @@ limitations under the License. */
#pragma once
#ifdef PADDLE_WITH_NGRAPH
#include <algorithm>
#include <string>
#include <unordered_map>
...
...
@@ -34,14 +32,14 @@ limitations under the License. */
namespace
paddle
{
namespace
framework
{
class
Fused
Operator
:
public
OperatorBase
{
class
Ngraph
Operator
:
public
OperatorBase
{
public:
static
std
::
vector
<
std
::
vector
<
std
::
vector
<
std
::
unique_ptr
<
OperatorBase
>>::
iterator
>>
Fused
OpIntervals
(
Ngraph
OpIntervals
(
std
::
vector
<
std
::
unique_ptr
<
paddle
::
framework
::
OperatorBase
>>*
ops
);
explicit
Fused
Operator
(
explicit
Ngraph
Operator
(
const
ProgramDesc
&
prog
,
size_t
block_id
,
std
::
vector
<
std
::
unique_ptr
<
OperatorBase
>>::
iterator
start
,
std
::
vector
<
std
::
unique_ptr
<
OperatorBase
>>::
iterator
end
,
...
...
@@ -64,4 +62,3 @@ class FusedOperator : public OperatorBase {
};
}
// namespace framework
}
// namespace paddle
#endif
paddle/fluid/inference/analysis/passes/ir_graph_build_pass.cc
浏览文件 @
0c554a59
...
...
@@ -44,9 +44,10 @@ void IrGraphBuildPass::RunImpl(Argument *argument) {
argument
->
SetMainProgram
(
program
.
release
());
}
else
if
(
argument
->
model_program_path_valid
()
&&
argument
->
model_params_path_valid
())
{
auto
program
=
LoadModel
(
argument
->
model_program_path
(),
argument
->
model_params_path
(),
argument
->
scope_ptr
(),
place
,
argument
->
model_from_memory
());
auto
program
=
LoadModel
(
argument
->
model_program_path
(),
argument
->
model_params_path
(),
argument
->
scope_ptr
(),
place
,
argument
->
model_from_memory_valid
()
&&
argument
->
model_from_memory
());
argument
->
SetMainProgram
(
program
.
release
());
}
else
{
PADDLE_THROW
(
...
...
paddle/fluid/inference/tests/api/CMakeLists.txt
浏览文件 @
0c554a59
set
(
INFERENCE_EXTRA_DEPS paddle_inference_api paddle_fluid_api ir_pass_manager analysis_predictor
)
set
(
INFERENCE_EXTRA_DEPS paddle_inference_api paddle_fluid_api ir_pass_manager analysis_predictor
benchmark
)
if
(
WITH_GPU AND TENSORRT_FOUND
)
set
(
INFERENCE_EXTRA_DEPS
${
INFERENCE_EXTRA_DEPS
}
analysis
${
analysis_deps
}
ir_pass_manager analysis_predictor
)
...
...
paddle/fluid/inference/tests/api/tester_helper.h
浏览文件 @
0c554a59
...
...
@@ -30,8 +30,10 @@
#include "paddle/fluid/inference/api/helper.h"
#include "paddle/fluid/inference/tests/api/config_printer.h"
#include "paddle/fluid/inference/tests/test_helper.h"
#include "paddle/fluid/inference/utils/benchmark.h"
#include "paddle/fluid/platform/profiler.h"
DEFINE_string
(
model_name
,
""
,
"model name"
);
DEFINE_string
(
infer_model
,
""
,
"model path"
);
DEFINE_string
(
infer_data
,
""
,
"data file"
);
DEFINE_int32
(
batch_size
,
1
,
"batch size."
);
...
...
@@ -40,6 +42,8 @@ DEFINE_bool(test_all_data, false, "Test the all dataset in data file.");
DEFINE_int32
(
num_threads
,
1
,
"Running the inference program in multi-threads."
);
DEFINE_bool
(
use_analysis
,
true
,
"Running the inference program in analysis mode."
);
DEFINE_bool
(
record_benchmark
,
false
,
"Record benchmark after profiling the model"
);
DECLARE_bool
(
profile
);
DECLARE_int32
(
paddle_num_threads
);
...
...
@@ -192,8 +196,16 @@ void TestOneThreadPrediction(
predictor
->
Run
(
inputs
[
j
],
outputs
,
batch_size
);
}
}
PrintTime
(
batch_size
,
num_times
,
1
,
0
,
run_timer
.
toc
()
/
num_times
,
inputs
.
size
());
double
latency
=
run_timer
.
toc
()
/
num_times
;
PrintTime
(
batch_size
,
num_times
,
1
,
0
,
latency
,
inputs
.
size
());
if
(
FLAGS_record_benchmark
)
{
Benchmark
benchmark
;
benchmark
.
SetName
(
FLAGS_model_name
);
benchmark
.
SetBatchSize
(
batch_size
);
benchmark
.
SetLatency
(
latency
);
benchmark
.
PersistToFile
(
"benchmark_record.txt"
);
}
}
}
...
...
paddle/fluid/inference/tests/api/trt_models_tester.cc
浏览文件 @
0c554a59
...
...
@@ -135,6 +135,9 @@ TEST(TensorRT_resnext50, compare) {
TEST
(
TensorRT_resnext50
,
profile
)
{
std
::
string
model_dir
=
FLAGS_infer_model
+
"/resnext50"
;
// Set FLAGS_record_benchmark to true to record benchmark to file.
// FLAGS_record_benchmark=true;
FLAGS_model_name
=
"resnext50"
;
profile
(
model_dir
,
/* use_analysis */
true
,
FLAGS_use_tensorrt
);
}
...
...
paddle/fluid/inference/utils/benchmark.cc
浏览文件 @
0c554a59
...
...
@@ -30,7 +30,7 @@ std::string Benchmark::SerializeToString() const {
ss
<<
'\n'
;
ss
<<
name_
<<
"
\t
"
;
ss
<<
batch_size_
<<
"
\t
"
;
ss
<<
batch_size_
<<
"
\t
\t
"
;
ss
<<
num_threads_
<<
"
\t
"
;
ss
<<
latency_
<<
"
\t
"
;
ss
<<
1000.0
/
latency_
;
...
...
paddle/fluid/inference/utils/visualizer.cc
浏览文件 @
0c554a59
...
...
@@ -26,9 +26,6 @@ DEFINE_string(model_dir, "", "model directory");
DEFINE_string
(
model_program_path
,
""
,
"model program path"
);
DEFINE_string
(
model_params_path
,
""
,
"model params path"
);
USE_PASS
(
graph_viz_pass
);
USE_PASS
(
graph_to_program_pass
);
using
paddle
::
inference
::
analysis
::
Argument
;
namespace
paddle
{
...
...
@@ -40,7 +37,6 @@ void Visualizer::SetArgument(Argument *argument) { argument_ = argument; }
bool
Visualizer
::
Run
()
{
paddle
::
framework
::
InitDevices
(
false
);
paddle
::
inference
::
analysis
::
Analyzer
().
Run
(
argument_
);
return
true
;
}
...
...
@@ -77,7 +73,7 @@ int main(int argc, char *argv[]) {
// Only 1 pass, default filename is 0_ir_origin.dot
// For more details, looking for paddle::inference::analysis::IRPassManager
argument
.
SetIrAnalysisPasses
({
"graph_viz_pass"
});
argument
.
SetIrAnalysisPasses
({
"
infer_clean_graph_pass"
,
"
graph_viz_pass"
});
std
::
unique_ptr
<
paddle
::
framework
::
Scope
>
scope
{
new
paddle
::
framework
::
Scope
()};
...
...
@@ -90,3 +86,7 @@ int main(int argc, char *argv[]) {
return
0
;
}
USE_PASS
(
infer_clean_graph_pass
);
USE_PASS
(
graph_viz_pass
);
USE_PASS
(
graph_to_program_pass
);
paddle/fluid/operators/activation_op.h
浏览文件 @
0c554a59
...
...
@@ -301,23 +301,22 @@ template <typename T>
struct
GeluFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
,
typename
X
,
typename
Out
>
void
operator
()(
Device
d
,
X
x
,
Out
out
)
const
{
auto
temp
=
((
x
*
static_cast
<
T
>
(
M_SQRT1_2
)).
erf
()).
template
cast
<
T
>().
eval
();
auto
temp
=
(
x
*
static_cast
<
T
>
(
M_SQRT1_2
)).
erf
();
out
.
device
(
d
)
=
x
*
static_cast
<
T
>
(
0.5
)
*
(
static_cast
<
T
>
(
1
)
+
temp
);
}
};
template
<
typename
T
>
struct
GeluGradFunctor
:
BaseActivationFunctor
<
T
>
{
bool
Inplace
()
const
{
return
IsInplace
(
"gelu"
);
}
template
<
typename
Device
,
typename
X
,
typename
Out
,
typename
dOut
,
typename
dX
>
void
operator
()(
Device
d
,
X
x
,
Out
out
,
dOut
dout
,
dX
dx
)
const
{
auto
temp
=
(
static_cast
<
T
>
(
0.5
*
M_2_SQRTPI
*
M_SQRT1_2
)
*
x
*
((
-
static_cast
<
T
>
(
0.5
)
*
x
.
square
()).
exp
()))
.
template
cast
<
T
>()
.
eval
();
dx
.
device
(
d
)
=
dout
*
(
out
/
x
+
temp
);
auto
first
=
static_cast
<
T
>
(
0.5
)
*
(
static_cast
<
T
>
(
1
)
+
((
x
*
static_cast
<
T
>
(
M_SQRT1_2
)).
erf
()));
auto
second
=
static_cast
<
T
>
(
0.5
*
M_2_SQRTPI
*
M_SQRT1_2
)
*
x
*
(
-
static_cast
<
T
>
(
0.5
)
*
x
.
square
()).
exp
();
dx
.
device
(
d
)
=
dout
*
(
first
+
second
);
}
};
...
...
paddle/fluid/operators/distributed/brpc_client.cc
浏览文件 @
0c554a59
...
...
@@ -158,7 +158,7 @@ ChannelQueuePtr BRPCClient::GetChannel(const std::string& ep) {
for
(
int
i
=
0
;
i
<
FLAGS_brpc_channel_num
;
++
i
)
{
std
::
shared_ptr
<
ChannelContext
>
c
(
new
ChannelContext
());
if
(
c
->
channel
.
Init
(
ep
.
c_str
(),
&
options
)
!=
0
)
{
LOG
(
ERROR
)
<<
"Fail to initialize channel"
;
LOG
(
FATAL
)
<<
"Fail to initialize channel"
;
return
nullptr
;
}
...
...
paddle/fluid/operators/distributed/grpc_client.cc
浏览文件 @
0c554a59
...
...
@@ -390,8 +390,7 @@ void GRPCClient::Proceed() {
VLOG
(
3
)
<<
c
->
GetVarHandlePtr
()
->
String
()
<<
" process"
;
c
->
Process
();
}
else
if
(
c
->
status_
.
error_code
()
==
grpc
::
StatusCode
::
DEADLINE_EXCEEDED
)
{
// FIXME(gongwb): parse error_details?
LOG
(
ERROR
)
<<
c
->
GetVarHandlePtr
()
->
String
()
LOG
(
FATAL
)
<<
c
->
GetVarHandlePtr
()
->
String
()
<<
" meets grpc error, error_code:"
<<
c
->
status_
.
error_code
()
<<
" error_message:"
<<
c
->
status_
.
error_message
()
<<
" error_details:"
<<
c
->
status_
.
error_details
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录