Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
b6c3b69a
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
b6c3b69a
编写于
1月 31, 2019
作者:
G
guoshengCS
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/Paddle
into fix-beam-search-size
test=develop
上级
5dfce931
46a6cac9
变更
41
隐藏空白更改
内联
并排
Showing
41 changed file
with
1037 addition
and
89 deletion
+1037
-89
paddle/fluid/API.spec
paddle/fluid/API.spec
+1
-0
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+1
-0
paddle/fluid/framework/ir/identity_scale_op_clean_pass.cc
paddle/fluid/framework/ir/identity_scale_op_clean_pass.cc
+80
-0
paddle/fluid/framework/ir/identity_scale_op_clean_pass.h
paddle/fluid/framework/ir/identity_scale_op_clean_pass.h
+33
-0
paddle/fluid/framework/scope.cc
paddle/fluid/framework/scope.cc
+1
-5
paddle/fluid/inference/analysis/ir_pass_manager.cc
paddle/fluid/inference/analysis/ir_pass_manager.cc
+3
-3
paddle/fluid/inference/api/analysis_config.cc
paddle/fluid/inference/api/analysis_config.cc
+5
-0
paddle/fluid/inference/api/analysis_predictor.cc
paddle/fluid/inference/api/analysis_predictor.cc
+2
-1
paddle/fluid/inference/api/analysis_predictor_tester.cc
paddle/fluid/inference/api/analysis_predictor_tester.cc
+1
-1
paddle/fluid/inference/api/paddle_analysis_config.h
paddle/fluid/inference/api/paddle_analysis_config.h
+5
-2
paddle/fluid/inference/api/paddle_pass_builder.h
paddle/fluid/inference/api/paddle_pass_builder.h
+2
-0
paddle/fluid/inference/tests/api/analyzer_seq_pool1_tester.cc
...le/fluid/inference/tests/api/analyzer_seq_pool1_tester.cc
+1
-1
paddle/fluid/inference/tests/api/analyzer_text_classification_tester.cc
...nference/tests/api/analyzer_text_classification_tester.cc
+1
-1
paddle/fluid/memory/allocation/legacy_allocator.cc
paddle/fluid/memory/allocation/legacy_allocator.cc
+64
-12
paddle/fluid/memory/allocation/legacy_allocator.h
paddle/fluid/memory/allocation/legacy_allocator.h
+47
-0
paddle/fluid/operators/batch_norm_op.cc
paddle/fluid/operators/batch_norm_op.cc
+4
-2
paddle/fluid/operators/detection/CMakeLists.txt
paddle/fluid/operators/detection/CMakeLists.txt
+1
-0
paddle/fluid/operators/detection/bbox_util.h
paddle/fluid/operators/detection/bbox_util.h
+24
-0
paddle/fluid/operators/detection/box_clip_op.cc
paddle/fluid/operators/detection/box_clip_op.cc
+86
-0
paddle/fluid/operators/detection/box_clip_op.cu
paddle/fluid/operators/detection/box_clip_op.cu
+74
-0
paddle/fluid/operators/detection/box_clip_op.h
paddle/fluid/operators/detection/box_clip_op.h
+50
-0
paddle/fluid/operators/jit/benchmark.cc
paddle/fluid/operators/jit/benchmark.cc
+4
-0
paddle/fluid/operators/jit/gen/blas.cc
paddle/fluid/operators/jit/gen/blas.cc
+1
-1
paddle/fluid/operators/jit/gen/blas.h
paddle/fluid/operators/jit/gen/blas.h
+1
-0
paddle/fluid/operators/jit/helper.h
paddle/fluid/operators/jit/helper.h
+15
-8
paddle/fluid/operators/jit/more/mix/mix.cc
paddle/fluid/operators/jit/more/mix/mix.cc
+10
-43
paddle/fluid/operators/jit/more/mkl/mkl.cc
paddle/fluid/operators/jit/more/mkl/mkl.cc
+1
-1
paddle/fluid/operators/math/fc_compute.h
paddle/fluid/operators/math/fc_compute.h
+6
-4
paddle/fluid/operators/math/softmax_impl.h
paddle/fluid/operators/math/softmax_impl.h
+3
-2
paddle/fluid/operators/ngraph/ngraph_bridge.cc
paddle/fluid/operators/ngraph/ngraph_bridge.cc
+2
-0
paddle/fluid/operators/ngraph/ngraph_ops.h
paddle/fluid/operators/ngraph/ngraph_ops.h
+1
-0
paddle/fluid/operators/ngraph/ops/conv2d_op.h
paddle/fluid/operators/ngraph/ops/conv2d_op.h
+235
-0
paddle/fluid/platform/place.cc
paddle/fluid/platform/place.cc
+6
-0
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+8
-0
python/paddle/fluid/io.py
python/paddle/fluid/io.py
+13
-1
python/paddle/fluid/layers/detection.py
python/paddle/fluid/layers/detection.py
+51
-0
python/paddle/fluid/tests/test_detection.py
python/paddle/fluid/tests/test_detection.py
+11
-0
python/paddle/fluid/tests/unittests/ngraph/test_conv2d_ngraph_op.py
...dle/fluid/tests/unittests/ngraph/test_conv2d_ngraph_op.py
+52
-0
python/paddle/fluid/tests/unittests/test_box_clip_op.py
python/paddle/fluid/tests/unittests/test_box_clip_op.py
+70
-0
python/paddle/fluid/tests/unittests/test_inference_model_io.py
...n/paddle/fluid/tests/unittests/test_inference_model_io.py
+2
-1
python/paddle/fluid/tests/unittests/test_peak_gpumem_monitor.py
.../paddle/fluid/tests/unittests/test_peak_gpumem_monitor.py
+59
-0
未找到文件。
paddle/fluid/API.spec
浏览文件 @
b6c3b69a
...
...
@@ -325,6 +325,7 @@ paddle.fluid.layers.iou_similarity ArgSpec(args=['x', 'y', 'name'], varargs=None
paddle.fluid.layers.box_coder ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'name', 'axis'], varargs=None, keywords=None, defaults=('encode_center_size', True, None, 0))
paddle.fluid.layers.polygon_box_transform ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'anchor_mask', 'class_num', 'ignore_thresh', 'downsample_ratio', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.box_clip ArgSpec(args=['input', 'im_info', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.multiclass_nms ArgSpec(args=['bboxes', 'scores', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'normalized', 'nms_eta', 'background_label', 'name'], varargs=None, keywords=None, defaults=(0.3, True, 1.0, 0, None))
paddle.fluid.layers.accuracy ArgSpec(args=['input', 'label', 'k', 'correct', 'total'], varargs=None, keywords=None, defaults=(1, None, None))
paddle.fluid.layers.auc ArgSpec(args=['input', 'label', 'curve', 'num_thresholds', 'topk', 'slide_steps'], varargs=None, keywords=None, defaults=('ROC', 4095, 1, 1))
...
...
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
b6c3b69a
...
...
@@ -65,6 +65,7 @@ pass_library(conv_elementwise_add2_act_fuse_pass inference)
pass_library
(
conv_elementwise_add_fuse_pass inference
)
pass_library
(
conv_affine_channel_fuse_pass inference
)
pass_library
(
transpose_flatten_concat_fuse_pass inference
)
pass_library
(
identity_scale_op_clean_pass base
)
# There may be many transpose-flatten structures in a model, and the output of
# these structures will be used as inputs to the concat Op. This pattern will
...
...
paddle/fluid/framework/ir/identity_scale_op_clean_pass.cc
0 → 100644
浏览文件 @
b6c3b69a
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/identity_scale_op_clean_pass.h"
#include <string>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
std
::
unique_ptr
<
ir
::
Graph
>
IdentityScaleOpCleanPass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
FusePassBase
::
Init
(
"identity_scale_op_clean"
,
graph
.
get
());
// pre_op -> scale_in -> scale_op -> scale_out
// ->
// pre_op -> scale_out
GraphPatternDetector
detector
;
auto
pre_op
=
detector
.
mutable_pattern
()
->
NewNode
(
"pre_op"
)
->
assert_is_op
();
auto
scale_in
=
detector
.
mutable_pattern
()
->
NewNode
(
"scale_in"
)
->
assert_is_op_input
(
"scale"
)
->
AsIntermediate
();
auto
scale_op
=
detector
.
mutable_pattern
()
->
NewNode
(
"scale_fuse"
)
->
assert_is_op
(
"scale"
)
->
assert_op_attr
<
float
>
(
"scale"
,
1.
)
->
assert_op_attr
<
float
>
(
"bias"
,
0.
);
auto
scale_out
=
detector
.
mutable_pattern
()
->
NewNode
(
"scale_out"
)
->
assert_is_op_output
(
"scale"
);
pre_op
->
LinksTo
({
scale_in
});
scale_op
->
LinksFrom
({
scale_in
}).
LinksTo
({
scale_out
});
GraphPatternDetector
::
handle_t
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
Node
*
scale_op_var
=
subgraph
.
at
(
scale_op
);
Node
*
scale_in_var
=
subgraph
.
at
(
scale_in
);
Node
*
scale_out_var
=
subgraph
.
at
(
scale_out
);
Node
*
pre_op_var
=
subgraph
.
at
(
pre_op
);
// Link pre_op directly to scale_out
const
std
::
string
scale_in_name
=
scale_in_var
->
Name
();
const
std
::
string
scale_out_name
=
scale_out_var
->
Name
();
// Remove links in graph
GraphSafeRemoveNodes
(
graph
,
{
scale_in_var
,
scale_op_var
});
// Modify proto message
auto
*
pre_op_desc
=
pre_op_var
->
Op
();
for
(
auto
&
parameter
:
*
pre_op_desc
->
Proto
()
->
mutable_outputs
())
{
auto
*
arguments
=
parameter
.
mutable_arguments
();
auto
it
=
std
::
find
(
arguments
->
begin
(),
arguments
->
end
(),
scale_in_name
);
PADDLE_ENFORCE
(
it
!=
arguments
->
end
());
*
it
=
scale_out_name
;
}
IR_NODE_LINK_TO
(
pre_op_var
,
scale_out_var
);
};
detector
(
graph
.
get
(),
handler
);
return
graph
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
identity_scale_op_clean_pass
,
paddle
::
framework
::
ir
::
IdentityScaleOpCleanPass
);
paddle/fluid/framework/ir/identity_scale_op_clean_pass.h
0 → 100644
浏览文件 @
b6c3b69a
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
IdentityScaleOpCleanPass
:
public
FusePassBase
{
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
;
private:
virtual
~
IdentityScaleOpCleanPass
()
=
default
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/scope.cc
浏览文件 @
b6c3b69a
...
...
@@ -22,11 +22,7 @@ limitations under the License. */
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/string/printf.h"
DEFINE_bool
(
benchmark
,
false
,
"Doing memory benchmark. It will make deleting scope synchronized, "
"and add some memory usage logs."
"Default cuda is asynchronous device, set to True will"
"force op run in synchronous mode."
);
DECLARE_bool
(
benchmark
);
DEFINE_bool
(
eager_delete_scope
,
true
,
...
...
paddle/fluid/inference/analysis/ir_pass_manager.cc
浏览文件 @
b6c3b69a
...
...
@@ -83,7 +83,6 @@ void IRPassManager::CreatePasses(Argument *argument,
new
std
::
string
(
GetOrCreateModelOptCacheDir
(
model_opt_cache_dir
)));
}
// graph_ = pass->Apply(std::move(graph_));
pre_pass
=
pass_name
;
passes_
.
emplace_back
(
std
::
move
(
pass
));
...
...
@@ -97,8 +96,9 @@ std::unique_ptr<Graph> IRPassManager::Apply(std::unique_ptr<Graph> graph) {
PADDLE_ENFORCE
(
graph
.
get
());
// Apply all the passes
for
(
const
auto
&
pass
:
passes_
)
{
if
(
pass
->
Type
()
==
"graph_viz_pass"
)
continue
;
PrettyLogEndl
(
Style
::
H2
(),
"--- Running IR pass [%s]"
,
pass
->
Type
());
if
(
pass
->
Type
()
!=
"graph_viz_pass"
)
{
PrettyLogEndl
(
Style
::
H2
(),
"--- Running IR pass [%s]"
,
pass
->
Type
());
}
graph
=
pass
->
Apply
(
std
::
move
(
graph
));
}
return
std
::
move
(
graph
);
...
...
paddle/fluid/inference/api/analysis_config.cc
浏览文件 @
b6c3b69a
...
...
@@ -318,4 +318,9 @@ NativeConfig AnalysisConfig::ToNativeConfig() const {
return
config
;
}
void
AnalysisConfig
::
SwitchIrDebug
(
int
x
)
{
ir_debug_
=
x
;
Update
();
}
}
// namespace paddle
paddle/fluid/inference/api/analysis_predictor.cc
浏览文件 @
b6c3b69a
...
...
@@ -58,7 +58,8 @@ namespace {
bool
IsPersistable
(
const
framework
::
VarDesc
*
var
)
{
if
(
var
->
Persistable
()
&&
var
->
GetType
()
!=
framework
::
proto
::
VarType
::
FEED_MINIBATCH
&&
var
->
GetType
()
!=
framework
::
proto
::
VarType
::
FETCH_LIST
)
{
var
->
GetType
()
!=
framework
::
proto
::
VarType
::
FETCH_LIST
&&
var
->
GetType
()
!=
framework
::
proto
::
VarType
::
RAW
)
{
return
true
;
}
return
false
;
...
...
paddle/fluid/inference/api/analysis_predictor_tester.cc
浏览文件 @
b6c3b69a
...
...
@@ -196,7 +196,7 @@ TEST(AnalysisPredictor, memory_optim) {
AnalysisConfig
config
(
FLAGS_dirname
);
config
.
DisableGpu
();
config
.
EnableMemoryOptim
(
true
);
config
.
pass_builder
()
->
TurnOn
Debug
();
config
.
SwitchIr
Debug
();
auto
native_predictor
=
CreatePaddlePredictor
<
NativeConfig
>
(
config
.
ToNativeConfig
());
...
...
paddle/fluid/inference/api/paddle_analysis_config.h
浏览文件 @
b6c3b69a
...
...
@@ -140,9 +140,12 @@ struct AnalysisConfig {
*/
bool
tensorrt_engine_enabled
()
const
{
return
use_tensorrt_
;
}
/** Control whther to debug IR graph analysis phase.
/** \brief Control whether to debug IR graph analysis phase.
*
* This will generate DOT files for visualizing the computation graph after
* each analysis pass applied.
*/
void
SwitchIrDebug
(
int
x
=
true
)
{
ir_debug_
=
x
;
}
void
SwitchIrDebug
(
int
x
=
true
)
;
/** Turn on MKLDNN.
*/
...
...
paddle/fluid/inference/api/paddle_pass_builder.h
浏览文件 @
b6c3b69a
...
...
@@ -117,6 +117,7 @@ class CpuPassStrategy : public PassStrategy {
"conv_bn_fuse_pass"
,
//
"conv_eltwiseadd_bn_fuse_pass"
,
//
"is_test_pass"
,
//
"identity_scale_op_clean_pass"
,
//
});
use_gpu_
=
false
;
}
...
...
@@ -155,6 +156,7 @@ class GpuPassStrategy : public PassStrategy {
GpuPassStrategy
()
:
PassStrategy
({})
{
passes_
.
assign
({
"infer_clean_graph_pass"
,
//
"identity_scale_op_clean_pass"
,
//
"conv_affine_channel_fuse_pass"
,
//
"conv_eltwiseadd_affine_channel_fuse_pass"
,
//
"conv_bn_fuse_pass"
,
//
...
...
paddle/fluid/inference/tests/api/analyzer_seq_pool1_tester.cc
浏览文件 @
b6c3b69a
...
...
@@ -142,7 +142,7 @@ void SetConfig(AnalysisConfig *cfg, bool use_mkldnn = false) {
cfg
->
SetModel
(
FLAGS_infer_model
+
"/model"
,
FLAGS_infer_model
+
"/params"
);
cfg
->
DisableGpu
();
cfg
->
SwitchSpecifyInputNames
();
cfg
->
pass_builder
()
->
TurnOn
Debug
();
cfg
->
SwitchIr
Debug
();
cfg
->
SetCpuMathLibraryNumThreads
(
FLAGS_paddle_num_threads
);
if
(
use_mkldnn
)
{
cfg
->
EnableMKLDNN
();
...
...
paddle/fluid/inference/tests/api/analyzer_text_classification_tester.cc
浏览文件 @
b6c3b69a
...
...
@@ -69,7 +69,7 @@ void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
TEST
(
Analyzer_Text_Classification
,
profile
)
{
AnalysisConfig
cfg
;
SetConfig
(
&
cfg
);
cfg
.
pass_builder
()
->
TurnOn
Debug
();
cfg
.
SwitchIr
Debug
();
std
::
vector
<
PaddleTensor
>
outputs
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
input_slots_all
;
...
...
paddle/fluid/memory/allocation/legacy_allocator.cc
浏览文件 @
b6c3b69a
...
...
@@ -35,6 +35,7 @@ DEFINE_bool(init_allocated_mem, false,
"To find this error in time, we use init_allocated_mem to indicate "
"that initializing the allocated memory with a small value "
"during unit testing."
);
DECLARE_bool
(
benchmark
);
DECLARE_double
(
fraction_of_gpu_memory_to_use
);
namespace
paddle
{
...
...
@@ -59,11 +60,6 @@ size_t memory_usage(const platform::Place &p);
using
BuddyAllocator
=
detail
::
BuddyAllocator
;
std
::
unordered_map
<
/*device id*/
int
,
std
::
pair
<
/*current memory usage*/
uint64_t
,
/*peak memory usage*/
uint64_t
>>
gpu_mem_info
;
BuddyAllocator
*
GetCPUBuddyAllocator
()
{
// We tried thread_local for inference::RNN1 model, but that not works much
// for multi-thread test.
...
...
@@ -144,6 +140,8 @@ BuddyAllocator *GetGPUBuddyAllocator(int gpu_id) {
devices
=
platform
::
GetSelectedDevices
();
int
gpu_num
=
devices
.
size
();
allocation
::
GPUMemMonitor
.
Initialize
(
devices
.
size
());
a_arr
=
new
BuddyAllocator
*
[
gpu_num
];
for
(
size_t
i
=
0
;
i
<
devices
.
size
();
++
i
)
{
int
dev_id
=
devices
[
i
];
...
...
@@ -204,12 +202,7 @@ void *Alloc<platform::CUDAPlace>(const platform::CUDAPlace &place,
<<
string
::
HumanReadableSize
(
Used
<
platform
::
CUDAPlace
>
(
place
));
platform
::
SetDeviceId
(
cur_dev
);
}
else
{
gpu_mem_info
[
place
.
device
].
first
+=
size
;
if
(
gpu_mem_info
[
place
.
device
].
first
>
gpu_mem_info
[
place
.
device
].
second
)
{
gpu_mem_info
[
place
.
device
].
second
=
gpu_mem_info
[
place
.
device
].
first
;
VLOG
(
3
)
<<
"device: "
<<
place
.
device
<<
" peak memory usage : "
<<
(
gpu_mem_info
[
place
.
device
].
second
>>
20
)
<<
" MiB"
;
}
if
(
FLAGS_benchmark
)
allocation
::
GPUMemMonitor
.
Add
(
place
.
device
,
size
);
if
(
FLAGS_init_allocated_mem
)
{
cudaMemset
(
ptr
,
0xEF
,
size
);
}
...
...
@@ -225,7 +218,7 @@ void Free<platform::CUDAPlace>(const platform::CUDAPlace &place, void *p,
size_t
size
)
{
#ifdef PADDLE_WITH_CUDA
GetGPUBuddyAllocator
(
place
.
device
)
->
Free
(
p
);
gpu_mem_info
[
place
.
device
].
first
-=
size
;
if
(
FLAGS_benchmark
)
allocation
::
GPUMemMonitor
.
Minus
(
place
.
device
,
size
)
;
#else
PADDLE_THROW
(
"'CUDAPlace' is not supported in CPU only device."
);
#endif
...
...
@@ -335,6 +328,8 @@ size_t Usage::operator()(const platform::CUDAPinnedPlace &cuda_pinned) const {
namespace
allocation
{
LegacyMemMonitor
GPUMemMonitor
;
Allocation
*
LegacyAllocator
::
AllocateImpl
(
size_t
size
,
Allocator
::
Attr
attr
)
{
void
*
ptr
=
boost
::
apply_visitor
(
legacy
::
AllocVisitor
(
size
),
place_
);
return
new
Allocation
(
ptr
,
size
,
place_
);
...
...
@@ -346,6 +341,63 @@ void LegacyAllocator::Free(Allocation *allocation) {
allocation
->
place
());
delete
allocation
;
}
bool
MemInfo
::
Add
(
const
size_t
&
size
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
usage_
+=
size
;
bool
peak_point
=
usage_
>
peak_usage_
;
if
(
peak_point
)
peak_usage_
=
usage_
;
return
peak_point
;
}
void
MemInfo
::
Minus
(
const
size_t
&
size
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
usage_
-=
size
;
}
uint64_t
MemInfo
::
GetPeakUsage
()
{
return
peak_usage_
;
}
LegacyMemMonitor
::~
LegacyMemMonitor
()
{
for
(
auto
&
item
:
gpu_mem_info_
)
delete
item
.
second
;
}
void
LegacyMemMonitor
::
Initialize
(
const
int
&
device_num
)
{
for
(
auto
i
=
0
;
i
<
device_num
;
++
i
)
{
gpu_mem_info_
[
i
]
=
new
MemInfo
();
}
}
void
LegacyMemMonitor
::
Add
(
const
int
&
device
,
const
size_t
&
size
)
{
if
(
gpu_mem_info_
[
device
]
->
Add
(
size
))
{
VLOG
(
3
)
<<
"#LegacyMemMonitor# device: "
<<
device
<<
" peak memory usage : "
<<
(
gpu_mem_info_
[
device
]
->
GetPeakUsage
()
>>
20
)
<<
" MiB"
;
}
}
void
LegacyMemMonitor
::
Minus
(
const
int
&
device
,
const
size_t
&
size
)
{
gpu_mem_info_
[
device
]
->
Minus
(
size
);
}
uint64_t
LegacyMemMonitor
::
GetMemUsage
(
const
int
&
device
)
{
return
gpu_mem_info_
.
find
(
device
)
==
gpu_mem_info_
.
end
()
?
0
:
gpu_mem_info_
[
device
]
->
GetPeakUsage
();
}
void
LegacyMemMonitor
::
PrintMemUsage
()
{
std
::
vector
<
int
>
devices
;
for
(
const
auto
&
item
:
gpu_mem_info_
)
{
devices
.
emplace_back
(
item
.
first
);
}
std
::
sort
(
devices
.
begin
(),
devices
.
end
());
for
(
const
auto
&
device
:
devices
)
{
std
::
cout
<<
"Device : "
<<
device
<<
" Peak Memory Usage : "
<<
(
gpu_mem_info_
[
device
]
->
GetPeakUsage
()
>>
20
)
<<
" MiB"
<<
std
::
endl
;
}
}
}
// namespace allocation
}
// namespace memory
}
// namespace paddle
paddle/fluid/memory/allocation/legacy_allocator.h
浏览文件 @
b6c3b69a
...
...
@@ -13,12 +13,59 @@
// limitations under the License.
#pragma once
#include <algorithm>
#include <mutex> // NOLINT
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/platform/place.h"
namespace
paddle
{
namespace
memory
{
namespace
allocation
{
class
MemInfo
{
public:
MemInfo
()
:
usage_
(
0
),
peak_usage_
(
0
)
{}
MemInfo
(
const
MemInfo
&
)
=
delete
;
MemInfo
&
operator
=
(
const
MemInfo
&
)
=
delete
;
// return a flag to indicate current operation will create a peak point or not
bool
Add
(
const
size_t
&
);
void
Minus
(
const
size_t
&
);
uint64_t
GetPeakUsage
();
private:
/* current memory usage*/
uint64_t
usage_
;
uint64_t
peak_usage_
;
std
::
mutex
mutex_
;
};
class
LegacyMemMonitor
{
public:
// used to store the GPU memory usage of each devices
using
MemUsage
=
std
::
unordered_map
<
/*device id*/
int
,
/*mem usage info node*/
MemInfo
*>
;
MemUsage
GetMemUsageInfo
()
{
return
gpu_mem_info_
;
}
~
LegacyMemMonitor
();
void
Initialize
(
const
int
&
);
void
Add
(
const
int
&
,
const
size_t
&
);
void
Minus
(
const
int
&
,
const
size_t
&
);
uint64_t
GetMemUsage
(
const
int
&
);
void
PrintMemUsage
();
protected:
MemUsage
gpu_mem_info_
;
};
extern
LegacyMemMonitor
GPUMemMonitor
;
class
LegacyAllocatorPrivate
;
class
LegacyAllocator
:
public
Allocator
{
public:
...
...
paddle/fluid/operators/batch_norm_op.cc
浏览文件 @
b6c3b69a
...
...
@@ -589,8 +589,10 @@ class BatchNormGradMaker : public framework::SingleGradOpDescMaker {
op
->
SetInput
(
"SavedVariance"
,
Output
(
"SavedVariance"
));
// used when setting use_global_stats True during training
op
->
SetInput
(
"Mean"
,
Output
(
"MeanOut"
));
op
->
SetInput
(
"Variance"
,
Output
(
"VarianceOut"
));
if
(
boost
::
get
<
bool
>
(
GetAttr
(
"use_global_stats"
)))
{
op
->
SetInput
(
"Mean"
,
Output
(
"MeanOut"
));
op
->
SetInput
(
"Variance"
,
Output
(
"VarianceOut"
));
}
op
->
SetAttrMap
(
Attrs
());
...
...
paddle/fluid/operators/detection/CMakeLists.txt
浏览文件 @
b6c3b69a
...
...
@@ -31,6 +31,7 @@ detection_library(polygon_box_transform_op SRCS polygon_box_transform_op.cc
polygon_box_transform_op.cu
)
detection_library
(
rpn_target_assign_op SRCS rpn_target_assign_op.cc
)
detection_library
(
generate_proposal_labels_op SRCS generate_proposal_labels_op.cc
)
detection_library
(
box_clip_op SRCS box_clip_op.cc box_clip_op.cu
)
detection_library
(
yolov3_loss_op SRCS yolov3_loss_op.cc
)
if
(
WITH_GPU
)
...
...
paddle/fluid/operators/detection/bbox_util.h
浏览文件 @
b6c3b69a
...
...
@@ -99,5 +99,29 @@ void BboxOverlaps(const framework::Tensor& r_boxes,
}
}
template
<
class
T
>
void
ClipTiledBoxes
(
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Tensor
&
im_info
,
const
framework
::
Tensor
&
input_boxes
,
framework
::
Tensor
*
out
)
{
T
*
out_data
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
const
T
*
im_info_data
=
im_info
.
data
<
T
>
();
const
T
*
input_boxes_data
=
input_boxes
.
data
<
T
>
();
T
zero
(
0
);
T
im_w
=
round
(
im_info_data
[
1
]
/
im_info_data
[
2
]);
T
im_h
=
round
(
im_info_data
[
0
]
/
im_info_data
[
2
]);
for
(
int64_t
i
=
0
;
i
<
input_boxes
.
numel
();
++
i
)
{
if
(
i
%
4
==
0
)
{
out_data
[
i
]
=
std
::
max
(
std
::
min
(
input_boxes_data
[
i
],
im_w
-
1
),
zero
);
}
else
if
(
i
%
4
==
1
)
{
out_data
[
i
]
=
std
::
max
(
std
::
min
(
input_boxes_data
[
i
],
im_h
-
1
),
zero
);
}
else
if
(
i
%
4
==
2
)
{
out_data
[
i
]
=
std
::
max
(
std
::
min
(
input_boxes_data
[
i
],
im_w
-
1
),
zero
);
}
else
{
out_data
[
i
]
=
std
::
max
(
std
::
min
(
input_boxes_data
[
i
],
im_h
-
1
),
zero
);
}
}
}
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/detection/box_clip_op.cc
0 → 100644
浏览文件 @
b6c3b69a
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/detection/box_clip_op.h"
#include "paddle/fluid/framework/op_registry.h"
namespace
paddle
{
namespace
operators
{
class
BoxClipOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Input"
),
"Input(Input) of BoxClipOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"ImInfo"
),
"Input(ImInfo) of BoxClipOp should not be null."
);
auto
input_box_dims
=
ctx
->
GetInputDim
(
"Input"
);
auto
im_info_dims
=
ctx
->
GetInputDim
(
"ImInfo"
);
if
(
ctx
->
IsRuntime
())
{
auto
input_box_size
=
input_box_dims
.
size
();
PADDLE_ENFORCE_EQ
(
input_box_dims
[
input_box_size
-
1
],
4
,
"The last dimension of Input must be 4"
);
PADDLE_ENFORCE_EQ
(
im_info_dims
.
size
(),
2
,
"The rank of Input(Input) in BoxClipOp must be 2"
);
PADDLE_ENFORCE_EQ
(
im_info_dims
[
1
],
3
,
"The last dimension of ImInfo must be 3"
);
}
ctx
->
ShareDim
(
"Input"
,
/*->*/
"Output"
);
ctx
->
ShareLoD
(
"Input"
,
/*->*/
"Output"
);
}
};
class
BoxClipOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"Input"
,
"(LoDTensor) "
"Input is a LoDTensor with shape [..., 4] holds 4 points"
"in last dimension in format [xmin, ymin, xmax, ymax]"
);
AddInput
(
"ImInfo"
,
"(Tensor) Information for image reshape is in shape (N, 3), "
"in format (height, width, im_scale)"
);
AddOutput
(
"Output"
,
"(LoDTensor) "
"Output is a LoDTensor with the same shape as Input"
"and it is the result after clip"
);
AddComment
(
R"DOC(
This operator clips input boxes to original input images.
For each input box, The formula is given as follows:
$$xmin = \max(\min(xmin, im_w - 1), 0)$$
$$ymin = \max(\min(ymin, im_h - 1), 0)$$
$$xmax = \max(\min(xmax, im_w - 1), 0)$$
$$ymax = \max(\min(ymax, im_h - 1), 0)$$
where im_w and im_h are computed from ImInfo, the formula is given as follows:
$$im_w = \round(width / im_scale)$$
$$im_h = \round(height / im_scale)$$
)DOC"
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
box_clip
,
ops
::
BoxClipOp
,
ops
::
BoxClipOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
);
REGISTER_OP_CPU_KERNEL
(
box_clip
,
ops
::
BoxClipKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
BoxClipKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
paddle/fluid/operators/detection/box_clip_op.cu
0 → 100644
浏览文件 @
b6c3b69a
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detection/box_clip_op.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/hostdevice.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
using
LoDTenso
=
framework
::
LoDTensor
;
static
constexpr
int
ImInfoSize
=
3
;
template
<
typename
T
,
int
BlockSize
>
static
__global__
void
GPUBoxClip
(
const
T
*
input
,
const
size_t
*
lod
,
const
size_t
width
,
const
T
*
im_info
,
T
*
output
)
{
T
im_w
=
round
(
im_info
[
blockIdx
.
x
*
ImInfoSize
+
1
]
/
im_info
[
blockIdx
.
x
*
ImInfoSize
+
2
]);
T
im_h
=
round
(
im_info
[
blockIdx
.
x
*
ImInfoSize
]
/
im_info
[
blockIdx
.
x
*
ImInfoSize
+
2
]);
for
(
int
i
=
threadIdx
.
x
;
i
<
(
lod
[
blockIdx
.
x
+
1
]
-
lod
[
blockIdx
.
x
])
*
width
;
i
+=
BlockSize
)
{
int
idx
=
lod
[
blockIdx
.
x
]
*
width
+
i
;
T
im_size
=
(
idx
%
2
==
0
)
?
im_w
:
im_h
;
output
[
idx
]
=
max
(
min
(
input
[
idx
],
im_size
-
1
),
T
(
0.
));
}
}
template
<
typename
DeviceContext
,
typename
T
>
class
GPUBoxClipKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
context
.
GetPlace
()),
"This kernel only runs on GPU device."
);
auto
*
input
=
context
.
Input
<
LoDTensor
>
(
"Input"
);
auto
*
im_info
=
context
.
Input
<
Tensor
>
(
"ImInfo"
);
auto
*
output
=
context
.
Output
<
LoDTensor
>
(
"Output"
);
const
int64_t
num
=
input
->
dims
()[
0
];
const
int64_t
bbox_width
=
input
->
numel
()
/
num
;
auto
lod
=
input
->
lod
();
framework
::
LoD
abs_offset_lod
=
framework
::
ToAbsOffset
(
lod
);
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
auto
stream
=
dev_ctx
.
stream
();
const
size_t
batch_size
=
lod
.
back
().
size
()
-
1
;
T
*
output_data
=
output
->
mutable_data
<
T
>
(
dev_ctx
.
GetPlace
());
GPUBoxClip
<
T
,
512
><<<
batch_size
,
512
,
0
,
stream
>>>
(
input
->
data
<
T
>
(),
abs_offset_lod
[
0
].
CUDAMutableData
(
dev_ctx
.
GetPlace
()),
bbox_width
,
im_info
->
data
<
T
>
(),
output_data
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
box_clip
,
ops
::
GPUBoxClipKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
GPUBoxClipKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
paddle/fluid/operators/detection/box_clip_op.h
0 → 100644
浏览文件 @
b6c3b69a
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detection/bbox_util.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
using
LoDTensor
=
framework
::
LoDTensor
;
template
<
typename
DeviceContext
,
typename
T
>
class
BoxClipKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
input_box
=
context
.
Input
<
LoDTensor
>
(
"Input"
);
auto
*
im_info
=
context
.
Input
<
LoDTensor
>
(
"ImInfo"
);
auto
*
output_box
=
context
.
Output
<
LoDTensor
>
(
"Output"
);
auto
&
dev_ctx
=
context
.
template
device_context
<
platform
::
CPUDeviceContext
>();
output_box
->
mutable_data
<
T
>
(
context
.
GetPlace
());
if
(
input_box
->
lod
().
size
())
{
PADDLE_ENFORCE_EQ
(
input_box
->
lod
().
size
(),
1UL
,
"Only support 1 level of LoD."
);
}
auto
box_lod
=
input_box
->
lod
().
back
();
int64_t
n
=
static_cast
<
int64_t
>
(
box_lod
.
size
()
-
1
);
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
Tensor
im_info_slice
=
im_info
->
Slice
(
i
,
i
+
1
);
Tensor
box_slice
=
input_box
->
Slice
(
box_lod
[
i
],
box_lod
[
i
+
1
]);
Tensor
output_slice
=
output_box
->
Slice
(
box_lod
[
i
],
box_lod
[
i
+
1
]);
ClipTiledBoxes
<
T
>
(
dev_ctx
,
im_info_slice
,
box_slice
,
&
output_slice
);
}
}
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jit/benchmark.cc
浏览文件 @
b6c3b69a
...
...
@@ -93,6 +93,7 @@ std::vector<int> TestSizes() {
template
<
typename
KernelTuples
,
typename
...
Args
>
struct
BenchFunc
{
// return this function avg time
// TODO(TJ): clear cache every time
double
operator
()(
const
typename
KernelTuples
::
func_type
tgt
,
Args
...
args
)
{
for
(
int
i
=
0
;
i
<
FLAGS_burning
;
++
i
)
{
tgt
(
args
...);
...
...
@@ -172,6 +173,9 @@ void BenchXYZNKernel() {
RandomVec
<
T
>
(
d
,
y_data
);
BenchAllImpls
<
KT
,
jit
::
XYZNTuples
<
T
>
,
PlaceType
>
(
d
,
x
.
data
<
T
>
(),
y
.
data
<
T
>
(),
z_data
,
d
);
// test inplace
BenchAllImpls
<
KT
,
jit
::
XYZNTuples
<
T
>
,
PlaceType
>
(
d
,
x
.
data
<
T
>
(),
z_data
,
z_data
,
d
);
}
}
...
...
paddle/fluid/operators/jit/gen/blas.cc
浏览文件 @
b6c3b69a
...
...
@@ -155,7 +155,7 @@ class NCHW16CMulNCCreator : public JitCodeCreator<int> {
class name##Creator : public JitCodeCreator<int> { \
public: \
bool UseMe(const int& attr) const override { \
return platform::MayIUse(platform::avx)
;
\
return platform::MayIUse(platform::avx)
&& attr <= 1024;
\
} \
size_t CodeSize(const int& d) const override { \
return 96 + d / YMM_FLOAT_BLOCK * 4 * 8; \
...
...
paddle/fluid/operators/jit/gen/blas.h
浏览文件 @
b6c3b69a
...
...
@@ -61,6 +61,7 @@ class VXXJitCode : public JitCode {
base
+=
"_Vec"
;
}
base
+=
(
with_relu_
?
"_Relu"
:
""
);
base
+=
"_D"
+
std
::
to_string
(
num_
);
return
base
.
c_str
();
}
void
genCode
()
override
;
...
...
paddle/fluid/operators/jit/helper.h
浏览文件 @
b6c3b69a
...
...
@@ -118,26 +118,33 @@ typename KernelTuples::func_type Get(
return
GetRefer
<
KT
,
KernelTuples
>
();
}
template
<
KernelType
KT
,
typename
KernelTuples
>
class
KernelFuncs
Cache
{
template
<
KernelType
KT
,
typename
KernelTuples
,
typename
PlaceType
>
class
KernelFuncs
{
public:
KernelFuncs
Cache
()
=
default
;
static
KernelFuncs
Cache
&
Instanc
e
()
{
static
thread_local
KernelFuncs
Cache
<
KT
,
KernelTuples
>
g_func_cache
;
KernelFuncs
()
=
default
;
static
KernelFuncs
&
Cach
e
()
{
static
thread_local
KernelFuncs
<
KT
,
KernelTuples
,
PlaceType
>
g_func_cache
;
return
g_func_cache
;
}
bool
Has
(
int
key
)
const
{
return
funcs_
.
find
(
key
)
!=
funcs_
.
end
();
}
typename
KernelTuples
::
func_type
At
(
int
key
)
{
return
funcs_
.
at
(
key
);
}
void
Insert
(
int
key
,
typename
KernelTuples
::
func_type
func
)
{
funcs_
.
emplace
(
key
,
func
);
}
typename
KernelTuples
::
func_type
At
(
int
key
)
{
if
(
Has
(
key
))
{
return
funcs_
.
at
(
key
);
}
auto
func
=
Get
<
KT
,
KernelTuples
,
PlaceType
>
(
key
);
Insert
(
key
,
func
);
return
func
;
}
private:
std
::
unordered_map
<
int
,
typename
KernelTuples
::
func_type
>
funcs_
;
DISABLE_COPY_AND_ASSIGN
(
KernelFuncs
Cache
);
DISABLE_COPY_AND_ASSIGN
(
KernelFuncs
);
};
const
char
*
to_string
(
KernelType
kt
);
...
...
paddle/fluid/operators/jit/more/mix/mix.cc
浏览文件 @
b6c3b69a
...
...
@@ -49,49 +49,16 @@ void VTanh(const T* x, T* y, int n) {
}
void
Softmax
(
const
T
*
x
,
T
*
y
,
int
n
,
int
bs
)
{
typename
XRNTuples
<
T
>::
func_type
compute_hmax
{
nullptr
};
typename
XRNTuples
<
T
>::
func_type
compute_hsum
{
nullptr
};
typename
AXYNTuples
<
T
>::
func_type
compute_vscal
{
nullptr
};
typename
AXYNTuples
<
T
>::
func_type
compute_vaddbias
{
nullptr
};
typename
XYNTuples
<
T
>::
func_type
compute_vexp
{
nullptr
};
if
(
!
KernelFuncsCache
<
kHMax
,
XRNTuples
<
T
>>::
Instance
().
Has
(
n
))
{
compute_hmax
=
Get
<
kHMax
,
XRNTuples
<
T
>
,
platform
::
CPUPlace
>
(
n
);
KernelFuncsCache
<
kHMax
,
XRNTuples
<
T
>>::
Instance
().
Insert
(
n
,
compute_hmax
);
}
else
{
compute_hmax
=
KernelFuncsCache
<
kHMax
,
XRNTuples
<
T
>>::
Instance
().
At
(
n
);
}
if
(
!
KernelFuncsCache
<
kHSum
,
XRNTuples
<
T
>>::
Instance
().
Has
(
n
))
{
compute_hsum
=
Get
<
kHSum
,
XRNTuples
<
T
>
,
platform
::
CPUPlace
>
(
n
);
KernelFuncsCache
<
kHSum
,
XRNTuples
<
T
>>::
Instance
().
Insert
(
n
,
compute_hsum
);
}
else
{
compute_hsum
=
KernelFuncsCache
<
kHSum
,
XRNTuples
<
T
>>::
Instance
().
At
(
n
);
}
if
(
!
KernelFuncsCache
<
kVScal
,
AXYNTuples
<
T
>>::
Instance
().
Has
(
n
))
{
compute_vscal
=
Get
<
kVScal
,
AXYNTuples
<
T
>
,
platform
::
CPUPlace
>
(
n
);
KernelFuncsCache
<
kVScal
,
AXYNTuples
<
T
>>::
Instance
().
Insert
(
n
,
compute_vscal
);
}
else
{
compute_vscal
=
KernelFuncsCache
<
kVScal
,
AXYNTuples
<
T
>>::
Instance
().
At
(
n
);
}
if
(
!
KernelFuncsCache
<
kVAddBias
,
AXYNTuples
<
T
>>::
Instance
().
Has
(
n
))
{
compute_vaddbias
=
Get
<
kVAddBias
,
AXYNTuples
<
T
>
,
platform
::
CPUPlace
>
(
n
);
KernelFuncsCache
<
kVAddBias
,
AXYNTuples
<
T
>>::
Instance
().
Insert
(
n
,
compute_vaddbias
);
}
else
{
compute_vaddbias
=
KernelFuncsCache
<
kVAddBias
,
AXYNTuples
<
T
>>::
Instance
().
At
(
n
);
}
if
(
!
KernelFuncsCache
<
kVExp
,
XYNTuples
<
T
>>::
Instance
().
Has
(
n
))
{
compute_vexp
=
Get
<
KernelType
::
kVExp
,
XYNTuples
<
T
>
,
platform
::
CPUPlace
>
(
n
);
KernelFuncsCache
<
kVExp
,
XYNTuples
<
T
>>::
Instance
().
Insert
(
n
,
compute_vexp
);
}
else
{
compute_vexp
=
KernelFuncsCache
<
kVExp
,
XYNTuples
<
T
>>::
Instance
().
At
(
n
);
}
auto
compute_hmax
=
KernelFuncs
<
kHMax
,
XRNTuples
<
T
>
,
platform
::
CPUPlace
>::
Cache
().
At
(
n
);
auto
compute_hsum
=
KernelFuncs
<
kHSum
,
XRNTuples
<
T
>
,
platform
::
CPUPlace
>::
Cache
().
At
(
n
);
auto
compute_vscal
=
KernelFuncs
<
kVScal
,
AXYNTuples
<
T
>
,
platform
::
CPUPlace
>::
Cache
().
At
(
n
);
auto
compute_vaddbias
=
KernelFuncs
<
kVAddBias
,
AXYNTuples
<
T
>
,
platform
::
CPUPlace
>::
Cache
().
At
(
n
);
auto
compute_vexp
=
KernelFuncs
<
kVExp
,
XYNTuples
<
T
>
,
platform
::
CPUPlace
>::
Cache
().
At
(
n
);
for
(
int
i
=
0
;
i
<
bs
;
++
i
)
{
T
scalar
;
...
...
paddle/fluid/operators/jit/more/mkl/mkl.cc
浏览文件 @
b6c3b69a
...
...
@@ -136,7 +136,7 @@ bool VMulKernel<float>::UseMe(const int& d) const {
template
<
>
bool
VAddKernel
<
float
>::
UseMe
(
const
int
&
d
)
const
{
return
platform
::
MayIUse
(
platform
::
avx
512f
)
&&
d
>
512
;
return
platform
::
MayIUse
(
platform
::
avx
)
&&
d
>
512
;
}
template
<
>
...
...
paddle/fluid/operators/math/fc_compute.h
浏览文件 @
b6c3b69a
...
...
@@ -30,15 +30,17 @@ inline void FCCompute(const BlasT<DeviceContext, T>& blas, const int M,
return
;
}
if
(
relu
)
{
auto
compute
=
jit
::
Get
<
jit
::
kVAddRelu
,
jit
::
XYZNTuples
<
T
>
,
platform
::
CPUPlace
>
(
N
);
auto
compute
=
jit
::
KernelFuncs
<
jit
::
kVAddRelu
,
jit
::
XYZNTuples
<
T
>
,
platform
::
CPUPlace
>::
Cache
()
.
At
(
N
);
for
(
int
i
=
0
;
i
<
M
;
i
++
)
{
T
*
dst
=
Y
+
i
*
N
;
compute
(
B
,
dst
,
dst
,
N
);
}
}
else
{
auto
compute
=
jit
::
Get
<
jit
::
kVAdd
,
jit
::
XYZNTuples
<
T
>
,
platform
::
CPUPlace
>
(
N
);
auto
compute
=
jit
::
KernelFuncs
<
jit
::
kVAdd
,
jit
::
XYZNTuples
<
T
>
,
platform
::
CPUPlace
>::
Cache
()
.
At
(
N
);
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
...
...
paddle/fluid/operators/math/softmax_impl.h
浏览文件 @
b6c3b69a
...
...
@@ -82,8 +82,9 @@ class SoftmaxFunctor<DeviceContext, float, true, enable_if_CPU<DeviceContext>> {
const
int
kClassDim
=
1
;
// 2D data. Batch x C
auto
compute_softmax
=
jit
::
Get
<
jit
::
kSoftmax
,
jit
::
SoftmaxTuples
<
float
>
,
platform
::
CPUPlace
>
(
in_dims
[
kClassDim
]);
jit
::
KernelFuncs
<
jit
::
kSoftmax
,
jit
::
SoftmaxTuples
<
float
>
,
platform
::
CPUPlace
>::
Cache
()
.
At
(
in_dims
[
kClassDim
]);
compute_softmax
(
in_data
,
out_data
,
in_dims
[
kClassDim
],
in_dims
[
kBatchDim
]);
}
};
...
...
paddle/fluid/operators/ngraph/ngraph_bridge.cc
浏览文件 @
b6c3b69a
...
...
@@ -31,6 +31,8 @@ std::map<std::string,
std
::
shared_ptr
<
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
ngraph
::
Node
>>>
)
>>
NgraphBridge
::
NG_NODE_MAP
=
{
{
"conv2d"
,
NG_OPS
::
BuildConv2dNode
},
{
"conv2d_grad"
,
NG_OPS
::
BuildConv2dGradNode
},
{
"elementwise_add"
,
NG_OPS
::
BuildElementwiseAddNode
},
{
"elementwise_add_grad"
,
NG_OPS
::
BuildElementwiseAddGradNode
},
{
"fill_constant"
,
NG_OPS
::
BuildFillConstantNode
},
...
...
paddle/fluid/operators/ngraph/ngraph_ops.h
浏览文件 @
b6c3b69a
...
...
@@ -22,6 +22,7 @@ limitations under the License. */
#pragma once
#include "ops/binary_unnary_op.h"
#include "ops/conv2d_op.h"
#include "ops/elementwise_add_op.h"
#include "ops/fill_constant_op.h"
#include "ops/mean_op.h"
...
...
paddle/fluid/operators/ngraph/ops/conv2d_op.h
0 → 100644
浏览文件 @
b6c3b69a
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "ngraph/ngraph.hpp"
#include "paddle/fluid/platform/ngraph_helper.h"
namespace
paddle
{
namespace
operators
{
namespace
ngraphs
{
std
::
shared_ptr
<
ngraph
::
Node
>
GroupedConvolution
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
data_batch
,
const
std
::
shared_ptr
<
ngraph
::
Node
>&
filters
,
const
ngraph
::
Strides
strides
,
const
ngraph
::
Strides
dilations
,
const
ngraph
::
CoordinateDiff
&
paddings
,
size_t
groups
)
{
auto
&
data_shape
=
data_batch
->
get_shape
();
auto
&
filter_shape
=
filters
->
get_shape
();
ngraph
::
NodeVector
ng_slices
;
for
(
size_t
i
=
0
;
i
<
groups
;
++
i
)
{
size_t
channel_step
=
filter_shape
.
at
(
1
);
const
std
::
vector
<
size_t
>
lower_bound
{
0
,
i
*
channel_step
,
0
,
0
};
const
std
::
vector
<
size_t
>
upper_bound
{
data_shape
.
at
(
0
),
(
i
+
1
)
*
channel_step
,
data_shape
.
at
(
2
),
data_shape
.
at
(
3
)};
auto
data_slice
=
std
::
make_shared
<
ngraph
::
op
::
Slice
>
(
data_batch
,
lower_bound
,
upper_bound
);
size_t
filter_step
=
filter_shape
.
at
(
0
)
/
groups
;
const
std
::
vector
<
size_t
>
filter_lower_bound
{
i
*
filter_step
,
0
,
0
,
0
};
const
std
::
vector
<
size_t
>
filter_upper_bound
{
(
i
+
1
)
*
filter_step
,
filter_shape
.
at
(
1
),
filter_shape
.
at
(
2
),
filter_shape
.
at
(
3
)};
auto
filter_slice
=
std
::
make_shared
<
ngraph
::
op
::
Slice
>
(
filters
,
filter_lower_bound
,
filter_upper_bound
);
auto
ng_conv
=
std
::
make_shared
<
ngraph
::
op
::
Convolution
>
(
data_slice
,
filter_slice
,
strides
,
dilations
,
paddings
,
paddings
);
ng_slices
.
push_back
(
ng_conv
);
}
size_t
concat_axis
=
1
;
return
std
::
make_shared
<
ngraph
::
op
::
Concat
>
(
ng_slices
,
concat_axis
);
}
std
::
shared_ptr
<
ngraph
::
Node
>
GroupedGradConvolutionFilter
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
data_batch
,
const
std
::
shared_ptr
<
ngraph
::
Node
>&
filters
,
const
std
::
shared_ptr
<
ngraph
::
Node
>&
doutput
,
const
ngraph
::
Strides
strides
,
const
ngraph
::
Strides
dilations
,
const
ngraph
::
CoordinateDiff
&
paddings
,
size_t
groups
)
{
auto
&
data_shape
=
data_batch
->
get_shape
();
auto
&
filter_shape
=
filters
->
get_shape
();
auto
&
out_shape
=
doutput
->
get_shape
();
ngraph
::
NodeVector
ng_slices
;
for
(
size_t
i
=
0
;
i
<
groups
;
++
i
)
{
size_t
channel_step
=
filter_shape
.
at
(
1
);
const
std
::
vector
<
size_t
>
lower_bound
{
0
,
i
*
channel_step
,
0
,
0
};
const
std
::
vector
<
size_t
>
upper_bound
{
data_shape
.
at
(
0
),
(
i
+
1
)
*
channel_step
,
data_shape
.
at
(
2
),
data_shape
.
at
(
3
)};
auto
data_slice
=
std
::
make_shared
<
ngraph
::
op
::
Slice
>
(
data_batch
,
lower_bound
,
upper_bound
);
size_t
filter_step
=
data_shape
.
at
(
0
);
const
std
::
vector
<
size_t
>
filter_lower_bound
{
i
*
filter_step
,
0
,
0
,
0
};
const
std
::
vector
<
size_t
>
filter_upper_bound
{
(
i
+
1
)
*
filter_step
,
filter_shape
.
at
(
1
),
filter_shape
.
at
(
2
),
filter_shape
.
at
(
3
)};
auto
filter_slice
=
std
::
make_shared
<
ngraph
::
op
::
Slice
>
(
filters
,
filter_lower_bound
,
filter_upper_bound
);
const
std
::
vector
<
size_t
>
olower_bound
{
0
,
i
*
filter_step
,
0
,
0
};
const
std
::
vector
<
size_t
>
oupper_bound
{
out_shape
.
at
(
0
),
(
i
+
1
)
*
filter_step
,
out_shape
.
at
(
2
),
out_shape
.
at
(
3
)};
auto
out_slice
=
std
::
make_shared
<
ngraph
::
op
::
Slice
>
(
doutput
,
olower_bound
,
oupper_bound
);
auto
ng_conv
=
std
::
make_shared
<
ngraph
::
op
::
ConvolutionBackpropFilters
>
(
data_slice
,
filter_slice
->
get_shape
(),
out_slice
,
strides
,
dilations
,
paddings
,
paddings
,
ngraph
::
Strides
{
1
,
1
});
ng_slices
.
push_back
(
ng_conv
);
}
size_t
concat_axis
=
0
;
return
std
::
make_shared
<
ngraph
::
op
::
Concat
>
(
ng_slices
,
concat_axis
);
}
std
::
shared_ptr
<
ngraph
::
Node
>
GroupedGradConvolutionData
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
data_batch
,
const
std
::
shared_ptr
<
ngraph
::
Node
>&
filters
,
const
std
::
shared_ptr
<
ngraph
::
Node
>&
doutput
,
const
ngraph
::
Strides
strides
,
const
ngraph
::
Strides
dilations
,
const
ngraph
::
CoordinateDiff
&
paddings
,
size_t
groups
)
{
auto
&
data_shape
=
data_batch
->
get_shape
();
auto
&
filter_shape
=
filters
->
get_shape
();
auto
&
out_shape
=
doutput
->
get_shape
();
ngraph
::
NodeVector
ng_slices
;
for
(
size_t
i
=
0
;
i
<
groups
;
++
i
)
{
size_t
channel_step
=
filter_shape
.
at
(
1
);
const
std
::
vector
<
size_t
>
lower_bound
{
0
,
i
*
channel_step
,
0
,
0
};
const
std
::
vector
<
size_t
>
upper_bound
{
data_shape
.
at
(
0
),
(
i
+
1
)
*
channel_step
,
data_shape
.
at
(
2
),
data_shape
.
at
(
3
)};
auto
data_slice
=
std
::
make_shared
<
ngraph
::
op
::
Slice
>
(
data_batch
,
lower_bound
,
upper_bound
);
size_t
filter_step
=
data_shape
.
at
(
0
);
const
std
::
vector
<
size_t
>
filter_lower_bound
{
i
*
filter_step
,
0
,
0
,
0
};
const
std
::
vector
<
size_t
>
filter_upper_bound
{
(
i
+
1
)
*
filter_step
,
filter_shape
.
at
(
1
),
filter_shape
.
at
(
2
),
filter_shape
.
at
(
3
)};
auto
filter_slice
=
std
::
make_shared
<
ngraph
::
op
::
Slice
>
(
filters
,
filter_lower_bound
,
filter_upper_bound
);
const
std
::
vector
<
size_t
>
olower_bound
{
0
,
i
*
filter_step
,
0
,
0
};
const
std
::
vector
<
size_t
>
oupper_bound
{
out_shape
.
at
(
0
),
(
i
+
1
)
*
filter_step
,
out_shape
.
at
(
2
),
out_shape
.
at
(
3
)};
auto
out_slice
=
std
::
make_shared
<
ngraph
::
op
::
Slice
>
(
doutput
,
olower_bound
,
oupper_bound
);
auto
ng_conv
=
std
::
make_shared
<
ngraph
::
op
::
ConvolutionBackpropData
>
(
data_slice
->
get_shape
(),
filter_slice
,
out_slice
,
strides
,
dilations
,
paddings
,
paddings
,
ngraph
::
Strides
{
1
,
1
});
ng_slices
.
push_back
(
ng_conv
);
}
size_t
concat_axis
=
1
;
return
std
::
make_shared
<
ngraph
::
op
::
Concat
>
(
ng_slices
,
concat_axis
);
}
void
BuildConv2dNode
(
const
std
::
shared_ptr
<
paddle
::
framework
::
OperatorBase
>&
op
,
std
::
shared_ptr
<
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
ngraph
::
Node
>>>
ngb_node_map
)
{
auto
op_attrs
=
paddle
::
framework
::
AttrReader
(
op
->
Attrs
());
auto
filters
=
paddle
::
platform
::
GetInputNode
(
op
,
"Filter"
,
ngb_node_map
);
auto
input
=
paddle
::
platform
::
GetInputNode
(
op
,
"Input"
,
ngb_node_map
);
std
::
vector
<
int
>
strides
=
op_attrs
.
Get
<
std
::
vector
<
int
>>
(
"strides"
);
std
::
vector
<
int
>
paddings
=
op_attrs
.
Get
<
std
::
vector
<
int
>>
(
"paddings"
);
std
::
vector
<
int
>
dilations
=
op_attrs
.
Get
<
std
::
vector
<
int
>>
(
"dilations"
);
const
ngraph
::
Strides
ng_strides
{
static_cast
<
size_t
>
(
strides
.
at
(
0
)),
static_cast
<
size_t
>
(
strides
.
at
(
1
))};
const
ngraph
::
Strides
ng_dilations
{
static_cast
<
size_t
>
(
dilations
.
at
(
0
)),
static_cast
<
size_t
>
(
dilations
.
at
(
1
))};
const
ngraph
::
CoordinateDiff
ng_paddings
{
static_cast
<
std
::
ptrdiff_t
>
(
paddings
.
at
(
0
)),
static_cast
<
std
::
ptrdiff_t
>
(
paddings
.
at
(
1
))};
int
groups
=
static_cast
<
size_t
>
(
op_attrs
.
Get
<
int
>
(
"groups"
));
PADDLE_ENFORCE_GE
(
groups
,
1
,
"conv groups needs be no less than 1"
);
std
::
shared_ptr
<
ngraph
::
Node
>
result
;
if
(
groups
==
1
)
{
result
=
std
::
make_shared
<
ngraph
::
op
::
Convolution
>
(
input
,
filters
,
ng_strides
,
ng_dilations
,
ng_paddings
,
ng_paddings
);
}
else
{
result
=
GroupedConvolution
(
input
,
filters
,
ng_strides
,
ng_dilations
,
ng_paddings
,
groups
);
}
paddle
::
platform
::
SetOutputNode
(
op
,
"Output"
,
result
,
ngb_node_map
);
}
void
BuildConv2dGradNode
(
const
std
::
shared_ptr
<
paddle
::
framework
::
OperatorBase
>&
op
,
std
::
shared_ptr
<
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
ngraph
::
Node
>>>
ngb_node_map
)
{
auto
op_attrs
=
paddle
::
framework
::
AttrReader
(
op
->
Attrs
());
auto
filter
=
paddle
::
platform
::
GetInputNode
(
op
,
"Filter"
,
ngb_node_map
);
auto
input
=
paddle
::
platform
::
GetInputNode
(
op
,
"Input"
,
ngb_node_map
);
auto
doutput
=
paddle
::
platform
::
GetInputNode
(
op
,
"Output@GRAD"
,
ngb_node_map
);
int
groups
=
op_attrs
.
Get
<
int
>
(
"groups"
);
std
::
vector
<
int
>
strides
=
op_attrs
.
Get
<
std
::
vector
<
int
>>
(
"strides"
);
std
::
vector
<
int
>
paddings
=
op_attrs
.
Get
<
std
::
vector
<
int
>>
(
"paddings"
);
std
::
vector
<
int
>
dilations
=
op_attrs
.
Get
<
std
::
vector
<
int
>>
(
"dilations"
);
const
ngraph
::
Strides
ng_strides
{
static_cast
<
size_t
>
(
strides
.
at
(
0
)),
static_cast
<
size_t
>
(
strides
.
at
(
1
))};
const
ngraph
::
Strides
ng_dilations
{
static_cast
<
size_t
>
(
dilations
.
at
(
0
)),
static_cast
<
size_t
>
(
dilations
.
at
(
1
))};
const
ngraph
::
CoordinateDiff
ng_paddings
{
static_cast
<
std
::
ptrdiff_t
>
(
paddings
.
at
(
0
)),
static_cast
<
std
::
ptrdiff_t
>
(
paddings
.
at
(
1
))};
std
::
shared_ptr
<
ngraph
::
Node
>
dfilter
;
std
::
shared_ptr
<
ngraph
::
Node
>
dinput
;
if
(
groups
==
1
)
{
dfilter
=
std
::
make_shared
<
ngraph
::
op
::
ConvolutionBackpropFilters
>
(
input
,
filter
->
get_shape
(),
doutput
,
ng_strides
,
ng_dilations
,
ng_paddings
,
ng_paddings
,
ngraph
::
Strides
{
1
,
1
});
dinput
=
std
::
make_shared
<
ngraph
::
op
::
ConvolutionBackpropData
>
(
input
->
get_shape
(),
filter
,
doutput
,
ng_strides
,
ng_dilations
,
ng_paddings
,
ng_paddings
,
ngraph
::
Strides
{
1
,
1
});
}
else
{
dfilter
=
GroupedGradConvolutionFilter
(
input
,
filter
,
doutput
,
ng_strides
,
ng_dilations
,
ng_paddings
,
groups
);
dinput
=
GroupedGradConvolutionData
(
input
,
filter
,
doutput
,
ng_strides
,
ng_dilations
,
ng_paddings
,
groups
);
}
paddle
::
platform
::
SetOutputNode
(
op
,
"Filter@GRAD"
,
dfilter
,
ngb_node_map
);
paddle
::
platform
::
SetOutputNode
(
op
,
"Input@GRAD"
,
dinput
,
ngb_node_map
);
}
}
// namespace ngraphs
}
// namespace operators
}
// namespace paddle
paddle/fluid/platform/place.cc
浏览文件 @
b6c3b69a
...
...
@@ -14,6 +14,12 @@ limitations under the License. */
#include "paddle/fluid/platform/place.h"
DEFINE_bool
(
benchmark
,
false
,
"Doing memory benchmark. It will make deleting scope synchronized, "
"and add some memory usage logs."
"Default cuda is asynchronous device, set to True will"
"force op run in synchronous mode."
);
namespace
paddle
{
namespace
platform
{
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
b6c3b69a
...
...
@@ -37,6 +37,7 @@ limitations under the License. */
#include "paddle/fluid/framework/version.h"
#include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/memory/allocation/allocator_strategy.h"
#include "paddle/fluid/memory/allocation/legacy_allocator.h"
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/py_func_op.h"
#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h"
...
...
@@ -127,6 +128,13 @@ PYBIND11_MODULE(core, m) {
m
.
add_object
(
"_cleanup"
,
py
::
capsule
([]()
{
ScopePool
::
Instance
().
Clear
();
}));
m
.
def
(
"get_mem_usage"
,
[](
int
device
)
{
return
memory
::
allocation
::
GPUMemMonitor
.
GetMemUsage
(
device
);
});
m
.
def
(
"print_mem_usage"
,
[]()
{
return
memory
::
allocation
::
GPUMemMonitor
.
PrintMemUsage
();
});
py
::
class_
<
imperative
::
VarBase
>
(
m
,
"VarBase"
,
R"DOC()DOC"
)
// .def(py::init<>())
.
def
(
py
::
init
<
bool
>
(),
py
::
arg
(
"stop_gradient"
)
=
false
)
...
...
python/paddle/fluid/io.py
浏览文件 @
b6c3b69a
...
...
@@ -21,9 +21,10 @@ import shutil
import
six
from
functools
import
reduce
from
paddle.fluid
import
layers
from
paddle.fluid.executor
import
Executor
from
paddle.fluid.evaluator
import
Evaluator
from
paddle.fluid.framework
import
Program
,
Parameter
,
default_main_program
,
default_startup_program
,
Variable
from
paddle.fluid.framework
import
Program
,
Parameter
,
default_main_program
,
default_startup_program
,
Variable
,
program_guard
from
.
import
core
__all__
=
[
...
...
@@ -931,6 +932,17 @@ def save_inference_model(dirname,
if
main_program
is
None
:
main_program
=
default_main_program
()
# fix the bug that the activation op's output as target will be pruned.
# will affect the inference performance.
# TODO(Superjomn) add an IR pass to remove 1-scale op.
with
program_guard
(
main_program
):
uniq_target_vars
=
[]
for
var
in
target_vars
:
if
isinstance
(
var
,
Variable
):
var1
=
layers
.
scale
(
var
,
1.
)
uniq_target_vars
.
append
(
var1
)
target_vars
=
uniq_target_vars
# when a pserver and a trainer running on the same machine, mkdir may conflict
try
:
os
.
makedirs
(
dirname
)
...
...
python/paddle/fluid/layers/detection.py
浏览文件 @
b6c3b69a
...
...
@@ -49,6 +49,7 @@ __all__ = [
'box_coder'
,
'polygon_box_transform'
,
'yolov3_loss'
,
'box_clip'
,
'multiclass_nms'
,
]
...
...
@@ -2055,6 +2056,54 @@ def generate_proposals(scores,
return
rpn_rois
,
rpn_roi_probs
def
box_clip
(
input
,
im_info
,
name
=
None
):
"""
Clip the box into the size given by im_info
For each input box, The formula is given as follows:
.. code-block:: text
xmin = max(min(xmin, im_w - 1), 0)
ymin = max(min(ymin, im_h - 1), 0)
xmax = max(min(xmax, im_w - 1), 0)
ymax = max(min(ymax, im_h - 1), 0)
where im_w and im_h are computed from im_info:
.. code-block:: text
im_h = round(height / scale)
im_w = round(weight / scale)
Args:
input(variable): The input box, the last dimension is 4.
im_info(variable): The information of image with shape [N, 3] with
layout (height, width, scale). height and width
is the input size and scale is the ratio of input
size and original size.
name (str): The name of this layer. It is optional.
Returns:
Variable: The cliped tensor variable.
Examples:
.. code-block:: python
boxes = fluid.layers.data(
name='data', shape=[8, 4], dtype='float32', lod_level=1)
im_info = fluid.layers.data(name='im_info', shape=[3])
out = fluid.layers.box_clip(
input=boxes, im_info=im_info, inplace=True)
"""
helper
=
LayerHelper
(
"box_clip"
,
**
locals
())
output
=
helper
.
create_variable_for_type_inference
(
dtype
=
input
.
dtype
)
inputs
=
{
"Input"
:
input
,
"ImInfo"
:
im_info
}
helper
.
append_op
(
type
=
"box_clip"
,
inputs
=
inputs
,
outputs
=
{
"Output"
:
output
})
return
output
def
multiclass_nms
(
bboxes
,
scores
,
score_threshold
,
...
...
@@ -2132,9 +2181,11 @@ def multiclass_nms(bboxes,
(After version 1.3, when no boxes detected, the lod is changed
from {0} to {1})
Examples:
.. code-block:: python
boxes = fluid.layers.data(name='bboxes', shape=[81, 4],
dtype='float32', lod_level=1)
scores = fluid.layers.data(name='scores', shape=[81],
...
...
python/paddle/fluid/tests/test_detection.py
浏览文件 @
b6c3b69a
...
...
@@ -482,6 +482,17 @@ class TestYoloDetection(unittest.TestCase):
self
.
assertIsNotNone
(
loss
)
class
TestBoxClip
(
unittest
.
TestCase
):
def
test_box_clip
(
self
):
program
=
Program
()
with
program_guard
(
program
):
input_box
=
layers
.
data
(
name
=
'input_box'
,
shape
=
[
7
,
4
],
dtype
=
'float32'
,
lod_level
=
1
)
im_info
=
layers
.
data
(
name
=
'im_info'
,
shape
=
[
3
],
dtype
=
'float32'
)
out
=
layers
.
box_clip
(
input_box
,
im_info
)
self
.
assertIsNotNone
(
out
)
class
TestMulticlassNMS
(
unittest
.
TestCase
):
def
test_multiclass_nms
(
self
):
program
=
Program
()
...
...
python/paddle/fluid/tests/unittests/ngraph/test_conv2d_ngraph_op.py
0 → 100644
浏览文件 @
b6c3b69a
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
from
paddle.fluid.tests.unittests.test_conv2d_op
import
*
class
TestNGRAPH
(
TestConv2dOp
):
def
init_kernel_type
(
self
):
super
(
TestNGRAPH
,
self
).
init_kernel_type
()
class
TestNGRAPHWithPad
(
TestWithPad
):
def
init_kernel_type
(
self
):
super
(
TestNGRAPHWithPad
,
self
).
init_kernel_type
()
class
TestNGRAPHWithStride
(
TestWithStride
):
def
init_kernel_type
(
self
):
super
(
TestNGRAPHWithStride
,
self
).
init_kernel_type
()
class
TestNGRAPHWithGroup
(
TestWithGroup
):
def
init_kernel_type
(
self
):
super
(
TestNGRAPHWithGroup
,
self
).
init_kernel_type
()
class
TestNGRAPHWith1x1
(
TestWith1x1
):
def
init_kernel_type
(
self
):
super
(
TestNGRAPHWith1x1
,
self
).
init_kernel_type
()
class
TestNGRAPHWithInput1x1Filter1x1
(
TestWithInput1x1Filter1x1
):
def
init_kernel_type
(
self
):
super
(
TestNGRAPHWithInput1x1Filter1x1
,
self
).
init_kernel_type
()
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_box_clip_op.py
0 → 100644
浏览文件 @
b6c3b69a
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
numpy
as
np
import
sys
import
math
from
op_test
import
OpTest
import
copy
def
box_clip
(
input_box
,
im_info
,
output_box
):
im_w
=
round
(
im_info
[
1
]
/
im_info
[
2
])
im_h
=
round
(
im_info
[
0
]
/
im_info
[
2
])
output_box
[:,
:,
0
]
=
np
.
maximum
(
np
.
minimum
(
input_box
[:,
:,
0
],
im_w
-
1
),
0
)
output_box
[:,
:,
1
]
=
np
.
maximum
(
np
.
minimum
(
input_box
[:,
:,
1
],
im_h
-
1
),
0
)
output_box
[:,
:,
2
]
=
np
.
maximum
(
np
.
minimum
(
input_box
[:,
:,
2
],
im_w
-
1
),
0
)
output_box
[:,
:,
3
]
=
np
.
maximum
(
np
.
minimum
(
input_box
[:,
:,
3
],
im_h
-
1
),
0
)
def
batch_box_clip
(
input_boxes
,
im_info
,
lod
):
n
=
input_boxes
.
shape
[
0
]
m
=
input_boxes
.
shape
[
1
]
output_boxes
=
np
.
zeros
((
n
,
m
,
4
),
dtype
=
np
.
float32
)
cur_offset
=
0
for
i
in
range
(
len
(
lod
)):
box_clip
(
input_boxes
[
cur_offset
:(
cur_offset
+
lod
[
i
]),
:,
:],
im_info
[
i
,
:],
output_boxes
[
cur_offset
:(
cur_offset
+
lod
[
i
]),
:,
:])
cur_offset
+=
lod
[
i
]
return
output_boxes
class
TestBoxClipOp
(
OpTest
):
def
test_check_output
(
self
):
self
.
check_output
()
def
setUp
(
self
):
self
.
op_type
=
"box_clip"
lod
=
[[
1
,
2
,
3
]]
input_boxes
=
np
.
random
.
random
((
6
,
10
,
4
))
*
5
im_info
=
np
.
array
([[
5
,
8
,
1.
],
[
6
,
6
,
1.
],
[
7
,
5
,
1.
]])
output_boxes
=
batch_box_clip
(
input_boxes
,
im_info
,
lod
[
0
])
self
.
inputs
=
{
'Input'
:
(
input_boxes
.
astype
(
'float32'
),
lod
),
'ImInfo'
:
im_info
.
astype
(
'float32'
),
}
self
.
outputs
=
{
'Output'
:
output_boxes
}
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_inference_model_io.py
浏览文件 @
b6c3b69a
...
...
@@ -82,7 +82,8 @@ class TestBook(unittest.TestCase):
self
.
assertEqual
(
feed_var_names
,
[
"x"
,
"y"
])
self
.
assertEqual
(
len
(
fetch_vars
),
1
)
self
.
assertEqual
(
str
(
fetch_vars
[
0
]),
str
(
avg_cost
))
print
(
"fetch %s"
%
str
(
fetch_vars
[
0
]))
self
.
assertTrue
(
"scale"
in
str
(
fetch_vars
[
0
]))
self
.
assertEqual
(
expected
,
actual
)
...
...
python/paddle/fluid/tests/unittests/test_peak_gpumem_monitor.py
0 → 100644
浏览文件 @
b6c3b69a
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
os
os
.
environ
[
'FLAGS_benchmark'
]
=
'True'
import
numpy
import
paddle.fluid.core
as
core
from
paddle.fluid.executor
import
Executor
from
paddle.fluid.layers
import
mul
,
data
class
TestPeakMemoryMonitoring
(
unittest
.
TestCase
):
def
test_mul
(
self
):
a
=
data
(
name
=
'a'
,
shape
=
[
784
],
dtype
=
'float32'
)
b
=
data
(
name
=
'b'
,
shape
=
[
784
,
100
],
dtype
=
'float32'
,
append_batch_size
=
False
)
out
=
mul
(
x
=
a
,
y
=
b
)
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
a_np
=
numpy
.
random
.
random
((
100
,
784
)).
astype
(
'float32'
)
b_np
=
numpy
.
random
.
random
((
784
,
100
)).
astype
(
'float32'
)
self
.
assertEqual
(
0
,
core
.
get_mem_usage
(
0
))
exe
=
Executor
(
place
)
outs
=
exe
.
run
(
feed
=
{
'a'
:
a_np
,
'b'
:
b_np
},
fetch_list
=
[
out
])
out
=
outs
[
0
]
#disable this assert since ctest will ignore the os.environ setting
#self.assertGreater(core.get_mem_usage(0), 0)
raised
=
False
try
:
core
.
print_mem_usage
()
except
:
raised
=
True
self
.
assertFalse
(
raised
,
'Exception raised'
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录