Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
53d558cd
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
53d558cd
编写于
1月 28, 2019
作者:
J
JiabinYang
浏览文件
操作
浏览文件
下载
差异文件
test=develop, polish code and merge develop
上级
8e3da976
10bc9ffc
变更
21
显示空白变更内容
内联
并排
Showing
21 changed file
with
305 addition
and
195 deletion
+305
-195
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+0
-4
paddle/fluid/framework/details/execution_strategy.h
paddle/fluid/framework/details/execution_strategy.h
+1
-1
paddle/fluid/framework/operator.cc
paddle/fluid/framework/operator.cc
+5
-4
paddle/fluid/inference/analysis/argument.h
paddle/fluid/inference/analysis/argument.h
+3
-1
paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc
...e/fluid/inference/analysis/passes/memory_optimize_pass.cc
+117
-71
paddle/fluid/inference/analysis/passes/memory_optimize_pass.h
...le/fluid/inference/analysis/passes/memory_optimize_pass.h
+4
-2
paddle/fluid/inference/api/analysis_config.cc
paddle/fluid/inference/api/analysis_config.cc
+20
-4
paddle/fluid/inference/api/analysis_predictor.cc
paddle/fluid/inference/api/analysis_predictor.cc
+13
-11
paddle/fluid/inference/api/analysis_predictor.h
paddle/fluid/inference/api/analysis_predictor.h
+1
-1
paddle/fluid/inference/api/paddle_analysis_config.h
paddle/fluid/inference/api/paddle_analysis_config.h
+5
-13
paddle/fluid/inference/tests/api/analyzer_dam_tester.cc
paddle/fluid/inference/tests/api/analyzer_dam_tester.cc
+21
-3
paddle/fluid/operators/distributed/request_handler_impl.cc
paddle/fluid/operators/distributed/request_handler_impl.cc
+5
-0
paddle/fluid/operators/distributed/rpc_server.cc
paddle/fluid/operators/distributed/rpc_server.cc
+22
-15
paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc
paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc
+7
-1
paddle/fluid/operators/grid_sampler_op.cc
paddle/fluid/operators/grid_sampler_op.cc
+8
-6
paddle/fluid/operators/ngraph/CMakeLists.txt
paddle/fluid/operators/ngraph/CMakeLists.txt
+1
-0
paddle/fluid/operators/ngraph/ngraph_bridge.cc
paddle/fluid/operators/ngraph/ngraph_bridge.cc
+18
-18
paddle/fluid/operators/ngraph/ngraph_bridge.h
paddle/fluid/operators/ngraph/ngraph_bridge.h
+6
-6
paddle/fluid/operators/ngraph/ngraph_engine.cc
paddle/fluid/operators/ngraph/ngraph_engine.cc
+6
-7
python/paddle/fluid/contrib/int8_inference/utility.py
python/paddle/fluid/contrib/int8_inference/utility.py
+29
-5
python/paddle/fluid/contrib/tests/test_calibration.py
python/paddle/fluid/contrib/tests/test_calibration.py
+13
-22
未找到文件。
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
53d558cd
...
@@ -129,10 +129,6 @@ cc_test(version_test SRCS version_test.cc DEPS version)
...
@@ -129,10 +129,6 @@ cc_test(version_test SRCS version_test.cc DEPS version)
cc_library
(
proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog version
)
cc_library
(
proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog version
)
if
(
WITH_NGRAPH
)
cc_library
(
ngraph_bridge SRCS ngraph_bridge.cc DEPS operator framework_proto ngraph
)
endif
(
WITH_NGRAPH
)
cc_library
(
op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc
)
cc_library
(
op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc
)
nv_test
(
op_registry_test SRCS op_registry_test.cc DEPS op_registry
)
nv_test
(
op_registry_test SRCS op_registry_test.cc DEPS op_registry
)
...
...
paddle/fluid/framework/details/execution_strategy.h
浏览文件 @
53d558cd
...
@@ -28,7 +28,7 @@ struct ExecutionStrategy {
...
@@ -28,7 +28,7 @@ struct ExecutionStrategy {
// If we set this to 1, we will delete all variables when finish a batch. and
// If we set this to 1, we will delete all variables when finish a batch. and
// this will loss 15%+ performance.
// this will loss 15%+ performance.
// Please be aware about this parameters.
// Please be aware about this parameters.
size_t
num_iteration_per_drop_scope_
{
1
00
};
size_t
num_iteration_per_drop_scope_
{
1
};
ExecutorType
type_
{
kDefault
};
ExecutorType
type_
{
kDefault
};
bool
dry_run_
{
false
};
bool
dry_run_
{
false
};
};
};
...
...
paddle/fluid/framework/operator.cc
浏览文件 @
53d558cd
...
@@ -1072,8 +1072,9 @@ Scope* OperatorWithKernel::PrepareData(
...
@@ -1072,8 +1072,9 @@ Scope* OperatorWithKernel::PrepareData(
proto
::
VarType
::
Type
OperatorWithKernel
::
IndicateDataType
(
proto
::
VarType
::
Type
OperatorWithKernel
::
IndicateDataType
(
const
ExecutionContext
&
ctx
)
const
{
const
ExecutionContext
&
ctx
)
const
{
proto
::
VarType
::
Type
defaut_data_type
=
static_cast
<
proto
::
VarType
::
Type
>
(
-
1
);
proto
::
VarType
::
Type
dafault_data_type
=
proto
::
VarType
::
Type
data_type
=
defaut_data_type
;
static_cast
<
proto
::
VarType
::
Type
>
(
-
1
);
proto
::
VarType
::
Type
data_type
=
dafault_data_type
;
for
(
auto
&
input
:
this
->
inputs_
)
{
for
(
auto
&
input
:
this
->
inputs_
)
{
const
std
::
vector
<
const
Variable
*>
vars
=
ctx
.
MultiInputVar
(
input
.
first
);
const
std
::
vector
<
const
Variable
*>
vars
=
ctx
.
MultiInputVar
(
input
.
first
);
for
(
size_t
i
=
0
;
i
<
vars
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
vars
.
size
();
++
i
)
{
...
@@ -1092,7 +1093,7 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
...
@@ -1092,7 +1093,7 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
input
.
first
,
i
);
input
.
first
,
i
);
proto
::
VarType
::
Type
tmp
=
t
->
type
();
proto
::
VarType
::
Type
tmp
=
t
->
type
();
PADDLE_ENFORCE
(
PADDLE_ENFORCE
(
tmp
==
data_type
||
data_type
==
d
efau
t_data_type
,
tmp
==
data_type
||
data_type
==
d
afaul
t_data_type
,
"DataType of Paddle Op %s must be the same. Get (%d) != (%d)"
,
"DataType of Paddle Op %s must be the same. Get (%d) != (%d)"
,
Type
(),
DataTypeToString
(
data_type
),
DataTypeToString
(
tmp
));
Type
(),
DataTypeToString
(
data_type
),
DataTypeToString
(
tmp
));
data_type
=
tmp
;
data_type
=
tmp
;
...
@@ -1100,7 +1101,7 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
...
@@ -1100,7 +1101,7 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
}
}
}
}
}
}
PADDLE_ENFORCE
(
data_type
!=
d
efau
t_data_type
,
PADDLE_ENFORCE
(
data_type
!=
d
afaul
t_data_type
,
"DataType should be indicated by input"
);
"DataType should be indicated by input"
);
return
data_type
;
return
data_type
;
}
}
...
...
paddle/fluid/inference/analysis/argument.h
浏览文件 @
53d558cd
...
@@ -133,7 +133,9 @@ struct Argument {
...
@@ -133,7 +133,9 @@ struct Argument {
// Memory optimized related.
// Memory optimized related.
DECL_ARGUMENT_FIELD
(
enable_memory_optim
,
EnableMemoryOptim
,
bool
);
DECL_ARGUMENT_FIELD
(
enable_memory_optim
,
EnableMemoryOptim
,
bool
);
DECL_ARGUMENT_FIELD
(
memory_optim_force_update
,
MemoryOptimForceUpdate
,
bool
);
DECL_ARGUMENT_FIELD
(
static_memory_optim
,
StaticMemoryOptim
,
bool
);
DECL_ARGUMENT_FIELD
(
static_memory_optim_force_update
,
StaticMemoryOptimForceUpdate
,
bool
);
// Indicate which kind of sort algorithm is used for operators, the memory
// Indicate which kind of sort algorithm is used for operators, the memory
// optimization relays on the sort algorithm.
// optimization relays on the sort algorithm.
DECL_ARGUMENT_FIELD
(
memory_optim_sort_kind
,
MemoryOptimSortKind
,
int
);
DECL_ARGUMENT_FIELD
(
memory_optim_sort_kind
,
MemoryOptimSortKind
,
int
);
...
...
paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc
浏览文件 @
53d558cd
...
@@ -444,6 +444,26 @@ std::vector<std::map<std::string, std::vector<int>>> DeseralizeBatchVarShapes(
...
@@ -444,6 +444,26 @@ std::vector<std::map<std::string, std::vector<int>>> DeseralizeBatchVarShapes(
return
batch_shapes
;
return
batch_shapes
;
}
}
// Replace the -1 in shape to a real number to fake the shape.
std
::
vector
<
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>>
FakeBatchVarShapes
(
const
framework
::
ProgramDesc
&
program
)
{
std
::
vector
<
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>>
res
;
res
.
emplace_back
();
auto
&
record
=
res
.
front
();
const
int
fake_batch_size
=
3
;
for
(
auto
*
var
:
program
.
Block
(
0
).
AllVars
())
{
if
(
var
->
GetType
()
==
framework
::
proto
::
VarType
::
Type
::
VarType_Type_LOD_TENSOR
)
{
auto
shape
=
var
->
GetShape
();
for
(
auto
&
v
:
shape
)
{
if
(
v
<
0
)
v
=
fake_batch_size
;
}
record
[
var
->
Name
()].
assign
(
shape
.
begin
(),
shape
.
end
());
}
}
return
res
;
}
// Calculate the average dim of each tensor from the batch shape cache.
// Calculate the average dim of each tensor from the batch shape cache.
std
::
unordered_map
<
std
::
string
,
size_t
>
GetBatchAverageSize
(
std
::
unordered_map
<
std
::
string
,
size_t
>
GetBatchAverageSize
(
const
std
::
vector
<
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>>&
batches
)
{
const
std
::
vector
<
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>>&
batches
)
{
...
@@ -478,6 +498,7 @@ std::vector<std::unordered_set<std::string>> AnalysisBatchShapesByBatchSize(
...
@@ -478,6 +498,7 @@ std::vector<std::unordered_set<std::string>> AnalysisBatchShapesByBatchSize(
std
::
unordered_map
<
std
::
string
,
std
::
stringstream
>
var_batchsize_hashes
;
std
::
unordered_map
<
std
::
string
,
std
::
stringstream
>
var_batchsize_hashes
;
for
(
auto
&
batch
:
batches
)
{
for
(
auto
&
batch
:
batches
)
{
for
(
auto
&
ele
:
batch
)
{
for
(
auto
&
ele
:
batch
)
{
PADDLE_ENFORCE
(
!
ele
.
second
.
empty
());
int
batch_size
=
ele
.
second
.
front
();
int
batch_size
=
ele
.
second
.
front
();
// TODO(Superjomn) might consume large memory here, use combine hash.
// TODO(Superjomn) might consume large memory here, use combine hash.
var_batchsize_hashes
[
ele
.
first
]
<<
batch_size
;
var_batchsize_hashes
[
ele
.
first
]
<<
batch_size
;
...
@@ -538,9 +559,21 @@ std::vector<std::unordered_set<std::string>> AnalysisBatchShapesBySimilarSize(
...
@@ -538,9 +559,21 @@ std::vector<std::unordered_set<std::string>> AnalysisBatchShapesBySimilarSize(
std
::
string
MemoryOptimizePass
::
repr
()
const
{
return
"memory optimize pass"
;
}
std
::
string
MemoryOptimizePass
::
repr
()
const
{
return
"memory optimize pass"
;
}
std
::
pair
<
size_t
,
size_t
>
GetRange
(
const
std
::
unordered_map
<
std
::
string
,
size_t
>&
ave_size
)
{
auto
res
=
std
::
make_pair
(
std
::
numeric_limits
<
size_t
>::
max
(),
std
::
numeric_limits
<
size_t
>::
min
());
for
(
auto
&
item
:
ave_size
)
{
res
.
first
=
std
::
min
(
item
.
second
,
res
.
first
);
res
.
second
=
std
::
max
(
item
.
second
,
res
.
second
);
}
return
res
;
}
void
MemoryOptimizePass
::
RunImpl
(
Argument
*
argument
)
{
void
MemoryOptimizePass
::
RunImpl
(
Argument
*
argument
)
{
// When force update, should not optimize memory.
// When force update, should not optimize memory.
if
(
!
argument
->
enable_memory_optim
()
||
argument
->
memory_optim_force_update
())
if
(
!
argument
->
enable_memory_optim
()
||
argument
->
static_memory_optim_force_update
())
return
;
return
;
graph_
=
argument
->
main_graph_ptr
();
graph_
=
argument
->
main_graph_ptr
();
...
@@ -549,11 +582,26 @@ void MemoryOptimizePass::RunImpl(Argument* argument) {
...
@@ -549,11 +582,26 @@ void MemoryOptimizePass::RunImpl(Argument* argument) {
argument
->
model_program_path_valid
()
?
argument
->
model_program_path
()
argument
->
model_program_path_valid
()
?
argument
->
model_program_path
()
:
""
);
:
""
);
VLOG
(
3
)
<<
"Load memory cache from "
<<
path
;
VLOG
(
3
)
<<
"Load memory cache from "
<<
path
;
if
(
inference
::
IsFileExists
(
path
))
{
std
::
vector
<
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>>
batches
;
VLOG
(
4
)
<<
"Performing memory optimize"
;
auto
batches
=
DeseralizeBatchVarShapes
(
path
);
if
(
argument
->
static_memory_optim
()
&&
inference
::
IsFileExists
(
path
))
{
string
::
PrettyLogInfo
(
"--- Performing static memory optimize"
);
batches
=
DeseralizeBatchVarShapes
(
path
);
}
else
{
string
::
PrettyLogInfo
(
"--- Performing dynamic memory optimize"
);
batches
=
FakeBatchVarShapes
(
argument
->
main_program
());
}
auto
var_batch_ave_size
=
GetBatchAverageSize
(
batches
);
auto
var_batch_ave_size
=
GetBatchAverageSize
(
batches
);
// Get min and max memory size.
const
auto
range
=
GetRange
(
var_batch_ave_size
);
const
int
cluster_size
=
std
::
max
(
static_cast
<
int
>
((
range
.
second
-
range
.
first
)
/
100
/*cluster num*/
),
1024
);
const
int
cluster_size1
=
std
::
max
(
static_cast
<
int
>
((
range
.
second
-
range
.
first
)
/
1000
/*cluster num*/
),
1024
);
std
::
unordered_map
<
std
::
string
,
Node
*>
tensor_nodes
;
std
::
unordered_map
<
std
::
string
,
Node
*>
tensor_nodes
;
space_table_t
space_table
;
space_table_t
space_table
;
CollectVarMemorySize
(
var_batch_ave_size
,
&
tensor_nodes
,
&
space_table
);
CollectVarMemorySize
(
var_batch_ave_size
,
&
tensor_nodes
,
&
space_table
);
...
@@ -564,6 +612,8 @@ void MemoryOptimizePass::RunImpl(Argument* argument) {
...
@@ -564,6 +612,8 @@ void MemoryOptimizePass::RunImpl(Argument* argument) {
std
::
vector
<
std
::
function
<
MemoryAllocation
()
>>
strategies
;
std
::
vector
<
std
::
function
<
MemoryAllocation
()
>>
strategies
;
for
(
int
sort_kind
=
0
;
sort_kind
<
2
;
sort_kind
++
)
{
for
(
int
sort_kind
=
0
;
sort_kind
<
2
;
sort_kind
++
)
{
if
(
argument
->
static_memory_optim
())
{
// This strategy only make scene in static memory optimize.
strategies
.
emplace_back
([
&
,
sort_kind
]
{
strategies
.
emplace_back
([
&
,
sort_kind
]
{
auto
clustered_vars_by_batch_size
=
auto
clustered_vars_by_batch_size
=
AnalysisBatchShapesByBatchSize
(
batches
);
AnalysisBatchShapesByBatchSize
(
batches
);
...
@@ -572,22 +622,23 @@ void MemoryOptimizePass::RunImpl(Argument* argument) {
...
@@ -572,22 +622,23 @@ void MemoryOptimizePass::RunImpl(Argument* argument) {
space_table
,
&
reuse_table
,
sort_kind
,
&
allocation
);
space_table
,
&
reuse_table
,
sort_kind
,
&
allocation
);
return
allocation
;
return
allocation
;
});
});
}
strategies
.
emplace_back
([
&
,
sort_kind
]
{
strategies
.
emplace_back
([
&
,
sort_kind
]
{
auto
clustered_vars_by_ave_size
=
AnalysisBatchShapesBySimilarSize
(
auto
clustered_vars_by_ave_size
=
space_table
,
batches
,
1024
);
// interval 1kb
AnalysisBatchShapesBySimilarSize
(
space_table
,
batches
,
cluster_size
);
MemoryAllocation
allocation
;
MemoryAllocation
allocation
;
MakeReusePlan
(
clustered_vars_by_ave_size
,
var_batch_ave_siz
e
,
MakeReusePlan
(
clustered_vars_by_ave_size
,
var_batch_ave_size
,
space_tabl
e
,
space_table
,
&
reuse_table
,
sort_kind
,
&
allocation
);
&
reuse_table
,
sort_kind
,
&
allocation
);
return
allocation
;
return
allocation
;
});
});
strategies
.
emplace_back
([
&
,
sort_kind
]
{
strategies
.
emplace_back
([
&
,
sort_kind
]
{
auto
clustered_vars_by_ave_size
=
AnalysisBatchShapesBySimilarSize
(
auto
clustered_vars_by_ave_size
=
space_table
,
batches
,
1024
*
1024
);
// interval 1MB
AnalysisBatchShapesBySimilarSize
(
space_table
,
batches
,
cluster_size1
);
MemoryAllocation
allocation
;
MemoryAllocation
allocation
;
MakeReusePlan
(
clustered_vars_by_ave_size
,
var_batch_ave_siz
e
,
MakeReusePlan
(
clustered_vars_by_ave_size
,
var_batch_ave_size
,
space_tabl
e
,
space_table
,
&
reuse_table
,
sort_kind
,
&
allocation
);
&
reuse_table
,
sort_kind
,
&
allocation
);
return
allocation
;
return
allocation
;
});
});
...
@@ -596,8 +647,8 @@ void MemoryOptimizePass::RunImpl(Argument* argument) {
...
@@ -596,8 +647,8 @@ void MemoryOptimizePass::RunImpl(Argument* argument) {
space_table
,
batches
,
space_table
,
batches
,
std
::
numeric_limits
<
int
>::
max
());
// no intervals
std
::
numeric_limits
<
int
>::
max
());
// no intervals
MemoryAllocation
allocation
;
MemoryAllocation
allocation
;
MakeReusePlan
(
clustered_vars_by_ave_size
,
var_batch_ave_siz
e
,
MakeReusePlan
(
clustered_vars_by_ave_size
,
var_batch_ave_size
,
space_tabl
e
,
space_table
,
&
reuse_table
,
sort_kind
,
&
allocation
);
&
reuse_table
,
sort_kind
,
&
allocation
);
return
allocation
;
return
allocation
;
});
});
}
}
...
@@ -615,19 +666,15 @@ void MemoryOptimizePass::RunImpl(Argument* argument) {
...
@@ -615,19 +666,15 @@ void MemoryOptimizePass::RunImpl(Argument* argument) {
}
}
}
}
if
(
!
best_strategy
)
{
if
(
!
best_strategy
)
{
LOG
(
ERROR
)
LOG
(
ERROR
)
<<
"This model makes poor memory optimize, skip memory optimize"
;
<<
"This model makes poor memory optimize, skip memory optimize"
;
return
;
return
;
}
}
auto
memory_allocation
=
(
*
best_strategy
)();
auto
memory_allocation
=
(
*
best_strategy
)();
string
::
PrettyLogH2
(
string
::
PrettyLogInfo
(
"--- Saved %.2f%s memory for workspace(temporary variables)"
,
"--- Saved %.2f%s memory for workspace(temporary variables)"
,
memory_allocation
.
GetSavingRatio
()
*
100
,
"%"
);
memory_allocation
.
GetSavingRatio
()
*
100
,
"%"
);
string
::
PrettyLogDetail
(
"--- Allocated %d MB"
,
memory_allocation
.
allocated
/
1024.
/
1024.
);
string
::
PrettyLogDetail
(
"--- Saved %d MB"
,
memory_allocation
.
saved
/
1024.
/
1024.
);
argument
->
main_graph
().
Set
(
framework
::
ir
::
kGraphToProgramVarsToRemove
,
argument
->
main_graph
().
Set
(
framework
::
ir
::
kGraphToProgramVarsToRemove
,
new
std
::
unordered_set
<
std
::
string
>
);
new
std
::
unordered_set
<
std
::
string
>
);
auto
&
vars2remove
=
auto
&
vars2remove
=
...
@@ -636,7 +683,6 @@ void MemoryOptimizePass::RunImpl(Argument* argument) {
...
@@ -636,7 +683,6 @@ void MemoryOptimizePass::RunImpl(Argument* argument) {
PerformReusePlan
(
reuse_table
,
memory_allocation
.
sort_kind
,
&
vars2remove
);
PerformReusePlan
(
reuse_table
,
memory_allocation
.
sort_kind
,
&
vars2remove
);
argument
->
SetMemoryOptimSortKind
(
memory_allocation
.
sort_kind
);
argument
->
SetMemoryOptimSortKind
(
memory_allocation
.
sort_kind
);
}
}
}
float
MemoryOptimizePass
::
MemoryAllocation
::
GetSavingRatio
()
const
{
float
MemoryOptimizePass
::
MemoryAllocation
::
GetSavingRatio
()
const
{
...
...
paddle/fluid/inference/analysis/passes/memory_optimize_pass.h
浏览文件 @
53d558cd
...
@@ -13,9 +13,11 @@
...
@@ -13,9 +13,11 @@
// limitations under the License.
// limitations under the License.
#pragma once
#pragma once
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/inference/analysis/analysis_pass.h"
#include "paddle/fluid/inference/analysis/analysis_pass.h"
#include "paddle/fluid/
inference/analysis/passes/memory_optimize_pass
.h"
#include "paddle/fluid/
platform/port
.h"
namespace
paddle
{
namespace
paddle
{
namespace
inference
{
namespace
inference
{
...
...
paddle/fluid/inference/api/analysis_config.cc
浏览文件 @
53d558cd
...
@@ -95,7 +95,8 @@ contrib::AnalysisConfig::AnalysisConfig(const contrib::AnalysisConfig &other) {
...
@@ -95,7 +95,8 @@ contrib::AnalysisConfig::AnalysisConfig(const contrib::AnalysisConfig &other) {
CP_MEMBER
(
memory_pool_init_size_mb_
);
CP_MEMBER
(
memory_pool_init_size_mb_
);
CP_MEMBER
(
enable_memory_optim_
);
CP_MEMBER
(
enable_memory_optim_
);
CP_MEMBER
(
memory_optim_force_update_
);
CP_MEMBER
(
static_memory_optim_
);
CP_MEMBER
(
static_memory_optim_force_update_
);
// TensorRT releated.
// TensorRT releated.
CP_MEMBER
(
use_tensorrt_
);
CP_MEMBER
(
use_tensorrt_
);
CP_MEMBER
(
tensorrt_workspace_size_
);
CP_MEMBER
(
tensorrt_workspace_size_
);
...
@@ -238,7 +239,8 @@ std::string contrib::AnalysisConfig::SerializeInfoCache() {
...
@@ -238,7 +239,8 @@ std::string contrib::AnalysisConfig::SerializeInfoCache() {
ss
<<
tensorrt_min_subgraph_size_
;
ss
<<
tensorrt_min_subgraph_size_
;
ss
<<
enable_memory_optim_
;
ss
<<
enable_memory_optim_
;
ss
<<
memory_optim_force_update_
;
ss
<<
static_memory_optim_
;
ss
<<
static_memory_optim_force_update_
;
ss
<<
use_mkldnn_
;
ss
<<
use_mkldnn_
;
for
(
auto
&
item
:
mkldnn_enabled_op_types_
)
ss
<<
item
;
for
(
auto
&
item
:
mkldnn_enabled_op_types_
)
ss
<<
item
;
...
@@ -278,9 +280,11 @@ float contrib::AnalysisConfig::fraction_of_gpu_memory_for_pool() const {
...
@@ -278,9 +280,11 @@ float contrib::AnalysisConfig::fraction_of_gpu_memory_for_pool() const {
#endif
#endif
}
}
void
contrib
::
AnalysisConfig
::
EnableMemoryOptim
(
bool
force_update_cache
)
{
void
contrib
::
AnalysisConfig
::
EnableMemoryOptim
(
bool
static_optim
,
bool
force_update_static_cache
)
{
enable_memory_optim_
=
true
;
enable_memory_optim_
=
true
;
memory_optim_force_update_
=
force_update_cache
;
static_memory_optim_
=
static_optim
;
static_memory_optim_force_update_
=
force_update_static_cache
;
Update
();
Update
();
}
}
...
@@ -300,4 +304,16 @@ void contrib::AnalysisConfig::SetModelBuffer(const char *prog_buffer,
...
@@ -300,4 +304,16 @@ void contrib::AnalysisConfig::SetModelBuffer(const char *prog_buffer,
Update
();
Update
();
}
}
NativeConfig
contrib
::
AnalysisConfig
::
ToNativeConfig
()
const
{
NativeConfig
config
;
config
.
model_dir
=
model_dir_
;
config
.
prog_file
=
prog_file_
;
config
.
param_file
=
params_file_
;
config
.
use_gpu
=
use_gpu_
;
config
.
device
=
device_id_
;
config
.
fraction_of_gpu_memory
=
fraction_of_gpu_memory_for_pool
();
config
.
specify_input_name
=
specify_input_name_
;
return
config
;
}
}
// namespace paddle
}
// namespace paddle
paddle/fluid/inference/api/analysis_predictor.cc
浏览文件 @
53d558cd
...
@@ -298,15 +298,15 @@ void AnalysisPredictor::GetFetchOne(const framework::LoDTensor &fetch,
...
@@ -298,15 +298,15 @@ void AnalysisPredictor::GetFetchOne(const framework::LoDTensor &fetch,
bool
AnalysisPredictor
::
GetFetch
(
std
::
vector
<
PaddleTensor
>
*
outputs
,
bool
AnalysisPredictor
::
GetFetch
(
std
::
vector
<
PaddleTensor
>
*
outputs
,
framework
::
Scope
*
scope
)
{
framework
::
Scope
*
scope
)
{
VLOG
(
3
)
<<
"Predictor::get_fetch"
;
VLOG
(
3
)
<<
"Predictor::get_fetch"
;
outputs
->
resize
(
fetchs_
.
size
());
outputs
->
resize
(
fetch
e
s_
.
size
());
for
(
size_t
i
=
0
;
i
<
fetchs_
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
fetch
e
s_
.
size
();
++
i
)
{
int
idx
=
boost
::
get
<
int
>
(
fetchs_
[
i
]
->
GetAttr
(
"col"
));
int
idx
=
boost
::
get
<
int
>
(
fetch
e
s_
[
i
]
->
GetAttr
(
"col"
));
PADDLE_ENFORCE
((
size_t
)
idx
==
i
);
PADDLE_ENFORCE
((
size_t
)
idx
==
i
);
framework
::
LoDTensor
&
fetch
=
framework
::
LoDTensor
&
fetch
=
framework
::
GetFetchVariable
(
*
scope
,
"fetch"
,
idx
);
framework
::
GetFetchVariable
(
*
scope
,
"fetch"
,
idx
);
auto
type
=
fetch
.
type
();
auto
type
=
fetch
.
type
();
auto
output
=
&
(
outputs
->
at
(
i
));
auto
output
=
&
(
outputs
->
at
(
i
));
output
->
name
=
fetchs_
[
idx
]
->
Input
(
"X"
)[
0
];
output
->
name
=
fetch
e
s_
[
idx
]
->
Input
(
"X"
)[
0
];
if
(
type
==
framework
::
proto
::
VarType
::
FP32
)
{
if
(
type
==
framework
::
proto
::
VarType
::
FP32
)
{
GetFetchOne
<
float
>
(
fetch
,
output
);
GetFetchOne
<
float
>
(
fetch
,
output
);
output
->
dtype
=
PaddleDType
::
FLOAT32
;
output
->
dtype
=
PaddleDType
::
FLOAT32
;
...
@@ -327,7 +327,9 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
...
@@ -327,7 +327,9 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
argument_
.
SetUseGPU
(
config_
.
use_gpu
());
argument_
.
SetUseGPU
(
config_
.
use_gpu
());
argument_
.
SetGPUDeviceId
(
config_
.
gpu_device_id
());
argument_
.
SetGPUDeviceId
(
config_
.
gpu_device_id
());
argument_
.
SetEnableMemoryOptim
(
config_
.
enable_memory_optim
());
argument_
.
SetEnableMemoryOptim
(
config_
.
enable_memory_optim
());
argument_
.
SetMemoryOptimForceUpdate
(
config_
.
memory_optim_force_update_
);
argument_
.
SetStaticMemoryOptim
(
config_
.
static_memory_optim_
);
argument_
.
SetStaticMemoryOptimForceUpdate
(
config_
.
static_memory_optim_force_update_
);
argument_
.
SetModelFromMemory
(
config_
.
model_from_memory_
);
argument_
.
SetModelFromMemory
(
config_
.
model_from_memory_
);
// Analyze inference_program
// Analyze inference_program
if
(
!
config_
.
model_dir
().
empty
())
{
if
(
!
config_
.
model_dir
().
empty
())
{
...
@@ -422,10 +424,10 @@ void AnalysisPredictor::PrepareFeedFetch() {
...
@@ -422,10 +424,10 @@ void AnalysisPredictor::PrepareFeedFetch() {
feed_names_
[
op
->
Output
(
"Out"
)[
0
]]
=
idx
;
feed_names_
[
op
->
Output
(
"Out"
)[
0
]]
=
idx
;
}
else
if
(
op
->
Type
()
==
"fetch"
)
{
}
else
if
(
op
->
Type
()
==
"fetch"
)
{
int
idx
=
boost
::
get
<
int
>
(
op
->
GetAttr
(
"col"
));
int
idx
=
boost
::
get
<
int
>
(
op
->
GetAttr
(
"col"
));
if
(
fetchs_
.
size
()
<=
static_cast
<
size_t
>
(
idx
))
{
if
(
fetch
e
s_
.
size
()
<=
static_cast
<
size_t
>
(
idx
))
{
fetchs_
.
resize
(
idx
+
1
);
fetch
e
s_
.
resize
(
idx
+
1
);
}
}
fetchs_
[
idx
]
=
op
;
fetch
e
s_
[
idx
]
=
op
;
}
}
}
}
}
}
...
@@ -638,12 +640,12 @@ bool AnalysisPredictor::need_collect_var_shapes_for_memory_optim() {
...
@@ -638,12 +640,12 @@ bool AnalysisPredictor::need_collect_var_shapes_for_memory_optim() {
// check if the cache exists
// check if the cache exists
if
(
!
config_
.
enable_memory_optim
())
{
if
(
!
config_
.
enable_memory_optim
())
{
need
=
false
;
need
=
false
;
}
else
if
(
config_
.
enable_memory_optim
()
&&
}
else
if
(
config_
.
static_memory_optim_
&&
!
inference
::
IsFileExists
(
inference
::
analysis
::
GetMemoryCachePath
(
!
inference
::
IsFileExists
(
inference
::
analysis
::
GetMemoryCachePath
(
config_
.
model_dir
(),
config_
.
prog_file
())))
{
config_
.
model_dir
(),
config_
.
prog_file
())))
{
need
=
true
;
need
=
true
;
}
else
if
(
config_
.
enable_memory_optim
()
&&
}
else
if
(
config_
.
static_memory_optim_
&&
config_
.
memory_optim_force_update_
)
{
config_
.
static_
memory_optim_force_update_
)
{
need
=
true
;
need
=
true
;
}
}
...
...
paddle/fluid/inference/api/analysis_predictor.h
浏览文件 @
53d558cd
...
@@ -115,7 +115,7 @@ class AnalysisPredictor : public PaddlePredictor {
...
@@ -115,7 +115,7 @@ class AnalysisPredictor : public PaddlePredictor {
std
::
shared_ptr
<
framework
::
ProgramDesc
>
inference_program_
;
std
::
shared_ptr
<
framework
::
ProgramDesc
>
inference_program_
;
std
::
vector
<
framework
::
OpDesc
*>
feeds_
;
std
::
vector
<
framework
::
OpDesc
*>
feeds_
;
std
::
map
<
std
::
string
,
size_t
>
feed_names_
;
std
::
map
<
std
::
string
,
size_t
>
feed_names_
;
std
::
vector
<
framework
::
OpDesc
*>
fetchs_
;
std
::
vector
<
framework
::
OpDesc
*>
fetch
e
s_
;
// Memory buffer for feed inputs. The temporary LoDTensor will cause serious
// Memory buffer for feed inputs. The temporary LoDTensor will cause serious
// concurrency problems, wrong results and memory leak, so cache them.
// concurrency problems, wrong results and memory leak, so cache them.
std
::
vector
<
framework
::
LoDTensor
>
feed_tensors_
;
std
::
vector
<
framework
::
LoDTensor
>
feed_tensors_
;
...
...
paddle/fluid/inference/api/paddle_analysis_config.h
浏览文件 @
53d558cd
...
@@ -162,17 +162,7 @@ struct AnalysisConfig {
...
@@ -162,17 +162,7 @@ struct AnalysisConfig {
/** Transform the AnalysisConfig to NativeConfig.
/** Transform the AnalysisConfig to NativeConfig.
*/
*/
NativeConfig
ToNativeConfig
()
const
{
NativeConfig
ToNativeConfig
()
const
;
NativeConfig
config
;
config
.
model_dir
=
model_dir_
;
config
.
prog_file
=
prog_file_
;
config
.
param_file
=
params_file_
;
config
.
use_gpu
=
use_gpu_
;
config
.
device
=
device_id_
;
config
.
fraction_of_gpu_memory
=
fraction_of_gpu_memory_for_pool
();
config
.
specify_input_name
=
specify_input_name_
;
return
config
;
}
/** Specify the operator type list to use MKLDNN acceleration.
/** Specify the operator type list to use MKLDNN acceleration.
* @param op_list the operator type list.
* @param op_list the operator type list.
*/
*/
...
@@ -195,7 +185,8 @@ struct AnalysisConfig {
...
@@ -195,7 +185,8 @@ struct AnalysisConfig {
/** Turn on memory optimize
/** Turn on memory optimize
* NOTE still in development, will release latter.
* NOTE still in development, will release latter.
*/
*/
void
EnableMemoryOptim
(
bool
force_update_cache
=
false
);
void
EnableMemoryOptim
(
bool
static_optim
=
false
,
bool
force_update_static_cache
=
false
);
/** Tell whether the memory optimization is activated. */
/** Tell whether the memory optimization is activated. */
bool
enable_memory_optim
()
const
;
bool
enable_memory_optim
()
const
;
...
@@ -241,7 +232,8 @@ struct AnalysisConfig {
...
@@ -241,7 +232,8 @@ struct AnalysisConfig {
// memory reuse related.
// memory reuse related.
bool
enable_memory_optim_
{
false
};
bool
enable_memory_optim_
{
false
};
bool
memory_optim_force_update_
{
false
};
bool
static_memory_optim_
{
false
};
bool
static_memory_optim_force_update_
{
false
};
bool
use_mkldnn_
{
false
};
bool
use_mkldnn_
{
false
};
std
::
unordered_set
<
std
::
string
>
mkldnn_enabled_op_types_
;
std
::
unordered_set
<
std
::
string
>
mkldnn_enabled_op_types_
;
...
...
paddle/fluid/inference/tests/api/analyzer_dam_tester.cc
浏览文件 @
53d558cd
...
@@ -253,7 +253,7 @@ void compare(bool use_mkldnn = false) {
...
@@ -253,7 +253,7 @@ void compare(bool use_mkldnn = false) {
}
}
// Compare result of NativeConfig and AnalysisConfig with memory optimization.
// Compare result of NativeConfig and AnalysisConfig with memory optimization.
TEST
(
Analyzer_dam
,
compare_with_memory_optim
)
{
TEST
(
Analyzer_dam
,
compare_with_
static_
memory_optim
)
{
// The small dam will core in CI, but works in local.
// The small dam will core in CI, but works in local.
if
(
FLAGS_max_turn_num
==
9
)
{
if
(
FLAGS_max_turn_num
==
9
)
{
contrib
::
AnalysisConfig
cfg
,
cfg1
;
contrib
::
AnalysisConfig
cfg
,
cfg1
;
...
@@ -263,7 +263,7 @@ TEST(Analyzer_dam, compare_with_memory_optim) {
...
@@ -263,7 +263,7 @@ TEST(Analyzer_dam, compare_with_memory_optim) {
SetInput
(
&
input_slots_all
);
SetInput
(
&
input_slots_all
);
// Run the first time to force to update memory cache
// Run the first time to force to update memory cache
SetConfig
(
&
cfg
);
SetConfig
(
&
cfg
);
cfg
.
EnableMemoryOptim
(
true
);
cfg
.
EnableMemoryOptim
(
true
,
true
/*force update*/
);
CompareNativeAndAnalysis
(
CompareNativeAndAnalysis
(
reinterpret_cast
<
const
PaddlePredictor
::
Config
*>
(
&
cfg
),
reinterpret_cast
<
const
PaddlePredictor
::
Config
*>
(
&
cfg
),
...
@@ -271,7 +271,7 @@ TEST(Analyzer_dam, compare_with_memory_optim) {
...
@@ -271,7 +271,7 @@ TEST(Analyzer_dam, compare_with_memory_optim) {
// Run second time to use the memory cache and perform memory optimization.
// Run second time to use the memory cache and perform memory optimization.
SetConfig
(
&
cfg1
);
SetConfig
(
&
cfg1
);
cfg1
.
EnableMemoryOptim
();
cfg1
.
EnableMemoryOptim
(
true
,
false
/*do not force update*/
);
CompareNativeAndAnalysis
(
CompareNativeAndAnalysis
(
reinterpret_cast
<
const
PaddlePredictor
::
Config
*>
(
&
cfg1
),
reinterpret_cast
<
const
PaddlePredictor
::
Config
*>
(
&
cfg1
),
...
@@ -279,6 +279,24 @@ TEST(Analyzer_dam, compare_with_memory_optim) {
...
@@ -279,6 +279,24 @@ TEST(Analyzer_dam, compare_with_memory_optim) {
}
}
}
}
TEST
(
Analyzer_dam
,
compare_with_dynamic_memory_optim
)
{
// The small dam will core in CI, but works in local.
if
(
FLAGS_max_turn_num
==
9
)
{
contrib
::
AnalysisConfig
cfg
,
cfg1
;
DataRecord
data
(
FLAGS_infer_data
,
FLAGS_batch_size
);
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
input_slots_all
;
SetInput
(
&
input_slots_all
);
// Run the first time to force to update memory cache
SetConfig
(
&
cfg
);
cfg
.
EnableMemoryOptim
();
CompareNativeAndAnalysis
(
reinterpret_cast
<
const
PaddlePredictor
::
Config
*>
(
&
cfg
),
input_slots_all
);
}
}
TEST
(
Analyzer_dam
,
compare
)
{
compare
();
}
TEST
(
Analyzer_dam
,
compare
)
{
compare
();
}
#ifdef PADDLE_WITH_MKLDNN
#ifdef PADDLE_WITH_MKLDNN
...
...
paddle/fluid/operators/distributed/request_handler_impl.cc
浏览文件 @
53d558cd
...
@@ -54,6 +54,11 @@ bool RequestSendHandler::Handle(const std::string& varname,
...
@@ -54,6 +54,11 @@ bool RequestSendHandler::Handle(const std::string& varname,
// Async
// Async
if
(
!
sync_mode_
)
{
if
(
!
sync_mode_
)
{
VLOG
(
3
)
<<
"async process var: "
<<
varname
;
VLOG
(
3
)
<<
"async process var: "
<<
varname
;
if
(
varname
==
BATCH_BARRIER_MESSAGE
)
{
PADDLE_THROW
(
"async mode should not recv BATCH_BARRIER_MESSAGE or "
"COMPLETE_MESSAGE"
);
}
try
{
try
{
executor_
->
RunPreparedContext
((
*
grad_to_prepared_ctx_
)[
varname
].
get
(),
executor_
->
RunPreparedContext
((
*
grad_to_prepared_ctx_
)[
varname
].
get
(),
scope
);
scope
);
...
...
paddle/fluid/operators/distributed/rpc_server.cc
浏览文件 @
53d558cd
...
@@ -39,27 +39,33 @@ void RPCServer::SavePort() const {
...
@@ -39,27 +39,33 @@ void RPCServer::SavePort() const {
port_file
.
open
(
file_path
);
port_file
.
open
(
file_path
);
port_file
<<
selected_port_
;
port_file
<<
selected_port_
;
port_file
.
close
();
port_file
.
close
();
VLOG
(
4
)
<<
"selected port written to "
<<
file_path
;
VLOG
(
3
)
<<
"selected port written to "
<<
file_path
;
}
}
void
RPCServer
::
WaitBarrier
(
const
std
::
string
&
rpc_name
)
{
void
RPCServer
::
WaitBarrier
(
const
std
::
string
&
rpc_name
)
{
VLOG
(
3
)
<<
"WaitBarrier in: "
<<
rpc_name
;
std
::
unique_lock
<
std
::
mutex
>
lock
(
this
->
mutex_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
this
->
mutex_
);
barrier_cond_
.
wait
(
lock
,
[
this
,
&
rpc_name
]
{
barrier_cond_
.
wait
(
lock
,
[
this
,
&
rpc_name
]
{
return
((
barrier_counter_
[
rpc_name
]
==
client_num_
&&
client_num_
!=
0
)
||
return
((
barrier_counter_
[
rpc_name
]
==
client_num_
&&
client_num_
!=
0
)
||
exit_flag_
.
load
());
exit_flag_
.
load
());
});
});
VLOG
(
3
)
<<
"
batch_barrier_: "
<<
rpc_name
<<
" "
VLOG
(
3
)
<<
"
WaitBarrier out: "
<<
rpc_name
<<
barrier_counter_
[
rpc_name
];
<<
" counter: "
<<
barrier_counter_
[
rpc_name
];
}
}
void
RPCServer
::
IncreaseBatchBarrier
(
const
std
::
string
rpc_name
)
{
void
RPCServer
::
IncreaseBatchBarrier
(
const
std
::
string
rpc_name
)
{
VLOG
(
4
)
<<
"RPCServer begin IncreaseBatchBarrier "
<<
rpc_name
;
VLOG
(
3
)
<<
"RPCServer begin IncreaseBatchBarrier "
<<
rpc_name
;
// barrier msg should make sure that it's in the right cond(send|recv)
WaitCond
(
rpc_name
);
int
b
=
0
;
int
b
=
0
;
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
b
=
++
barrier_counter_
[
rpc_name
];
b
=
++
barrier_counter_
[
rpc_name
];
VLOG
(
3
)
<<
rpc_name
<<
" barrier_counter: "
<<
b
;
if
(
b
>=
client_num_
)
{
if
(
b
>=
client_num_
)
{
lock
.
unlock
();
lock
.
unlock
();
VLOG
(
3
)
<<
"BatchBarrier counter reach "
<<
client_num_
<<
" for "
<<
rpc_name
;
barrier_cond_
.
notify_all
();
barrier_cond_
.
notify_all
();
lock
.
lock
();
lock
.
lock
();
}
}
...
@@ -71,7 +77,7 @@ void RPCServer::Complete() {
...
@@ -71,7 +77,7 @@ void RPCServer::Complete() {
client_num_
--
;
client_num_
--
;
need_reset_all_vars_
=
true
;
need_reset_all_vars_
=
true
;
VLOG
(
4
)
<<
"decrease client_num to: "
<<
client_num_
;
VLOG
(
3
)
<<
"decrease client_num to: "
<<
client_num_
;
if
(
cur_cond_
.
load
()
==
rpc_cond_map_
[
kRequestGet
])
{
if
(
cur_cond_
.
load
()
==
rpc_cond_map_
[
kRequestGet
])
{
barrier_counter_
[
kRequestGet
]
--
;
barrier_counter_
[
kRequestGet
]
--
;
}
}
...
@@ -105,8 +111,8 @@ void RPCServer::RegisterRPC(const std::string& rpc_name,
...
@@ -105,8 +111,8 @@ void RPCServer::RegisterRPC(const std::string& rpc_name,
static
int
cond
=
-
1
;
static
int
cond
=
-
1
;
rpc_cond_map_
[
rpc_name
]
=
++
cond
;
rpc_cond_map_
[
rpc_name
]
=
++
cond
;
VLOG
(
4
)
<<
"RegisterRPC rpc_name:"
<<
rpc_name
<<
", handler:
"
<<
handler
VLOG
(
3
)
<<
"RegisterRPC rpc_name: "
<<
rpc_name
<<
", handler:
"
<<
handler
<<
", cond:"
<<
rpc_cond_map_
[
rpc_name
];
<<
", cond:
"
<<
rpc_cond_map_
[
rpc_name
];
}
}
void
RPCServer
::
SetCond
(
const
std
::
string
&
rpc_name
)
{
void
RPCServer
::
SetCond
(
const
std
::
string
&
rpc_name
)
{
...
@@ -120,7 +126,7 @@ void RPCServer::SetCond(const std::string& rpc_name) {
...
@@ -120,7 +126,7 @@ void RPCServer::SetCond(const std::string& rpc_name) {
}
}
void
RPCServer
::
WaitCond
(
const
std
::
string
&
rpc_name
)
{
void
RPCServer
::
WaitCond
(
const
std
::
string
&
rpc_name
)
{
VLOG
(
4
)
<<
"RPCServer WaitCond
"
<<
rpc_name
;
VLOG
(
3
)
<<
"RPCServer WaitCond in
"
<<
rpc_name
;
int
cond
=
0
;
int
cond
=
0
;
{
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
...
@@ -130,6 +136,7 @@ void RPCServer::WaitCond(const std::string& rpc_name) {
...
@@ -130,6 +136,7 @@ void RPCServer::WaitCond(const std::string& rpc_name) {
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
rpc_cond_
.
wait
(
rpc_cond_
.
wait
(
lock
,
[
=
]
{
return
(
cur_cond_
.
load
()
==
cond
||
exit_flag_
.
load
());
});
lock
,
[
=
]
{
return
(
cur_cond_
.
load
()
==
cond
||
exit_flag_
.
load
());
});
VLOG
(
3
)
<<
"RPCServer WaitCond out "
<<
rpc_name
;
}
}
void
RPCServer
::
RegisterVar
(
const
std
::
string
&
var_name
,
void
RPCServer
::
RegisterVar
(
const
std
::
string
&
var_name
,
...
@@ -151,7 +158,7 @@ void RPCServer::RegisterVar(const std::string& var_name,
...
@@ -151,7 +158,7 @@ void RPCServer::RegisterVar(const std::string& var_name,
}
}
rpc_cond_
.
notify_all
();
rpc_cond_
.
notify_all
();
VLOG
(
4
)
<<
"RegisterVar context:"
<<
h
.
String
();
VLOG
(
3
)
<<
"RegisterVar context:"
<<
h
.
String
();
}
}
void
RPCServer
::
IncreaseVarBarrier
(
const
std
::
string
&
var_name
)
{
void
RPCServer
::
IncreaseVarBarrier
(
const
std
::
string
&
var_name
)
{
...
@@ -167,11 +174,11 @@ void RPCServer::IncreaseVarBarrier(const std::string& var_name) {
...
@@ -167,11 +174,11 @@ void RPCServer::IncreaseVarBarrier(const std::string& var_name) {
barrier_cond_
.
notify_all
();
barrier_cond_
.
notify_all
();
}
}
VLOG
(
4
)
<<
"IncreaseVarBarrier context:"
<<
h
.
String
();
VLOG
(
3
)
<<
"IncreaseVarBarrier context:"
<<
h
.
String
();
}
}
void
RPCServer
::
WaitVarBarrier
(
const
std
::
string
&
var_name
)
{
void
RPCServer
::
WaitVarBarrier
(
const
std
::
string
&
var_name
)
{
VLOG
(
4
)
<<
"Wait
Barrier var_name:"
<<
var_name
;
VLOG
(
3
)
<<
"WaitVar
Barrier var_name:"
<<
var_name
;
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
barrier_cond_
.
wait
(
lock
,
[
&
]()
{
barrier_cond_
.
wait
(
lock
,
[
&
]()
{
...
@@ -179,11 +186,11 @@ void RPCServer::WaitVarBarrier(const std::string& var_name) {
...
@@ -179,11 +186,11 @@ void RPCServer::WaitVarBarrier(const std::string& var_name) {
exit_flag_
.
load
());
exit_flag_
.
load
());
});
});
VLOG
(
4
)
<<
"Wait
Barrier context: "
<<
var_map_
[
var_name
].
String
();
VLOG
(
3
)
<<
"WaitVar
Barrier context: "
<<
var_map_
[
var_name
].
String
();
}
}
void
RPCServer
::
SetVarCond
(
const
std
::
string
&
var_name
)
{
void
RPCServer
::
SetVarCond
(
const
std
::
string
&
var_name
)
{
VLOG
(
4
)
<<
"SetVarCond var_name:"
<<
var_name
;
VLOG
(
3
)
<<
"SetVarCond var_name:"
<<
var_name
;
{
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
if
(
var_map_
.
find
(
var_name
)
!=
var_map_
.
end
())
{
if
(
var_map_
.
find
(
var_name
)
!=
var_map_
.
end
())
{
...
@@ -193,14 +200,14 @@ void RPCServer::SetVarCond(const std::string& var_name) {
...
@@ -193,14 +200,14 @@ void RPCServer::SetVarCond(const std::string& var_name) {
}
}
void
RPCServer
::
WaitVarCond
(
const
std
::
string
&
var_name
)
{
void
RPCServer
::
WaitVarCond
(
const
std
::
string
&
var_name
)
{
VLOG
(
4
)
<<
"WaitVarCond var_name:"
<<
var_name
;
VLOG
(
3
)
<<
"WaitVarCond var_name:"
<<
var_name
;
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
rpc_cond_
.
wait
(
lock
,
[
=
]
{
rpc_cond_
.
wait
(
lock
,
[
=
]
{
return
(
var_map_
.
find
(
var_name
)
!=
var_map_
.
end
()
||
exit_flag_
.
load
());
return
(
var_map_
.
find
(
var_name
)
!=
var_map_
.
end
()
||
exit_flag_
.
load
());
});
});
VLOG
(
4
)
<<
"WaitVarCond var_name:"
<<
var_name
<<
" end"
;
VLOG
(
3
)
<<
"WaitVarCond var_name:"
<<
var_name
<<
" end"
;
}
}
MonomerHandle
RPCServer
::
GetMonomer
(
const
std
::
string
&
var_name
)
{
MonomerHandle
RPCServer
::
GetMonomer
(
const
std
::
string
&
var_name
)
{
...
...
paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc
浏览文件 @
53d558cd
...
@@ -137,7 +137,9 @@ void ListenAndServOp::RunSyncLoop(
...
@@ -137,7 +137,9 @@ void ListenAndServOp::RunSyncLoop(
while
(
true
)
{
while
(
true
)
{
// Get from multiple trainers, we don't care about the order in which
// Get from multiple trainers, we don't care about the order in which
// the gradients arrives, just add suffix 0~n and merge the gradient.
// the gradients arrives, just add suffix 0~n and merge the gradient.
VLOG
(
3
)
<<
"wait all clients to send gradient"
;
rpc_service_
->
SetCond
(
distributed
::
kRequestSend
);
rpc_service_
->
SetCond
(
distributed
::
kRequestSend
);
VLOG
(
3
)
<<
"wait all clients to send send_barrier"
;
rpc_service_
->
WaitBarrier
(
distributed
::
kRequestSend
);
rpc_service_
->
WaitBarrier
(
distributed
::
kRequestSend
);
if
(
rpc_service_
->
IsExit
())
{
if
(
rpc_service_
->
IsExit
())
{
...
@@ -168,12 +170,16 @@ void ListenAndServOp::RunSyncLoop(
...
@@ -168,12 +170,16 @@ void ListenAndServOp::RunSyncLoop(
}
}
ParallelExecuteBlocks
(
parallel_blkids
,
executor
,
optimize_prepared
,
program
,
ParallelExecuteBlocks
(
parallel_blkids
,
executor
,
optimize_prepared
,
program
,
recv_scope
);
recv_scope
);
VLOG
(
2
)
<<
"run all blocks spent "
<<
GetTimestamp
()
-
ts
<<
"(ms)"
;
VLOG
(
3
)
<<
"run all blocks spent "
<<
GetTimestamp
()
-
ts
<<
"(ms)"
;
VLOG
(
3
)
<<
"ResetReceivedVars"
;
ResetReceivedVars
(
recv_scope
,
dev_ctx
,
rpc_service_
->
NeedResetAllVars
());
ResetReceivedVars
(
recv_scope
,
dev_ctx
,
rpc_service_
->
NeedResetAllVars
());
VLOG
(
3
)
<<
"wait all clients to get parameters back"
;
rpc_service_
->
SetCond
(
distributed
::
kRequestGet
);
rpc_service_
->
SetCond
(
distributed
::
kRequestGet
);
VLOG
(
3
)
<<
"wait all clients to send fetch_barrier"
;
rpc_service_
->
WaitBarrier
(
distributed
::
kRequestGet
);
rpc_service_
->
WaitBarrier
(
distributed
::
kRequestGet
);
VLOG
(
3
)
<<
"ResetBarrierCounter"
;
rpc_service_
->
ResetBarrierCounter
();
rpc_service_
->
ResetBarrierCounter
();
}
// while(true)
}
// while(true)
}
}
...
...
paddle/fluid/operators/grid_sampler_op.cc
浏览文件 @
53d558cd
...
@@ -43,12 +43,14 @@ class GridSampleOp : public framework::OperatorWithKernel {
...
@@ -43,12 +43,14 @@ class GridSampleOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE
(
grid_dims
[
3
]
==
2
,
"Input(Grid) dims[3] should be 2."
);
PADDLE_ENFORCE
(
grid_dims
[
3
]
==
2
,
"Input(Grid) dims[3] should be 2."
);
PADDLE_ENFORCE_EQ
(
grid_dims
[
0
],
x_dims
[
0
],
PADDLE_ENFORCE_EQ
(
grid_dims
[
0
],
x_dims
[
0
],
"Input(X) and Input(Grid) dims[0] should be equal."
);
"Input(X) and Input(Grid) dims[0] should be equal."
);
if
(
ctx
->
IsRuntime
())
{
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
grid_dims
[
1
],
x_dims
[
2
],
grid_dims
[
1
],
x_dims
[
2
],
"Input(X) dims[2] and Input(Grid) dims[1] should be equal."
);
"Input(X) dims[2] and Input(Grid) dims[1] should be equal."
);
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
grid_dims
[
2
],
x_dims
[
3
],
grid_dims
[
2
],
x_dims
[
3
],
"Input(X) dims[3] and Input(Grid) dims[2] should be equal."
);
"Input(X) dims[3] and Input(Grid) dims[2] should be equal."
);
}
ctx
->
SetOutputDim
(
"Output"
,
x_dims
);
ctx
->
SetOutputDim
(
"Output"
,
x_dims
);
ctx
->
ShareLoD
(
"X"
,
"Output"
);
ctx
->
ShareLoD
(
"X"
,
"Output"
);
...
...
paddle/fluid/operators/ngraph/CMakeLists.txt
浏览文件 @
53d558cd
if
(
WITH_NGRAPH
)
if
(
WITH_NGRAPH
)
cc_library
(
ngraph_bridge SRCS ngraph_bridge.cc DEPS operator framework_proto ngraph
)
cc_library
(
ngraph_engine SRCS ngraph_engine.cc DEPS ngraph_bridge framework_proto
)
cc_library
(
ngraph_engine SRCS ngraph_engine.cc DEPS ngraph_bridge framework_proto
)
op_library
(
ngraph_engine_op DEPS ngraph_engine op_registry op_info device_context
)
op_library
(
ngraph_engine_op DEPS ngraph_engine op_registry op_info device_context
)
endif
()
endif
()
paddle/fluid/
framework
/ngraph_bridge.cc
→
paddle/fluid/
operators/ngraph
/ngraph_bridge.cc
浏览文件 @
53d558cd
...
@@ -17,39 +17,39 @@ limitations under the License. */
...
@@ -17,39 +17,39 @@ limitations under the License. */
#include <vector>
#include <vector>
#include "ngraph/ngraph.hpp"
#include "ngraph/ngraph.hpp"
#include "paddle/fluid/framework/ngraph_bridge.h"
#include "paddle/fluid/operators/ngraph/ngraph_bridge.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/ngraph/ngraph_ops.h"
#include "paddle/fluid/operators/ngraph/ngraph_ops.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/ngraph_helper.h"
#include "paddle/fluid/platform/ngraph_helper.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
operators
{
namespace
NG_OPS
=
paddle
::
operators
::
ngraphs
;
namespace
NG_OPS
=
paddle
::
operators
::
ngraphs
;
std
::
map
<
std
::
string
,
std
::
map
<
std
::
string
,
std
::
function
<
void
(
const
std
::
shared_ptr
<
OperatorBase
>&
,
std
::
function
<
void
(
const
std
::
shared_ptr
<
framework
::
OperatorBase
>&
,
std
::
shared_ptr
<
std
::
unordered_map
<
std
::
shared_ptr
<
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
ngraph
::
Node
>>>
)
>>
std
::
string
,
std
::
shared_ptr
<
ngraph
::
Node
>>>
)
>>
NgraphBridge
::
NG_NODE_MAP
=
{
NgraphBridge
::
NG_NODE_MAP
=
{
{
"elementwise_add"
,
NG_OPS
::
BuildElementwiseAddNode
},
{
"elementwise_add"
,
NG_OPS
::
BuildElementwiseAddNode
},
{
"elementwise_add_grad"
,
NG_OPS
::
BuildElementwiseAddGradNode
},
{
"elementwise_add_grad"
,
NG_OPS
::
BuildElementwiseAddGradNode
},
{
"fill_constant"
,
paddle
::
operators
::
ngraphs
::
BuildFillConstantNode
},
{
"fill_constant"
,
NG_OPS
::
BuildFillConstantNode
},
{
"mean"
,
paddle
::
operators
::
ngraphs
::
BuildMeanNode
},
{
"mean"
,
NG_OPS
::
BuildMeanNode
},
{
"mean_grad"
,
paddle
::
operators
::
ngraphs
::
BuildMeanGradNode
},
{
"mean_grad"
,
NG_OPS
::
BuildMeanGradNode
},
{
"mul"
,
paddle
::
operators
::
ngraphs
::
BuildMulNode
},
{
"mul"
,
NG_OPS
::
BuildMulNode
},
{
"mul_grad"
,
paddle
::
operators
::
ngraphs
::
BuildMulGradNode
},
{
"mul_grad"
,
NG_OPS
::
BuildMulGradNode
},
{
"softmax"
,
paddle
::
operators
::
ngraphs
::
BuildSoftmaxNode
},
{
"softmax"
,
NG_OPS
::
BuildSoftmaxNode
},
{
"softmax_grad"
,
paddle
::
operators
::
ngraphs
::
BuildSoftmaxGradNode
},
{
"softmax_grad"
,
NG_OPS
::
BuildSoftmaxGradNode
},
{
"scale"
,
paddle
::
operators
::
ngraphs
::
BuildScaleNode
},
{
"scale"
,
NG_OPS
::
BuildScaleNode
},
{
"relu"
,
paddle
::
operators
::
ngraphs
::
BuildUnaryNode
<
ngraph
::
op
::
Relu
>
},
{
"relu"
,
NG_OPS
::
BuildUnaryNode
<
ngraph
::
op
::
Relu
>
},
{
"tanh"
,
paddle
::
operators
::
ngraphs
::
BuildUnaryNode
<
ngraph
::
op
::
Tanh
>
},
{
"tanh"
,
NG_OPS
::
BuildUnaryNode
<
ngraph
::
op
::
Tanh
>
},
{
"top_k"
,
paddle
::
operators
::
ngraphs
::
BuildTopKNode
}};
{
"top_k"
,
NG_OPS
::
BuildTopKNode
}};
void
NgraphBridge
::
BuildNgNode
(
const
std
::
shared_ptr
<
OperatorBase
>&
op
)
{
void
NgraphBridge
::
BuildNgNode
(
const
std
::
shared_ptr
<
framework
::
OperatorBase
>&
op
)
{
auto
&
op_type
=
op
->
Type
();
auto
&
op_type
=
op
->
Type
();
NG_NODE_MAP
[
op_type
](
op
,
ngb_node_map_
);
NG_NODE_MAP
[
op_type
](
op
,
ngb_node_map_
);
}
}
}
// namespace
framework
}
// namespace
operators
}
// namespace paddle
}
// namespace paddle
paddle/fluid/
framework
/ngraph_bridge.h
→
paddle/fluid/
operators/ngraph
/ngraph_bridge.h
浏览文件 @
53d558cd
...
@@ -21,16 +21,16 @@ limitations under the License. */
...
@@ -21,16 +21,16 @@ limitations under the License. */
#include "ngraph/node.hpp"
#include "ngraph/node.hpp"
namespace
paddle
{
#include "paddle/fluid/framework/operator.h"
namespace
framework
{
class
OperatorBase
;
namespace
paddle
{
namespace
operators
{
class
NgraphBridge
{
class
NgraphBridge
{
public:
public:
static
std
::
map
<
static
std
::
map
<
std
::
string
,
std
::
string
,
std
::
function
<
void
(
const
std
::
shared_ptr
<
OperatorBase
>&
,
std
::
function
<
void
(
const
std
::
shared_ptr
<
framework
::
OperatorBase
>&
,
std
::
shared_ptr
<
std
::
unordered_map
<
std
::
shared_ptr
<
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
ngraph
::
Node
>>>
)
>>
std
::
string
,
std
::
shared_ptr
<
ngraph
::
Node
>>>
)
>>
NG_NODE_MAP
;
NG_NODE_MAP
;
...
@@ -41,7 +41,7 @@ class NgraphBridge {
...
@@ -41,7 +41,7 @@ class NgraphBridge {
var_node_map
)
var_node_map
)
:
ngb_node_map_
(
var_node_map
)
{}
:
ngb_node_map_
(
var_node_map
)
{}
void
BuildNgNode
(
const
std
::
shared_ptr
<
OperatorBase
>&
op
);
void
BuildNgNode
(
const
std
::
shared_ptr
<
framework
::
OperatorBase
>&
op
);
private:
private:
std
::
shared_ptr
<
std
::
shared_ptr
<
...
@@ -49,5 +49,5 @@ class NgraphBridge {
...
@@ -49,5 +49,5 @@ class NgraphBridge {
ngb_node_map_
;
ngb_node_map_
;
};
};
}
// namespace
framework
}
// namespace
operators
}
// namespace paddle
}
// namespace paddle
paddle/fluid/operators/ngraph/ngraph_engine.cc
浏览文件 @
53d558cd
...
@@ -24,11 +24,11 @@ limitations under the License. */
...
@@ -24,11 +24,11 @@ limitations under the License. */
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/ngraph_bridge.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/ngraph/ngraph_bridge.h"
#include "paddle/fluid/operators/ngraph/ngraph_engine.h"
#include "paddle/fluid/operators/ngraph/ngraph_engine.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -88,15 +88,14 @@ static std::vector<std::vector<int>> NgraphOpIntervals(
...
@@ -88,15 +88,14 @@ static std::vector<std::vector<int>> NgraphOpIntervals(
int
pivot
=
left
;
int
pivot
=
left
;
while
(
pivot
<
right
)
{
while
(
pivot
<
right
)
{
auto
op_type
=
ops
.
at
(
pivot
)
->
Type
();
auto
op_type
=
ops
.
at
(
pivot
)
->
Type
();
if
(
paddle
::
framework
::
NgraphBridge
::
NG_NODE_MAP
.
find
(
op_type
)
==
if
(
NgraphBridge
::
NG_NODE_MAP
.
find
(
op_type
)
==
paddle
::
framework
::
NgraphBridge
::
NG_NODE_MAP
.
end
())
{
NgraphBridge
::
NG_NODE_MAP
.
end
())
{
++
pivot
;
++
pivot
;
}
else
{
}
else
{
int
start
=
pivot
,
end
=
start
;
int
start
=
pivot
,
end
=
start
;
while
(
pivot
<
right
&&
while
(
pivot
<
right
&&
(
paddle
::
framework
::
NgraphBridge
::
NG_NODE_MAP
.
find
(
(
NgraphBridge
::
NG_NODE_MAP
.
find
(
ops
.
at
(
pivot
)
->
Type
())
!=
ops
.
at
(
pivot
)
->
Type
())
!=
NgraphBridge
::
NG_NODE_MAP
.
end
()))
{
paddle
::
framework
::
NgraphBridge
::
NG_NODE_MAP
.
end
()))
{
++
pivot
;
++
pivot
;
++
end
;
++
end
;
}
}
...
@@ -283,7 +282,7 @@ void NgraphEngine::BuildNgNodes() {
...
@@ -283,7 +282,7 @@ void NgraphEngine::BuildNgNodes() {
}
}
}
}
}
}
framework
::
NgraphBridge
ngb
(
var_node_map_
);
NgraphBridge
ngb
(
var_node_map_
);
for
(
auto
&
op
:
fused_ops_
)
{
for
(
auto
&
op
:
fused_ops_
)
{
ngb
.
BuildNgNode
(
op
);
ngb
.
BuildNgNode
(
op
);
}
}
...
...
python/paddle/fluid/contrib/int8_inference/utility.py
浏览文件 @
53d558cd
...
@@ -32,10 +32,13 @@ class Calibrator(object):
...
@@ -32,10 +32,13 @@ class Calibrator(object):
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
self
.
program
=
kwargs
[
'program'
]
self
.
program
=
kwargs
[
'program'
]
self
.
iterations
=
kwargs
[
'iterations'
]
self
.
pretrained_model
=
kwargs
[
'pretrained_model'
]
self
.
pretrained_model
=
kwargs
[
'pretrained_model'
]
self
.
debug
=
kwargs
[
'debug'
]
self
.
debug
=
kwargs
[
'debug'
]
if
'debug'
in
kwargs
else
False
self
.
algo
=
kwargs
[
'algo'
]
self
.
algo
=
kwargs
[
'algo'
]
self
.
output
=
kwargs
[
'output'
]
self
.
feed_var_names
=
kwargs
[
'feed_var_names'
]
self
.
fetch_list
=
kwargs
[
'fetch_list'
]
self
.
exe
=
kwargs
[
'exe'
]
self
.
_conv_input_var_name
=
[]
self
.
_conv_input_var_name
=
[]
self
.
_conv_output_var_name
=
[]
self
.
_conv_output_var_name
=
[]
...
@@ -54,17 +57,38 @@ class Calibrator(object):
...
@@ -54,17 +57,38 @@ class Calibrator(object):
self
.
_u8_output_var
=
[]
self
.
_u8_output_var
=
[]
self
.
_s8_output_var
=
[]
self
.
_s8_output_var
=
[]
self
.
_persistable_vars
=
[]
self
.
_persistable_vars
=
[]
self
.
_sampling_data
=
{}
def
generate_sampling_program
(
self
):
self
.
__init_analysis
()
self
.
__init_analysis
()
self
.
__generate_output_program
()
self
.
__generate_output_program
()
def
generate_quantized_data
(
self
,
sampling_data
):
def
save_int8_model
(
self
):
self
.
__sampling
(
sampling_data
)
self
.
__sampling
(
s
elf
.
_s
ampling_data
)
self
.
__save_scale
()
self
.
__save_scale
()
self
.
__update_program
()
self
.
__update_program
()
self
.
__update_output_program_attr
()
self
.
__update_output_program_attr
()
self
.
__display_debug
()
self
.
__display_debug
()
self
.
__save_offline_model
()
def
sample_data
(
self
):
'''
Sampling the tensor data of variable.
'''
for
i
in
self
.
sampling_program
.
list_vars
():
if
i
.
name
in
self
.
sampling_vars
:
np_data
=
np
.
array
(
fluid
.
global_scope
().
find_var
(
i
.
name
)
.
get_tensor
())
if
i
.
name
not
in
self
.
_sampling_data
:
self
.
_sampling_data
[
i
.
name
]
=
[]
self
.
_sampling_data
[
i
.
name
].
append
(
np_data
)
def
__save_offline_model
(
self
):
'''
Save the quantized model to the disk.
'''
fluid
.
io
.
save_inference_model
(
self
.
output
,
self
.
feed_var_names
,
self
.
fetch_list
,
self
.
exe
,
self
.
sampling_program
)
def
__display_debug
(
self
):
def
__display_debug
(
self
):
if
self
.
debug
:
if
self
.
debug
:
...
...
python/paddle/fluid/contrib/tests/test_calibration.py
浏览文件 @
53d558cd
...
@@ -26,7 +26,7 @@ import paddle.fluid.profiler as profiler
...
@@ -26,7 +26,7 @@ import paddle.fluid.profiler as profiler
from
PIL
import
Image
,
ImageEnhance
from
PIL
import
Image
,
ImageEnhance
import
math
import
math
sys
.
path
.
append
(
'..'
)
sys
.
path
.
append
(
'..'
)
import
int8_inference.utility
as
ut
import
int8_inference.utility
as
int8_utility
random
.
seed
(
0
)
random
.
seed
(
0
)
np
.
random
.
seed
(
0
)
np
.
random
.
seed
(
0
)
...
@@ -120,13 +120,13 @@ class TestCalibration(unittest.TestCase):
...
@@ -120,13 +120,13 @@ class TestCalibration(unittest.TestCase):
def
setUp
(
self
):
def
setUp
(
self
):
# TODO(guomingz): Put the download process in the cmake.
# TODO(guomingz): Put the download process in the cmake.
# Download and unzip test data set
# Download and unzip test data set
imagenet_dl_url
=
'http://paddle-inference-dist.
bj
.bcebos.com/int8/calibration_test_data.tar.gz'
imagenet_dl_url
=
'http://paddle-inference-dist.
cdn
.bcebos.com/int8/calibration_test_data.tar.gz'
zip_file_name
=
imagenet_dl_url
.
split
(
'/'
)[
-
1
]
zip_file_name
=
imagenet_dl_url
.
split
(
'/'
)[
-
1
]
cmd
=
'rm -rf data {} && mkdir data && wget {} && tar xvf {} -C data'
.
format
(
cmd
=
'rm -rf data {} && mkdir data && wget {} && tar xvf {} -C data'
.
format
(
zip_file_name
,
imagenet_dl_url
,
zip_file_name
)
zip_file_name
,
imagenet_dl_url
,
zip_file_name
)
os
.
system
(
cmd
)
os
.
system
(
cmd
)
# resnet50 fp32 data
# resnet50 fp32 data
resnet50_fp32_model_url
=
'http://paddle-inference-dist.
bj
.bcebos.com/int8/resnet50_int8_model.tar.gz'
resnet50_fp32_model_url
=
'http://paddle-inference-dist.
cdn
.bcebos.com/int8/resnet50_int8_model.tar.gz'
resnet50_zip_name
=
resnet50_fp32_model_url
.
split
(
'/'
)[
-
1
]
resnet50_zip_name
=
resnet50_fp32_model_url
.
split
(
'/'
)[
-
1
]
resnet50_unzip_folder_name
=
'resnet50_fp32'
resnet50_unzip_folder_name
=
'resnet50_fp32'
cmd
=
'rm -rf {} {} && mkdir {} && wget {} && tar xvf {} -C {}'
.
format
(
cmd
=
'rm -rf {} {} && mkdir {} && wget {} && tar xvf {} -C {}'
.
format
(
...
@@ -135,8 +135,7 @@ class TestCalibration(unittest.TestCase):
...
@@ -135,8 +135,7 @@ class TestCalibration(unittest.TestCase):
resnet50_zip_name
,
resnet50_unzip_folder_name
)
resnet50_zip_name
,
resnet50_unzip_folder_name
)
os
.
system
(
cmd
)
os
.
system
(
cmd
)
self
.
iterations
=
100
self
.
iterations
=
50
self
.
skip_batch_num
=
5
def
run_program
(
self
,
model_path
,
generate_int8
=
False
,
algo
=
'direct'
):
def
run_program
(
self
,
model_path
,
generate_int8
=
False
,
algo
=
'direct'
):
image_shape
=
[
3
,
224
,
224
]
image_shape
=
[
3
,
224
,
224
]
...
@@ -163,16 +162,15 @@ class TestCalibration(unittest.TestCase):
...
@@ -163,16 +162,15 @@ class TestCalibration(unittest.TestCase):
print
(
"Start calibration ..."
)
print
(
"Start calibration ..."
)
calibrator
=
ut
.
Calibrator
(
calibrator
=
int8_utility
.
Calibrator
(
program
=
infer_program
,
program
=
infer_program
,
pretrained_model
=
model_path
,
pretrained_model
=
model_path
,
iterations
=
100
,
algo
=
algo
,
debug
=
Fals
e
,
exe
=
ex
e
,
algo
=
algo
)
output
=
int8_model
,
feed_var_names
=
feed_dict
,
sampling_data
=
{}
fetch_list
=
fetch_targets
)
calibrator
.
generate_sampling_program
()
test_info
=
[]
test_info
=
[]
cnt
=
0
cnt
=
0
for
batch_id
,
data
in
enumerate
(
val_reader
()):
for
batch_id
,
data
in
enumerate
(
val_reader
()):
...
@@ -192,13 +190,7 @@ class TestCalibration(unittest.TestCase):
...
@@ -192,13 +190,7 @@ class TestCalibration(unittest.TestCase):
feed_dict
[
1
]:
label
},
feed_dict
[
1
]:
label
},
fetch_list
=
fetch_targets
)
fetch_list
=
fetch_targets
)
if
generate_int8
:
if
generate_int8
:
for
i
in
calibrator
.
sampling_program
.
list_vars
():
calibrator
.
sample_data
()
if
i
.
name
in
calibrator
.
sampling_vars
:
np_data
=
np
.
array
(
fluid
.
global_scope
().
find_var
(
i
.
name
)
.
get_tensor
())
if
i
.
name
not
in
sampling_data
:
sampling_data
[
i
.
name
]
=
[]
sampling_data
[
i
.
name
].
append
(
np_data
)
test_info
.
append
(
np
.
mean
(
acc1
)
*
len
(
data
))
test_info
.
append
(
np
.
mean
(
acc1
)
*
len
(
data
))
cnt
+=
len
(
data
)
cnt
+=
len
(
data
)
...
@@ -209,9 +201,8 @@ class TestCalibration(unittest.TestCase):
...
@@ -209,9 +201,8 @@ class TestCalibration(unittest.TestCase):
break
break
if
generate_int8
:
if
generate_int8
:
calibrator
.
generate_quantized_data
(
sampling_data
)
calibrator
.
save_int8_model
()
fluid
.
io
.
save_inference_model
(
int8_model
,
feed_dict
,
fetch_targets
,
exe
,
calibrator
.
sampling_program
)
print
(
print
(
"Calibration is done and the corresponding files were generated at {}"
.
"Calibration is done and the corresponding files were generated at {}"
.
format
(
os
.
path
.
abspath
(
"calibration_out"
)))
format
(
os
.
path
.
abspath
(
"calibration_out"
)))
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录